diff --git a/src/models/modnet.py b/src/models/modnet.py index 16609b3..9e268e7 100644 --- a/src/models/modnet.py +++ b/src/models/modnet.py @@ -24,7 +24,7 @@ class IBNorm(nn.Module): def forward(self, x): bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) - in_x = self.inorm(x[:, self.inorm_channels:, ...].contiguous()) + in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) return torch.cat((bn_x, in_x), 1)