diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d64f109 --- /dev/null +++ b/.gitignore @@ -0,0 +1,96 @@ +# Temporary directories and files +*.ckpt + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# IPython Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# dotenv +.env + +# virtualenv +venv/ +ENV/ + +# Spyder project settings +.spyderproject + +# Rope project settings +.ropeproject + + +# Project files +.vscode \ No newline at end of file diff --git a/README.md b/README.md index b011385..2eb8f98 100644 --- a/README.md +++ b/README.md @@ -4,27 +4,44 @@

Arxiv Preprint | - Supplementary Video + Supplementary Video | + Video Matting Demo :fire: | + Image Matting Demo :fire:

This is the official project of our paper Is a Green Screen Really Necessary for Real-Time Portrait Matting?
MODNet is a trimap-free model for portrait matting in real time (on a single GPU).
-
Our amazing demo, code, pre-trained model, and validation benchmark are coming soon!
--- -## Announcement -I have received some requests for accessing our code. I am sorry that we need some time to get everything ready since this repository is now supported by Zhanghan Ke alone. Our plans in the next few months are: -- We will publish an online image/video matting demo along with the pre-trained model **in these two weeks (approximately Dec. 7, 2020 to Dec. 18, 2020)**. -- We then plan to release the code of supervised training and unsupervised SOC in Jan. 2021. -- We finally plan to open source the PPM-100 validation benchmark in Feb. 2021. - -We look forward to your continued attention to this project. Thanks. ## News +- [Dec 10 2020] Release [Video Matting Demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing) and [Image Matting Demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing). - [Nov 24 2020] Release [Arxiv Preprint](https://arxiv.org/pdf/2011.11961.pdf) and [Supplementary Video](https://youtu.be/PqJ3BRHX3Lc). + +## Video Matting Demo +We provide two real-time portrait video matting demos based on WebCam. +If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/README.md) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing). + + + +## Image Matting Demo +We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting. +It allows you to upload portrait images and predict/visualize/download the alpha mattes. + + + + +## TO DO +- Release training code (scheduled in **Jan. 2021**) +- Release PPM-100 validation benchmark (scheduled in **Feb. 2021**) + +## Acknowledgement +We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project. + + ## Citation If this work helps your research, please consider to cite: @@ -37,3 +54,7 @@ If this work helps your research, please consider to cite: year = {2020}, } ``` + +## Contact +This project is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)). +If you have any questions, please feel free to contact `kezhanghan@outlook.com`. diff --git a/demo/image_matting/README.md b/demo/image_matting/README.md new file mode 100644 index 0000000..6642651 --- /dev/null +++ b/demo/image_matting/README.md @@ -0,0 +1,2 @@ +## MODNet - Portrait Image Matting Demo +Please try MODNet portrait image matting demo through our [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing). diff --git a/demo/image_matting/inference.py b/demo/image_matting/inference.py new file mode 100644 index 0000000..2cdaff7 --- /dev/null +++ b/demo/image_matting/inference.py @@ -0,0 +1,99 @@ +import os +import sys +import argparse +import numpy as np +from PIL import Image + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms + +from src.models.modnet import MODNet + + +if __name__ == '__main__': + # define cmd arguments + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help='path of input images') + parser.add_argument('--output-path', type=str, help='path of output images') + parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet') + args = parser.parse_args() + + # check input arguments + if not os.path.exists(args.input_path): + print('Cannot find input path: {0}'.format(args.input_path)) + exit() + if not os.path.exists(args.output_path): + print('Cannot find output path: {0}'.format(args.output_path)) + exit() + if not os.path.exists(args.ckpt_path): + print('Cannot find ckpt path: {0}'.format(args.ckpt_path)) + exit() + + # define hyper-parameters + ref_size = 512 + + # define image to tensor transform + im_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ] + ) + + # create MODNet and load the pre-trained ckpt + modnet = MODNet(backbone_pretrained=False) + modnet = nn.DataParallel(modnet).cuda() + modnet.load_state_dict(torch.load(args.ckpt_path)) + modnet.eval() + + # inference images + im_names = os.listdir(args.input_path) + for im_name in im_names: + print('Process image: {0}'.format(im_name)) + + # read image + im = Image.open(os.path.join(args.input_path, im_name)) + + # unify image channels to 3 + im = np.asarray(im) + if len(im.shape) == 2: + im = im[:, :, None] + if im.shape[2] == 1: + im = np.repeat(im, 3, axis=2) + elif im.shape[2] == 4: + im = im[:, :, 0:3] + + # convert image to PyTorch tensor + im = Image.fromarray(im) + im = im_transform(im) + + # add mini-batch dim + im = im[None, :, :, :] + + # resize image for input + im_b, im_c, im_h, im_w = im.shape + if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: + if im_w >= im_h: + im_rh = ref_size + im_rw = int(im_w / im_h * ref_size) + elif im_w < im_h: + im_rw = ref_size + im_rh = int(im_h / im_w * ref_size) + else: + im_rh = im_h + im_rw = im_w + + im_rw = im_rw - im_rw % 32 + im_rh = im_rh - im_rh % 32 + im = F.interpolate(im, size=(im_rh, im_rw), mode='area') + + # inference + _, _, matte = modnet(im.cuda(), inference=False) + + # resize and save matte + matte = F.interpolate(matte, size=(im_h, im_w), mode='area') + matte = matte[0][0].data.cpu().numpy() + matte_name = im_name.split('.')[0] + '.png' + Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name)) diff --git a/demo/video_matting/README.md b/demo/video_matting/README.md new file mode 100644 index 0000000..f622201 --- /dev/null +++ b/demo/video_matting/README.md @@ -0,0 +1,50 @@ +## MODNet - WebCam-Based Portrait Video Matting Demo +This is a MODNet portrait video matting demo based on WebCam. It will call your local WebCam and display the matting results in real time. + +### Requirements +The basic requirements for this demo are: +- Ubuntu System +- WebCam +- Nvidia GPU with CUDA +- Python 3+ + +**NOTE**: If your device does not satisfy the above conditions, please try our [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing). + + +### Introduction +We use ~400 unlabeled video clips (divided into ~50,000 frames) downloaded from the internet to perform SOC to adapt MODNet to the video domain. Nonetheless, due to insufficient labeled training data (~3k labeled foregrounds), our model may still make errors in portrait semantics estimation under challenging scenes. Besides, this demo does not currently support the OFD trick, which will be provided soon. + +For a better experience, please: + +* make sure the portrait and background are distinguishable, i.e., are not similar +* run in soft and bright ambient lighting +* do not be too close or too far from the WebCam +* do not move too fast + +### Run Demo +We recommend creating a new conda virtual environment to run this demo, as follow: + +1. Clone the MODNet repository: + ``` + git clone https://github.com/ZHKKKe/MODNet.git + cd MODNet + ``` + +2. Download the pre-trained model from this [link](https://drive.google.com/file/d/1Nf1ZxeJZJL8Qx9KadcYYyEmmlKhTADxX/view?usp=sharing) and put it into the folder `MODNet/pretrained/`. + + +3. Create a conda virtual environment named `modnet-webcam` and activate it: + ``` + conda create -n modnet-webcam python=3.6 + source activate modnet-webcam + ``` + +4. Install the required python dependencies (here we use PyTorch==1.0.0): + ``` + pip install -r demo/video_matting/requirements.txt + ``` + +5. Execute the main code: + ``` + python -m demo.video_matting.webcam + ``` diff --git a/demo/video_matting/requirements.txt b/demo/video_matting/requirements.txt new file mode 100644 index 0000000..eb0d67b --- /dev/null +++ b/demo/video_matting/requirements.txt @@ -0,0 +1,5 @@ +numpy +Pillow +opencv-python +torch == 1.0.0 +torchvision \ No newline at end of file diff --git a/demo/video_matting/webcam.py b/demo/video_matting/webcam.py new file mode 100644 index 0000000..fc09aeb --- /dev/null +++ b/demo/video_matting/webcam.py @@ -0,0 +1,56 @@ +import cv2 +import numpy as np +from PIL import Image + +import torch +import torch.nn as nn +import torchvision.transforms as transforms + +from src.models.modnet import MODNet + + +torch_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] +) + +print('Load pre-trained MODNet...') +pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt' +modnet = MODNet(backbone_pretrained=False) +modnet = nn.DataParallel(modnet).cuda() +modnet.load_state_dict(torch.load(pretrained_ckpt)) +modnet.eval() + +print('Init WebCam...') +cap = cv2.VideoCapture(0) +cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) +cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) + +print('Start matting...') +while(True): + _, frame_np = cap.read() + frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB) + frame_np = cv2.resize(frame_np, (910, 512), cv2.INTER_AREA) + frame_np = frame_np[:, 120:792, :] + frame_np = cv2.flip(frame_np, 1) + + frame_PIL = Image.fromarray(frame_np) + frame_tensor = torch_transforms(frame_PIL) + frame_tensor = frame_tensor[None, :, :, :].cuda() + + with torch.no_grad(): + _, _, matte_tensor = modnet(frame_tensor, inference=True) + + matte_tensor = matte_tensor.repeat(1, 3, 1, 1) + matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0) + fg_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0) + view_np = np.uint8(np.concatenate((frame_np, fg_np), axis=1)) + view_np = cv2.cvtColor(view_np, cv2.COLOR_RGB2BGR) + + cv2.imshow('MODNet - WebCam [Press \'Q\' To Exit]', view_np) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + +print('Exit...') diff --git a/doc/gif/image_matting_demo.gif b/doc/gif/image_matting_demo.gif new file mode 100644 index 0000000..a1a5e87 Binary files /dev/null and b/doc/gif/image_matting_demo.gif differ diff --git a/pretrained/README.md b/pretrained/README.md new file mode 100644 index 0000000..7eaa227 --- /dev/null +++ b/pretrained/README.md @@ -0,0 +1,2 @@ +## MODNet - Pre-Trained Models +This folder is used to save the official pre-trained models of MODNet. You can download them from this [link](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR?usp=sharing). \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/backbones/__init__.py b/src/models/backbones/__init__.py new file mode 100644 index 0000000..4cbeee5 --- /dev/null +++ b/src/models/backbones/__init__.py @@ -0,0 +1,10 @@ +from .wrapper import * + + +#------------------------------------------------------------------------------ +# Replaceable Backbones +#------------------------------------------------------------------------------ + +SUPPORTED_BACKBONES = { + 'mobilenetv2': MobileNetV2Backbone, +} diff --git a/src/models/backbones/mobilenetv2.py b/src/models/backbones/mobilenetv2.py new file mode 100644 index 0000000..67cc138 --- /dev/null +++ b/src/models/backbones/mobilenetv2.py @@ -0,0 +1,185 @@ +""" This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch""" + +import math +import json +from functools import reduce + +import torch +from torch import nn + + +#------------------------------------------------------------------------------ +# Useful functions +#------------------------------------------------------------------------------ + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +#------------------------------------------------------------------------------ +# Class of Inverted Residual block +#------------------------------------------------------------------------------ + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expansion, dilation=1): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expansion) + self.use_res_connect = self.stride == 1 and inp == oup + + if expansion == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +#------------------------------------------------------------------------------ +# Class of MobileNetV2 +#------------------------------------------------------------------------------ + +class MobileNetV2(nn.Module): + def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000): + super(MobileNetV2, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s + [1 , 16, 1, 1], + [expansion, 24, 2, 2], + [expansion, 32, 3, 2], + [expansion, 64, 4, 2], + [expansion, 96, 3, 1], + [expansion, 160, 3, 2], + [expansion, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel*alpha, 8) + self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel + self.features = [conv_bn(self.in_channels, input_channel, 2)] + + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = _make_divisible(int(c*alpha), 8) + for i in range(n): + if i == 0: + self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t)) + else: + self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t)) + input_channel = output_channel + + # building last several layers + self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # building classifier + if self.num_classes is not None: + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, num_classes), + ) + + # Initialize weights + self._init_weights() + + def forward(self, x, feature_names=None): + # Stage1 + x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x) + # Stage2 + x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x) + # Stage3 + x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x) + # Stage4 + x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x) + # Stage5 + x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x) + + # Classification + if self.num_classes is not None: + x = x.mean(dim=(2,3)) + x = self.classifier(x) + + # Output + return x + + def _load_pretrained_model(self, pretrained_file): + pretrain_dict = torch.load(pretrained_file, map_location='cpu') + model_dict = {} + state_dict = self.state_dict() + print("[MobileNetV2] Loading pretrained model...") + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + else: + print(k, "is ignored") + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() diff --git a/src/models/backbones/wrapper.py b/src/models/backbones/wrapper.py new file mode 100644 index 0000000..36817ba --- /dev/null +++ b/src/models/backbones/wrapper.py @@ -0,0 +1,59 @@ +import os +from functools import reduce + +import torch +import torch.nn as nn + +from .mobilenetv2 import MobileNetV2 + + +class BaseBackbone(nn.Module): + """ Superclass of Replaceable Backbone Model for Semantic Estimation + """ + + def __init__(self, in_channels): + super(BaseBackbone, self).__init__() + self.in_channels = in_channels + + self.model = None + self.enc_channels = [] + + def forward(self, x): + raise NotImplementedError + + def load_pretrained_ckpt(self): + raise NotImplementedError + + +class MobileNetV2Backbone(BaseBackbone): + """ MobileNetV2 Backbone + """ + + def __init__(self, in_channels): + super(MobileNetV2Backbone, self).__init__(in_channels) + + self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None) + self.enc_channels = [16, 24, 32, 96, 1280] + + def forward(self, x): + x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) + enc2x = x + x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) + enc4x = x + x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) + enc8x = x + x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) + enc16x = x + x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) + enc32x = x + return [enc2x, enc4x, enc8x, enc16x, enc32x] + + def load_pretrained_ckpt(self): + # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch + ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' + if not os.path.exists(ckpt_path): + print('cannot find the pretrained mobilenetv2 backbone') + exit() + + ckpt = torch.load(ckpt_path) + self.model.load_state_dict(ckpt) diff --git a/src/models/modnet.py b/src/models/modnet.py new file mode 100644 index 0000000..16609b3 --- /dev/null +++ b/src/models/modnet.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbones import SUPPORTED_BACKBONES + + +#------------------------------------------------------------------------------ +# MODNet Basic Modules +#------------------------------------------------------------------------------ + +class IBNorm(nn.Module): + """ Combine Instance Norm and Batch Norm into One Layer + """ + + def __init__(self, in_channels): + super(IBNorm, self).__init__() + in_channels = in_channels + self.bnorm_channels = int(in_channels / 2) + self.inorm_channels = in_channels - self.bnorm_channels + + self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) + self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) + + def forward(self, x): + bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) + in_x = self.inorm(x[:, self.inorm_channels:, ...].contiguous()) + + return torch.cat((bn_x, in_x), 1) + + +class Conv2dIBNormRelu(nn.Module): + """ Convolution + IBNorm + ReLu + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True, + with_ibn=True, with_relu=True): + super(Conv2dIBNormRelu, self).__init__() + + layers = [ + nn.Conv2d(in_channels, out_channels, kernel_size, + stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + ] + + if with_ibn: + layers.append(IBNorm(out_channels)) + if with_relu: + layers.append(nn.ReLU(inplace=True)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class SEBlock(nn.Module): + """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf + """ + + def __init__(self, in_channels, out_channels, reduction=1): + super(SEBlock, self).__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(in_channels, int(in_channels // reduction), bias=False), + nn.ReLU(inplace=True), + nn.Linear(int(in_channels // reduction), out_channels, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + w = self.pool(x).view(b, c) + w = self.fc(w).view(b, c, 1, 1) + + return x * w.expand_as(x) + + +#------------------------------------------------------------------------------ +# MODNet Branches +#------------------------------------------------------------------------------ + +class LRBranch(nn.Module): + """ Low Resolution Branch of MODNet + """ + + def __init__(self, backbone): + super(LRBranch, self).__init__() + + enc_channels = backbone.enc_channels + + self.backbone = backbone + self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) + self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) + self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) + self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) + + def forward(self, img, inference): + enc_features = self.backbone.forward(img) + enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] + + enc32x = self.se_block(enc32x) + lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) + lr16x = self.conv_lr16x(lr16x) + lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) + lr8x = self.conv_lr8x(lr8x) + + pred_semantic = None + if not inference: + lr = self.conv_lr(lr8x) + pred_semantic = torch.sigmoid(lr) + + return pred_semantic, lr8x, [enc2x, enc4x] + + +class HRBranch(nn.Module): + """ High Resolution Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(HRBranch, self).__init__() + + self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) + self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) + + self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) + self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) + + self.conv_hr4x = nn.Sequential( + Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr2x = nn.Sequential( + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, enc2x, enc4x, lr8x, inference): + img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False) + img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False) + + enc2x = self.tohr_enc2x(enc2x) + hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) + + enc4x = self.tohr_enc4x(enc4x) + hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) + + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) + hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) + + hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) + hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) + + pred_detail = None + if not inference: + hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False) + hr = self.conv_hr(torch.cat((hr, img), dim=1)) + pred_detail = torch.sigmoid(hr) + + return pred_detail, hr2x + + +class FusionBranch(nn.Module): + """ Fusion Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(FusionBranch, self).__init__() + self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) + + self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) + self.conv_f = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), + Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, lr8x, hr2x): + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) + lr4x = self.conv_lr4x(lr4x) + lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) + + f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) + f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) + f = self.conv_f(torch.cat((f, img), dim=1)) + pred_matte = torch.sigmoid(f) + + return pred_matte + + +#------------------------------------------------------------------------------ +# MODNet +#------------------------------------------------------------------------------ + +class MODNet(nn.Module): + """ Architecture of MODNet + """ + + def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True): + super(MODNet, self).__init__() + + self.in_channels = in_channels + self.hr_channels = hr_channels + self.backbone_arch = backbone_arch + self.backbone_pretrained = backbone_pretrained + + self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels) + + self.lr_branch = LRBranch(self.backbone) + self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) + self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + self._init_conv(m) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): + self._init_norm(m) + + if self.backbone_pretrained: + self.backbone.load_pretrained_ckpt() + + def forward(self, img, inference): + pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference) + pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference) + pred_matte = self.f_branch(img, lr8x, hr2x) + + return pred_semantic, pred_detail, pred_matte + + def freeze_norm(self): + norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] + for m in self.modules(): + for n in norm_types: + if isinstance(m, n): + m.eval() + continue + + def _init_conv(self, conv): + nn.init.kaiming_uniform_( + conv.weight, a=0, mode='fan_in', nonlinearity='relu') + if conv.bias is not None: + nn.init.constant_(conv.bias, 0) + + def _init_norm(self, norm): + if norm.weight is not None: + nn.init.constant_(norm.weight, 1) + nn.init.constant_(norm.bias, 0)