diff --git a/src/modnet/models/modnet.py b/src/modnet/models/modnet.py index 46b2cf2..76b157f 100644 --- a/src/modnet/models/modnet.py +++ b/src/modnet/models/modnet.py @@ -56,7 +56,7 @@ class Conv2dIBNormRelu(nn.Module): class SEBlock(nn.Module): - """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf + """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf """ def __init__(self, in_channels, out_channels, reduction=1): @@ -68,7 +68,7 @@ class SEBlock(nn.Module): nn.Linear(int(in_channels // reduction), out_channels, bias=False), nn.Sigmoid() ) - + def forward(self, x): b, c, _, _ = x.size() w = self.pool(x).view(b, c)