mirror of https://github.com/ZHKKKe/MODNet.git
full train demo
parent
77118b2a4e
commit
a070aaeeed
138
README.md
138
README.md
|
|
@ -1,126 +1,24 @@
|
|||
<h2 align="center">MODNet: Trimap-Free Portrait Matting in Real Time</h2>
|
||||
# 说明
|
||||
代码fork from [MODNet官方代码](https://github.com/ZHKKKe/MODNet) 。本项目完善了数据准备、模型评价及模型训练相关代码
|
||||
# 模型训练、评价、推理
|
||||
```bash
|
||||
# 1. 下载代码并进入工作目录
|
||||
git clone https://github.com/actboy/MODNet
|
||||
cd MODNet
|
||||
|
||||
<div align="center"><i>MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition (AAAI 2022)</i></div>
|
||||
# 2. 安装依赖
|
||||
pip install -r src/requirements.txt
|
||||
|
||||
<br />
|
||||
# 3. 下载并解压数据集
|
||||
wget -c https://paddleseg.bj.bcebos.com/matting/datasets/PPM-100.zip -O src/datasets/PPM-100.zip
|
||||
unzip src/datasets/PPM-100.zip -d src/datasets
|
||||
|
||||
<img src="doc/gif/homepage_demo.gif" width="100%">
|
||||
# 4. 训练模型
|
||||
python src/trainer.py
|
||||
|
||||
<div align="center">MODNet is a model for <b>real-time</b> portrait matting with <b>only RGB image input</b></div>
|
||||
<div align="center">MODNet是一个<b>仅需RGB图片输入</b>的<b>实时</b>人像抠图模型</div>
|
||||
# 5. 模型评估
|
||||
python src/eval.py
|
||||
|
||||
<br />
|
||||
|
||||
<p align="center">
|
||||
<a href="#online-application-在线应用">Online Application (在线应用)</a> |
|
||||
<a href="#research-demo">Research Demo</a> |
|
||||
<a href="https://arxiv.org/pdf/2011.11961.pdf">AAAI 2022 Paper</a> |
|
||||
<a href="https://youtu.be/PqJ3BRHX3Lc">Supplementary Video</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="#community">Community</a> |
|
||||
<a href="#code">Code</a> |
|
||||
<a href="#ppm-benchmark">PPM Benchmark</a> |
|
||||
<a href="#license">License</a> |
|
||||
<a href="#acknowledgement">Acknowledgement</a> |
|
||||
<a href="#citation">Citation</a> |
|
||||
<a href="#contact">Contact</a>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
|
||||
## Online Application (在线应用)
|
||||
|
||||
A **Single** model! Only **7M**! Process **2K** resolution image with a **Fast** speed on common PCs or Mobiles! **Beter** than research demos!
|
||||
Please try online portrait image matting via [this website](https://sight-x.cn/portrait_matting)!
|
||||
|
||||
**单个**模型!大小仅为**7M**!可以在普通PC或移动设备上**快速**处理具有**2K**分辨率的图像!效果比研究示例**更好**!
|
||||
请通过[此网站](https://sight-x.cn/portrait_matting)在线尝试图片抠像!
|
||||
|
||||
|
||||
## Research Demo
|
||||
|
||||
All the models behind the following demos are trained on the datasets mentioned in [our paper](https://arxiv.org/pdf/2011.11961.pdf).
|
||||
|
||||
### Portrait Image Matting
|
||||
We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting.
|
||||
It allows you to upload portrait images and predict/visualize/download the alpha mattes.
|
||||
|
||||
<!-- <img src="doc/gif/image_matting_demo.gif" width='40%'> -->
|
||||
|
||||
### Portrait Video Matting
|
||||
We provide two real-time portrait video matting demos based on WebCam. When using the demo, you can move the WebCam around at will.
|
||||
If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/webcam) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
|
||||
We also provide an [offline demo](demo/video_matting/custom) that allows you to process custom videos.
|
||||
|
||||
<!-- <img src="doc/gif/video_matting_demo.gif" width='60%'> -->
|
||||
|
||||
|
||||
## Community
|
||||
|
||||
We share some cool applications/extentions of MODNet built by the community.
|
||||
|
||||
- **WebGUI for Portrait Image Matting**
|
||||
You can try [this WebGUI](https://www.gradio.app/hub/aliabd/modnet) (hosted on [Gradio](https://www.gradio.app/)) for portrait image matting from your browser without code!
|
||||
|
||||
- **Colab Demo of Bokeh (Blur Background)**
|
||||
You can try [this Colab demo](https://colab.research.google.com/github/eyaler/avatars4all/blob/master/yarok.ipynb) (built by [@eyaler](https://github.com/eyaler)) to blur the backgroud based on MODNet!
|
||||
|
||||
- **ONNX Version of MODNet**
|
||||
You can convert the pre-trained MODNet to an ONNX model by using [this code](onnx) (provided by [@manthan3C273](https://github.com/manthan3C273)). You can also try [this Colab demo](https://colab.research.google.com/drive/1P3cWtg8fnmu9karZHYDAtmm1vj1rgA-f?usp=sharing) for MODNet image matting (ONNX version).
|
||||
|
||||
- **TorchScript Version of MODNet**
|
||||
You can convert the pre-trained MODNet to an TorchScript model by using [this code](torchscript) (provided by [@yarkable](https://github.com/yarkable)).
|
||||
|
||||
- **TensorRT Version of MODNet**
|
||||
You can access [this Github repository](https://github.com/jkjung-avt/tensorrt_demos) to try the TensorRT version of MODNet (provided by [@jkjung-avt](https://github.com/jkjung-avt)).
|
||||
|
||||
|
||||
There are some resources about MODNet from the community.
|
||||
- [Video from What's AI YouTube Channel](https://youtu.be/rUo0wuVyefU)
|
||||
- [Article from Louis Bouchard's Blog](https://www.louisbouchard.ai/remove-background/)
|
||||
|
||||
|
||||
## Code
|
||||
We provide the [code](src/trainer.py) of MODNet training iteration, including:
|
||||
- **Supervised Training**: Train MODNet on a labeled matting dataset
|
||||
- **SOC Adaptation**: Adapt a trained MODNet to an unlabeled dataset
|
||||
|
||||
In code comments, we provide examples for using the functions.
|
||||
|
||||
|
||||
## PPM Benchmark
|
||||
The PPM benchmark is released in a separate repository [PPM](https://github.com/ZHKKKe/PPM).
|
||||
|
||||
|
||||
## License
|
||||
The code, models, and demos in this repository (excluding GIF files under the folder `doc/gif`) are released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0) license.
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
- We thank
|
||||
[@yzhou0919](https://github.com/yzhou0919), [@eyaler](https://github.com/eyaler), [@manthan3C273](https://github.com/manthan3C273), [@yarkable](https://github.com/yarkable), [@jkjung-avt](https://github.com/jkjung-avt), [@manzke](https://github.com/manzke),
|
||||
[the Gradio team](https://github.com/gradio-app/gradio), [What's AI YouTube Channel](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg), [Louis Bouchard's Blog](https://www.louisbouchard.ai),
|
||||
for their contributions to this repository or their cool applications/extentions/resources of MODNet.
|
||||
|
||||
|
||||
## Citation
|
||||
If this work helps your research, please consider to cite:
|
||||
|
||||
```bibtex
|
||||
@InProceedings{MODNet,
|
||||
author = {Zhanghan Ke and Jiayu Sun and Kaican Li and Qiong Yan and Rynson W.H. Lau},
|
||||
title = {MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition},
|
||||
booktitle = {AAAI},
|
||||
year = {2022},
|
||||
}
|
||||
# 6. 模型推理
|
||||
python src/infer.py
|
||||
```
|
||||
|
||||
|
||||
## Contact
|
||||
This repository is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
|
||||
For questions, please contact `kezhanghan@outlook.com`.
|
||||
|
||||
<img src="doc/gif/commercial_image_matting_model_result.gif" width='100%'>
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import numpy as np
|
||||
from glob import glob
|
||||
from src.models.modnet import MODNet
|
||||
from models.modnet import MODNet
|
||||
from PIL import Image
|
||||
from src.infer import predit_matte
|
||||
from infer import predit_matte
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
|
@ -22,8 +22,8 @@ def cal_mse(pred, gt):
|
|||
|
||||
|
||||
def load_eval_dataset(dataset_root_dir='src/datasets/PPM-100'):
|
||||
image_path = dataset_root_dir + '/image/*'
|
||||
matte_path = dataset_root_dir + '/matte/*'
|
||||
image_path = dataset_root_dir + '/val/fg/*'
|
||||
matte_path = dataset_root_dir + '/val/alpha/*'
|
||||
image_file_name_list = glob(image_path)
|
||||
image_file_name_list = sorted(image_file_name_list)
|
||||
matte_file_name_list = glob(matte_path)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from src.models.modnet import MODNet
|
||||
from models.modnet import MODNet
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
|
|
@ -76,7 +76,7 @@ if __name__ == '__main__':
|
|||
weights = torch.load(ckp_pth, map_location=torch.device('cpu'))
|
||||
modnet.load_state_dict(weights)
|
||||
|
||||
pth = 'src/datasets/PPM-100/image/13179159164_1a4ae8d085_o.jpg'
|
||||
pth = 'src/datasets/PPM-100/val/fg/5588688353_3426d4b5d9_o.jpg'
|
||||
img = Image.open(pth)
|
||||
|
||||
matte = predit_matte(modnet, img)
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ class MattingDataset(Dataset):
|
|||
def __init__(self,
|
||||
dataset_root_dir='src/datasets/PPM-100',
|
||||
transform=None):
|
||||
image_path = dataset_root_dir + '/image/*'
|
||||
matte_path = dataset_root_dir + '/matte/*'
|
||||
image_path = dataset_root_dir + '/train/fg/*'
|
||||
matte_path = dataset_root_dir + '/train/alpha/*'
|
||||
image_file_name_list = glob(image_path)
|
||||
matte_file_name_list = glob(matte_path)
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ class GenTrimap(object):
|
|||
k_size = random.choice(range(2, 5))
|
||||
iterations = np.random.randint(5, 15)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
|
||||
(k_size, k_size))
|
||||
(k_size, k_size)) # cv2.MORPH_RECT, cv2.MORPH_CROSS
|
||||
dilated = cv2.dilate(matte, kernel, iterations=iterations)
|
||||
eroded = cv2.erode(matte, kernel, iterations=iterations)
|
||||
|
||||
|
|
@ -141,7 +141,7 @@ class ToTrainArray(object):
|
|||
if __name__ == '__main__':
|
||||
|
||||
# test MattingDataset.gen_trimap
|
||||
matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg')
|
||||
matte = Image.open('src/datasets/PPM-100/train/alpha/6146816_556eaff97f_o.jpg')
|
||||
trimap1 = GenTrimap().gen_trimap(matte)
|
||||
trimap1 = np.array(trimap1) * 255
|
||||
trimap1 = np.uint8(trimap1)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
BS = 16 # BATCH SIZE
|
||||
LR = 0.01 # LEARN RATE
|
||||
EPOCHS = 4 # TOTAL EPOCHS
|
||||
|
||||
SEMANTIC_SCALE = 10.0
|
||||
DETAIL_SCALE = 10.0
|
||||
MATTE_SCALE = 1.0
|
||||
|
||||
SAVE_EPOCH_STEP = 3
|
||||
|
|
@ -307,6 +307,7 @@ if __name__ == '__main__':
|
|||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from models.modnet import MODNet
|
||||
from setting import BS, LR, EPOCHS, SEMANTIC_SCALE, DETAIL_SCALE, MATTE_SCALE, SAVE_EPOCH_STEP
|
||||
transform = transforms.Compose([
|
||||
Rescale(512),
|
||||
GenTrimap(),
|
||||
|
|
@ -317,41 +318,39 @@ if __name__ == '__main__':
|
|||
])
|
||||
mattingDataset = MattingDataset(transform=transform)
|
||||
|
||||
bs = 4 # batch size
|
||||
lr = 0.01 # learn rate
|
||||
epochs = 40 # total epochs
|
||||
|
||||
modnet = torch.nn.DataParallel(MODNet()) #.cuda()
|
||||
if torch.cuda.is_available():
|
||||
modnet = modnet.cuda()
|
||||
|
||||
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
|
||||
optimizer = torch.optim.SGD(modnet.parameters(), lr=LR, momentum=0.9)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * EPOCHS), gamma=0.1)
|
||||
|
||||
dataloader = DataLoader(mattingDataset,
|
||||
batch_size=bs,
|
||||
batch_size=BS,
|
||||
shuffle=True)
|
||||
|
||||
for epoch in range(0, epochs):
|
||||
for epoch in range(0, EPOCHS):
|
||||
print(f'epoch: {epoch}/{EPOCHS-1}')
|
||||
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||
semantic_loss, detail_loss, matte_loss = \
|
||||
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
|
||||
break
|
||||
if epoch % 4 == 0 and epoch > 1:
|
||||
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte,
|
||||
semantic_scale=SEMANTIC_SCALE,
|
||||
detail_scale=DETAIL_SCALE,
|
||||
matte_scale=MATTE_SCALE)
|
||||
print(f'{(idx+1) * BS}/{len(mattingDataset)} --- '
|
||||
f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}\r',
|
||||
end='')
|
||||
lr_scheduler.step()
|
||||
# 保存中间训练结果
|
||||
if epoch % SAVE_EPOCH_STEP == 0 and epoch > 1:
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': modnet.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
|
||||
}, f'pretrained/modnet_custom_portrait_matting_{epoch+1}.ckpt')
|
||||
lr_scheduler.step()
|
||||
print(f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}')
|
||||
if epoch == 4:
|
||||
break
|
||||
}, f'pretrained/modnet_custom_portrait_matting_{epoch}_th.ckpt')
|
||||
print(f'{len(mattingDataset)}/{len(mattingDataset)} --- '
|
||||
f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}')
|
||||
|
||||
torch.save({
|
||||
'epoch': epochs,
|
||||
'model_state_dict': modnet.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
|
||||
}, f'pretrained/modnet_custom_portrait_matting_last_epoch.ckpt')
|
||||
# 仅保存模型权重参数
|
||||
torch.save(modnet.state_dict(), f'pretrained/modnet_custom_portrait_matting_last_epoch_weight.ckpt')
|
||||
|
|
|
|||
Loading…
Reference in New Issue