full train demo

pull/177/head
actboy 2022-02-13 16:14:29 +08:00
parent 77118b2a4e
commit a070aaeeed
6 changed files with 58 additions and 152 deletions

138
README.md
View File

@ -1,126 +1,24 @@
<h2 align="center">MODNet: Trimap-Free Portrait Matting in Real Time</h2> # 说明
代码fork from [MODNet官方代码](https://github.com/ZHKKKe/MODNet) 。本项目完善了数据准备、模型评价及模型训练相关代码
# 模型训练、评价、推理
```bash
# 1. 下载代码并进入工作目录
git clone https://github.com/actboy/MODNet
cd MODNet
<div align="center"><i>MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition (AAAI 2022)</i></div> # 2. 安装依赖
pip install -r src/requirements.txt
<br /> # 3. 下载并解压数据集
wget -c https://paddleseg.bj.bcebos.com/matting/datasets/PPM-100.zip -O src/datasets/PPM-100.zip
unzip src/datasets/PPM-100.zip -d src/datasets
<img src="doc/gif/homepage_demo.gif" width="100%"> # 4. 训练模型
python src/trainer.py
<div align="center">MODNet is a model for <b>real-time</b> portrait matting with <b>only RGB image input</b></div> # 5. 模型评估
<div align="center">MODNet是一个<b>仅需RGB图片输入</b><b>实时</b>人像抠图模型</div> python src/eval.py
<br /> # 6. 模型推理
python src/infer.py
<p align="center">
<a href="#online-application-在线应用">Online Application (在线应用)</a> |
<a href="#research-demo">Research Demo</a> |
<a href="https://arxiv.org/pdf/2011.11961.pdf">AAAI 2022 Paper</a> |
<a href="https://youtu.be/PqJ3BRHX3Lc">Supplementary Video</a>
</p>
<p align="center">
<a href="#community">Community</a> |
<a href="#code">Code</a> |
<a href="#ppm-benchmark">PPM Benchmark</a> |
<a href="#license">License</a> |
<a href="#acknowledgement">Acknowledgement</a> |
<a href="#citation">Citation</a> |
<a href="#contact">Contact</a>
</p>
---
## Online Application (在线应用)
A **Single** model! Only **7M**! Process **2K** resolution image with a **Fast** speed on common PCs or Mobiles! **Beter** than research demos!
Please try online portrait image matting via [this website](https://sight-x.cn/portrait_matting)!
**单个**模型!大小仅为**7M**可以在普通PC或移动设备上**快速**处理具有**2K**分辨率的图像!效果比研究示例**更好**
请通过[此网站](https://sight-x.cn/portrait_matting)在线尝试图片抠像!
## Research Demo
All the models behind the following demos are trained on the datasets mentioned in [our paper](https://arxiv.org/pdf/2011.11961.pdf).
### Portrait Image Matting
We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting.
It allows you to upload portrait images and predict/visualize/download the alpha mattes.
<!-- <img src="doc/gif/image_matting_demo.gif" width='40%'> -->
### Portrait Video Matting
We provide two real-time portrait video matting demos based on WebCam. When using the demo, you can move the WebCam around at will.
If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/webcam) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
We also provide an [offline demo](demo/video_matting/custom) that allows you to process custom videos.
<!-- <img src="doc/gif/video_matting_demo.gif" width='60%'> -->
## Community
We share some cool applications/extentions of MODNet built by the community.
- **WebGUI for Portrait Image Matting**
You can try [this WebGUI](https://www.gradio.app/hub/aliabd/modnet) (hosted on [Gradio](https://www.gradio.app/)) for portrait image matting from your browser without code!
- **Colab Demo of Bokeh (Blur Background)**
You can try [this Colab demo](https://colab.research.google.com/github/eyaler/avatars4all/blob/master/yarok.ipynb) (built by [@eyaler](https://github.com/eyaler)) to blur the backgroud based on MODNet!
- **ONNX Version of MODNet**
You can convert the pre-trained MODNet to an ONNX model by using [this code](onnx) (provided by [@manthan3C273](https://github.com/manthan3C273)). You can also try [this Colab demo](https://colab.research.google.com/drive/1P3cWtg8fnmu9karZHYDAtmm1vj1rgA-f?usp=sharing) for MODNet image matting (ONNX version).
- **TorchScript Version of MODNet**
You can convert the pre-trained MODNet to an TorchScript model by using [this code](torchscript) (provided by [@yarkable](https://github.com/yarkable)).
- **TensorRT Version of MODNet**
You can access [this Github repository](https://github.com/jkjung-avt/tensorrt_demos) to try the TensorRT version of MODNet (provided by [@jkjung-avt](https://github.com/jkjung-avt)).
There are some resources about MODNet from the community.
- [Video from What's AI YouTube Channel](https://youtu.be/rUo0wuVyefU)
- [Article from Louis Bouchard's Blog](https://www.louisbouchard.ai/remove-background/)
## Code
We provide the [code](src/trainer.py) of MODNet training iteration, including:
- **Supervised Training**: Train MODNet on a labeled matting dataset
- **SOC Adaptation**: Adapt a trained MODNet to an unlabeled dataset
In code comments, we provide examples for using the functions.
## PPM Benchmark
The PPM benchmark is released in a separate repository [PPM](https://github.com/ZHKKKe/PPM).
## License
The code, models, and demos in this repository (excluding GIF files under the folder `doc/gif`) are released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0) license.
## Acknowledgement
- We thank
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;[@yzhou0919](https://github.com/yzhou0919), [@eyaler](https://github.com/eyaler), [@manthan3C273](https://github.com/manthan3C273), [@yarkable](https://github.com/yarkable), [@jkjung-avt](https://github.com/jkjung-avt), [@manzke](https://github.com/manzke),
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;[the Gradio team](https://github.com/gradio-app/gradio), [What's AI YouTube Channel](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg), [Louis Bouchard's Blog](https://www.louisbouchard.ai),
for their contributions to this repository or their cool applications/extentions/resources of MODNet.
## Citation
If this work helps your research, please consider to cite:
```bibtex
@InProceedings{MODNet,
author = {Zhanghan Ke and Jiayu Sun and Kaican Li and Qiong Yan and Rynson W.H. Lau},
title = {MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition},
booktitle = {AAAI},
year = {2022},
}
``` ```
## Contact
This repository is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
For questions, please contact `kezhanghan@outlook.com`.
<img src="doc/gif/commercial_image_matting_model_result.gif" width='100%'>

View File

@ -1,8 +1,8 @@
import numpy as np import numpy as np
from glob import glob from glob import glob
from src.models.modnet import MODNet from models.modnet import MODNet
from PIL import Image from PIL import Image
from src.infer import predit_matte from infer import predit_matte
import torch.nn as nn import torch.nn as nn
import torch import torch
@ -22,8 +22,8 @@ def cal_mse(pred, gt):
def load_eval_dataset(dataset_root_dir='src/datasets/PPM-100'): def load_eval_dataset(dataset_root_dir='src/datasets/PPM-100'):
image_path = dataset_root_dir + '/image/*' image_path = dataset_root_dir + '/val/fg/*'
matte_path = dataset_root_dir + '/matte/*' matte_path = dataset_root_dir + '/val/alpha/*'
image_file_name_list = glob(image_path) image_file_name_list = glob(image_path)
image_file_name_list = sorted(image_file_name_list) image_file_name_list = sorted(image_file_name_list)
matte_file_name_list = glob(matte_path) matte_file_name_list = glob(matte_path)

View File

@ -1,4 +1,4 @@
from src.models.modnet import MODNet from models.modnet import MODNet
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from torchvision import transforms from torchvision import transforms
@ -76,7 +76,7 @@ if __name__ == '__main__':
weights = torch.load(ckp_pth, map_location=torch.device('cpu')) weights = torch.load(ckp_pth, map_location=torch.device('cpu'))
modnet.load_state_dict(weights) modnet.load_state_dict(weights)
pth = 'src/datasets/PPM-100/image/13179159164_1a4ae8d085_o.jpg' pth = 'src/datasets/PPM-100/val/fg/5588688353_3426d4b5d9_o.jpg'
img = Image.open(pth) img = Image.open(pth)
matte = predit_matte(modnet, img) matte = predit_matte(modnet, img)

View File

@ -13,8 +13,8 @@ class MattingDataset(Dataset):
def __init__(self, def __init__(self,
dataset_root_dir='src/datasets/PPM-100', dataset_root_dir='src/datasets/PPM-100',
transform=None): transform=None):
image_path = dataset_root_dir + '/image/*' image_path = dataset_root_dir + '/train/fg/*'
matte_path = dataset_root_dir + '/matte/*' matte_path = dataset_root_dir + '/train/alpha/*'
image_file_name_list = glob(image_path) image_file_name_list = glob(image_path)
matte_file_name_list = glob(matte_path) matte_file_name_list = glob(matte_path)
@ -75,7 +75,7 @@ class GenTrimap(object):
k_size = random.choice(range(2, 5)) k_size = random.choice(range(2, 5))
iterations = np.random.randint(5, 15) iterations = np.random.randint(5, 15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(k_size, k_size)) (k_size, k_size)) # cv2.MORPH_RECT, cv2.MORPH_CROSS
dilated = cv2.dilate(matte, kernel, iterations=iterations) dilated = cv2.dilate(matte, kernel, iterations=iterations)
eroded = cv2.erode(matte, kernel, iterations=iterations) eroded = cv2.erode(matte, kernel, iterations=iterations)
@ -141,7 +141,7 @@ class ToTrainArray(object):
if __name__ == '__main__': if __name__ == '__main__':
# test MattingDataset.gen_trimap # test MattingDataset.gen_trimap
matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg') matte = Image.open('src/datasets/PPM-100/train/alpha/6146816_556eaff97f_o.jpg')
trimap1 = GenTrimap().gen_trimap(matte) trimap1 = GenTrimap().gen_trimap(matte)
trimap1 = np.array(trimap1) * 255 trimap1 = np.array(trimap1) * 255
trimap1 = np.uint8(trimap1) trimap1 = np.uint8(trimap1)

9
src/setting.py Normal file
View File

@ -0,0 +1,9 @@
BS = 16 # BATCH SIZE
LR = 0.01 # LEARN RATE
EPOCHS = 4 # TOTAL EPOCHS
SEMANTIC_SCALE = 10.0
DETAIL_SCALE = 10.0
MATTE_SCALE = 1.0
SAVE_EPOCH_STEP = 3

View File

@ -307,6 +307,7 @@ if __name__ == '__main__':
from torchvision import transforms from torchvision import transforms
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from models.modnet import MODNet from models.modnet import MODNet
from setting import BS, LR, EPOCHS, SEMANTIC_SCALE, DETAIL_SCALE, MATTE_SCALE, SAVE_EPOCH_STEP
transform = transforms.Compose([ transform = transforms.Compose([
Rescale(512), Rescale(512),
GenTrimap(), GenTrimap(),
@ -317,41 +318,39 @@ if __name__ == '__main__':
]) ])
mattingDataset = MattingDataset(transform=transform) mattingDataset = MattingDataset(transform=transform)
bs = 4 # batch size
lr = 0.01 # learn rate
epochs = 40 # total epochs
modnet = torch.nn.DataParallel(MODNet()) #.cuda() modnet = torch.nn.DataParallel(MODNet()) #.cuda()
if torch.cuda.is_available(): if torch.cuda.is_available():
modnet = modnet.cuda() modnet = modnet.cuda()
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) optimizer = torch.optim.SGD(modnet.parameters(), lr=LR, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * EPOCHS), gamma=0.1)
dataloader = DataLoader(mattingDataset, dataloader = DataLoader(mattingDataset,
batch_size=bs, batch_size=BS,
shuffle=True) shuffle=True)
for epoch in range(0, epochs): for epoch in range(0, EPOCHS):
print(f'epoch: {epoch}/{EPOCHS-1}')
for idx, (image, trimap, gt_matte) in enumerate(dataloader): for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss = \ semantic_loss, detail_loss, matte_loss = \
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) supervised_training_iter(modnet, optimizer, image, trimap, gt_matte,
break semantic_scale=SEMANTIC_SCALE,
if epoch % 4 == 0 and epoch > 1: detail_scale=DETAIL_SCALE,
matte_scale=MATTE_SCALE)
print(f'{(idx+1) * BS}/{len(mattingDataset)} --- '
f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}\r',
end='')
lr_scheduler.step()
# 保存中间训练结果
if epoch % SAVE_EPOCH_STEP == 0 and epoch > 1:
torch.save({ torch.save({
'epoch': epoch, 'epoch': epoch,
'model_state_dict': modnet.state_dict(), 'model_state_dict': modnet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss}, 'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
}, f'pretrained/modnet_custom_portrait_matting_{epoch+1}.ckpt') }, f'pretrained/modnet_custom_portrait_matting_{epoch}_th.ckpt')
lr_scheduler.step() print(f'{len(mattingDataset)}/{len(mattingDataset)} --- '
print(f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}') f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}')
if epoch == 4:
break
torch.save({ # 仅保存模型权重参数
'epoch': epochs, torch.save(modnet.state_dict(), f'pretrained/modnet_custom_portrait_matting_last_epoch_weight.ckpt')
'model_state_dict': modnet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
}, f'pretrained/modnet_custom_portrait_matting_last_epoch.ckpt')