undo whitespace changes

pull/217/head
Richard Brown 2024-05-06 16:27:42 +02:00
parent 67a565bcd1
commit f95e1236bc
2 changed files with 15 additions and 15 deletions

View File

@ -26,7 +26,7 @@ class BaseBackbone(nn.Module):
class MobileNetV2Backbone(BaseBackbone): class MobileNetV2Backbone(BaseBackbone):
""" MobileNetV2 Backbone """ MobileNetV2 Backbone
""" """
def __init__(self, in_channels): def __init__(self, in_channels):
@ -72,11 +72,11 @@ class MobileNetV2Backbone(BaseBackbone):
return [enc2x, enc4x, enc8x, enc16x, enc32x] return [enc2x, enc4x, enc8x, enc16x, enc32x]
def load_pretrained_ckpt(self): def load_pretrained_ckpt(self):
# the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
print('cannot find the pretrained mobilenetv2 backbone') print('cannot find the pretrained mobilenetv2 backbone')
exit() exit()
ckpt = torch.load(ckpt_path) ckpt = torch.load(ckpt_path)
self.model.load_state_dict(ckpt) self.model.load_state_dict(ckpt)

View File

@ -21,7 +21,7 @@ class IBNorm(nn.Module):
self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
def forward(self, x): def forward(self, x):
bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
@ -33,18 +33,18 @@ class Conv2dIBNormRelu(nn.Module):
""" Convolution + IBNorm + ReLu """ Convolution + IBNorm + ReLu
""" """
def __init__(self, in_channels, out_channels, kernel_size, def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, bias=True, stride=1, padding=0, dilation=1, groups=1, bias=True,
with_ibn=True, with_relu=True): with_ibn=True, with_relu=True):
super(Conv2dIBNormRelu, self).__init__() super(Conv2dIBNormRelu, self).__init__()
layers = [ layers = [
nn.Conv2d(in_channels, out_channels, kernel_size, nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, dilation=dilation, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias) groups=groups, bias=bias)
] ]
if with_ibn: if with_ibn:
layers.append(IBNorm(out_channels)) layers.append(IBNorm(out_channels))
if with_relu: if with_relu:
layers.append(nn.ReLU(inplace=True)) layers.append(nn.ReLU(inplace=True))
@ -52,7 +52,7 @@ class Conv2dIBNormRelu(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
return self.layers(x) return self.layers(x)
class SEBlock(nn.Module): class SEBlock(nn.Module):
@ -89,7 +89,7 @@ class LRBranch(nn.Module):
super(LRBranch, self).__init__() super(LRBranch, self).__init__()
enc_channels = backbone.enc_channels enc_channels = backbone.enc_channels
self.backbone = backbone self.backbone = backbone
self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
@ -111,7 +111,7 @@ class LRBranch(nn.Module):
lr = self.conv_lr(lr8x) lr = self.conv_lr(lr8x)
pred_semantic = torch.sigmoid(lr) pred_semantic = torch.sigmoid(lr)
return pred_semantic, lr8x, [enc2x, enc4x] return pred_semantic, lr8x, [enc2x, enc4x]
class HRBranch(nn.Module): class HRBranch(nn.Module):
@ -177,7 +177,7 @@ class FusionBranch(nn.Module):
def __init__(self, hr_channels, enc_channels): def __init__(self, hr_channels, enc_channels):
super(FusionBranch, self).__init__() super(FusionBranch, self).__init__()
self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
self.conv_f = nn.Sequential( self.conv_f = nn.Sequential(
Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
@ -226,7 +226,7 @@ class MODNet(nn.Module):
self._init_norm(m) self._init_norm(m)
if self.backbone_pretrained: if self.backbone_pretrained:
self.backbone.load_pretrained_ckpt() self.backbone.load_pretrained_ckpt()
def forward(self, img, inference): def forward(self, img, inference):
pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference) pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
@ -234,7 +234,7 @@ class MODNet(nn.Module):
pred_matte = self.f_branch(img, lr8x, hr2x) pred_matte = self.f_branch(img, lr8x, hr2x)
return pred_semantic, pred_detail, pred_matte return pred_semantic, pred_detail, pred_matte
def freeze_norm(self): def freeze_norm(self):
norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
for m in self.modules(): for m in self.modules():