diff --git a/README.md b/README.md index c30477b..e4ed531 100644 --- a/README.md +++ b/README.md @@ -1,126 +1,24 @@ -
+# 4. 训练模型
+python src/trainer.py
-- Online Application (在线应用) | - Research Demo | - AAAI 2022 Paper | - Supplementary Video -
- -- Community | - Code | - PPM Benchmark | - License | - Acknowledgement | - Citation | - Contact -
- ---- - - -## Online Application (在线应用) - -A **Single** model! Only **7M**! Process **2K** resolution image with a **Fast** speed on common PCs or Mobiles! **Beter** than research demos! -Please try online portrait image matting via [this website](https://sight-x.cn/portrait_matting)! - -**单个**模型!大小仅为**7M**!可以在普通PC或移动设备上**快速**处理具有**2K**分辨率的图像!效果比研究示例**更好**! -请通过[此网站](https://sight-x.cn/portrait_matting)在线尝试图片抠像! - - -## Research Demo - -All the models behind the following demos are trained on the datasets mentioned in [our paper](https://arxiv.org/pdf/2011.11961.pdf). - -### Portrait Image Matting -We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting. -It allows you to upload portrait images and predict/visualize/download the alpha mattes. - - - -### Portrait Video Matting -We provide two real-time portrait video matting demos based on WebCam. When using the demo, you can move the WebCam around at will. -If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/webcam) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing). -We also provide an [offline demo](demo/video_matting/custom) that allows you to process custom videos. - - - - -## Community - -We share some cool applications/extentions of MODNet built by the community. - -- **WebGUI for Portrait Image Matting** -You can try [this WebGUI](https://www.gradio.app/hub/aliabd/modnet) (hosted on [Gradio](https://www.gradio.app/)) for portrait image matting from your browser without code! - -- **Colab Demo of Bokeh (Blur Background)** -You can try [this Colab demo](https://colab.research.google.com/github/eyaler/avatars4all/blob/master/yarok.ipynb) (built by [@eyaler](https://github.com/eyaler)) to blur the backgroud based on MODNet! - -- **ONNX Version of MODNet** -You can convert the pre-trained MODNet to an ONNX model by using [this code](onnx) (provided by [@manthan3C273](https://github.com/manthan3C273)). You can also try [this Colab demo](https://colab.research.google.com/drive/1P3cWtg8fnmu9karZHYDAtmm1vj1rgA-f?usp=sharing) for MODNet image matting (ONNX version). - -- **TorchScript Version of MODNet** -You can convert the pre-trained MODNet to an TorchScript model by using [this code](torchscript) (provided by [@yarkable](https://github.com/yarkable)). - -- **TensorRT Version of MODNet** -You can access [this Github repository](https://github.com/jkjung-avt/tensorrt_demos) to try the TensorRT version of MODNet (provided by [@jkjung-avt](https://github.com/jkjung-avt)). - - -There are some resources about MODNet from the community. -- [Video from What's AI YouTube Channel](https://youtu.be/rUo0wuVyefU) -- [Article from Louis Bouchard's Blog](https://www.louisbouchard.ai/remove-background/) - - -## Code -We provide the [code](src/trainer.py) of MODNet training iteration, including: -- **Supervised Training**: Train MODNet on a labeled matting dataset -- **SOC Adaptation**: Adapt a trained MODNet to an unlabeled dataset - -In code comments, we provide examples for using the functions. - - -## PPM Benchmark -The PPM benchmark is released in a separate repository [PPM](https://github.com/ZHKKKe/PPM). - - -## License -The code, models, and demos in this repository (excluding GIF files under the folder `doc/gif`) are released under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0) license. - - -## Acknowledgement -- We thank - [@yzhou0919](https://github.com/yzhou0919), [@eyaler](https://github.com/eyaler), [@manthan3C273](https://github.com/manthan3C273), [@yarkable](https://github.com/yarkable), [@jkjung-avt](https://github.com/jkjung-avt), [@manzke](https://github.com/manzke), - [the Gradio team](https://github.com/gradio-app/gradio), [What's AI YouTube Channel](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg), [Louis Bouchard's Blog](https://www.louisbouchard.ai), -for their contributions to this repository or their cool applications/extentions/resources of MODNet. - - -## Citation -If this work helps your research, please consider to cite: - -```bibtex -@InProceedings{MODNet, - author = {Zhanghan Ke and Jiayu Sun and Kaican Li and Qiong Yan and Rynson W.H. Lau}, - title = {MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition}, - booktitle = {AAAI}, - year = {2022}, -} +# 6. 模型推理 +python src/infer.py ``` - - -## Contact -This repository is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)). -For questions, please contact `kezhanghan@outlook.com`. - -
diff --git a/src/eval.py b/src/eval.py
index e9e5305..7106051 100644
--- a/src/eval.py
+++ b/src/eval.py
@@ -1,8 +1,8 @@
import numpy as np
from glob import glob
-from src.models.modnet import MODNet
+from models.modnet import MODNet
from PIL import Image
-from src.infer import predit_matte
+from infer import predit_matte
import torch.nn as nn
import torch
@@ -22,8 +22,8 @@ def cal_mse(pred, gt):
def load_eval_dataset(dataset_root_dir='src/datasets/PPM-100'):
- image_path = dataset_root_dir + '/image/*'
- matte_path = dataset_root_dir + '/matte/*'
+ image_path = dataset_root_dir + '/val/fg/*'
+ matte_path = dataset_root_dir + '/val/alpha/*'
image_file_name_list = glob(image_path)
image_file_name_list = sorted(image_file_name_list)
matte_file_name_list = glob(matte_path)
diff --git a/src/infer.py b/src/infer.py
index 1fe99cc..116807b 100644
--- a/src/infer.py
+++ b/src/infer.py
@@ -1,4 +1,4 @@
-from src.models.modnet import MODNet
+from models.modnet import MODNet
from PIL import Image
import numpy as np
from torchvision import transforms
@@ -76,7 +76,7 @@ if __name__ == '__main__':
weights = torch.load(ckp_pth, map_location=torch.device('cpu'))
modnet.load_state_dict(weights)
- pth = 'src/datasets/PPM-100/image/13179159164_1a4ae8d085_o.jpg'
+ pth = 'src/datasets/PPM-100/val/fg/5588688353_3426d4b5d9_o.jpg'
img = Image.open(pth)
matte = predit_matte(modnet, img)
diff --git a/src/matting_dataset.py b/src/matting_dataset.py
index 9838f3f..4e404fb 100644
--- a/src/matting_dataset.py
+++ b/src/matting_dataset.py
@@ -13,8 +13,8 @@ class MattingDataset(Dataset):
def __init__(self,
dataset_root_dir='src/datasets/PPM-100',
transform=None):
- image_path = dataset_root_dir + '/image/*'
- matte_path = dataset_root_dir + '/matte/*'
+ image_path = dataset_root_dir + '/train/fg/*'
+ matte_path = dataset_root_dir + '/train/alpha/*'
image_file_name_list = glob(image_path)
matte_file_name_list = glob(matte_path)
@@ -75,7 +75,7 @@ class GenTrimap(object):
k_size = random.choice(range(2, 5))
iterations = np.random.randint(5, 15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
- (k_size, k_size))
+ (k_size, k_size)) # cv2.MORPH_RECT, cv2.MORPH_CROSS
dilated = cv2.dilate(matte, kernel, iterations=iterations)
eroded = cv2.erode(matte, kernel, iterations=iterations)
@@ -141,7 +141,7 @@ class ToTrainArray(object):
if __name__ == '__main__':
# test MattingDataset.gen_trimap
- matte = Image.open('src/datasets/PPM-100/matte/6146816_556eaff97f_o.jpg')
+ matte = Image.open('src/datasets/PPM-100/train/alpha/6146816_556eaff97f_o.jpg')
trimap1 = GenTrimap().gen_trimap(matte)
trimap1 = np.array(trimap1) * 255
trimap1 = np.uint8(trimap1)
diff --git a/src/setting.py b/src/setting.py
new file mode 100644
index 0000000..9e2070d
--- /dev/null
+++ b/src/setting.py
@@ -0,0 +1,9 @@
+BS = 16 # BATCH SIZE
+LR = 0.01 # LEARN RATE
+EPOCHS = 4 # TOTAL EPOCHS
+
+SEMANTIC_SCALE = 10.0
+DETAIL_SCALE = 10.0
+MATTE_SCALE = 1.0
+
+SAVE_EPOCH_STEP = 3
diff --git a/src/trainer.py b/src/trainer.py
index aa71ef0..becb4af 100644
--- a/src/trainer.py
+++ b/src/trainer.py
@@ -307,6 +307,7 @@ if __name__ == '__main__':
from torchvision import transforms
from torch.utils.data import DataLoader
from models.modnet import MODNet
+ from setting import BS, LR, EPOCHS, SEMANTIC_SCALE, DETAIL_SCALE, MATTE_SCALE, SAVE_EPOCH_STEP
transform = transforms.Compose([
Rescale(512),
GenTrimap(),
@@ -317,41 +318,39 @@ if __name__ == '__main__':
])
mattingDataset = MattingDataset(transform=transform)
- bs = 4 # batch size
- lr = 0.01 # learn rate
- epochs = 40 # total epochs
-
modnet = torch.nn.DataParallel(MODNet()) #.cuda()
if torch.cuda.is_available():
modnet = modnet.cuda()
- optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
- lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
+ optimizer = torch.optim.SGD(modnet.parameters(), lr=LR, momentum=0.9)
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * EPOCHS), gamma=0.1)
dataloader = DataLoader(mattingDataset,
- batch_size=bs,
+ batch_size=BS,
shuffle=True)
- for epoch in range(0, epochs):
+ for epoch in range(0, EPOCHS):
+ print(f'epoch: {epoch}/{EPOCHS-1}')
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
semantic_loss, detail_loss, matte_loss = \
- supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
- break
- if epoch % 4 == 0 and epoch > 1:
+ supervised_training_iter(modnet, optimizer, image, trimap, gt_matte,
+ semantic_scale=SEMANTIC_SCALE,
+ detail_scale=DETAIL_SCALE,
+ matte_scale=MATTE_SCALE)
+ print(f'{(idx+1) * BS}/{len(mattingDataset)} --- '
+ f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}\r',
+ end='')
+ lr_scheduler.step()
+ # 保存中间训练结果
+ if epoch % SAVE_EPOCH_STEP == 0 and epoch > 1:
torch.save({
'epoch': epoch,
'model_state_dict': modnet.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
- }, f'pretrained/modnet_custom_portrait_matting_{epoch+1}.ckpt')
- lr_scheduler.step()
- print(f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}')
- if epoch == 4:
- break
+ }, f'pretrained/modnet_custom_portrait_matting_{epoch}_th.ckpt')
+ print(f'{len(mattingDataset)}/{len(mattingDataset)} --- '
+ f'semantic_loss: {semantic_loss:f}, detail_loss: {detail_loss:f}, matte_loss: {matte_loss:f}')
- torch.save({
- 'epoch': epochs,
- 'model_state_dict': modnet.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': {'semantic_loss': semantic_loss, 'detail_loss': detail_loss, 'matte_loss': matte_loss},
- }, f'pretrained/modnet_custom_portrait_matting_last_epoch.ckpt')
+ # 仅保存模型权重参数
+ torch.save(modnet.state_dict(), f'pretrained/modnet_custom_portrait_matting_last_epoch_weight.ckpt')