upload portrait image/video matting demos of MODNet

pull/21/head
ZHKKKe 2020-12-07 13:55:12 +08:00 committed by kezhanghan
parent 7358e486b3
commit 0bdc3d1ddf
15 changed files with 849 additions and 9 deletions

96
.gitignore vendored Normal file
View File

@ -0,0 +1,96 @@
# Temporary directories and files
*.ckpt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
# Project files
.vscode

View File

@ -4,27 +4,44 @@
<p align="center">
<a href="https://arxiv.org/pdf/2011.11961.pdf">Arxiv Preprint</a> |
<a href="https://youtu.be/PqJ3BRHX3Lc">Supplementary Video</a>
<a href="https://youtu.be/PqJ3BRHX3Lc">Supplementary Video</a> |
<a href="https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing">Video Matting Demo</a> :fire: |
<a href="https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing">Image Matting Demo</a> :fire:
</p>
<div align="center">This is the official project of our paper <b>Is a Green Screen Really Necessary for Real-Time Portrait Matting?</b></div>
<div align="center">MODNet is a <b>trimap-free</b> model for portrait matting in <b>real time</b> (on a single GPU).</div>
<div align="center">Our amazing demo, code, pre-trained model, and validation benchmark are coming soon!</div>
---
## Announcement
I have received some requests for accessing our code. I am sorry that we need some time to get everything ready since this repository is now supported by Zhanghan Ke alone. Our plans in the next few months are:
- We will publish an online image/video matting demo along with the pre-trained model **in these two weeks (approximately Dec. 7, 2020 to Dec. 18, 2020)**.
- We then plan to release the code of supervised training and unsupervised SOC in Jan. 2021.
- We finally plan to open source the PPM-100 validation benchmark in Feb. 2021.
We look forward to your continued attention to this project. Thanks.
## News
- [Dec 10 2020] Release [Video Matting Demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing) and [Image Matting Demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing).
- [Nov 24 2020] Release [Arxiv Preprint](https://arxiv.org/pdf/2011.11961.pdf) and [Supplementary Video](https://youtu.be/PqJ3BRHX3Lc).
## Video Matting Demo
We provide two real-time portrait video matting demos based on WebCam.
If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/README.md) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
## Image Matting Demo
We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting.
It allows you to upload portrait images and predict/visualize/download the alpha mattes.
<img src="doc/gif/image_matting_demo.gif">
## TO DO
- Release training code (scheduled in **Jan. 2021**)
- Release PPM-100 validation benchmark (scheduled in **Feb. 2021**)
## Acknowledgement
We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project.
## Citation
If this work helps your research, please consider to cite:
@ -37,3 +54,7 @@ If this work helps your research, please consider to cite:
year = {2020},
}
```
## Contact
This project is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
If you have any questions, please feel free to contact `kezhanghan@outlook.com`.

View File

@ -0,0 +1,2 @@
## MODNet - Portrait Image Matting Demo
Please try MODNet portrait image matting demo through our [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing).

View File

@ -0,0 +1,99 @@
import os
import sys
import argparse
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from src.models.modnet import MODNet
if __name__ == '__main__':
# define cmd arguments
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help='path of input images')
parser.add_argument('--output-path', type=str, help='path of output images')
parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet')
args = parser.parse_args()
# check input arguments
if not os.path.exists(args.input_path):
print('Cannot find input path: {0}'.format(args.input_path))
exit()
if not os.path.exists(args.output_path):
print('Cannot find output path: {0}'.format(args.output_path))
exit()
if not os.path.exists(args.ckpt_path):
print('Cannot find ckpt path: {0}'.format(args.ckpt_path))
exit()
# define hyper-parameters
ref_size = 512
# define image to tensor transform
im_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
# create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet).cuda()
modnet.load_state_dict(torch.load(args.ckpt_path))
modnet.eval()
# inference images
im_names = os.listdir(args.input_path)
for im_name in im_names:
print('Process image: {0}'.format(im_name))
# read image
im = Image.open(os.path.join(args.input_path, im_name))
# unify image channels to 3
im = np.asarray(im)
if len(im.shape) == 2:
im = im[:, :, None]
if im.shape[2] == 1:
im = np.repeat(im, 3, axis=2)
elif im.shape[2] == 4:
im = im[:, :, 0:3]
# convert image to PyTorch tensor
im = Image.fromarray(im)
im = im_transform(im)
# add mini-batch dim
im = im[None, :, :, :]
# resize image for input
im_b, im_c, im_h, im_w = im.shape
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
if im_w >= im_h:
im_rh = ref_size
im_rw = int(im_w / im_h * ref_size)
elif im_w < im_h:
im_rw = ref_size
im_rh = int(im_h / im_w * ref_size)
else:
im_rh = im_h
im_rw = im_w
im_rw = im_rw - im_rw % 32
im_rh = im_rh - im_rh % 32
im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
# inference
_, _, matte = modnet(im.cuda(), inference=False)
# resize and save matte
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
matte = matte[0][0].data.cpu().numpy()
matte_name = im_name.split('.')[0] + '.png'
Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name))

View File

@ -0,0 +1,50 @@
## MODNet - WebCam-Based Portrait Video Matting Demo
This is a MODNet portrait video matting demo based on WebCam. It will call your local WebCam and display the matting results in real time.
### Requirements
The basic requirements for this demo are:
- Ubuntu System
- WebCam
- Nvidia GPU with CUDA
- Python 3+
**NOTE**: If your device does not satisfy the above conditions, please try our [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
### Introduction
We use ~400 unlabeled video clips (divided into ~50,000 frames) downloaded from the internet to perform SOC to adapt MODNet to the video domain. Nonetheless, due to insufficient labeled training data (~3k labeled foregrounds), our model may still make errors in portrait semantics estimation under challenging scenes. Besides, this demo does not currently support the OFD trick, which will be provided soon.
For a better experience, please:
* make sure the portrait and background are distinguishable, <i>i.e.</i>, are not similar
* run in soft and bright ambient lighting
* do not be too close or too far from the WebCam
* do not move too fast
### Run Demo
We recommend creating a new conda virtual environment to run this demo, as follow:
1. Clone the MODNet repository:
```
git clone https://github.com/ZHKKKe/MODNet.git
cd MODNet
```
2. Download the pre-trained model from this [link](https://drive.google.com/file/d/1Nf1ZxeJZJL8Qx9KadcYYyEmmlKhTADxX/view?usp=sharing) and put it into the folder `MODNet/pretrained/`.
3. Create a conda virtual environment named `modnet-webcam` and activate it:
```
conda create -n modnet-webcam python=3.6
source activate modnet-webcam
```
4. Install the required python dependencies (here we use PyTorch==1.0.0):
```
pip install -r demo/video_matting/requirements.txt
```
5. Execute the main code:
```
python -m demo.video_matting.webcam
```

View File

@ -0,0 +1,5 @@
numpy
Pillow
opencv-python
torch == 1.0.0
torchvision

View File

@ -0,0 +1,56 @@
import cv2
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from src.models.modnet import MODNet
torch_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
print('Load pre-trained MODNet...')
pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet).cuda()
modnet.load_state_dict(torch.load(pretrained_ckpt))
modnet.eval()
print('Init WebCam...')
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
print('Start matting...')
while(True):
_, frame_np = cap.read()
frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)
frame_np = cv2.resize(frame_np, (910, 512), cv2.INTER_AREA)
frame_np = frame_np[:, 120:792, :]
frame_np = cv2.flip(frame_np, 1)
frame_PIL = Image.fromarray(frame_np)
frame_tensor = torch_transforms(frame_PIL)
frame_tensor = frame_tensor[None, :, :, :].cuda()
with torch.no_grad():
_, _, matte_tensor = modnet(frame_tensor, inference=True)
matte_tensor = matte_tensor.repeat(1, 3, 1, 1)
matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0)
fg_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0)
view_np = np.uint8(np.concatenate((frame_np, fg_np), axis=1))
view_np = cv2.cvtColor(view_np, cv2.COLOR_RGB2BGR)
cv2.imshow('MODNet - WebCam [Press \'Q\' To Exit]', view_np)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
print('Exit...')

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 MiB

2
pretrained/README.md Normal file
View File

@ -0,0 +1,2 @@
## MODNet - Pre-Trained Models
This folder is used to save the official pre-trained models of MODNet. You can download them from this [link](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR?usp=sharing).

0
src/__init__.py Normal file
View File

0
src/models/__init__.py Normal file
View File

View File

@ -0,0 +1,10 @@
from .wrapper import *
#------------------------------------------------------------------------------
# Replaceable Backbones
#------------------------------------------------------------------------------
SUPPORTED_BACKBONES = {
'mobilenetv2': MobileNetV2Backbone,
}

View File

@ -0,0 +1,185 @@
""" This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch"""
import math
import json
from functools import reduce
import torch
from torch import nn
#------------------------------------------------------------------------------
# Useful functions
#------------------------------------------------------------------------------
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
#------------------------------------------------------------------------------
# Class of Inverted Residual block
#------------------------------------------------------------------------------
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expansion, dilation=1):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
if expansion == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
#------------------------------------------------------------------------------
# Class of MobileNetV2
#------------------------------------------------------------------------------
class MobileNetV2(nn.Module):
def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
super(MobileNetV2, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
input_channel = 32
last_channel = 1280
interverted_residual_setting = [
# t, c, n, s
[1 , 16, 1, 1],
[expansion, 24, 2, 2],
[expansion, 32, 3, 2],
[expansion, 64, 4, 2],
[expansion, 96, 3, 1],
[expansion, 160, 3, 2],
[expansion, 320, 1, 1],
]
# building first layer
input_channel = _make_divisible(input_channel*alpha, 8)
self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel
self.features = [conv_bn(self.in_channels, input_channel, 2)]
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
output_channel = _make_divisible(int(c*alpha), 8)
for i in range(n):
if i == 0:
self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
else:
self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
input_channel = output_channel
# building last several layers
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
# make it nn.Sequential
self.features = nn.Sequential(*self.features)
# building classifier
if self.num_classes is not None:
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)
# Initialize weights
self._init_weights()
def forward(self, x, feature_names=None):
# Stage1
x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)
# Stage2
x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)
# Stage3
x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)
# Stage4
x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x)
# Stage5
x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x)
# Classification
if self.num_classes is not None:
x = x.mean(dim=(2,3))
x = self.classifier(x)
# Output
return x
def _load_pretrained_model(self, pretrained_file):
pretrain_dict = torch.load(pretrained_file, map_location='cpu')
model_dict = {}
state_dict = self.state_dict()
print("[MobileNetV2] Loading pretrained model...")
for k, v in pretrain_dict.items():
if k in state_dict:
model_dict[k] = v
else:
print(k, "is ignored")
state_dict.update(model_dict)
self.load_state_dict(state_dict)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()

View File

@ -0,0 +1,59 @@
import os
from functools import reduce
import torch
import torch.nn as nn
from .mobilenetv2 import MobileNetV2
class BaseBackbone(nn.Module):
""" Superclass of Replaceable Backbone Model for Semantic Estimation
"""
def __init__(self, in_channels):
super(BaseBackbone, self).__init__()
self.in_channels = in_channels
self.model = None
self.enc_channels = []
def forward(self, x):
raise NotImplementedError
def load_pretrained_ckpt(self):
raise NotImplementedError
class MobileNetV2Backbone(BaseBackbone):
""" MobileNetV2 Backbone
"""
def __init__(self, in_channels):
super(MobileNetV2Backbone, self).__init__(in_channels)
self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
self.enc_channels = [16, 24, 32, 96, 1280]
def forward(self, x):
x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
enc2x = x
x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
enc4x = x
x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
enc8x = x
x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
enc16x = x
x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
enc32x = x
return [enc2x, enc4x, enc8x, enc16x, enc32x]
def load_pretrained_ckpt(self):
# the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
if not os.path.exists(ckpt_path):
print('cannot find the pretrained mobilenetv2 backbone')
exit()
ckpt = torch.load(ckpt_path)
self.model.load_state_dict(ckpt)

255
src/models/modnet.py Normal file
View File

@ -0,0 +1,255 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbones import SUPPORTED_BACKBONES
#------------------------------------------------------------------------------
# MODNet Basic Modules
#------------------------------------------------------------------------------
class IBNorm(nn.Module):
""" Combine Instance Norm and Batch Norm into One Layer
"""
def __init__(self, in_channels):
super(IBNorm, self).__init__()
in_channels = in_channels
self.bnorm_channels = int(in_channels / 2)
self.inorm_channels = in_channels - self.bnorm_channels
self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
def forward(self, x):
bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
in_x = self.inorm(x[:, self.inorm_channels:, ...].contiguous())
return torch.cat((bn_x, in_x), 1)
class Conv2dIBNormRelu(nn.Module):
""" Convolution + IBNorm + ReLu
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, bias=True,
with_ibn=True, with_relu=True):
super(Conv2dIBNormRelu, self).__init__()
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
]
if with_ibn:
layers.append(IBNorm(out_channels))
if with_relu:
layers.append(nn.ReLU(inplace=True))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class SEBlock(nn.Module):
""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
"""
def __init__(self, in_channels, out_channels, reduction=1):
super(SEBlock, self).__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, int(in_channels // reduction), bias=False),
nn.ReLU(inplace=True),
nn.Linear(int(in_channels // reduction), out_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
w = self.pool(x).view(b, c)
w = self.fc(w).view(b, c, 1, 1)
return x * w.expand_as(x)
#------------------------------------------------------------------------------
# MODNet Branches
#------------------------------------------------------------------------------
class LRBranch(nn.Module):
""" Low Resolution Branch of MODNet
"""
def __init__(self, backbone):
super(LRBranch, self).__init__()
enc_channels = backbone.enc_channels
self.backbone = backbone
self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)
def forward(self, img, inference):
enc_features = self.backbone.forward(img)
enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
enc32x = self.se_block(enc32x)
lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
lr16x = self.conv_lr16x(lr16x)
lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
lr8x = self.conv_lr8x(lr8x)
pred_semantic = None
if not inference:
lr = self.conv_lr(lr8x)
pred_semantic = torch.sigmoid(lr)
return pred_semantic, lr8x, [enc2x, enc4x]
class HRBranch(nn.Module):
""" High Resolution Branch of MODNet
"""
def __init__(self, hr_channels, enc_channels):
super(HRBranch, self).__init__()
self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
self.conv_hr4x = nn.Sequential(
Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
)
self.conv_hr2x = nn.Sequential(
Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
)
self.conv_hr = nn.Sequential(
Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
)
def forward(self, img, enc2x, enc4x, lr8x, inference):
img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)
enc2x = self.tohr_enc2x(enc2x)
hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
enc4x = self.tohr_enc4x(enc4x)
hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
pred_detail = None
if not inference:
hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
hr = self.conv_hr(torch.cat((hr, img), dim=1))
pred_detail = torch.sigmoid(hr)
return pred_detail, hr2x
class FusionBranch(nn.Module):
""" Fusion Branch of MODNet
"""
def __init__(self, hr_channels, enc_channels):
super(FusionBranch, self).__init__()
self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
self.conv_f = nn.Sequential(
Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
)
def forward(self, img, lr8x, hr2x):
lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
lr4x = self.conv_lr4x(lr4x)
lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
f = self.conv_f(torch.cat((f, img), dim=1))
pred_matte = torch.sigmoid(f)
return pred_matte
#------------------------------------------------------------------------------
# MODNet
#------------------------------------------------------------------------------
class MODNet(nn.Module):
""" Architecture of MODNet
"""
def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
super(MODNet, self).__init__()
self.in_channels = in_channels
self.hr_channels = hr_channels
self.backbone_arch = backbone_arch
self.backbone_pretrained = backbone_pretrained
self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)
self.lr_branch = LRBranch(self.backbone)
self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
for m in self.modules():
if isinstance(m, nn.Conv2d):
self._init_conv(m)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
self._init_norm(m)
if self.backbone_pretrained:
self.backbone.load_pretrained_ckpt()
def forward(self, img, inference):
pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
pred_matte = self.f_branch(img, lr8x, hr2x)
return pred_semantic, pred_detail, pred_matte
def freeze_norm(self):
norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
for m in self.modules():
for n in norm_types:
if isinstance(m, n):
m.eval()
continue
def _init_conv(self, conv):
nn.init.kaiming_uniform_(
conv.weight, a=0, mode='fan_in', nonlinearity='relu')
if conv.bias is not None:
nn.init.constant_(conv.bias, 0)
def _init_norm(self, norm):
if norm.weight is not None:
nn.init.constant_(norm.weight, 1)
nn.init.constant_(norm.bias, 0)