diff --git a/README.md b/README.md index c30477b..e4ed531 100644 --- a/README.md +++ b/README.md @@ -1,126 +1,24 @@ -

MODNet: Trimap-Free Portrait Matting in Real Time

+# 说明 +代码fork from [MODNet官方代码](https://github.com/ZHKKKe/MODNet) 。本项目完善了数据准备、模型评价及模型训练相关代码 +# 模型训练、评价、推理 +```bash +# 1. 下载代码并进入工作目录 +git clone https://github.com/actboy/MODNet +cd MODNet -
MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition (AAAI 2022)
+# 2. 安装依赖 +pip install -r src/requirements.txt -
+# 3. 下载并解压数据集 +wget -c https://paddleseg.bj.bcebos.com/matting/datasets/PPM-100.zip -O src/datasets/PPM-100.zip +unzip src/datasets/PPM-100.zip -d src/datasets - +# 4. 训练模型 +python src/trainer.py -
MODNet is a model for real-time portrait matting with only RGB image input
-
MODNet是一个仅需RGB图片输入实时人像抠图模型
+# 5. 模型评估 +python src/eval.py -
- -

- Online Application (在线应用) | - Research Demo | - AAAI 2022 Paper | - Supplementary Video -

- -

- Community | - Code | - PPM Benchmark | - License | - Acknowledgement | - Citation | - Contact -

- ---- - - -## Online Application (在线应用) - -A **Single** model! Only **7M**! Process **2K** resolution image with a **Fast** speed on common PCs or Mobiles! **Beter** than research demos! -Please try online portrait image matting via [this website](https://sight-x.cn/portrait_matting)! - -**单个**模型!大小仅为**7M**!可以在普通PC或移动设备上**快速**处理具有**2K**分辨率的图像!效果比研究示例**更好**! -请通过[此网站](https://sight-x.cn/portrait_matting)在线尝试图片抠像! - - -## Research Demo - -All the models behind the following demos are trained on the datasets mentioned in [our paper](https://arxiv.org/pdf/2011.11961.pdf). - -### Portrait Image Matting -We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting. -It allows you to upload portrait images and predict/visualize/download the alpha mattes. - - - -### Portrait Video Matting -We provide two real-time portrait video matting demos based on WebCam. When using the demo, you can move the WebCam around at will. -If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/webcam) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing). -We also provide an [offline demo](demo/video_matting/custom) that allows you to process custom videos. - - - - -## Community - -We share some cool applications/extentions of MODNet built by the community. - -- **WebGUI for Portrait Image Matting** -You can try [this WebGUI](https://www.gradio.app/hub/aliabd/modnet) (hosted on [Gradio](https://www.gradio.app/)) for portrait image matting from your browser without code! - -- **Colab Demo of Bokeh (Blur Background)** -You can try [this Colab demo](https://colab.research.google.com/github/eyaler/avatars4all/blob/master/yarok.ipynb) (built by [@eyaler](https://github.com/eyaler)) to blur the backgroud based on MODNet! - -- **ONNX Version of MODNet** -You can convert the pre-trained MODNet to an ONNX model by using [this code](onnx) (provided by [@manthan3C273](https://github.com/manthan3C273)). You can also try [this Colab demo](https://colab.research.google.com/drive/1P3cWtg8fnmu9karZHYDAtmm1vj1rgA-f?usp=sharing) for MODNet image matting (ONNX version). - -- **TorchScript Version of MODNet** -You can convert the pre-trained MODNet to an TorchScript model by using [this code](torchscript) (provided by [@yarkable](https://github.com/yarkable)). - -- **TensorRT Version of MODNet** -You can access [this Github repository](https://github.com/jkjung-avt/tensorrt_demos) to try the TensorRT version of MODNet (provided by [@jkjung-avt](https://github.com/jkjung-avt)). - - -There are some resources about MODNet from the community. -- [Video from What's AI YouTube Channel](https://youtu.be/rUo0wuVyefU) -- [Article from Louis Bouchard's Blog](https://www.louisbouchard.ai/remove-background/) - - -## Code -We provide the [code](src/trainer.py) of MODNet training iteration, including: -- **Supervised Training**: Train MODNet on a labeled matting dataset -- **SOC Adaptation**: Adapt a trained MODNet to an unlabeled dataset - -In code comments, we provide examples for using the functions. - - -## PPM Benchmark -The PPM benchmark is released in a separate repository [PPM](https://github.com/ZHKKKe/PPM). - - -## License -The code, models, and demos in this repository (excluding GIF files under the folder `doc/gif`) are released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0) license. - - -## Acknowledgement -- We thank -        [@yzhou0919](https://github.com/yzhou0919), [@eyaler](https://github.com/eyaler), [@manthan3C273](https://github.com/manthan3C273), [@yarkable](https://github.com/yarkable), [@jkjung-avt](https://github.com/jkjung-avt), [@manzke](https://github.com/manzke), -        [the Gradio team](https://github.com/gradio-app/gradio), [What's AI YouTube Channel](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg), [Louis Bouchard's Blog](https://www.louisbouchard.ai), -for their contributions to this repository or their cool applications/extentions/resources of MODNet. - - -## Citation -If this work helps your research, please consider to cite: - -```bibtex -@InProceedings{MODNet, - author = {Zhanghan Ke and Jiayu Sun and Kaican Li and Qiong Yan and Rynson W.H. Lau}, - title = {MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition}, - booktitle = {AAAI}, - year = {2022}, -} +# 6. 模型推理 +python src/infer.py ``` - - -## Contact -This repository is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)). -For questions, please contact `kezhanghan@outlook.com`. - - diff --git a/src/eval.py b/src/eval.py index e9e5305..7106051 100644 --- a/src/eval.py +++ b/src/eval.py @@ -1,8 +1,8 @@ import numpy as np from glob import glob -from src.models.modnet import MODNet +from models.modnet import MODNet from PIL import Image -from src.infer import predit_matte +from infer import predit_matte import torch.nn as nn import torch @@ -22,8 +22,8 @@ def cal_mse(pred, gt): def load_eval_dataset(dataset_root_dir='src/datasets/PPM-100'): - image_path = dataset_root_dir + '/image/*' - matte_path = dataset_root_dir + '/matte/*' + image_path = dataset_root_dir + '/val/fg/*' + matte_path = dataset_root_dir + '/val/alpha/*' image_file_name_list = glob(image_path) image_file_name_list = sorted(image_file_name_list) matte_file_name_list = glob(matte_path) diff --git a/src/infer.py b/src/infer.py index 1fe99cc..116807b 100644 --- a/src/infer.py +++ b/src/infer.py @@ -1,4 +1,4 @@ -from src.models.modnet import MODNet +from models.modnet import MODNet from PIL import Image import numpy as np from torchvision import transforms @@ -76,7 +76,7 @@ if __name__ == '__main__': weights = torch.load(ckp_pth, map_location=torch.device('cpu')) modnet.load_state_dict(weights) - pth = 'src/datasets/PPM-100/image/13179159164_1a4ae8d085_o.jpg' + pth = 'src/datasets/PPM-100/val/fg/5588688353_3426d4b5d9_o.jpg' img = Image.open(pth) matte = predit_matte(modnet, img) diff --git a/src/matting_dataset.py b/src/matting_dataset.py index 9838f3f..4e404fb 100644 --- a/src/matting_dataset.py +++ b/src/matting_dataset.py @@ -13,8 +13,8 @@ class MattingDataset(Dataset): def __init__(self, dataset_root_dir='src/datasets/PPM-100', transform=None): - image_path = dataset_root_dir + '/image/*' - matte_path = dataset_root_dir + '/matte/*' + image_path = dataset_root_dir + '/train/fg/*' + matte_path = dataset_root_dir + '/train/alpha/*' image_file_name_list = glob(image_path) matte_file_name_list = glob(matte_path) @@ -75,7 +75,7 @@ class GenTrimap(object): k_size = random.choice(range(2, 5)) iterations = np.random.randint(5, 15) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, - (k_size, k_size)) + (k_size, k_size)) # cv2.MORPH_RECT, cv2.MORPH_CROSS dilated = cv2.dilate(matte, kernel, iterations=iterations) eroded = cv2.erode(matte, kernel, iterations=iterations) @@ -141,7 +141,7 @@ class ToTrainArray(object): if __name__ == '__main__': # test MattingDataset.gen_trimap - matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg') + matte = Image.open('src/datasets/PPM-100/train/alpha/6146816_556eaff97f_o.jpg') trimap1 = GenTrimap().gen_trimap(matte) trimap1 = np.array(trimap1) * 255 trimap1 = np.uint8(trimap1) diff --git a/src/setting.py b/src/setting.py new file mode 100644 index 0000000..9e2070d --- /dev/null +++ b/src/setting.py @@ -0,0 +1,9 @@ +BS = 16 # BATCH SIZE +LR = 0.01 # LEARN RATE +EPOCHS = 4 # TOTAL EPOCHS + +SEMANTIC_SCALE = 10.0 +DETAIL_SCALE = 10.0 +MATTE_SCALE = 1.0 + +SAVE_EPOCH_STEP = 3 diff --git a/src/trainer.py b/src/trainer.py index aa71ef0..becb4af 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -307,6 +307,7 @@ if __name__ == '__main__': from torchvision import transforms from torch.utils.data import DataLoader from models.modnet import MODNet + from setting import BS, LR, EPOCHS, SEMANTIC_SCALE, DETAIL_SCALE, MATTE_SCALE, SAVE_EPOCH_STEP transform = transforms.Compose([ Rescale(512), GenTrimap(), @@ -317,41 +318,39 @@ if __name__ == '__main__': ]) mattingDataset = MattingDataset(transform=transform) - bs = 4 # batch size - lr = 0.01 # learn rate - epochs = 40 # total epochs - modnet = torch.nn.DataParallel(MODNet()) #.cuda() if torch.cuda.is_available(): modnet = modnet.cuda() - optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1) + optimizer = torch.optim.SGD(modnet.parameters(), lr=LR, momentum=0.9) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * EPOCHS), gamma=0.1) dataloader = DataLoader(mattingDataset, - batch_size=bs, + batch_size=BS, shuffle=True) - for epoch in range(0, epochs): + for epoch in range(0, EPOCHS): + print(f'epoch: {epoch}/{EPOCHS-1}') for idx, (image, trimap, gt_matte) in enumerate(dataloader): semantic_loss, detail_loss, matte_loss = \ - supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) - break - if epoch % 4 == 0 and epoch > 1: + supervised_training_iter(modnet, optimizer, image, trimap, gt_matte, + semantic_scale=SEMANTIC_SCALE, + detail_scale=DETAIL_SCALE, + matte_scale=MATTE_SCALE) + print(f'{(idx+1) * BS}/{len(mattingDataset)} --- ' + f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}\r', + end='') + lr_scheduler.step() + # 保存中间训练结果 + if epoch % SAVE_EPOCH_STEP == 0 and epoch > 1: torch.save({ 'epoch': epoch, 'model_state_dict': modnet.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss}, - }, f'pretrained/modnet_custom_portrait_matting_{epoch+1}.ckpt') - lr_scheduler.step() - print(f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}') - if epoch == 4: - break + }, f'pretrained/modnet_custom_portrait_matting_{epoch}_th.ckpt') + print(f'{len(mattingDataset)}/{len(mattingDataset)} --- ' + f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}') - torch.save({ - 'epoch': epochs, - 'model_state_dict': modnet.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss}, - }, f'pretrained/modnet_custom_portrait_matting_last_epoch.ckpt') + # 仅保存模型权重参数 + torch.save(modnet.state_dict(), f'pretrained/modnet_custom_portrait_matting_last_epoch_weight.ckpt')