mirror of https://github.com/ZHKKKe/MODNet.git
pip install-able
parent
28165a451e
commit
72ad182ea2
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Export ONNX model of MODNet with:
|
Export ONNX model of MODNet with:
|
||||||
input shape: (batch_size, 3, height, width)
|
input shape: (batch_size, 3, height, width)
|
||||||
output shape: (batch_size, 1, height, width)
|
output shape: (batch_size, 1, height, width)
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
--ckpt-path: path of the checkpoint that will be converted
|
--ckpt-path: path of the checkpoint that will be converted
|
||||||
|
|
@ -50,6 +50,6 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# export to onnx model
|
# export to onnx model
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
modnet.module, dummy_input, args.output_path, export_params = True,
|
modnet.module, dummy_input, args.output_path, export_params = True,
|
||||||
input_names = ['input'], output_names = ['output'],
|
input_names = ['input'], output_names = ['output'],
|
||||||
dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'}, 'output': {0: 'batch_size', 2: 'height', 3: 'width'}})
|
dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'}, 'output': {0: 'batch_size', 2: 'height', 3: 'width'}})
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) 2024 Synthesia Limited - All Rights Reserved
|
||||||
|
#
|
||||||
|
# Unauthorized copying of this file, via any medium is strictly prohibited.
|
||||||
|
# Proprietary and confidential.
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools", "setuptools-scm"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "modnet"
|
||||||
|
version = "0.0.1"
|
||||||
|
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"torch",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
@ -4,7 +4,7 @@ from functools import reduce
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .mobilenetv2 import MobileNetV2
|
from modnet.models.backbones.mobilenetv2 import MobileNetV2
|
||||||
|
|
||||||
|
|
||||||
class BaseBackbone(nn.Module):
|
class BaseBackbone(nn.Module):
|
||||||
|
|
@ -26,7 +26,7 @@ class BaseBackbone(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MobileNetV2Backbone(BaseBackbone):
|
class MobileNetV2Backbone(BaseBackbone):
|
||||||
""" MobileNetV2 Backbone
|
""" MobileNetV2 Backbone
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
|
|
@ -72,11 +72,11 @@ class MobileNetV2Backbone(BaseBackbone):
|
||||||
return [enc2x, enc4x, enc8x, enc16x, enc32x]
|
return [enc2x, enc4x, enc8x, enc16x, enc32x]
|
||||||
|
|
||||||
def load_pretrained_ckpt(self):
|
def load_pretrained_ckpt(self):
|
||||||
# the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
|
# the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
|
||||||
ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
|
ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
|
||||||
if not os.path.exists(ckpt_path):
|
if not os.path.exists(ckpt_path):
|
||||||
print('cannot find the pretrained mobilenetv2 backbone')
|
print('cannot find the pretrained mobilenetv2 backbone')
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
ckpt = torch.load(ckpt_path)
|
ckpt = torch.load(ckpt_path)
|
||||||
self.model.load_state_dict(ckpt)
|
self.model.load_state_dict(ckpt)
|
||||||
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .backbones import SUPPORTED_BACKBONES
|
from modnet.models.backbones import SUPPORTED_BACKBONES
|
||||||
|
|
||||||
|
|
||||||
#------------------------------------------------------------------------------
|
#------------------------------------------------------------------------------
|
||||||
|
|
@ -21,7 +21,7 @@ class IBNorm(nn.Module):
|
||||||
|
|
||||||
self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
|
self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
|
||||||
self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
|
self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
|
bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
|
||||||
in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
|
in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
|
||||||
|
|
@ -33,18 +33,18 @@ class Conv2dIBNormRelu(nn.Module):
|
||||||
""" Convolution + IBNorm + ReLu
|
""" Convolution + IBNorm + ReLu
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size,
|
def __init__(self, in_channels, out_channels, kernel_size,
|
||||||
stride=1, padding=0, dilation=1, groups=1, bias=True,
|
stride=1, padding=0, dilation=1, groups=1, bias=True,
|
||||||
with_ibn=True, with_relu=True):
|
with_ibn=True, with_relu=True):
|
||||||
super(Conv2dIBNormRelu, self).__init__()
|
super(Conv2dIBNormRelu, self).__init__()
|
||||||
|
|
||||||
layers = [
|
layers = [
|
||||||
nn.Conv2d(in_channels, out_channels, kernel_size,
|
nn.Conv2d(in_channels, out_channels, kernel_size,
|
||||||
stride=stride, padding=padding, dilation=dilation,
|
stride=stride, padding=padding, dilation=dilation,
|
||||||
groups=groups, bias=bias)
|
groups=groups, bias=bias)
|
||||||
]
|
]
|
||||||
|
|
||||||
if with_ibn:
|
if with_ibn:
|
||||||
layers.append(IBNorm(out_channels))
|
layers.append(IBNorm(out_channels))
|
||||||
if with_relu:
|
if with_relu:
|
||||||
layers.append(nn.ReLU(inplace=True))
|
layers.append(nn.ReLU(inplace=True))
|
||||||
|
|
@ -52,11 +52,11 @@ class Conv2dIBNormRelu(nn.Module):
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.layers(x)
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
class SEBlock(nn.Module):
|
class SEBlock(nn.Module):
|
||||||
""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
|
""" SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, reduction=1):
|
def __init__(self, in_channels, out_channels, reduction=1):
|
||||||
|
|
@ -68,7 +68,7 @@ class SEBlock(nn.Module):
|
||||||
nn.Linear(int(in_channels // reduction), out_channels, bias=False),
|
nn.Linear(int(in_channels // reduction), out_channels, bias=False),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, c, _, _ = x.size()
|
b, c, _, _ = x.size()
|
||||||
w = self.pool(x).view(b, c)
|
w = self.pool(x).view(b, c)
|
||||||
|
|
@ -89,7 +89,7 @@ class LRBranch(nn.Module):
|
||||||
super(LRBranch, self).__init__()
|
super(LRBranch, self).__init__()
|
||||||
|
|
||||||
enc_channels = backbone.enc_channels
|
enc_channels = backbone.enc_channels
|
||||||
|
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
|
self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
|
||||||
self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
|
self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
|
||||||
|
|
@ -111,7 +111,7 @@ class LRBranch(nn.Module):
|
||||||
lr = self.conv_lr(lr8x)
|
lr = self.conv_lr(lr8x)
|
||||||
pred_semantic = torch.sigmoid(lr)
|
pred_semantic = torch.sigmoid(lr)
|
||||||
|
|
||||||
return pred_semantic, lr8x, [enc2x, enc4x]
|
return pred_semantic, lr8x, [enc2x, enc4x]
|
||||||
|
|
||||||
|
|
||||||
class HRBranch(nn.Module):
|
class HRBranch(nn.Module):
|
||||||
|
|
@ -177,7 +177,7 @@ class FusionBranch(nn.Module):
|
||||||
def __init__(self, hr_channels, enc_channels):
|
def __init__(self, hr_channels, enc_channels):
|
||||||
super(FusionBranch, self).__init__()
|
super(FusionBranch, self).__init__()
|
||||||
self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
|
self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
|
||||||
|
|
||||||
self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
|
self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
|
||||||
self.conv_f = nn.Sequential(
|
self.conv_f = nn.Sequential(
|
||||||
Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
|
Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
|
||||||
|
|
@ -226,7 +226,7 @@ class MODNet(nn.Module):
|
||||||
self._init_norm(m)
|
self._init_norm(m)
|
||||||
|
|
||||||
if self.backbone_pretrained:
|
if self.backbone_pretrained:
|
||||||
self.backbone.load_pretrained_ckpt()
|
self.backbone.load_pretrained_ckpt()
|
||||||
|
|
||||||
def forward(self, img, inference):
|
def forward(self, img, inference):
|
||||||
pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
|
pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
|
||||||
|
|
@ -234,7 +234,7 @@ class MODNet(nn.Module):
|
||||||
pred_matte = self.f_branch(img, lr8x, hr2x)
|
pred_matte = self.f_branch(img, lr8x, hr2x)
|
||||||
|
|
||||||
return pred_semantic, pred_detail, pred_matte
|
return pred_semantic, pred_detail, pred_matte
|
||||||
|
|
||||||
def freeze_norm(self):
|
def freeze_norm(self):
|
||||||
norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
|
norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
Loading…
Reference in New Issue