Просмотр исходного кода

[Transformer/PyT] 22.06 release

Jan Baczek 3 лет назад
Родитель
Сommit
d666f14553
45 измененных файлов с 851 добавлено и 2543 удалено
  1. 5 3
      PyTorch/Translation/Transformer/Dockerfile
  2. 0 6
      PyTorch/Translation/Transformer/NOTICE
  3. 34 63
      PyTorch/Translation/Transformer/README.md
  4. 0 63
      PyTorch/Translation/Transformer/distributed_train.py
  5. 1 1
      PyTorch/Translation/Transformer/fairseq/data/__init__.py
  6. 1 1
      PyTorch/Translation/Transformer/fairseq/data/data_utils.py
  7. 0 35
      PyTorch/Translation/Transformer/fairseq/data/fairseq_dataset.py
  8. 1 1
      PyTorch/Translation/Transformer/fairseq/data/language_pair_dataset.py
  9. 0 108
      PyTorch/Translation/Transformer/fairseq/data/token_block_dataset.py
  10. 1 1
      PyTorch/Translation/Transformer/fairseq/ddp_trainer.py
  11. 16 21
      PyTorch/Translation/Transformer/fairseq/distributed_utils.py
  12. 3 3
      PyTorch/Translation/Transformer/fairseq/log_helper.py
  13. 0 159
      PyTorch/Translation/Transformer/fairseq/models/fused_layer_norm.py
  14. 1 1
      PyTorch/Translation/Transformer/fairseq/models/transformer.py
  15. 0 138
      PyTorch/Translation/Transformer/fairseq/modules/adaptive_softmax.py
  16. 0 38
      PyTorch/Translation/Transformer/fairseq/modules/conv_tbc.py
  17. 0 258
      PyTorch/Translation/Transformer/fairseq/modules/downsampled_multihead_attention.py
  18. 0 20
      PyTorch/Translation/Transformer/fairseq/modules/grad_multiply.py
  19. 0 89
      PyTorch/Translation/Transformer/fairseq/modules/linearized_convolution.py
  20. 1 1
      PyTorch/Translation/Transformer/fairseq/modules/multihead_attention.py
  21. 0 33
      PyTorch/Translation/Transformer/fairseq/modules/scalar_bias.py
  22. 3 3
      PyTorch/Translation/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm.cpp
  23. 582 212
      PyTorch/Translation/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm_cuda.cu
  24. 1 1
      PyTorch/Translation/Transformer/fairseq/optim/adam.py
  25. 1 1
      PyTorch/Translation/Transformer/fairseq/optim/fairseq_optimizer.py
  26. 1 24
      PyTorch/Translation/Transformer/fairseq/options.py
  27. 16 30
      PyTorch/Translation/Transformer/fairseq/sequence_generator.py
  28. 1 1
      PyTorch/Translation/Transformer/fairseq/tokenizer.py
  29. 1 1
      PyTorch/Translation/Transformer/fairseq/utils.py
  30. 3 8
      PyTorch/Translation/Transformer/inference.py
  31. 0 123
      PyTorch/Translation/Transformer/scripts/deployer.py
  32. 0 969
      PyTorch/Translation/Transformer/scripts/deployer_lib.py
  33. 0 1
      PyTorch/Translation/Transformer/scripts/docker/build.sh
  34. 0 15
      PyTorch/Translation/Transformer/scripts/docker/launch.sh
  35. 0 54
      PyTorch/Translation/Transformer/scripts/export_model.sh
  36. 12 11
      PyTorch/Translation/Transformer/scripts/run_DGX1_AMP.sh
  37. 10 10
      PyTorch/Translation/Transformer/scripts/run_DGX1_FP32.sh
  38. 58 0
      PyTorch/Translation/Transformer/scripts/run_DGX2_AMP.sh
  39. 10 19
      PyTorch/Translation/Transformer/scripts/run_DGX2_FP32.sh
  40. 10 10
      PyTorch/Translation/Transformer/scripts/run_DGXA100_AMP.sh
  41. 57 0
      PyTorch/Translation/Transformer/scripts/run_DGXA100_TF32.sh
  42. 15 0
      PyTorch/Translation/Transformer/scripts/run_inference.sh
  43. 2 2
      PyTorch/Translation/Transformer/scripts/run_training.sh
  44. 1 1
      PyTorch/Translation/Transformer/setup.py
  45. 3 4
      PyTorch/Translation/Transformer/train.py

+ 5 - 3
PyTorch/Translation/Transformer/Dockerfile

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.05-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.06-py3
 FROM ${FROM_IMAGE_NAME}
 
 WORKDIR /workspace
@@ -20,7 +20,8 @@ WORKDIR /workspace
 # && cd apex \
 # && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
 # Install Python dependencies
-RUN pip install --no-cache-dir \
+RUN pip install --upgrade --no-cache-dir pip \
+ && pip install --no-cache-dir \
       sacrebleu \
       sentencepiece
 RUN pip install jupyter
@@ -44,4 +45,5 @@ RUN git clone https://github.com/rsennrich/subword-nmt.git /workspace/translatio
 RUN git clone https://github.com/NVIDIA/cutlass.git && cd cutlass && git checkout ed2ed4d6 && cd ..
 COPY . .
 RUN pip install -e .
-RUN pip install git+https://github.com/NVIDIA/[email protected]#egg=dllogger
+RUN pip install nvidia-pyindex
+RUN pip install nvidia-dllogger

+ 0 - 6
PyTorch/Translation/Transformer/NOTICE

@@ -1,6 +0,0 @@
-Transformer PyTorch
-
-This repository includes software from https://github.com/facebookresearch/fairseq
-licensed under the BSD License.
-
-

+ 34 - 63
PyTorch/Translation/Transformer/README.md

@@ -152,7 +152,7 @@ The following section lists the requirements in order to start training the Tran
 This repository contains Dockerfile which extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
 
 -   [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
--   [PyTorch 20.03-py3+ NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
+-   [PyTorch 22.06-py3+ NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
 -   GPU-based architecture:
 	- [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
 	- [NVIDIA Turing](https://www.nvidia.com/en-us/geforce/turing/)
@@ -196,7 +196,7 @@ After running this command, data will be downloaded to `/workspace/translation/e
 
 4. Start training
 ```bash
-python -m torch.distributed.launch --nproc_per_node 8 /workspace/translation/train.py /data/wmt14_en_de_joined_dict \
+python -m torch.distributed.run --nproc_per_node 8 /workspace/translation/train.py /data/wmt14_en_de_joined_dict \
   --arch transformer_wmt_en_de_big_t2t \
   --share-all-embeddings \
   --optimizer adam \
@@ -217,8 +217,7 @@ python -m torch.distributed.launch --nproc_per_node 8 /workspace/translation/tra
   --fuse-layer-norm \
   --amp \
   --amp-level O2 \
-  --save-dir /workspace/checkpoints \
-  --distributed-init-method env:// 
+  --save-dir /workspace/checkpoints
 ```
 
 The script saves checkpoints every epoch to the directory specified in the `--save-dir` option. In addition, the best performing checkpoint (in terms of loss) and the latest checkpoints are saved separately.
@@ -363,8 +362,6 @@ sacrebleu -t wmt14/full -l en-de --echo src | python inference.py --buffer-size
 
 ## Performance
 
-The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
-
 ### Benchmarking
 
 The following section shows how to run benchmarks measuring the model performance in training and inference modes.
@@ -403,31 +400,31 @@ Mean BLEU score after reaching 4.02 validation loss is 27.38. We observe varianc
 </p>
 
 ##### Training accuracy: NVIDIA DGX A100 (8x A100 40GB)
-Our results were obtained by running the `run_DGXA100_AMP_8GPU.sh` and `run_DGXA100_TF32_8GPU.sh` training scripts in the pytorch-20.06-py3 NGC container on NVIDIA DGX A100 (8x A100 40GB) GPUs. We report average accuracy over 6 runs. We consider a model trained when it reaches minimal validation loss. Time to train contains only training time without validation. Depending on a configuration and frequency of validation it can take up to additional minute per epoch. 
+Our results were obtained by running the `run_DGXA100_AMP_8GPU.sh` and `run_DGXA100_TF32_8GPU.sh` training scripts in the pytorch-22.06-py3 NGC container on NVIDIA DGX A100 (8x A100 40GB) GPUs. We report average accuracy over 6 runs. We consider a model trained when it reaches minimal validation loss. Time to train contains only training time without validation. Depending on a configuration and frequency of validation it can take up to additional minute per epoch. 
 
 | GPUs    | Batch size / GPU    | Accuracy - TF32  | Accuracy - mixed precision  |   Time to train - TF32  |  Time to train - mixed precision | Time to train speedup (TF32 to mixed precision)        
 |---------|---------------------|------------------|-----------------------------|-------------------------|----------------------------------|------------------------------------
-| 8       | 10240               | 27.92            | 27.76                       | 2.87 hours              | 2.79 hours                       | x1.03
+| 8       | 10240               | 27.92            | 27.76                       | 2.74 hours              | 2.64 hours                       | x1.04
 
 ##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
 
-Our results were obtained by running the `run_DGX1_AMP_8GPU.sh` and `run_DGX1_FP32_8GPU.sh` training scripts in the pytorch-20.06-py3 NGC container on NVIDIA DGX-1 (8x V100 16GB) GPUs. We report average accuracy over 6 runs. We consider a model trained when it reaches minimal validation loss. Time to train contains only training time without validation. Depending on a configuration and frequency of validation it can take up to additional minute per epoch. Using mixed precision we could fit a larger batch size in the memory, further speeding up the training.
+Our results were obtained by running the `run_DGX1_AMP_8GPU.sh` and `run_DGX1_FP32_8GPU.sh` training scripts in the pytorch-22.06-py3 NGC container on NVIDIA DGX-1 (8x V100 16GB) GPUs. We report average accuracy over 6 runs. We consider a model trained when it reaches minimal validation loss. Time to train contains only training time without validation. Depending on a configuration and frequency of validation it can take up to additional minute per epoch. Using mixed precision we could fit a larger batch size in the memory, further speeding up the training.
 
 | GPUs    | Batch size / GPU    | Accuracy - FP32  | Accuracy - mixed precision  |   Time to train - FP32  |  Time to train - mixed precision | Time to train speedup (FP32 to mixed precision)        
 |---------|---------------------|------------------|-----------------------------|-------------------------|----------------------------------|------------------------------------
-| 8       | 5120/2560           | 27.66            | 27.82                       | 12 hours                | 4.6  hours                       | x2.64
+| 8       | 5120/2560           | 27.66            | 27.82                       | 11.8 hours              | 4.5  hours                       | x2.62
 
 #### Training performance results
 
 ##### Training performance: NVIDIA DGX A100 (8x A100 40GB)
 
-Our results were obtained by running the `run_DGXA100_AMP_8GPU.sh` and `run_DGXA100_TF32_8GPU.sh` training scripts in the pytorch-20.06-py3 NGC container on NVIDIA DGX A100 (8x A100 40GB) GPUs. Performance numbers (in tokens per second) were averaged over an entire training epoch.
+Our results were obtained by running the `run_DGXA100_AMP_8GPU.sh` and `run_DGXA100_TF32_8GPU.sh` training scripts in the pytorch-22.06-py3 NGC container on NVIDIA DGX A100 (8x A100 40GB) GPUs. Performance numbers (in tokens per second) were averaged over an entire training epoch.
 
 | GPUs   | Batch size / GPU   | Throughput - TF32    | Throughput - mixed precision    | Throughput speedup (TF32 - mixed precision)   | Weak scaling - TF32    | Weak scaling - mixed precision        
 |--------|--------------------|----------------------|---------------------------------|-----------------------------------------------|------------------------|-----
-| 8      | 10240              | 316913               | 582721                          | x1.84                                         | 6.93                   | 7.05 
-| 4      | 10240              | 161980               | 298741                          | x1.84                                         | 3.54                   | 3.62
-| 1      | 10240              | 45755                | 82618                           | x1.81                                         | 1                      | 1
+| 8      | 10240              | 347936               | 551599                          | x1.59                                         | 6.81                   | 6.72 
+| 4      | 10240              | 179245               | 286081                          | x1.60                                         | 3.51                   | 3.49
+| 1      | 10240              | 51057                | 82059                           | x1.60                                         | 1                      | 1
 
 
 To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
@@ -444,27 +441,27 @@ The following plot shows average validation loss curves for different configs. W
 
 ##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
 
-Our results were obtained by running the `run_DGX1_AMP_8GPU.sh` and `run_DGX1_FP32_8GPU.sh` training scripts in the pytorch-20.06-py3 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs. Performance numbers (in tokens per second) were averaged over an entire training epoch. Using mixed precision we could fit a larger batch size in the memory, further speeding up the training.
+Our results were obtained by running the `run_DGX1_AMP_8GPU.sh` and `run_DGX1_FP32_8GPU.sh` training scripts in the pytorch-22.06-py3 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs. Performance numbers (in tokens per second) were averaged over an entire training epoch. Using mixed precision we could fit a larger batch size in the memory, further speeding up the training.
 
 | GPUs   | Batch size / GPU   | Throughput - FP32    | Throughput - mixed precision    | Throughput speedup (FP32 - mixed precision)   | Weak scaling - FP32    | Weak scaling - mixed precision        
 |--------|--------------------|----------------------|---------------------------------|-----------------------------------------------|------------------------|-----
-| 8      | 5120/2560          | 58742                | 223245                          | x3.80                                         | 6.91                   | 6.67
-| 4      | 5120/2560          | 29674                | 115269                          | x3.88                                         | 3.49                   | 3.44
-| 1      | 5120/2560          | 8498                 | 33468                           | x3.94                                         | 1                      | 1
+| 8      | 5120/2560          | 59316                | 214656                          | x3.62                                         | 6.79                   | 6.52
+| 4      | 5120/2560          | 30204                | 109726                          | x3.63                                         | 3.46                   | 3.33
+| 1      | 5120/2560          | 8742                 | 32942                           | x3.77                                         | 1                      | 1
 
 
 To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
 
 ##### Training performance: NVIDIA DGX-2 (16x V100 32GB)
 
-Our results were obtained by running the `run_DGX1_AMP_8GPU.sh` and `run_DGX1_FP32_8GPU.sh` training scripts setting number of GPUs to 16 in the pytorch-20.06-py3 NGC container on NVIDIA DGX-2 with (16x V100 32GB) GPUs. Performance numbers (in tokens per second) were averaged over an entire training epoch. Using mixed precision we could fit a larger batch size in the memory, further speeding up the training.
+Our results were obtained by running the `run_DGX1_AMP_8GPU.sh` and `run_DGX1_FP32_8GPU.sh` training scripts setting number of GPUs to 16 in the pytorch-22.06-py3 NGC container on NVIDIA DGX-2 with (16x V100 32GB) GPUs. Performance numbers (in tokens per second) were averaged over an entire training epoch. Using mixed precision we could fit a larger batch size in the memory, further speeding up the training.
 
 | GPUs   | Batch size / GPU   | Throughput - FP32    | Throughput - mixed precision    | Throughput speedup (FP32 - mixed precision)   | Weak scaling - FP32    | Weak scaling - mixed precision        
 |--------|--------------------|----------------------|---------------------------------|-----------------------------------------------|------------------------|-----
-| 16     | 10240/5120         | 130867               | 510267                          | x3.9                                          | 13.38                  | 12.7
-| 8      | 10240/5120         | 68829                | 269464                          | x3.91                                         | 7.04                   | 6.71
-| 4      | 10240/5120         | 35168                | 141143                          | x4.01                                         | 3.6                    | 3.51
-| 1      | 10240/5120         | 9779                 | 40163                           | x4.11                                         | 1                      | 1   
+| 16     | 10240/5120         | 136253               | 517227                          | x3.80                                         | 13.87                  | 12.96
+| 8      | 10240/5120         | 68929                | 267815                          | x3.89                                         | 7.01                   | 6.71
+| 4      | 10240/5120         | 35216                | 137703                          | x3.91                                         | 3.58                   | 3.45
+| 1      | 10240/5120         | 9827                 | 39911                           | x4.06                                         | 1                      | 1   
 
 
 To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
@@ -475,53 +472,23 @@ Our implementation of the Transformer has dynamic batching algorithm, which batc
 
 ##### Inference performance: NVIDIA DGX A100 (1x A100 40GB)
 
-Our results were obtained by running the `inference.py` inferencing benchmarking script in the pytorch-20.06-py3 NGC container on NVIDIA DGX A100 (1x A100 40GB) GPU.
-
-FP16
-
-| Batch size |  Throughput Avg | Latency Avg | Latency 90% |Latency 95% |Latency 99% |
-|------------|-----------------|-------------|-------------|------------|------------|
-| 10240      | 9653            | 0.986s      | 1.291s      | 2.157s     | 2.167s     |
-| 2560       | 5092            | 0.504s      | 0.721s      | 0.830s     | 1.752s     |
-| 1024       | 2590            | 0.402s      | 0.587s      | 0.666s     | 0.918s     |
-| 512        | 1357            | 0.380s      | 0.561s      | 0.633s     | 0.788s     |
-| 256        | 721             | 0.347s      | 0.513s      | 0.576s     | 0.698s     | 
-
-TF32
+Our results were obtained by running the `inference.py` inferencing benchmarking script in the pytorch-22.06-py3 NGC container on NVIDIA DGX A100 (1x A100 40GB) GPU.
 
-| Batch size | Throughput Avg | Latency Avg | Latency 90% |Latency 95% |Latency 99% |
-|------------|----------------|-------------|-------------|------------|------------|
-|  10240     | 7755           | 1.227s      | 1.592s      | 2.512s     | 2.525s     |
-|  2560      | 4624           | 0.555s      | 0.786s      | 0.872s     | 1.886s     |
-|  1024      | 2394           | 0.435s      | 0.627s      | 0.702s     | 0.881s     |
-|  512       | 1275           | 0.405s      | 0.586s      | 0.663s     | 0.821s     |
-|  256       | 677            | 0.370s      | 0.546s      | 0.613s     | 0.733s     |    
+| Precision | Batch size | Throughput Avg | Latency Avg | Latency 90% |Latency 95% |Latency 99% |
+|-----------|------------|----------------|-------------|-------------|------------|------------|
+| TF32      |  10240     | 7105           | 1.22s       | 1.67s       | 1.67s      | 1.67s      |
+| FP16      |  10240     | 7988           | 1.09s       | 1.73s       | 1.73s      | 1.73s      |
 
 To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
 
 ##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
 
-Our results were obtained by running the `inference.py` inferencing benchmarking script in the pytorch-20.06-py3 NGC container on NVIDIA DGX-1 with (1x V100 16GB) GPU.
+Our results were obtained by running the `inference.py` inferencing benchmarking script in the pytorch-22.06-py3 NGC container on NVIDIA DGX-1 with (1x V100 16GB) GPU.
 
-FP16
-
-| Batch size | Throughput Avg | Latency Avg | Latency 90% |Latency 95% |Latency 99% |
-|------------|----------------|-------------|-------------|------------|------------|
-| 10240      | 7464           | 1.283s      | 1.704s      | 1.792s     | 1.801s     |
-| 2560       | 3596           | 0.719s      | 1.066s      | 1.247s     | 1.423s     |
-| 1024       | 1862           | 0.563s      | 0.857s      | 0.936s     | 1.156s     |
-| 512        | 1003           | 0.518s      | 0.782s      | 0.873s     | 1.103s     |
-| 256        | 520            | 0.484s      | 0.723s      | 0.813s     | 0.992s     |
-
-FP32
-
-| Batch size | Throughput Avg | Latency Avg | Latency 90% | Latency 95% | Latency 99% |
-|------------|----------------|-------------|-------------|-------------|-------------|
-| 10240      | 3782           | 2.531s      | 3.091s      | 3.121s      | 3.136s      |
-| 2560       | 2910           | 0.888s      | 1.221s      | 1.252s      | 1.432s      |
-| 1024       | 1516           | 0.692s      | 1.001s      | 1.126s      | 1.297s      |
-| 512        | 941            | 0.551s      | 0.812s      | 0.893s      | 1.133s      |
-| 256        | 502            | 0.501s      | 0.734s      | 0.822s      | 0.978s      |
+| Precision | Batch size | Throughput Avg | Latency Avg | Latency 90% | Latency 95% | Latency 99% |
+|-----------|------------|----------------|-------------|-------------|-------------|-------------|
+| FP32      | 10240      | 3461           | 2.51s       | 3.19 s      | 3.19s       | 3.19s       |
+| FP16      | 10240      | 5983           | 1.45s       | 2.03 s      | 2.03s       | 2.03s       |
 
 To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
 
@@ -530,6 +497,10 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
 
 ### Changelog
 
+February 2022:
+- Update depricated calls in PyTorch CPP and Python API
+- Update README with latest performance measurements
+
 June 2020
 - add TorchScript support
 - Ampere support

+ 0 - 63
PyTorch/Translation/Transformer/distributed_train.py

@@ -1,63 +0,0 @@
-#!/usr/bin/env python3 -u
-# Copyright (c) 2017-present, Facebook, Inc.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the LICENSE file in
-# the root directory of this source tree. An additional grant of patent rights
-# can be found in the PATENTS file in the same directory.
-#
-#-------------------------------------------------------------------------
-#
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-
-import os
-import socket
-import subprocess
-
-from train import main as single_process_main
-from fairseq import distributed_utils, options
-
-
-def main(args):
-    if args.distributed_init_method is None and args.distributed_port > 0:
-        # We can determine the init method automatically for Slurm.
-        node_list = os.environ.get('SLURM_JOB_NODELIST')
-        if node_list is not None:
-            try:
-                hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
-                args.distributed_init_method = 'tcp://{host}:{port}'.format(
-                    host=hostnames.split()[0].decode('utf-8'),
-                    port=args.distributed_port)
-                args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
-                args.device_id = int(os.environ.get('SLURM_LOCALID'))
-            except subprocess.CalledProcessError as e:  # scontrol failed
-                raise e
-            except FileNotFoundError as e:  # Slurm is not installed
-                pass
-    if args.distributed_init_method is None:
-        raise ValueError('--distributed-init-method or --distributed-port '
-                         'must be specified for distributed training')
-
-    args.distributed_rank = distributed_utils.distributed_init(args)
-    args.device_id = int(os.environ.get('LOCAL_RANK', args.local_rank))
-    print('| initialized host {} as rank {} and device id {}'.format(socket.gethostname(), args.distributed_rank, args.device_id))
-    single_process_main(args)
-
-
-if __name__ == '__main__':
-    parser = options.get_training_parser()
-    args = options.parse_args_and_arch(parser)
-    main(args)

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/data/__init__.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/data/data_utils.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 0 - 35
PyTorch/Translation/Transformer/fairseq/data/fairseq_dataset.py

@@ -1,35 +0,0 @@
-# Copyright (c) 2017-present, Facebook, Inc.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the LICENSE file in
-# the root directory of this source tree. An additional grant of patent rights
-# can be found in the PATENTS file in the same directory.
-
-import torch.utils.data
-
-
-class FairseqDataset(torch.utils.data.Dataset):
-    """A dataset that provides helpers for batching."""
-
-    def __getitem__(self, index):
-        raise NotImplementedError
-
-    def __len__(self):
-        raise NotImplementedError
-
-    def collater(self, samples):
-        """Merge a list of samples to form a mini-batch."""
-        raise NotImplementedError
-
-
-    def num_tokens(self, index):
-        """Return an example's length (number of tokens), used for batching."""
-        raise NotImplementedError
-
-    def ordered_indices(self, seed=None, epoch=0):
-        """Ordered indices for batching."""
-        raise NotImplementedError
-
-    def valid_size(self, index, max_positions):
-        """Check if an example's size is valid according to max_positions."""
-        raise NotImplementedError

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/data/language_pair_dataset.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 0 - 108
PyTorch/Translation/Transformer/fairseq/data/token_block_dataset.py

@@ -1,108 +0,0 @@
-# Copyright (c) 2017-present, Facebook, Inc.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the LICENSE file in
-# the root directory of this source tree. An additional grant of patent rights
-# can be found in the PATENTS file in the same directory.
-#
-#-------------------------------------------------------------------------
-#
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import math
-
-import numpy as np
-import torch
-
-
-class TokenBlockDataset(torch.utils.data.Dataset):
-    """Break a 1d tensor of tokens into blocks.
-
-    The blocks are fetched from the original tensor so no additional memory is allocated.
-
-    Args:
-        tokens: 1d tensor of tokens to break into blocks
-        sizes: sentence lengths (required for 'complete' and 'eos')
-        block_size: maximum block size (ignored in 'eos' break mode)
-        break_mode: Mode used for breaking tokens. Values can be one of:
-            - 'none': break tokens into equally sized blocks (up to block_size)
-            - 'complete': break tokens into blocks (up to block_size) such that
-                blocks contains complete sentences, although block_size may be
-                exceeded if some sentences exceed block_size
-            - 'eos': each block contains one sentence (block_size is ignored)
-        include_targets: return next tokens as targets
-    """
-
-    def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False):
-        super().__init__()
-
-        self.tokens = tokens
-        self.total_size = len(tokens)
-        self.include_targets = include_targets
-        self.slice_indices = []
-
-        if break_mode is None or break_mode == 'none':
-            length = math.ceil(len(tokens) / block_size)
-
-            def block_at(i):
-                start = i * block_size
-                end = min(start + block_size, len(tokens))
-                return (start, end)
-
-            self.slice_indices = [block_at(i) for i in range(length)]
-        elif break_mode == 'complete':
-            assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
-            tok_idx = 0
-            sz_idx = 0
-            curr_size = 0
-            while sz_idx < len(sizes):
-                if curr_size + sizes[sz_idx] <= block_size or curr_size == 0:
-                    curr_size += sizes[sz_idx]
-                    sz_idx += 1
-                else:
-                    self.slice_indices.append((tok_idx, tok_idx + curr_size))
-                    tok_idx += curr_size
-                    curr_size = 0
-            if curr_size > 0:
-                self.slice_indices.append((tok_idx, tok_idx + curr_size))
-        elif break_mode == 'eos':
-            assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
-            curr = 0
-            for sz in sizes:
-                # skip samples with just 1 example (which would be just the eos token)
-                if sz > 1:
-                    self.slice_indices.append((curr, curr + sz))
-                curr += sz
-        else:
-            raise ValueError('Invalid break_mode: ' + break_mode)
-
-        self.sizes = np.array([e - s for s, e in self.slice_indices])
-
-    def __getitem__(self, index):
-        s, e = self.slice_indices[index]
-
-        item = torch.LongTensor(self.tokens[s:e])
-
-        if self.include_targets:
-            # target is the sentence, for source, rotate item one token to the left (would start with eos)
-            if s == 0:
-                source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
-            else:
-                source = self.tokens[s - 1:e - 1]
-
-            return torch.LongTensor(source), item
-        return item
-
-    def __len__(self):
-        return len(self.slice_indices)

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/ddp_trainer.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 16 - 21
PyTorch/Translation/Transformer/fairseq/distributed_utils.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -34,26 +34,21 @@ def is_master(args):
 
 
 def distributed_init(args):
-    if args.distributed_world_size == 1:
-        raise ValueError('Cannot initialize distributed with distributed_world_size=1')
-
-    print('| distributed init (rank {}): {}'.format(
-        args.distributed_rank, args.distributed_init_method), flush=True)
-    print("| distributed env init. MASTER_ADDR: " + os.environ['MASTER_ADDR'] +
-          ", MASTER_PORT: " + os.environ['MASTER_PORT'] +
-          ", WORLD_SIZE: " + os.environ['WORLD_SIZE'] + ", RANK: " + os.environ['RANK'], flush=True)
-    torch.distributed.init_process_group(
-        backend=args.distributed_backend, init_method='env://')
-    print("| distributed init done!", flush=True)
-    args.distributed_world_size = int(os.environ['WORLD_SIZE'])
-
-    args.distributed_rank = torch.distributed.get_rank()
-    args.device_id = int(os.environ.get('LOCAL_RANK', args.local_rank))
-    suppress_output(args)
-    print('| initialized host {} as rank {} and device id {}'
-          .format(socket.gethostname(), args.distributed_rank, args.device_id))
-
-    return args.distributed_rank
+    args.distributed_world_size = int(os.environ.get('WORLD_SIZE',1))
+    args.distributed_rank = int(os.environ.get('RANK',0))
+    args.local_rank = int(os.environ.get('LOCAL_RANK', 0))
+
+    if args.distributed_world_size > 1:
+
+        print('| distributed init (rank {}): env://'.format(args.distributed_rank), flush=True)
+        print(f"| distributed env init. MASTER_ADDR: {os.environ['MASTER_ADDR']}, MASTER_PORT: {os.environ['MASTER_PORT']}" +
+              f", WORLD_SIZE: {os.environ['WORLD_SIZE']}, RANK: {os.environ['RANK']}", flush=True)
+        torch.distributed.init_process_group(backend='nccl', init_method='env://')
+        print("| distributed init done!", flush=True)
+
+        suppress_output(args)
+        print('| initialized host {} as rank {} and device id {}'
+              .format(socket.gethostname(), args.distributed_rank, args.local_rank))
 
 
 def suppress_output(main_args):

+ 3 - 3
PyTorch/Translation/Transformer/fairseq/log_helper.py

@@ -177,10 +177,10 @@ def setup_logger(args):
     container_setup_info = get_framework_env_vars()
     dllogger.log(step='PARAMETER', data=container_setup_info, verbosity=0)
 
-    dllogger.metadata('loss', {'unit': None, 'GOAL': 'MINIMIZE', 'STAGE': 'TRAIN'})
-    dllogger.metadata('val_loss', {'unit': None, 'GOAL': 'MINIMIZE', 'STAGE': 'VAL'})
+    dllogger.metadata('loss', {'unit': 'nat', 'GOAL': 'MINIMIZE', 'STAGE': 'TRAIN'})
+    dllogger.metadata('val_loss', {'unit': 'nat', 'GOAL': 'MINIMIZE', 'STAGE': 'VAL'})
     dllogger.metadata('speed', {'unit': 'tokens/s', 'format': ':.3f', 'GOAL': 'MAXIMIZE', 'STAGE': 'TRAIN'})
-    dllogger.metadata('accuracy', {'unit': None, 'format': ':.2f', 'GOAL': 'MAXIMIZE', 'STAGE': 'VAL'})
+    dllogger.metadata('accuracy', {'unit': 'bleu', 'format': ':.2f', 'GOAL': 'MAXIMIZE', 'STAGE': 'VAL'})
 
 
 def get_framework_env_vars():

+ 0 - 159
PyTorch/Translation/Transformer/fairseq/models/fused_layer_norm.py

@@ -1,159 +0,0 @@
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.import math
-
-import math
-import torch
-import numbers
-from torch.nn.parameter import Parameter
-from torch.nn import init
-
-import fused_layer_norm_cuda
-
-class FusedLayerNormAffineFunction(torch.autograd.Function):
-  def __init__(self, normalized_shape, eps=1e-6):
-    self.normalized_shape = normalized_shape
-    self.eps = eps
-
-  def forward(self, input, weight, bias):
-    input_ = input.contiguous()
-    weight_ = weight.contiguous()
-    bias_ = bias.contiguous()
-    output, mean, invvar = fused_layer_norm_cuda.forward_affine(
-        input_, self.normalized_shape, weight_, bias_, self.eps)
-    self.save_for_backward(input_, weight_, bias_, mean, invvar)
-    return output
-
-  def backward(self, grad_output):
-    input_, weight_, bias_, mean, invvar = self.saved_tensors
-    grad_input = grad_weight = grad_bias = None
-    grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
-        grad_output.contiguous(), mean, invvar,
-        input_, self.normalized_shape, 
-        weight_, bias_, self.eps)
-    return grad_input, grad_weight, grad_bias;
-    
-class FusedLayerNormFunction(torch.autograd.Function):
-  def __init__(self, normalized_shape, eps=1e-6):
-    self.normalized_shape = normalized_shape
-    self.eps = eps
-
-  def forward(self, input):
-    input_ = input.contiguous()
-    output, mean, invvar = fused_layer_norm_cuda.forward(
-        input_, self.normalized_shape, self.eps)
-    self.save_for_backward(input_, mean, invvar)
-    return output
-
-  def backward(self, grad_output):
-    input_, mean, invvar = self.saved_tensors
-    grad_input = None
-    grad_input = fused_layer_norm_cuda.backward(
-        grad_output.contiguous(), mean, invvar,
-        input_, self.normalized_shape,
-        self.eps)
-    return grad_input
-
-def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
-    return FusedLayerNormAffineFunction(normalized_shape,eps)(input, weight, bias)
-
-def fused_layer_norm(input, normalized_shape, eps=1e-6):
-    return FusedLayerNormFunction(normalized_shape,eps)(input)
-
-class FusedLayerNorm(torch.nn.Module):
-    r"""Applies Layer Normalization over a mini-batch of inputs as described in
-    the paper `Layer Normalization`_ .
-
-    .. math::
-        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
-
-    The mean and standard-deviation are calculated separately over the last
-    certain number dimensions which have to be of the shape specified by
-    :attr:`normalized_shape`.
-    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
-    :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
-
-    .. note::
-        Unlike Batch Normalization and Instance Normalization, which applies
-        scalar scale and bias for each entire channel/plane with the
-        :attr:`affine` option, Layer Normalization applies per-element scale and
-        bias with :attr:`elementwise_affine`.
-
-    This layer uses statistics computed from input data in both training and
-    evaluation modes.
-
-    Args:
-        normalized_shape (int or list or torch.Size): input shape from an expected input
-            of size
-
-            .. math::
-                [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
-                    \times \ldots \times \text{normalized\_shape}[-1]]
-
-            If a single integer is used, it is treated as a singleton list, and this module will
-            normalize over the last dimension which is expected to be of that specific size.
-        eps: a value added to the denominator for numerical stability. Default: 1e-5
-        elementwise_affine: a boolean value that when set to ``True``, this module
-            has learnable per-element affine parameters initialized to ones (for weights)
-            and zeros (for biases). Default: ``True``.
-
-    Shape:
-        - Input: :math:`(N, *)`
-        - Output: :math:`(N, *)` (same shape as input)
-
-    Examples::
-
-        >>> input = torch.randn(20, 5, 10, 10)
-        >>> # With Learnable Parameters
-        >>> m = nn.LayerNorm(input.size()[1:])
-        >>> # Without Learnable Parameters
-        >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
-        >>> # Normalize over last two dimensions
-        >>> m = nn.LayerNorm([10, 10])
-        >>> # Normalize over last dimension of size 10
-        >>> m = nn.LayerNorm(10)
-        >>> # Activating the module
-        >>> output = m(input)
-
-    .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
-    """
-    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
-        super(FusedLayerNorm, self).__init__()
-        if isinstance(normalized_shape, numbers.Integral):
-            normalized_shape = (normalized_shape,)
-        self.normalized_shape = torch.Size(normalized_shape)
-        self.eps = eps
-        self.elementwise_affine = elementwise_affine
-        if self.elementwise_affine:
-            self.weight = Parameter(torch.Tensor(*normalized_shape))
-            self.bias = Parameter(torch.Tensor(*normalized_shape))
-        else:
-            self.register_parameter('weight', None)
-            self.register_parameter('bias', None)
-        self.reset_parameters()
-
-    def reset_parameters(self):
-        if self.elementwise_affine:
-            init.ones_(self.weight)
-            init.zeros_(self.bias)
-
-    def forward(self, input):
-        if self.elementwise_affine:
-          return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
-              input, self.weight, self.bias)
-        else:
-          return FusedLayerNormFunction(self.normalized_shape,self.eps)(
-              input)
-
-    def extra_repr(self):
-        return '{normalized_shape}, eps={eps}, ' \
-            'elementwise_affine={elementwise_affine}'.format(**self.__dict__)

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/models/transformer.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 0 - 138
PyTorch/Translation/Transformer/fairseq/modules/adaptive_softmax.py

@@ -1,138 +0,0 @@
-# Copyright (c) 2017-present, Facebook, Inc.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the LICENSE file in
-# the root directory of this source tree. An additional grant of patent rights
-# can be found in the PATENTS file in the same directory.
-
-
-import torch.nn.functional as F
-from torch import nn
-
-
-class AdaptiveSoftmax(nn.Module):
-    """
-    This is an implementation of the efficient softmax approximation for
-    graphical processing units (GPU), described in the paper "Efficient softmax
-    approximation for GPUs" (http://arxiv.org/abs/1609.04309).
-    """
-
-    def __init__(self, vocab_size, input_dim, cutoff, dropout):
-        super().__init__()
-
-        if vocab_size > cutoff[-1]:
-            cutoff = cutoff + [vocab_size]
-        else:
-            assert vocab_size == cutoff[
-                -1], 'cannot specify cutoff smaller than vocab size'
-
-        output_dim = cutoff[0] + len(cutoff) - 1
-
-        self.vocab_size = vocab_size
-        self.cutoff = cutoff
-        self.dropout = dropout
-
-        self.lsm = nn.LogSoftmax(dim=1)
-        self.head = nn.Linear(input_dim, output_dim, bias=False)
-        self.tail = nn.ModuleList()
-
-        for i in range(len(cutoff) - 1):
-            self.tail.append(
-                nn.Sequential(
-                    nn.Linear(input_dim, input_dim // 4 ** i, bias=False),
-                    nn.Dropout(dropout),
-                    nn.Linear(input_dim // 4 ** i, cutoff[i + 1] - cutoff[i], bias=False)
-                )
-            )
-
-        def init_weights(m):
-            if hasattr(m, 'weight'):
-                nn.init.xavier_uniform_(m.weight)
-
-        self.apply(init_weights)
-
-    def adapt_target(self, target):
-        """
-        In order to be efficient, the AdaptiveSoftMax does not compute the
-        scores for all the word of the vocabulary for all the examples. It is
-        thus necessary to call the method adapt_target of the AdaptiveSoftMax
-        layer inside each forward pass.
-        """
-
-        target = target.view(-1)
-        new_target = [target.clone()]
-        target_idxs = []
-
-        for i in range(len(self.cutoff) - 1):
-            mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
-            new_target[0][mask] = self.cutoff[0] + i - 1
-
-            if mask.any():
-                target_idxs.append(mask.nonzero().squeeze(1))
-                new_target.append(target[mask].add(-self.cutoff[i]))
-            else:
-                target_idxs.append(None)
-                new_target.append(None)
-
-        return new_target, target_idxs
-
-    def forward(self, input, target):
-        """
-        Args:
-            input: (b x t x d)
-            target: (b x t)
-        Returns:
-            2 lists: output for each cutoff section and new targets by cut off
-        """
-
-        input = input.contiguous().view(-1, input.size(-1))
-        input = F.dropout(input, p=self.dropout, training=self.training)
-
-        new_target, target_idxs = self.adapt_target(target)
-        output = [self.head(input)]
-
-        for i in range(len(target_idxs)):
-            if target_idxs[i] is not None:
-                output.append(self.tail[i](input.index_select(0, target_idxs[i])))
-            else:
-                output.append(None)
-
-        return output, new_target
-
-    def get_log_prob(self, input, target):
-        """
-        Computes the log probabilities for all the words of the vocabulary,
-        given a 2D tensor of hidden vectors.
-        """
-
-        bsz, length, dim = input.size()
-        input = input.contiguous().view(-1, dim)
-
-        if target is not None:
-            _, target_idxs = self.adapt_target(target)
-        else:
-            target_idxs = None
-
-        head_y = self.head(input)
-        log_probs = head_y.new_zeros(input.size(0), self.vocab_size)
-
-        head_sz = self.cutoff[0] + len(self.tail)
-        log_probs[:, :head_sz] = self.lsm(head_y)
-        tail_priors = log_probs[:, self.cutoff[0] - 1: head_sz - 1].clone()
-
-        for i in range(len(self.tail)):
-            start = self.cutoff[i]
-            end = self.cutoff[i + 1]
-
-            if target_idxs is None:
-                tail_out = log_probs[:, start:end]
-                tail_out.copy_(self.tail[i](input))
-                log_probs[:, start:end] = self.lsm(tail_out).add_(tail_priors[:, i, None])
-            elif target_idxs[i] is not None:
-                idxs = target_idxs[i]
-                tail_out = log_probs[idxs, start:end]
-                tail_out.copy_(self.tail[i](input[idxs]))
-                log_probs[idxs, start:end] = self.lsm(tail_out).add_(tail_priors[idxs, i, None])
-
-        log_probs = log_probs.view(bsz, length, -1)
-        return log_probs

+ 0 - 38
PyTorch/Translation/Transformer/fairseq/modules/conv_tbc.py

@@ -1,38 +0,0 @@
-# Copyright (c) 2017-present, Facebook, Inc.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the LICENSE file in
-# the root directory of this source tree. An additional grant of patent rights
-# can be found in the PATENTS file in the same directory.
-
-import torch
-from torch.nn.modules.utils import _single
-
-
-class ConvTBC(torch.nn.Module):
-    """1D convolution over an input of shape (time x batch x channel)
-
-    The implementation uses gemm to perform the convolution. This implementation
-    is faster than cuDNN for small kernel sizes.
-    """
-    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
-        super(ConvTBC, self).__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.kernel_size = _single(kernel_size)
-        self.padding = _single(padding)
-
-        self.weight = torch.nn.Parameter(torch.Tensor(
-            self.kernel_size[0], in_channels, out_channels))
-        self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
-
-    def forward(self, input):
-        return input.contiguous().conv_tbc(self.weight, self.bias, self.padding[0])
-
-    def __repr__(self):
-        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
-             ', padding={padding}')
-        if self.bias is None:
-            s += ', bias=False'
-        s += ')'
-        return s.format(name=self.__class__.__name__, **self.__dict__)

+ 0 - 258
PyTorch/Translation/Transformer/fairseq/modules/downsampled_multihead_attention.py

@@ -1,258 +0,0 @@
-# Copyright (c) 2017-present, Facebook, Inc.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the LICENSE file in
-# the root directory of this source tree. An additional grant of patent rights
-# can be found in the PATENTS file in the same directory.
-#
-
-import math
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from fairseq.modules.scalar_bias import scalar_bias
-
-
-class SingleHeadAttention(nn.Module):
-    """
-    Single-head attention that supports Gating and Downsampling
-    """
-    def __init__(
-        self, out_channels, embed_dim, head_dim, head_index, dropout=0.,
-        bias=True, project_input=True, gated=False, downsample=False,
-        num_heads=1,
-    ):
-        super().__init__()
-        self.embed_dim = embed_dim
-        self.dropout = dropout
-        self.head_index = head_index
-        self.head_dim = head_dim
-        self.project_input = project_input
-        self.gated = gated
-        self.downsample = downsample
-        self.num_heads = num_heads
-        self.projection = None
-
-        k_layers = []
-        v_layers = []
-        if self.downsample:
-            k_layers.append(Downsample(self.head_index))
-            v_layers.append(Downsample(self.head_index))
-            out_proj_size = self.head_dim
-        else:
-            out_proj_size = self.head_dim * self.num_heads
-        if self.gated:
-            k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
-            self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias)
-            v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
-        else:
-            k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
-            self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias)
-            v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
-
-        self.in_proj_k = nn.Sequential(*k_layers)
-        self.in_proj_v = nn.Sequential(*v_layers)
-
-        if self.downsample:
-            self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias)
-        else:
-            self.out_proj = Linear(out_proj_size, out_channels, bias=bias)
-
-        self.scaling = self.head_dim**-0.5
-
-    def forward(
-        self, query, key, value, mask_future_timesteps=False,
-        key_padding_mask=None, use_scalar_bias=False,
-    ):
-        """Input shape: Time x Batch x Channel
-        Self-attention can be implemented by passing in the same arguments for
-        query, key and value. Future timesteps can be masked with the
-        `mask_future_timesteps` argument. Padding elements can be excluded from
-        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
-        batch x src_len, where padding elements are indicated by 1s.
-        """
-        src_len, bsz, out_channels = key.size()
-        tgt_len = query.size(0)
-        assert list(query.size()) == [tgt_len, bsz, out_channels]
-        assert key.size() == value.size()
-
-        if key_padding_mask is not None:
-            assert key_padding_mask.size(0) == bsz
-            assert key_padding_mask.size(1) == src_len
-
-        if self.downsample:
-            size = bsz
-        else:
-            size = bsz * self.num_heads
-
-        k = key
-        v = value
-        q = query
-        if self.project_input:
-            q = self.in_proj_q(q)
-            k = self.in_proj_k(k)
-            v = self.in_proj_v(v)
-            src_len = k.size()[0]
-        q *= self.scaling
-
-        if not self.downsample:
-            q = q.view(tgt_len, size, self.head_dim)
-            k = k.view(src_len, size, self.head_dim)
-            v = v.view(src_len, size, self.head_dim)
-
-        q = q.transpose(0, 1)
-        k = k.transpose(0, 1)
-        v = v.transpose(0, 1)
-
-        attn_weights = torch.bmm(q, k.transpose(1, 2))
-        if mask_future_timesteps:
-            assert query.size() == key.size(), \
-                'mask_future_timesteps only applies to self-attention'
-            attn_weights *= torch.tril(
-                attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(),
-                diagonal=-1,
-            )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
-            attn_weights += torch.triu(
-                attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(),
-                diagonal=0
-            )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
-        tgt_size = tgt_len
-        if use_scalar_bias:
-            attn_weights = scalar_bias(attn_weights, 2)
-            v = scalar_bias(v, 1)
-            tgt_size += 1
-
-        if key_padding_mask is not None:
-            # don't attend to padding symbols
-            if key_padding_mask.max() > 0:
-                if self.downsample:
-                    attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len)
-                else:
-                    attn_weights = attn_weights.view(size, self.num_heads, tgt_len, src_len)
-                attn_weights = attn_weights.masked_fill(
-                    key_padding_mask.unsqueeze(1).unsqueeze(2),
-                    -math.inf,
-                )
-                attn_weights = attn_weights.view(size, tgt_len, src_len)
-        attn_weights = F.softmax(attn_weights, dim=-1)
-        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
-
-        attn = torch.bmm(attn_weights, v)
-        if self.downsample:
-            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim)
-        else:
-            attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
-
-        attn = self.out_proj(attn)
-
-        return attn, attn_weights
-
-
-class DownsampledMultiHeadAttention(nn.ModuleList):
-    """
-    Multi-headed attention with Gating and Downsampling
-    """
-    def __init__(
-        self, out_channels, embed_dim, num_heads, dropout=0., bias=True,
-        project_input=True, gated=False, downsample=False,
-    ):
-        self.embed_dim = embed_dim
-        self.num_heads = num_heads
-        self.dropout = dropout
-        self.head_dim = embed_dim // num_heads
-        self.downsample = downsample
-        self.gated = gated
-        self.project_input = project_input
-        assert self.head_dim * num_heads == embed_dim
-
-        if self.downsample:
-            attention_heads = []
-            for index in range(self.num_heads):
-                attention_heads.append(
-                    SingleHeadAttention(
-                        out_channels, self.embed_dim, self.head_dim, index,
-                        self.dropout, bias, self.project_input, self.gated,
-                        self.downsample, self.num_heads,
-                    )
-                )
-            super().__init__(modules=attention_heads)
-            self.out_proj = Linear(embed_dim, out_channels, bias=bias)
-        else:
-            # either we have a list of attention heads, or just one attention head
-            # if not being downsampled, we can do the heads with one linear layer instead of separate ones
-            super().__init__()
-            self.attention_module = SingleHeadAttention(
-                out_channels, self.embed_dim, self.head_dim, 1, self.dropout,
-                bias, self.project_input, self.gated, self.downsample, self.num_heads,
-            )
-
-    def forward(
-        self, query, key, value, mask_future_timesteps=False,
-        key_padding_mask=None, use_scalar_bias=False,
-    ):
-        src_len, bsz, embed_dim = key.size()
-        tgt_len = query.size(0)
-        assert embed_dim == self.embed_dim
-        assert list(query.size()) == [tgt_len, bsz, embed_dim]
-        assert key.size() == value.size()
-
-        tgt_size = tgt_len
-        if use_scalar_bias:
-            tgt_size += 1
-
-        attn = []
-        attn_weights = []
-        if self.downsample:
-            for attention_head_number in range(self.num_heads):
-                # call the forward of each attention head
-                _attn, _attn_weight = self[attention_head_number](
-                    query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
-                )
-                attn.append(_attn)
-                attn_weights.append(_attn_weight)
-            full_attn = torch.cat(attn, dim=2)
-            full_attn = self.out_proj(full_attn)
-            return full_attn, attn_weights[0].clone()
-        else:
-            _attn, _attn_weight = self.attention_module(
-                query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
-            )
-            attn.append(_attn)
-            attn_weights.append(_attn_weight)
-            full_attn = torch.cat(attn, dim=2)
-            full_attn_weights = torch.cat(attn_weights)
-            full_attn_weights = full_attn_weights.view(bsz, self.num_heads, tgt_size, src_len)
-            full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads
-            return full_attn, full_attn_weights
-
-
-class Downsample(nn.Module):
-    """
-    Selects every nth element, where n is the index
-    """
-    def __init__(self, index):
-        super().__init__()
-        self.index = index
-
-    def forward(self, x):
-        return x[::self.index+1]
-
-
-def Linear(in_features, out_features, dropout=0., bias=True):
-    """Weight-normalized Linear layer (input: B x T x C)"""
-    m = nn.Linear(in_features, out_features, bias=bias)
-    m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
-    m.bias.data.zero_()
-    return nn.utils.weight_norm(m)
-
-
-def GatedLinear(in_features, out_features, dropout=0., bias=True):
-    """Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units"""
-    return nn.Sequential(
-        Linear(in_features, out_features*4, dropout, bias),
-        nn.GLU(),
-        Linear(out_features*2, out_features*2, dropout, bias),
-        nn.GLU(),
-        Linear(out_features, out_features, dropout, bias)
-    )

+ 0 - 20
PyTorch/Translation/Transformer/fairseq/modules/grad_multiply.py

@@ -1,20 +0,0 @@
-# Copyright (c) 2017-present, Facebook, Inc.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the LICENSE file in
-# the root directory of this source tree. An additional grant of patent rights
-# can be found in the PATENTS file in the same directory.
-
-import torch
-
-
-class GradMultiply(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, x, scale):
-        ctx.scale = scale
-        res = x.new(x)
-        return res
-
-    @staticmethod
-    def backward(ctx, grad):
-        return grad * ctx.scale, None

+ 0 - 89
PyTorch/Translation/Transformer/fairseq/modules/linearized_convolution.py

@@ -1,89 +0,0 @@
-# Copyright (c) 2017-present, Facebook, Inc.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the LICENSE file in
-# the root directory of this source tree. An additional grant of patent rights
-# can be found in the PATENTS file in the same directory.
-
-import torch
-import torch.nn.functional as F
-
-from fairseq import utils
-
-from .conv_tbc import ConvTBC
-
-
-class LinearizedConvolution(ConvTBC):
-    """An optimized version of nn.Conv1d.
-
-    At training time, this module uses ConvTBC, which is an optimized version
-    of Conv1d. At inference time, it optimizes incremental generation (i.e.,
-    one time step at a time) by replacing the convolutions with linear layers.
-    Note that the input order changes from training to inference.
-    """
-
-    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
-        super().__init__(in_channels, out_channels, kernel_size, **kwargs)
-        self._linearized_weight = None
-        self.register_backward_hook(self._clear_linearized_weight)
-
-    def forward(self, input, incremental_state=None):
-        """
-        Input:
-            Time x Batch x Channel during training
-            Batch x Time x Channel during inference
-        Args:
-            incremental_state: Used to buffer signal; if not None, then input is
-                expected to contain a single frame. If the input order changes
-                between time steps, call reorder_incremental_state.
-        """
-        if incremental_state is None:
-            output = super().forward(input)
-            if self.kernel_size[0] > 1 and self.padding[0] > 0:
-                # remove future timesteps added by padding
-                output = output[:-self.padding[0], :, :]
-            return output
-
-        # reshape weight
-        weight = self._get_linearized_weight()
-        kw = self.kernel_size[0]
-
-        bsz = input.size(0)  # input: bsz x len x dim
-        if kw > 1:
-            input = input.data
-            input_buffer = self._get_input_buffer(incremental_state)
-            if input_buffer is None:
-                input_buffer = input.new(bsz, kw, input.size(2)).zero_()
-                self._set_input_buffer(incremental_state, input_buffer)
-            else:
-                # shift buffer
-                input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
-            # append next input
-            input_buffer[:, -1, :] = input[:, -1, :]
-            input = input_buffer
-        with torch.no_grad():
-            output = F.linear(input.view(bsz, -1), weight, self.bias)
-        return output.view(bsz, 1, -1)
-
-    def reorder_incremental_state(self, incremental_state, new_order):
-        input_buffer = self._get_input_buffer(incremental_state)
-        if input_buffer is not None:
-            input_buffer = input_buffer.index_select(0, new_order)
-            self._set_input_buffer(incremental_state, input_buffer)
-
-    def _get_input_buffer(self, incremental_state):
-        return utils.get_incremental_state(self, incremental_state, 'input_buffer')
-
-    def _set_input_buffer(self, incremental_state, new_buffer):
-        return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
-
-    def _get_linearized_weight(self):
-        if self._linearized_weight is None:
-            kw = self.kernel_size[0]
-            weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
-            assert weight.size() == (self.out_channels, kw, self.in_channels)
-            self._linearized_weight = weight.view(self.out_channels, -1)
-        return self._linearized_weight
-
-    def _clear_linearized_weight(self, *args):
-        self._linearized_weight = None

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/modules/multihead_attention.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 0 - 33
PyTorch/Translation/Transformer/fairseq/modules/scalar_bias.py

@@ -1,33 +0,0 @@
-# Copyright (c) 2017-present, Facebook, Inc.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the LICENSE file in
-# the root directory of this source tree. An additional grant of patent rights
-# can be found in the PATENTS file in the same directory.
-#
-
-import torch
-
-
-class ScalarBias(torch.autograd.Function):
-    """
-    Adds a vector of scalars, used in self-attention mechanism to allow
-    the model to optionally attend to this vector instead of the past
-    """
-
-    @staticmethod
-    def forward(ctx, input, dim, bias_init):
-        size = list(input.size())
-        size[dim] += 1
-        output = input.new(*size).fill_(bias_init)
-        output.narrow(dim, 1, size[dim] - 1).copy_(input)
-        ctx.dim = dim
-        return output
-
-    @staticmethod
-    def backward(ctx, grad):
-        return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None
-
-
-def scalar_bias(input, dim, bias_init=0):
-    return ScalarBias.apply(input, dim, bias_init)

+ 3 - 3
PyTorch/Translation/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm.cpp

@@ -48,9 +48,9 @@ at::Tensor strided_batched_gemm(
   AT_ASSERTM(in_result.size(2) == batch2.size(2), "wrong matrix size");
   AT_ASSERTM(batch1.size(2)    == batch2.size(1), "wrong matrix size");
 
-  AT_ASSERTM(batch1.type().scalarType()    == at::ScalarType::Half, "Only HALF is supported");
-  AT_ASSERTM(batch2.type().scalarType()    == at::ScalarType::Half, "Only HALF is supported");
-  AT_ASSERTM(in_result.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
+  AT_ASSERTM(batch1.dtype()    == at::ScalarType::Half, "Only HALF is supported");
+  AT_ASSERTM(batch2.dtype()    == at::ScalarType::Half, "Only HALF is supported");
+  AT_ASSERTM(in_result.dtype() == at::ScalarType::Half, "Only HALF is supported");
   
   return strided_batched_gemm_cuda(beta, in_result, alpha, batch1, batch2);
 }

+ 582 - 212
PyTorch/Translation/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm_cuda.cu

@@ -1,270 +1,643 @@
-// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <vector>
+#pragma once
 #include <iostream>
+#include <vector>
 
-#include <ATen/ATen.h>
-#include <ATen/cuda/CUDAContext.h>
 #include <cuda.h>
-#include <cuda_runtime.h>
 #include <cuda_fp16.h>
+#include <cuda_profiler_api.h>
+#include <cuda_runtime.h>
 
-#include "THC/THC.h"
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/Exceptions.h>
 
 #include "cutlass/cutlass.h"
 #include "cutlass/gemm/gemm.h"
 #include "cutlass/gemm/wmma_gemm_traits.h"
 
-// symbol to be automatically resolved by PyTorch libs
-extern THCState *state;
-
+namespace {
 cublasOperation_t convertTransToCublasOperation(char trans) {
-  if (trans == 't') return CUBLAS_OP_T;
-  else if (trans == 'n') return CUBLAS_OP_N;
-  else if (trans == 'c') return CUBLAS_OP_C;
+  if (trans == 't')
+    return CUBLAS_OP_T;
+  else if (trans == 'n')
+    return CUBLAS_OP_N;
+  else if (trans == 'c')
+    return CUBLAS_OP_C;
   else {
-    THError("trans must be one of: t, n, c");
+    AT_ERROR("trans must be one of: t, n, c");
     return CUBLAS_OP_T;
   }
 }
 
-void CublasGemm(THCState *state, char transa, char transb, long m, long n, long k,
-                    float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
-                    float beta, half *c, long ldc, long strideC, long batchCount) {
-    cublasOperation_t opa = convertTransToCublasOperation(transa);
-    cublasOperation_t opb = convertTransToCublasOperation(transb);
- 
-    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
-    //cublasSetStream(handle, THCState_getCurrentStream(state));
-    float fAlpha = alpha;
-    float fBeta = beta;
-    THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
-    THCublasCheck(cublasGemmStridedBatchedEx(handle,
-                                     opa, opb, (int)m, (int)n, (int)k,
-                                     (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
-                                     b, CUDA_R_16F, (int)ldb, strideB,
-                                     (void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,
-                                     (int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-    THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
+void CublasStridedBatchedGemm(
+    char transa, char transb, long m, long n, long k,
+    float alpha, const half *a, long lda, long strideA, const half *b, long ldb,
+    long strideB, float beta, half *c, long ldc, long strideC, long batchCount,
+    cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) {
+  cublasOperation_t opa = convertTransToCublasOperation(transa);
+  cublasOperation_t opb = convertTransToCublasOperation(transb);
+
+  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+  cublasSetStream(handle, stream);
+  float fAlpha = alpha;
+  float fBeta = beta;
+  // THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
+  TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
+      handle, opa, opb, (int)m, (int)n, (int)k, (void *)&fAlpha, a, CUDA_R_16F,
+      (int)lda, strideA, b, CUDA_R_16F, (int)ldb, strideB, (void *)&fBeta, c,
+      CUDA_R_16F, (int)ldc, strideC, (int)batchCount, CUDA_R_32F, algo));
+  // THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
 }
+} // namespace
 
-template<cutlass::MatrixLayout::Kind A_LAYOUT, cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C>
+template <cutlass::MatrixLayout::Kind A_LAYOUT,
+          cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C>
 void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
-                          float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
-                          float beta, half *c, long ldc, long strideC, long batchCount) {
-  //printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
+                           float alpha, const half *a, long lda, long strideA,
+                           const half *b, long ldb, long strideB, float beta,
+                           half *c, long ldc, long strideC, long batchCount) {
+  // printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC:
+  // %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n",
+  // ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k,
+  // SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
   typedef cutlass::gemm::WmmaGemmTraits<
-    A_LAYOUT,
-    B_LAYOUT,
-    cutlass::Shape<32, 16, 16>,
-    half,
-    half,
-    half,
-    cutlass::gemm::LinearScaling<float>,
-    float,
-    typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<typename cutlass::Shape<32, 16, 16> >::Shape,
-    typename cutlass::Shape<16, 16, 16>,
-    SRC_A,   //kScalarsPerLdgA_
-    SRC_B,   //kScalarsPerLdgB_
-    SRC_A,   //KScalarsPerLdsA_
-    SRC_B,   //KScalarsPerLdsB_
-    DST_C,   //kScalarsPerLdgCAndStgD_
-    DST_C/2, //kScalarsPerStsD_
-    DST_C/2  //kScalarsPerLdsD_
-  >
-    WmmaGemmTraits;
+      A_LAYOUT, B_LAYOUT, cutlass::Shape<32, 16, 16>, half, half, half,
+      cutlass::gemm::LinearScaling<float>, float,
+      typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp<
+          typename cutlass::Shape<32, 16, 16>>::Shape,
+      typename cutlass::Shape<16, 16, 16>,
+      SRC_A,     // kScalarsPerLdgA_
+      SRC_B,     // kScalarsPerLdgB_
+      SRC_A,     // KScalarsPerLdsA_
+      SRC_B,     // KScalarsPerLdsB_
+      DST_C,     // kScalarsPerLdgCAndStgD_
+      DST_C / 2, // kScalarsPerStsD_
+      DST_C / 2  // kScalarsPerLdsD_
+      >
+      WmmaGemmTraits;
 
   typedef cutlass::gemm::Gemm<WmmaGemmTraits> Gemm;
   typename Gemm::Params params;
 
-
   int result = params.initialize(
-    m,                  // M dimension for each batch
-    n,                  // N dimension for each batch
-    k,                  // K dimension for each batch
-    alpha,              // scalar alpha
-    a,
-    lda,
-    strideA,     // distance in memory between the first element of neighboring batch
-    b,
-    ldb,
-    strideB,     // distance in memory between the first element of neighboring batch
-    beta,               // scalar beta
-    c,                  // source matrix C
-    ldc,
-    strideC,     // distance in memory between the first element of neighboring batch
-    c,                  // destination matrix C (may be different memory than source C matrix)
-    ldc,
-    strideC,    // distance in memory between the first element of neighboring batch
-    batchCount
-  );
+      m,     // M dimension for each batch
+      n,     // N dimension for each batch
+      k,     // K dimension for each batch
+      alpha, // scalar alpha
+      a, lda,
+      strideA, // distance in memory between the first element of neighboring
+               // batch
+      b, ldb,
+      strideB, // distance in memory between the first element of neighboring
+               // batch
+      beta,    // scalar beta
+      c,       // source matrix C
+      ldc,
+      strideC, // distance in memory between the first element of neighboring
+               // batch
+      c, // destination matrix C (may be different memory than source C matrix)
+      ldc,
+      strideC, // distance in memory between the first element of neighboring
+               // batch
+      batchCount);
 
   AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
-  
-  // Launch the CUTLASS GEMM kernel.
-  THCudaCheck(Gemm::launch(params));
 
+  // batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is
+  // limited to 16 bits. To implement batched GEMM with larger batch size, we
+  // fragment it into smaller batched GEMMs of gridDim.z <= 64k
+  long batchesLeft = batchCount;
+  long iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
+
+  do {
+    // printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC:
+    // %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f
+    // TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'),
+    // ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb,
+    // ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
+    int result =
+        params.initialize(m,     // M dimension for each batch
+                          n,     // N dimension for each batch
+                          k,     // K dimension for each batch
+                          alpha, // scalar alpha
+                          a, lda,
+                          strideA, // distance in memory between the first
+                                   // element of neighboring batch
+                          b, ldb,
+                          strideB, // distance in memory between the first
+                                   // element of neighboring batch
+                          beta,    // scalar beta
+                          c,       // source matrix C
+                          ldc,
+                          strideC, // distance in memory between the first
+                                   // element of neighboring batch
+                          c, // destination matrix C (may be different memory
+                             // than source C matrix)
+                          ldc,
+                          strideC, // distance in memory between the first
+                                   // element of neighboring batch
+                          iterBatchCount);
+
+    AT_ASSERTM(result == 0,
+               "Failed to initialize CUTLASS Gemm::Params object.");
+    // Launch the CUTLASS GEMM kernel.
+    C10_CUDA_CHECK(Gemm::launch(params, stream));
+
+    // Update batched GEMM params based on completed work
+    batchesLeft = batchesLeft - iterBatchCount;
+    a += iterBatchCount * strideA;
+    b += iterBatchCount * strideB;
+    c += iterBatchCount * strideC;
+    ;
+
+    iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
+
+  } while (batchesLeft > 0);
 }
 
-void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,
-                           float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
-                           float beta, half *c, long ldc, long strideC, long batchCount) {
-  //cudaStream_t stream = THCState_getCurrentStream(state);
-  //printf("GEMM   -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
+namespace {
+void gemm_switch_fp32accum(char transa, char transb, long m,
+                           long n, long k, float alpha, const half *a, long lda,
+                           long strideA, const half *b, long ldb, long strideB,
+                           float beta, half *c, long ldc, long strideC,
+                           long batchCount) {
   auto stream = c10::cuda::getCurrentCUDAStream();
-  if        ( (transa == 't') && (transb == 'n') ) { 
-    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else                                                   { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-  } else if ( (transa == 'n') && (transb == 'n') ) {
-    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else                                                   { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-  } else if ( (transa == 'n') && (transb == 't') ) {
-    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,4,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,8,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,4,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,4>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,2,2,2>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
-    else                                                   { CublasGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); }
+  // printf("GEMM   -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa ==
+  // 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
+  if ((transa == 't') && (transb == 'n')) {
+    if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
+                               strideA, b, ldb, strideB, beta, c, ldc, strideC,
+                               batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
+    }
+    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 8, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 8, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 4, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 4, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 4, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 2, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 2, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 2, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 8, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 8, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 8, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 4, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 4, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 4, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 2, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 2, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 2, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 8, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 8, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 8, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 4, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 4, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 4, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 2, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 2, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 2, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else {
+      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
+                               strideA, b, ldb, strideB, beta, c, ldc, strideC,
+                               batchCount);
+    }
+  } else if ((transa == 'n') && (transb == 'n')) {
+    if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
+                               strideA, b, ldb, strideB, beta, c, ldc, strideC,
+                               batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
+    }
+    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 8, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 8, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 4, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 4, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 4, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 2, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 2, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 8, 2, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 8, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 8, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 8, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 4, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 4, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 4, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 2, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 2, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 4, 2, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 8, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 8, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 8, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 4, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 4, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 4, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 2, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 2, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kColumnMajor, 2, 2, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else {
+      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
+                               strideA, b, ldb, strideB, beta, c, ldc, strideC,
+                               batchCount);
+    }
+  } else if ((transa == 'n') && (transb == 't')) {
+    if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
+                               strideA, b, ldb, strideB, beta, c, ldc, strideC,
+                               batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
+    }
+    else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 8, 8, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 8, 8, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 8, 4, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 8, 4, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 8, 4, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 8, 2, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 8, 2, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 8, 2, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 4, 8, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 4, 8, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 4, 8, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 4, 4, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 4, 4, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 4, 2, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 4, 2, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 4, 2, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 2, 8, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 2, 8, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 2, 8, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 2, 4, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 2, 4, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 2, 4, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 2, 2, 8>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 2, 2, 4>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) {
+      CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,
+                            cutlass::MatrixLayout::kRowMajor, 2, 2, 2>(
+          stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c,
+          ldc, strideC, batchCount);
+    } else {
+      CublasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda,
+                               strideA, b, ldb, strideB, beta, c, ldc, strideC,
+                               batchCount);
+    }
   } else {
     AT_ASSERTM(false, "TransA and TransB are invalid");
   }
 }
 
-void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t *lda, int64_t *ldb, int64_t *ldc)
-{
+void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k,
+                    int64_t *lda, int64_t *ldb, int64_t *ldc) {
   int transa_ = ((transa == 't') || (transa == 'T'));
   int transb_ = ((transb == 't') || (transb == 'T'));
 
-  // Note: leading dimensions generally are checked that they are > 0 and at least as big the result
-  // requires (even if the value won't be used).
-  if(n <= 1)
+  // Note: leading dimensions generally are checked that they are > 0 and at
+  // least as big the result requires (even if the value won't be used).
+  if (n <= 1)
     *ldc = std::max<int64_t>(m, 1);
 
-  if(transa_)
-  {
-    if(m <= 1)
+  if (transa_) {
+    if (m <= 1)
       *lda = std::max<int64_t>(k, 1);
-  }
-  else
-  {
-    if(k <= 1)
+  } else {
+    if (k <= 1)
       *lda = std::max<int64_t>(m, 1);
   }
 
-  if(transb_)
-  {
-    if(k <= 1)
+  if (transb_) {
+    if (k <= 1)
       *ldb = std::max<int64_t>(n, 1);
-  }
-  else
-  {
-    if(n <= 1)
+  } else {
+    if (n <= 1)
       *ldb = std::max<int64_t>(k, 1);
   }
-
 }
 
-void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long n, long k,
-                             float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
-                             float beta, half *c, long ldc, long strideC, long batchCount)
-{
-  if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX)  || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
+void HgemmStridedBatched(char transa, char transb, long m,
+                         long n, long k, float alpha, const half *a, long lda,
+                         long strideA, const half *b, long ldb, long strideB,
+                         float beta, half *c, long ldc, long strideC,
+                         long batchCount) {
+  if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) ||
+      (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX))
 
   {
-    THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
-            "with the bound [val] <= %d", INT_MAX);
+    AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
+             "batchCount"
+             "with the bound [val] <= %d",
+             INT_MAX);
   }
 
   adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
 
-  //gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
-  gemm_switch_fp32accum(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
+  gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
+                        b, ldb, strideB, beta, c, ldc, strideC, batchCount);
 }
 
+} // namespace
+
 at::Tensor strided_batched_gemm_cuda(
     float beta,
     at::Tensor in_result,
@@ -326,7 +699,6 @@ at::Tensor strided_batched_gemm_cuda(
   int64_t num_batches = result.size(0);
 
   HgemmStridedBatched(
-      state,
       transpose_batch1,
       transpose_batch2,
       result.size(transpose_result ? 2 : 1),
@@ -341,5 +713,3 @@ at::Tensor strided_batched_gemm_cuda(
 
   return in_result;
 }
-
-

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/optim/adam.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/optim/fairseq_optimizer.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 1 - 24
PyTorch/Translation/Transformer/fairseq/options.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -34,7 +34,6 @@ from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
 def get_training_parser():
     parser = get_parser('Trainer')
     add_dataset_args(parser, train=True, gen=True)
-    add_distributed_training_args(parser)
     add_model_args(parser)
     add_optimization_args(parser)
     add_checkpoint_args(parser)
@@ -161,28 +160,6 @@ def add_dataset_args(parser, train=False, gen=False):
                            help='id of the shard to generate (id < num_shards)')
     return group
 
-
-def add_distributed_training_args(parser):
-    group = parser.add_argument_group('Distributed training')
-    group.add_argument('--distributed-world-size', type=int, metavar='N',
-                       default=torch.cuda.device_count(),
-                       help='total number of GPUs across all nodes (default: all visible GPUs)')
-    group.add_argument('--distributed-rank', default=os.getenv('LOCAL_RANK', 0), type=int,
-                       help='rank of the current worker')
-    group.add_argument('--local_rank', default=0, type=int,
-                       help='rank of the current worker')
-    group.add_argument('--distributed-backend', default='nccl', type=str,
-                       help='distributed backend')
-    group.add_argument('--distributed-init-method', default=None, type=str,
-                       help='typically tcp://hostname:port that will be used to '
-                            'establish initial connetion')
-    group.add_argument('--distributed-port', default=-1, type=int,
-                       help='port number (not required if using --distributed-init-method)')
-    group.add_argument('--device-id', default=0, type=int,
-                       help='which GPU to use (usually configured automatically)')
-    return group
-
-
 def add_optimization_args(parser):
     group = parser.add_argument_group('Optimization')
     group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',

+ 16 - 30
PyTorch/Translation/Transformer/fairseq/sequence_generator.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -315,11 +315,7 @@ class SequenceGenerator(object):
                     nonpad_idxs = src_tokens.ne(self.pad)
                 attn[:, :, step + 1].copy_(avg_attn_scores)
 
-            cand_scores = buffer('cand_scores', type_of=scores)
-            cand_indices = buffer('cand_indices')
             cand_beams = buffer('cand_beams')
-            eos_bbsz_idx = buffer('eos_bbsz_idx')
-            eos_scores = buffer('eos_scores', type_of=scores)
             if step < maxlen:
                 if prefix_tokens is not None and step < prefix_tokens.size(1):
                     probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
@@ -336,23 +332,23 @@ class SequenceGenerator(object):
                         values, indices = probs[:, 2:].topk(self.sampling_topk)
                         exp_probs = values.div_(self.sampling_temperature).exp()
                         if step == 0:
-                            torch.multinomial(exp_probs, beam_size, replacement=True, out=cand_indices)
+                            cand_indices = torch.multinomial(exp_probs, beam_size, replacement=True)
                         else:
-                            torch.multinomial(exp_probs, 1, replacement=True, out=cand_indices)
-                        torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
-                        torch.gather(indices, dim=1, index=cand_indices, out=cand_indices)
+                            cand_indices = torch.multinomial(exp_probs, 1, replacement=True)
+                        cand_scores = torch.gather(exp_probs, dim=1, index=cand_indices)
+                        cand_indices = torch.gather(indices, dim=1, index=cand_indices)
                         cand_indices.add_(2)
                     else:
                         exp_probs = probs.div_(self.sampling_temperature).exp_().view(-1, self.vocab_size)
 
                         if step == 0:
                             # we exclude the first two vocab items, one of which is pad
-                            torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
+                            cand_indices = torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True)
                         else:
-                            torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
+                            cand_indices = torch.multinomial(exp_probs[:, 2:], 1, replacement=True)
 
                         cand_indices.add_(2)
-                        torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
+                        cand_scores = torch.gather(exp_probs, dim=1, index=cand_indices)
 
                     cand_scores.log_()
                     cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
@@ -371,20 +367,18 @@ class SequenceGenerator(object):
                 else:
                     # take the best 2 x beam_size predictions. We'll choose the first
                     # beam_size of these which don't predict eos to continue with.
-                    torch.topk(
+                    cand_scores, cand_indices = torch.topk(
                         probs.view(bsz, -1),
                         k=min(cand_size, probs.view(bsz, -1).size(1) - 1),  # -1 so we never select pad
-                        out=(cand_scores, cand_indices),
                     )
-                    torch.div(cand_indices, self.vocab_size, out=cand_beams, rounding_mode='trunc')
+                    cand_beams = torch.div(cand_indices, self.vocab_size, rounding_mode='trunc')
                     cand_indices.fmod_(self.vocab_size)
             else:
                 # finalize all active hypotheses once we hit maxlen
                 # pick the hypothesis with the highest prob of EOS right now
-                torch.sort(
+                eos_scores, eos_bbsz_idx = torch.sort(
                     probs[:, self.eos],
                     descending=True,
-                    out=(eos_scores, eos_bbsz_idx),
                 )
                 num_remaining_sent -= len(finalize_hypos(
                     step, eos_bbsz_idx, eos_scores))
@@ -402,16 +396,14 @@ class SequenceGenerator(object):
             finalized_sents = set()
             if step >= self.minlen:
                 # only consider eos when it's among the top beam_size indices
-                torch.masked_select(
+                eos_bbsz_idx = torch.masked_select(
                     cand_bbsz_idx[:, :beam_size],
                     mask=eos_mask[:, :beam_size],
-                    out=eos_bbsz_idx,
                 )
                 if eos_bbsz_idx.numel() > 0:
-                    torch.masked_select(
+                    eos_scores = torch.masked_select(
                         cand_scores[:, :beam_size],
                         mask=eos_mask[:, :beam_size],
-                        out=eos_scores,
                     )
                     finalized_sents = finalize_hypos(
                         step, eos_bbsz_idx, eos_scores, cand_scores)
@@ -454,24 +446,18 @@ class SequenceGenerator(object):
             # set active_mask so that values > cand_size indicate eos hypos
             # and values < cand_size indicate candidate active hypos.
             # After, the min values per row are the top candidate active hypos
-            active_mask = buffer('active_mask')
-            torch.add(
+            active_mask = torch.add(
                 eos_mask.type_as(cand_offsets) * cand_size,
                 cand_offsets[:eos_mask.size(1)],
-                out=active_mask,
             )
 
             # get the top beam_size active hypotheses, which are just the hypos
             # with the smallest values in active_mask
-            active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
-            torch.topk(
+            _ignore, active_hypos = torch.topk(
                 active_mask, k=beam_size, dim=1, largest=False,
-                out=(_ignore, active_hypos)
             )
-            active_bbsz_idx = buffer('active_bbsz_idx')
-            torch.gather(
+            active_bbsz_idx = torch.gather(
                 cand_bbsz_idx, dim=1, index=active_hypos,
-                out=active_bbsz_idx,
             )
             active_scores = torch.gather(
                 cand_scores, dim=1, index=active_hypos,

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/tokenizer.py

@@ -7,7 +7,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 1 - 1
PyTorch/Translation/Transformer/fairseq/utils.py

@@ -7,7 +7,7 @@
 #
 #--------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 3 - 8
PyTorch/Translation/Transformer/inference.py

@@ -8,7 +8,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -118,12 +118,8 @@ def setup_logger(args):
             dllogger.log(step='PARAMETER', data={k:v}, verbosity=0)
         container_setup_info = log_helper.get_framework_env_vars()
         dllogger.log(step='PARAMETER', data=container_setup_info, verbosity=0)
-        dllogger.metadata('throughput', {'unit':'tokens/s', 'format':':/3f', 'GOAL':'MAXIMIZE', 'STAGE':'INFER'})
-        dllogger.metadata('latency_avg', {'unit':'s', 'format':':/3f', 'GOAL':'MINIMIZE', 'STAGE':'INFER'})
-        dllogger.metadata('latency_p90', {'unit':'s', 'format':':/3f', 'GOAL':'MINIMIZE', 'STAGE':'INFER'})
-        dllogger.metadata('latency_p95', {'unit':'s', 'format':':/3f', 'GOAL':'MINIMIZE', 'STAGE':'INFER'})
-        dllogger.metadata('latency_p99', {'unit':'s', 'format':':/3f', 'GOAL':'MINIMIZE', 'STAGE':'INFER'})
-        dllogger.metadata('total_infernece_time', {'unit':'s', 'format':':/3f', 'GOAL':'MINIMIZE', 'STAGE':'INFER'})
+        dllogger.metadata('throughput',
+                          {'unit':'tokens/s', 'format':':/3f', 'GOAL':'MAXIMIZE', 'STAGE':'INFER'})
     else:
         dllogger.init(backends=[])
 
@@ -260,7 +256,6 @@ def main(args):
             for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
                 print(f'Score {hypo[0]}', file=sys.stderr)
                 print(hypo[1])
-                print(pos_scores, file=sys.stderr)
                 if align is not None:
                     print(align, file=sys.stderr)
 

+ 0 - 123
PyTorch/Translation/Transformer/scripts/deployer.py

@@ -1,123 +0,0 @@
-#!/usr/bin/python
-
-# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License. 
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, 
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
-# See the License for the specific language governing permissions and
-# limitations under the License. 
-
-
-import sys
-import torch
-import argparse
-import deployer_lib
-# 
-import torch
-from fairseq import data
-from fairseq.data import load_dataset_splits, data_utils
-from fairseq.models.transformer import TransformerModel
-from copy import deepcopy
-
-def get_model_and_args(model_args):
-    ''' the arguments initialize_model will receive '''
-    parser = argparse.ArgumentParser()
-    ## Required parameters by the model. 
-    parser.add_argument("--checkpoint", 
-                        default=None, 
-                        type=str, 
-                        required=True, 
-                        help="The checkpoint of the model. ")
-    parser.add_argument('--batch-size', 
-                        default=10240, 
-                        type=int, 
-                        help='Batch size for inference')
-    parser.add_argument('--num-batches',
-                        default=2,
-                        type=int,
-                        help='Number of batches to check accuracy on')
-    parser.add_argument("--data",
-                        default=None,
-                        type=str,
-                        required=True,
-                        help="Path to the dataset")
-    parser.add_argument('--part',
-                        choices=['encoder', 'decoder', 'model'],
-                        default='model',
-                        type=str,
-                        help='Choose the part of the model to export')
-
-    args = parser.parse_args(model_args)
-
-    state_dict = torch.load(args.checkpoint, map_location='cpu')
-
-    model_args = state_dict['args']
-    model_args.data = args.data
-    model_args.num_batches = args.num_batches
-    model_args.max_tokens = args.batch_size
-    model_args.fuse_layer_norm = False
-    model_args.part = args.part
-
-    model = TransformerModel.build_model(model_args)
-    model.load_state_dict(state_dict['model'], strict=True)
-    model.make_generation_fast_(need_attn=False)
-
-    return model, model_args
-
-def get_dataloader(args, encoder=None):
-    ''' return dataloader for inference '''
-    assert not(args.part == 'decoder' and encoder is None), "Cannot export decoder without providing encoder"
-    src_dict, tgt_dict = data_utils.load_dictionaries(args)
-    datasets = load_dataset_splits(args, ['valid'], src_dict, tgt_dict)
-    itr = data.EpochBatchIterator(
-        dataset=datasets['valid'],
-        max_tokens=args.max_tokens,
-        max_positions=args.max_positions,
-    ).next_epoch_itr(shuffle=False)
-
-    def input_itr():
-        for batch in itr:
-            if itr.count > args.num_batches:
-                break
-            ni = batch['net_input']
-            if args.part == 'decoder': #this part works only on GPU
-                with torch.no_grad():
-                    encoder_out = encoder(ni['src_tokens'].cuda(), ni['src_lengths'].cuda()) 
-                yield ni['prev_output_tokens'], encoder_out[0], encoder_out[1]
-            elif args.part == 'encoder':
-                yield ni['src_tokens'], ni['src_lengths']
-            else:
-                yield ni['src_tokens'], ni['src_lengths'], ni['prev_output_tokens']
-
-    return input_itr()
-
-
-if __name__=='__main__':
-    # don't touch this! 
-    deployer, model_argv = deployer_lib.create_deployer(sys.argv[1:]) # deployer and returns removed deployer arguments
-    
-    model, model_args = get_model_and_args(model_argv)
-
-    if model_args.part == 'decoder':
-        encoder = model.encoder
-        encoder.embed_tokens = deepcopy(encoder.embed_tokens)
-        encoder.cuda()
-    else:
-        encoder = None
-    
-    dataloader = get_dataloader(model_args, encoder=encoder)
-
-    if model_args.part == 'encoder':
-        model = model.encoder
-    elif model_args.part == 'decoder':
-        model = model.decoder
-    
-    deployer.deploy(dataloader, model)
-

+ 0 - 969
PyTorch/Translation/Transformer/scripts/deployer_lib.py

@@ -1,969 +0,0 @@
-#!/usr/bin/python
-
-# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License. 
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, 
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
-# See the License for the specific language governing permissions and
-# limitations under the License. 
-
-
-import os
-import sys
-import time
-import json
-import torch
-import argparse
-import statistics
-from collections import Counter
-
-
-torch_type_to_triton_type = {
-    torch.bool:     'TYPE_BOOL', 
-    torch.int8:     'TYPE_INT8', 
-    torch.int16:    'TYPE_INT16', 
-    torch.int32:    'TYPE_INT32', 
-    torch.int64:    'TYPE_INT64', 
-    torch.uint8:    'TYPE_UINT8', 
-    torch.float16:  'TYPE_FP16', 
-    torch.float32:  'TYPE_FP32', 
-    torch.float64:  'TYPE_FP64'
-}
-
-
-CONFIG_TEMPLATE = r"""
-name: "{model_name}"
-platform: "{platform}"
-max_batch_size: {max_batch_size}
-input [
-    {spec_inputs}
-]
-output [
-    {spec_outputs}
-]
-{dynamic_batching}
-{model_optimizations}
-instance_group [
-    {{
-        count: {engine_count}
-        kind: KIND_GPU
-        gpus: [ {gpu_list} ]
-    }}
-]"""
-
-
-INPUT_TEMPLATE = r"""
-{{
-    name: "input__{num}"
-    data_type: {type}
-    dims: {dims}
-    {reshape}
-}},"""
-
-
-OUTPUT_TEMPLATE = r""" 
-{{
-    name: "output__{num}"
-    data_type: {type}
-    dims: {dims}
-    {reshape}
-}},"""
-
-
-MODEL_OPTIMIZATION_TEMPLATE = r"""
-optimization {{
-  {execution_accelerator}
-  cuda {{
-    graphs: {capture_cuda_graph}
-  }}
-}}"""
-
-
-EXECUTION_ACCELERATOR_TEMPLATE = r"""
-  execution_accelerators {{
-    gpu_execution_accelerator: [
-      {{
-        name: "tensorrt"
-      }}
-    ]
-  }},"""
-
-
-def remove_empty_lines(text):
-    ''' removes empty lines from text, returns the result '''
-    ret = "".join([s for s in text.strip().splitlines(True) if s.strip()])
-    return ret
-
-
-def create_deployer(argv):
-    ''' takes a list of arguments, returns a deployer object and the list of unused arguments '''
-    parser = argparse.ArgumentParser()
-    # required args
-    method = parser.add_mutually_exclusive_group(required=True)
-    method.add_argument('--ts-script',
-                        action='store_true',
-                        help='convert to torchscript using torch.jit.script')
-    method.add_argument('--ts-trace',
-                        action='store_true',
-                        help='convert to torchscript using torch.jit.trace')
-    method.add_argument('--onnx',
-                        action='store_true',
-                        help='convert to onnx using torch.onnx.export')
-    method.add_argument('--trt',
-                        action='store_true',
-                        help='convert to trt using tensorrt')
-    # triton related args
-    arguments = parser.add_argument_group('triton related flags')
-    arguments.add_argument('--triton-no-cuda',
-                            action='store_true',
-                            help='Use the CPU for tracing.')
-    arguments.add_argument('--triton-model-name',
-                            type=str,
-                            default="model",
-                            help="exports to appropriate directory structure for TRTIS")
-    arguments.add_argument("--triton-model-version",
-                            type=int,
-                            default=1,
-                            help="exports to appropriate directory structure for TRTIS")
-    arguments.add_argument("--triton-server-url",
-                            type=str,
-                            default="localhost:8001",
-                            help="exports to appropriate directory structure for TRTIS")
-    arguments.add_argument("--triton-max-batch-size",
-                            type=int,
-                            default=8,
-                            help="Specifies the 'max_batch_size' in the TRTIS model config.\
-                                  See the TRTIS documentation for more info.")
-    arguments.add_argument("--triton-dyn-batching-delay",
-                            type=float,
-                            default=0,
-                            help="Determines the dynamic_batching queue delay in milliseconds(ms) for\
-                                  the TRTIS model config. Use '0' or '-1' to specify static batching.\
-                                  See the TRTIS documentation for more info.")
-    arguments.add_argument("--triton-engine-count",
-                            type=int,
-                            default=1,
-                            help="Specifies the 'instance_group' count value in the TRTIS model config.\
-                                  See the TRTIS documentation for more info.")
-    arguments.add_argument('--save-dir', type=str, default='./triton_models', help='Saved model directory')
-    # optimization args
-    arguments = parser.add_argument_group('optimization flags')
-    arguments.add_argument("--max_workspace_size",
-                            type=int,
-                            default=512*1024*1024,
-                            help="set the size of the workspace for trt export")
-    arguments.add_argument("--trt-fp16",
-                            action='store_true',
-                            help="trt flag ---- export model in mixed precision mode")
-    arguments.add_argument("--capture-cuda-graph",
-                            type=int,
-                            default=1,
-                            help="capture cuda graph for obtaining speedup. possible values: 0, 1. default: 1. ")
-    arguments.add_argument('--quantize',
-                            action='store_true',
-                            help='apply quantization for supported nodes')
-    arguments.add_argument('--calibrate',
-                            action='store_true',
-                            help='apply calibration for supported nodes')
-    # remainder args
-    arguments.add_argument('model_arguments', nargs=argparse.REMAINDER, help='arguments that will be ignored by deployer lib and will be forwarded to your deployer script')
-    # 
-    args = parser.parse_args(argv)
-    deployer = Deployer(args)
-    # 
-    return deployer, args.model_arguments[1:]
-
-
-class DeployerLibrary:
-    def __init__(self, args):
-        self.args = args
-        self.platform = None
-    
-    def set_platform(self, platform):
-        ''' sets the platform
-            :: platform :: "pytorch_libtorch" or "onnxruntime_onnx" or "tensorrt_plan"
-        '''
-        self.platform = platform
-    
-    def build_trt_engine(self, model_file, shapes):
-        ''' takes a path to an onnx file, and shape information, returns a trt engine
-            :: model_file :: path to an onnx model
-            :: shapes :: dictionary containing min shape, max shape, opt shape for the trt engine
-        '''
-        import tensorrt as trt
-        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
-        builder = trt.Builder(TRT_LOGGER)
-        builder.fp16_mode = self.args.trt_fp16
-        builder.max_batch_size = self.args.triton_max_batch_size
-        # 
-        config = builder.create_builder_config()
-        config.max_workspace_size = self.args.max_workspace_size
-        if self.args.trt_fp16:
-            config.flags |= 1 << int(trt.BuilderFlag.FP16)
-        profile = builder.create_optimization_profile()
-        for s in shapes:
-            profile.set_shape(s['name'], min=s['min'], opt=s['opt'], max=s['max'])
-        config.add_optimization_profile(profile)
-        explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
-        network = builder.create_network(explicit_batch)
-        # 
-        with trt.OnnxParser(network, TRT_LOGGER) as parser:
-            with open(model_file, 'rb') as model:
-                parser.parse(model.read())
-                for i in range(parser.num_errors):
-                    e = parser.get_error(i)
-                    print("||||e", e)
-                engine = builder.build_engine(network, config=config)
-        return engine
-    
-    def load_engine(self, engine_filepath):
-        ''' loads a trt engine from engine_filepath, returns it '''
-        import tensorrt as trt
-        TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
-        with open(engine_filepath, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
-            engine = runtime.deserialize_cuda_engine(f.read())
-        return engine
-    
-    def prepare_inputs(self, dataloader, device):
-        ''' load sample inputs to device '''
-        def _move_to_device(maybe_tensor):
-            if torch.is_tensor(maybe_tensor):
-                return maybe_tensor.to(device)
-            elif isinstance(maybe_tensor, dict):
-                return {
-                    key: _move_to_device(value)
-                    for key, value in maybe_tensor.items()
-                }
-            elif isinstance(maybe_tensor, list) or isinstance(maybe_tensor, tuple):
-                return [_move_to_device(x) for x in maybe_tensor]
-            else:
-                return maybe_tensor
-
-        inputs = []
-        for batch in dataloader:
-            batch_d = _move_to_device(batch)
-            if not hasattr(batch_d, '__iter__'):
-                batch_d = (batch_d,)
-            inputs.append(batch_d)
-
-        return inputs
-    
-    def get_list_of_shapes(self, l, fun):
-        ''' returns the list of min/max shapes, depending on fun
-            :: l :: list of tuples of tensors
-            :: fun :: min or max
-        '''
-        tensor_tuple = l[0]
-        shapes = [list(x.shape) for x in tensor_tuple]
-        for tensor_tuple in l:
-            assert len(tensor_tuple) == len(shapes), "tensors with varying shape lengths are not supported"
-            for i,x in enumerate(tensor_tuple):
-                for j in range(len(x.shape)):
-                    shapes[i][j] = fun(shapes[i][j], x.shape[j])
-        return shapes # a list of shapes
-    
-    def get_tuple_of_min_shapes(self, l):
-        ''' returns the tuple of min shapes 
-            :: l :: list of tuples of tensors '''
-        shapes = self.get_list_of_shapes(l, min)
-        min_batch = 1
-        shapes = [[min_batch,*shape[1:]] for shape in shapes]
-        shapes = tuple(shapes)
-        return shapes # tuple of min shapes
-    
-    def get_tuple_of_max_shapes(self, l):
-        ''' returns the tuple of max shapes 
-            :: l :: list of tuples of tensors '''
-        shapes = self.get_list_of_shapes(l, max)
-        max_batch = max(2,shapes[0][0])
-        shapes = [[max_batch,*shape[1:]] for shape in shapes]
-        shapes = tuple(shapes)
-        return shapes # tuple of max shapes
-    
-    def get_tuple_of_opt_shapes(self, l):
-        ''' returns the tuple of opt shapes 
-            :: l :: list of tuples of tensors '''
-        counter = Counter()
-        for tensor_tuple in l:
-            shapes = [tuple(x.shape) for x in tensor_tuple]
-            shapes = tuple(shapes)
-            counter[shapes] += 1
-        shapes = counter.most_common(1)[0][0]
-        return shapes # tuple of most common occuring shapes
-    
-    def get_tuple_of_dynamic_shapes(self, l):
-        ''' returns a tuple of dynamic shapes: variable tensor dimensions 
-            (for ex. batch size) occur as -1 in the tuple
-            :: l :: list of tuples of tensors '''
-        tensor_tuple = l[0]
-        shapes = [list(x.shape) for x in tensor_tuple]
-        for tensor_tuple in l:
-            err_msg = "tensors with varying shape lengths are not supported"
-            assert len(tensor_tuple) == len(shapes), err_msg
-            for i,x in enumerate(tensor_tuple):
-                for j in range(len(x.shape)):
-                    if shapes[i][j] != x.shape[j] or j == 0:
-                        shapes[i][j] = -1
-        shapes = tuple(shapes)
-        return shapes # tuple of dynamic shapes
-    
-    def run_models(self, models, inputs):
-        ''' run the models on inputs, return the outputs and execution times '''
-        ret = []
-        for model in models:
-            torch.cuda.synchronize()
-            time_start = time.time()
-            outputs = []
-            for input in inputs:
-                with torch.no_grad():
-                    output = model(*input)
-                if type(output) is torch.Tensor:
-                    output = [output]
-                elif type(output) is dict:
-                    output = list(output.items())
-                    output.sort(key=lambda x: x[0])
-                    output = [x[0] for x in output]
-                outputs.append(output)
-            torch.cuda.synchronize()
-            time_end = time.time()
-            t = time_end - time_start
-            ret.append(outputs)
-            ret.append(t)
-        return ret
-
-    def compute_tensor_stats(self, tensor):
-        #if tensor is not empty
-        if tensor.numel():
-            return {'std': tensor.std().item(),
-                    'mean': tensor.mean().item(),
-                    'max': tensor.max().item(),
-                    'min': tensor.min().item(),
-            }
-        else:
-            return {'std': 0,
-                    'mean':0,
-                    'max': 0,
-                    'min': 0,
-            }
-
-    def compute_errors(self, outputs_A, outputs_B):
-        ''' returns dictionary with errors statistics '''
-        device = outputs_A[0][0][0].device
-        dtype = outputs_A[0][0][0].dtype
-        num_outputs = len(outputs_A[0])
-        x_values = [torch.zeros(0, device = device, dtype = dtype) for _ in range(num_outputs)]
-        y_values = [torch.zeros(0, device = device, dtype = dtype) for _ in range(num_outputs)]
-        d_values = [torch.zeros(0, device = device, dtype = dtype) for _ in range(num_outputs)]
-        for output_A,output_B in zip(outputs_A,outputs_B):
-            for i,(x,y) in enumerate(zip(output_A, output_B)):
-                x = x.view(-1).float()
-                y = y.view(-1).float()
-                d = abs(x - y)
-                x_values[i] = torch.cat((x_values[i], x), 0)
-                y_values[i] = torch.cat((y_values[i], y), 0)
-                d_values[i] = torch.cat((d_values[i], d), 0)
-        Error_stats = [{'Original': self.compute_tensor_stats(x),
-                       'Converted': self.compute_tensor_stats(y),
-                       'Absolute difference': self.compute_tensor_stats(d),
-                           } for x,y,z in zip(x_values, y_values, d_values)]
-        return Error_stats
-    
-    def print_errors(self, Error_stats):
-        ''' print various statistcs of Linf errors '''
-        print()
-        print("conversion correctness test results")
-        print("-----------------------------------")
-        import pandas as pd
-        for i,e in enumerate(Error_stats):
-            print(f'Output {i}:')
-            print(pd.DataFrame(e))
-    
-    def write_config(self, config_filename, 
-                     input_shapes, input_types, 
-                     output_shapes, output_types):
-        ''' writes TRTIS config file 
-            :: config_filename :: the file to write the config file into
-            :: input_shapes :: tuple of dynamic shapes of the input tensors
-            :: input_types :: tuple of torch types of the input tensors
-            :: output_shapes :: tuple of dynamic shapes of the output tensors
-            :: output_types :: tuple of torch types of the output tensors
-        '''
-        assert self.platform is not None, "error - platform is not set"
-        
-        config_template = CONFIG_TEMPLATE
-        input_template = INPUT_TEMPLATE
-        optimization_template = MODEL_OPTIMIZATION_TEMPLATE
-        accelerator_template = EXECUTION_ACCELERATOR_TEMPLATE
-        
-        spec_inputs = r""""""
-        for i,(shape,typ) in enumerate(zip(input_shapes,input_types)):
-            d = {
-                'num' : str(i), 
-                'type': torch_type_to_triton_type[typ],
-                'dims': str([1]) if len(shape) == 1 else str(list(shape)[1:]) # first dimension is the batch size 
-            }
-            d['reshape'] = 'reshape: { shape: [ ] }' if len(shape) == 1 else ''
-            spec_inputs += input_template.format_map(d)
-        spec_inputs = spec_inputs[:-1]
-        
-        output_template = OUTPUT_TEMPLATE
-        spec_outputs = r""""""
-        for i,(shape,typ) in enumerate(zip(output_shapes,output_types)):
-            d = {
-                'num' : str(i), 
-                'type': torch_type_to_triton_type[typ],
-                'dims': str([1]) if len(shape) == 1 else str(list(shape)[1:]) # first dimension is the batch size 
-            }
-            d['reshape'] = 'reshape: { shape: [ ] }' if len(shape) == 1 else ''
-            spec_outputs += output_template.format_map(d)
-        spec_outputs = spec_outputs[:-1]
-        
-        batching_str = ""
-        max_batch_size = self.args.triton_max_batch_size
-        
-        if (self.args.triton_dyn_batching_delay > 0):
-            # Use only full and half full batches 
-            pref_batch_size = [int(max_batch_size / 2.0), max_batch_size]
-            
-            batching_str = r"""
-dynamic_batching {{
-    preferred_batch_size: [{0}]
-    max_queue_delay_microseconds: {1}
-}}""".format(", ".join([str(x) for x in pref_batch_size]), 
-                        int(self.args.triton_dyn_batching_delay * 1000.0))
-        
-        accelerator_str = ""
-        if self.platform == 'onnxruntime_onnx':
-            accelerator_str = accelerator_template.format_map({})
-        
-        d = {
-            "execution_accelerator":  accelerator_str, 
-            "capture_cuda_graph":     str(self.args.capture_cuda_graph)
-        }
-        optimization_str = optimization_template.format_map(d)
-        
-        config_values = {
-            "model_name":           self.args.triton_model_name, 
-            "platform":             self.platform, 
-            "max_batch_size":       max_batch_size, 
-            "spec_inputs":          spec_inputs, 
-            "spec_outputs":         spec_outputs, 
-            "dynamic_batching":     batching_str, 
-            "model_optimizations" : optimization_str, 
-            "gpu_list":         ", ".join([str(x) for x in range(torch.cuda.device_count())]), 
-            "engine_count":     self.args.triton_engine_count
-        }
-        
-        # write config 
-        with open(config_filename, "w") as file:
-            final_config_str = config_template.format_map(config_values)
-            final_config_str = remove_empty_lines(final_config_str)
-            file.write(final_config_str)
-
-
-class Deployer:
-    def __init__(self, args):
-        self.args = args
-        self.lib = DeployerLibrary(args)
-    
-    def deploy(self, dataloader, model):
-        ''' deploy the model and test for correctness with dataloader '''
-        if self.args.ts_script or self.args.ts_trace:
-            self.lib.set_platform("pytorch_libtorch")
-            print("deploying model " + self.args.triton_model_name + " in format " + self.lib.platform)
-            self.to_triton_torchscript(dataloader, model)
-        elif self.args.onnx:
-            self.lib.set_platform("onnxruntime_onnx")
-            print("deploying model " + self.args.triton_model_name + " in format " + self.lib.platform)
-            self.to_triton_onnx(dataloader, model)
-        elif self.args.trt:
-            self.lib.set_platform("tensorrt_plan")
-            print("deploying model " + self.args.triton_model_name + " in format " + self.lib.platform)
-            self.to_triton_trt(dataloader, model)
-        else:
-            assert False, "error"
-        print("done")
-    
-    def to_triton_trt(self, dataloader, model):
-        ''' export the model to trt and test correctness on dataloader '''
-        import tensorrt as trt
-        # setup device
-        if self.args.triton_no_cuda:
-            device = torch.device('cpu')
-        else:
-            device = torch.device('cuda')
-
-        assert not self.args.quantize, 'quantize flag not supported by trt'
-        assert not self.args.calibrate, 'calibrate flag not supported by trt'
-
-        # prepare model 
-        model.to(device)
-        model.eval()
-        assert not model.training, "internal error - model should be in eval() mode! "
-        
-        # prepare inputs
-        inputs = self.lib.prepare_inputs(dataloader, device)
-        
-        # generate outputs
-        outputs = []
-        for input in inputs:
-            with torch.no_grad():
-                output = model(*input)
-            if type(output) is torch.Tensor:
-                output = [output]
-            outputs.append(output)
-        
-        # generate input shapes - dynamic tensor shape support 
-        input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
-        
-        # generate output shapes - dynamic tensor shape support 
-        output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
-        
-        # generate input types 
-        input_types = [x.dtype for x in inputs[0]]
-        
-        # generate output types
-        output_types = [x.dtype for x in outputs[0]]
-        
-        # get input names
-        rng = range(len(input_types))
-        input_names = ["input__" + str(num) for num in rng]
-        
-        # get output names
-        rng = range(len(output_types))
-        output_names = ["output__" + str(num) for num in rng]
-        
-        # prepare save path
-        model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
-        version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
-        if not os.path.exists(version_folder):
-            os.makedirs(version_folder)
-        final_model_path = os.path.join(version_folder, 'model.plan')
-        
-        # get indices of dynamic input and output shapes
-        dynamic_axes = {}
-        for input_name,shape in zip(input_names,input_shapes):
-            dynamic_axes[input_name] = [i for i,x in enumerate(shape) if x == -1]
-        for output_name,shape in zip(output_names,output_shapes):
-            dynamic_axes[output_name] = [i for i,x in enumerate(shape) if x == -1]
-        
-        # export the model to onnx first
-        with torch.no_grad():
-            torch.onnx.export(model, inputs[0], final_model_path, verbose=False, 
-                              input_names=input_names, output_names=output_names, 
-                              dynamic_axes=dynamic_axes, opset_version=11)
-      
-        # get shapes
-        min_shapes = self.lib.get_tuple_of_min_shapes(inputs)
-        opt_shapes = self.lib.get_tuple_of_opt_shapes(inputs)
-        max_shapes = self.lib.get_tuple_of_max_shapes(inputs)
-        
-        zipped = zip(input_names, min_shapes, opt_shapes, max_shapes)
-        shapes = []
-        for name,min_shape,opt_shape,max_shape in zipped:
-            d = {
-                 "name":name, 
-                 "min": min_shape, 
-                 "opt": opt_shape, 
-                 "max": max_shape
-                }
-            shapes.append(d)
-        
-        # build trt engine
-        engine = self.lib.build_trt_engine(final_model_path, shapes)
-        assert engine is not None, " trt export failure "
-        
-        # write trt engine
-        with open(final_model_path, 'wb') as f:
-            f.write(engine.serialize())
-        
-        # load the model
-        engine = self.lib.load_engine(final_model_path)
-        
-        class TRT_model:
-            def __init__(self, engine, input_names, output_names, output_types, device):
-                self.engine = engine
-                self.context = self.engine.create_execution_context()
-                self.input_names = input_names
-                self.output_names = output_names
-                self.output_types = output_types
-                self.device = device
-            
-            def is_dimension_dynamic(self, dim):
-                return dim is None or dim <= 0
-            
-            def is_shape_dynamic(self, shape):
-                return any([self.is_dimension_dynamic(dim) for dim in shape])
-            
-            def __call__(self, *inputs):
-                # get input shapes
-                input_shapes = [x.shape for x in inputs]
-                # bindings
-                bindings = [None] * self.engine.num_bindings
-                # set input shapes, bind input tensors
-                zipped = zip(self.input_names, inputs)
-                for key,input in zipped:
-                    idx = self.engine.get_binding_index(key)
-                    bindings[idx] = input.data_ptr()
-                    if self.engine.is_shape_binding(idx) and self.is_shape_dynamic(self.context.get_shape(idx)):
-                        self.context.set_shape_input(idx, input)
-                    elif self.is_shape_dynamic(self.engine.get_binding_shape(idx)):
-                        self.context.set_binding_shape(idx, input.shape)
-                assert self.context.all_binding_shapes_specified, "trt error"
-                assert self.context.all_shape_inputs_specified, "trt error"
-                # calculate output shapes, allocate output tensors and bind them
-                outputs = []
-                zipped = zip(self.output_names, self.output_types)
-                for key,dtype in zipped:
-                    idx = self.engine.get_binding_index(key)
-                    shape = self.context.get_binding_shape(idx)
-                    shape = tuple(shape)
-                    assert -1 not in shape, "trt error"
-                    tensor = torch.zeros(shape, dtype=dtype, device=self.device)
-                    outputs.append(tensor)
-                    bindings[idx] = outputs[-1].data_ptr()
-                # run inference
-                self.context.execute_v2(bindings=bindings)
-                # return the result
-                if len(outputs) == 1:
-                    outputs = outputs[0]
-                return outputs
-        
-        model_trt = TRT_model(engine, input_names, output_names, output_types, device)
-        
-        # run both models on inputs
-        assert not model.training, "internal error - model should be in eval() mode! "
-        models = (model, model_trt)
-        outputs, time_model, outputs_trt, time_model_trt = self.lib.run_models(models, inputs)
-        
-        # check for errors
-        Error_stats = self.lib.compute_errors(outputs, outputs_trt)
-        self.lib.print_errors(Error_stats)
-        print('time of error check of native model: ', time_model, 'seconds')
-        print('time of error check of trt model: ', time_model_trt, 'seconds')
-        print()
-        
-        # write TRTIS config
-        config_filename = os.path.join(model_folder, "config.pbtxt")
-        self.lib.write_config(config_filename, 
-                              input_shapes, input_types, 
-                              output_shapes, output_types)
-    
-    def name_onnx_nodes(self, model_path):
-        '''
-        Name all unnamed nodes in ONNX model
-            parameter model_path: path  ONNX model
-            return: none
-        '''
-        model = onnx.load(model_path)
-        node_id = 0
-        for node in model.graph.node:
-            if len(node.name) == 0:
-                node.name = "unnamed_node_%d" % node_id
-            node_id += 1
-        # This check partially validates model
-        onnx.checker.check_model(model)
-        onnx.save(model, model_path)
-        # Only inference really checks ONNX model for some issues
-        # like duplicated node names
-        onnxruntime.InferenceSession(model_path, None)
-    
-    def to_triton_onnx(self, dataloader, model):
-        ''' export the model to onnx and test correctness on dataloader '''
-        import onnx as local_onnx
-        global onnx
-        onnx = local_onnx
-        import onnxruntime as local_onnxruntime
-        global onnxruntime
-        onnxruntime = local_onnxruntime
-        # setup device
-        if self.args.triton_no_cuda:
-            device = torch.device('cpu')
-        else:
-            device = torch.device('cuda')
-        
-        if self.args.calibrate:
-            assert self.args.quantize, ("calibrate flag not supported "
-                                        "without quantize")
-        if self.args.quantize:
-           try:
-               from quantize import quantize, QuantizationMode
-           except ImportError as error:
-               print('quantize scripts are not present')
-               raise error
-        
-        if self.args.calibrate:
-            try:
-                import calibrate
-            except ImportError as error:
-                print('calibrate scripts are not present')
-                raise error
-        
-        # prepare model 
-        model.to(device)
-        model.eval()
-        assert not model.training, "internal error - model should be in eval() mode! "
-        
-        # prepare inputs
-        inputs = self.lib.prepare_inputs(dataloader, device)
-        
-        # generate outputs
-        outputs = []
-        for input in inputs:
-            with torch.no_grad():
-                output = model(*input)
-            if type(output) is torch.Tensor:
-                output = [output]
-            outputs.append(output)
-        
-        # generate input shapes - dynamic tensor shape support 
-        input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
-        
-        # generate output shapes - dynamic tensor shape support 
-        output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
-        
-        # generate input types 
-        input_types = [x.dtype for x in inputs[0]]
-        
-        # generate output types
-        output_types = [x.dtype for x in outputs[0]]
-        
-        # get input names
-        rng = range(len(input_types))
-        input_names = ["input__" + str(num) for num in rng]
-        
-        # get output names
-        rng = range(len(output_types))
-        output_names = ["output__" + str(num) for num in rng]
-        
-        # prepare save path
-        model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
-        version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
-        if not os.path.exists(version_folder):
-            os.makedirs(version_folder)
-        final_model_path = os.path.join(version_folder, 'model.onnx')
-        
-        # get indices of dynamic input and output shapes
-        dynamic_axes = {}
-        for input_name,input_shape in zip(input_names,input_shapes):
-            dynamic_axes[input_name] = [i for i,x in enumerate(input_shape) if x == -1]
-        for output_name,output_shape in zip(output_names,output_shapes):
-            dynamic_axes[output_name] = [i for i,x in enumerate(output_shape) if x == -1]
-        
-        # export the model
-        assert not model.training, "internal error - model should be in eval() mode! "
-        with torch.no_grad():
-            torch.onnx.export(model, inputs[0], final_model_path, verbose=False, 
-                              input_names=input_names, output_names=output_names, 
-                              dynamic_axes=dynamic_axes, opset_version=11)
-        
-        # syntactic error check
-        converted_model = onnx.load(final_model_path)
-        # check that the IR is well formed
-        onnx.checker.check_model(converted_model)
-
-        # Name unnamed nodes - it helps for some other processing tools
-        self.name_onnx_nodes(final_model_path)
-        converted_model = onnx.load(final_model_path)
-        
-        # quantize model
-        if self.args.quantize:
-            if not self.args.calibrate:
-                quantized_model = quantize(
-                    converted_model,
-                    quantization_mode = QuantizationMode.IntegerOps,
-                )
-                # check that the IR is well formed
-                try:
-                    onnx.checker.check_model(quantized_model)
-                except onnx.onnx_cpp2py_export.checker.ValidationError as error:
-                    # FIXME: It is unclear, why checker fails for quantized model so
-                    # this error is ignored currently. Inference works for
-                    # some quantized models so lets show warning here
-                    print("model check failed with warning: [", error, "]")
-                    print("Warning during onnx.checker.check_model in quantized model ignored")
-                onnx.save(quantized_model, final_model_path)
-            else:
-
-                #assert not self.args.calibrate, 'calibrate flag not supported by ONNX'
-                # Parsing command-line arguments
-                #parser = argparse.ArgumentParser(description='parsing model and test data set paths')
-                #parser.add_argument('--model_path', required=True)
-                #parser.add_argument('--dataset_path', required=True)
-                #parser.add_argument('--output_model_path', type=str, default='calibrated_quantized_model.onnx')
-                #parser.add_argument('--dataset_size', type=int, default=0, help="Number of images or tensors to load. Default is 0 which means all samples")
-                #parser.add_argument('--data_preprocess', type=str, required=True, choices=['preprocess_method1', 'preprocess_method2', 'None'], help="Refer to Readme.md for guidance on choosing this option.")
-                #args = parser.parse_args()
-                #model_path = args.model_path
-                #output_model_path = args.output_model_path
-                #images_folder = args.dataset_path
-                calib_mode = "naive"
-                size_limit = 0 # int(args.dataset_size)
-                
-                # Generating augmented ONNX model
-                # FIXME: use proper temporary file path
-                augmented_model_path = 'augmented_model.onnx'
-                #model = onnx.load(model_path)
-                augmented_model = calibrate.augment_graph(converted_model)
-                onnx.checker.check_model(augmented_model)
-                #onnx.save(augmented_model, final_model_path)
-                onnx.save(augmented_model, augmented_model_path)
-                
-                # Conducting inference
-                #session = onnxruntime.InferenceSession(final_model_path, None)
-                print(augmented_model_path)
-                session = onnxruntime.InferenceSession(augmented_model_path, None)
-                #session = onnxruntime.InferenceSession('augmented_modelv3.onnx', None)
-                (samples, channels, height, width) = session.get_inputs()[0].shape
-                print(session.get_inputs()[0].shape)
-                #return
-                
-                # Generating inputs for quantization
-                #if args.data_preprocess == "None":
-                #    inputs = load_pb_file(images_folder, args.dataset_size, samples, channels, height, width)
-                #else:
-                #    inputs = load_batch(images_folder, height, width, args.data_preprocess, size_limit)
-                
-                import numpy as np
-                inputs_calibrate_tmp = inputs[0][0].cpu().numpy()
-                
-                dict_for_quantization = calibrate.get_intermediate_outputs(
-                    final_model_path,
-                    session,
-                    inputs_calibrate_tmp,
-                    calib_mode,
-                )
-                quantization_params_dict = calibrate.calculate_quantization_params(
-                    augmented_model,
-                    quantization_thresholds = dict_for_quantization,
-                )
-                calibrated_quantized_model = quantize(
-                    converted_model,
-                    quantization_mode = QuantizationMode.QLinearOps,
-                    quantization_params = quantization_params_dict,
-                )
-                onnx.save(calibrated_quantized_model, final_model_path)
-                
-                print("Calibrated, quantized model saved.")
-        
-        # load the model
-        session = onnxruntime.InferenceSession(final_model_path, None)
-        
-        class ONNX_model:
-            def __init__(self, session, input_names, device):
-                self.session = session
-                self.input_names = input_names
-                        
-            def to_numpy(self, tensor):
-                return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
-            
-            def __call__(self, *inputs):
-                inp = [(input_name, inputs[i]) for i,input_name in enumerate(self.input_names)]
-                inp = {input_name : self.to_numpy(x) for input_name,x in inp}
-                outputs = self.session.run(None, inp)
-                outputs = [torch.from_numpy(output) for output in outputs]
-                outputs = [output.to(device) for output in outputs]
-                if len(outputs) == 1:
-                    outputs = outputs[0]
-                return outputs
-        
-        # switch to eval mode
-        model_onnx = ONNX_model(session, input_names, device)
-        
-        # run both models on inputs
-        assert not model.training, "internal error - model should be in eval() mode! "
-        models = (model, model_onnx)
-        outputs, time_model, outputs_onnx, time_model_onnx = self.lib.run_models(models, inputs)
-        
-        # check for errors
-        Error_stats = self.lib.compute_errors(outputs, outputs_onnx)
-        self.lib.print_errors(Error_stats)
-        print('time of error check of native model: ', time_model, 'seconds')
-        print('time of error check of onnx model: ', time_model_onnx, 'seconds')
-        print()
-        
-        # write TRTIS config
-        config_filename = os.path.join(model_folder, "config.pbtxt")
-        self.lib.write_config(config_filename, 
-                              input_shapes, input_types, 
-                              output_shapes, output_types)
-    
-    def to_triton_torchscript(self, dataloader, model):
-        ''' export the model to torchscript and test correctness on dataloader '''
-        # setup device
-        if self.args.triton_no_cuda:
-            device = torch.device('cpu')
-        else:
-            device = torch.device('cuda')
-        
-        # prepare model 
-        model.to(device)
-        model.eval()
-        assert not model.training, "internal error - model should be in eval() mode! "
-
-        #TODO: support quantize
-        assert not self.args.quantize, 'quantize flag not supported by torchscript yet'
-        
-        # prepare inputs
-        inputs = self.lib.prepare_inputs(dataloader, device)
-        
-        # generate input shapes - dynamic tensor shape support 
-        input_shapes = self.lib.get_tuple_of_dynamic_shapes(inputs)
-        
-        # generate input types 
-        input_types = [x.dtype for x in inputs[0]]
-        
-        # prepare save path 
-        model_folder = os.path.join(self.args.save_dir, self.args.triton_model_name)
-        version_folder = os.path.join(model_folder, str(self.args.triton_model_version))
-        if not os.path.exists(version_folder):
-            os.makedirs(version_folder)
-        final_model_path = os.path.join(version_folder, 'model.pt')
-        
-        # convert the model 
-        with torch.no_grad():
-            if self.args.ts_trace: # trace it 
-                model_ts = torch.jit.trace(model, inputs[0])
-            if self.args.ts_script: # script it 
-                model_ts = torch.jit.script(model)
-        
-        # save the model 
-        torch.jit.save(model_ts, final_model_path)
-        
-        # load the model 
-        model_ts = torch.jit.load(final_model_path)
-        model_ts.eval() # WAR for bug : by default, model_ts gets loaded in training mode
-        
-        # run both models on inputs
-        assert not model.training, "internal error - model should be in eval() mode! "
-        assert not model_ts.training, "internal error - converted model should be in eval() mode! "
-        models = (model, model_ts)
-        outputs, time_model, outputs_ts, time_model_ts = self.lib.run_models(models, inputs)
-        
-        # check for errors
-        Error_stats = self.lib.compute_errors(outputs, outputs_ts)
-        self.lib.print_errors(Error_stats)
-        print('time of error check of native model: ', time_model, 'seconds')
-        print('time of error check of ts model: ', time_model_ts, 'seconds')
-        print()
-        
-        # generate output shapes - dynamic tensor shape support 
-        output_shapes = self.lib.get_tuple_of_dynamic_shapes(outputs)
-        
-        # generate output types 
-        output_types = [x.dtype for x in outputs[0]]
-        
-        # now we build the config for TRTIS 
-        config_filename = os.path.join(model_folder, "config.pbtxt")
-        self.lib.write_config(config_filename, 
-                              input_shapes, input_types, 
-                              output_shapes, output_types)
-
- 

+ 0 - 1
PyTorch/Translation/Transformer/scripts/docker/build.sh

@@ -1 +0,0 @@
-docker build . --network=host -t transformer_pyt

+ 0 - 15
PyTorch/Translation/Transformer/scripts/docker/launch.sh

@@ -1,15 +0,0 @@
-#!/bin/bash
-
-CMD=${1:-/bin/bash}
-NV_VISIBLE_DEVICES=${2:-"0,1,2,3,4,5,6,7,8"}
-DOCKER_BRIDGE=${3:-"host"}
-
-nvidia-docker run -it --rm \
-  --net=$DOCKER_BRIDGE \
-  --shm-size=1g \
-  --ulimit memlock=-1 \
-  --ulimit stack=67108864 \
-  -e NVIDIA_VISIBLE_DEVICES=${NV_VISIBLE_DEVICES} \
-  -v $PWD/results:/results \
-  -v $PWD/data:/data \
-  transformer_pyt $CMD

+ 0 - 54
PyTorch/Translation/Transformer/scripts/export_model.sh

@@ -1,54 +0,0 @@
-#!/bin/bash
-
-# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License. 
-
-NV_VISIBLE_DEVICES=${1:-"0"}
-DOCKER_BRIDGE=${2:-"host"}
-checkpoint=${3:-"/checkpoints/checkpoint_jit.pt"}
-batch_size=${4:-"5120"}
-WORKSPACE=${5:-"/workspace/translation"}
-triton_model_version=${6:-1}
-triton_model_name=${7:-"transformer"}
-triton_dyn_batching_delay=${8:-0}
-triton_engine_count=${9:-1}
-triton_model_overwrite=${10:-"False"}
-
-DEPLOYER="deployer.py"
-
-#TODO: add fp16 option
-
-CMD="python triton/${DEPLOYER} \
-    --ts-script \
-    --save-dir ${WORKSPACE}/triton/triton_models \
-    --triton-model-name ${triton_model_name} \
-    --triton-model-version ${triton_model_version} \
-    --triton-max-batch-size ${batch_size} \
-    --triton-dyn-batching-delay ${triton_dyn_batching_delay} \
-    --triton-engine-count ${triton_engine_count} "
-
-ENCODER_EXPORT_CMD="$CMD --triton-model-name ${triton_model_name}-encoder"
-DECODER_EXPORT_CMD="$CMD --triton-model-name ${triton_model_name}-decoder"
-
-MODEL_ARGS=" -- --checkpoint ${checkpoint} \
-    --batch-size=${batch_size} \
-    --num-batches=2 \
-    --data /data "
-
-ENCODER_EXPORT_CMD+="${MODEL_ARGS} --part encoder"
-DECODER_EXPORT_CMD+="${MODEL_ARGS} --part decoder"
-
-echo Exporting encoder...
-bash scripts/docker/launch.sh "${ENCODER_EXPORT_CMD}" ${NV_VISIBLE_DEVICES} ${DOCKER_BRIDGE}
-echo Exporting decoder...
-bash scripts/docker/launch.sh "${DECODER_EXPORT_CMD}" ${NV_VISIBLE_DEVICES} ${DOCKER_BRIDGE}

+ 12 - 11
PyTorch/Translation/Transformer/scripts/run_DGX1_AMP_8GPU.sh → PyTorch/Translation/Transformer/scripts/run_DGX1_AMP.sh

@@ -17,20 +17,20 @@ nvidia-smi
 
 RESULTS_DIR='/results'
 CHECKPOINTS_DIR='/results/checkpoints'
-STAT_FILE=${RESULTS_DIR}/DGX1_amp_8GPU.json
 mkdir -p $CHECKPOINTS_DIR
 
-SEED=${1:-1}
-LR=${2:-0.000846}
-WARMUP=${3:-4000}
-NUM_EPOCHS=${4:-40}
-BATCH_SIZE=${5:-10240}
-NUM_GPU=${6:-8}
+: ${SEED:=1}
+: ${LR:=0.0006}
+: ${WARMUP:=4000}
+: ${NUM_EPOCHS:=30}
+: ${BS:=5120}
+: ${NUM_GPU:=8}
 
-DISTRIBUTED="-m torch.distributed.launch --nproc_per_node=${NUM_GPU}"
+STAT_FILE=${RESULTS_DIR}/DGX1_amp_${NUM_GPU}GPU.json
+DISTRIBUTED="-m torch.distributed.run --nproc_per_node=${NUM_GPU}"
 
 python ${DISTRIBUTED} /workspace/translation/train.py \
-  /data/wmt14_en_de_joined_dict \
+  /data/ \
   --arch transformer_wmt_en_de_big_t2t \
   --share-all-embeddings \
   --optimizer adam \
@@ -41,11 +41,12 @@ python ${DISTRIBUTED} /workspace/translation/train.py \
   --warmup-init-lr 0.0 \
   --warmup-updates ${WARMUP} \
   --lr $LR \
-  --min-lr 0.0 \ --dropout 0.1 \
+  --min-lr 0.0 \
+  --dropout 0.1 \
   --weight-decay 0.0 \
   --criterion label_smoothed_cross_entropy \
   --label-smoothing 0.1 \
-  --max-tokens ${BATCH_SIZE} \
+  --max-tokens ${BS} \
   --seed ${SEED} \
   --max-epoch ${NUM_EPOCHS} \
   --no-epoch-checkpoints \

+ 10 - 10
PyTorch/Translation/Transformer/scripts/run_DGX1_FP32_8GPU.sh → PyTorch/Translation/Transformer/scripts/run_DGX1_FP32.sh

@@ -17,20 +17,20 @@ nvidia-smi
 
 RESULTS_DIR='/results'
 CHECKPOINTS_DIR='/results/checkpoints'
-STAT_FILE=${RESULTS_DIR}/DGX1_fp32_8GPU.json
 mkdir -p $CHECKPOINTS_DIR
 
-SEED=${1:-1}
-LR=${2:-0.0006}
-WARMUP=${3:-4000}
-NUM_EPOCHS=${4:-40}
-BATCH_SIZE=${5:-5120}
-NUM_GPU=${6:-8}
+: ${SEED:=1}
+: ${LR:=0.000846}
+: ${WARMUP:=4000}
+: ${NUM_EPOCHS:=30}
+: ${BS:=2560}
+: ${NUM_GPU:=8}
 
-DISTRIBUTED="-m torch.distributed.launch --nproc_per_node=${NUM_GPU}"
+STAT_FILE=${RESULTS_DIR}/DGX1_fp32_${NUM_GPU}GPU.json
+DISTRIBUTED="-m torch.distributed.run --nproc_per_node=${NUM_GPU}"
 
 python ${DISTRIBUTED} /workspace/translation/train.py \
-  /data/wmt14_en_de_joined_dict \
+  /data/ \
   --arch transformer_wmt_en_de_big_t2t \
   --share-all-embeddings \
   --optimizer adam \
@@ -46,7 +46,7 @@ python ${DISTRIBUTED} /workspace/translation/train.py \
   --weight-decay 0.0 \
   --criterion label_smoothed_cross_entropy \
   --label-smoothing 0.1 \
-  --max-tokens ${BATCH_SIZE} \
+  --max-tokens ${BS} \
   --seed ${SEED} \
   --max-epoch ${NUM_EPOCHS} \
   --no-epoch-checkpoints \

+ 58 - 0
PyTorch/Translation/Transformer/scripts/run_DGX2_AMP.sh

@@ -0,0 +1,58 @@
+#! /bin/bash
+#
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+nvidia-smi
+
+RESULTS_DIR='/results'
+CHECKPOINTS_DIR='/results/checkpoints'
+mkdir -p $CHECKPOINTS_DIR
+
+: ${SEED:=1}
+: ${LR:=0.001}
+: ${WARMUP:=4000}
+: ${NUM_EPOCHS:=30}
+: ${BS:=10240}
+: ${NUM_GPU:=16}
+
+STAT_FILE=${RESULTS_DIR}/DGX2_amp_${NUM_GPU}GPU.json
+DISTRIBUTED="-m torch.distributed.run --nproc_per_node=${NUM_GPU}"
+
+python ${DISTRIBUTED} /workspace/translation/train.py \
+  /data/ \
+  --arch transformer_wmt_en_de_big_t2t \
+  --share-all-embeddings \
+  --optimizer adam \
+  --adam-betas 0.9 0.997 \
+  --adam-eps 1e-9 \
+  --clip-norm 0.0 \
+  --lr-scheduler inverse_sqrt \
+  --warmup-init-lr 0.0 \
+  --warmup-updates ${WARMUP} \
+  --lr $LR \
+  --min-lr 0.0 \
+  --dropout 0.1 \
+  --weight-decay 0.0 \
+  --criterion label_smoothed_cross_entropy \
+  --label-smoothing 0.1 \
+  --max-tokens ${BS} \
+  --seed ${SEED} \
+  --max-epoch ${NUM_EPOCHS} \
+  --no-epoch-checkpoints \
+  --fuse-layer-norm \
+  --online-eval \
+  --log-interval 500 \
+  --save-dir ${RESULTS_DIR} \
+  --stat-file ${STAT_FILE} \
+  --amp 

+ 10 - 19
PyTorch/Translation/Transformer/scripts/run_DGXA100_TF32_8GPU.sh → PyTorch/Translation/Transformer/scripts/run_DGX2_FP32.sh

@@ -17,29 +17,20 @@ nvidia-smi
 
 RESULTS_DIR='/results'
 CHECKPOINTS_DIR='/results/checkpoints'
-STAT_FILE=${RESULTS_DIR}/DGXA100_tf32_8GPU_log.json
 mkdir -p $CHECKPOINTS_DIR
 
-PREC=${1:-'tf32'}
-SEED=${2:-1}
-LR=${3:-0.000846}
-WARMUP=${4:-4000}
-NUM_EPOCHS=${5:-40}
-BATCH_SIZE=${6:-10240}
-NUM_GPU=${7:-8}
+: ${SEED:=1}
+: ${LR:=0.000846}
+: ${WARMUP:=4000}
+: ${NUM_EPOCHS:=30}
+: ${BS:=5120}
+: ${NUM_GPU:=16}
 
-DISTRIBUTED="-m torch.distributed.launch --nproc_per_node=${NUM_GPU}"
-
-if [ "$PREC" = "fp32" ];
-then
-    PREC=''
-    export NVIDIA_TF32_OVERRIDE=0
-else
-    PREC=''
-fi
+STAT_FILE=${RESULTS_DIR}/DGX2_fp32_${NUM_GPU}GPU.json
+DISTRIBUTED="-m torch.distributed.run --nproc_per_node=${NUM_GPU}"
 
 python ${DISTRIBUTED} /workspace/translation/train.py \
-  /data/wmt14_en_de_joined_dict \
+  /data/ \
   --arch transformer_wmt_en_de_big_t2t \
   --share-all-embeddings \
   --optimizer adam \
@@ -55,7 +46,7 @@ python ${DISTRIBUTED} /workspace/translation/train.py \
   --weight-decay 0.0 \
   --criterion label_smoothed_cross_entropy \
   --label-smoothing 0.1 \
-  --max-tokens ${BATCH_SIZE} \
+  --max-tokens ${BS} \
   --seed ${SEED} \
   --max-epoch ${NUM_EPOCHS} \
   --no-epoch-checkpoints \

+ 10 - 10
PyTorch/Translation/Transformer/scripts/run_DGXA100_AMP_8GPU.sh → PyTorch/Translation/Transformer/scripts/run_DGXA100_AMP.sh

@@ -17,20 +17,20 @@ nvidia-smi
 
 RESULTS_DIR='/results'
 CHECKPOINTS_DIR='/results/checkpoints'
-STAT_FILE=${RESULTS_DIR}/DGXA100_amp_8GPU_log.json
 mkdir -p $CHECKPOINTS_DIR
 
-SEED=${1:-1}
-LR=${2:-0.000846}
-WARMUP=${3:-4000}
-NUM_EPOCHS=${4:-40}
-BATCH_SIZE=${5:-10240}
-NUM_GPU=${6:-8}
+: ${SEED:=1}
+: ${LR:=0.000846}
+: ${WARMUP:=4000}
+: ${NUM_EPOCHS:=30}
+: ${BS:=10240}
+: ${NUM_GPU:=8}
 
-DISTRIBUTED="-m torch.distributed.launch --nproc_per_node=${NUM_GPU}"
+STAT_FILE=${RESULTS_DIR}/DGXA100_amp_${NUM_GPU}GPU_log.json
+DISTRIBUTED="-m torch.distributed.run --nproc_per_node=${NUM_GPU}"
 
 python ${DISTRIBUTED} /workspace/translation/train.py \
-  /data/wmt14_en_de_joined_dict \
+  /data/ \
   --arch transformer_wmt_en_de_big_t2t \
   --share-all-embeddings \
   --optimizer adam \
@@ -46,7 +46,7 @@ python ${DISTRIBUTED} /workspace/translation/train.py \
   --weight-decay 0.0 \
   --criterion label_smoothed_cross_entropy \
   --label-smoothing 0.1 \
-  --max-tokens ${BATCH_SIZE} \
+  --max-tokens ${BS} \
   --seed ${SEED} \
   --max-epoch ${NUM_EPOCHS} \
   --no-epoch-checkpoints \

+ 57 - 0
PyTorch/Translation/Transformer/scripts/run_DGXA100_TF32.sh

@@ -0,0 +1,57 @@
+#! /bin/bash
+#
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+nvidia-smi
+
+RESULTS_DIR='/results'
+CHECKPOINTS_DIR='/results/checkpoints'
+mkdir -p $CHECKPOINTS_DIR
+
+: ${SEED:=1}
+: ${LR:=0.000846}
+: ${WARMUP:=4000}
+: ${NUM_EPOCHS:=30}
+: ${BS:=10240}
+: ${NUM_GPU:=8}
+
+STAT_FILE=${RESULTS_DIR}/DGXA100_tf32_${NUM_GPU}GPU_log.json
+DISTRIBUTED="-m torch.distributed.run --nproc_per_node=${NUM_GPU}"
+
+python ${DISTRIBUTED} /workspace/translation/train.py \
+  /data/ \
+  --arch transformer_wmt_en_de_big_t2t \
+  --share-all-embeddings \
+  --optimizer adam \
+  --adam-betas 0.9 0.997 \
+  --adam-eps 1e-9 \
+  --clip-norm 0.0 \
+  --lr-scheduler inverse_sqrt \
+  --warmup-init-lr 0.0 \
+  --warmup-updates ${WARMUP} \
+  --lr $LR \
+  --min-lr 0.0 \
+  --dropout 0.1 \
+  --weight-decay 0.0 \
+  --criterion label_smoothed_cross_entropy \
+  --label-smoothing 0.1 \
+  --max-tokens ${BS} \
+  --seed ${SEED} \
+  --max-epoch ${NUM_EPOCHS} \
+  --no-epoch-checkpoints \
+  --fuse-layer-norm \
+  --online-eval \
+  --log-interval 500 \
+  --save-dir ${RESULTS_DIR} \
+  --stat-file ${STAT_FILE}

+ 15 - 0
PyTorch/Translation/Transformer/scripts/run_inference.sh

@@ -0,0 +1,15 @@
+: ${FP16:=0}
+
+[ ${FP16} -ne 0 ] && PREC="--fp16"
+
+sacrebleu -t wmt14/full -l en-de --echo src | \
+python inference.py \
+    --buffer-size 5000 \
+    --path /checkpoints/transformer_pyt_20.06.pt \
+    --max-tokens 10240 \
+    --fuse-dropout-add \
+    --remove-bpe \
+    --bpe-codes /checkpoints/bpe_codes \
+    ${PREC} \
+    | sacrebleu -t wmt14/full -l en-de -lc
+

+ 2 - 2
PyTorch/Translation/Transformer/scripts/run_training.sh

@@ -24,14 +24,14 @@ mkdir -p $CHECKPOINTS_DIR
 : ${SEED:=1}
 : ${LR:=0.000846}
 : ${WARMUP:=4000}
-: ${NUM_EPOCHS:=40}
+: ${NUM_EPOCHS:=30}
 : ${BS:=5120}
 : ${NUM_GPU:=8}
 : ${USE_SLURM:=0}
 : ${USE_DISTRIBUTED:=1}
 
 DISTRIBUTED=""
-[ ${USE_DISTRIBUTED} = 1 ] && DISTRIBUTED+="-m torch.distributed.launch --nproc_per_node=${NUM_GPU}"
+[ ${USE_DISTRIBUTED} = 1 ] && DISTRIBUTED+="-m torch.distributed.run --nproc_per_node=${NUM_GPU}"
 [ ${USE_DISTRIBUTED} = 1 ] && [ ${USE_SLURM} = 1 ] && DISTRIBUTED+=" --nnodes ${WORLD_SIZE} --node_rank ${SLURM_NODEID}  \
             --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} "
 

+ 1 - 1
PyTorch/Translation/Transformer/setup.py

@@ -8,7 +8,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at

+ 3 - 4
PyTorch/Translation/Transformer/train.py

@@ -8,7 +8,7 @@
 #
 #-------------------------------------------------------------------------
 #
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -48,7 +48,7 @@ def main(args):
 
     if not torch.cuda.is_available():
         raise NotImplementedError('Training on CPU is not supported')
-    torch.cuda.set_device(args.device_id)
+    torch.cuda.set_device(args.local_rank)
     if args.distributed_world_size > 1:
         assert torch.distributed.is_initialized()
         torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
@@ -424,7 +424,6 @@ if __name__ == '__main__':
     parser = options.get_training_parser()
     ARGS = options.parse_args_and_arch(parser)
 
-    if ARGS.distributed_world_size > 1:
-        distributed_utils.distributed_init(ARGS)
+    distributed_utils.distributed_init(ARGS)
 
     main(ARGS)