mirror of https://github.com/ZHKKKe/MODNet.git
upload the code of MODNet training iteration
parent
ffe384255c
commit
9106a62265
44
README.md
44
README.md
|
|
@ -20,13 +20,15 @@ WebCam Video Demo [<a href="demo/video_matting/webcam">Offline</a>][<a href="htt
|
||||||
|
|
||||||
|
|
||||||
## News
|
## News
|
||||||
|
- [Jan 28 2021] Release the [code](src/trainer.py) of MODNet training iteration.
|
||||||
- [Dec 25 2020] ***Merry Christmas!*** :christmas_tree: Release Custom Video Matting Demo [[Offline](demo/video_matting/custom)] for user videos.
|
- [Dec 25 2020] ***Merry Christmas!*** :christmas_tree: Release Custom Video Matting Demo [[Offline](demo/video_matting/custom)] for user videos.
|
||||||
- [Dec 15 2020] A cool [WebGUI](https://gradio.app/g/modnet) for image matting based on MODNet is built by the [Gradio](https://github.com/gradio-app/gradio) team!
|
|
||||||
- [Dec 10 2020] Release WebCam Video Matting Demo [[Offline](demo/video_matting/webcam)][[Colab](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing)] and Image Matting Demo [[Colab](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing)].
|
- [Dec 10 2020] Release WebCam Video Matting Demo [[Offline](demo/video_matting/webcam)][[Colab](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing)] and Image Matting Demo [[Colab](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing)].
|
||||||
- [Nov 24 2020] Release [Arxiv Preprint](https://arxiv.org/pdf/2011.11961.pdf) and [Supplementary Video](https://youtu.be/PqJ3BRHX3Lc).
|
- [Nov 24 2020] Release [Arxiv Preprint](https://arxiv.org/pdf/2011.11961.pdf) and [Supplementary Video](https://youtu.be/PqJ3BRHX3Lc).
|
||||||
|
|
||||||
|
|
||||||
## Video Matting Demo
|
## Demos
|
||||||
|
|
||||||
|
### Video Matting
|
||||||
We provide two real-time portrait video matting demos based on WebCam. When using the demo, you can move the WebCam around at will.
|
We provide two real-time portrait video matting demos based on WebCam. When using the demo, you can move the WebCam around at will.
|
||||||
If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/webcam) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
|
If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/webcam) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
|
||||||
We also provide an [offline demo](demo/video_matting/custom) that allows you to process custom videos.
|
We also provide an [offline demo](demo/video_matting/custom) that allows you to process custom videos.
|
||||||
|
|
@ -34,21 +36,36 @@ We also provide an [offline demo](demo/video_matting/custom) that allows you to
|
||||||
<img src="doc/gif/video_matting_demo.gif" width='60%'>
|
<img src="doc/gif/video_matting_demo.gif" width='60%'>
|
||||||
|
|
||||||
|
|
||||||
|
### Image Matting
|
||||||
## Image Matting Demo
|
|
||||||
We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting.
|
We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting.
|
||||||
It allows you to upload portrait images and predict/visualize/download the alpha mattes.
|
It allows you to upload portrait images and predict/visualize/download the alpha mattes.
|
||||||
|
|
||||||
<img src="doc/gif/image_matting_demo.gif" width='40%'>
|
<img src="doc/gif/image_matting_demo.gif" width='40%'>
|
||||||
|
|
||||||
You can also use this [WebGUI](https://gradio.app/g/modnet) (hosted on [Gradio](https://github.com/gradio-app/gradio)) for portrait image matting from your browser without any code!
|
|
||||||
|
|
||||||
<img src="https://i.ibb.co/9gLxFXF/modnet.gif" width='40%'>
|
### Community
|
||||||
|
Here we share some cool applications of MODNet built by the community.
|
||||||
|
|
||||||
|
- **WebGUI for Image Matting**
|
||||||
|
You can try [this WebGUI](https://gradio.app/g/modnet) (hosted on [Gradio](https://www.gradio.app/)) for portrait matting from your browser without any code!
|
||||||
|
<!-- <img src="https://i.ibb.co/9gLxFXF/modnet.gif" width='40%'> -->
|
||||||
|
|
||||||
|
- **Colab Demo of Bokeh (Blur Background)**
|
||||||
|
You can try [this Colab demo](https://colab.research.google.com/github/eyaler/avatars4all/blob/master/yarok.ipynb) (built by [@eyaler](https://github.com/eyaler)) to blur the backgroud based on MODNet!
|
||||||
|
|
||||||
|
|
||||||
## TO DO
|
## Code
|
||||||
- Release training code (scheduled in **Jan. 2021**)
|
We provide the [code](src/trainer.py) of MODNet training iteration, including:
|
||||||
- Release PPM-100 validation benchmark (scheduled in **Feb. 2021**)
|
- **Supervised Training**: Train MODNet on a labeled matting dataset
|
||||||
|
- **SOC Adaptation**: Adapt a trained MODNet to an unlabeled dataset
|
||||||
|
|
||||||
|
In the function comments, we provide examples of how to call the function.
|
||||||
|
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
- Release the code of One-Frame Delay (OFD)
|
||||||
|
- Release PPM-100 validation benchmark (scheduled in **Feb 2021**)
|
||||||
|
**NOTE**: PPM-100 is a **validation set**. Our training set will not be published
|
||||||
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
@ -56,8 +73,12 @@ This project (code, pre-trained models, demos, *etc.*) is released under the [Cr
|
||||||
|
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project.
|
- We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project.
|
||||||
We thank the [Gradio](https://github.com/gradio-app/gradio) team for their contributions to building the demos.
|
- We thank
|
||||||
|
[the Gradio team](https://github.com/gradio-app/gradio), [@eyaler](https://github.com/eyaler),
|
||||||
|
for their cool applications based on MODNet.
|
||||||
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
If this work helps your research, please consider to cite:
|
If this work helps your research, please consider to cite:
|
||||||
|
|
||||||
|
|
@ -71,6 +92,7 @@ If this work helps your research, please consider to cite:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Contact
|
## Contact
|
||||||
This project is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
|
This project is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
|
||||||
If you have any questions, please feel free to contact `kezhanghan@outlook.com`.
|
If you have any questions, please feel free to contact `kezhanghan@outlook.com`.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,295 @@
|
||||||
|
import math
|
||||||
|
import scipy
|
||||||
|
import numpy as np
|
||||||
|
from scipy.ndimage import grey_dilation, grey_erosion
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'supervised_training_iter',
|
||||||
|
'soc_adaptation_iter',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------
|
||||||
|
# Tool Classes/Functions
|
||||||
|
# ----------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class GaussianBlurLayer(nn.Module):
|
||||||
|
""" Add Gaussian Blur to a 4D tensors
|
||||||
|
This layer takes a 4D tensor of {N, C, H, W} as input.
|
||||||
|
The Gaussian blur will be performed in given channel number (C) splitly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, kernel_size):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
channels (int): Channel for input tensor
|
||||||
|
kernel_size (int): Size of the kernel used in blurring
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(GaussianBlurLayer, self).__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
assert self.kernel_size % 2 != 0
|
||||||
|
|
||||||
|
self.op = nn.Sequential(
|
||||||
|
nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
|
||||||
|
nn.Conv2d(channels, channels, self.kernel_size,
|
||||||
|
stride=1, padding=0, bias=None, groups=channels)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_kernel()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
x (torch.Tensor): input 4D tensor
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Blurred version of the input
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not len(list(x.shape)) == 4:
|
||||||
|
print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
|
||||||
|
exit()
|
||||||
|
elif not x.shape[1] == self.channels:
|
||||||
|
print('In \'GaussianBlurLayer\', the required channel ({0}) is'
|
||||||
|
'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
|
||||||
|
exit()
|
||||||
|
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
def _init_kernel(self):
|
||||||
|
sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
|
||||||
|
|
||||||
|
n = np.zeros((self.kernel_size, self.kernel_size))
|
||||||
|
i = math.floor(self.kernel_size / 2)
|
||||||
|
n[i, i] = 1
|
||||||
|
kernel = scipy.ndimage.gaussian_filter(n, sigma)
|
||||||
|
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
param.data.copy_(torch.from_numpy(kernel))
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------
|
||||||
|
# MODNet Training Functions
|
||||||
|
# ----------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
blurer = GaussianBlurLayer(1, 3).cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def supervised_training_iter(
|
||||||
|
modnet, optimizer, image, trimap, gt_matte,
|
||||||
|
semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
|
||||||
|
""" Supervised training iteration of MODNet
|
||||||
|
This function trains MODNet for one iteration in a labeled dataset.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
modnet (torch.nn.Module): instance of MODNet
|
||||||
|
optimizer (torch.optim.Optimizer): optimizer for supervised training
|
||||||
|
image (torch.autograd.Variable): input RGB image
|
||||||
|
trimap (torch.autograd.Variable): trimap used to calculate the losses
|
||||||
|
NOTE: foreground=1, background=0, unknown=0.5
|
||||||
|
gt_matte (torch.autograd.Variable): ground truth alpha matte
|
||||||
|
semantic_scale (float): scale of the semantic loss
|
||||||
|
NOTE: please adjust according to your dataset
|
||||||
|
detail_scale (float): scale of the detail loss
|
||||||
|
NOTE: please adjust according to your dataset
|
||||||
|
matte_scale (float): scale of the matte loss
|
||||||
|
NOTE: please adjust according to your dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
|
||||||
|
detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
|
||||||
|
matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
import torch
|
||||||
|
from src.models.modnet import MODNet
|
||||||
|
from src.trainer import supervised_training_iter
|
||||||
|
|
||||||
|
bs = 16 # batch size
|
||||||
|
lr = 0.01 # learn rate
|
||||||
|
epochs = 40 # total epochs
|
||||||
|
|
||||||
|
modnet = torch.nn.DataParallel(MODNet()).cuda()
|
||||||
|
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
|
||||||
|
|
||||||
|
dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function
|
||||||
|
|
||||||
|
for epoch in range(0, epochs):
|
||||||
|
for idx, (image, trimap, gt_matte) in enumerate(dataloader):
|
||||||
|
semantic_loss, detail_loss, matte_loss = \
|
||||||
|
supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
|
||||||
|
lr_scheduler.step()
|
||||||
|
"""
|
||||||
|
|
||||||
|
global blurer
|
||||||
|
|
||||||
|
# set the model to train mode and clear the optimizer
|
||||||
|
modnet.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# forward the model
|
||||||
|
pred_semantic, pred_detail, pred_matte = modnet(image, False)
|
||||||
|
|
||||||
|
# calculate the boundary mask from the trimap
|
||||||
|
boundaries = (trimap < 0.5) + (trimap > 0.5)
|
||||||
|
|
||||||
|
# calculate the semantic loss
|
||||||
|
gt_semantic = F.interpolate(gt_matte, scale_factor=1/16, mode='bilinear')
|
||||||
|
gt_semantic = blurer(gt_semantic)
|
||||||
|
semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
|
||||||
|
semantic_loss = semantic_scale * semantic_loss
|
||||||
|
|
||||||
|
# calculate the detail loss
|
||||||
|
pred_boundary_detail = torch.where(boundaries, trimap, pred_detail)
|
||||||
|
gt_detail = torch.where(boundaries, trimap, gt_matte)
|
||||||
|
detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail))
|
||||||
|
detail_loss = detail_scale * detail_loss
|
||||||
|
|
||||||
|
# calculate the matte loss
|
||||||
|
pred_boundary_matte = torch.where(boundaries, trimap, pred_matte)
|
||||||
|
matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
|
||||||
|
matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
|
||||||
|
+ 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
|
||||||
|
matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
|
||||||
|
matte_loss = matte_scale * matte_loss
|
||||||
|
|
||||||
|
# calculate the final loss, backward the loss, and update the model
|
||||||
|
loss = semantic_loss + detail_loss + matte_loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# for test
|
||||||
|
return semantic_loss, detail_loss, matte_loss
|
||||||
|
|
||||||
|
|
||||||
|
def soc_adaptation_iter(
|
||||||
|
modnet, backup_modnet, optimizer, image,
|
||||||
|
soc_semantic_scale=100.0, soc_detail_scale=1.0):
|
||||||
|
""" Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet
|
||||||
|
This function fine-tunes MODNet for one iteration in an unlabeled dataset.
|
||||||
|
Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been
|
||||||
|
trained in a labeled dataset.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
modnet (torch.nn.Module): instance of MODNet
|
||||||
|
backup_modnet (torch.nn.Module): backup of the trained MODNet
|
||||||
|
optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC
|
||||||
|
image (torch.autograd.Variable): input RGB image
|
||||||
|
soc_semantic_scale (float): scale of the SOC semantic loss
|
||||||
|
NOTE: please adjust according to your dataset
|
||||||
|
soc_detail_scale (float): scale of the SOC detail loss
|
||||||
|
NOTE: please adjust according to your dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
soc_semantic_loss (torch.Tensor): loss of the semantic SOC
|
||||||
|
soc_detail_loss (torch.Tensor): loss of the detail SOC
|
||||||
|
|
||||||
|
Example:
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from src.models.modnet import MODNet
|
||||||
|
from src.trainer import soc_adaptation_iter
|
||||||
|
|
||||||
|
bs = 1 # batch size
|
||||||
|
lr = 0.00001 # learn rate
|
||||||
|
epochs = 10 # total epochs
|
||||||
|
|
||||||
|
modnet = torch.nn.DataParallel(MODNet()).cuda()
|
||||||
|
modnet = LOAD_TRAINED_CKPT() # NOTE: please finish this function
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99))
|
||||||
|
dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function
|
||||||
|
|
||||||
|
for epoch in range(0, epochs):
|
||||||
|
backup_modnet = copy.deepcopy(modnet)
|
||||||
|
for idx, (image) in enumerate(dataloader):
|
||||||
|
soc_semantic_loss, soc_detail_loss = \
|
||||||
|
soc_adaptation_iter(modnet, backup_modnet, optimizer, image)
|
||||||
|
"""
|
||||||
|
|
||||||
|
global blurer
|
||||||
|
|
||||||
|
# set the backup model to eval mode
|
||||||
|
backup_modnet.eval()
|
||||||
|
|
||||||
|
# set the main model to train mode and freeze its norm layers
|
||||||
|
modnet.train()
|
||||||
|
modnet.module.freeze_norm()
|
||||||
|
|
||||||
|
# clear the optimizer
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# forward the main model
|
||||||
|
pred_semantic, pred_detail, pred_matte = modnet(image, False)
|
||||||
|
|
||||||
|
# forward the backup model
|
||||||
|
with torch.no_grad():
|
||||||
|
_, pred_backup_detail, pred_backup_matte = backup_modnet(image, False)
|
||||||
|
|
||||||
|
# calculate the boundary mask from `pred_matte` and `pred_semantic`
|
||||||
|
pred_matte_fg = (pred_matte.detach() > 0.1).float()
|
||||||
|
pred_semantic_fg = (pred_semantic.detach() > 0.1).float()
|
||||||
|
pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear')
|
||||||
|
pred_fg = pred_matte_fg * pred_semantic_fg
|
||||||
|
|
||||||
|
n, c, h, w = pred_matte.shape
|
||||||
|
np_pred_fg = pred_fg.data.cpu().numpy()
|
||||||
|
np_boundaries = np.zeros([n, c, h, w])
|
||||||
|
for sdx in range(0, n):
|
||||||
|
sample_np_boundaries = np_boundaries[sdx, 0, ...]
|
||||||
|
sample_np_pred_fg = np_pred_fg[sdx, 0, ...]
|
||||||
|
|
||||||
|
side = int((h + w) / 2 * 0.05)
|
||||||
|
dilated = grey_dilation(sample_np_pred_fg, size=(side, side))
|
||||||
|
eroded = grey_erosion(sample_np_pred_fg, size=(side, side))
|
||||||
|
|
||||||
|
sample_np_boundaries[np.where(dilated - eroded != 0)] = 1
|
||||||
|
np_boundaries[sdx, 0, ...] = sample_np_boundaries
|
||||||
|
|
||||||
|
boundaries = torch.tensor(np_boundaries).float().cuda()
|
||||||
|
|
||||||
|
# sub-objectives consistency between `pred_semantic` and `pred_matte`
|
||||||
|
# generate pseudo ground truth for `pred_semantic`
|
||||||
|
downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1/16, mode='bilinear'))
|
||||||
|
pseudo_gt_semantic = downsampled_pred_matte.detach()
|
||||||
|
pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float()
|
||||||
|
|
||||||
|
# generate pseudo ground truth for `pred_matte`
|
||||||
|
pseudo_gt_matte = pred_semantic.detach()
|
||||||
|
pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float()
|
||||||
|
|
||||||
|
# calculate the SOC semantic loss
|
||||||
|
soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte, pseudo_gt_matte)
|
||||||
|
soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss)
|
||||||
|
|
||||||
|
# NOTE: using the formulas in our paper to calculate the following losses has similar results
|
||||||
|
# sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only)
|
||||||
|
backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail)
|
||||||
|
backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
|
||||||
|
backup_detail_loss = torch.mean(backup_detail_loss)
|
||||||
|
|
||||||
|
# sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only)
|
||||||
|
backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte)
|
||||||
|
backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
|
||||||
|
backup_matte_loss = torch.mean(backup_matte_loss)
|
||||||
|
|
||||||
|
soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss)
|
||||||
|
|
||||||
|
# calculate the final loss, backward the loss, and update the model
|
||||||
|
loss = soc_semantic_loss + soc_detail_loss
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
return soc_semantic_loss, soc_detail_loss
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------
|
||||||
Loading…
Reference in New Issue