From 26033da8f628251b1bcf3c969527094e1aa118fd Mon Sep 17 00:00:00 2001 From: kevin Date: Sat, 27 Feb 2021 19:26:06 +0800 Subject: [PATCH] add torchscript --- TorchScript/README.md | 12 ++ TorchScript/__init__.py | 0 TorchScript/export_torchscript.py | 42 ++++++ TorchScript/modnet_torchscript.py | 275 ++++++++++++++++++++++++++++++++++++ src/models/backbones/mobilenetv2.py | 24 +++- src/models/backbones/wrapper.py | 33 ++++- 6 files changed, 376 insertions(+), 10 deletions(-) create mode 100644 TorchScript/README.md create mode 100644 TorchScript/__init__.py create mode 100644 TorchScript/export_torchscript.py create mode 100644 TorchScript/modnet_torchscript.py diff --git a/TorchScript/README.md b/TorchScript/README.md new file mode 100644 index 0000000..5233420 --- /dev/null +++ b/TorchScript/README.md @@ -0,0 +1,12 @@ +## Usage: + +```shell + +python export_torchscript.py \ + --ckpt-path pretrained/modnet_photographic_portrait_matting.ckpt\ + --out-dir scripted_model +``` + +## Official TorchScript model: + +[BaiduCloudDisk](https://pan.baidu.com/s/1kOmmmbG7lSZiSmDdE7CaRw), extract_code=dm9e \ No newline at end of file diff --git a/TorchScript/__init__.py b/TorchScript/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TorchScript/export_torchscript.py b/TorchScript/export_torchscript.py new file mode 100644 index 0000000..ccae1a7 --- /dev/null +++ b/TorchScript/export_torchscript.py @@ -0,0 +1,42 @@ +import os +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict +from . import modnet_torchscript + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet') + parser.add_argument('--out-dir', type=str, required=True, help='path for saving the TorchScript model') + args = parser.parse_args() + + # check input arguments + if not os.path.exists(args.ckpt_path): + print('Cannot find checkpoint path: {0}'.format(args.ckpt_path)) + exit() + + if not os.path.exists(args.out_dir): + os.mkdir(args.out_dir) + + # create MODNet and load the pre-trained ckpt + modnet = MODNet(backbone_pretrained=True) + # modnet = nn.DataParallel(modnet).cuda() + modnet = modnet.cuda() + ckpt = torch.load(args.ckpt) + + # if use more than one GPU + if 'module.' in ckpt.keys(): + ckpt = OrderedDict() + for k, v in ckpt.items(): + k = k.replace('module.', '') + ckpt[k] = v + + modnet.load_state_dict(ckpt) + modnet.eval() + + scripted_model = torch.jit.script(modnet) + torch.jit.save(scripted_model, os.path.join(args.out_dir,'modnet.pt')) + diff --git a/TorchScript/modnet_torchscript.py b/TorchScript/modnet_torchscript.py new file mode 100644 index 0000000..3df13e5 --- /dev/null +++ b/TorchScript/modnet_torchscript.py @@ -0,0 +1,275 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# from .backbones import SUPPORTED_BACKBONES +from .backbones import SUPPORTED_BACKBONES + + +#------------------------------------------------------------------------------ +# MODNet Basic Modules +#------------------------------------------------------------------------------ + +class IBNorm(nn.Module): + """ Combine Instance Norm and Batch Norm into One Layer + 对一半channel做BN,一半做IN + """ + + def __init__(self, in_channels): + super(IBNorm, self).__init__() + in_channels = in_channels + self.bnorm_channels = int(in_channels / 2) + self.inorm_channels = in_channels - self.bnorm_channels + + self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) + self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) + + def forward(self, x): + bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) + in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) + + return torch.cat((bn_x, in_x), 1) + + +class Conv2dIBNormRelu(nn.Module): + """ Convolution + IBNorm + ReLu + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, bias=True, + with_ibn=True, with_relu=True): + super(Conv2dIBNormRelu, self).__init__() + + layers = [ + nn.Conv2d(in_channels, out_channels, kernel_size, + stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + ] + + if with_ibn: + layers.append(IBNorm(out_channels)) + if with_relu: + layers.append(nn.ReLU(inplace=True)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class SEBlock(nn.Module): + """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf + 通道 Attention + """ + + def __init__(self, in_channels, out_channels, reduction=1): + super(SEBlock, self).__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(in_channels, int(in_channels // reduction), bias=False), + nn.ReLU(inplace=True), + nn.Linear(int(in_channels // reduction), out_channels, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + w = self.pool(x).view(b, c) + w = self.fc(w).view(b, c, 1, 1) + + return x * w.expand_as(x) + + +#------------------------------------------------------------------------------ +# MODNet Branches +#------------------------------------------------------------------------------ + +class LRBranch(nn.Module): + """ Low Resolution Branch of MODNet + """ + + def __init__(self, backbone): + super(LRBranch, self).__init__() + + enc_channels = backbone.enc_channels + # ==> self.enc_channels = [16, 24, 32, 96, 1280] + + self.backbone = backbone + self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) + self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) + self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) + self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) + + def forward(self, img, inference): + enc_features = self.backbone.forward(img) + enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] + + # 对最后一层进行通道注意力 + enc32x = self.se_block(enc32x) + # 再上采样4倍 + lr16x = F.interpolate(enc32x, scale_factor=2.0, mode='bilinear', align_corners=False) + lr16x = self.conv_lr16x(lr16x) + lr8x = F.interpolate(lr16x, scale_factor=2.0, mode='bilinear', align_corners=False) + lr8x = self.conv_lr8x(lr8x) + + pred_semantic = torch.tensor([]) # None + if not inference: + lr = self.conv_lr(lr8x) + pred_semantic = torch.sigmoid(lr) + + return pred_semantic, lr8x, [enc2x, enc4x] + + +class HRBranch(nn.Module): + """ High Resolution Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(HRBranch, self).__init__() + + self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) + self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) + + self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) + self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) + + self.conv_hr4x = nn.Sequential( + Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr2x = nn.Sequential( + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, enc2x, enc4x, lr8x, inference): + img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False) + img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False) + + enc2x = self.tohr_enc2x(enc2x) + # 把原图叠加到通道上 + hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) + + # 把两个 featmap 连接 + enc4x = self.tohr_enc4x(enc4x) + hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) + + lr4x = F.interpolate(lr8x, scale_factor=2.0, mode='bilinear', align_corners=False) + hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) + + hr2x = F.interpolate(hr4x, scale_factor=2.0, mode='bilinear', align_corners=False) + hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) + + pred_detail = torch.tensor([]) # None + if not inference: + hr = F.interpolate(hr2x, scale_factor=2.0, mode='bilinear', align_corners=False) + hr = self.conv_hr(torch.cat((hr, img), dim=1)) + pred_detail = torch.sigmoid(hr) + + return pred_detail, hr2x + + +class FusionBranch(nn.Module): + """ Fusion Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(FusionBranch, self).__init__() + self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) + + self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) + self.conv_f = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), + Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, lr8x, hr2x): + lr4x = F.interpolate(lr8x, scale_factor=2.0, mode='bilinear', align_corners=False) + lr4x = self.conv_lr4x(lr4x) + lr2x = F.interpolate(lr4x, scale_factor=2.0, mode='bilinear', align_corners=False) + + f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) + f = F.interpolate(f2x, scale_factor=2.0, mode='bilinear', align_corners=False) + f = self.conv_f(torch.cat((f, img), dim=1)) + pred_matte = torch.sigmoid(f) + + return pred_matte + + +#------------------------------------------------------------------------------ +# MODNet +#------------------------------------------------------------------------------ + +class MODNet(nn.Module): + """ Architecture of MODNet + """ + + def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True): + super(MODNet, self).__init__() + + self.in_channels = in_channels + self.hr_channels = hr_channels + self.backbone_arch = backbone_arch + self.backbone_pretrained = backbone_pretrained + + self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels) + + self.lr_branch = LRBranch(self.backbone) + self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) + self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + self._init_conv(m) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): + self._init_norm(m) + + if self.backbone_pretrained: + self.backbone.load_pretrained_ckpt() + + def forward(self, img, inference): + pred_semantic = self.lr_branch(img, inference)[0] + lr8x = self.lr_branch(img, inference)[1] + enc2x = self.lr_branch(img, inference)[2][0] + enc4x = self.lr_branch(img, inference)[2][1] + + pred_detail = self.hr_branch(img, enc2x, enc4x, lr8x, inference)[0] + hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)[1] + + pred_matte = self.f_branch(img, lr8x, hr2x) + + return pred_semantic, pred_detail, pred_matte + + def freeze_norm(self): + norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] + for m in self.modules(): + for n in norm_types: + if isinstance(m, n): + m.eval() + continue + + def _init_conv(self, conv): + nn.init.kaiming_uniform_( + conv.weight, a=0, mode='fan_in', nonlinearity='relu') + if conv.bias is not None: + nn.init.constant_(conv.bias, 0) + + def _init_norm(self, norm): + if norm.weight is not None: + nn.init.constant_(norm.weight, 1) + nn.init.constant_(norm.bias, 0) + + +if __name__ == "__main__": + IbNorm = IBNorm(20) + out = IbNorm(torch.randn((1,3,224,224))) + print(out.shape) \ No newline at end of file diff --git a/src/models/backbones/mobilenetv2.py b/src/models/backbones/mobilenetv2.py index 67cc138..2b72db7 100644 --- a/src/models/backbones/mobilenetv2.py +++ b/src/models/backbones/mobilenetv2.py @@ -138,15 +138,29 @@ class MobileNetV2(nn.Module): def forward(self, x, feature_names=None): # Stage1 - x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x) + x = self.features[0](x) + x = self.features[1](x) # Stage2 - x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x) + x = self.features[2](x) + x = self.features[3](x) # Stage3 - x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x) + x = self.features[4](x) + x = self.features[5](x) + x = self.features[6](x) # Stage4 - x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x) + x = self.features[7](x) + x = self.features[8](x) + x = self.features[9](x) + x = self.features[10](x) + x = self.features[11](x) + x = self.features[12](x) + x = self.features[13](x) # Stage5 - x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x) + x = self.features[14](x) + x = self.features[15](x) + x = self.features[16](x) + x = self.features[17](x) + x = self.features[18](x) # Classification if self.num_classes is not None: diff --git a/src/models/backbones/wrapper.py b/src/models/backbones/wrapper.py index 36817ba..72b8f17 100644 --- a/src/models/backbones/wrapper.py +++ b/src/models/backbones/wrapper.py @@ -36,15 +36,38 @@ class MobileNetV2Backbone(BaseBackbone): self.enc_channels = [16, 24, 32, 96, 1280] def forward(self, x): - x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) + # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) + x = self.model.features[0](x) + x = self.model.features[1](x) enc2x = x - x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) + x = self.model.features[2](x) + x = self.model.features[3](x) enc4x = x - x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) + x = self.model.features[4](x) + x = self.model.features[5](x) + x = self.model.features[6](x) enc8x = x - x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) + x = self.model.features[7](x) + x = self.model.features[8](x) + x = self.model.features[9](x) + x = self.model.features[10](x) + x = self.model.features[11](x) + x = self.model.features[12](x) + x = self.model.features[13](x) enc16x = x - x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) + x = self.model.features[14](x) + x = self.model.features[15](x) + x = self.model.features[16](x) + x = self.model.features[17](x) + x = self.model.features[18](x) enc32x = x return [enc2x, enc4x, enc8x, enc16x, enc32x]