From 57104c242e0e00d7ecbb21ea1531074a93abb552 Mon Sep 17 00:00:00 2001 From: kevin Date: Fri, 5 Mar 2021 14:23:49 +0800 Subject: [PATCH] add torchscript --- src/models/backbones/mobilenetv2.py | 26 ++++++++++++++++++++------ src/models/backbones/wrapper.py | 33 ++++++++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/models/backbones/mobilenetv2.py b/src/models/backbones/mobilenetv2.py index 67cc138..709d352 100644 --- a/src/models/backbones/mobilenetv2.py +++ b/src/models/backbones/mobilenetv2.py @@ -136,17 +136,31 @@ class MobileNetV2(nn.Module): # Initialize weights self._init_weights() - def forward(self, x, feature_names=None): + def forward(self, x): # Stage1 - x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x) + x = self.features[0](x) + x = self.features[1](x) # Stage2 - x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x) + x = self.features[2](x) + x = self.features[3](x) # Stage3 - x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x) + x = self.features[4](x) + x = self.features[5](x) + x = self.features[6](x) # Stage4 - x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x) + x = self.features[7](x) + x = self.features[8](x) + x = self.features[9](x) + x = self.features[10](x) + x = self.features[11](x) + x = self.features[12](x) + x = self.features[13](x) # Stage5 - x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x) + x = self.features[14](x) + x = self.features[15](x) + x = self.features[16](x) + x = self.features[17](x) + x = self.features[18](x) # Classification if self.num_classes is not None: diff --git a/src/models/backbones/wrapper.py b/src/models/backbones/wrapper.py index 36817ba..72b8f17 100644 --- a/src/models/backbones/wrapper.py +++ b/src/models/backbones/wrapper.py @@ -36,15 +36,38 @@ class MobileNetV2Backbone(BaseBackbone): self.enc_channels = [16, 24, 32, 96, 1280] def forward(self, x): - x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) + # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) + x = self.model.features[0](x) + x = self.model.features[1](x) enc2x = x - x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) + x = self.model.features[2](x) + x = self.model.features[3](x) enc4x = x - x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) + x = self.model.features[4](x) + x = self.model.features[5](x) + x = self.model.features[6](x) enc8x = x - x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) + x = self.model.features[7](x) + x = self.model.features[8](x) + x = self.model.features[9](x) + x = self.model.features[10](x) + x = self.model.features[11](x) + x = self.model.features[12](x) + x = self.model.features[13](x) enc16x = x - x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) + x = self.model.features[14](x) + x = self.model.features[15](x) + x = self.model.features[16](x) + x = self.model.features[17](x) + x = self.model.features[18](x) enc32x = x return [enc2x, enc4x, enc8x, enc16x, enc32x]