diff --git a/onnx/export_onnx.py b/onnx/export_onnx.py index a9cb864..b3598f9 100644 --- a/onnx/export_onnx.py +++ b/onnx/export_onnx.py @@ -1,7 +1,7 @@ """ Export ONNX model of MODNet with: input shape: (batch_size, 3, height, width) - output shape: (batch_size, 1, height, width) + output shape: (batch_size, 1, height, width) Arguments: --ckpt-path: path of the checkpoint that will be converted @@ -50,6 +50,6 @@ if __name__ == '__main__': # export to onnx model torch.onnx.export( - modnet.module, dummy_input, args.output_path, export_params = True, - input_names = ['input'], output_names = ['output'], + modnet.module, dummy_input, args.output_path, export_params = True, + input_names = ['input'], output_names = ['output'], dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'}, 'output': {0: 'batch_size', 2: 'height', 3: 'width'}}) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ce0d196 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Synthesia Limited - All Rights Reserved +# +# Unauthorized copying of this file, via any medium is strictly prohibited. +# Proprietary and confidential. + +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "modnet" +version = "0.0.1" + +requires-python = ">=3.10" + +dependencies = [ + "torch", +] + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/src/__init__.py b/src/modnet/__init__.py similarity index 100% rename from src/__init__.py rename to src/modnet/__init__.py diff --git a/src/models/__init__.py b/src/modnet/models/__init__.py similarity index 100% rename from src/models/__init__.py rename to src/modnet/models/__init__.py diff --git a/src/models/backbones/__init__.py b/src/modnet/models/backbones/__init__.py similarity index 100% rename from src/models/backbones/__init__.py rename to src/modnet/models/backbones/__init__.py diff --git a/src/models/backbones/mobilenetv2.py b/src/modnet/models/backbones/mobilenetv2.py similarity index 100% rename from src/models/backbones/mobilenetv2.py rename to src/modnet/models/backbones/mobilenetv2.py diff --git a/src/models/backbones/wrapper.py b/src/modnet/models/backbones/wrapper.py similarity index 95% rename from src/models/backbones/wrapper.py rename to src/modnet/models/backbones/wrapper.py index 72b8f17..c622497 100644 --- a/src/models/backbones/wrapper.py +++ b/src/modnet/models/backbones/wrapper.py @@ -4,7 +4,7 @@ from functools import reduce import torch import torch.nn as nn -from .mobilenetv2 import MobileNetV2 +from modnet.models.backbones.mobilenetv2 import MobileNetV2 class BaseBackbone(nn.Module): @@ -26,7 +26,7 @@ class BaseBackbone(nn.Module): class MobileNetV2Backbone(BaseBackbone): - """ MobileNetV2 Backbone + """ MobileNetV2 Backbone """ def __init__(self, in_channels): @@ -72,11 +72,11 @@ class MobileNetV2Backbone(BaseBackbone): return [enc2x, enc4x, enc8x, enc16x, enc32x] def load_pretrained_ckpt(self): - # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch + # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' if not os.path.exists(ckpt_path): print('cannot find the pretrained mobilenetv2 backbone') exit() - + ckpt = torch.load(ckpt_path) self.model.load_state_dict(ckpt) diff --git a/src/models/modnet.py b/src/modnet/models/modnet.py similarity index 95% rename from src/models/modnet.py rename to src/modnet/models/modnet.py index 9e268e7..00ced37 100644 --- a/src/models/modnet.py +++ b/src/modnet/models/modnet.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .backbones import SUPPORTED_BACKBONES +from modnet.models.backbones import SUPPORTED_BACKBONES #------------------------------------------------------------------------------ @@ -21,7 +21,7 @@ class IBNorm(nn.Module): self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) - + def forward(self, x): bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) @@ -33,18 +33,18 @@ class Conv2dIBNormRelu(nn.Module): """ Convolution + IBNorm + ReLu """ - def __init__(self, in_channels, out_channels, kernel_size, - stride=1, padding=0, dilation=1, groups=1, bias=True, + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True, with_ibn=True, with_relu=True): super(Conv2dIBNormRelu, self).__init__() layers = [ - nn.Conv2d(in_channels, out_channels, kernel_size, - stride=stride, padding=padding, dilation=dilation, + nn.Conv2d(in_channels, out_channels, kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) ] - if with_ibn: + if with_ibn: layers.append(IBNorm(out_channels)) if with_relu: layers.append(nn.ReLU(inplace=True)) @@ -52,11 +52,11 @@ class Conv2dIBNormRelu(nn.Module): self.layers = nn.Sequential(*layers) def forward(self, x): - return self.layers(x) + return self.layers(x) class SEBlock(nn.Module): - """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf + """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf """ def __init__(self, in_channels, out_channels, reduction=1): @@ -68,7 +68,7 @@ class SEBlock(nn.Module): nn.Linear(int(in_channels // reduction), out_channels, bias=False), nn.Sigmoid() ) - + def forward(self, x): b, c, _, _ = x.size() w = self.pool(x).view(b, c) @@ -89,7 +89,7 @@ class LRBranch(nn.Module): super(LRBranch, self).__init__() enc_channels = backbone.enc_channels - + self.backbone = backbone self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) @@ -111,7 +111,7 @@ class LRBranch(nn.Module): lr = self.conv_lr(lr8x) pred_semantic = torch.sigmoid(lr) - return pred_semantic, lr8x, [enc2x, enc4x] + return pred_semantic, lr8x, [enc2x, enc4x] class HRBranch(nn.Module): @@ -177,7 +177,7 @@ class FusionBranch(nn.Module): def __init__(self, hr_channels, enc_channels): super(FusionBranch, self).__init__() self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) - + self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) self.conv_f = nn.Sequential( Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), @@ -226,7 +226,7 @@ class MODNet(nn.Module): self._init_norm(m) if self.backbone_pretrained: - self.backbone.load_pretrained_ckpt() + self.backbone.load_pretrained_ckpt() def forward(self, img, inference): pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference) @@ -234,7 +234,7 @@ class MODNet(nn.Module): pred_matte = self.f_branch(img, lr8x, hr2x) return pred_semantic, pred_detail, pred_matte - + def freeze_norm(self): norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] for m in self.modules(): diff --git a/src/trainer.py b/src/modnet/trainer.py similarity index 100% rename from src/trainer.py rename to src/modnet/trainer.py