diff --git a/README.md b/README.md index 569cec4..7c9cab4 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,15 @@ WebCam Video Demo [Offline][ - -## Image Matting Demo +### Image Matting We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting. It allows you to upload portrait images and predict/visualize/download the alpha mattes. -You can also use this [WebGUI](https://gradio.app/g/modnet) (hosted on [Gradio](https://github.com/gradio-app/gradio)) for portrait image matting from your browser without any code! - +### Community +Here we share some cool applications of MODNet built by the community. + +- **WebGUI for Image Matting** +You can try [this WebGUI](https://gradio.app/g/modnet) (hosted on [Gradio](https://www.gradio.app/)) for portrait matting from your browser without any code! + + +- **Colab Demo of Bokeh (Blur Background)** +You can try [this Colab demo](https://colab.research.google.com/github/eyaler/avatars4all/blob/master/yarok.ipynb) (built by [@eyaler](https://github.com/eyaler)) to blur the backgroud based on MODNet! -## TO DO -- Release training code (scheduled in **Jan. 2021**) -- Release PPM-100 validation benchmark (scheduled in **Feb. 2021**) +## Code +We provide the [code](src/trainer.py) of MODNet training iteration, including: +- **Supervised Training**: Train MODNet on a labeled matting dataset +- **SOC Adaptation**: Adapt a trained MODNet to an unlabeled dataset + +In the function comments, we provide examples of how to call the function. + + +## TODO +- Release the code of One-Frame Delay (OFD) +- Release PPM-100 validation benchmark (scheduled in **Feb 2021**) +**NOTE**: PPM-100 is a **validation set**. Our training set will not be published ## License @@ -56,8 +73,12 @@ This project (code, pre-trained models, demos, *etc.*) is released under the [Cr ## Acknowledgement -We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project. -We thank the [Gradio](https://github.com/gradio-app/gradio) team for their contributions to building the demos. +- We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project. +- We thank +        [the Gradio team](https://github.com/gradio-app/gradio), [@eyaler](https://github.com/eyaler), +for their cool applications based on MODNet. + + ## Citation If this work helps your research, please consider to cite: @@ -71,6 +92,7 @@ If this work helps your research, please consider to cite: } ``` + ## Contact This project is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)). If you have any questions, please feel free to contact `kezhanghan@outlook.com`. diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000..bff6c82 --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,295 @@ +import math +import scipy +import numpy as np +from scipy.ndimage import grey_dilation, grey_erosion + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +__all__ = [ + 'supervised_training_iter', + 'soc_adaptation_iter', +] + + +# ---------------------------------------------------------------------------------- +# Tool Classes/Functions +# ---------------------------------------------------------------------------------- + +class GaussianBlurLayer(nn.Module): + """ Add Gaussian Blur to a 4D tensors + This layer takes a 4D tensor of {N, C, H, W} as input. + The Gaussian blur will be performed in given channel number (C) splitly. + """ + + def __init__(self, channels, kernel_size): + """ + Arguments: + channels (int): Channel for input tensor + kernel_size (int): Size of the kernel used in blurring + """ + + super(GaussianBlurLayer, self).__init__() + self.channels = channels + self.kernel_size = kernel_size + assert self.kernel_size % 2 != 0 + + self.op = nn.Sequential( + nn.ReflectionPad2d(math.floor(self.kernel_size / 2)), + nn.Conv2d(channels, channels, self.kernel_size, + stride=1, padding=0, bias=None, groups=channels) + ) + + self._init_kernel() + + def forward(self, x): + """ + Arguments: + x (torch.Tensor): input 4D tensor + Returns: + torch.Tensor: Blurred version of the input + """ + + if not len(list(x.shape)) == 4: + print('\'GaussianBlurLayer\' requires a 4D tensor as input\n') + exit() + elif not x.shape[1] == self.channels: + print('In \'GaussianBlurLayer\', the required channel ({0}) is' + 'not the same as input ({1})\n'.format(self.channels, x.shape[1])) + exit() + + return self.op(x) + + def _init_kernel(self): + sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8 + + n = np.zeros((self.kernel_size, self.kernel_size)) + i = math.floor(self.kernel_size / 2) + n[i, i] = 1 + kernel = scipy.ndimage.gaussian_filter(n, sigma) + + for name, param in self.named_parameters(): + param.data.copy_(torch.from_numpy(kernel)) + +# ---------------------------------------------------------------------------------- + + +# ---------------------------------------------------------------------------------- +# MODNet Training Functions +# ---------------------------------------------------------------------------------- + +blurer = GaussianBlurLayer(1, 3).cuda() + + +def supervised_training_iter( + modnet, optimizer, image, trimap, gt_matte, + semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0): + """ Supervised training iteration of MODNet + This function trains MODNet for one iteration in a labeled dataset. + + Arguments: + modnet (torch.nn.Module): instance of MODNet + optimizer (torch.optim.Optimizer): optimizer for supervised training + image (torch.autograd.Variable): input RGB image + trimap (torch.autograd.Variable): trimap used to calculate the losses + NOTE: foreground=1, background=0, unknown=0.5 + gt_matte (torch.autograd.Variable): ground truth alpha matte + semantic_scale (float): scale of the semantic loss + NOTE: please adjust according to your dataset + detail_scale (float): scale of the detail loss + NOTE: please adjust according to your dataset + matte_scale (float): scale of the matte loss + NOTE: please adjust according to your dataset + + Returns: + semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch] + detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch] + matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch] + + Example: + import torch + from src.models.modnet import MODNet + from src.trainer import supervised_training_iter + + bs = 16 # batch size + lr = 0.01 # learn rate + epochs = 40 # total epochs + + modnet = torch.nn.DataParallel(MODNet()).cuda() + optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1) + + dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function + + for epoch in range(0, epochs): + for idx, (image, trimap, gt_matte) in enumerate(dataloader): + semantic_loss, detail_loss, matte_loss = \ + supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) + lr_scheduler.step() + """ + + global blurer + + # set the model to train mode and clear the optimizer + modnet.train() + optimizer.zero_grad() + + # forward the model + pred_semantic, pred_detail, pred_matte = modnet(image, False) + + # calculate the boundary mask from the trimap + boundaries = (trimap < 0.5) + (trimap > 0.5) + + # calculate the semantic loss + gt_semantic = F.interpolate(gt_matte, scale_factor=1/16, mode='bilinear') + gt_semantic = blurer(gt_semantic) + semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic)) + semantic_loss = semantic_scale * semantic_loss + + # calculate the detail loss + pred_boundary_detail = torch.where(boundaries, trimap, pred_detail) + gt_detail = torch.where(boundaries, trimap, gt_matte) + detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail)) + detail_loss = detail_scale * detail_loss + + # calculate the matte loss + pred_boundary_matte = torch.where(boundaries, trimap, pred_matte) + matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte) + matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \ + + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte) + matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss) + matte_loss = matte_scale * matte_loss + + # calculate the final loss, backward the loss, and update the model + loss = semantic_loss + detail_loss + matte_loss + loss.backward() + optimizer.step() + + # for test + return semantic_loss, detail_loss, matte_loss + + +def soc_adaptation_iter( + modnet, backup_modnet, optimizer, image, + soc_semantic_scale=100.0, soc_detail_scale=1.0): + """ Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet + This function fine-tunes MODNet for one iteration in an unlabeled dataset. + Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been + trained in a labeled dataset. + + Arguments: + modnet (torch.nn.Module): instance of MODNet + backup_modnet (torch.nn.Module): backup of the trained MODNet + optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC + image (torch.autograd.Variable): input RGB image + soc_semantic_scale (float): scale of the SOC semantic loss + NOTE: please adjust according to your dataset + soc_detail_scale (float): scale of the SOC detail loss + NOTE: please adjust according to your dataset + + Returns: + soc_semantic_loss (torch.Tensor): loss of the semantic SOC + soc_detail_loss (torch.Tensor): loss of the detail SOC + + Example: + import copy + import torch + from src.models.modnet import MODNet + from src.trainer import soc_adaptation_iter + + bs = 1 # batch size + lr = 0.00001 # learn rate + epochs = 10 # total epochs + + modnet = torch.nn.DataParallel(MODNet()).cuda() + modnet = LOAD_TRAINED_CKPT() # NOTE: please finish this function + + optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99)) + dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function + + for epoch in range(0, epochs): + backup_modnet = copy.deepcopy(modnet) + for idx, (image) in enumerate(dataloader): + soc_semantic_loss, soc_detail_loss = \ + soc_adaptation_iter(modnet, backup_modnet, optimizer, image) + """ + + global blurer + + # set the backup model to eval mode + backup_modnet.eval() + + # set the main model to train mode and freeze its norm layers + modnet.train() + modnet.module.freeze_norm() + + # clear the optimizer + optimizer.zero_grad() + + # forward the main model + pred_semantic, pred_detail, pred_matte = modnet(image, False) + + # forward the backup model + with torch.no_grad(): + _, pred_backup_detail, pred_backup_matte = backup_modnet(image, False) + + # calculate the boundary mask from `pred_matte` and `pred_semantic` + pred_matte_fg = (pred_matte.detach() > 0.1).float() + pred_semantic_fg = (pred_semantic.detach() > 0.1).float() + pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear') + pred_fg = pred_matte_fg * pred_semantic_fg + + n, c, h, w = pred_matte.shape + np_pred_fg = pred_fg.data.cpu().numpy() + np_boundaries = np.zeros([n, c, h, w]) + for sdx in range(0, n): + sample_np_boundaries = np_boundaries[sdx, 0, ...] + sample_np_pred_fg = np_pred_fg[sdx, 0, ...] + + side = int((h + w) / 2 * 0.05) + dilated = grey_dilation(sample_np_pred_fg, size=(side, side)) + eroded = grey_erosion(sample_np_pred_fg, size=(side, side)) + + sample_np_boundaries[np.where(dilated - eroded != 0)] = 1 + np_boundaries[sdx, 0, ...] = sample_np_boundaries + + boundaries = torch.tensor(np_boundaries).float().cuda() + + # sub-objectives consistency between `pred_semantic` and `pred_matte` + # generate pseudo ground truth for `pred_semantic` + downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1/16, mode='bilinear')) + pseudo_gt_semantic = downsampled_pred_matte.detach() + pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float() + + # generate pseudo ground truth for `pred_matte` + pseudo_gt_matte = pred_semantic.detach() + pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float() + + # calculate the SOC semantic loss + soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte, pseudo_gt_matte) + soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss) + + # NOTE: using the formulas in our paper to calculate the following losses has similar results + # sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only) + backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail) + backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) + backup_detail_loss = torch.mean(backup_detail_loss) + + # sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only) + backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte) + backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) + backup_matte_loss = torch.mean(backup_matte_loss) + + soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss) + + # calculate the final loss, backward the loss, and update the model + loss = soc_semantic_loss + soc_detail_loss + + loss.backward() + optimizer.step() + + return soc_semantic_loss, soc_detail_loss + +# ----------------------------------------------------------------------------------