Просмотр исходного кода

[ConvNets/PyT] Adding checkpoints for EfficientNet/PyT, Squeeze&Excitation can use Conv or Linear layer depending on `--trt` switch.

Andrzej Sulecki 4 лет назад
Родитель
Сommit
555b84b3b1

+ 84 - 0
PyTorch/Classification/ConvNets/configs.yml

@@ -309,6 +309,48 @@ models:
                 arch: efficientnet-b0
                 batch_size: 256
     # }}}
+    efficientnet-quant-b0: # {{{
+        T4:
+            AMP:
+                <<: *efficientnet_b0_params_4k
+                arch: efficientnet-quant-b0
+                batch_size: 128
+                memory_format: nhwc
+            FP32:
+                <<: *efficientnet_b0_params_4k
+                arch: efficientnet-quant-b0
+                batch_size: 64
+        DGX1V-16G:
+            AMP:
+                <<: *efficientnet_b0_params_4k
+                arch: efficientnet-quant-b0
+                batch_size: 128
+                memory_format: nhwc
+            FP32:
+                <<: *efficientnet_b0_params_4k
+                arch: efficientnet-quant-b0
+                batch_size: 64
+        DGX1V-32G:
+            AMP:
+                <<: *efficientnet_b0_params_4k
+                arch: efficientnet-quant-b0
+                batch_size: 256
+                memory_format: nhwc
+            FP32:
+                <<: *efficientnet_b0_params_4k
+                arch: efficientnet-quant-b0
+                batch_size: 128
+        DGXA100:
+            AMP:
+                <<: *efficientnet_b0_params_4k
+                arch: efficientnet-quant-b0
+                batch_size: 256
+                memory_format: nhwc
+            TF32:
+                <<: *efficientnet_b0_params_4k
+                arch: efficientnet-quant-b0
+                batch_size: 256
+    # }}}
     efficientnet-widese-b4: # {{{
         T4:
             AMP:
@@ -393,3 +435,45 @@ models:
                 arch: efficientnet-b4
                 batch_size: 64
     # }}}
+    efficientnet-quant-b4: # {{{
+        T4:
+            AMP:
+                <<: *efficientnet_b4_params_4k
+                arch: efficientnet-quant-b4
+                batch_size: 32
+                memory_format: nhwc
+            FP32:
+                <<: *efficientnet_b4_params_4k
+                arch: efficientnet-quant-b4
+                batch_size: 16
+        DGX1V-16G:
+            AMP:
+                <<: *efficientnet_b4_params_4k
+                arch: efficientnet-quant-b4
+                batch_size: 32
+                memory_format: nhwc
+            FP32:
+                <<: *efficientnet_b4_params_4k
+                arch: efficientnet-quant-b4
+                batch_size: 16
+        DGX1V-32G:
+            AMP:
+                <<: *efficientnet_b4_params_4k
+                arch: efficientnet-quant-b4
+                batch_size: 64
+                memory_format: nhwc
+            FP32:
+                <<: *efficientnet_b4_params_4k
+                arch: efficientnet-quant-b4
+                batch_size: 32
+        DGXA100:
+            AMP:
+                <<: *efficientnet_b4_params_4k
+                arch: efficientnet-quant-b4
+                batch_size: 128
+                memory_format: nhwc
+            TF32:
+                <<: *efficientnet_b4_params_4k
+                arch: efficientnet-quant-b4
+                batch_size: 64
+    # }}}

+ 10 - 2
PyTorch/Classification/ConvNets/efficientnet/README.md

@@ -434,11 +434,19 @@ You can also run the ImageNet validation on pretrained weights:
 Pretrained weights can be downloaded from NGC:
 
 ```bash
-wget --content-disposition <ngc weights url>
+wget <ngc weights url>
 ```
 
+URL for each model can be found in the following table:
 
-
+| **Model** | **NGC weights URL** |
+|:---------:|:-------------------:|
+| efficientnet-b0 | https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_b0_pyt_amp/versions/20.12.0/files/nvidia_efficientnet-b0_210412.pth | 
+| efficientnet-b4 | https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_b4_pyt_amp/versions/20.12.0/files/nvidia_efficientnet-b4_210412.pth | 
+| efficientnet-widese-b0 | https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_widese_b0_pyt_amp/versions/20.12.0/files/nvidia_efficientnet-widese-b0_210412.pth | 
+| efficientnet-widese-b4 | https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_widese_b4_pyt_amp/versions/20.12.0/files/nvidia_efficientnet-widese-b4_210412.pth | 
+| efficientnet-quant-b0 | https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_b0_pyt_qat_ckpt_fp32/versions/21.03.0/files/nvidia-efficientnet-quant-b0-130421.pth | 
+| efficientnet-quant-b4 | https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_b4_pyt_qat_ckpt_fp32/versions/21.03.0/files/nvidia-efficientnet-quant-b4-130421.pth | 
 
 To run inference on ImageNet, run:
 

+ 36 - 10
PyTorch/Classification/ConvNets/image_classification/models/common.py

@@ -3,8 +3,17 @@ from collections import OrderedDict
 from dataclasses import dataclass
 from typing import Optional
 import torch
+import warnings
 from torch import nn
-from pytorch_quantization import nn as quant_nn
+
+try:
+    from pytorch_quantization import nn as quant_nn
+except ImportError as e:
+    warnings.warn(
+        "pytorch_quantization module not found, quantization will not be available"
+    )
+    quant_nn = None
+
 
 # LayerBuilder {{{
 class LayerBuilder(object):
@@ -134,20 +143,30 @@ class LambdaLayer(nn.Module):
 
 # SqueezeAndExcitation {{{
 class SqueezeAndExcitation(nn.Module):
-    def __init__(self, in_channels, squeeze, activation):
+    def __init__(self, in_channels, squeeze, activation, use_conv=False):
         super(SqueezeAndExcitation, self).__init__()
-        self.pooling = nn.AdaptiveAvgPool2d(1)
-        self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
-        self.expand = nn.Conv2d(squeeze, in_channels, 1)
+        if use_conv:
+            self.pooling = nn.AdaptiveAvgPool2d(1)
+            self.squeeze = nn.Conv2d(in_channels, squeeze, 1)
+            self.expand = nn.Conv2d(squeeze, in_channels, 1)
+        else:
+            self.squeeze = nn.Linear(in_channels, squeeze)
+            self.expand = nn.Linear(squeeze, in_channels)
         self.activation = activation
         self.sigmoid = nn.Sigmoid()
+        self.use_conv = use_conv
 
     def forward(self, x):
-        out = self.pooling(x)
+        if self.use_conv:
+            out = self.pooling(x)
+        else:
+            out = torch.mean(x, [2, 3])
         out = self.squeeze(out)
         out = self.activation(out)
         out = self.expand(out)
         out = self.sigmoid(out)
+        if not self.use_conv:
+            out = out.unsqueeze(2).unsqueeze(3)
         return out
 
 
@@ -199,12 +218,19 @@ class ONNXSiLU(nn.Module):
 
 
 class SequentialSqueezeAndExcitation(SqueezeAndExcitation):
-    def __init__(self, in_channels, squeeze, activation, quantized=False):
-        super().__init__(in_channels, squeeze, activation,)
+    def __init__(
+        self, in_channels, squeeze, activation, quantized=False, use_conv=False
+    ):
+        super().__init__(in_channels, squeeze, activation, use_conv=use_conv)
         self.quantized = quantized
         if quantized:
-            self.mul_a_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)
-            self.mul_b_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)
+            assert quant_nn is not None, "pytorch_quantization is not available"
+            self.mul_a_quantizer = quant_nn.TensorQuantizer(
+                quant_nn.QuantConv2d.default_quant_desc_input
+            )
+            self.mul_b_quantizer = quant_nn.TensorQuantizer(
+                quant_nn.QuantConv2d.default_quant_desc_input
+            )
 
     def forward(self, x):
         if not self.quantized:

+ 61 - 16
PyTorch/Classification/ConvNets/image_classification/models/efficientnet.py

@@ -1,6 +1,7 @@
 import argparse
 import random
 import math
+import warnings
 from typing import List, Any, Optional
 from collections import namedtuple, OrderedDict
 from dataclasses import dataclass, replace
@@ -8,7 +9,26 @@ from dataclasses import dataclass, replace
 import torch
 from torch import nn
 from functools import partial
-from pytorch_quantization import nn as quant_nn
+
+try:
+    from pytorch_quantization import nn as quant_nn
+    from ..quantization import switch_on_quantization
+except ImportError as e:
+    warnings.warn(
+        "pytorch_quantization module not found, quantization will not be available"
+    )
+    quant_nn = None
+
+    import contextlib
+
+    @contextlib.contextmanager
+    def switch_on_quantization(do_quantization=False):
+        assert not do_quantization, "quantization is not available"
+        try:
+            yield
+        finally:
+            pass
+
 
 from .common import (
     SqueezeAndExcitation,
@@ -27,7 +47,6 @@ from .model import (
     EntryPoint,
 )
 
-from ..quantization import switch_on_quantization
 
 # EffNetArch {{{
 @dataclass
@@ -107,6 +126,7 @@ class EffNetParams(ModelParams):
     bn_epsilon: float = 1e-3
     survival_prob: float = 1
     quantized: bool = False
+    trt: bool = False
 
     def parser(self, name):
         p = super().parser(name)
@@ -145,6 +165,7 @@ class EffNetParams(ModelParams):
         p.add_argument(
             "--dropout", default=self.dropout, type=float, help="Dropout drop prob"
         )
+        p.add_argument("--trt", metavar="True|False", default=self.trt, type=bool)
         return p
 
 
@@ -162,7 +183,8 @@ class EfficientNet(nn.Module):
         bn_momentum: float = 1 - 0.99,
         bn_epsilon: float = 1e-3,
         survival_prob: float = 1,
-        quantized: bool = False
+        quantized: bool = False,
+        trt: bool = False,
     ):
         self.quantized = quantized
         with switch_on_quantization(self.quantized):
@@ -195,6 +217,7 @@ class EfficientNet(nn.Module):
                     out_channels=c,
                     squeeze_excitation_ratio=arch.squeeze_excitation_ratio,
                     prev_layer_count=plc,
+                    trt=trt,
                 )
                 plc = plc + r
                 setattr(self, f"layer{i+1}", layer)
@@ -293,6 +316,7 @@ class EfficientNet(nn.Module):
         out_channels,
         squeeze_excitation_ratio,
         prev_layer_count,
+        trt,
     ):
         layers = []
 
@@ -307,7 +331,8 @@ class EfficientNet(nn.Module):
             stride,
             self.arch.squeeze_excitation_ratio,
             survival_prob if stride == 1 and in_channels == out_channels else 1.0,
-            self.quantized
+            self.quantized,
+            trt=trt,
         )
         layers.append((f"block{idx}", blk))
 
@@ -322,7 +347,8 @@ class EfficientNet(nn.Module):
                 1,  # stride
                 squeeze_excitation_ratio,
                 survival_prob,
-                self.quantized
+                self.quantized,
+                trt=trt,
             )
             layers.append((f"block{idx}", blk))
         return nn.Sequential(OrderedDict(layers)), out_channels
@@ -343,7 +369,8 @@ class MBConvBlock(nn.Module):
         squeeze_excitation_ratio: int,
         squeeze_hidden=False,
         survival_prob: float = 1.0,
-        quantized: bool = False
+        quantized: bool = False,
+        trt: bool = False,
     ):
         super().__init__()
         self.quantized = quantized
@@ -361,14 +388,17 @@ class MBConvBlock(nn.Module):
             depsep_kernel_size, hidden_dim, hidden_dim, stride, bn=True, act=True
         )
         self.se = SequentialSqueezeAndExcitation(
-            hidden_dim, squeeze_dim, builder.activation(), self.quantized
+            hidden_dim, squeeze_dim, builder.activation(), self.quantized, use_conv=trt
         )
         self.proj = builder.conv1x1(hidden_dim, out_channels, bn=True)
 
         self.survival_prob = survival_prob
 
         if self.quantized and self.residual:
-            self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)  # TODO QuantConv2d ?!?
+            assert quant_nn is not None, "pytorch_quantization is not available"
+            self.residual_quantizer = quant_nn.TensorQuantizer(
+                quant_nn.QuantConv2d.default_quant_desc_input
+            )  # TODO QuantConv2d ?!?
 
     def drop(self):
         if self.survival_prob == 1.0:
@@ -406,6 +436,7 @@ def original_mbconv(
     squeeze_excitation_ratio: int,
     survival_prob: float,
     quantized: bool,
+    trt: bool
 ):
     return MBConvBlock(
         builder,
@@ -417,7 +448,8 @@ def original_mbconv(
         squeeze_excitation_ratio,
         squeeze_hidden=False,
         survival_prob=survival_prob,
-        quantized=quantized
+        quantized=quantized,
+        trt=trt,
     )
 
 
@@ -431,6 +463,7 @@ def widese_mbconv(
     squeeze_excitation_ratio: int,
     survival_prob: float,
     quantized: bool,
+    trt: bool,
 ):
     return MBConvBlock(
         builder,
@@ -442,7 +475,8 @@ def widese_mbconv(
         squeeze_excitation_ratio,
         squeeze_hidden=True,
         survival_prob=survival_prob,
-        quantized=False
+        quantized=quantized,
+        trt=trt,
     )
 
 
@@ -469,31 +503,42 @@ effnet_b5_layers=effnet_b0_layers.scale(wc=1.6, dc=2.2, dis=456)
 effnet_b6_layers=effnet_b0_layers.scale(wc=1.8, dc=2.6, dis=528)
 effnet_b7_layers=effnet_b0_layers.scale(wc=2.0, dc=3.1, dis=600)
 
+
+
+urls = {
+    "efficientnet-b0": "https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_b0_pyt_amp/versions/20.12.0/files/nvidia_efficientnet-b0_210412.pth",
+    "efficientnet-b4": "https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_b4_pyt_amp/versions/20.12.0/files/nvidia_efficientnet-b4_210412.pth",
+    "efficientnet-widese-b0": "https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_widese_b0_pyt_amp/versions/20.12.0/files/nvidia_efficientnet-widese-b0_210412.pth",
+    "efficientnet-widese-b4": "https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_widese_b4_pyt_amp/versions/20.12.0/files/nvidia_efficientnet-widese-b4_210412.pth",
+    "efficientnet-quant-b0": "https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_b0_pyt_qat_ckpt_fp32/versions/21.03.0/files/nvidia-efficientnet-quant-b0-130421.pth",
+    "efficientnet-quant-b4": "https://api.ngc.nvidia.com/v2/models/nvidia/efficientnet_b4_pyt_qat_ckpt_fp32/versions/21.03.0/files/nvidia-efficientnet-quant-b4-130421.pth",
+}
+
 def _m(*args, **kwargs):
     return Model(constructor=EfficientNet, *args, **kwargs)
 
 architectures = {
-    "efficientnet-b0": _m(arch=effnet_b0_layers, params=EffNetParams(dropout=0.2)),
+    "efficientnet-b0": _m(arch=effnet_b0_layers, params=EffNetParams(dropout=0.2), checkpoint_url=urls["efficientnet-b0"]),
     "efficientnet-b1": _m(arch=effnet_b1_layers, params=EffNetParams(dropout=0.2)),
     "efficientnet-b2": _m(arch=effnet_b2_layers, params=EffNetParams(dropout=0.3)),
     "efficientnet-b3": _m(arch=effnet_b3_layers, params=EffNetParams(dropout=0.3)),
-    "efficientnet-b4": _m(arch=effnet_b4_layers, params=EffNetParams(dropout=0.4, survival_prob=0.8)),
+    "efficientnet-b4": _m(arch=effnet_b4_layers, params=EffNetParams(dropout=0.4, survival_prob=0.8), checkpoint_url=urls["efficientnet-b4"]),
     "efficientnet-b5": _m(arch=effnet_b5_layers, params=EffNetParams(dropout=0.4)),
     "efficientnet-b6": _m(arch=effnet_b6_layers, params=EffNetParams(dropout=0.5)),
     "efficientnet-b7": _m(arch=effnet_b7_layers, params=EffNetParams(dropout=0.5)),
-    "efficientnet-widese-b0": _m(arch=replace(effnet_b0_layers, block=widese_mbconv), params=EffNetParams(dropout=0.2)),
+    "efficientnet-widese-b0": _m(arch=replace(effnet_b0_layers, block=widese_mbconv), params=EffNetParams(dropout=0.2), checkpoint_url=urls["efficientnet-widese-b0"]),
     "efficientnet-widese-b1": _m(arch=replace(effnet_b1_layers, block=widese_mbconv), params=EffNetParams(dropout=0.2)),
     "efficientnet-widese-b2": _m(arch=replace(effnet_b2_layers, block=widese_mbconv), params=EffNetParams(dropout=0.3)),
     "efficientnet-widese-b3": _m(arch=replace(effnet_b3_layers, block=widese_mbconv), params=EffNetParams(dropout=0.3)),
-    "efficientnet-widese-b4": _m(arch=replace(effnet_b4_layers, block=widese_mbconv), params=EffNetParams(dropout=0.4, survival_prob=0.8)),
+    "efficientnet-widese-b4": _m(arch=replace(effnet_b4_layers, block=widese_mbconv), params=EffNetParams(dropout=0.4, survival_prob=0.8), checkpoint_url=urls["efficientnet-widese-b4"]),
     "efficientnet-widese-b5": _m(arch=replace(effnet_b5_layers, block=widese_mbconv), params=EffNetParams(dropout=0.4)),
     "efficientnet-widese-b6": _m(arch=replace(effnet_b6_layers, block=widese_mbconv), params=EffNetParams(dropout=0.5)),
     "efficientnet-widese-b7": _m(arch=replace(effnet_b7_layers, block=widese_mbconv), params=EffNetParams(dropout=0.5)),
-    "efficientnet-quant-b0": _m(arch=effnet_b0_layers, params=EffNetParams(dropout=0.2, quantized=True)),
+    "efficientnet-quant-b0": _m(arch=effnet_b0_layers, params=EffNetParams(dropout=0.2, quantized=True), checkpoint_url=urls["efficientnet-quant-b0"]),
     "efficientnet-quant-b1": _m(arch=effnet_b1_layers, params=EffNetParams(dropout=0.2, quantized=True)),
     "efficientnet-quant-b2": _m(arch=effnet_b2_layers, params=EffNetParams(dropout=0.3, quantized=True)),
     "efficientnet-quant-b3": _m(arch=effnet_b3_layers, params=EffNetParams(dropout=0.3, quantized=True)),
-    "efficientnet-quant-b4": _m(arch=effnet_b4_layers, params=EffNetParams(dropout=0.4, survival_prob=0.8, quantized=True)),
+    "efficientnet-quant-b4": _m(arch=effnet_b4_layers, params=EffNetParams(dropout=0.4, survival_prob=0.8, quantized=True), checkpoint_url=urls["efficientnet-quant-b4"]),
     "efficientnet-quant-b5": _m(arch=effnet_b5_layers, params=EffNetParams(dropout=0.4, quantized=True)),
     "efficientnet-quant-b6": _m(arch=effnet_b6_layers, params=EffNetParams(dropout=0.5, quantized=True)),
     "efficientnet-quant-b7": _m(arch=effnet_b7_layers, params=EffNetParams(dropout=0.5, quantized=True)),

+ 38 - 10
PyTorch/Classification/ConvNets/image_classification/models/model.py

@@ -13,7 +13,9 @@ class ModelArch:
 @dataclass
 class ModelParams:
     def parser(self, name):
-        return argparse.ArgumentParser(description=f"{name} arguments", add_help=False, usage="")
+        return argparse.ArgumentParser(
+            description=f"{name} arguments", add_help=False, usage=""
+        )
 
 
 @dataclass
@@ -44,7 +46,9 @@ class EntryPoint:
         state_dict = None
         if pretrained:
             assert self.model.checkpoint_url is not None
-            state_dict = torch.hub.load_state_dict_from_url(self.model.checkpoint_url, map_location=torch.device('cpu'))
+            state_dict = torch.hub.load_state_dict_from_url(
+                self.model.checkpoint_url, map_location=torch.device("cpu")
+            )
 
         if pretrained_from_file is not None:
             if os.path.isfile(pretrained_from_file):
@@ -53,7 +57,9 @@ class EntryPoint:
                         pretrained_from_file
                     )
                 )
-                state_dict = torch.load(pretrained_from_file, map_location=torch.device('cpu'))
+                state_dict = torch.load(
+                    pretrained_from_file, map_location=torch.device("cpu")
+                )
             else:
                 print(
                     "=> no pretrained weights found at '{}'".format(
@@ -63,17 +69,40 @@ class EntryPoint:
         # Temporary fix to allow NGC checkpoint loading
         if state_dict is not None:
             state_dict = {
-                k[len("module."):] if k.startswith("module.") else k: v for k, v in state_dict.items()
+                k[len("module.") :] if k.startswith("module.") else k: v
+                for k, v in state_dict.items()
             }
+
+            def reshape(t, conv):
+                if conv:
+                    if len(t.shape) == 4:
+                        return t
+                    else:
+                        return t.view(t.shape[0], -1, 1, 1)
+                else:
+                    if len(t.shape) == 4:
+                        return t.view(t.shape[0], t.shape[1])
+                    else:
+                        return t
+
             state_dict = {
-                k: v.view(v.shape[0], -1, 1, 1) if is_linear_se_weight(k, v) else v for k, v in state_dict.items()
+                k: reshape(
+                    v,
+                    conv=dict(model.named_modules())[
+                        ".".join(k.split(".")[:-2])
+                    ].use_conv,
+                )
+                if is_se_weight(k, v)
+                else v
+                for k, v in state_dict.items()
             }
 
             model.load_state_dict(state_dict)
         return model
 
     def parser(self):
-        if self.model.params is None: return None
+        if self.model.params is None:
+            return None
         parser = self.model.params.parser(self.name)
         parser.add_argument(
             "--pretrained-from-file",
@@ -87,15 +116,14 @@ class EntryPoint:
                 "--pretrained",
                 default=False,
                 action="store_true",
-                help="load pretrained weights from NGC"
+                help="load pretrained weights from NGC",
             )
 
         return parser
 
 
-def is_linear_se_weight(key, value):
-    return (key.endswith('squeeze.weight') or key.endswith('expand.weight')) and len(value.shape) == 2
-
+def is_se_weight(key, value):
+    return (key.endswith("squeeze.weight") or key.endswith("expand.weight"))
 
 def create_entrypoint(m: Model):
     def _ep(**kwargs):

+ 10 - 3
PyTorch/Classification/ConvNets/image_classification/models/resnet.py

@@ -114,6 +114,7 @@ class Bottleneck(nn.Module):
         downsample=None,
         fused_se=True,
         last_bn_0_init=False,
+        trt=False,
     ):
         super(Bottleneck, self).__init__()
         self.conv1 = builder.conv1x1(inplanes, planes)
@@ -128,7 +129,7 @@ class Bottleneck(nn.Module):
 
         self.fused_se = fused_se
         self.squeeze = (
-            SqueezeAndExcitation(planes * expansion, se_squeeze, builder.activation())
+            SqueezeAndExcitation(planes * expansion, se_squeeze, builder.activation(), use_conv=trt)
             if se
             else None
         )
@@ -175,6 +176,7 @@ class SEBottleneck(Bottleneck):
         downsample=None,
         fused_se=True,
         last_bn_0_init=False,
+        trt=False,
     ):
         super(SEBottleneck, self).__init__(
             builder,
@@ -188,6 +190,7 @@ class SEBottleneck(Bottleneck):
             downsample=downsample,
             fused_se=fused_se,
             last_bn_0_init=last_bn_0_init,
+            trt=trt,
         )
 
 
@@ -211,6 +214,7 @@ class ResNet(nn.Module):
         num_classes: int = 1000
         last_bn_0_init: bool = False
         conv_init: str = "fan_in"
+        trt: bool = False
 
         def parser(self, name):
             p = super().parser(name)
@@ -235,7 +239,7 @@ class ResNet(nn.Module):
                 type=str,
                 help="initialization mode for convolutional layers, see https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_",
             )
-
+            p.add_argument("--trt", metavar="True|False", default=self.trt, type=bool)
             return p
 
     def __init__(
@@ -244,6 +248,7 @@ class ResNet(nn.Module):
         num_classes: int = 1000,
         last_bn_0_init: bool = False,
         conv_init: str = "fan_in",
+        trt: bool = False,
     ):
 
         super(ResNet, self).__init__()
@@ -269,6 +274,7 @@ class ResNet(nn.Module):
                 l,
                 cardinality=arch.cardinality,
                 stride=1 if i == 0 else 2,
+                trt=trt,
             )
             setattr(self, f"layer{i+1}", layer)
 
@@ -326,7 +332,7 @@ class ResNet(nn.Module):
 
     # helper functions {{{
     def _make_layer(
-        self, block, expansion, inplanes, planes, blocks, stride=1, cardinality=1
+        self, block, expansion, inplanes, planes, blocks, stride=1, cardinality=1, trt=False,
     ):
         downsample = None
         if stride != 1 or inplanes != planes * expansion:
@@ -350,6 +356,7 @@ class ResNet(nn.Module):
                     downsample=downsample if i == 0 else None,
                     fused_se=True,
                     last_bn_0_init=self.last_bn_0_init,
+                    trt = trt,
                 )
             )
             inplanes = planes * expansion

+ 4 - 0
README.md

@@ -18,6 +18,10 @@ These examples, along with our NVIDIA deep learning software stack, are provided
 | [ResNet-50](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnet50v1.5)  |PyTorch  | Yes  | Yes  | Yes  | -  | Yes  | -  | [Yes](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/triton/resnet50)  | Yes  | - |
 | [ResNeXt-101](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnext101-32x4d)  |PyTorch  | Yes  | Yes  | Yes  | -  | Yes  |   -  | [Yes](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/triton/resnext101-32x4d)  | Yes  | - |
 | [SE-ResNeXt-101](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/se-resnext101-32x4d)  |PyTorch  | Yes  | Yes  | Yes  | -  | Yes  | -  | [Yes](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/triton/se-resnext101-32x4d)  | Yes  | - |
+| [EfficientNet-B0](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/efficientnet)  |PyTorch  | Yes  | Yes  | Yes  | -  | - | - | - | Yes  | - |
+| [EfficientNet-B4](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/efficientnet)  |PyTorch  | Yes  | Yes  | Yes  | -  | - | - | - | Yes  | - |
+| [EfficientNet-WideSE-B0](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/efficientnet)  |PyTorch  | Yes  | Yes  | Yes  | -  | - | - | - | Yes  | - |
+| [EfficientNet-WideSE-B4](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/efficientnet)  |PyTorch  | Yes  | Yes  | Yes  | -  | - | - | - | Yes  | - |
 | [Mask R-CNN](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/MaskRCNN) |PyTorch  | Yes  | Yes  | Yes  | -  | -  |   -  | -  | -  | [Yes](https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Segmentation/MaskRCNN/pytorch/notebooks/pytorch_MaskRCNN_pyt_train_and_inference.ipynb) |
 | [nnUNet](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Segmentation/nnUNet) |PyTorch  | Yes  | Yes  | Yes  | -  | -  |   -  | -  | Yes  | - |
 | [SSD](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Detection/SSD) |PyTorch  | Yes  | Yes  | Yes  | -  | -  |   -  | -  | -  | [Yes](https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Detection/SSD/examples/inference.ipynb) |