mirror of https://github.com/ZHKKKe/MODNet.git
upload portrait image/video matting demos of MODNet
parent
7358e486b3
commit
0bdc3d1ddf
|
|
@ -0,0 +1,96 @@
|
|||
# Temporary directories and files
|
||||
*.ckpt
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*,cover
|
||||
.hypothesis/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# IPython Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# dotenv
|
||||
.env
|
||||
|
||||
# virtualenv
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
|
||||
# Project files
|
||||
.vscode
|
||||
39
README.md
39
README.md
|
|
@ -4,27 +4,44 @@
|
|||
|
||||
<p align="center">
|
||||
<a href="https://arxiv.org/pdf/2011.11961.pdf">Arxiv Preprint</a> |
|
||||
<a href="https://youtu.be/PqJ3BRHX3Lc">Supplementary Video</a>
|
||||
<a href="https://youtu.be/PqJ3BRHX3Lc">Supplementary Video</a> |
|
||||
<a href="https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing">Video Matting Demo</a> :fire: |
|
||||
<a href="https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing">Image Matting Demo</a> :fire:
|
||||
</p>
|
||||
|
||||
<div align="center">This is the official project of our paper <b>Is a Green Screen Really Necessary for Real-Time Portrait Matting?</b></div>
|
||||
<div align="center">MODNet is a <b>trimap-free</b> model for portrait matting in <b>real time</b> (on a single GPU).</div>
|
||||
<div align="center">Our amazing demo, code, pre-trained model, and validation benchmark are coming soon!</div>
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Announcement
|
||||
I have received some requests for accessing our code. I am sorry that we need some time to get everything ready since this repository is now supported by Zhanghan Ke alone. Our plans in the next few months are:
|
||||
- We will publish an online image/video matting demo along with the pre-trained model **in these two weeks (approximately Dec. 7, 2020 to Dec. 18, 2020)**.
|
||||
- We then plan to release the code of supervised training and unsupervised SOC in Jan. 2021.
|
||||
- We finally plan to open source the PPM-100 validation benchmark in Feb. 2021.
|
||||
|
||||
We look forward to your continued attention to this project. Thanks.
|
||||
|
||||
## News
|
||||
- [Dec 10 2020] Release [Video Matting Demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing) and [Image Matting Demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing).
|
||||
- [Nov 24 2020] Release [Arxiv Preprint](https://arxiv.org/pdf/2011.11961.pdf) and [Supplementary Video](https://youtu.be/PqJ3BRHX3Lc).
|
||||
|
||||
|
||||
## Video Matting Demo
|
||||
We provide two real-time portrait video matting demos based on WebCam.
|
||||
If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/README.md) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
|
||||
|
||||
|
||||
|
||||
## Image Matting Demo
|
||||
We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting.
|
||||
It allows you to upload portrait images and predict/visualize/download the alpha mattes.
|
||||
|
||||
<img src="doc/gif/image_matting_demo.gif">
|
||||
|
||||
|
||||
## TO DO
|
||||
- Release training code (scheduled in **Jan. 2021**)
|
||||
- Release PPM-100 validation benchmark (scheduled in **Feb. 2021**)
|
||||
|
||||
## Acknowledgement
|
||||
We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project.
|
||||
|
||||
|
||||
## Citation
|
||||
If this work helps your research, please consider to cite:
|
||||
|
||||
|
|
@ -37,3 +54,7 @@ If this work helps your research, please consider to cite:
|
|||
year = {2020},
|
||||
}
|
||||
```
|
||||
|
||||
## Contact
|
||||
This project is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
|
||||
If you have any questions, please feel free to contact `kezhanghan@outlook.com`.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
## MODNet - Portrait Image Matting Demo
|
||||
Please try MODNet portrait image matting demo through our [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing).
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from src.models.modnet import MODNet
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# define cmd arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input-path', type=str, help='path of input images')
|
||||
parser.add_argument('--output-path', type=str, help='path of output images')
|
||||
parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet')
|
||||
args = parser.parse_args()
|
||||
|
||||
# check input arguments
|
||||
if not os.path.exists(args.input_path):
|
||||
print('Cannot find input path: {0}'.format(args.input_path))
|
||||
exit()
|
||||
if not os.path.exists(args.output_path):
|
||||
print('Cannot find output path: {0}'.format(args.output_path))
|
||||
exit()
|
||||
if not os.path.exists(args.ckpt_path):
|
||||
print('Cannot find ckpt path: {0}'.format(args.ckpt_path))
|
||||
exit()
|
||||
|
||||
# define hyper-parameters
|
||||
ref_size = 512
|
||||
|
||||
# define image to tensor transform
|
||||
im_transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
|
||||
# create MODNet and load the pre-trained ckpt
|
||||
modnet = MODNet(backbone_pretrained=False)
|
||||
modnet = nn.DataParallel(modnet).cuda()
|
||||
modnet.load_state_dict(torch.load(args.ckpt_path))
|
||||
modnet.eval()
|
||||
|
||||
# inference images
|
||||
im_names = os.listdir(args.input_path)
|
||||
for im_name in im_names:
|
||||
print('Process image: {0}'.format(im_name))
|
||||
|
||||
# read image
|
||||
im = Image.open(os.path.join(args.input_path, im_name))
|
||||
|
||||
# unify image channels to 3
|
||||
im = np.asarray(im)
|
||||
if len(im.shape) == 2:
|
||||
im = im[:, :, None]
|
||||
if im.shape[2] == 1:
|
||||
im = np.repeat(im, 3, axis=2)
|
||||
elif im.shape[2] == 4:
|
||||
im = im[:, :, 0:3]
|
||||
|
||||
# convert image to PyTorch tensor
|
||||
im = Image.fromarray(im)
|
||||
im = im_transform(im)
|
||||
|
||||
# add mini-batch dim
|
||||
im = im[None, :, :, :]
|
||||
|
||||
# resize image for input
|
||||
im_b, im_c, im_h, im_w = im.shape
|
||||
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
|
||||
if im_w >= im_h:
|
||||
im_rh = ref_size
|
||||
im_rw = int(im_w / im_h * ref_size)
|
||||
elif im_w < im_h:
|
||||
im_rw = ref_size
|
||||
im_rh = int(im_h / im_w * ref_size)
|
||||
else:
|
||||
im_rh = im_h
|
||||
im_rw = im_w
|
||||
|
||||
im_rw = im_rw - im_rw % 32
|
||||
im_rh = im_rh - im_rh % 32
|
||||
im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
|
||||
|
||||
# inference
|
||||
_, _, matte = modnet(im.cuda(), inference=False)
|
||||
|
||||
# resize and save matte
|
||||
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
|
||||
matte = matte[0][0].data.cpu().numpy()
|
||||
matte_name = im_name.split('.')[0] + '.png'
|
||||
Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name))
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
## MODNet - WebCam-Based Portrait Video Matting Demo
|
||||
This is a MODNet portrait video matting demo based on WebCam. It will call your local WebCam and display the matting results in real time.
|
||||
|
||||
### Requirements
|
||||
The basic requirements for this demo are:
|
||||
- Ubuntu System
|
||||
- WebCam
|
||||
- Nvidia GPU with CUDA
|
||||
- Python 3+
|
||||
|
||||
**NOTE**: If your device does not satisfy the above conditions, please try our [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
|
||||
|
||||
|
||||
### Introduction
|
||||
We use ~400 unlabeled video clips (divided into ~50,000 frames) downloaded from the internet to perform SOC to adapt MODNet to the video domain. Nonetheless, due to insufficient labeled training data (~3k labeled foregrounds), our model may still make errors in portrait semantics estimation under challenging scenes. Besides, this demo does not currently support the OFD trick, which will be provided soon.
|
||||
|
||||
For a better experience, please:
|
||||
|
||||
* make sure the portrait and background are distinguishable, <i>i.e.</i>, are not similar
|
||||
* run in soft and bright ambient lighting
|
||||
* do not be too close or too far from the WebCam
|
||||
* do not move too fast
|
||||
|
||||
### Run Demo
|
||||
We recommend creating a new conda virtual environment to run this demo, as follow:
|
||||
|
||||
1. Clone the MODNet repository:
|
||||
```
|
||||
git clone https://github.com/ZHKKKe/MODNet.git
|
||||
cd MODNet
|
||||
```
|
||||
|
||||
2. Download the pre-trained model from this [link](https://drive.google.com/file/d/1Nf1ZxeJZJL8Qx9KadcYYyEmmlKhTADxX/view?usp=sharing) and put it into the folder `MODNet/pretrained/`.
|
||||
|
||||
|
||||
3. Create a conda virtual environment named `modnet-webcam` and activate it:
|
||||
```
|
||||
conda create -n modnet-webcam python=3.6
|
||||
source activate modnet-webcam
|
||||
```
|
||||
|
||||
4. Install the required python dependencies (here we use PyTorch==1.0.0):
|
||||
```
|
||||
pip install -r demo/video_matting/requirements.txt
|
||||
```
|
||||
|
||||
5. Execute the main code:
|
||||
```
|
||||
python -m demo.video_matting.webcam
|
||||
```
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
numpy
|
||||
Pillow
|
||||
opencv-python
|
||||
torch == 1.0.0
|
||||
torchvision
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from src.models.modnet import MODNet
|
||||
|
||||
|
||||
torch_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
|
||||
print('Load pre-trained MODNet...')
|
||||
pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
|
||||
modnet = MODNet(backbone_pretrained=False)
|
||||
modnet = nn.DataParallel(modnet).cuda()
|
||||
modnet.load_state_dict(torch.load(pretrained_ckpt))
|
||||
modnet.eval()
|
||||
|
||||
print('Init WebCam...')
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
|
||||
|
||||
print('Start matting...')
|
||||
while(True):
|
||||
_, frame_np = cap.read()
|
||||
frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)
|
||||
frame_np = cv2.resize(frame_np, (910, 512), cv2.INTER_AREA)
|
||||
frame_np = frame_np[:, 120:792, :]
|
||||
frame_np = cv2.flip(frame_np, 1)
|
||||
|
||||
frame_PIL = Image.fromarray(frame_np)
|
||||
frame_tensor = torch_transforms(frame_PIL)
|
||||
frame_tensor = frame_tensor[None, :, :, :].cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
_, _, matte_tensor = modnet(frame_tensor, inference=True)
|
||||
|
||||
matte_tensor = matte_tensor.repeat(1, 3, 1, 1)
|
||||
matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0)
|
||||
fg_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0)
|
||||
view_np = np.uint8(np.concatenate((frame_np, fg_np), axis=1))
|
||||
view_np = cv2.cvtColor(view_np, cv2.COLOR_RGB2BGR)
|
||||
|
||||
cv2.imshow('MODNet - WebCam [Press \'Q\' To Exit]', view_np)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
print('Exit...')
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 9.2 MiB |
|
|
@ -0,0 +1,2 @@
|
|||
## MODNet - Pre-Trained Models
|
||||
This folder is used to save the official pre-trained models of MODNet. You can download them from this [link](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR?usp=sharing).
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
from .wrapper import *
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# Replaceable Backbones
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
SUPPORTED_BACKBONES = {
|
||||
'mobilenetv2': MobileNetV2Backbone,
|
||||
}
|
||||
|
|
@ -0,0 +1,185 @@
|
|||
""" This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch"""
|
||||
|
||||
import math
|
||||
import json
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# Useful functions
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def conv_bn(inp, oup, stride):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# Class of Inverted Residual block
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expansion, dilation=1):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = round(inp * expansion)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expansion == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# Class of MobileNetV2
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
|
||||
super(MobileNetV2, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
interverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1 , 16, 1, 1],
|
||||
[expansion, 24, 2, 2],
|
||||
[expansion, 32, 3, 2],
|
||||
[expansion, 64, 4, 2],
|
||||
[expansion, 96, 3, 1],
|
||||
[expansion, 160, 3, 2],
|
||||
[expansion, 320, 1, 1],
|
||||
]
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel*alpha, 8)
|
||||
self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel
|
||||
self.features = [conv_bn(self.in_channels, input_channel, 2)]
|
||||
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in interverted_residual_setting:
|
||||
output_channel = _make_divisible(int(c*alpha), 8)
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
|
||||
else:
|
||||
self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
|
||||
input_channel = output_channel
|
||||
|
||||
# building last several layers
|
||||
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
|
||||
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*self.features)
|
||||
|
||||
# building classifier
|
||||
if self.num_classes is not None:
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self._init_weights()
|
||||
|
||||
def forward(self, x, feature_names=None):
|
||||
# Stage1
|
||||
x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)
|
||||
# Stage2
|
||||
x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)
|
||||
# Stage3
|
||||
x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)
|
||||
# Stage4
|
||||
x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x)
|
||||
# Stage5
|
||||
x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x)
|
||||
|
||||
# Classification
|
||||
if self.num_classes is not None:
|
||||
x = x.mean(dim=(2,3))
|
||||
x = self.classifier(x)
|
||||
|
||||
# Output
|
||||
return x
|
||||
|
||||
def _load_pretrained_model(self, pretrained_file):
|
||||
pretrain_dict = torch.load(pretrained_file, map_location='cpu')
|
||||
model_dict = {}
|
||||
state_dict = self.state_dict()
|
||||
print("[MobileNetV2] Loading pretrained model...")
|
||||
for k, v in pretrain_dict.items():
|
||||
if k in state_dict:
|
||||
model_dict[k] = v
|
||||
else:
|
||||
print(k, "is ignored")
|
||||
state_dict.update(model_dict)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
n = m.weight.size(1)
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
m.bias.data.zero_()
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .mobilenetv2 import MobileNetV2
|
||||
|
||||
|
||||
class BaseBackbone(nn.Module):
|
||||
""" Superclass of Replaceable Backbone Model for Semantic Estimation
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(BaseBackbone, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.model = None
|
||||
self.enc_channels = []
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_pretrained_ckpt(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MobileNetV2Backbone(BaseBackbone):
|
||||
""" MobileNetV2 Backbone
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(MobileNetV2Backbone, self).__init__(in_channels)
|
||||
|
||||
self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
|
||||
self.enc_channels = [16, 24, 32, 96, 1280]
|
||||
|
||||
def forward(self, x):
|
||||
x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
|
||||
enc2x = x
|
||||
x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
|
||||
enc4x = x
|
||||
x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
|
||||
enc8x = x
|
||||
x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
|
||||
enc16x = x
|
||||
x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
|
||||
enc32x = x
|
||||
return [enc2x, enc4x, enc8x, enc16x, enc32x]
|
||||
|
||||
def load_pretrained_ckpt(self):
|
||||
# the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
|
||||
ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
|
||||
if not os.path.exists(ckpt_path):
|
||||
print('cannot find the pretrained mobilenetv2 backbone')
|
||||
exit()
|
||||
|
||||
ckpt = torch.load(ckpt_path)
|
||||
self.model.load_state_dict(ckpt)
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .backbones import SUPPORTED_BACKBONES
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# MODNet Basic Modules
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
class IBNorm(nn.Module):
|
||||
""" Combine Instance Norm and Batch Norm into One Layer
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(IBNorm, self).__init__()
|
||||
in_channels = in_channels
|
||||
self.bnorm_channels = int(in_channels / 2)
|
||||
self.inorm_channels = in_channels - self.bnorm_channels
|
||||
|
||||
self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
|
||||
self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
|
||||
|
||||
def forward(self, x):
|
||||
bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
|
||||
in_x = self.inorm(x[:, self.inorm_channels:, ...].contiguous())
|
||||
|
||||
return torch.cat((bn_x, in_x), 1)
|
||||
|
||||
|
||||
class Conv2dIBNormRelu(nn.Module):
|
||||
""" Convolution + IBNorm + ReLu
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size,
|
||||
stride=1, padding=0, dilation=1, groups=1, bias=True,
|
||||
with_ibn=True, with_relu=True):
|
||||
super(Conv2dIBNormRelu, self).__init__()
|
||||
|
||||
layers = [
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size,
|
||||
stride=stride, padding=padding, dilation=dilation,
|
||||
groups=groups, bias=bias)
|
||||
]
|
||||
|
||||
if with_ibn:
|
||||
layers.append(IBNorm(out_channels))
|
||||
if with_relu:
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, reduction=1):
|
||||
super(SEBlock, self).__init__()
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, int(in_channels // reduction), bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(int(in_channels // reduction), out_channels, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
w = self.pool(x).view(b, c)
|
||||
w = self.fc(w).view(b, c, 1, 1)
|
||||
|
||||
return x * w.expand_as(x)
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# MODNet Branches
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
class LRBranch(nn.Module):
|
||||
""" Low Resolution Branch of MODNet
|
||||
"""
|
||||
|
||||
def __init__(self, backbone):
|
||||
super(LRBranch, self).__init__()
|
||||
|
||||
enc_channels = backbone.enc_channels
|
||||
|
||||
self.backbone = backbone
|
||||
self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
|
||||
self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
|
||||
self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
|
||||
self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)
|
||||
|
||||
def forward(self, img, inference):
|
||||
enc_features = self.backbone.forward(img)
|
||||
enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
|
||||
|
||||
enc32x = self.se_block(enc32x)
|
||||
lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
lr16x = self.conv_lr16x(lr16x)
|
||||
lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
lr8x = self.conv_lr8x(lr8x)
|
||||
|
||||
pred_semantic = None
|
||||
if not inference:
|
||||
lr = self.conv_lr(lr8x)
|
||||
pred_semantic = torch.sigmoid(lr)
|
||||
|
||||
return pred_semantic, lr8x, [enc2x, enc4x]
|
||||
|
||||
|
||||
class HRBranch(nn.Module):
|
||||
""" High Resolution Branch of MODNet
|
||||
"""
|
||||
|
||||
def __init__(self, hr_channels, enc_channels):
|
||||
super(HRBranch, self).__init__()
|
||||
|
||||
self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
|
||||
self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
|
||||
|
||||
self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
|
||||
self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
|
||||
|
||||
self.conv_hr4x = nn.Sequential(
|
||||
Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
|
||||
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
|
||||
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
|
||||
)
|
||||
|
||||
self.conv_hr2x = nn.Sequential(
|
||||
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
|
||||
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
|
||||
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
|
||||
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
|
||||
)
|
||||
|
||||
self.conv_hr = nn.Sequential(
|
||||
Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
|
||||
Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
|
||||
)
|
||||
|
||||
def forward(self, img, enc2x, enc4x, lr8x, inference):
|
||||
img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
|
||||
img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)
|
||||
|
||||
enc2x = self.tohr_enc2x(enc2x)
|
||||
hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
|
||||
|
||||
enc4x = self.tohr_enc4x(enc4x)
|
||||
hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
|
||||
|
||||
lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
|
||||
|
||||
hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
|
||||
|
||||
pred_detail = None
|
||||
if not inference:
|
||||
hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
hr = self.conv_hr(torch.cat((hr, img), dim=1))
|
||||
pred_detail = torch.sigmoid(hr)
|
||||
|
||||
return pred_detail, hr2x
|
||||
|
||||
|
||||
class FusionBranch(nn.Module):
|
||||
""" Fusion Branch of MODNet
|
||||
"""
|
||||
|
||||
def __init__(self, hr_channels, enc_channels):
|
||||
super(FusionBranch, self).__init__()
|
||||
self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
|
||||
|
||||
self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
|
||||
self.conv_f = nn.Sequential(
|
||||
Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
|
||||
Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
|
||||
)
|
||||
|
||||
def forward(self, img, lr8x, hr2x):
|
||||
lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
lr4x = self.conv_lr4x(lr4x)
|
||||
lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
|
||||
f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
|
||||
f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
f = self.conv_f(torch.cat((f, img), dim=1))
|
||||
pred_matte = torch.sigmoid(f)
|
||||
|
||||
return pred_matte
|
||||
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# MODNet
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
class MODNet(nn.Module):
|
||||
""" Architecture of MODNet
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
|
||||
super(MODNet, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.hr_channels = hr_channels
|
||||
self.backbone_arch = backbone_arch
|
||||
self.backbone_pretrained = backbone_pretrained
|
||||
|
||||
self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)
|
||||
|
||||
self.lr_branch = LRBranch(self.backbone)
|
||||
self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
|
||||
self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
self._init_conv(m)
|
||||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
|
||||
self._init_norm(m)
|
||||
|
||||
if self.backbone_pretrained:
|
||||
self.backbone.load_pretrained_ckpt()
|
||||
|
||||
def forward(self, img, inference):
|
||||
pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
|
||||
pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
|
||||
pred_matte = self.f_branch(img, lr8x, hr2x)
|
||||
|
||||
return pred_semantic, pred_detail, pred_matte
|
||||
|
||||
def freeze_norm(self):
|
||||
norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
|
||||
for m in self.modules():
|
||||
for n in norm_types:
|
||||
if isinstance(m, n):
|
||||
m.eval()
|
||||
continue
|
||||
|
||||
def _init_conv(self, conv):
|
||||
nn.init.kaiming_uniform_(
|
||||
conv.weight, a=0, mode='fan_in', nonlinearity='relu')
|
||||
if conv.bias is not None:
|
||||
nn.init.constant_(conv.bias, 0)
|
||||
|
||||
def _init_norm(self, norm):
|
||||
if norm.weight is not None:
|
||||
nn.init.constant_(norm.weight, 1)
|
||||
nn.init.constant_(norm.bias, 0)
|
||||
Loading…
Reference in New Issue