diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d64f109 --- /dev/null +++ b/.gitignore @@ -0,0 +1,96 @@ +# Temporary directories and files +*.ckpt + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# IPython Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# dotenv +.env + +# virtualenv +venv/ +ENV/ + +# Spyder project settings +.spyderproject + +# Rope project settings +.ropeproject + + +# Project files +.vscode \ No newline at end of file diff --git a/README.md b/README.md index b011385..2eb8f98 100644 --- a/README.md +++ b/README.md @@ -4,27 +4,44 @@
Arxiv Preprint | - Supplementary Video + Supplementary Video | + Video Matting Demo :fire: | + Image Matting Demo :fire:
+
+
+## TO DO
+- Release training code (scheduled in **Jan. 2021**)
+- Release PPM-100 validation benchmark (scheduled in **Feb. 2021**)
+
+## Acknowledgement
+We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project.
+
+
## Citation
If this work helps your research, please consider to cite:
@@ -37,3 +54,7 @@ If this work helps your research, please consider to cite:
year = {2020},
}
```
+
+## Contact
+This project is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
+If you have any questions, please feel free to contact `kezhanghan@outlook.com`.
diff --git a/demo/image_matting/README.md b/demo/image_matting/README.md
new file mode 100644
index 0000000..6642651
--- /dev/null
+++ b/demo/image_matting/README.md
@@ -0,0 +1,2 @@
+## MODNet - Portrait Image Matting Demo
+Please try MODNet portrait image matting demo through our [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing).
diff --git a/demo/image_matting/inference.py b/demo/image_matting/inference.py
new file mode 100644
index 0000000..2cdaff7
--- /dev/null
+++ b/demo/image_matting/inference.py
@@ -0,0 +1,99 @@
+import os
+import sys
+import argparse
+import numpy as np
+from PIL import Image
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+
+from src.models.modnet import MODNet
+
+
+if __name__ == '__main__':
+ # define cmd arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input-path', type=str, help='path of input images')
+ parser.add_argument('--output-path', type=str, help='path of output images')
+ parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet')
+ args = parser.parse_args()
+
+ # check input arguments
+ if not os.path.exists(args.input_path):
+ print('Cannot find input path: {0}'.format(args.input_path))
+ exit()
+ if not os.path.exists(args.output_path):
+ print('Cannot find output path: {0}'.format(args.output_path))
+ exit()
+ if not os.path.exists(args.ckpt_path):
+ print('Cannot find ckpt path: {0}'.format(args.ckpt_path))
+ exit()
+
+ # define hyper-parameters
+ ref_size = 512
+
+ # define image to tensor transform
+ im_transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+ ]
+ )
+
+ # create MODNet and load the pre-trained ckpt
+ modnet = MODNet(backbone_pretrained=False)
+ modnet = nn.DataParallel(modnet).cuda()
+ modnet.load_state_dict(torch.load(args.ckpt_path))
+ modnet.eval()
+
+ # inference images
+ im_names = os.listdir(args.input_path)
+ for im_name in im_names:
+ print('Process image: {0}'.format(im_name))
+
+ # read image
+ im = Image.open(os.path.join(args.input_path, im_name))
+
+ # unify image channels to 3
+ im = np.asarray(im)
+ if len(im.shape) == 2:
+ im = im[:, :, None]
+ if im.shape[2] == 1:
+ im = np.repeat(im, 3, axis=2)
+ elif im.shape[2] == 4:
+ im = im[:, :, 0:3]
+
+ # convert image to PyTorch tensor
+ im = Image.fromarray(im)
+ im = im_transform(im)
+
+ # add mini-batch dim
+ im = im[None, :, :, :]
+
+ # resize image for input
+ im_b, im_c, im_h, im_w = im.shape
+ if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
+ if im_w >= im_h:
+ im_rh = ref_size
+ im_rw = int(im_w / im_h * ref_size)
+ elif im_w < im_h:
+ im_rw = ref_size
+ im_rh = int(im_h / im_w * ref_size)
+ else:
+ im_rh = im_h
+ im_rw = im_w
+
+ im_rw = im_rw - im_rw % 32
+ im_rh = im_rh - im_rh % 32
+ im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
+
+ # inference
+ _, _, matte = modnet(im.cuda(), inference=False)
+
+ # resize and save matte
+ matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
+ matte = matte[0][0].data.cpu().numpy()
+ matte_name = im_name.split('.')[0] + '.png'
+ Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name))
diff --git a/demo/video_matting/README.md b/demo/video_matting/README.md
new file mode 100644
index 0000000..f622201
--- /dev/null
+++ b/demo/video_matting/README.md
@@ -0,0 +1,50 @@
+## MODNet - WebCam-Based Portrait Video Matting Demo
+This is a MODNet portrait video matting demo based on WebCam. It will call your local WebCam and display the matting results in real time.
+
+### Requirements
+The basic requirements for this demo are:
+- Ubuntu System
+- WebCam
+- Nvidia GPU with CUDA
+- Python 3+
+
+**NOTE**: If your device does not satisfy the above conditions, please try our [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
+
+
+### Introduction
+We use ~400 unlabeled video clips (divided into ~50,000 frames) downloaded from the internet to perform SOC to adapt MODNet to the video domain. Nonetheless, due to insufficient labeled training data (~3k labeled foregrounds), our model may still make errors in portrait semantics estimation under challenging scenes. Besides, this demo does not currently support the OFD trick, which will be provided soon.
+
+For a better experience, please:
+
+* make sure the portrait and background are distinguishable, i.e., are not similar
+* run in soft and bright ambient lighting
+* do not be too close or too far from the WebCam
+* do not move too fast
+
+### Run Demo
+We recommend creating a new conda virtual environment to run this demo, as follow:
+
+1. Clone the MODNet repository:
+ ```
+ git clone https://github.com/ZHKKKe/MODNet.git
+ cd MODNet
+ ```
+
+2. Download the pre-trained model from this [link](https://drive.google.com/file/d/1Nf1ZxeJZJL8Qx9KadcYYyEmmlKhTADxX/view?usp=sharing) and put it into the folder `MODNet/pretrained/`.
+
+
+3. Create a conda virtual environment named `modnet-webcam` and activate it:
+ ```
+ conda create -n modnet-webcam python=3.6
+ source activate modnet-webcam
+ ```
+
+4. Install the required python dependencies (here we use PyTorch==1.0.0):
+ ```
+ pip install -r demo/video_matting/requirements.txt
+ ```
+
+5. Execute the main code:
+ ```
+ python -m demo.video_matting.webcam
+ ```
diff --git a/demo/video_matting/requirements.txt b/demo/video_matting/requirements.txt
new file mode 100644
index 0000000..eb0d67b
--- /dev/null
+++ b/demo/video_matting/requirements.txt
@@ -0,0 +1,5 @@
+numpy
+Pillow
+opencv-python
+torch == 1.0.0
+torchvision
\ No newline at end of file
diff --git a/demo/video_matting/webcam.py b/demo/video_matting/webcam.py
new file mode 100644
index 0000000..fc09aeb
--- /dev/null
+++ b/demo/video_matting/webcam.py
@@ -0,0 +1,56 @@
+import cv2
+import numpy as np
+from PIL import Image
+
+import torch
+import torch.nn as nn
+import torchvision.transforms as transforms
+
+from src.models.modnet import MODNet
+
+
+torch_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ ]
+)
+
+print('Load pre-trained MODNet...')
+pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
+modnet = MODNet(backbone_pretrained=False)
+modnet = nn.DataParallel(modnet).cuda()
+modnet.load_state_dict(torch.load(pretrained_ckpt))
+modnet.eval()
+
+print('Init WebCam...')
+cap = cv2.VideoCapture(0)
+cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
+cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
+
+print('Start matting...')
+while(True):
+ _, frame_np = cap.read()
+ frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)
+ frame_np = cv2.resize(frame_np, (910, 512), cv2.INTER_AREA)
+ frame_np = frame_np[:, 120:792, :]
+ frame_np = cv2.flip(frame_np, 1)
+
+ frame_PIL = Image.fromarray(frame_np)
+ frame_tensor = torch_transforms(frame_PIL)
+ frame_tensor = frame_tensor[None, :, :, :].cuda()
+
+ with torch.no_grad():
+ _, _, matte_tensor = modnet(frame_tensor, inference=True)
+
+ matte_tensor = matte_tensor.repeat(1, 3, 1, 1)
+ matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0)
+ fg_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0)
+ view_np = np.uint8(np.concatenate((frame_np, fg_np), axis=1))
+ view_np = cv2.cvtColor(view_np, cv2.COLOR_RGB2BGR)
+
+ cv2.imshow('MODNet - WebCam [Press \'Q\' To Exit]', view_np)
+ if cv2.waitKey(1) & 0xFF == ord('q'):
+ break
+
+print('Exit...')
diff --git a/doc/gif/image_matting_demo.gif b/doc/gif/image_matting_demo.gif
new file mode 100644
index 0000000..a1a5e87
Binary files /dev/null and b/doc/gif/image_matting_demo.gif differ
diff --git a/pretrained/README.md b/pretrained/README.md
new file mode 100644
index 0000000..7eaa227
--- /dev/null
+++ b/pretrained/README.md
@@ -0,0 +1,2 @@
+## MODNet - Pre-Trained Models
+This folder is used to save the official pre-trained models of MODNet. You can download them from this [link](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR?usp=sharing).
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/models/backbones/__init__.py b/src/models/backbones/__init__.py
new file mode 100644
index 0000000..4cbeee5
--- /dev/null
+++ b/src/models/backbones/__init__.py
@@ -0,0 +1,10 @@
+from .wrapper import *
+
+
+#------------------------------------------------------------------------------
+# Replaceable Backbones
+#------------------------------------------------------------------------------
+
+SUPPORTED_BACKBONES = {
+ 'mobilenetv2': MobileNetV2Backbone,
+}
diff --git a/src/models/backbones/mobilenetv2.py b/src/models/backbones/mobilenetv2.py
new file mode 100644
index 0000000..67cc138
--- /dev/null
+++ b/src/models/backbones/mobilenetv2.py
@@ -0,0 +1,185 @@
+""" This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch"""
+
+import math
+import json
+from functools import reduce
+
+import torch
+from torch import nn
+
+
+#------------------------------------------------------------------------------
+# Useful functions
+#------------------------------------------------------------------------------
+
+def _make_divisible(v, divisor, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+def conv_bn(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU6(inplace=True)
+ )
+
+
+def conv_1x1_bn(inp, oup):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU6(inplace=True)
+ )
+
+
+#------------------------------------------------------------------------------
+# Class of Inverted Residual block
+#------------------------------------------------------------------------------
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expansion, dilation=1):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = round(inp * expansion)
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ if expansion == 1:
+ self.conv = nn.Sequential(
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+ else:
+ self.conv = nn.Sequential(
+ # pw
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
+ nn.BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+#------------------------------------------------------------------------------
+# Class of MobileNetV2
+#------------------------------------------------------------------------------
+
+class MobileNetV2(nn.Module):
+ def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
+ super(MobileNetV2, self).__init__()
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ input_channel = 32
+ last_channel = 1280
+ interverted_residual_setting = [
+ # t, c, n, s
+ [1 , 16, 1, 1],
+ [expansion, 24, 2, 2],
+ [expansion, 32, 3, 2],
+ [expansion, 64, 4, 2],
+ [expansion, 96, 3, 1],
+ [expansion, 160, 3, 2],
+ [expansion, 320, 1, 1],
+ ]
+
+ # building first layer
+ input_channel = _make_divisible(input_channel*alpha, 8)
+ self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel
+ self.features = [conv_bn(self.in_channels, input_channel, 2)]
+
+ # building inverted residual blocks
+ for t, c, n, s in interverted_residual_setting:
+ output_channel = _make_divisible(int(c*alpha), 8)
+ for i in range(n):
+ if i == 0:
+ self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
+ else:
+ self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
+ input_channel = output_channel
+
+ # building last several layers
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
+
+ # make it nn.Sequential
+ self.features = nn.Sequential(*self.features)
+
+ # building classifier
+ if self.num_classes is not None:
+ self.classifier = nn.Sequential(
+ nn.Dropout(0.2),
+ nn.Linear(self.last_channel, num_classes),
+ )
+
+ # Initialize weights
+ self._init_weights()
+
+ def forward(self, x, feature_names=None):
+ # Stage1
+ x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)
+ # Stage2
+ x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)
+ # Stage3
+ x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)
+ # Stage4
+ x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x)
+ # Stage5
+ x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x)
+
+ # Classification
+ if self.num_classes is not None:
+ x = x.mean(dim=(2,3))
+ x = self.classifier(x)
+
+ # Output
+ return x
+
+ def _load_pretrained_model(self, pretrained_file):
+ pretrain_dict = torch.load(pretrained_file, map_location='cpu')
+ model_dict = {}
+ state_dict = self.state_dict()
+ print("[MobileNetV2] Loading pretrained model...")
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ else:
+ print(k, "is ignored")
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ n = m.weight.size(1)
+ m.weight.data.normal_(0, 0.01)
+ m.bias.data.zero_()
diff --git a/src/models/backbones/wrapper.py b/src/models/backbones/wrapper.py
new file mode 100644
index 0000000..36817ba
--- /dev/null
+++ b/src/models/backbones/wrapper.py
@@ -0,0 +1,59 @@
+import os
+from functools import reduce
+
+import torch
+import torch.nn as nn
+
+from .mobilenetv2 import MobileNetV2
+
+
+class BaseBackbone(nn.Module):
+ """ Superclass of Replaceable Backbone Model for Semantic Estimation
+ """
+
+ def __init__(self, in_channels):
+ super(BaseBackbone, self).__init__()
+ self.in_channels = in_channels
+
+ self.model = None
+ self.enc_channels = []
+
+ def forward(self, x):
+ raise NotImplementedError
+
+ def load_pretrained_ckpt(self):
+ raise NotImplementedError
+
+
+class MobileNetV2Backbone(BaseBackbone):
+ """ MobileNetV2 Backbone
+ """
+
+ def __init__(self, in_channels):
+ super(MobileNetV2Backbone, self).__init__(in_channels)
+
+ self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
+ self.enc_channels = [16, 24, 32, 96, 1280]
+
+ def forward(self, x):
+ x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
+ enc2x = x
+ x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
+ enc4x = x
+ x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
+ enc8x = x
+ x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
+ enc16x = x
+ x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
+ enc32x = x
+ return [enc2x, enc4x, enc8x, enc16x, enc32x]
+
+ def load_pretrained_ckpt(self):
+ # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
+ ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
+ if not os.path.exists(ckpt_path):
+ print('cannot find the pretrained mobilenetv2 backbone')
+ exit()
+
+ ckpt = torch.load(ckpt_path)
+ self.model.load_state_dict(ckpt)
diff --git a/src/models/modnet.py b/src/models/modnet.py
new file mode 100644
index 0000000..16609b3
--- /dev/null
+++ b/src/models/modnet.py
@@ -0,0 +1,255 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .backbones import SUPPORTED_BACKBONES
+
+
+#------------------------------------------------------------------------------
+# MODNet Basic Modules
+#------------------------------------------------------------------------------
+
+class IBNorm(nn.Module):
+ """ Combine Instance Norm and Batch Norm into One Layer
+ """
+
+ def __init__(self, in_channels):
+ super(IBNorm, self).__init__()
+ in_channels = in_channels
+ self.bnorm_channels = int(in_channels / 2)
+ self.inorm_channels = in_channels - self.bnorm_channels
+
+ self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
+ self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
+
+ def forward(self, x):
+ bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
+ in_x = self.inorm(x[:, self.inorm_channels:, ...].contiguous())
+
+ return torch.cat((bn_x, in_x), 1)
+
+
+class Conv2dIBNormRelu(nn.Module):
+ """ Convolution + IBNorm + ReLu
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride=1, padding=0, dilation=1, groups=1, bias=True,
+ with_ibn=True, with_relu=True):
+ super(Conv2dIBNormRelu, self).__init__()
+
+ layers = [
+ nn.Conv2d(in_channels, out_channels, kernel_size,
+ stride=stride, padding=padding, dilation=dilation,
+ groups=groups, bias=bias)
+ ]
+
+ if with_ibn:
+ layers.append(IBNorm(out_channels))
+ if with_relu:
+ layers.append(nn.ReLU(inplace=True))
+
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class SEBlock(nn.Module):
+ """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
+ """
+
+ def __init__(self, in_channels, out_channels, reduction=1):
+ super(SEBlock, self).__init__()
+ self.pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(in_channels, int(in_channels // reduction), bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(int(in_channels // reduction), out_channels, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ w = self.pool(x).view(b, c)
+ w = self.fc(w).view(b, c, 1, 1)
+
+ return x * w.expand_as(x)
+
+
+#------------------------------------------------------------------------------
+# MODNet Branches
+#------------------------------------------------------------------------------
+
+class LRBranch(nn.Module):
+ """ Low Resolution Branch of MODNet
+ """
+
+ def __init__(self, backbone):
+ super(LRBranch, self).__init__()
+
+ enc_channels = backbone.enc_channels
+
+ self.backbone = backbone
+ self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
+ self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
+ self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
+ self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)
+
+ def forward(self, img, inference):
+ enc_features = self.backbone.forward(img)
+ enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
+
+ enc32x = self.se_block(enc32x)
+ lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
+ lr16x = self.conv_lr16x(lr16x)
+ lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
+ lr8x = self.conv_lr8x(lr8x)
+
+ pred_semantic = None
+ if not inference:
+ lr = self.conv_lr(lr8x)
+ pred_semantic = torch.sigmoid(lr)
+
+ return pred_semantic, lr8x, [enc2x, enc4x]
+
+
+class HRBranch(nn.Module):
+ """ High Resolution Branch of MODNet
+ """
+
+ def __init__(self, hr_channels, enc_channels):
+ super(HRBranch, self).__init__()
+
+ self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
+ self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
+
+ self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
+ self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
+
+ self.conv_hr4x = nn.Sequential(
+ Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
+ Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
+ Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
+ )
+
+ self.conv_hr2x = nn.Sequential(
+ Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
+ Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
+ Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
+ )
+
+ self.conv_hr = nn.Sequential(
+ Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
+ Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
+ )
+
+ def forward(self, img, enc2x, enc4x, lr8x, inference):
+ img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
+ img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)
+
+ enc2x = self.tohr_enc2x(enc2x)
+ hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
+
+ enc4x = self.tohr_enc4x(enc4x)
+ hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
+
+ lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
+ hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
+
+ hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
+ hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
+
+ pred_detail = None
+ if not inference:
+ hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
+ hr = self.conv_hr(torch.cat((hr, img), dim=1))
+ pred_detail = torch.sigmoid(hr)
+
+ return pred_detail, hr2x
+
+
+class FusionBranch(nn.Module):
+ """ Fusion Branch of MODNet
+ """
+
+ def __init__(self, hr_channels, enc_channels):
+ super(FusionBranch, self).__init__()
+ self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
+
+ self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
+ self.conv_f = nn.Sequential(
+ Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
+ Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
+ )
+
+ def forward(self, img, lr8x, hr2x):
+ lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
+ lr4x = self.conv_lr4x(lr4x)
+ lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
+
+ f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
+ f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
+ f = self.conv_f(torch.cat((f, img), dim=1))
+ pred_matte = torch.sigmoid(f)
+
+ return pred_matte
+
+
+#------------------------------------------------------------------------------
+# MODNet
+#------------------------------------------------------------------------------
+
+class MODNet(nn.Module):
+ """ Architecture of MODNet
+ """
+
+ def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
+ super(MODNet, self).__init__()
+
+ self.in_channels = in_channels
+ self.hr_channels = hr_channels
+ self.backbone_arch = backbone_arch
+ self.backbone_pretrained = backbone_pretrained
+
+ self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)
+
+ self.lr_branch = LRBranch(self.backbone)
+ self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
+ self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ self._init_conv(m)
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
+ self._init_norm(m)
+
+ if self.backbone_pretrained:
+ self.backbone.load_pretrained_ckpt()
+
+ def forward(self, img, inference):
+ pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
+ pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
+ pred_matte = self.f_branch(img, lr8x, hr2x)
+
+ return pred_semantic, pred_detail, pred_matte
+
+ def freeze_norm(self):
+ norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
+ for m in self.modules():
+ for n in norm_types:
+ if isinstance(m, n):
+ m.eval()
+ continue
+
+ def _init_conv(self, conv):
+ nn.init.kaiming_uniform_(
+ conv.weight, a=0, mode='fan_in', nonlinearity='relu')
+ if conv.bias is not None:
+ nn.init.constant_(conv.bias, 0)
+
+ def _init_norm(self, norm):
+ if norm.weight is not None:
+ nn.init.constant_(norm.weight, 1)
+ nn.init.constant_(norm.bias, 0)