|
|
@@ -63,14 +63,16 @@ class BasicBlock(nn.Module):
|
|
|
stride=1,
|
|
|
cardinality=1,
|
|
|
downsample=None,
|
|
|
+ fused_se=True,
|
|
|
last_bn_0_init=False,
|
|
|
+ trt=False,
|
|
|
):
|
|
|
super(BasicBlock, self).__init__()
|
|
|
- self.conv1 = builder.conv3x3(inplanes, planes, stride, cardinality=cardinality)
|
|
|
+ self.conv1 = builder.conv3x3(inplanes, planes, stride, groups=cardinality)
|
|
|
self.bn1 = builder.batchnorm(planes)
|
|
|
self.relu = builder.activation()
|
|
|
self.conv2 = builder.conv3x3(
|
|
|
- planes, planes * expansion, cardinality=cardinality
|
|
|
+ planes, planes * expansion, groups=cardinality
|
|
|
)
|
|
|
self.bn2 = builder.batchnorm(planes * expansion, zero_init=last_bn_0_init)
|
|
|
self.downsample = downsample
|