Browse Source

[ResNet/PyT] Fix Resnet BasicBlock constructor

smichniak 3 years ago
parent
commit
0c4310bf17

+ 4 - 2
PyTorch/Classification/ConvNets/image_classification/models/resnet.py

@@ -63,14 +63,16 @@ class BasicBlock(nn.Module):
         stride=1,
         cardinality=1,
         downsample=None,
+        fused_se=True,
         last_bn_0_init=False,
+        trt=False,
     ):
         super(BasicBlock, self).__init__()
-        self.conv1 = builder.conv3x3(inplanes, planes, stride, cardinality=cardinality)
+        self.conv1 = builder.conv3x3(inplanes, planes, stride, groups=cardinality)
         self.bn1 = builder.batchnorm(planes)
         self.relu = builder.activation()
         self.conv2 = builder.conv3x3(
-            planes, planes * expansion, cardinality=cardinality
+            planes, planes * expansion, groups=cardinality
         )
         self.bn2 = builder.batchnorm(planes * expansion, zero_init=last_bn_0_init)
         self.downsample = downsample