Эх сурвалжийг харах

Merge pull request #714 from NVIDIA/gh/release

[ConvNets/Pyt] Triton Deployment
nv-kkudrynski 5 жил өмнө
parent
commit
ac059022a5
24 өөрчлөгдсөн 2046 нэмэгдсэн , 11 устгасан
  1. 1 1
      PyTorch/Classification/ConvNets/Dockerfile
  2. 18 3
      PyTorch/Classification/ConvNets/image_classification/dataloaders.py
  3. 25 5
      PyTorch/Classification/ConvNets/image_classification/resnet.py
  4. 1 1
      PyTorch/Classification/ConvNets/image_classification/utils.py
  5. 1 1
      PyTorch/Classification/ConvNets/main.py
  6. 36 0
      PyTorch/Classification/ConvNets/triton/Dockerfile
  7. 115 0
      PyTorch/Classification/ConvNets/triton/client.py
  8. 105 0
      PyTorch/Classification/ConvNets/triton/deployer.py
  9. 911 0
      PyTorch/Classification/ConvNets/triton/deployer_lib.py
  10. BIN
      PyTorch/Classification/ConvNets/triton/resnet50/Latency-vs-Throughput-TensorRT.png
  11. BIN
      PyTorch/Classification/ConvNets/triton/resnet50/Performance-analysis-TensorRT-FP16.png
  12. BIN
      PyTorch/Classification/ConvNets/triton/resnet50/Performance-analysis-TensorRT-FP32.png
  13. 248 0
      PyTorch/Classification/ConvNets/triton/resnet50/README.md
  14. BIN
      PyTorch/Classification/ConvNets/triton/resnext101-32x4d/Latency-vs-Throughput-TensorRT.png
  15. BIN
      PyTorch/Classification/ConvNets/triton/resnext101-32x4d/Performance-analysis-TensorRT-FP16.png
  16. BIN
      PyTorch/Classification/ConvNets/triton/resnext101-32x4d/Performance-analysis-TensorRT-FP32.png
  17. 243 0
      PyTorch/Classification/ConvNets/triton/resnext101-32x4d/README.md
  18. 53 0
      PyTorch/Classification/ConvNets/triton/scripts/benchmark.sh
  19. 24 0
      PyTorch/Classification/ConvNets/triton/scripts/get_metrics_static.sh
  20. 20 0
      PyTorch/Classification/ConvNets/triton/scripts/process_output.sh
  21. BIN
      PyTorch/Classification/ConvNets/triton/se-resnext101-32x4d/Latency-vs-Throughput-TensorRT.png
  22. BIN
      PyTorch/Classification/ConvNets/triton/se-resnext101-32x4d/Performance-analysis-TensorRT-FP16.png
  23. BIN
      PyTorch/Classification/ConvNets/triton/se-resnext101-32x4d/Performance-analysis-TensorRT-FP32.png
  24. 245 0
      PyTorch/Classification/ConvNets/triton/se-resnext101-32x4d/README.md

+ 1 - 1
PyTorch/Classification/ConvNets/Dockerfile

@@ -1,4 +1,4 @@
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.06-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.07-py3
 FROM ${FROM_IMAGE_NAME}
 
 ADD requirements.txt /workspace/

+ 18 - 3
PyTorch/Classification/ConvNets/image_classification/dataloaders.py

@@ -200,7 +200,7 @@ class DALIWrapper(object):
 
     def __iter__(self):
         return DALIWrapper.gen_wrapper(
-                self.dalipipeline, self.num_classes, self.one_hot, self.memory_format
+            self.dalipipeline, self.num_classes, self.one_hot, self.memory_format
         )
 
 
@@ -472,7 +472,10 @@ class SynteticDataLoader(object):
         memory_format=torch.contiguous_format,
     ):
         input_data = (
-            torch.empty(batch_size, num_channels, height, width).contiguous(memory_format=memory_format).cuda().normal_(0, 1.0)
+            torch.empty(batch_size, num_channels, height, width)
+            .contiguous(memory_format=memory_format)
+            .cuda()
+            .normal_(0, 1.0)
         )
         if one_hot:
             input_target = torch.empty(batch_size, num_classes).cuda()
@@ -502,4 +505,16 @@ def get_syntetic_loader(
     fp16=False,
     memory_format=torch.contiguous_format,
 ):
-    return SynteticDataLoader(fp16, batch_size, num_classes, 3, 224, 224, one_hot, memory_format=memory_format), -1
+    return (
+        SynteticDataLoader(
+            fp16,
+            batch_size,
+            num_classes,
+            3,
+            224,
+            224,
+            one_hot,
+            memory_format=memory_format,
+        ),
+        -1,
+    )

+ 25 - 5
PyTorch/Classification/ConvNets/image_classification/resnet.py

@@ -100,6 +100,7 @@ class ResNetBuilder(object):
 
 # ResNetBuilder }}}
 
+
 # BasicBlock {{{
 class BasicBlock(nn.Module):
     def __init__(self, builder, inplanes, planes, expansion, stride=1, downsample=None):
@@ -137,6 +138,7 @@ class BasicBlock(nn.Module):
 
 # BasicBlock }}}
 
+
 # SqueezeAndExcitation {{{
 class SqueezeAndExcitation(nn.Module):
     def __init__(self, planes, squeeze):
@@ -159,6 +161,7 @@ class SqueezeAndExcitation(nn.Module):
 
 # }}}
 
+
 # Bottleneck {{{
 class Bottleneck(nn.Module):
     def __init__(
@@ -171,6 +174,7 @@ class Bottleneck(nn.Module):
         se=False,
         se_squeeze=16,
         downsample=None,
+        fused_se=True,
     ):
         super(Bottleneck, self).__init__()
         self.conv1 = builder.conv1x1(inplanes, planes)
@@ -182,6 +186,8 @@ class Bottleneck(nn.Module):
         self.relu = builder.activation()
         self.downsample = downsample
         self.stride = stride
+
+        self.fused_se = fused_se
         self.squeeze = (
             SqueezeAndExcitation(planes * expansion, se_squeeze) if se else None
         )
@@ -206,14 +212,19 @@ class Bottleneck(nn.Module):
         if self.squeeze is None:
             out += residual
         else:
-            out = torch.addcmul(residual, 1.0, out, self.squeeze(out))
+            if self.fused_se:
+                out = torch.addcmul(residual, out, self.squeeze(out), value=1)
+            else:
+                out = residual + out * self.squeeze(out)
 
         out = self.relu(out)
 
         return out
 
 
-def SEBottleneck(builder, inplanes, planes, expansion, stride=1, downsample=None):
+def SEBottleneck(
+    builder, inplanes, planes, expansion, stride=1, downsample=None, fused_se=True
+):
     return Bottleneck(
         builder,
         inplanes,
@@ -223,15 +234,20 @@ def SEBottleneck(builder, inplanes, planes, expansion, stride=1, downsample=None
         se=True,
         se_squeeze=16,
         downsample=downsample,
+        fused_se=fused_se,
     )
 
 
 # Bottleneck }}}
 
+
 # ResNet {{{
 class ResNet(nn.Module):
-    def __init__(self, builder, block, expansion, layers, widths, num_classes=1000):
+    def __init__(
+        self, builder, block, expansion, layers, widths, num_classes=1000, fused_se=True
+    ):
         self.inplanes = 64
+        self.fused_se = fused_se
         super(ResNet, self).__init__()
         self.conv1 = builder.conv7x7(3, 64, stride=2)
         self.bn1 = builder.batchnorm(64)
@@ -269,11 +285,14 @@ class ResNet(nn.Module):
                 expansion,
                 stride=stride,
                 downsample=downsample,
+                fused_se=self.fused_se,
             )
         )
         self.inplanes = planes * expansion
         for i in range(1, blocks):
-            layers.append(block(builder, self.inplanes, planes, expansion))
+            layers.append(
+                block(builder, self.inplanes, planes, expansion, fused_se=self.fused_se)
+            )
 
         return nn.Sequential(*layers)
 
@@ -384,7 +403,7 @@ resnet_versions = {
 }
 
 
-def build_resnet(version, config, num_classes, verbose=True):
+def build_resnet(version, config, num_classes, verbose=True, fused_se=True):
     version = resnet_versions[version]
     config = resnet_configs[config]
 
@@ -400,6 +419,7 @@ def build_resnet(version, config, num_classes, verbose=True):
         version["layers"],
         version["widths"],
         num_classes,
+        fused_se,
     )
 
     return model

+ 1 - 1
PyTorch/Classification/ConvNets/image_classification/utils.py

@@ -89,7 +89,7 @@ def accuracy(output, target, topk=(1,)):
 
     res = []
     for k in topk:
-        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+        correct_k = correct[:k].float().sum()
         res.append(correct_k.mul_(100.0 / batch_size))
     return res
 

+ 1 - 1
PyTorch/Classification/ConvNets/main.py

@@ -276,7 +276,7 @@ def add_parser_arguments(parser):
     )
 
     parser.add_argument("--checkpoint-filename", default="checkpoint.pth.tar", type=str)
-    
+
     parser.add_argument(
         "--workspace",
         type=str,

+ 36 - 0
PyTorch/Classification/ConvNets/triton/Dockerfile

@@ -0,0 +1,36 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.07-py3
+ARG TRITON_BASE_IMAGE=nvcr.io/nvidia/tritonserver:20.07-py3-clientsdk
+
+FROM ${TRITON_BASE_IMAGE} as trt
+FROM ${FROM_IMAGE_NAME}
+
+ADD requirements.txt .
+RUN pip install -r requirements.txt
+RUN pip install onnxruntime
+
+COPY --from=trt /workspace/v2.1.0.clients.tar.gz ./v2.1.0.clients.tar.gz
+COPY --from=trt /workspace/install/bin/perf_client /bin/perf_client
+
+RUN tar -xzf v2.1.0.clients.tar.gz \
+    && pip install ./python/tritonclientutils-2.1.0-py3-none-any.whl
+
+
+RUN apt update && apt install -y libb64-0d
+
+WORKDIR /workspace/rn50
+COPY . .
+

+ 115 - 0
PyTorch/Classification/ConvNets/triton/client.py

@@ -0,0 +1,115 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import json
+import os
+
+import numpy as np
+import torch
+import torchvision.datasets as datasets
+import torchvision.transforms as transforms
+
+from image_classification.dataloaders import get_pytorch_val_loader
+
+from tqdm import tqdm
+
+import tritongrpcclient
+from tritonclientutils import InferenceServerException
+
+
+def get_data_loader(batch_size, *, data_path):
+    valdir = os.path.join(data_path, "val-jpeg")
+    val_dataset = datasets.ImageFolder(
+        valdir,
+        transforms.Compose(
+            [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
+        ),
+    )
+
+    val_loader = torch.utils.data.DataLoader(
+        val_dataset, batch_size=batch_size, shuffle=False
+    )
+
+    return val_loader
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--triton-server-url",
+        type=str,
+        required=True,
+        help="URL adress of trtion server (with port)",
+    )
+    parser.add_argument(
+        "--triton-model-name",
+        type=str,
+        required=True,
+        help="Triton deployed model name",
+    )
+    parser.add_argument(
+        "-v", "--verbose", action="store_true", default=False, help="Verbose mode."
+    )
+
+    parser.add_argument(
+        "--inference_data", type=str, help="Path to file with inference data."
+    )
+    parser.add_argument(
+        "--batch_size", type=int, default=1, help="Inference request batch size"
+    )
+    parser.add_argument(
+        "--fp16",
+        action="store_true",
+        default=False,
+        help="Use fp16 precision for input data",
+    )
+    FLAGS = parser.parse_args()
+
+    triton_client = tritongrpcclient.InferenceServerClient(
+        url=FLAGS.triton_server_url, verbose=FLAGS.verbose
+    )
+    dataloader = get_data_loader(FLAGS.batch_size, data_path=FLAGS.inference_data)
+
+    inputs = []
+    inputs.append(
+        tritongrpcclient.InferInput(
+            "input__0",
+            [FLAGS.batch_size, 3, 224, 224],
+            "FP16" if FLAGS.fp16 else "FP32",
+        )
+    )
+
+    outputs = []
+    outputs.append(tritongrpcclient.InferRequestedOutput("output__0"))
+
+    all_img = 0
+    cor_img = 0
+
+    result_prev = None
+    for image, target in tqdm(dataloader):
+        if FLAGS.fp16:
+            image = image.half()
+        inputs[0].set_data_from_numpy(image.numpy())
+
+        result = triton_client.infer(
+            FLAGS.triton_model_name, inputs, outputs=outputs, headers=None
+        )
+        result = result.as_numpy("output__0")
+        result = np.argmax(result, axis=1)
+        cor_img += np.sum(result == target.numpy())
+        all_img += result.shape[0]
+
+    acc = cor_img / all_img
+    print(f"Final accuracy {acc:.04f}")

+ 105 - 0
PyTorch/Classification/ConvNets/triton/deployer.py

@@ -0,0 +1,105 @@
+#!/usr/bin/python
+
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+import torch
+import argparse
+import triton.deployer_lib as deployer_lib
+
+
+def get_model_args(model_args):
+    """ the arguments initialize_model will receive """
+    parser = argparse.ArgumentParser()
+    ## Required parameters by the model.
+    parser.add_argument(
+        "--config",
+        default="resnet50",
+        type=str,
+        required=True,
+        help="Network to deploy",
+    )
+    parser.add_argument(
+        "--checkpoint", default=None, type=str, help="The checkpoint of the model. "
+    )
+    parser.add_argument(
+        "--batch_size", default=1000, type=int, help="Batch size for inference"
+    )
+    parser.add_argument(
+        "--fp16", default=False, action="store_true", help="FP16 inference"
+    )
+    parser.add_argument(
+        "--dump_perf_data",
+        type=str,
+        default=None,
+        help="Directory to dump perf data sample for testing",
+    )
+    return parser.parse_args(model_args)
+
+
+def initialize_model(args):
+    """ return model, ready to trace """
+    from image_classification.resnet import build_resnet
+
+    model = build_resnet(args.config, "fanin", 1000, fused_se=False)
+
+    if args.checkpoint:
+        state_dict = torch.load(args.checkpoint, map_location="cpu")
+        model.load_state_dict(
+            {k.replace("module.", ""): v for k, v in state_dict.items()}
+        )
+        model.load_state_dict(state_dict)
+    return model.half() if args.fp16 else model
+
+
+def get_dataloader(args):
+    """ return dataloader for inference """
+    from image_classification.dataloaders import get_syntetic_loader
+
+    def data_loader():
+        loader, _ = get_syntetic_loader(None, 128, 1000, True, fp16=args.fp16)
+        processed = 0
+        for inp, _ in loader:
+            yield inp
+            processed += 1
+            if processed > 10:
+                break
+
+    return data_loader()
+
+
+if __name__ == "__main__":
+    # don't touch this!
+    deployer, model_argv = deployer_lib.create_deployer(
+        sys.argv[1:]
+    )  # deployer and returns removed deployer arguments
+
+    model_args = get_model_args(model_argv)
+
+    model = initialize_model(model_args)
+    dataloader = get_dataloader(model_args)
+
+    if model_args.dump_perf_data:
+        input_0 = next(iter(dataloader))
+        if model_args.fp16:
+            input_0 = input_0.half()
+
+        os.makedirs(model_args.dump_perf_data, exist_ok=True)
+        input_0.detach().cpu().numpy()[0].tofile(
+            os.path.join(model_args.dump_perf_data, "input__0")
+        )
+
+    deployer.deploy(dataloader, model)

+ 911 - 0
PyTorch/Classification/ConvNets/triton/deployer_lib.py

@@ -0,0 +1,911 @@
+#!/usr/bin/python
+
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import time
+import json
+import torch
+import argparse
+import statistics
+from collections import Counter
+
+torch_type_to_triton_type = {
+    torch.bool: "TYPE_BOOL",
+    torch.int8: "TYPE_INT8",
+    torch.int16: "TYPE_INT16",
+    torch.int32: "TYPE_INT32",
+    torch.int64: "TYPE_INT64",
+    torch.uint8: "TYPE_UINT8",
+    torch.float16: "TYPE_FP16",
+    torch.float32: "TYPE_FP32",
+    torch.float64: "TYPE_FP64",
+}
+
+CONFIG_TEMPLATE = r"""
+name: "{model_name}"
+platform: "{platform}"
+max_batch_size: {max_batch_size}
+input [
+    {spec_inputs}
+]
+output [
+    {spec_outputs}
+]
+{dynamic_batching}
+{model_optimizations}
+instance_group [
+    {{
+        count: {engine_count}
+        kind: KIND_GPU
+        gpus: [ {gpu_list} ]
+    }}
+]"""
+
+INPUT_TEMPLATE = r"""
+{{
+    name: "input__{num}"
+    data_type: {type}
+    dims: {dims}
+    {reshape}
+}},"""
+
+OUTPUT_TEMPLATE = r""" 
+{{
+    name: "output__{num}"
+    data_type: {type}
+    dims: {dims}
+    {reshape}
+}},"""
+
+MODEL_OPTIMIZATION_TEMPLATE = r"""
+optimization {{
+  {execution_accelerator}
+  cuda {{
+    graphs: {capture_cuda_graph}
+  }}
+}}"""
+
+EXECUTION_ACCELERATOR_TEMPLATE = r"""
+  execution_accelerators {{
+    gpu_execution_accelerator: [
+      {{
+        name: "tensorrt"
+      }}
+    ]
+  }},"""
+
+
+def remove_empty_lines(text):
+    """ removes empty lines from text, returns the result """
+    ret = "".join([s for s in text.strip().splitlines(True) if s.strip()])
+    return ret
+
+
+def create_deployer(argv):
+    """ takes a list of arguments, returns a deployer object and the list of unused arguments """
+    parser = argparse.ArgumentParser()
+    # required args
+    method = parser.add_mutually_exclusive_group(required=True)
+    method.add_argument(
+        "--ts-script",
+        action="store_true",
+        help="convert to torchscript using torch.jit.script",
+    )
+    method.add_argument(
+        "--ts-trace",
+        action="store_true",
+        help="convert to torchscript using torch.jit.trace",
+    )
+    method.add_argument(
+        "--onnx", action="store_true", help="convert to onnx using torch.onnx.export"
+    )
+    method.add_argument(
+        "--trt", action="store_true", help="convert to trt using tensorrt"
+    )
+    # triton related args
+    arguments = parser.add_argument_group("triton related flags")
+    arguments.add_argument(
+        "--triton-no-cuda", action="store_true", help="Use the CPU for tracing."
+    )
+    arguments.add_argument(
+        "--triton-model-name",
+        type=str,
+        default="model",
+        help="exports to appropriate directory structure for TRITON",
+    )
+    arguments.add_argument(
+        "--triton-model-version",
+        type=int,
+        default=1,
+        help="exports to appropriate directory structure for TRITON",
+    )
+    arguments.add_argument(
+        "--triton-max-batch-size",
+        type=int,
+        default=8,
+        help="Specifies the 'max_batch_size' in the TRITON model config.\
+                                  See the TRITON documentation for more info.",
+    )
+    arguments.add_argument(
+        "--triton-dyn-batching-delay",
+        type=float,
+        default=0,
+        help="Determines the dynamic_batching queue delay in milliseconds(ms) for\
+                                  the TRITON model config. Use '0' or '-1' to specify static batching.\
+                                  See the TRITON documentation for more info.",
+    )
+    arguments.add_argument(
+        "--triton-engine-count",
+        type=int,
+        default=1,
+        help="Specifies the 'instance_group' count value in the TRITON model config.\
+                                  See the TRITON documentation for more info.",
+    )
+    arguments.add_argument(
+        "--save-dir", type=str, default="./triton_models", help="Saved model directory"
+    )
+    # optimization args
+    arguments = parser.add_argument_group("optimization flags")
+    arguments.add_argument(
+        "--max_workspace_size",
+        type=int,
+        default=512 * 1024 * 1024,
+        help="set the size of the workspace for trt export",
+    )
+    arguments.add_argument(
+        "--trt-fp16",
+        action="store_true",
+        help="trt flag ---- export model in mixed precision mode",
+    )
+    arguments.add_argument(
+        "--capture-cuda-graph",
+        type=int,
+        default=1,
+        help="capture cuda graph for obtaining speedup. possible values: 0, 1. default: 1. ",
+    )
+
+    # remainder args
+    arguments.add_argument(
+        "model_arguments",
+        nargs=argparse.REMAINDER,
+        help="arguments that will be ignored by deployer lib and will be forwarded to your deployer script",
+    )
+    #
+    args = parser.parse_args(argv)
+    deployer = Deployer(args)
+    #
+    return deployer, args.model_arguments[1:]
+
+
+class DeployerLibrary:
+    def __init__(self, args):
+        self.args = args
+        self.platform = None
+
+    def set_platform(self, platform):
+        """ sets the platform
+            :: platform :: "pytorch_libtorch" or "onnxruntime_onnx" or "tensorrt_plan"
+        """
+        self.platform = platform
+
+    def build_trt_engine(self, model_file, shapes):
+        """ takes a path to an onnx file, and shape information, returns a trt engine
+            :: model_file :: path to an onnx model
+            :: shapes :: dictionary containing min shape, max shape, opt shape for the trt engine
+        """
+        import tensorrt as trt
+
+        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
+        builder = trt.Builder(TRT_LOGGER)
+        builder.fp16_mode = self.args.trt_fp16
+        builder.max_batch_size = self.args.triton_max_batch_size
+        #
+        config = builder.create_builder_config()
+        config.max_workspace_size = self.args.max_workspace_size
+        if self.args.trt_fp16:
+            config.flags |= 1 << int(trt.BuilderFlag.FP16)
+        profile = builder.create_optimization_profile()
+        for s in shapes:
+            profile.set_shape(s["name"], min=s["min"], opt=s["opt"], max=s["max"])
+        config.add_optimization_profile(profile)
+        explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+        network = builder.create_network(explicit_batch)
+        #
+        with trt.OnnxParser(network, TRT_LOGGER) as parser:
+            with open(model_file, "rb") as model:
+                parser.parse(model.read())
+                for i in range(parser.num_errors):
+                    e = parser.get_error(i)
+                    print("||||e", e)
+                engine = builder.build_engine(network, config=config)
+        return engine
+
+    def load_engine(self, engine_filepath):
+        """ loads a trt engine from engine_filepath, returns it """
+        import tensorrt as trt
+
+        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
+        with open(engine_filepath, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
+            engine = runtime.deserialize_cuda_engine(f.read())
+        return engine
+
+    def prepare_inputs(self, dataloader, device):
+        """ load sample inputs to device """
+        inputs = []
+        for batch in dataloader:
+            if type(batch) is torch.Tensor:
+                batch_d = batch.to(device)
+                batch_d = (batch_d,)
+                inputs.append(batch_d)
+            else:
+                batch_d = []
+                for x in batch:
+                    assert type(x) is torch.Tensor, "input is not a tensor"
+                    batch_d.append(x.to(device))
+                batch_d = tuple(batch_d)
+                inputs.append(batch_d)
+        return inputs
+
+    def get_list_of_shapes(self, l, fun):
+        """ returns the list of min/max shapes, depending on fun
+            :: l :: list of tuples of tensors
+            :: fun :: min or max
+        """
+        tensor_tuple = l[0]
+        shapes = [list(x.shape) for x in tensor_tuple]
+        for tensor_tuple in l:
+            assert len(tensor_tuple) == len(
+                shapes
+            ), "tensors with varying shape lengths are not supported"
+            for i, x in enumerate(tensor_tuple):
+                for j in range(len(x.shape)):
+                    shapes[i][j] = fun(shapes[i][j], x.shape[j])
+        return shapes  # a list of shapes
+
+    def get_tuple_of_min_shapes(self, l):
+        """ returns the tuple of min shapes 
+            :: l :: list of tuples of tensors """
+        shapes = self.get_list_of_shapes(l, min)
+        min_batch = 1
+        shapes = [[min_batch, *shape[1:]] for shape in shapes]
+        shapes = tuple(shapes)
+        return shapes  # tuple of min shapes
+
+    def get_tuple_of_max_shapes(self, l):
+        """ returns the tuple of max shapes 
+            :: l :: list of tuples of tensors """
+        shapes = self.get_list_of_shapes(l, max)
+        max_batch = max(2, shapes[0][0])
+        shapes = [[max_batch, *shape[1:]] for shape in shapes]
+        shapes = tuple(shapes)
+        return shapes  # tuple of max shapes
+
+    def get_tuple_of_opt_shapes(self, l):
+        """ returns the tuple of opt shapes 
+            :: l :: list of tuples of tensors """
+        counter = Counter()
+        for tensor_tuple in l:
+            shapes = [tuple(x.shape) for x in tensor_tuple]
+            shapes = tuple(shapes)
+            counter[shapes] += 1
+        shapes = counter.most_common(1)[0][0]
+        return shapes  # tuple of most common occuring shapes
+
+    def get_tuple_of_dynamic_shapes(self, l):
+        """ returns a tuple of dynamic shapes: variable tensor dimensions 
+            (for ex. batch size) occur as -1 in the tuple
+            :: l :: list of tuples of tensors """
+        tensor_tuple = l[0]
+        shapes = [list(x.shape) for x in tensor_tuple]
+        for tensor_tuple in l:
+            err_msg = "tensors with varying shape lengths are not supported"
+            assert len(tensor_tuple) == len(shapes), err_msg
+            for i, x in enumerate(tensor_tuple):
+                for j in range(len(x.shape)):
+                    if shapes[i][j] != x.shape[j] or j == 0:
+                        shapes[i][j] = -1
+        shapes = tuple(shapes)
+        return shapes  # tuple of dynamic shapes
+
+    def run_models(self, models, inputs):
+        """ run the models on inputs, return the outputs and execution times """
+        ret = []
+        for model in models:
+            torch.cuda.synchronize()
+            time_start = time.time()
+            outputs = []
+            for input in inputs:
+                with torch.no_grad():
+                    output = model(*input)
+                if type(output) is torch.Tensor:
+                    output = [output]
+                outputs.append(output)
+            torch.cuda.synchronize()
+            time_end = time.time()
+            t = time_end - time_start
+            ret.append(outputs)
+            ret.append(t)
+        return ret
+
+    def compute_tensor_stats(self, tensor):
+        return {
+            "std": tensor.std().item(),
+            "mean": tensor.mean().item(),
+            "max": tensor.max().item(),
+            "min": tensor.min().item(),
+        }
+
+    def compute_errors(self, outputs_A, outputs_B):
+        """ returns dictionary with errors statistics """
+        device = outputs_A[0][0][0].device
+        dtype = outputs_A[0][0][0].dtype
+        x_values = torch.zeros(0, device=device, dtype=dtype)
+        y_values = torch.zeros(0, device=device, dtype=dtype)
+        d_values = torch.zeros(0, device=device, dtype=dtype)
+        for output_A, output_B in zip(outputs_A, outputs_B):
+            for x, y in zip(output_A, output_B):
+                d = abs(x - y)
+                x_values = torch.cat((x_values, x), 0)
+                y_values = torch.cat((y_values, y), 0)
+                d_values = torch.cat((d_values, d), 0)
+        Error_stats = {
+            "Original": self.compute_tensor_stats(x_values),
+            "Converted": self.compute_tensor_stats(y_values),
+            "Absolute difference": self.compute_tensor_stats(d_values),
+        }
+        return Error_stats
+
+    def print_errors(self, Error_stats):
+        """ print various statistcs of Linf errors """
+        print()
+        print("conversion correctness test results")
+        print("-----------------------------------")
+        import pandas as pd
+
+        print(pd.DataFrame(Error_stats))
+
+    def write_config(
+        self, config_filename, input_shapes, input_types, output_shapes, output_types
+    ):
+        """ writes TRTIS config file 
+            :: config_filename :: the file to write the config file into
+            :: input_shapes :: tuple of dynamic shapes of the input tensors
+            :: input_types :: tuple of torch types of the input tensors
+            :: output_shapes :: tuple of dynamic shapes of the output tensors
+            :: output_types :: tuple of torch types of the output tensors
+        """
+        assert self.platform is not None, "error - platform is not set"
+
+        config_template = CONFIG_TEMPLATE
+        input_template = INPUT_TEMPLATE
+        optimization_template = MODEL_OPTIMIZATION_TEMPLATE
+        accelerator_template = EXECUTION_ACCELERATOR_TEMPLATE
+
+        spec_inputs = r""""""
+        for i, (shape, typ) in enumerate(zip(input_shapes, input_types)):
+            d = {
+                "num": str(i),
+                "type": torch_type_to_triton_type[typ],
+                "dims": str([1])
+                if len(shape) == 1
+                else str(list(shape)[1:]),  # first dimension is the batch size
+            }
+            d["reshape"] = "reshape: { shape: [ ] }" if len(shape) == 1 else ""
+            spec_inputs += input_template.format_map(d)
+        spec_inputs = spec_inputs[:-1]
+
+        output_template = OUTPUT_TEMPLATE
+        spec_outputs = r""""""
+        for i, (shape, typ) in enumerate(zip(output_shapes, output_types)):
+            d = {
+                "num": str(i),
+                "type": torch_type_to_triton_type[typ],
+                "dims": str([1])
+                if len(shape) == 1
+                else str(list(shape)[1:]),  # first dimension is the batch size
+            }
+            d["reshape"] = "reshape: { shape: [ ] }" if len(shape) == 1 else ""
+            spec_outputs += output_template.format_map(d)
+        spec_outputs = spec_outputs[:-1]
+
+        batching_str = ""
+        max_batch_size = self.args.triton_max_batch_size
+
+        if self.args.triton_dyn_batching_delay >= 0:
+            # Use only full and half full batches
+            pref_batch_size = [int(max_batch_size / 2.0), max_batch_size]
+
+            if self.args.triton_dyn_batching_delay > 0:
+                dyn_batch_delay_str = f"max_queue_delay_microseconds: {int(self.args.triton_dyn_batching_delay * 1000.0)}"
+            else:
+                dyn_batch_delay_str = ""
+
+            batching_str = r"""
+dynamic_batching {{
+    preferred_batch_size: [{0}]
+    {1}
+}}""".format(
+                ", ".join([str(x) for x in pref_batch_size]), dyn_batch_delay_str
+            )
+
+        accelerator_str = ""
+
+        d = {
+            "execution_accelerator": accelerator_str,
+            "capture_cuda_graph": str(self.args.capture_cuda_graph),
+        }
+        optimization_str = optimization_template.format_map(d)
+
+        config_values = {
+            "model_name": self.args.triton_model_name,
+            "platform": self.platform,
+            "max_batch_size": max_batch_size,
+            "spec_inputs": spec_inputs,
+            "spec_outputs": spec_outputs,
+            "dynamic_batching": batching_str,
+            "model_optimizations": optimization_str,
+            "gpu_list": ", ".join([str(x) for x in range(torch.cuda.device_count())]),
+            "engine_count": self.args.triton_engine_count,
+        }
+
+        # write config
+        with open(config_filename, "w") as file:
+            final_config_str = config_template.format_map(config_values)
+            final_config_str = remove_empty_lines(final_config_str)
+            file.write(final_config_str)
+
+
+class Deployer:
+    def __init__(self, args):
+        self.args = args
+        self.lib = DeployerLibrary(args)
+
+    def deploy(self, dataloader, model):
+        """ deploy the model and test for correctness with dataloader """
+        if self.args.ts_script or self.args.ts_trace:
+            self.lib.set_platform("pytorch_libtorch")
+            print(
+                "deploying model "
+                + self.args.triton_model_name
+                + " in format "
+                + self.lib.platform
+            )
+            self.to_triton_torchscript(dataloader, model)
+        elif self.args.onnx:
+            self.lib.set_platform("onnxruntime_onnx")
+            print(
+                "deploying model "
+                + self.args.triton_model_name
+                + " in format "
+                + self.lib.platform
+            )
+            self.to_triton_onnx(dataloader, model)
+        elif self.args.trt:
+            self.lib.set_platform("tensorrt_plan")
+            print(
+                "deploying model "
+                + self.args.triton_model_name
+                + " in format "
+                + self.lib.platform
+            )
+            self.to_triton_trt(dataloader, model)
+        else:
+            assert False, "error"
+        print("done")
+
+    def to_triton_trt(self, dataloader, model):
+        """ export the model to trt and test correctness on dataloader """
+        import tensorrt as trt
+
+        # setup device
+        if self.args.triton_no_cuda:
+            device = torch.device("cpu")
+        else:
+            device = torch.device("cuda")
+
+        # prepare model
+        model.to(device)
+        model.eval()
+        assert not model.training, "internal error - model should be in eval() mode! "
+
+        # prepare inputs
+        inputs = self.lib.prepare_inputs(dataloader, device)
+
+        # generate outputs
+        outputs = []
+        for input in inputs:
+            with torch.no_grad():
+                output = model(*input)
+            if type(output) is torch.Tensor:
+                output = [output]
+            outputs.append(output)
+
+        # generate input shapes - dynamic tensor shape support
+        input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
+
+        # generate output shapes - dynamic tensor shape support
+        output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
+
+        # generate input types
+        input_types = [x.dtype for x in inputs[0]]
+
+        # generate output types
+        output_types = [x.dtype for x in outputs[0]]
+
+        # get input names
+        rng = range(len(input_types))
+        input_names = ["input__" + str(num) for num in rng]
+
+        # get output names
+        rng = range(len(output_types))
+        output_names = ["output__" + str(num) for num in rng]
+
+        # prepare save path
+        model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
+        version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
+        if not os.path.exists(version_folder):
+            os.makedirs(version_folder)
+        final_model_path = os.path.join(version_folder, "model.plan")
+
+        # get indices of dynamic input and output shapes
+        dynamic_axes = {}
+        for input_name, shape in zip(input_names, input_shapes):
+            dynamic_axes[input_name] = [i for i, x in enumerate(shape) if x == -1]
+        for output_name, shape in zip(output_names, output_shapes):
+            dynamic_axes[output_name] = [i for i, x in enumerate(shape) if x == -1]
+
+        # export the model to onnx first
+        with torch.no_grad():
+            torch.onnx.export(
+                model,
+                inputs[0],
+                final_model_path,
+                verbose=False,
+                input_names=input_names,
+                output_names=output_names,
+                dynamic_axes=dynamic_axes,
+                opset_version=11,
+            )
+
+        # get shapes
+        min_shapes = self.lib.get_tuple_of_min_shapes(inputs)
+        opt_shapes = self.lib.get_tuple_of_opt_shapes(inputs)
+        max_shapes = self.lib.get_tuple_of_max_shapes(inputs)
+
+        zipped = zip(input_names, min_shapes, opt_shapes, max_shapes)
+        shapes = []
+        for name, min_shape, opt_shape, max_shape in zipped:
+            d = {"name": name, "min": min_shape, "opt": opt_shape, "max": max_shape}
+            shapes.append(d)
+
+        # build trt engine
+        engine = self.lib.build_trt_engine(final_model_path, shapes)
+        assert engine is not None, " trt export failure "
+
+        # write trt engine
+        with open(final_model_path, "wb") as f:
+            f.write(engine.serialize())
+
+        # load the model
+        engine = self.lib.load_engine(final_model_path)
+
+        class TRT_model:
+            def __init__(self, engine, input_names, output_names, output_types, device):
+                self.engine = engine
+                self.context = self.engine.create_execution_context()
+                self.input_names = input_names
+                self.output_names = output_names
+                self.output_types = output_types
+                self.device = device
+
+            def is_dimension_dynamic(self, dim):
+                return dim is None or dim <= 0
+
+            def is_shape_dynamic(self, shape):
+                return any([self.is_dimension_dynamic(dim) for dim in shape])
+
+            def __call__(self, *inputs):
+                # get input shapes
+                input_shapes = [x.shape for x in inputs]
+                # bindings
+                bindings = [None] * self.engine.num_bindings
+                # set input shapes, bind input tensors
+                zipped = zip(self.input_names, inputs)
+                for key, input in zipped:
+                    idx = self.engine.get_binding_index(key)
+                    bindings[idx] = input.data_ptr()
+                    if self.engine.is_shape_binding(idx) and self.is_shape_dynamic(
+                        self.context.get_shape(idx)
+                    ):
+                        self.context.set_shape_input(idx, input)
+                    elif self.is_shape_dynamic(self.engine.get_binding_shape(idx)):
+                        self.context.set_binding_shape(idx, input.shape)
+                assert self.context.all_binding_shapes_specified, "trt error"
+                assert self.context.all_shape_inputs_specified, "trt error"
+                # calculate output shapes, allocate output tensors and bind them
+                outputs = []
+                zipped = zip(self.output_names, self.output_types)
+                for key, dtype in zipped:
+                    idx = self.engine.get_binding_index(key)
+                    shape = self.context.get_binding_shape(idx)
+                    shape = tuple(shape)
+                    assert -1 not in shape, "trt error"
+                    tensor = torch.zeros(shape, dtype=dtype, device=self.device)
+                    outputs.append(tensor)
+                    bindings[idx] = outputs[-1].data_ptr()
+                # run inference
+                self.context.execute_v2(bindings=bindings)
+                # return the result
+                if len(outputs) == 1:
+                    outputs = outputs[0]
+                return outputs
+
+        model_trt = TRT_model(engine, input_names, output_names, output_types, device)
+
+        # run both models on inputs
+        assert not model.training, "internal error - model should be in eval() mode! "
+        models = (model, model_trt)
+        outputs, time_model, outputs_trt, time_model_trt = self.lib.run_models(
+            models, inputs
+        )
+
+        # check for errors
+        Error_stats = self.lib.compute_errors(outputs, outputs_trt)
+        self.lib.print_errors(Error_stats)
+        print("time of error check of native model: ", time_model, "seconds")
+        print("time of error check of trt model: ", time_model_trt, "seconds")
+        print()
+
+        # write TRTIS config
+        config_filename = os.path.join(model_folder, "config.pbtxt")
+        self.lib.write_config(
+            config_filename, input_shapes, input_types, output_shapes, output_types
+        )
+
+    def name_onnx_nodes(self, model_path):
+        """
+        Name all unnamed nodes in ONNX model
+            parameter model_path: path  ONNX model
+            return: none
+        """
+        model = onnx.load(model_path)
+        node_id = 0
+        for node in model.graph.node:
+            if len(node.name) == 0:
+                node.name = "unnamed_node_%d" % node_id
+            node_id += 1
+        # This check partially validates model
+        onnx.checker.check_model(model)
+        onnx.save(model, model_path)
+        # Only inference really checks ONNX model for some issues
+        # like duplicated node names
+        onnxruntime.InferenceSession(model_path, None)
+
+    def to_triton_onnx(self, dataloader, model):
+        """ export the model to onnx and test correctness on dataloader """
+        import onnx as local_onnx
+
+        global onnx
+        onnx = local_onnx
+        import onnxruntime as local_onnxruntime
+
+        global onnxruntime
+        onnxruntime = local_onnxruntime
+        # setup device
+        if self.args.triton_no_cuda:
+            device = torch.device("cpu")
+        else:
+            device = torch.device("cuda")
+
+        # prepare model
+        model.to(device)
+        model.eval()
+        assert not model.training, "internal error - model should be in eval() mode! "
+
+        # prepare inputs
+        inputs = self.lib.prepare_inputs(dataloader, device)
+
+        # generate outputs
+        outputs = []
+        for input in inputs:
+            with torch.no_grad():
+                output = model(*input)
+            if type(output) is torch.Tensor:
+                output = [output]
+            outputs.append(output)
+
+        # generate input shapes - dynamic tensor shape support
+        input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
+
+        # generate output shapes - dynamic tensor shape support
+        output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
+
+        # generate input types
+        input_types = [x.dtype for x in inputs[0]]
+
+        # generate output types
+        output_types = [x.dtype for x in outputs[0]]
+
+        # get input names
+        rng = range(len(input_types))
+        input_names = ["input__" + str(num) for num in rng]
+
+        # get output names
+        rng = range(len(output_types))
+        output_names = ["output__" + str(num) for num in rng]
+
+        # prepare save path
+        model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
+        version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
+        if not os.path.exists(version_folder):
+            os.makedirs(version_folder)
+        final_model_path = os.path.join(version_folder, "model.onnx")
+
+        # get indices of dynamic input and output shapes
+        dynamic_axes = {}
+        for input_name, input_shape in zip(input_names, input_shapes):
+            dynamic_axes[input_name] = [i for i, x in enumerate(input_shape) if x == -1]
+        for output_name, output_shape in zip(output_names, output_shapes):
+            dynamic_axes[output_name] = [
+                i for i, x in enumerate(output_shape) if x == -1
+            ]
+
+        # export the model
+        assert not model.training, "internal error - model should be in eval() mode! "
+        with torch.no_grad():
+            torch.onnx.export(
+                model,
+                inputs[0],
+                final_model_path,
+                verbose=True,
+                input_names=input_names,
+                output_names=output_names,
+                dynamic_axes=dynamic_axes,
+                opset_version=11,
+            )
+
+        # syntactic error check
+        converted_model = onnx.load(final_model_path)
+        # check that the IR is well formed
+        onnx.checker.check_model(converted_model)
+
+        # Name unnamed nodes - it helps for some other processing tools
+        self.name_onnx_nodes(final_model_path)
+        converted_model = onnx.load(final_model_path)
+
+        # load the model
+        session = onnxruntime.InferenceSession(final_model_path, None)
+
+        class ONNX_model:
+            def __init__(self, session, input_names, device):
+                self.session = session
+                self.input_names = input_names
+
+            def to_numpy(self, tensor):
+                return (
+                    tensor.detach().cpu().numpy()
+                    if tensor.requires_grad
+                    else tensor.cpu().numpy()
+                )
+
+            def __call__(self, *inputs):
+                inp = [
+                    (input_name, inputs[i])
+                    for i, input_name in enumerate(self.input_names)
+                ]
+                inp = {input_name: self.to_numpy(x) for input_name, x in inp}
+                outputs = self.session.run(None, inp)
+                outputs = [torch.from_numpy(output) for output in outputs]
+                outputs = [output.to(device) for output in outputs]
+                if len(outputs) == 1:
+                    outputs = outputs[0]
+                return outputs
+
+        # switch to eval mode
+        model_onnx = ONNX_model(session, input_names, device)
+
+        # run both models on inputs
+        assert not model.training, "internal error - model should be in eval() mode! "
+        models = (model, model_onnx)
+        outputs, time_model, outputs_onnx, time_model_onnx = self.lib.run_models(
+            models, inputs
+        )
+
+        # check for errors
+        Error_stats = self.lib.compute_errors(outputs, outputs_onnx)
+        self.lib.print_errors(Error_stats)
+        print("time of error check of native model: ", time_model, "seconds")
+        print("time of error check of onnx model: ", time_model_onnx, "seconds")
+        print()
+
+        # write TRTIS config
+        config_filename = os.path.join(model_folder, "config.pbtxt")
+        self.lib.write_config(
+            config_filename, input_shapes, input_types, output_shapes, output_types
+        )
+
+    def to_triton_torchscript(self, dataloader, model):
+        """ export the model to torchscript and test correctness on dataloader """
+        # setup device
+        if self.args.triton_no_cuda:
+            device = torch.device("cpu")
+        else:
+            device = torch.device("cuda")
+
+        # prepare model
+        model.to(device)
+        model.eval()
+        assert not model.training, "internal error - model should be in eval() mode! "
+
+        # prepare inputs
+        inputs = self.lib.prepare_inputs(dataloader, device)
+
+        # generate input shapes - dynamic tensor shape support
+        input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
+
+        # generate input types
+        input_types = [x.dtype for x in inputs[0]]
+
+        # prepare save path
+        model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
+        version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
+        if not os.path.exists(version_folder):
+            os.makedirs(version_folder)
+        final_model_path = os.path.join(version_folder, "model.pt")
+
+        # convert the model
+        with torch.no_grad():
+            if self.args.ts_trace:  # trace it
+                model_ts = torch.jit.trace(model, inputs[0])
+            if self.args.ts_script:  # script it
+                model_ts = torch.jit.script(model)
+
+        # save the model
+        torch.jit.save(model_ts, final_model_path)
+
+        # load the model
+        model_ts = torch.jit.load(final_model_path)
+        model_ts.eval()  # WAR for bug : by default, model_ts gets loaded in training mode
+
+        # run both models on inputs
+        assert not model.training, "internal error - model should be in eval() mode! "
+        assert (
+            not model_ts.training
+        ), "internal error - converted model should be in eval() mode! "
+        models = (model, model_ts)
+        outputs, time_model, outputs_ts, time_model_ts = self.lib.run_models(
+            models, inputs
+        )
+
+        # check for errors
+        Error_stats = self.lib.compute_errors(outputs, outputs_ts)
+        self.lib.print_errors(Error_stats)
+        print("time of error check of native model: ", time_model, "seconds")
+        print("time of error check of ts model: ", time_model_ts, "seconds")
+        print()
+
+        # generate output shapes - dynamic tensor shape support
+        output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
+
+        # generate output types
+        output_types = [x.dtype for x in outputs[0]]
+
+        # now we build the config for TRTIS
+        config_filename = os.path.join(model_folder, "config.pbtxt")
+        self.lib.write_config(
+            config_filename, input_shapes, input_types, output_shapes, output_types
+        )

BIN
PyTorch/Classification/ConvNets/triton/resnet50/Latency-vs-Throughput-TensorRT.png


BIN
PyTorch/Classification/ConvNets/triton/resnet50/Performance-analysis-TensorRT-FP16.png


BIN
PyTorch/Classification/ConvNets/triton/resnet50/Performance-analysis-TensorRT-FP32.png


+ 248 - 0
PyTorch/Classification/ConvNets/triton/resnet50/README.md

@@ -0,0 +1,248 @@
+# Deploying the ResNet-50 v1.5 model using Triton Inference Server
+
+The [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) provides a datacenter and cloud inferencing solution optimized for NVIDIA GPUs. The server provides an inference service via an HTTP or gRPC endpoint, allowing remote clients to request inferencing for any number of GPU or CPU models being managed by the server. 
+
+This folder contains instructions on how to deploy and run inference on
+Triton Inference Server as well as gather detailed performance analysis.
+
+## Table Of Contents
+
+* [Model overview](#model-overview)
+* [Setup](#setup)
+  * [Inference container](#inference-container)
+  * [Deploying the model](#deploying-the-model)
+  * [Running the Triton Inference Server](#running-the-triton-inference-server)
+* [Quick Start Guide](#quick-start-guide)
+  * [Running the client](#running-the-client)
+  * [Gathering performance data](#gathering-performance-data)
+* [Advanced](#advanced)
+  * [Automated benchmark script](#automated-benchmark-script)
+* [Performance](#performance)
+  * [Dynamic batching performance](#dynamic-batching-performance)
+  * [TensorRT backend inference performance (1x V100 16GB)](#tensorrt-backend-inference-performance-1x-v100-16gb)
+* [Release notes](#release-notes)
+  * [Changelog](#changelog)
+  * [Known issues](#known-issues)
+
+## Model overview
+The ResNet50 v1.5 model is a modified version of the [original ResNet50 v1 model](https://arxiv.org/abs/1512.03385).
+
+The difference between v1 and v1.5 is that, in the bottleneck blocks which requires
+downsampling, v1 has stride = 2 in the first 1x1 convolution, whereas v1.5 has stride = 2 in the 3x3 convolution.
+
+This difference makes ResNet50 v1.5 slightly more accurate (~0.5% top1) than v1, but comes with a smallperformance drawback (~5% imgs/sec)
+
+The ResNet50 v1.5 model can be deployed for inference on the [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) using
+TorchScript, ONNX Runtime or TensorRT as an execution backend.
+
+## Setup
+
+This script requires trained ResNet50 v1.5 model checkpoint that can be used for deployment. 
+
+### Inference container
+
+For easy-to-use deployment, a build script for special inference container was prepared. To build that container, go to the main repository folder and run:
+
+`docker build -t rn50_inference . -f triton/Dockerfile`
+
+This command will download the dependencies and build the inference containers. Then, run shell inside the container:
+
+`docker run -it --rm --gpus device=0 --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --net=host -v <PATH_TO_MODEL_REPOSITORY>:/repository rn50_inference bash`
+
+Here `device=0,1,2,3` selects the GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates location to where the
+deployed models were stored.
+
+### Deploying the model
+
+To deploy the ResNet-50 v1.5 model into the Triton Inference Server, you must run the `deployer.py` script from inside the deployment Docker container to achieve a compatible format. 
+
+```
+usage: deployer.py [-h] (--ts-script | --ts-trace | --onnx | --trt)
+                   [--triton-no-cuda] [--triton-model-name TRITON_MODEL_NAME]
+                   [--triton-model-version TRITON_MODEL_VERSION]
+                   [--triton-server-url TRITON_SERVER_URL]
+                   [--triton-max-batch-size TRITON_MAX_BATCH_SIZE]
+                   [--triton-dyn-batching-delay TRITON_DYN_BATCHING_DELAY]
+                   [--triton-engine-count TRITON_ENGINE_COUNT]
+                   [--save-dir SAVE_DIR]
+                   [--max_workspace_size MAX_WORKSPACE_SIZE] [--trt-fp16]
+                   [--capture-cuda-graph CAPTURE_CUDA_GRAPH]
+                   ...
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --ts-script           convert to torchscript using torch.jit.script
+  --ts-trace            convert to torchscript using torch.jit.trace
+  --onnx                convert to onnx using torch.onnx.export
+  --trt                 convert to trt using tensorrt
+
+triton related flags:
+  --triton-no-cuda      Use the CPU for tracing.
+  --triton-model-name TRITON_MODEL_NAME
+                        exports to appropriate directory structure for TRITON
+  --triton-model-version TRITON_MODEL_VERSION
+                        exports to appropriate directory structure for TRITON
+  --triton-server-url TRITON_SERVER_URL
+                        exports to appropriate directory structure for TRITON
+  --triton-max-batch-size TRITON_MAX_BATCH_SIZE
+                        Specifies the 'max_batch_size' in the TRITON model
+                        config. See the TRITON documentation for more info.
+  --triton-dyn-batching-delay TRITON_DYN_BATCHING_DELAY
+                        Determines the dynamic_batching queue delay in
+                        milliseconds(ms) for the TRITON model config. Use '0'
+                        or '-1' to specify static batching. See the TRITON
+                        documentation for more info.
+  --triton-engine-count TRITON_ENGINE_COUNT
+                        Specifies the 'instance_group' count value in the
+                        TRITON model config. See the TRITON documentation for
+                        more info.
+  --save-dir SAVE_DIR   Saved model directory
+
+optimization flags:
+  --max_workspace_size MAX_WORKSPACE_SIZE
+                        set the size of the workspace for trt export
+  --trt-fp16            trt flag ---- export model in mixed precision mode
+  --capture-cuda-graph CAPTURE_CUDA_GRAPH
+                        capture cuda graph for obtaining speedup. possible
+                        values: 0, 1. default: 1.
+  model_arguments       arguments that will be ignored by deployer lib and
+                        will be forwarded to your deployer script
+```
+
+Following model specific arguments have to be specified for model deployment:
+  
+```
+  --config CONFIG        Network architecture to use for deployment (eg. resnet50, 
+                         resnext101-32x4d or se-resnext101-32x4d)
+  --checkpoint CHECKPOINT
+                         Path to stored model weight. If not specified, model will be 
+                         randomly initialized
+  --batch_size BATCH_SIZE
+                         Batch size used for dummy dataloader
+  --fp16                 Use model with half-precision calculations
+```
+
+For example, to deploy model into TensorRT format, using half precision and max batch size 64 called
+`rn-trt-16` execute:
+
+`python -m triton.deployer --trt --trt-fp16 --triton-model-name rn-trt-16 --triton-max-batch-size 64 --save-dir /repository -- --config resnet50 --checkpoint model_checkpoint --batch_size 64 --fp16`
+
+Where `model_checkpoint` is a checkpoint for a trained model with the same architecture (resnet50) as used during export.
+
+### Running the Triton Inference Server
+
+**NOTE: This step is executed outside the inference container.**
+
+Pull the Triton Inference Server container from our repository:
+
+`docker pull nvcr.io/nvidia/tritonserver:20.07-py3`
+
+Run the command to start the Triton Inference Server:
+
+`docker run -d --rm --gpus device=0 --ipc=host --network=host -p 8000:8000 -p 8001:8001 -p 8002:8002 -v <PATH_TO_MODEL_REPOSITORY>:/models nvcr.io/nvidia/tritonserver:20.07-py3 trtserver --model-store=/models --log-verbose=1 --model-control-mode=poll --repository-poll-secs=5`
+
+Here `device=0,1,2,3` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates the location where the 
+deployed models were stored. An additional `--model-controle-mode` option allows to reload the model when it changes in the filesystem. It is a required option for benchmark scripts that works with multiple model versions on a single Triton Inference Server instance.
+
+## Quick Start Guide
+
+### Running the client
+
+The client `client.py` checks the model accuracy against synthetic or real validation
+data. The client connects to Triton Inference Server and performs inference. 
+
+```
+usage: client.py [-h] --triton-server-url TRITON_SERVER_URL
+                 --triton-model-name TRITON_MODEL_NAME [-v]
+                 [--inference_data INFERENCE_DATA] [--batch_size BATCH_SIZE]
+                 [--fp16]
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --triton-server-url TRITON_SERVER_URL
+                        URL adress of trtion server (with port)
+  --triton-model-name TRITON_MODEL_NAME
+                        Triton deployed model name
+  -v, --verbose         Verbose mode.
+  --inference_data INFERENCE_DATA
+                        Path to file with inference data.
+  --batch_size BATCH_SIZE
+                        Inference request batch size
+  --fp16                Use fp16 precision for input data
+
+```
+
+To run inference on the model exported in the previous steps, using the data located under
+`/dataset`, run:
+
+`python -m triton.client --triton-server-url localhost:8001 --triton-model-name rn-trt-16 --inference_data /data/test_data.bin --batch_size 16 --fp16`
+
+
+### Gathering performance data
+Performance data can be gathered using the `perf_client` tool. To use this tool to measure performance for batch_size=32, the following command can be used:
+
+`/workspace/bin/perf_client --max-threads 10 -m rn-trt-16 -x 1 -p 10000 -v -i gRPC -u localhost:8001 -b 32 -l 5000 --concurrency-range 1 -f result.csv`
+
+For more information about `perf_client`, refer to the [documentation](https://docs.nvidia.com/deeplearning/sdk/triton-inference-server-master-branch-guide/docs/optimization.html#perf-client).
+
+## Advanced
+
+### Automated benchmark script
+To automate benchmarks of different model configurations, a special benchmark script is located in `triton/scripts/benchmark.sh`. To use this script,
+run Triton Inference Server and then execute the script as follows:
+
+`bash triton/scripts/benchmark.sh <MODEL_REPOSITORY> <LOG_DIRECTORY> <ARCHITECTURE> (<CHECKPOINT_PATH>)`
+
+The benchmark script tests all supported backends with different batch sizes and server configuration. Logs from execution will be stored in `<LOG DIRECTORY>`.
+To process static configuration logs, `triton/scripts/process_output.sh` script can be used.
+
+## Performance
+
+### Dynamic batching performance
+The Triton Inference Server has a dynamic batching mechanism built-in that can be enabled. When it is enabled, the server creates inference batches from multiple received requests. This allows us to achieve better performance than doing inference on each single request. The single request is assumed to be a single image that needs to be inferenced. With dynamic batching enabled, the server will concatenate single image requests into an inference batch. The upper bound of the size of the inference batch is set to 64. All these parameters are configurable.
+
+Our results were obtained by running automated benchmark script. 
+Throughput is measured in images/second, and latency in milliseconds.
+
+### TensorRT backend inference performance (1x V100 16GB)
+**FP32 Inference Performance**
+
+|**Concurrent requests**|**Throughput (img/s)**|**Avg. Latency (ms)**|**90% Latency (ms)**|**95% Latency (ms)**|**99% Latency (ms)**|
+|-----|--------|-------|--------|-------|-------|
+| 1 | 133.6 | 7.48 | 7.56 | 7.59 | 7.68 |
+| 2 | 156.6 | 12.77 | 12.84 | 12.86 | 12.93 |
+| 4 | 193.3 | 20.70 | 20.82 | 20.85 | 20.92 | 
+| 8 | 357.4 | 22.38 | 22.53 | 22.57 | 22.67 |
+| 16 | 627.3 | 25.49 | 25.64 | 25.69 | 25.80 |
+| 32 | 1003 | 31.87 | 32.43 | 32.61 | 32.91 |
+| 64 | 1394.7 | 45.85 | 46.13 | 46.22 | 46.86 |
+| 128 | 1604.4 | 79.70 | 80.50 | 80.96 | 83.09 |
+| 256 | 1670.7 | 152.21 | 186.78 | 188.36 | 190.52 |
+
+**FP16 Inference Performance**
+
+|**Concurrent requests**|**Throughput (img/s)**|**Avg. Latency (ms)**|**90% Latency (ms)**|**95% Latency (ms)**|**99% Latency (ms)**|
+|-----|--------|-------|--------|-------|-------|
+| 1 | 250.1 | 3.99 | 4.08 | 4.11 | 4.16 |
+| 2 | 314.8 | 6.35 | 6.42 | 6.44 | 6.49 |
+| 4 | 384.8 | 10.39 | 10.51 | 10.54 | 10.60 |
+| 8 | 693.8 | 11.52 | 11.78 | 11.88 | 12.09 |
+| 16 | 1132.9 | 14.13 | 14.31 | 14.41 | 14.65 |
+| 32 | 1689.7 | 18.93 | 19.11 | 19.20 | 19.44 |
+| 64 | 2226.3 | 28.74 | 29.53 | 29.74 | 31.09 |
+| 128 | 2521.5 | 50.74 | 51.97 | 52.30 | 53.61 |
+| 256 | 2738 | 93.76 | 97.14 | 115.19 | 117.21 |
+
+
+![Latency vs Througput](./Latency-vs-Throughput-TensorRT.png)
+
+![Performance analysis - TensorRT FP32](./Performance-analysis-TensorRT-FP32.png)
+
+![Performance analysis - TensorRT FP16](./Performance-analysis-TensorRT-FP16.png)
+
+
+## Release notes
+
+### Changelog
+September 2020
+- Initial release

BIN
PyTorch/Classification/ConvNets/triton/resnext101-32x4d/Latency-vs-Throughput-TensorRT.png


BIN
PyTorch/Classification/ConvNets/triton/resnext101-32x4d/Performance-analysis-TensorRT-FP16.png


BIN
PyTorch/Classification/ConvNets/triton/resnext101-32x4d/Performance-analysis-TensorRT-FP32.png


+ 243 - 0
PyTorch/Classification/ConvNets/triton/resnext101-32x4d/README.md

@@ -0,0 +1,243 @@
+# Deploying the ResNeXt101-32x4d model using Triton Inference Server
+
+The [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) provides a datacenter and cloud inferencing solution optimized for NVIDIA GPUs. The server provides an inference service via an HTTP or gRPC endpoint, allowing remote clients to request inferencing for any number of GPU or CPU models being managed by the server. 
+
+This folder contains instructions on how to deploy and run inference on
+Triton Inference Server as well as gather detailed performance analysis.
+
+## Table Of Contents
+
+* [Model overview](#model-overview)
+* [Setup](#setup)
+  * [Inference container](#inference-container)
+  * [Deploying the model](#deploying-the-model)
+  * [Running the Triton Inference Server](#running-the-triton-inference-server)
+* [Quick Start Guide](#quick-start-guide)
+  * [Running the client](#running-the-client)
+  * [Gathering performance data](#gathering-performance-data)
+* [Advanced](#advanced)
+  * [Automated benchmark script](#automated-benchmark-script)
+* [Performance](#performance)
+  * [Dynamic batching performance](#dynamic-batching-performance)
+  * [TensorRT backend inference performance (1x V100 16GB)](#tensorrt-backend-inference-performance-1x-v100-16gb)
+* [Release notes](#release-notes)
+  * [Changelog](#changelog)
+  * [Known issues](#known-issues)
+
+## Model overview
+The ResNeXt101-32x4d is a model introduced in the [Aggregated Residual Transformations for Deep Neural Networks](https://arxiv.org/pdf/1611.05431.pdf) paper.
+It is based on regular ResNet model, substituting 3x3 convolutions inside the bottleneck block for 3x3 grouped convolutions.
+
+The ResNeXt101-32x4d model can be deployed for inference on the [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) using
+TorchScript, ONNX Runtime or TensorRT as an execution backend.
+
+## Setup
+
+This script requires trained ResNeXt101-32x4d model checkpoint that can be used for deployment. 
+
+### Inference container
+
+For easy-to-use deployment, a build script for special inference container was prepared. To build that container, go to the main repository folder and run:
+
+`docker build -t rnxt_inference . -f triton/Dockerfile`
+
+This command will download the dependencies and build the inference containers. Then, run shell inside the container:
+
+`docker run -it --rm --gpus device=0 --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --net=host -v <PATH_TO_MODEL_REPOSITORY>:/repository rnxt_inference bash`
+
+Here `device=0,1,2,3` selects the GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates location to where the
+deployed models were stored.
+
+### Deploying the model
+
+To deploy the ResNext101-32x4d model into the Triton Inference Server, you must run the `deployer.py` script from inside the deployment Docker container to achieve a compatible format. 
+
+```
+usage: deployer.py [-h] (--ts-script | --ts-trace | --onnx | --trt)
+                   [--triton-no-cuda] [--triton-model-name TRITON_MODEL_NAME]
+                   [--triton-model-version TRITON_MODEL_VERSION]
+                   [--triton-server-url TRITON_SERVER_URL]
+                   [--triton-max-batch-size TRITON_MAX_BATCH_SIZE]
+                   [--triton-dyn-batching-delay TRITON_DYN_BATCHING_DELAY]
+                   [--triton-engine-count TRITON_ENGINE_COUNT]
+                   [--save-dir SAVE_DIR]
+                   [--max_workspace_size MAX_WORKSPACE_SIZE] [--trt-fp16]
+                   [--capture-cuda-graph CAPTURE_CUDA_GRAPH]
+                   ...
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --ts-script           convert to torchscript using torch.jit.script
+  --ts-trace            convert to torchscript using torch.jit.trace
+  --onnx                convert to onnx using torch.onnx.export
+  --trt                 convert to trt using tensorrt
+
+triton related flags:
+  --triton-no-cuda      Use the CPU for tracing.
+  --triton-model-name TRITON_MODEL_NAME
+                        exports to appropriate directory structure for TRITON
+  --triton-model-version TRITON_MODEL_VERSION
+                        exports to appropriate directory structure for TRITON
+  --triton-server-url TRITON_SERVER_URL
+                        exports to appropriate directory structure for TRITON
+  --triton-max-batch-size TRITON_MAX_BATCH_SIZE
+                        Specifies the 'max_batch_size' in the TRITON model
+                        config. See the TRITON documentation for more info.
+  --triton-dyn-batching-delay TRITON_DYN_BATCHING_DELAY
+                        Determines the dynamic_batching queue delay in
+                        milliseconds(ms) for the TRITON model config. Use '0'
+                        or '-1' to specify static batching. See the TRITON
+                        documentation for more info.
+  --triton-engine-count TRITON_ENGINE_COUNT
+                        Specifies the 'instance_group' count value in the
+                        TRITON model config. See the TRITON documentation for
+                        more info.
+  --save-dir SAVE_DIR   Saved model directory
+
+optimization flags:
+  --max_workspace_size MAX_WORKSPACE_SIZE
+                        set the size of the workspace for trt export
+  --trt-fp16            trt flag ---- export model in mixed precision mode
+  --capture-cuda-graph CAPTURE_CUDA_GRAPH
+                        capture cuda graph for obtaining speedup. possible
+                        values: 0, 1. default: 1.
+  model_arguments       arguments that will be ignored by deployer lib and
+                        will be forwarded to your deployer script
+```
+
+Following model specific arguments have to be specified for model deployment:
+  
+```
+  --config CONFIG        Network architecture to use for deployment (eg. resnet50, 
+                         resnext101-32x4d or se-resnext101-32x4d)
+  --checkpoint CHECKPOINT
+                         Path to stored model weight. If not specified, model will be 
+                         randomly initialized
+  --batch_size BATCH_SIZE
+                         Batch size used for dummy dataloader
+  --fp16                 Use model with half-precision calculations
+```
+
+For example, to deploy model into TensorRT format, using half precision and max batch size 64 called
+`rnxt-trt-16` execute:
+
+`python -m triton.deployer --trt --trt-fp16 --triton-model-name rnxt-trt-16 --triton-max-batch-size 64 --save-dir /repository -- --config resnext101-32x4d --checkpoint model_checkpoint --batch_size 64 --fp16`
+
+Where `model_checkpoint` is a checkpoint for a trained model with the same architecture (resnext101-32x4d) as used during export.
+
+### Running the Triton Inference Server
+
+**NOTE: This step is executed outside the inference container.**
+
+Pull the Triton Inference Server container from our repository:
+
+`docker pull nvcr.io/nvidia/tritonserver:20.07-py3`
+
+Run the command to start the Triton Inference Server:
+
+`docker run -d --rm --gpus device=0 --ipc=host --network=host -p 8000:8000 -p 8001:8001 -p 8002:8002 -v <PATH_TO_MODEL_REPOSITORY>:/models nvcr.io/nvidia/tritonserver:20.07-py3 trtserver --model-store=/models --log-verbose=1 --model-control-mode=poll --repository-poll-secs=5`
+
+Here `device=0,1,2,3` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates the location where the 
+deployed models were stored. An additional `--model-controle-mode` option allows to reload the model when it changes in the filesystem. It is a required option for benchmark scripts that works with multiple model versions on a single Triton Inference Server instance.
+
+## Quick Start Guide
+
+### Running the client
+
+The client `client.py` checks the model accuracy against synthetic or real validation
+data. The client connects to Triton Inference Server and performs inference. 
+
+```
+usage: client.py [-h] --triton-server-url TRITON_SERVER_URL
+                 --triton-model-name TRITON_MODEL_NAME [-v]
+                 [--inference_data INFERENCE_DATA] [--batch_size BATCH_SIZE]
+                 [--fp16]
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --triton-server-url TRITON_SERVER_URL
+                        URL adress of trtion server (with port)
+  --triton-model-name TRITON_MODEL_NAME
+                        Triton deployed model name
+  -v, --verbose         Verbose mode.
+  --inference_data INFERENCE_DATA
+                        Path to file with inference data.
+  --batch_size BATCH_SIZE
+                        Inference request batch size
+  --fp16                Use fp16 precision for input data
+
+```
+
+To run inference on the model exported in the previous steps, using the data located under
+`/dataset`, run:
+
+`python -m triton.client --triton-server-url localhost:8001 --triton-model-name rnxt-trt-16 --inference_data /data/test_data.bin --batch_size 16 --fp16`
+
+
+### Gathering performance data
+Performance data can be gathered using the `perf_client` tool. To use this tool to measure performance for batch_size=32, the following command can be used:
+
+`/workspace/bin/perf_client --max-threads 10 -m rnxt-trt-16 -x 1 -p 10000 -v -i gRPC -u localhost:8001 -b 32 -l 5000 --concurrency-range 1 -f result.csv`
+
+For more information about `perf_client`, refer to the [documentation](https://docs.nvidia.com/deeplearning/sdk/triton-inference-server-master-branch-guide/docs/optimization.html#perf-client).
+
+## Advanced
+
+### Automated benchmark script
+To automate benchmarks of different model configurations, a special benchmark script is located in `triton/scripts/benchmark.sh`. To use this script,
+run Triton Inference Server and then execute the script as follows:
+
+`bash triton/scripts/benchmark.sh <MODEL_REPOSITORY> <LOG_DIRECTORY> <ARCHITECTURE> (<CHECKPOINT_PATH>)`
+
+The benchmark script tests all supported backends with different batch sizes and server configuration. Logs from execution will be stored in `<LOG DIRECTORY>`.
+To process static configuration logs, `triton/scripts/process_output.sh` script can be used.
+
+## Performance
+
+### Dynamic batching performance
+The Triton Inference Server has a dynamic batching mechanism built-in that can be enabled. When it is enabled, the server creates inference batches from multiple received requests. This allows us to achieve better performance than doing inference on each single request. The single request is assumed to be a single image that needs to be inferenced. With dynamic batching enabled, the server will concatenate single image requests into an inference batch. The upper bound of the size of the inference batch is set to 64. All these parameters are configurable.
+
+Our results were obtained by running automated benchmark script. 
+Throughput is measured in images/second, and latency in milliseconds.
+
+### TensorRT backend inference performance (1x V100 16GB)
+**FP32 Inference Performance**
+
+|**Concurrent requests**|**Throughput (img/s)**|**Avg. Latency (ms)**|**90% Latency (ms)**|**95% Latency (ms)**|**99% Latency (ms)**|
+|-----|--------|-------|--------|-------|-------|
+| 1 | 62.6 | 15.96 | 16.06 | 16.12 | 16.46|
+|2 | 69.5 | 28.74 | 28.81 | 28.84 | 28.88|
+|4 | 114.1 | 35.08 | 35.13 | 35.16 | 35.33|
+|8 | 180 | 44.41 | 44.21 | 49.83 | 50.16|
+|16 | 240 | 66.66 | 67.02 | 67.10 | 67.26|
+|32 | 342.2 | 93.75 | 108.43 | 109.48 | 125.68|
+|64 | 450.9 | 141.60 | 167.91 | 170.35 | 175.99|
+|128 | 545.5 | 234.40 | 248.57 | 250.87 | 254.69|
+|256 | 652.8 | 395.46 | 397.43 | 399.69 | 403.24|
+
+**FP16 Inference Performance**
+
+|**Concurrent requests**|**Throughput (img/s)**|**Avg. Latency (ms)**|**90% Latency (ms)**|**95% Latency (ms)**|**99% Latency (ms)**|
+|-----|--------|-------|--------|-------|-------|
+|1 | 85.7 | 11.68 | 11.76 | 11.79 | 11.85|
+|2 | 92 | 21.74 | 21.83 | 21.86 | 21.91|
+|4 | 141.7 | 28.22 | 35.01 | 35.38 | 35.51|
+|8 | 235.4 | 33.98 | 38.05 | 38.67 | 38.85|
+|16 | 393 | 40.67 | 42.90 | 43.28 | 43.50|
+|32 | 624.8 | 51.18 | 51.71 | 51.82 | 52.08|
+|64 | 874.6 | 73.39 | 74.39 | 74.60 | 75.12|
+|128 | 1126.4 | 113.73 | 114.16 | 114.54 | 115.99|
+|256 | 1312 | 195.87 | 196.87 | 197.75 | 199.06|
+
+![Latency vs Througput](./Latency-vs-Throughput-TensorRT.png)
+
+![Performance analysis - TensorRT FP32](./Performance-analysis-TensorRT-FP32.png)
+
+![Performance analysis - TensorRT FP16](./Performance-analysis-TensorRT-FP16.png)
+
+
+## Release notes
+
+### Changelog
+September 2020
+- Initial release

+ 53 - 0
PyTorch/Classification/ConvNets/triton/scripts/benchmark.sh

@@ -0,0 +1,53 @@
+#!/bin/bash
+
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MODEL_REPO=${1:-"/repo"}
+OUTPUT=${2:-"/logs"}
+MODEL_ARCH=${3:-"resnet50"}
+MODEL_CHECKPOINT=${4:-"/checkpoint.pth"}
+
+for backend in ts onnx trt; do
+    if [[ "$backend" = "ts" ]]; then
+        EXPORT_NAME="ts-script"
+    else
+        EXPORT_NAME="${backend}"
+    fi
+
+    for precision in 16 32; do
+        if [[ $precision -eq 16 ]]; then
+            CUSTOM_FLAGS="--fp16"
+            CUSTON_TRTFLAGS="--trt-fp16 --max_workspace_size 2147483648"
+        else
+            CUSTOM_FLAGS=""
+            CUSTON_TRTFLAGS=""
+        fi
+
+        echo "Exporting model as ${EXPORT_NAME} with precision ${precision}"
+
+        python -m triton.deployer --${EXPORT_NAME} --triton-model-name model_${backend} --triton-max-batch-size 64 \
+            --triton-engine-count 2 --save-dir ${MODEL_REPO} ${CUSTON_TRTFLAGS} -- --config ${MODEL_ARCH} ${CUSTOM_FLAGS}
+        sleep 30
+
+        /workspace/bin/perf_client --max-threads 10 -m model_${backend} -x 1 -p 10000 -v -i gRPC -u localhost:8001 -b 1 \
+            -l 5000 --concurrency-range 1:2 -f ${OUTPUT}/${backend}_dynamic_${precision}.csv
+        for CONCURENCY_LEVEL in 4 8 16 32 64 128 256; do
+            /workspace/bin/perf_client --max-threads 10 -m model_${backend} -x 1 -p 10000 -v -i gRPC -u localhost:8001 -b 1 \
+                -l 5000 --concurrency-range $CONCURENCY_LEVEL:$CONCURENCY_LEVEL -f >(tail -n +2 >> ${OUTPUT}/${backend}_dynamic_${precision}.csv)
+        done
+        rm -rf ${MODEL_REPO}/model_${backend}
+    done
+    cat ${OUTPUT}/*_dynamic_*.csv
+done

+ 24 - 0
PyTorch/Classification/ConvNets/triton/scripts/get_metrics_static.sh

@@ -0,0 +1,24 @@
+#!/bin/bash
+
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MODEL_NAME=${1}
+OUTPUT_FILE=${2:-result.csv}
+
+for i in 1 2 4 8 16 32 64 128; do
+    echo "Model $MODEL_NAME evaluation with BS $i"
+    /workspace/bin/perf_client --max-threads 10 -m $MODEL_NAME -x 1 -p 10000 -v -i gRPC -u localhost:8001 -b $i -l 5000 \
+        --concurrency-range 1 -f >(tail -n +2 | sed -e 's/^/BS='${i}',/' >> $OUTPUT_FILE)
+done

+ 20 - 0
PyTorch/Classification/ConvNets/triton/scripts/process_output.sh

@@ -0,0 +1,20 @@
+#!/bin/bash
+
+echo "Processing file $1"
+echo "Throughput:"
+
+cat $1 | cut -d ',' -f3
+
+echo ""
+echo "Average latency: "
+cat $1 | cut -d ',' -f4-10 | sed "s/,/\+/g" | sed "s/.*/scale=2; (\0) \/ 1000/g" | bc
+
+echo ""
+echo "p90 latency: "
+cat $1 | cut -d ',' -f12 | sed "s/.*/scale=2; \0 \/ 1000/g" | bc
+echo ""
+echo "p95 latency: "
+cat $1 | cut -d ',' -f13 | sed "s/.*/scale=2; \0 \/ 1000/g" | bc
+echo ""
+echo "p99 latency: "
+cat $1 | cut -d ',' -f14 | sed "s/.*/scale=2; \0 \/ 1000/g" | bc

BIN
PyTorch/Classification/ConvNets/triton/se-resnext101-32x4d/Latency-vs-Throughput-TensorRT.png


BIN
PyTorch/Classification/ConvNets/triton/se-resnext101-32x4d/Performance-analysis-TensorRT-FP16.png


BIN
PyTorch/Classification/ConvNets/triton/se-resnext101-32x4d/Performance-analysis-TensorRT-FP32.png


+ 245 - 0
PyTorch/Classification/ConvNets/triton/se-resnext101-32x4d/README.md

@@ -0,0 +1,245 @@
+# Deploying the SE-ResNeXt101-32x4d model using Triton Inference Server
+
+The [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) provides a datacenter and cloud inferencing solution optimized for NVIDIA GPUs. The server provides an inference service via an HTTP or gRPC endpoint, allowing remote clients to request inferencing for any number of GPU or CPU models being managed by the server. 
+
+This folder contains instructions on how to deploy and run inference on
+Triton Inference Server as well as gather detailed performance analysis.
+
+## Table Of Contents
+
+* [Model overview](#model-overview)
+* [Setup](#setup)
+  * [Inference container](#inference-container)
+  * [Deploying the model](#deploying-the-model)
+  * [Running the Triton Inference Server](#running-the-triton-inference-server)
+* [Quick Start Guide](#quick-start-guide)
+  * [Running the client](#running-the-client)
+  * [Gathering performance data](#gathering-performance-data)
+* [Advanced](#advanced)
+  * [Automated benchmark script](#automated-benchmark-script)
+* [Performance](#performance)
+  * [Dynamic batching performance](#dynamic-batching-performance)
+  * [TensorRT backend inference performance (1x V100 16GB)](#tensorrt-backend-inference-performance-1x-v100-16gb)
+* [Release notes](#release-notes)
+  * [Changelog](#changelog)
+  * [Known issues](#known-issues)
+
+## Model overview
+The SE-ResNeXt101-32x4d is a [ResNeXt101-32x4d](https://arxiv.org/pdf/1611.05431.pdf)
+model with added Squeeze-and-Excitation module introduced
+in [Squeeze-and-Excitation Networks](https://arxiv.org/pdf/1709.01507.pdf) paper.
+
+The SE-ResNeXt101-32x4d model can be deployed for inference on the [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) using
+TorchScript, ONNX Runtime or TensorRT as an execution backend.
+
+## Setup
+
+This script requires trained SE-ResNeXt101-32x4d model checkpoint that can be used for deployment. 
+
+### Inference container
+
+For easy-to-use deployment, a build script for special inference container was prepared. To build that container, go to the main repository folder and run:
+
+`docker build -t sernxt_inference . -f triton/Dockerfile`
+
+This command will download the dependencies and build the inference containers. Then, run shell inside the container:
+
+`docker run -it --rm --gpus device=0 --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --net=host -v <PATH_TO_MODEL_REPOSITORY>:/repository sernxt_inference bash`
+
+Here `device=0,1,2,3` selects the GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates location to where the
+deployed models were stored.
+
+### Deploying the model
+
+To deploy the SE-ResNext101-32x4d model into the Triton Inference Server, you must run the `deployer.py` script from inside the deployment Docker container to achieve a compatible format. 
+
+```
+usage: deployer.py [-h] (--ts-script | --ts-trace | --onnx | --trt)
+                   [--triton-no-cuda] [--triton-model-name TRITON_MODEL_NAME]
+                   [--triton-model-version TRITON_MODEL_VERSION]
+                   [--triton-server-url TRITON_SERVER_URL]
+                   [--triton-max-batch-size TRITON_MAX_BATCH_SIZE]
+                   [--triton-dyn-batching-delay TRITON_DYN_BATCHING_DELAY]
+                   [--triton-engine-count TRITON_ENGINE_COUNT]
+                   [--save-dir SAVE_DIR]
+                   [--max_workspace_size MAX_WORKSPACE_SIZE] [--trt-fp16]
+                   [--capture-cuda-graph CAPTURE_CUDA_GRAPH]
+                   ...
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --ts-script           convert to torchscript using torch.jit.script
+  --ts-trace            convert to torchscript using torch.jit.trace
+  --onnx                convert to onnx using torch.onnx.export
+  --trt                 convert to trt using tensorrt
+
+triton related flags:
+  --triton-no-cuda      Use the CPU for tracing.
+  --triton-model-name TRITON_MODEL_NAME
+                        exports to appropriate directory structure for TRITON
+  --triton-model-version TRITON_MODEL_VERSION
+                        exports to appropriate directory structure for TRITON
+  --triton-server-url TRITON_SERVER_URL
+                        exports to appropriate directory structure for TRITON
+  --triton-max-batch-size TRITON_MAX_BATCH_SIZE
+                        Specifies the 'max_batch_size' in the TRITON model
+                        config. See the TRITON documentation for more info.
+  --triton-dyn-batching-delay TRITON_DYN_BATCHING_DELAY
+                        Determines the dynamic_batching queue delay in
+                        milliseconds(ms) for the TRITON model config. Use '0'
+                        or '-1' to specify static batching. See the TRITON
+                        documentation for more info.
+  --triton-engine-count TRITON_ENGINE_COUNT
+                        Specifies the 'instance_group' count value in the
+                        TRITON model config. See the TRITON documentation for
+                        more info.
+  --save-dir SAVE_DIR   Saved model directory
+
+optimization flags:
+  --max_workspace_size MAX_WORKSPACE_SIZE
+                        set the size of the workspace for trt export
+  --trt-fp16            trt flag ---- export model in mixed precision mode
+  --capture-cuda-graph CAPTURE_CUDA_GRAPH
+                        capture cuda graph for obtaining speedup. possible
+                        values: 0, 1. default: 1.
+  model_arguments       arguments that will be ignored by deployer lib and
+                        will be forwarded to your deployer script
+```
+
+Following model specific arguments have to be specified for model deployment:
+  
+```
+  --config CONFIG        Network architecture to use for deployment (eg. resnet50, 
+                         resnext101-32x4d or se-resnext101-32x4d)
+  --checkpoint CHECKPOINT
+                         Path to stored model weight. If not specified, model will be 
+                         randomly initialized
+  --batch_size BATCH_SIZE
+                         Batch size used for dummy dataloader
+  --fp16                 Use model with half-precision calculations
+```
+
+For example, to deploy model into TensorRT format, using half precision and max batch size 64 called
+`sernxt-trt-16` execute:
+
+`python -m triton.deployer --trt --trt-fp16 --triton-model-name sernxt-trt-16 --triton-max-batch-size 64 --save-dir /repository -- --config se-resnext101-32x4d --checkpoint model_checkpoint --batch_size 64 --fp16`
+
+Where `model_checkpoint` is a checkpoint for a trained model with the same architecture (se-resnext101-32x4d) as used during export.
+
+### Running the Triton Inference Server
+
+**NOTE: This step is executed outside the inference container.**
+
+Pull the Triton Inference Server container from our repository:
+
+`docker pull nvcr.io/nvidia/tritonserver:20.07-py3`
+
+Run the command to start the Triton Inference Server:
+
+`docker run -d --rm --gpus device=0 --ipc=host --network=host -p 8000:8000 -p 8001:8001 -p 8002:8002 -v <PATH_TO_MODEL_REPOSITORY>:/models nvcr.io/nvidia/tritonserver:20.07-py3 trtserver --model-store=/models --log-verbose=1 --model-control-mode=poll --repository-poll-secs=5`
+
+Here `device=0,1,2,3` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates the location where the 
+deployed models were stored. An additional `--model-controle-mode` option allows to reload the model when it changes in the filesystem. It is a required option for benchmark scripts that works with multiple model versions on a single Triton Inference Server instance.
+
+## Quick Start Guide
+
+### Running the client
+
+The client `client.py` checks the model accuracy against synthetic or real validation
+data. The client connects to Triton Inference Server and performs inference. 
+
+```
+usage: client.py [-h] --triton-server-url TRITON_SERVER_URL
+                 --triton-model-name TRITON_MODEL_NAME [-v]
+                 [--inference_data INFERENCE_DATA] [--batch_size BATCH_SIZE]
+                 [--fp16]
+
+optional arguments:
+  -h, --help            show this help message and exit
+  --triton-server-url TRITON_SERVER_URL
+                        URL adress of trtion server (with port)
+  --triton-model-name TRITON_MODEL_NAME
+                        Triton deployed model name
+  -v, --verbose         Verbose mode.
+  --inference_data INFERENCE_DATA
+                        Path to file with inference data.
+  --batch_size BATCH_SIZE
+                        Inference request batch size
+  --fp16                Use fp16 precision for input data
+
+```
+
+To run inference on the model exported in the previous steps, using the data located under
+`/dataset`, run:
+
+`python -m triton.client --triton-server-url localhost:8001 --triton-model-name sernxt-trt-16 --inference_data /data/test_data.bin --batch_size 16 --fp16`
+
+
+### Gathering performance data
+Performance data can be gathered using the `perf_client` tool. To use this tool to measure performance for batch_size=32, the following command can be used:
+
+`/workspace/bin/perf_client --max-threads 10 -m sernxt-trt-16 -x 1 -p 10000 -v -i gRPC -u localhost:8001 -b 32 -l 5000 --concurrency-range 1 -f result.csv`
+
+For more information about `perf_client`, refer to the [documentation](https://docs.nvidia.com/deeplearning/sdk/triton-inference-server-master-branch-guide/docs/optimization.html#perf-client).
+
+## Advanced
+
+### Automated benchmark script
+To automate benchmarks of different model configurations, a special benchmark script is located in `triton/scripts/benchmark.sh`. To use this script,
+run Triton Inference Server and then execute the script as follows:
+
+`bash triton/scripts/benchmark.sh <MODEL_REPOSITORY> <LOG_DIRECTORY> <ARCHITECTURE> (<CHECKPOINT_PATH>)`
+
+The benchmark script tests all supported backends with different batch sizes and server configuration. Logs from execution will be stored in `<LOG DIRECTORY>`.
+To process static configuration logs, `triton/scripts/process_output.sh` script can be used.
+
+## Performance
+
+### Dynamic batching performance
+The Triton Inference Server has a dynamic batching mechanism built-in that can be enabled. When it is enabled, the server creates inference batches from multiple received requests. This allows us to achieve better performance than doing inference on each single request. The single request is assumed to be a single image that needs to be inferenced. With dynamic batching enabled, the server will concatenate single image requests into an inference batch. The upper bound of the size of the inference batch is set to 64. All these parameters are configurable.
+
+Our results were obtained by running automated benchmark script. 
+Throughput is measured in images/second, and latency in milliseconds.
+
+### TensorRT backend inference performance (1x V100 16GB)
+**FP32 Inference Performance**
+
+|**Concurrent requests**|**Throughput (img/s)**|**Avg. Latency (ms)**|**90% Latency (ms)**|**95% Latency (ms)**|**99% Latency (ms)**|
+|-----|--------|-------|--------|-------|-------|
+|1 | 62.1 | 16.10 | 16.20 | 16.23 | 16.33|
+|2 | 66.2 | 30.23 | 30.26 | 30.27 | 30.32|
+|4 | 124.6 | 32.13 | 32.19 | 32.21 | 32.28|
+|8 | 151.1 | 52.91 | 53.10 | 53.15 | 53.21|
+|16 | 240 | 66.51 | 66.82 | 66.91 | 67.05|
+|32 | 326.8 | 98.00 | 132.41 | 134.00 | 137.71|
+|64 | 412.6 | 154.74 | 182.47 | 185.90 | 195.43|
+|128 | 506.7 | 252.58 | 275.03 | 277.56 | 279.86|
+|256 | 588.8 | 434.40 | 435.82 | 436.59 | 444.09|
+
+
+**FP16 Inference Performance**
+
+|**Concurrent requests**|**Throughput (img/s)**|**Avg. Latency (ms)**|**90% Latency (ms)**|**95% Latency (ms)**|**99% Latency (ms)**|
+|-----|--------|-------|--------|-------|-------|
+|1 | 77.5 | 12.90 | 12.98 | 13.01 | 13.05|
+|2 | 82.8 | 24.15 | 24.23 | 24.25 | 24.30|
+|4 | 128.8 | 31.06 | 38.81 | 39.15 | 39.31|
+|8 | 212 | 37.68 | 42.28 | 43.06 | 43.17|
+|16 | 351.3 | 45.52 | 48.41 | 48.52 | 48.92|
+|32 | 548 | 58.38 | 59.09 | 59.38 | 59.80|
+|64 | 774 | 82.63 | 84.40 | 84.88 | 86.49|
+|128 | 985.7 | 130.30 | 130.83 | 131.26 | 132.86|
+|256 | 1132.8 | 225.56 | 226.34 | 227.31 | 229.30 |
+
+![Latency vs Througput](./Latency-vs-Throughput-TensorRT.png)
+
+![Performance analysis - TensorRT FP32](./Performance-analysis-TensorRT-FP32.png)
+
+![Performance analysis - TensorRT FP16](./Performance-analysis-TensorRT-FP16.png)
+
+
+## Release notes
+
+### Changelog
+September 2020
+- Initial release