diff --git a/src/modnet/models/backbones/wrapper.py b/src/modnet/models/backbones/wrapper.py index c622497..7afe666 100644 --- a/src/modnet/models/backbones/wrapper.py +++ b/src/modnet/models/backbones/wrapper.py @@ -26,7 +26,7 @@ class BaseBackbone(nn.Module): class MobileNetV2Backbone(BaseBackbone): - """ MobileNetV2 Backbone + """ MobileNetV2 Backbone """ def __init__(self, in_channels): @@ -72,11 +72,11 @@ class MobileNetV2Backbone(BaseBackbone): return [enc2x, enc4x, enc8x, enc16x, enc32x] def load_pretrained_ckpt(self): - # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch + # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' if not os.path.exists(ckpt_path): print('cannot find the pretrained mobilenetv2 backbone') exit() - + ckpt = torch.load(ckpt_path) self.model.load_state_dict(ckpt) diff --git a/src/modnet/models/modnet.py b/src/modnet/models/modnet.py index 00ced37..46b2cf2 100644 --- a/src/modnet/models/modnet.py +++ b/src/modnet/models/modnet.py @@ -21,7 +21,7 @@ class IBNorm(nn.Module): self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) - + def forward(self, x): bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) @@ -33,18 +33,18 @@ class Conv2dIBNormRelu(nn.Module): """ Convolution + IBNorm + ReLu """ - def __init__(self, in_channels, out_channels, kernel_size, - stride=1, padding=0, dilation=1, groups=1, bias=True, + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True, with_ibn=True, with_relu=True): super(Conv2dIBNormRelu, self).__init__() layers = [ - nn.Conv2d(in_channels, out_channels, kernel_size, - stride=stride, padding=padding, dilation=dilation, + nn.Conv2d(in_channels, out_channels, kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) ] - if with_ibn: + if with_ibn: layers.append(IBNorm(out_channels)) if with_relu: layers.append(nn.ReLU(inplace=True)) @@ -52,7 +52,7 @@ class Conv2dIBNormRelu(nn.Module): self.layers = nn.Sequential(*layers) def forward(self, x): - return self.layers(x) + return self.layers(x) class SEBlock(nn.Module): @@ -89,7 +89,7 @@ class LRBranch(nn.Module): super(LRBranch, self).__init__() enc_channels = backbone.enc_channels - + self.backbone = backbone self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) @@ -111,7 +111,7 @@ class LRBranch(nn.Module): lr = self.conv_lr(lr8x) pred_semantic = torch.sigmoid(lr) - return pred_semantic, lr8x, [enc2x, enc4x] + return pred_semantic, lr8x, [enc2x, enc4x] class HRBranch(nn.Module): @@ -177,7 +177,7 @@ class FusionBranch(nn.Module): def __init__(self, hr_channels, enc_channels): super(FusionBranch, self).__init__() self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) - + self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) self.conv_f = nn.Sequential( Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), @@ -226,7 +226,7 @@ class MODNet(nn.Module): self._init_norm(m) if self.backbone_pretrained: - self.backbone.load_pretrained_ckpt() + self.backbone.load_pretrained_ckpt() def forward(self, img, inference): pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference) @@ -234,7 +234,7 @@ class MODNet(nn.Module): pred_matte = self.f_branch(img, lr8x, hr2x) return pred_semantic, pred_detail, pred_matte - + def freeze_norm(self): norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] for m in self.modules():