Browse Source

[DLRM/PyT] Updates for Ampere

Przemek Strzelczyk 5 năm trước cách đây
mục cha
commit
36f3b1b670
57 tập tin đã thay đổi với 7545 bổ sung886 xóa
  1. 3 6
      PyTorch/Recommendation/DLRM/Dockerfile
  2. 192 71
      PyTorch/Recommendation/DLRM/README.md
  3. 3 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_ext/__init__.py
  4. 31 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_ext/dot_based_interact.py
  5. 29 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_ext/fused_gather_embedding.py
  6. 69 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_ext/sparse_embedding.py
  7. 771 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact.cu
  8. 361 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact_fp32.cu
  9. 74 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact_pytorch_types.cu
  10. 833 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact_tf32.cu
  11. 13 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/pytorch_ops.cpp
  12. 22 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/shared_utils.cuh
  13. 1137 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_volta/dot_based_interact.cu
  14. 68 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_volta/dot_based_interact_pytorch_types.cu
  15. 13 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_volta/pytorch_ops.cpp
  16. 293 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/gather_gpu_fused.cu
  17. 100 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/gather_gpu_fused_pytorch_impl.cu
  18. 22 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/pytorch_embedding_ops.cpp
  19. 21 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/sparse_gather/common.h
  20. 171 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/sparse_gather/gather_gpu.cu
  21. 14 0
      PyTorch/Recommendation/DLRM/dlrm/cuda_src/sparse_gather/sparse_pytorch_ops.cpp
  22. 14 54
      PyTorch/Recommendation/DLRM/dlrm/data/data_loader.py
  23. 242 0
      PyTorch/Recommendation/DLRM/dlrm/data/datasets.py
  24. 219 0
      PyTorch/Recommendation/DLRM/dlrm/data/factories.py
  25. 37 0
      PyTorch/Recommendation/DLRM/dlrm/data/samplers.py
  26. 0 42
      PyTorch/Recommendation/DLRM/dlrm/data/synthetic_dataset.py
  27. 164 0
      PyTorch/Recommendation/DLRM/dlrm/data/utils.py
  28. 0 224
      PyTorch/Recommendation/DLRM/dlrm/model.py
  29. 0 0
      PyTorch/Recommendation/DLRM/dlrm/model/__init__.py
  30. 156 0
      PyTorch/Recommendation/DLRM/dlrm/model/distributed.py
  31. 81 0
      PyTorch/Recommendation/DLRM/dlrm/model/single.py
  32. 0 0
      PyTorch/Recommendation/DLRM/dlrm/nn/__init__.py
  33. 248 0
      PyTorch/Recommendation/DLRM/dlrm/nn/embeddings.py
  34. 64 0
      PyTorch/Recommendation/DLRM/dlrm/nn/factories.py
  35. 113 0
      PyTorch/Recommendation/DLRM/dlrm/nn/interactions.py
  36. 117 0
      PyTorch/Recommendation/DLRM/dlrm/nn/mlps.py
  37. 135 0
      PyTorch/Recommendation/DLRM/dlrm/nn/parts.py
  38. 401 0
      PyTorch/Recommendation/DLRM/dlrm/scripts/dist_main.py
  39. 135 161
      PyTorch/Recommendation/DLRM/dlrm/scripts/main.py
  40. 26 0
      PyTorch/Recommendation/DLRM/dlrm/scripts/prepare_synthetic_dataset.py
  41. 105 78
      PyTorch/Recommendation/DLRM/dlrm/scripts/utils.py
  42. 0 0
      PyTorch/Recommendation/DLRM/dlrm/utils/__init__.py
  43. 139 0
      PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing.py
  44. 0 0
      PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing/__init__.py
  45. 105 0
      PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing/distributed.py
  46. 132 0
      PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing/model.py
  47. 66 0
      PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing/serial.py
  48. 153 0
      PyTorch/Recommendation/DLRM/dlrm/utils/distributed.py
  49. 12 4
      PyTorch/Recommendation/DLRM/preproc/prepare_dataset.sh
  50. 119 0
      PyTorch/Recommendation/DLRM/preproc/split_dataset.py
  51. 52 2
      PyTorch/Recommendation/DLRM/setup.py
  52. 9 8
      PyTorch/Recommendation/DLRM/triton/Dockerfile
  53. 88 94
      PyTorch/Recommendation/DLRM/triton/README.md
  54. 113 85
      PyTorch/Recommendation/DLRM/triton/client.py
  55. 45 36
      PyTorch/Recommendation/DLRM/triton/deployer.py
  56. 15 21
      PyTorch/Recommendation/DLRM/triton/deployer_lib.py
  57. BIN
      PyTorch/Recommendation/DLRM/triton/img/lat_vs_thr.png

+ 3 - 6
PyTorch/Recommendation/DLRM/Dockerfile

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.03-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.06-py3
 FROM ${FROM_IMAGE_NAME}
 
 RUN apt update && \
@@ -24,11 +24,8 @@ RUN apt update && \
 ADD requirements.txt .
 RUN pip install -r requirements.txt
 
-RUN pip uninstall -y apex && \
-    git clone https://github.com/NVIDIA/apex && \
-    cd apex && \
-    pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
-
 WORKDIR /workspace/dlrm
 
 COPY . .
+
+RUN pip install --no-cache-dir -e .

+ 192 - 71
PyTorch/Recommendation/DLRM/README.md

@@ -15,6 +15,7 @@ This repository provides a script and recipe to train the Deep Learning Recommen
         * [Features](#features)
      * [Mixed precision training](#mixed-precision-training)
         * [Enabling mixed precision](#enabling-mixed-precision)
+        * [Enabling TF32](#enabling-tf32)
   * [Setup](#setup)
      * [Requirements](#requirements)
   * [Quick Start Guide](#quick-start-guide)
@@ -28,16 +29,20 @@ This repository provides a script and recipe to train the Deep Learning Recommen
         * [Preprocess with Spark](#preprocess-with-spark)
      * [Training process](#training-process)
      * [Inference process](#inference-process)
+     * [Deploying DLRM Using NVIDIA Triton Inference Server](#deploying-dlrm-using-nvidia-triton-inference-server)
   * [Performance](#performance)
      * [Benchmarking](#benchmarking)
         * [Training performance benchmark](#training-performance-benchmark)
         * [Inference performance benchmark](#inference-performance-benchmark)
      * [Results](#results)
         * [Training accuracy results](#training-accuracy-results)
-           * [Training accuracy: NVIDIA DGX-1 (8x V100 32G)](#training-accuracy-nvidia-dgx-1-8x-v100-32g)
+           * [Training accuracy: NVIDIA DGX A100 (8x A100 40GB)](#training-accuracy-nvidia-dgx-a100-8x-a100-40gb)  
+           * [Training accuracy: NVIDIA DGX-1 (8x V100 32GB)](#training-accuracy-nvidia-dgx-1-8x-v100-32gb)
            * [Training stability test](#training-stability-test)
         * [Training performance results](#training-performance-results)
-           * [Training performance: NVIDIA DGX-1 (8x V100 32G)](#training-performance-nvidia-dgx-1-8x-v100-32g)
+           * [Training performance: NVIDIA DGX A100 (8x A100 40GB)](#training-performance-nvidia-dgx-a100-8x-a100-40gb)
+           * [Training performance: NVIDIA DGX-1 (8x V100 32GB)](#training-performance-nvidia-dgx-1-8x-v100-32gb)
+           * [Training performance: NVIDIA DGX-2 (16x V100 32GB)](#training-performance-nvidia-dgx-2-16x-v100-32gb)
   * [Release notes](#release-notes)
      * [Changelog](#changelog)
      * [Known issues](#known-issues)
@@ -54,7 +59,7 @@ This model uses a slightly different preprocessing procedure than the one found
 
 Using DLRM you can train a high-quality general model for providing recommendations.
 
-This model is trained with mixed precision using Tensor Cores on NVIDIA Volta and Turing GPUs. Therefore, researchers can get results 1.77x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. It is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
+This model is trained with mixed precision using Tensor Cores on Volta, Turing and NVIDIA Ampere GPU architectures. Therefore, researchers can get results 3.4x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. It is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
 
 
 
@@ -90,15 +95,18 @@ The following features are supported by this model:
 | Feature               | DLRM                
 |----------------------|--------------------------
 |Automatic mixed precision (AMP)   | yes
+|PyTorch Multi-GPU (NCCL)   | yes
          
 #### Features
 
 Automatic Mixed Precision (AMP) - enables mixed precision training without any changes to the code-base by performing automatic graph rewrites and loss scaling controlled by an environmental variable.
 
+Multi-GPU training with PyTorch distributed - our model uses `torch.distributed` to implement efficient multi-GPU training with NCCL. For details, see example sources in this repository or see the [PyTorch Tutorial](https://pytorch.org/tutorials/intermediate/dist_tuto.html).
+
 
 ### Mixed precision training
 
-Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in the Volta and Turing architecture, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps:
+Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in Volta, and following with both the Turing and Ampere architectures, significant training speedups are experienced by switching to mixed precision -- up to 3.4x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps:
 1.  Porting the model to use the FP16 data type where appropriate.    
 2.  Adding loss scaling to preserve small gradient values.
 
@@ -111,7 +119,18 @@ For information about:
 
 #### Enabling mixed precision
 
-Mixed precision training is enabled by default. To turn it off issue the `--nofp16` flag to the `main.py` script.
+Mixed precision training is turned off by default. To turn it on issue the `--amp` flag to the `main.py` script.
+
+
+#### Enabling TF32
+
+TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on Volta GPUs. 
+
+TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models which require high dynamic range for weights or activations.
+
+For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
+
+TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
 
 
 ## Setup
@@ -122,8 +141,12 @@ The following section lists the requirements for training DLRM.
 
 This repository contains Dockerfile which extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
 -   [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
--   [PyTorch 20.03-py3+] NGC container
--   [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) based GPU
+-   [PyTorch 20.06-py3] NGC container
+-   Supported GPUs:
+    - [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
+    - [NVIDIA Turing architecture](https://www.nvidia.com/en-us/geforce/turing/)
+    - [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
+
 
 For more information about how to get started with NGC containers, see the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
 -   [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
@@ -134,7 +157,7 @@ For those unable to use the PyTorch NGC container, to set up the required enviro
 
 ## Quick Start Guide
 
-To train your model using mixed precision with Tensor Cores or using FP32, perform the following steps using
+To train your model using mixed or TF32 precision with Tensor Cores or using FP32, perform the following steps using
 the default parameters of DLRM on the Criteo Terabyte dataset. For the specifics concerning training and inference,
 see the [Advanced](#advanced) section.
 
@@ -168,15 +191,28 @@ cd -
 ```
 
 5. Start training.
+
+- single-GPU:
 ```
 python -m dlrm.scripts.main --mode train --dataset /data/dlrm/binary_dataset/
 ```
 
+- multi-GPU:
+```
+python -u -m torch.distributed.launch --use_env --nproc_per_node 8 -m dlrm.scripts.dist_main --mode train --dataset /data/dlrm/binary_dataset
+```
+
 6. Start validation/evaluation.
+
+- single-GPU:
 ```
 python -m dlrm.scripts.main --mode test --dataset /data/dlrm/binary_dataset/
 ```
 
+- multi-GPU:
+```
+python -u -m torch.distributed.launch --use_env --nproc_per_node 8 -m dlrm.scripts.dist_main --mode test --dataset /data/dlrm/binary_dataset
+```
 
 ## Advanced
 
@@ -184,12 +220,13 @@ The following sections provide greater details of the dataset, running training
 
 ### Scripts and sample code
 
-The `dlrm/scripts/main.py` script provides an entry point to most of the functionality. Using different command-line flags allows you to run training, validation and benchmark both training and inference on real or synthetic data. 
+The `dlrm/scripts/main.py` script provides an entry point to most of the functionality in single-GPU setting. Using different command-line flags allows you to run training, validation and benchmark both training and inference on real or synthetic data.
 
-The `dlrm/model.py` file provides the definition of the DLRM neural network.
+Analogously, the `dlrm/scripts/dist_main.py` script provides an entry point for the functionality in multi-GPU setting. It uses the same flags as in single-GPU case with the defaults tuned to large model training.
 
-Utilities connected to loading the data reside in the `data` directory.
+The `dlrm/model/single.py` file provides the definition of the DLRM neural network for single-GPU, whereas `dlrm/model/distributed.py` contains DLRM definition for multi-GPU case.
 
+Utilities connected to loading the data reside in the `data` directory.
 
 ### Parameters
 
@@ -198,10 +235,12 @@ Utilities connected to loading the data reside in the `data` directory.
 The `dlrm/scripts/main.py` script supports a number of command-line flags. You can get the descriptions of those by running `python -m dlrm.scripts.main --help`. Running this command will output:
 
 ```        
-       USAGE: /workspace/dlrm/dlrm/scripts/main.py [flags]
+       USAGE: /workspace/dlrm/scripts/main.py [flags]
 flags:
 
-/workspace/dlrm/dlrm/scripts/main.py:
+/workspace/dlrm/scripts/main.py:
+  --[no]amp: If True the script will use Automatic Mixed Precision
+    (default: 'false')
   --auc_threshold: Stop the training after achieving this AUC
     (a number)
   --base_device: Device to run the majority of the model operations
@@ -209,89 +248,89 @@ flags:
   --batch_size: Batch size used for training
     (default: '32768')
     (an integer)
-  --benchmark_warmup_steps: Number of initial iterations to exclude from
-    throughput measurements
+  --benchmark_warmup_steps: Number of initial iterations to exclude from throughput measurements
     (default: '0')
     (an integer)
   --bottom_mlp_sizes: Linear layer sizes for the bottom MLP
     (default: '512,256,128')
     (a comma separated list)
-  --dataset: Full path to binary dataset. Must include files such as:
-    train_data.bin, test_data.bin
-  --dataset_subset: Use only a subset of the training data. If None (default)
-    will use all of it. Must be either None, or a float in range [0,1]
+  --dataset: Full path to binary dataset. Must include files such as: train_data.bin, test_data.bin
+  --dataset_subset: Use only a subset of the training data. If None (default) will use all of it. Must be either None, or a float in
+    range [0,1]
+    (a number)
+  --dataset_type: <binary|memmap|split|synthetic_gpu|synthetic_disk>: The type of the dataset to use
+    (default: 'split')
+  --decay_end_lr: LR after the decay ends
+    (default: '0.0')
     (a number)
-  --decay_start_step: Optimization step after which to start decaying the
-    learning rate, if None will start decaying right after the warmup phase is
-    completed
+  --decay_power: Polynomial learning rate decay power
+    (default: '2')
+    (an integer)
+  --decay_start_step: Optimization step after which to start decaying the learning rate, if None will start decaying right after the
+    warmup phase is completed
     (default: '64000')
     (an integer)
-  --decay_steps: Polynomial learning rate decay steps. If equal to 0 will not do
-    any decaying
+  --decay_steps: Polynomial learning rate decay steps. If equal to 0 will not do any decaying
     (default: '80000')
     (an integer)
   --embedding_dim: Dimensionality of embedding space for categorical features
     (default: '128')
     (an integer)
+  --embedding_type: <joint|joint_fused|joint_sparse|multi_table>: The type of the embedding operation to use
+    (default: 'joint_fused')
   --epochs: Number of epochs to train for
     (default: '1')
     (an integer)
-  --[no]fp16: If True (default) the script will use Automatic Mixed Precision
-    (default: 'true')
-  --[no]hash_indices: If True the model will compute `index := index % table
-    size` to ensure that the indices match table sizes
+  --[no]hash_indices: If True the model will compute `index := index % table size` to ensure that the indices match table sizes
     (default: 'false')
-  --inference_benchmark_batch_sizes: Batch sizes for inference throughput and
-    latency measurements
+  --inference_benchmark_batch_sizes: Batch sizes for inference throughput and latency measurements
     (default: '1,64,4096')
     (a comma separated list)
-  --inference_benchmark_steps: Number of steps for measuring inference latency
-    and throughput
+  --inference_benchmark_steps: Number of steps for measuring inference latency and throughput
     (default: '200')
     (an integer)
-  --interaction_op: Type of interaction operation to perform. Supported choices:
-    'dot' or 'cat'
-    (default: 'dot')
+  --interaction_op: <cuda_dot|dot|cat>: Type of interaction operation to perform.
+    (default: 'cuda_dot')
   --load_checkpoint_path: Path from which to load a checkpoint
   --log_path: Destination for the log file with various results and statistics
     (default: './log.json')
   --loss_scale: Static loss scale for Mixed Precision Training
-    (default: '8192.0')
+    (default: '1024.0')
     (a number)
   --lr: Base learning rate
     (default: '28.0')
     (a number)
   --max_steps: Stop training after doing this many optimization steps
     (an integer)
-  --max_table_size: Maximum number of rows per embedding table, by default equal
-    to the number of unique values for each categorical variable
+  --max_table_size: Maximum number of rows per embedding table, by default equal to the number of unique values for each categorical
+    variable
     (an integer)
   --mode: <train|test|inference_benchmark>: Select task to be performed
     (default: 'train')
-  --num_numerical_features: Number of numerical features in the dataset.
-    Defaults to 13 for the Criteo Terabyte Dataset
+  --num_numerical_features: Number of numerical features in the dataset. Defaults to 13 for the Criteo Terabyte Dataset
     (default: '13')
     (an integer)
+  --[no]optimized_mlp: Use an optimized implementation of MLP from apex
+    (default: 'true')
   --output_dir: Path where to save the checkpoints
     (default: '/tmp')
-  --print_freq: Number of optimizations steps between printing training status
-    to stdout
+  --print_freq: Number of optimizations steps between printing training status to stdout
     (default: '200')
     (an integer)
   --save_checkpoint_path: Path to which to save the training checkpoints
   --seed: Random seed
     (default: '12345')
     (an integer)
-  --[no]self_interaction: Set to True to use self-interaction
+  -shuffle,--[no]shuffle_batch_order: Read batch in train dataset by random order
     (default: 'false')
-  -shuffle,--[no]shuffle_batch_order: Read batch in train dataset by random
-    order
-    (default: 'false')
-  --[no]synthetic_dataset: Use synthetic instead of real data for benchmarking
-    purposes
-    (default: 'false')
-  --synthetic_dataset_table_sizes: Embedding table sizes to use with the
-    synthetic dataset
+  --synthetic_dataset_dir: Default synthetic disk dataset directory
+    (default: '/tmp/dlrm_sythetic_dataset')
+  --synthetic_dataset_num_entries: Number of samples per epoch for the synthetic dataset
+    (default: '33554432')
+    (an integer)
+  --synthetic_dataset_table_sizes: Embedding table sizes to use with the synthetic dataset
+    (default: '100000,100000,100000,100000,100000,100000,100000,100000,100000,100000,100000,100000,100000,100000,100000,100000,100000,10
+    0000,100000,100000,100000,100000,100000,100000,100000,100000')
     (a comma separated list)
   --test_after: Don't test the model unless this many epochs has been completed
     (default: '0.0')
@@ -299,8 +338,7 @@ flags:
   --test_batch_size: Batch size used for testing/validation
     (default: '32768')
     (an integer)
-  --test_freq: Number of optimization steps between validations. If None will
-    test after each epoch
+  --test_freq: Number of optimization steps between validations. If None will test after each epoch
     (an integer)
   --top_mlp_sizes: Linear layer sizes for the top MLP
     (default: '1024,1024,512,256,1')
@@ -407,13 +445,16 @@ of samples processed per second. We use mixed precision training with static los
 
 ### Inference process
 
-This section describes inference with PyTorch in Python. If you're interested in inference using the Triton Inference Server, refer to `triton/README.md` file.
+This section describes inference with PyTorch in Python. If you're interested in inference using the Triton Inference Server, refer to [triton/README.md](triton/README.md) file.
 
 Two modes for inference are currently supported by the `dlrm/scripts/main.py` script:
 
-1. Inference benchmark – this mode will measure and print out throughput and latency numbers for multiple batch sizes. You can activate it by setting the batch sizes to be tested with the `inference_benchmark_batch_sizes` command-line argument. It will use the default test dataset unless the `--synthetic_dataset` flag is passed.
+1. Inference benchmark – this mode will measure and print out throughput and latency numbers for multiple batch sizes. You can activate it by setting the batch sizes to be tested with the `inference_benchmark_batch_sizes` command-line argument. It will use the default test dataset unless the `--dataset_type synthetic_disk` flag is passed.
 2. Test-only – this mode can be used to run a full validation on a checkpoint to measure ROC AUC . You can enable it by passing the `--mode test` flag.
 
+### Deploying DLRM Using NVIDIA Triton Inference Server
+The NVIDIA Triton Inference Server provides a cloud inferencing solution optimized for NVIDIA GPUs. The server provides an inference service via an HTTP or GRPC endpoint, allowing remote clients to request inferencing for any model being managed by the server. More information on how to perform inference using NVIDIA Triton Inference Server can be found in [triton/README.md](triton/README.md).
+
 ## Performance
 
 ### Benchmarking
@@ -425,10 +466,10 @@ The following section shows how to run benchmarks measuring the model performanc
 To benchmark the training performance on a specific batch size, run:
 
 ```
-python -m dlrm.scripts.main --mode train --max_steps 500 --benchmark_warmup_steps 250 --dataset /data
+python -m dlrm.scripts.main --mode train --max_steps 1000 --benchmark_warmup_steps 500 --dataset /data
 ```
 
-You can also pass the `--synthetic_dataset` flag if you haven't yet downloaded the dataset.
+You can also pass the `--dataset_type synthetic_disk` flag if you haven't yet downloaded the dataset.
 
 #### Inference performance benchmark
 
@@ -438,22 +479,41 @@ To benchmark the inference performance on a specific batch size, run:
 python -m dlrm.scripts.main --mode inference_benchmark --dataset /data
 ```
 
-You can also pass the `--synthetic_dataset` flag if you haven't yet downloaded the dataset.
+You can also pass the `--dataset_type synthetic_disk` flag if you haven't yet downloaded the dataset.
 
 ### Results 
 
-The following sections provide details on how we achieved our performance and accuracy in training and inference.
+The following sections provide details on how we achieved our performance and accuracy in training and inference. 
+
+We used two model size variants to show memory scalability in multi-GPU setup:
+- small - refers to model trained on Criteo dataset with frequency thresholding set to 15 resulting in smaller embedding tables - total model size: ~15 GB
+- large - refers to model trained on Criteo dataset wtih frequency thresholding set to 3 resulting in larger embedding tables - total model size: ~82 GB
 
 #### Training accuracy results
 
 
-##### Training accuracy: NVIDIA DGX-1 (8x V100 32G)
+##### Training accuracy: NVIDIA DGX A100 (8x A100 40GB)
+
+Our results were obtained by running training scripts as described in the Quick Start Guide in the DLRM Docker container in two configurations:
+- on a single NVIDIA A100 40GB GPU (`dlrm/scripts/main.py`)
+- in multi-GPU setup on DGX A100 with 8x Ampere A100 40GB (`dlrm/scripts/dist_main.py`)
+
+| GPUs    | Model size    | Batch size / GPU    | Accuracy (AUC) - TF32  | Accuracy (AUC) - mixed precision  |   Time to train - TF32 [minutes]  |  Time to train - mixed precision [minutes] | Time to train speedup (TF32 to mixed precision)        
+|----:|----|----|----:|----:|---:|---:|---:|
+| 8 | large | 64k | 0.8027 | 0.8027 | 8.79 | 6.16 | 1.43 |
+| 1 | small | 32k | 0.8036 | 0.8036 | 28.20 | 17.45 | 1.62 |
 
-Our results were obtained by running the `dlrm/scripts/main.py` script for one epoch as described in the Quick Start Guide training script in the DLRM Docker container on a single Tesla V100 32G GPU.
 
-| GPUs    | Batch size / GPU    | Accuracy (AUC) - FP32  | Accuracy (AUC) - mixed precision  |   Time to train - FP32  [hours] |  Time to train - mixed precision  [hours] | Time to train speedup (FP32 to mixed precision)        
-|----|----|----|----|---|---|---|
-| 1 | 32k | 0.80362 | 0.80362 | 2.46 | 1.44 | 1.71 |
+##### Training accuracy: NVIDIA DGX-1 (8x V100 32GB)
+
+Our results were obtained by running training scripts as described in the Quick Start Guide in the DLRM Docker container in two configurations:
+- on a single Tesla V100 32GB GPU (`dlrm/scripts/main.py`)
+- in multi-GPU setup on DGX-1 8 x Tesla V100 32 GPU (`dlrm/scripts/dist_main.py`)
+
+| GPUs    | Model size    | Batch size / GPU    | Accuracy (AUC) - FP32  | Accuracy (AUC) - mixed precision  |   Time to train - FP32  [minutes] |  Time to train - mixed precision  [minutes] | Time to train speedup (FP32 to mixed precision)        
+|----:|----|----|----:|----:|---:|---:|---:|
+| 8 | large | 64k | 0.8027 | 0.8027 | 46.29 | 22.72 | 2.04 |
+| 1 | small | 32k | 0.8035 | 0.8035 | 105.98 | 31.12 | 3.40 |
 
 
 
@@ -486,22 +546,78 @@ The table below shows the complete convergence data for 16 different random seed
 #### Training performance results
 
 
-##### Training performance: NVIDIA DGX-1 (8x V100 32G)
+We used throughput in items processed per second as the performance metric.
+
+
+##### Training performance: NVIDIA DGX A100 (8x A100 40GB)
+
+Our results were obtained by running the following commands:
+- for single GPU setup:
+```
+python -m dlrm.scripts.main --mode train --dataset /data [--amp]
+```
+- for multi GPU setup:
+```
+python -u -m torch.distributed.launch --use_env --nproc_per_node 8 -m dlrm.scripts.dist_main --mode train --dataset /data/ [--amp]
+```
+
+in the DLRM Docker container on NVIDIA DGX A100 (8x A100 40GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
+
+| GPUs   | Model size    | Batch size / GPU   | Throughput - TF32    | Throughput - mixed precision    | Throughput speedup (TF32 - mixed precision)      
+|----:|----|----|---:|---:|---:|
+| 8 | large | 64k | 8252438.74 | 11771969.56 | 1.43 |
+| 1 | small | 32k | 2498002.39 | 4081969.37 | 1.63 |
+
+
+To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
+
 
-Our results were obtained by running:
+##### Training performance: NVIDIA DGX-1 (8x V100 32GB)
+
+Our results were obtained by running the following commands:
+- for single GPU setup:
+```
+python -m dlrm.scripts.main --mode train --dataset /data [--amp]
 ```
-python -m dlrm.scripts.main --mode train --max_steps 200 --benchmark_warmup_steps 50 --fp16 --dataset /data
+- for multi GPU setup:
 ```
- in the DLRM Docker container on NVIDIA DGX-1 with (8x V100 32G) GPUs. Performance numbers (in items/images per second) were averaged over 150 training steps.
+python -u -m torch.distributed.launch --use_env --nproc_per_node 8 -m dlrm.scripts.dist_main --mode train --dataset /data/ [--amp]
+```
+
+ in the DLRM Docker container on NVIDIA DGX-1 with (8x V100 32GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
 
-| GPUs   | Batch size / GPU   | Throughput - FP32    | Throughput - mixed precision    | Throughput speedup (FP32 - mixed precision)   |
-|----|---|---|---|---|
-| 1 | 32k |  494k | 875k | 1.773 |
+| GPUs   | Model size    | Batch size / GPU   | Throughput - FP32    | Throughput - mixed precision    | Throughput speedup (FP32 - mixed precision)   |     
+|----:|----|----|---:|---:|---:|
+| 8 | large | 64k | 1538759.56 | 3257414.75 | 2.12 |
+| 1 | small | 32k | 670238.82 | 2281278.45 | 3.40 |
 
 
 We used throughput in items processed per second as the performance metric.
 
 
+##### Training performance: NVIDIA DGX-2 (16x V100 32GB)
+
+Our results were obtained by running the following commands:
+- for single GPU setup:
+```
+python -m dlrm.scripts.main --mode train --dataset /data [--amp] 
+```
+- for multi GPU setup:
+```
+python -u -m torch.distributed.launch --use_env --nproc_per_node 16 -m dlrm.scripts.dist_main --mode train --dataset /data/ [--amp]
+```
+ in the DLRM Docker container on NVIDIA DGX-2 with (16x V100 32GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
+
+| GPUs   | Model size   | Batch size / GPU   | Throughput - FP32    | Throughput - mixed precision    | Throughput speedup (FP32 - mixed precision)     
+|----:|----|---|---:|---:|---:|
+| 16 | large | 64k | 4343127.59 | 9454627.44 | 2.18 |
+| 8 | large | 64k | 2948808.82 | 7057842.56 | 2.39 |
+| 1 | small | 32k | 706933.08 | 2417584.57 | 3.42 |
+
+
+To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
+
+
 ## Release notes
 
 ### Changelog
@@ -509,6 +625,11 @@ We used throughput in items processed per second as the performance metric.
 April 2020
 - Initial release
 
+May 2020
+- Performance optimizations
+
+June 2020
+- Updated performance tables to include A100 results and multi-GPU setup
 
 ### Known issues
 

+ 3 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_ext/__init__.py

@@ -0,0 +1,3 @@
+from .dot_based_interact import dotBasedInteract
+from .fused_gather_embedding import buckle_embedding_fused_gather
+from .sparse_embedding import JointSparseEmbedding

+ 31 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_ext/dot_based_interact.py

@@ -0,0 +1,31 @@
+import torch
+from torch.autograd import Function
+from apex import amp
+
+if torch.cuda.get_device_capability()[0] >= 8:
+    print('Using the Ampere-optimized dot interaction kernels')
+    from dlrm.cuda_ext import interaction_ampere as interaction
+else:
+    print('Using the Volta-optimized dot interaction kernels')
+    from dlrm.cuda_ext import interaction_volta as interaction
+
+
+class DotBasedInteract(Function):
+    """ Forward and Backward paths of cuda extension for dot-based feature interact."""
+
+    @staticmethod
+    @amp.half_function
+    def forward(ctx, input, bottom_mlp_output):
+        output = interaction.dotBasedInteractFwd(input, bottom_mlp_output)
+        ctx.save_for_backward(input)
+        return output
+
+    @staticmethod
+    @amp.half_function
+    def backward(ctx, grad_output):
+        input, = ctx.saved_tensors
+        grad, mlp_grad = interaction.dotBasedInteractBwd(input, grad_output)
+        return grad, mlp_grad
+
+
+dotBasedInteract = DotBasedInteract.apply

+ 29 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_ext/fused_gather_embedding.py

@@ -0,0 +1,29 @@
+"""
+Fused Buckle Embedding
+"""
+
+from absl import logging
+from apex import amp
+from torch.autograd import Function
+
+from dlrm.cuda_ext import fused_embedding
+
+
+class BuckleEmbeddingFusedGatherFunction(Function):
+    """Customized embedding gather """
+    @staticmethod
+    def forward(ctx, embedding, indices, offsets, amp_train):
+        output = fused_embedding.gather_gpu_fused_fwd(embedding, indices, offsets, amp_train)
+        ctx.save_for_backward(embedding, indices, offsets)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        embedding, indices, offsets = ctx.saved_tensors
+
+        logging.log_first_n(logging.WARNING, "Highly specialized embedding for embedding_dim 128", 1)
+        grad_weights = fused_embedding.gather_gpu_fused_bwd(embedding, indices, offsets, grad_output)
+        return grad_weights, None, None, None
+
+
+buckle_embedding_fused_gather = amp.float_function(BuckleEmbeddingFusedGatherFunction.apply)

+ 69 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_ext/sparse_embedding.py

@@ -0,0 +1,69 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+import torch
+from apex import amp
+from dlrm.cuda_ext import sparse_gather
+from torch import nn
+from torch.autograd import Function
+
+
+class EmbeddingGatherFunction(Function):
+    """Customized embedding gather with fused plain SGD"""
+    @staticmethod
+    def forward(ctx, embedding, indices):
+        output = sparse_gather.gather_gpu_fwd(embedding, indices)
+        ctx.save_for_backward(indices)
+        ctx.num_features = embedding.size(0)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        indices = ctx.saved_tensors[0]
+
+        grad_embedding = sparse_gather.gather_gpu_bwd(grad_output, indices, ctx.num_features)
+
+        return grad_embedding, None
+
+
+class JointSparseEmbedding(nn.Module):
+    """Joint multiple one hot embedding together
+
+    Multiple one hot embedding can be done as one embedding (indexing).
+
+    Args:
+        categorical_feature_sizes (list): A list of integer indicating number of features of each embedding table
+        embedding_dim (int): the size of each embedding vector
+        device (torch.device): where to create the embedding. Default "cuda"
+    """
+    def __init__(self, categorical_feature_sizes, embedding_dim, device="cuda"):
+        super(JointSparseEmbedding, self).__init__()
+        self.embedding_dim = embedding_dim
+        self.categorical_feature_sizes = copy.copy(categorical_feature_sizes)
+
+        self.register_buffer("offsets", torch.tensor([0] + categorical_feature_sizes).cumsum(0).to(device))
+        self.weights = torch.nn.Parameter(torch.rand((self.offsets[-1].item(), embedding_dim), device=device))
+
+    def forward(self, categorical_inputs):
+        # Check input has the right shape
+        assert categorical_inputs.shape[1] == len(self.categorical_feature_sizes)
+
+        embedding_out = embedding_gather(self.weights, categorical_inputs + self.offsets[:-1])
+
+        return embedding_out
+
+
+embedding_gather = amp.float_function(EmbeddingGatherFunction.apply)

+ 771 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact.cu

@@ -0,0 +1,771 @@
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime_api.h>
+#include <device_launch_parameters.h>
+#include <mma.h>
+#include <cuda_fp16.hpp>
+
+#include <fstream>
+#include <iomanip>
+#include <iostream>
+#include <vector>
+
+#include "shared_utils.cuh"
+
+struct __align__(8) half4 {
+  half2 vals[2];
+};
+
+using namespace nvcuda;
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint M_BLOCKS,
+          uint K_BLOCKS,
+          uint SMEM_STRIDE,
+          uint SMEM_STRIDE_ACC,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractFwdKernelNonAligned(const __half *__restrict input,
+                                                                                        __half *__restrict output,
+                                                                                        uint batch_size,
+                                                                                        uint num_rows,
+                                                                                        uint num_cols,
+                                                                                        uint num_rows_after_padding,
+                                                                                        uint num_cols_after_padding,
+                                                                                        uint smem_elems_per_warp,
+                                                                                        uint smem_rows_per_warp,
+                                                                                        uint output_size,
+                                                                                        uint num_row_steps,
+                                                                                        uint num_col_steps) {
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  int sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  int lane_id = threadIdx.x & (WARP_SIZE - 1);
+
+  extern __shared__ half shmem_dynamic[];
+  half *shmem = shmem_dynamic + (warp_id * smem_elems_per_warp);
+
+  const half *sample_input = input + num_rows * num_cols * sample_id;
+  for (uint i = 0; i < num_rows; ++i, sample_input += num_cols) {
+    for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+      (shmem + i * SMEM_STRIDE)[idx] = sample_input[idx];
+    }
+  }
+
+  uint idx = lane_id + num_cols;
+  if (idx < num_cols_after_padding) {
+    for (int i = 0; i < num_rows; ++i) {
+      (shmem + i * SMEM_STRIDE)[idx] = __float2half(0);
+    }
+  }
+
+  half4 zeros;
+  zeros.vals[0].x = __float2half(0);
+  zeros.vals[0].y = __float2half(0);
+  zeros.vals[1].x = __float2half(0);
+  zeros.vals[1].y = __float2half(0);
+  if (lane_id < (num_cols_after_padding >> 2)) {
+    for (int i = num_rows; i < num_rows_after_padding; i++) {
+      ((half4 *)(shmem + i * SMEM_STRIDE))[lane_id] = zeros;
+    }
+  }
+  __syncwarp();
+  half *gmem_output = output + output_size * sample_id;
+
+  for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+    gmem_output[idx] = shmem[idx];
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[M_BLOCKS][M_BLOCKS];
+
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      wmma::fill_fragment(acc[i][j], 0);
+    }
+  }
+
+  for (int k_step = 0; k_step < num_col_steps; k_step++) {
+    wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[M_BLOCKS];
+    wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::col_major> b[M_BLOCKS];
+    for (int j = 0; j < M_BLOCKS; j++) {
+      int base_row = (j < M_BLOCKS - 1) ? j * 16 : smem_rows_per_warp - 16;
+      const half *tile_ptr = shmem + (base_row * SMEM_STRIDE + k_step * 16);
+      wmma::load_matrix_sync(a[j], tile_ptr, SMEM_STRIDE);
+      wmma::load_matrix_sync(b[j], tile_ptr, SMEM_STRIDE);
+    }
+    for (int i = 0; i < M_BLOCKS; i++) {
+      for (int j = 0; j < M_BLOCKS; j++) {
+        wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]);
+      }
+    }
+  }
+  float *shmem_store = reinterpret_cast<float *>(shmem);
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      float *tile_ptr = shmem_store + (i * 16 * SMEM_STRIDE_ACC + j * 16);
+      wmma::store_matrix_sync(tile_ptr, acc[i][j], SMEM_STRIDE_ACC, wmma::mem_row_major);
+    }
+  }
+
+  half *gmem_interact_output = gmem_output + num_cols;
+  int lastRowBlockOffset = M_BLOCKS * 16 - smem_rows_per_warp;
+  int srcLine = 0;
+  for (int i = 0; i < num_rows; ++i, ++srcLine) {
+    if (i == ((M_BLOCKS - 1) * 16)) {
+      srcLine += lastRowBlockOffset;
+    }
+    if (lane_id < i) {
+      uint offset = (i * (i - 1)) >> 1;
+      gmem_interact_output[offset + lane_id] = __float2half(shmem_store[srcLine * SMEM_STRIDE_ACC + lane_id]);
+    }
+  }
+  // Padding
+  if (lane_id == 0) {
+    gmem_output[output_size - 1] = __float2half(0);
+  }
+}
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint M_BLOCKS,
+          uint K_BLOCKS,
+          uint SMEM_STRIDE,
+          uint SMEM_STRIDE_ACC,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractFwdKernel(const __half *__restrict input,
+                                                                              __half *__restrict output,
+                                                                              uint batch_size,
+                                                                              uint num_rows,
+                                                                              uint num_cols,
+                                                                              uint num_rows_after_padding,
+                                                                              uint num_cols_after_padding,
+                                                                              uint smem_elems_per_warp,
+                                                                              uint smem_rows_per_warp,
+                                                                              uint output_size,
+                                                                              uint num_row_steps,
+                                                                              uint num_col_steps) {
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  int sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  int lane_id = threadIdx.x & (WARP_SIZE - 1);
+
+  extern __shared__ half shmem_dynamic[];
+  half *shmem = shmem_dynamic + (warp_id * smem_elems_per_warp);
+
+  const half *sample_input = input + num_rows * num_cols * sample_id;
+  if (lane_id < (num_cols >> 2)) {
+    for (int i = 0; i < num_rows; ++i, sample_input += num_cols) {
+      ((float2 *)(shmem + i * SMEM_STRIDE))[lane_id] = ((float2 *)sample_input)[lane_id];
+    }
+  }
+
+  uint idx = lane_id + num_cols;
+  if (idx < num_cols_after_padding) {
+    for (int i = 0; i < num_rows; ++i) {
+      (shmem + i * SMEM_STRIDE)[idx] = __float2half(0);
+    }
+  }
+
+  half4 zeros;
+  zeros.vals[0].x = __float2half(0);
+  zeros.vals[0].y = __float2half(0);
+  zeros.vals[1].x = __float2half(0);
+  zeros.vals[1].y = __float2half(0);
+  if (lane_id < (num_cols_after_padding >> 2)) {
+    for (int i = num_rows; i < num_rows_after_padding; i++) {
+      ((half4 *)(shmem + i * SMEM_STRIDE))[lane_id] = zeros;
+    }
+  }
+  __syncwarp();
+  half *gmem_output = output + output_size * sample_id;
+  if (lane_id < (num_cols >> 2)) {
+    ((float2 *)gmem_output)[lane_id] = ((float2 *)shmem)[lane_id];
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[M_BLOCKS][M_BLOCKS];
+
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      wmma::fill_fragment(acc[i][j], 0);
+    }
+  }
+
+  for (int k_step = 0; k_step < num_col_steps; k_step++) {
+    wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[M_BLOCKS];
+    wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::col_major> b[M_BLOCKS];
+    for (int j = 0; j < M_BLOCKS; j++) {
+      int base_row = (j < M_BLOCKS - 1) ? j * 16 : smem_rows_per_warp - 16;
+      const half *tile_ptr = shmem + (base_row * SMEM_STRIDE + k_step * 16);
+      wmma::load_matrix_sync(a[j], tile_ptr, SMEM_STRIDE);
+      wmma::load_matrix_sync(b[j], tile_ptr, SMEM_STRIDE);
+    }
+    for (int i = 0; i < M_BLOCKS; i++) {
+      for (int j = 0; j < M_BLOCKS; j++) {
+        wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]);
+      }
+    }
+  }
+  float *shmem_store = reinterpret_cast<float *>(shmem);
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      float *tile_ptr = shmem_store + (i * 16 * SMEM_STRIDE_ACC + j * 16);
+      wmma::store_matrix_sync(tile_ptr, acc[i][j], SMEM_STRIDE_ACC, wmma::mem_row_major);
+    }
+  }
+
+  half *gmem_interact_output = gmem_output + num_cols;
+  int lastRowBlockOffset = M_BLOCKS * 16 - smem_rows_per_warp;
+  int srcLine = 0;
+  for (int i = 0; i < num_rows; ++i, ++srcLine) {
+    if (i == ((M_BLOCKS - 1) * 16)) {
+      srcLine += lastRowBlockOffset;
+    }
+    if (lane_id < i) {
+      uint offset = (i * (i - 1)) >> 1;
+      gmem_interact_output[offset + lane_id] = __float2half(shmem_store[srcLine * SMEM_STRIDE_ACC + lane_id]);
+    }
+  }
+  // Padding
+  if (lane_id == 0) {
+    gmem_output[output_size - 1] = __float2half(0);
+  }
+}
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint ROW_TILES_PER_STEP,
+          uint COL_TILES_PER_STEP,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__
+    void dotBasedInteractBwdKernelNonAligned(const __half *__restrict input,
+                                             const __half *__restrict upstream_grad,
+                                             half __restrict *grad,
+                                             half __restrict *bottom_mlp_grad,
+                                             uint batch_size,
+                                             uint num_rows,
+                                             uint num_cols,
+                                             uint num_rows_after_padding,
+                                             uint num_cols_after_padding,
+                                             uint sample_size,
+                                             uint interaction_ugrad_size,
+                                             uint interaction_ugrad_size_with_padding,
+                                             uint interaction_ugrad_2D_size_elems,
+                                             uint interaction_ugrad_2D_stride,
+                                             uint input_size_elems,
+                                             uint input_stride,
+                                             uint num_row_steps,
+                                             uint num_col_steps,
+                                             uint row_tiles_per_step,
+                                             uint shared_mem_per_warp_size_byte) {
+  extern __shared__ half shared_mem[];
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  uint lane_id = threadIdx.x & (WARP_SIZE - 1);
+  // ">> 1" to convert to half pointer
+  uint smem_warp_offset = warp_id * (shared_mem_per_warp_size_byte >> 1);
+
+  half *smem_in = &shared_mem[smem_warp_offset];
+  half *smem_temp = &shared_mem[smem_warp_offset + input_size_elems];
+  float *smem_out = reinterpret_cast<float *>(smem_temp);
+
+  // Global memory pointers for the current sample
+  // Input
+  uint gmem_input_sample_offset = sample_id * sample_size;
+  const half *gmem_input = &input[gmem_input_sample_offset];
+
+  // Interaction Gradient
+  const uint &gmem_grad_sample_offset = gmem_input_sample_offset;
+  half *gmem_grad = &grad[gmem_grad_sample_offset];
+
+  // Bottom MLP gradient
+  half *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols];
+
+  // Upstream gradient vector
+  uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding);
+  const half *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset];
+
+  // Upstream gradient vector for interactions
+  const half *gmem_ugrad_interactions = &gmem_ugrad[num_cols];
+
+  // upstream grad -> shared memory (place in input section temporarily)
+#pragma unroll
+  for (uint idx = lane_id; idx < interaction_ugrad_size; idx += WARP_SIZE) {
+    smem_in[idx] = gmem_ugrad_interactions[idx];
+  }
+  __syncwarp();
+  // Form the 2D ugrad matrix.
+  if (lane_id < num_rows_after_padding) {
+    uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1);
+    uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride;
+    for (uint row = 0; row < num_rows; row++) {
+      half ugrad_val = __float2half(0.0f);
+      if (row < lane_id && lane_id < num_rows) {
+        ugrad_val = smem_in[ugrad_flat_index + row];
+        smem_temp[ugrad_offset_1 + row] = ugrad_val;
+      }
+      if (row <= lane_id && lane_id < num_rows_after_padding) {
+        smem_temp[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val;
+      }
+    }
+    for (uint row = num_rows; row < num_rows_after_padding; row++) {
+      smem_temp[row * interaction_ugrad_2D_stride + lane_id] = __float2half(0.0f);
+    }
+  }
+  __syncwarp();
+
+  // Input -> Shared Memory
+
+  for (uint row = 0; row < num_rows; row++) {
+    half *smem_row_ptr = &smem_in[row * input_stride];
+    const half *gmem_row_ptr = &gmem_input[row * num_cols];
+    for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+      smem_row_ptr[idx] = gmem_row_ptr[idx];
+    }
+    uint idx = lane_id + num_cols;
+    if (idx < num_cols_after_padding) {
+      smem_row_ptr[idx] = __float2half(0);
+    }
+  }
+
+#pragma unroll 2
+  for (uint row = num_rows; row < num_rows_after_padding; row++) {
+    half *smem_row_ptr = &smem_in[row * input_stride];
+    for (uint idx = lane_id; idx < num_cols_after_padding; idx += WARP_SIZE) {
+      smem_row_ptr[idx] = __float2half(0);
+    }
+  }
+  __syncwarp();
+
+  wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[ROW_TILES_PER_STEP]
+                                                                                       [ROW_TILES_PER_STEP];
+  for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+    for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+      const half *tile_ptr = smem_temp + ((i * interaction_ugrad_2D_stride + j) << TILE_DIM_LOG_2);
+      wmma::load_matrix_sync(a[i][j], tile_ptr, interaction_ugrad_2D_stride);
+    }
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[ROW_TILES_PER_STEP];
+  wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> b[ROW_TILES_PER_STEP];
+  for (int col_step = 0; col_step < num_col_steps; col_step++) {
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      const half *tile_ptr = smem_in + ((i * input_stride + col_step) << TILE_DIM_LOG_2);
+      wmma::fill_fragment(acc[i], 0);
+      wmma::load_matrix_sync(b[i], tile_ptr, input_stride);
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+        wmma::mma_sync(acc[i], a[i][j], b[j], acc[i]);
+      }
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      float *tile_ptr = smem_out + i * TILE_DIM * TILE_DIM;
+      wmma::store_matrix_sync(tile_ptr, acc[i], TILE_DIM, wmma::mem_row_major);
+    }
+    __syncwarp();
+    uint gmem_grad_col = (col_step << TILE_DIM_LOG_2) + lane_id;
+    if (gmem_grad_col < num_cols) {
+      for (uint i = 0; i < num_rows; i++) {
+        gmem_grad[i * num_cols + gmem_grad_col] = __float2half(smem_out[(i << TILE_DIM_LOG_2) + lane_id]);
+      }
+    }
+  }
+
+  for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+    gmem_mlp_grad[idx] = gmem_ugrad[idx];
+  }
+}
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint ROW_TILES_PER_STEP,
+          uint COL_TILES_PER_STEP,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractBwdKernel(const __half *__restrict input,
+                                                                              const __half *__restrict upstream_grad,
+                                                                              half __restrict *grad,
+                                                                              half __restrict *bottom_mlp_grad,
+                                                                              uint batch_size,
+                                                                              uint num_rows,
+                                                                              uint num_cols,
+                                                                              uint num_rows_after_padding,
+                                                                              uint num_cols_after_padding,
+                                                                              uint sample_size,
+                                                                              uint interaction_ugrad_size,
+                                                                              uint interaction_ugrad_size_with_padding,
+                                                                              uint interaction_ugrad_2D_size_elems,
+                                                                              uint interaction_ugrad_2D_stride,
+                                                                              uint input_size_elems,
+                                                                              uint input_stride,
+                                                                              uint num_row_steps,
+                                                                              uint num_col_steps,
+                                                                              uint row_tiles_per_step,
+                                                                              uint shared_mem_per_warp_size_byte) {
+  extern __shared__ half shared_mem[];
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  uint lane_id = threadIdx.x & (WARP_SIZE - 1);
+  // ">> 1" to convert to half pointer
+  uint smem_warp_offset = warp_id * (shared_mem_per_warp_size_byte >> 1);
+
+  half *smem_in = &shared_mem[smem_warp_offset];
+  half *smem_temp = &shared_mem[smem_warp_offset + input_size_elems];
+  float *smem_out = reinterpret_cast<float *>(smem_temp);
+
+  // Global memory pointers for the current sample
+  // Input
+  uint gmem_input_sample_offset = sample_id * sample_size;
+  const half *gmem_input = &input[gmem_input_sample_offset];
+
+  // Interaction Gradient
+  const uint &gmem_grad_sample_offset = gmem_input_sample_offset;
+  half *gmem_grad = &grad[gmem_grad_sample_offset];
+
+  // Bottom MLP gradient
+  half *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols];
+
+  // Upstream gradient vector
+  uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding);
+  const half *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset];
+
+  // Upstream gradient vector for interactions
+  const half *gmem_ugrad_interactions = &gmem_ugrad[num_cols];
+
+  // upstream grad -> shared memory (place in input section temporarily)
+#pragma unroll
+  for (uint idx = lane_id; idx < (interaction_ugrad_size >> 3); idx += WARP_SIZE) {
+    ((float4 *)smem_in)[idx] = ((float4 *)gmem_ugrad_interactions)[idx];
+  }
+  uint offset = (interaction_ugrad_size >> 3) << 3;
+  for (uint idx = lane_id + offset; idx < interaction_ugrad_size; idx += WARP_SIZE) {
+    smem_in[idx] = gmem_ugrad_interactions[idx];
+  }
+  __syncwarp();
+  // Form the 2D ugrad matrix.
+  if (lane_id < num_rows_after_padding) {
+    uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1);
+    uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride;
+    for (uint row = 0; row < num_rows; row++) {
+      half ugrad_val = __float2half(0.0f);
+      if (row < lane_id && lane_id < num_rows) {
+        ugrad_val = smem_in[ugrad_flat_index + row];
+        smem_temp[ugrad_offset_1 + row] = ugrad_val;
+      }
+      if (row <= lane_id && lane_id < num_rows_after_padding) {
+        smem_temp[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val;
+      }
+    }
+    for (uint row = num_rows; row < num_rows_after_padding; row++) {
+      smem_temp[row * interaction_ugrad_2D_stride + lane_id] = __float2half(0.0f);
+    }
+  }
+  __syncwarp();
+
+  // Input -> Shared Memory
+
+  if (lane_id < (num_cols >> 2)) {
+    for (uint row = 0; row < num_rows; row++) {
+      half *smem_row_ptr = &smem_in[row * input_stride];
+      const half *gmem_row_ptr = &gmem_input[row * num_cols];
+      ((float2 *)smem_row_ptr)[lane_id] = ((float2 *)gmem_row_ptr)[lane_id];
+    }
+  }
+
+  uint idx = lane_id + num_cols;
+  if (idx < num_cols_after_padding) {
+    for (uint row = 0; row < num_rows; row++) {
+      half *smem_row_ptr = &smem_in[row * input_stride];
+      smem_row_ptr[idx] = __float2half(0);
+    }
+  }
+
+  half4 zeros;
+  zeros.vals[0].x = __float2half(0);
+  zeros.vals[0].y = __float2half(0);
+  zeros.vals[1].x = __float2half(0);
+  zeros.vals[1].y = __float2half(0);
+  if (lane_id < (num_cols_after_padding >> 2)) {
+#pragma unroll 2
+    for (uint row = num_rows; row < num_rows_after_padding; row++) {
+      half *smem_row_ptr = &smem_in[row * input_stride];
+      ((half4 *)smem_row_ptr)[lane_id] = zeros;
+    }
+  }
+  __syncwarp();
+
+  wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[ROW_TILES_PER_STEP]
+                                                                                       [ROW_TILES_PER_STEP];
+  for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+    for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+      const half *tile_ptr = smem_temp + ((i * interaction_ugrad_2D_stride + j) << TILE_DIM_LOG_2);
+      wmma::load_matrix_sync(a[i][j], tile_ptr, interaction_ugrad_2D_stride);
+    }
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[ROW_TILES_PER_STEP];
+  wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> b[ROW_TILES_PER_STEP];
+  for (int col_step = 0; col_step < num_col_steps; col_step++) {
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      const half *tile_ptr = smem_in + ((i * input_stride + col_step) << TILE_DIM_LOG_2);
+      wmma::fill_fragment(acc[i], 0);
+      wmma::load_matrix_sync(b[i], tile_ptr, input_stride);
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+        wmma::mma_sync(acc[i], a[i][j], b[j], acc[i]);
+      }
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      float *tile_ptr = smem_out + i * TILE_DIM * TILE_DIM;
+      wmma::store_matrix_sync(tile_ptr, acc[i], TILE_DIM, wmma::mem_row_major);
+    }
+    __syncwarp();
+    uint gmem_grad_col = (col_step << TILE_DIM_LOG_2) + lane_id;
+    if (gmem_grad_col < num_cols) {
+      for (uint i = 0; i < num_rows; i++) {
+        gmem_grad[i * num_cols + gmem_grad_col] = __float2half(smem_out[(i << TILE_DIM_LOG_2) + lane_id]);
+      }
+    }
+  }
+  if (lane_id < (num_cols >> 2)) {
+    ((float2 *)gmem_mlp_grad)[lane_id] = ((float2 *)gmem_ugrad)[lane_id];
+  }
+}
+
+inline void dotBasedInteractFwd(
+    const void *input, const void *bottom_mlp_output, void *output, uint batch_size, uint num_rows, uint num_cols) {
+  const uint kWarpSize = 32;
+  const uint kWarpSizeLog2 = Log2<kWarpSize>::value;
+  const uint kTileDim = 16;
+  const uint kTileDimLog2 = Log2<kTileDim>::value;
+  const uint warps_per_threadblock = 4;
+  const uint threadblock_size = warps_per_threadblock * 32;
+  const uint kPaddingSize = 1;
+  const uint kRowTilesPerStep = 2;
+  const uint kColTilesPerStep = 1;
+
+  // num tiles
+  uint num_row_tiles = (num_rows + kTileDim - 1) >> kTileDimLog2;
+  uint num_col_tiles = (num_cols + kTileDim - 1) >> kTileDimLog2;
+
+  // number of rows and columns after padding
+  uint num_rows_after_padding = kTileDim << 1;
+  uint num_cols_after_padding = num_col_tiles << kTileDimLog2;
+
+  uint num_row_steps = num_row_tiles / kRowTilesPerStep;
+  uint num_col_steps = num_col_tiles / kColTilesPerStep;
+
+  const uint K_BLOCKS = 8;
+  const uint M_BLOCKS = 2;
+  const uint SKEW_HALF = ((K_BLOCKS % 2) == 0) ? 8 : 0;
+  const uint SMEM_STRIDE = (K_BLOCKS * 16 + SKEW_HALF);
+  // multiple of 2 to guarantee 256-bit alignment for start of the row, at least 16 to safeload a tile
+  const uint smem_rows_per_warp = M_BLOCKS << 4;
+  const uint smem_elems_per_warp_mat = smem_rows_per_warp * SMEM_STRIDE;
+  const uint SKEW_HALF_ACC = ((M_BLOCKS % 2) == 0) ? 8 : 0;
+  const uint SMEM_STRIDE_ACC = (M_BLOCKS * 16 + SKEW_HALF_ACC);
+  const uint smem_elems_per_warp_acc = M_BLOCKS * 16 * SMEM_STRIDE_ACC * 2;  // output in FP32
+  const uint smem_elems_per_warp =
+      (smem_elems_per_warp_mat > smem_elems_per_warp_acc) ? smem_elems_per_warp_mat : smem_elems_per_warp_acc;
+  uint output_size = num_cols + ((num_rows * (num_rows - 1)) >> 1) + kPaddingSize;
+
+  bool float4_predicate = !((num_cols & 7) || (output_size & 7));
+
+  if (float4_predicate) {
+    dotBasedInteractFwdKernel<warps_per_threadblock,
+                              threadblock_size,
+                              M_BLOCKS,
+                              K_BLOCKS,
+                              SMEM_STRIDE,
+                              SMEM_STRIDE_ACC,
+                              kWarpSize,
+                              kWarpSizeLog2,
+                              kTileDim,
+                              kTileDimLog2>
+        <<<(batch_size + warps_per_threadblock - 1) / warps_per_threadblock,
+           threadblock_size,
+           warps_per_threadblock * smem_elems_per_warp * sizeof(__half)>>>((const __half *)input,
+                                                                           (half *)output,
+                                                                           batch_size,
+                                                                           num_rows,
+                                                                           num_cols,
+                                                                           num_rows_after_padding,
+                                                                           num_cols_after_padding,
+                                                                           smem_elems_per_warp,
+                                                                           smem_rows_per_warp,
+                                                                           output_size,
+                                                                           num_row_steps,
+                                                                           num_col_steps);
+  } else {
+    dotBasedInteractFwdKernelNonAligned<warps_per_threadblock,
+                                        threadblock_size,
+                                        M_BLOCKS,
+                                        K_BLOCKS,
+                                        SMEM_STRIDE,
+                                        SMEM_STRIDE_ACC,
+                                        kWarpSize,
+                                        kWarpSizeLog2,
+                                        kTileDim,
+                                        kTileDimLog2>
+        <<<(batch_size + warps_per_threadblock - 1) / warps_per_threadblock,
+           threadblock_size,
+           warps_per_threadblock * smem_elems_per_warp * sizeof(__half)>>>((const __half *)input,
+                                                                           (half *)output,
+                                                                           batch_size,
+                                                                           num_rows,
+                                                                           num_cols,
+                                                                           num_rows_after_padding,
+                                                                           num_cols_after_padding,
+                                                                           smem_elems_per_warp,
+                                                                           smem_rows_per_warp,
+                                                                           output_size,
+                                                                           num_row_steps,
+                                                                           num_col_steps);
+  }
+}
+
+inline void dotBasedInteractBwd(void *input,
+                                void *upstream_grad,
+                                void *grad,
+                                void *bottom_mlp_grad,
+                                uint batch_size,
+                                uint num_rows,
+                                uint num_cols) {
+  const uint kWarpSize = 32;
+  const uint kWarpSizeLog2 = Log2<kWarpSize>::value;
+  const uint kTileDim = 16;
+  const uint kTileDimLog2 = Log2<kTileDim>::value;
+  const uint mem_skew_size = 8;
+  const uint kPaddingSize = 1;
+  const uint kWarpsPerBlock = 4;
+  const uint kWarpsPerBlockLog2 = Log2<kWarpsPerBlock>::value;
+  const uint kNumThreads = kWarpsPerBlock * kWarpSize;
+  const uint kRowTilesPerStep = 2;
+  const uint kColTilesPerStep = 1;
+
+  uint row_tiles_per_step = num_rows > kTileDim ? kRowTilesPerStep : 1;
+
+  // num tiles
+  uint num_row_tiles = (num_rows + kTileDim - 1) >> kTileDimLog2;
+  uint num_col_tiles = (num_cols + kTileDim - 1) >> kTileDimLog2;
+
+  // number of rows and columns after padding
+  uint num_rows_after_padding = kTileDim << 1;
+  uint num_cols_after_padding = num_col_tiles << kTileDimLog2;
+
+  // 2D ugrad size and stride
+  uint interaction_ugrad_2D_stride = num_rows_after_padding + mem_skew_size;
+  uint interaction_ugrad_2D_size_elems = num_rows_after_padding * interaction_ugrad_2D_stride;
+  uint interaction_ugrad_2D_size_bytes = interaction_ugrad_2D_size_elems * sizeof(half);
+
+  // 1D ugrad size
+  uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1;
+  uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize;
+
+  // in_out place size and stride
+  uint input_stride = num_cols_after_padding + mem_skew_size;
+  uint input_size_elems = num_rows_after_padding * input_stride;
+  uint input_size_bytes = input_size_elems * sizeof(half);
+
+  // sample size
+  uint sample_size = num_rows * num_cols;
+
+  // output size
+  uint output_size_elems = kTileDim * kTileDim * kRowTilesPerStep * kColTilesPerStep;
+  uint output_size_bytes = output_size_elems * sizeof(float);
+
+  // staging area size
+  uint staging_area_size_bytes =
+      output_size_bytes > interaction_ugrad_2D_size_bytes ? output_size_bytes : interaction_ugrad_2D_size_bytes;
+
+  // Shared memory size
+  uint shared_mem_per_warp_size_byte = input_size_bytes + staging_area_size_bytes;
+  uint shared_mem_size_bytes = kWarpsPerBlock * shared_mem_per_warp_size_byte;
+
+  uint num_blocks = (batch_size + kWarpsPerBlock - 1) >> kWarpsPerBlockLog2;
+  uint num_row_steps = num_row_tiles / row_tiles_per_step;
+  uint num_col_steps = num_col_tiles / kColTilesPerStep;
+
+  bool float4_predicate = !((interaction_ugrad_size_with_padding & 7) || (num_cols & 7));
+  if (float4_predicate) {
+    dotBasedInteractBwdKernel<kWarpsPerBlock,
+                              kNumThreads,
+                              kRowTilesPerStep,
+                              kColTilesPerStep,
+                              kWarpSize,
+                              kWarpSizeLog2,
+                              kTileDim,
+                              kTileDimLog2>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const half *)input,
+                                                             (const half *)upstream_grad,
+                                                             (half *)grad,
+                                                             (half *)bottom_mlp_grad,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             num_rows_after_padding,
+                                                             num_cols_after_padding,
+                                                             sample_size,
+                                                             interaction_ugrad_size,
+                                                             interaction_ugrad_size_with_padding,
+                                                             interaction_ugrad_2D_size_elems,
+                                                             interaction_ugrad_2D_stride,
+                                                             input_size_elems,
+                                                             input_stride,
+                                                             num_row_steps,
+                                                             num_col_steps,
+                                                             row_tiles_per_step,
+                                                             shared_mem_per_warp_size_byte);
+  } else {
+    dotBasedInteractBwdKernelNonAligned<kWarpsPerBlock,
+                                        kNumThreads,
+                                        kRowTilesPerStep,
+                                        kColTilesPerStep,
+                                        kWarpSize,
+                                        kWarpSizeLog2,
+                                        kTileDim,
+                                        kTileDimLog2>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const half *)input,
+                                                             (const half *)upstream_grad,
+                                                             (half *)grad,
+                                                             (half *)bottom_mlp_grad,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             num_rows_after_padding,
+                                                             num_cols_after_padding,
+                                                             sample_size,
+                                                             interaction_ugrad_size,
+                                                             interaction_ugrad_size_with_padding,
+                                                             interaction_ugrad_2D_size_elems,
+                                                             interaction_ugrad_2D_stride,
+                                                             input_size_elems,
+                                                             input_stride,
+                                                             num_row_steps,
+                                                             num_col_steps,
+                                                             row_tiles_per_step,
+                                                             shared_mem_per_warp_size_byte);
+  }
+}

+ 361 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact_fp32.cu

@@ -0,0 +1,361 @@
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime_api.h>
+#include <device_launch_parameters.h>
+#include <mma.h>
+#include <cuda_fp16.hpp>
+
+#include <math.h>
+#include <fstream>
+#include <iomanip>
+#include <iostream>
+#include <vector>
+
+#include "shared_utils.cuh"
+
+template <uint THREADBLOCK_SIZE>
+__launch_bounds__(THREADBLOCK_SIZE) __global__
+    void dotBasedInteractF32FwdKernelNonAligned(const float *__restrict input,
+                                                float *__restrict output,
+                                                uint batch_size,
+                                                uint num_rows,
+                                                uint num_cols,
+                                                uint input_size,
+                                                uint output_size,
+                                                uint interaction_output_size) {
+  extern __shared__ float smem_f32_fwd[];
+  float *smem_in = &smem_f32_fwd[0];
+
+  uint input_batch_offset = blockIdx.x * input_size;
+  const float *gmem_in = &input[input_batch_offset];
+
+  uint output_batch_offset = blockIdx.x * output_size;
+  float *gmem_out_bottom_mlp = &output[output_batch_offset];
+  float *gmem_out_interaction = &output[output_batch_offset + num_cols];
+
+  // Load the input - one sample per block
+  for (uint idx = threadIdx.x; idx < input_size; idx += blockDim.x) {
+    smem_in[idx] = gmem_in[idx];
+  }
+  __syncthreads();
+
+  // Copy bottom MLP output to output
+  for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) {
+    gmem_out_bottom_mlp[idx] = smem_in[idx];
+  }
+
+  for (uint idx = threadIdx.x; idx < (interaction_output_size); idx += blockDim.x) {
+    uint elems_per_row = 1;
+    uint index = idx;
+    while (index >= elems_per_row) {
+      index -= elems_per_row;
+      elems_per_row++;
+    }
+    uint target_row = elems_per_row;
+    uint target_col = index;
+
+    float sum = 0;
+    for (uint i = 0; i < num_cols; i++) {
+      float tmp1 = smem_in[target_row * num_cols + i];
+      float tmp2 = smem_in[target_col * num_cols + i];
+      sum = fmaf(tmp1, tmp2, sum);
+    }
+
+    gmem_out_interaction[idx] = sum;
+  }
+
+  gmem_out_interaction[interaction_output_size] = 0;
+}
+
+template <uint THREADBLOCK_SIZE>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractF32FwdKernel(const float *__restrict input,
+                                                                                 float *__restrict output,
+                                                                                 uint batch_size,
+                                                                                 uint num_rows,
+                                                                                 uint num_cols,
+                                                                                 uint input_size,
+                                                                                 uint output_size,
+                                                                                 uint interaction_output_size) {
+  extern __shared__ float smem_f32_fwd[];
+  float *smem_in = &smem_f32_fwd[0];
+
+  uint input_batch_offset = blockIdx.x * input_size;
+  const float *gmem_in = &input[input_batch_offset];
+
+  uint output_batch_offset = blockIdx.x * output_size;
+  float *gmem_out_bottom_mlp = &output[output_batch_offset];
+  float *gmem_out_interaction = &output[output_batch_offset + num_cols];
+
+  // Load the input - one sample per block
+  uint input_size_float4 = input_size >> 2;
+  for (uint idx = threadIdx.x; idx < input_size_float4; idx += blockDim.x) {
+    ((float4 *)smem_in)[idx] = ((float4 *)gmem_in)[idx];
+  }
+  __syncthreads();
+
+  // Copy bottom MLP output to output
+  uint btm_mlp_out_size_float4 = num_cols >> 2;
+  for (uint idx = threadIdx.x; idx < btm_mlp_out_size_float4; idx += blockDim.x) {
+    ((float4 *)gmem_out_bottom_mlp)[idx] = ((float4 *)smem_in)[idx];
+  }
+
+  for (uint idx = threadIdx.x; idx < (interaction_output_size); idx += blockDim.x) {
+    uint elems_per_row = 1;
+    uint index = idx;
+    while (index >= elems_per_row) {
+      index -= elems_per_row;
+      elems_per_row++;
+    }
+    uint target_row = elems_per_row;
+    uint target_col = index;
+
+    float4 sum;
+    sum.x = 0;
+    sum.y = 0;
+    sum.z = 0;
+    sum.w = 0;
+    uint num_cols_float4 = num_cols >> 2;
+    for (uint i = 0; i < num_cols_float4; i++) {
+      float4 tmp1 = ((float4 *)smem_in)[target_row * num_cols_float4 + i];
+      float4 tmp2 = ((float4 *)smem_in)[target_col * num_cols_float4 + i];
+      sum.x = fmaf(tmp1.x, tmp2.x, sum.x);
+      sum.y = fmaf(tmp1.y, tmp2.y, sum.y);
+      sum.z = fmaf(tmp1.z, tmp2.z, sum.z);
+      sum.w = fmaf(tmp1.w, tmp2.w, sum.w);
+    }
+
+    gmem_out_interaction[idx] = sum.x + sum.y + sum.z + sum.w;
+  }
+
+  gmem_out_interaction[interaction_output_size] = 0;
+}
+
+inline void dotBasedInteractF32Fwd(const void *input,
+                                   const void *bottom_mlp_output,
+                                   const void *output,
+                                   uint batch_size,
+                                   uint num_rows,
+                                   uint num_cols) {
+  const uint kPaddingSize = 1;
+  const uint kNumThreads = 128;
+  uint num_blocks = batch_size;
+
+  // Output
+  uint interaction_output_size = (num_rows * (num_rows - 1)) >> 1;
+  uint output_size = num_cols + interaction_output_size + kPaddingSize;
+
+  // Input
+  uint input_size = num_rows * num_cols;
+
+  uint shared_mem_size_elems = input_size;
+  uint shared_mem_size_bytes = shared_mem_size_elems << 2;  // F32 Kernel
+
+  bool float4_predicate = !((num_cols & 3) || (output_size & 3));
+
+  if (float4_predicate) {
+    dotBasedInteractF32FwdKernel<kNumThreads>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const float *)input,
+                                                             (float *)output,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             input_size,
+                                                             output_size,
+                                                             interaction_output_size);
+  } else {
+    dotBasedInteractF32FwdKernelNonAligned<kNumThreads>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const float *)input,
+                                                             (float *)output,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             input_size,
+                                                             output_size,
+                                                             interaction_output_size);
+  }
+}
+
+template <uint THREADBLOCK_SIZE>
+__launch_bounds__(THREADBLOCK_SIZE) __global__
+    void dotBasedInteractF32BwdKernelNonAligned(const float *__restrict input,
+                                                const float *__restrict upstream_grad,
+                                                float *__restrict grad,
+                                                float *__restrict bottom_mlp_grad,
+                                                uint batch_size,
+                                                uint num_rows,
+                                                uint num_cols,
+                                                uint input_size,
+                                                uint ugrad_size,
+                                                uint interaction_ugrad_size) {
+  extern __shared__ float smem_f32_bwd[];
+  float *smem_in = &smem_f32_bwd[0];
+  float *smem_interaction_ugrad = &smem_f32_bwd[input_size];
+
+  // Input
+  uint input_batch_offset = blockIdx.x * input_size;
+  const float *gmem_in = &input[input_batch_offset];
+
+  // Gradient
+  const uint &grad_batch_offset = input_batch_offset;
+  float *gmem_mlp_grad = &bottom_mlp_grad[blockIdx.x * num_cols];
+  float *gmem_interaction_grad = &grad[grad_batch_offset];
+
+  // Upstream Gradient
+  uint upstream_grad_batch_offset = blockIdx.x * ugrad_size;
+  const float *gmem_mlp_ugrad = &upstream_grad[upstream_grad_batch_offset];
+  const float *gmem_interaction_ugrad = &upstream_grad[upstream_grad_batch_offset + num_cols];
+
+  // input -> shared memory
+  for (uint idx = threadIdx.x; idx < input_size; idx += blockDim.x) {
+    smem_in[idx] = gmem_in[idx];
+  }
+
+  // Interaction Upstream Grad -> Shared Memory
+  for (uint idx = threadIdx.x; idx < interaction_ugrad_size; idx += blockDim.x) {
+    smem_interaction_ugrad[idx] = gmem_interaction_ugrad[idx];
+  }
+  __syncthreads();
+
+  // Copy the upstream gradient w.r.t to mlp to it's corresponding memory location.
+  for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) {
+    gmem_mlp_grad[idx] = gmem_mlp_ugrad[idx];
+  }
+
+  for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) {
+    size_t grad_idx = idx;
+    for (uint row_idx = 0; row_idx < num_rows; row_idx++) {
+      float sum = 0;
+      size_t upstream_grad_offset = (row_idx * (row_idx - 1)) >> 1;
+      for (int k = 0; k < row_idx; k++) {
+        sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + k], sum);
+      }
+      for (int k = row_idx + 1; k < num_rows; k++) {
+        upstream_grad_offset = (k * (k - 1)) >> 1;  // TODO: this can become a sum
+        sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + row_idx], sum);
+      }
+      gmem_interaction_grad[grad_idx] = sum;
+      grad_idx += num_cols;
+    }
+  }
+}
+
+template <uint THREADBLOCK_SIZE>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractF32BwdKernel(const float *__restrict input,
+                                                                                 const float *__restrict upstream_grad,
+                                                                                 float *__restrict grad,
+                                                                                 float *__restrict bottom_mlp_grad,
+                                                                                 uint batch_size,
+                                                                                 uint num_rows,
+                                                                                 uint num_cols,
+                                                                                 uint input_size,
+                                                                                 uint ugrad_size,
+                                                                                 uint interaction_ugrad_size) {
+  extern __shared__ float smem_f32_bwd[];
+  float *smem_in = &smem_f32_bwd[0];
+  float *smem_interaction_ugrad = &smem_f32_bwd[input_size];
+
+  // Input
+  uint input_batch_offset = blockIdx.x * input_size;
+  const float *gmem_in = &input[input_batch_offset];
+
+  // Gradient
+  const uint &grad_batch_offset = input_batch_offset;
+  float *gmem_mlp_grad = &bottom_mlp_grad[blockIdx.x * num_cols];
+  float *gmem_interaction_grad = &grad[grad_batch_offset];
+
+  // Upstream Gradient
+  uint upstream_grad_batch_offset = blockIdx.x * ugrad_size;
+  const float *gmem_mlp_ugrad = &upstream_grad[upstream_grad_batch_offset];
+  const float *gmem_interaction_ugrad = &upstream_grad[upstream_grad_batch_offset + num_cols];
+
+  // input -> shared memory
+  uint input_size_float4 = input_size >> 2;
+  for (uint idx = threadIdx.x; idx < input_size_float4; idx += blockDim.x) {
+    ((float4 *)smem_in)[idx] = ((float4 *)gmem_in)[idx];
+  }
+
+  // Interaction Upstream Grad -> Shared Memory
+  uint upstream_grad_size_float4 = interaction_ugrad_size >> 2;
+  for (uint idx = threadIdx.x; idx < upstream_grad_size_float4; idx += blockDim.x) {
+    ((float4 *)smem_interaction_ugrad)[idx] = ((float4 *)gmem_interaction_ugrad)[idx];
+  }
+
+  uint vectorized_load_offset = (upstream_grad_size_float4 << 2);
+  for (uint idx = vectorized_load_offset + threadIdx.x; idx < interaction_ugrad_size; idx += blockDim.x) {
+    smem_interaction_ugrad[idx] = gmem_interaction_ugrad[idx];
+  }
+  __syncthreads();
+
+  // Copy the upstream gradient w.r.t to mlp to it's corresponding memory location.
+  for (uint idx = threadIdx.x; idx < (num_cols >> 2); idx += blockDim.x) {
+    ((float4 *)gmem_mlp_grad)[idx] = ((float4 *)gmem_mlp_ugrad)[idx];
+  }
+
+  for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) {
+    size_t grad_idx = idx;
+    for (uint row_idx = 0; row_idx < num_rows; row_idx++) {
+      float sum = 0;
+      size_t upstream_grad_offset = (row_idx * (row_idx - 1)) >> 1;
+      for (int k = 0; k < row_idx; k++) {
+        sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + k], sum);
+      }
+      for (int k = row_idx + 1; k < num_rows; k++) {
+        upstream_grad_offset = (k * (k - 1)) >> 1;  // TODO: this can become a sum
+        sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + row_idx], sum);
+      }
+      gmem_interaction_grad[grad_idx] = sum;
+      grad_idx += num_cols;
+    }
+  }
+}
+
+inline void dotBasedInteractF32Bwd(const void *input,
+                                   const void *upstream_grad,
+                                   void *grad,
+                                   void *bottom_mlp_grad,
+                                   uint batch_size,
+                                   uint num_rows,
+                                   uint num_cols) {
+  const uint kPaddingSize = 1;
+  const uint kNumThreads = 128;
+
+  uint num_blocks = batch_size;
+
+  uint input_size = num_rows * num_cols;
+
+  // 1D ugrad size
+  uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1;
+  uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize;
+  uint ugrad_size = num_cols + interaction_ugrad_size_with_padding;
+
+  // input space + upstream grad space
+  uint smem_size_elems = input_size + interaction_ugrad_size;
+  uint smem_size_bytes = smem_size_elems << 2;  // F32 Kernel
+
+  bool float4_predicate = !((interaction_ugrad_size_with_padding & 3) || (num_cols & 3));
+  if (float4_predicate) {
+    dotBasedInteractF32BwdKernel<kNumThreads>
+        <<<num_blocks, kNumThreads, smem_size_bytes>>>((const float *)input,
+                                                       (const float *)upstream_grad,
+                                                       (float *)grad,
+                                                       (float *)bottom_mlp_grad,
+                                                       batch_size,
+                                                       num_rows,
+                                                       num_cols,
+                                                       input_size,
+                                                       ugrad_size,
+                                                       interaction_ugrad_size);
+  } else {
+    dotBasedInteractF32BwdKernelNonAligned<kNumThreads>
+        <<<num_blocks, kNumThreads, smem_size_bytes>>>((const float *)input,
+                                                       (const float *)upstream_grad,
+                                                       (float *)grad,
+                                                       (float *)bottom_mlp_grad,
+                                                       batch_size,
+                                                       num_rows,
+                                                       num_cols,
+                                                       input_size,
+                                                       ugrad_size,
+                                                       interaction_ugrad_size);
+  }
+}

+ 74 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact_pytorch_types.cu

@@ -0,0 +1,74 @@
+#include <torch/extension.h>
+#include <torch/types.h>
+#include <stdexcept>
+
+#include "dot_based_interact.cu"
+#include "dot_based_interact_fp32.cu"
+#include "dot_based_interact_tf32.cu"
+
+
+torch::Tensor dotBasedInteractFwdTorch(torch::Tensor input, torch::Tensor bottom_mlp_output) {
+  const uint kPaddingSize = 1;
+  auto size = input.sizes();
+  auto batch_size = size[0];
+  auto num_rows = size[1];
+  auto num_cols = size[2];
+  uint output_size = ((num_rows * (num_rows - 1)) >> 1) + num_cols + kPaddingSize;
+
+  int64_t outputShape[2] = {batch_size, output_size};
+  auto output = torch::empty(c10::IntArrayRef(outputShape), input.options());
+  if (input.scalar_type() == torch::ScalarType::Half && bottom_mlp_output.scalar_type() == torch::ScalarType::Half) {
+    dotBasedInteractFwd(input.contiguous().data_ptr<at::Half>(),
+                        bottom_mlp_output.contiguous().data_ptr<at::Half>(),
+                        output.contiguous().data_ptr<at::Half>(),
+                        batch_size,
+                        num_rows,
+                        num_cols);
+  } else if (input.scalar_type() == torch::ScalarType::Float &&
+             bottom_mlp_output.scalar_type() == torch::ScalarType::Float) {
+
+    dotBasedInteractTF32Fwd(input.contiguous().data_ptr<float>(),
+                            bottom_mlp_output.contiguous().data_ptr<float>(),
+                            output.contiguous().data_ptr<float>(),
+                            batch_size,
+                            num_rows,
+                            num_cols);
+  } else {
+    throw std::invalid_argument("Invalid input type.");
+  }
+  return output;
+}
+
+std::vector<torch::Tensor> dotBasedInteractBwdTorch(torch::Tensor input, torch::Tensor upstreamGrad) {
+  auto size = input.sizes();
+  auto batch_size = size[0];
+  auto num_rows = size[1];
+  auto num_cols = size[2];
+
+  auto outputGrad = torch::empty_like(input);
+  int64_t outputShape[2] = {batch_size, num_cols};
+  auto mlp_grad = torch::empty(c10::IntArrayRef(outputShape), input.options());
+
+  if (input.scalar_type() == torch::ScalarType::Half && upstreamGrad.scalar_type() == torch::ScalarType::Half) {
+    dotBasedInteractBwd(input.contiguous().data_ptr<at::Half>(),
+                        upstreamGrad.contiguous().data_ptr<at::Half>(),
+                        outputGrad.contiguous().data_ptr<at::Half>(),
+                        mlp_grad.contiguous().data_ptr<at::Half>(),
+                        batch_size,
+                        num_rows,
+                        num_cols);
+  } else if (input.scalar_type() == torch::ScalarType::Float &&
+             upstreamGrad.scalar_type() == torch::ScalarType::Float) {
+
+    dotBasedInteractTF32Bwd(input.contiguous().data_ptr<float>(),
+                            upstreamGrad.contiguous().data_ptr<float>(),
+                            outputGrad.contiguous().data_ptr<float>(),
+                            mlp_grad.contiguous().data_ptr<float>(),
+                            batch_size,
+                            num_rows,
+                            num_cols);
+  } else {
+    throw std::invalid_argument("Invalid input type.");
+  }
+  return {outputGrad, mlp_grad};
+}

+ 833 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact_tf32.cu

@@ -0,0 +1,833 @@
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime_api.h>
+#include <device_launch_parameters.h>
+#include <mma.h>
+#include <cuda_fp16.hpp>
+
+#include <math.h>
+#include <fstream>
+#include <iomanip>
+#include <iostream>
+#include <vector>
+
+#include "shared_utils.cuh"
+
+using namespace nvcuda;
+
+using namespace nvcuda;
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint M_BLOCKS,
+          uint K_BLOCKS,
+          uint SMEM_STRIDE,
+          uint SMEM_STRIDE_ACC,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractFwdKernelNonAligned_(const __half *__restrict input,
+                                                                                         __half *__restrict output,
+                                                                                         uint batch_size,
+                                                                                         uint num_rows,
+                                                                                         uint num_cols,
+                                                                                         uint num_rows_after_padding,
+                                                                                         uint num_cols_after_padding,
+                                                                                         uint smem_elems_per_warp,
+                                                                                         uint smem_rows_per_warp,
+                                                                                         uint output_size,
+                                                                                         uint num_row_steps,
+                                                                                         uint num_col_steps) {
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  int sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  int lane_id = threadIdx.x & (WARP_SIZE - 1);
+
+  extern __shared__ half shmem_dynamic_[];
+  half *shmem = shmem_dynamic_ + (warp_id * smem_elems_per_warp);
+
+  const half *sample_input = input + num_rows * num_cols * sample_id;
+  for (uint i = 0; i < num_rows; ++i, sample_input += num_cols) {
+    for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+      (shmem + i * SMEM_STRIDE)[idx] = sample_input[idx];
+    }
+  }
+
+  uint idx = lane_id + num_cols;
+  if (idx < num_cols_after_padding) {
+    for (int i = 0; i < num_rows; ++i) {
+      (shmem + i * SMEM_STRIDE)[idx] = __float2half(0);
+    }
+  }
+
+  half4 zeros;
+  zeros.vals[0].x = __float2half(0);
+  zeros.vals[0].y = __float2half(0);
+  zeros.vals[1].x = __float2half(0);
+  zeros.vals[1].y = __float2half(0);
+  if (lane_id < (num_cols_after_padding >> 2)) {
+    for (int i = num_rows; i < num_rows_after_padding; i++) {
+      ((half4 *)(shmem + i * SMEM_STRIDE))[lane_id] = zeros;
+    }
+  }
+  __syncwarp();
+  half *gmem_output = output + output_size * sample_id;
+
+  for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+    gmem_output[idx] = shmem[idx];
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[M_BLOCKS][M_BLOCKS];
+
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      wmma::fill_fragment(acc[i][j], 0);
+    }
+  }
+
+  for (int k_step = 0; k_step < num_col_steps; k_step++) {
+    wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[M_BLOCKS];
+    wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::col_major> b[M_BLOCKS];
+    for (int j = 0; j < M_BLOCKS; j++) {
+      int base_row = (j < M_BLOCKS - 1) ? j * 16 : smem_rows_per_warp - 16;
+      const half *tile_ptr = shmem + (base_row * SMEM_STRIDE + k_step * 16);
+      wmma::load_matrix_sync(a[j], tile_ptr, SMEM_STRIDE);
+      wmma::load_matrix_sync(b[j], tile_ptr, SMEM_STRIDE);
+    }
+    for (int i = 0; i < M_BLOCKS; i++) {
+      for (int j = 0; j < M_BLOCKS; j++) {
+        wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]);
+      }
+    }
+  }
+  float *shmem_store = reinterpret_cast<float *>(shmem);
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      float *tile_ptr = shmem_store + (i * 16 * SMEM_STRIDE_ACC + j * 16);
+      wmma::store_matrix_sync(tile_ptr, acc[i][j], SMEM_STRIDE_ACC, wmma::mem_row_major);
+    }
+  }
+
+  half *gmem_interact_output = gmem_output + num_cols;
+  int lastRowBlockOffset = M_BLOCKS * 16 - smem_rows_per_warp;
+  int srcLine = 0;
+  for (int i = 0; i < num_rows; ++i, ++srcLine) {
+    if (i == ((M_BLOCKS - 1) * 16)) {
+      srcLine += lastRowBlockOffset;
+    }
+    if (lane_id < i) {
+      uint offset = (i * (i - 1)) >> 1;
+      gmem_interact_output[offset + lane_id] = __float2half(shmem_store[srcLine * SMEM_STRIDE_ACC + lane_id]);
+    }
+  }
+  // Padding
+  if (lane_id == 0) {
+    gmem_output[output_size - 1] = __float2half(0);
+  }
+}
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_LENGTH,
+          uint TILE_LENGTH_LOG_2,
+          uint TILE_WIDTH,
+          uint TILE_WIDTH_LOG_2,
+          uint ROW_TILES_PER_STEP>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractTF32FwdKernel(const float *__restrict input,
+                                                                                  float *__restrict output,
+                                                                                  uint batch_size,
+                                                                                  uint num_rows,
+                                                                                  uint num_cols,
+                                                                                  uint num_rows_after_padding,
+                                                                                  uint num_cols_after_padding,
+                                                                                  uint smem_elems_per_warp,
+                                                                                  uint output_size,
+                                                                                  uint num_row_steps,
+                                                                                  uint num_col_steps,
+                                                                                  uint smem_stride,
+                                                                                  uint smem_stride_acc) {
+  // The only support sizes for TF32.
+  const uint kWmmaM = 16;
+  const uint kWmmaN = 16;
+  const uint kWmmaK = 8;
+
+  uint warp_id = threadIdx.x >> WARP_SIZE_LOG_2;
+  uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  int lane_id = threadIdx.x & (WARP_SIZE - 1);
+
+  extern __shared__ float shmem_dynamic_float[];
+  float *shmem = shmem_dynamic_float + (warp_id * smem_elems_per_warp);
+
+  const float *gmem_input = input + num_rows * num_cols * sample_id;
+  if (lane_id < (num_cols >> 2)) {
+    for (int i = 0; i < num_rows; ++i, gmem_input += num_cols) {
+      float4 tmp = ((float4 *)gmem_input)[lane_id];
+      tmp.x = wmma::__float_to_tf32(tmp.x);
+      tmp.y = wmma::__float_to_tf32(tmp.y);
+      tmp.z = wmma::__float_to_tf32(tmp.z);
+      tmp.w = wmma::__float_to_tf32(tmp.w);
+      ((float4 *)(shmem + i * smem_stride))[lane_id] = tmp;
+    }
+  }
+
+  float zero = wmma::__float_to_tf32(0.0f);
+  float4 zero4;
+  zero4.x = zero;
+  zero4.y = zero;
+  zero4.z = zero;
+  zero4.w = zero;
+
+  uint idx = lane_id + num_cols;
+  if (idx < num_cols_after_padding) {
+    for (uint i = 0; i < num_rows; ++i) {
+      (shmem + i * smem_stride)[idx] = zero;
+    }
+  }
+
+  if (lane_id < (num_cols_after_padding >> 2)) {
+    for (int i = num_rows; i < num_rows_after_padding; i++) {
+      ((float4 *)(shmem + i * smem_stride))[lane_id] = zero4;
+    }
+  }
+  __syncwarp();
+  // TODO: MTMD - Copy directly without using shared memory
+  float *gmem_output = output + output_size * sample_id;
+  if (lane_id < (num_cols >> 2)) {
+    ((float4 *)gmem_output)[lane_id] = ((float4 *)shmem)[lane_id];
+  }
+
+  wmma::fragment<wmma::accumulator, kWmmaM, kWmmaN, kWmmaK, float> acc[ROW_TILES_PER_STEP][ROW_TILES_PER_STEP];
+
+  for (int i = 0; i < ROW_TILES_PER_STEP; i++) {
+    for (int j = 0; j < ROW_TILES_PER_STEP; j++) {
+      wmma::fill_fragment(acc[i][j], zero);
+    }
+  }
+
+  // TODO: MTMD - Loop promotion
+  for (int k_step = 0; k_step < num_col_steps; k_step++) {
+    wmma::fragment<wmma::matrix_a, kWmmaM, kWmmaN, kWmmaK, wmma::precision::tf32, wmma::row_major>
+        a[ROW_TILES_PER_STEP];
+    wmma::fragment<wmma::matrix_b, kWmmaM, kWmmaN, kWmmaK, wmma::precision::tf32, wmma::col_major>
+        b[ROW_TILES_PER_STEP];
+    for (int j = 0; j < ROW_TILES_PER_STEP; j++) {
+      int base_row = (j < ROW_TILES_PER_STEP - 1) ? j * 16 : num_rows_after_padding - 16;
+      const float *tile_ptr = shmem + (base_row * smem_stride + k_step * kWmmaK);
+      wmma::load_matrix_sync(a[j], tile_ptr, smem_stride);
+      wmma::load_matrix_sync(b[j], tile_ptr, smem_stride);
+    }
+    for (int i = 0; i < ROW_TILES_PER_STEP; i++) {
+      for (int j = 0; j < ROW_TILES_PER_STEP; j++) {
+        wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]);
+      }
+    }
+  }
+
+  for (int i = 0; i < ROW_TILES_PER_STEP; i++) {
+    for (int j = 0; j < ROW_TILES_PER_STEP; j++) {
+      float *tile_ptr = shmem + (i * kWmmaM * smem_stride_acc + j * kWmmaN);
+      wmma::store_matrix_sync(tile_ptr, acc[i][j], smem_stride_acc, wmma::mem_row_major);
+    }
+  }
+
+  float *gmem_interact_output = gmem_output + num_cols;
+  int lastRowBlockOffset = ROW_TILES_PER_STEP * 16 - num_rows_after_padding;
+  int src_line = 0;
+  for (int i = 0; i < num_rows; ++i, ++src_line) {
+    if (i == ((ROW_TILES_PER_STEP - 1) * 16)) {
+      src_line += lastRowBlockOffset;
+    }
+    if (lane_id < i) {
+      uint offset = (i * (i - 1)) >> 1;
+      gmem_interact_output[offset + lane_id] = shmem[src_line * smem_stride_acc + lane_id];
+    }
+  }
+  // Padding
+  if (lane_id == 0) {
+    gmem_output[output_size - 1] = 0;
+  }
+}
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint ROW_TILES_PER_STEP,
+          uint COL_TILES_PER_STEP,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__
+    void dotBasedInteractBwdKernelNonAligned_(const __half *__restrict input,
+                                              const __half *__restrict upstream_grad,
+                                              half __restrict *grad,
+                                              half __restrict *bottom_mlp_grad,
+                                              uint batch_size,
+                                              uint num_rows,
+                                              uint num_cols,
+                                              uint num_rows_after_padding,
+                                              uint num_cols_after_padding,
+                                              uint sample_size,
+                                              uint interaction_ugrad_size,
+                                              uint interaction_ugrad_size_with_padding,
+                                              uint interaction_ugrad_2D_size_elems,
+                                              uint interaction_ugrad_2D_stride,
+                                              uint input_size_elems,
+                                              uint input_stride,
+                                              uint num_row_steps,
+                                              uint num_col_steps,
+                                              uint row_tiles_per_step,
+                                              uint shared_mem_per_warp_size_byte) {
+  extern __shared__ half shared_mem[];
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  uint lane_id = threadIdx.x & (WARP_SIZE - 1);
+  // ">> 1" to convert to half pointer
+  uint smem_warp_offset = warp_id * (shared_mem_per_warp_size_byte >> 1);
+
+  half *smem_in = &shared_mem[smem_warp_offset];
+  half *smem_temp = &shared_mem[smem_warp_offset + input_size_elems];
+  float *smem_out = reinterpret_cast<float *>(smem_temp);
+
+  // Global memory pointers for the current sample
+  // Input
+  uint gmem_input_sample_offset = sample_id * sample_size;
+  const half *gmem_input = &input[gmem_input_sample_offset];
+
+  // Interaction Gradient
+  const uint &gmem_grad_sample_offset = gmem_input_sample_offset;
+  half *gmem_grad = &grad[gmem_grad_sample_offset];
+
+  // Bottom MLP gradient
+  half *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols];
+
+  // Upstream gradient vector
+  uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding);
+  const half *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset];
+
+  // Upstream gradient vector for interactions
+  const half *gmem_ugrad_interactions = &gmem_ugrad[num_cols];
+
+  // upstream grad -> shared memory (place in input section temporarily)
+#pragma unroll
+  for (uint idx = lane_id; idx < interaction_ugrad_size; idx += WARP_SIZE) {
+    smem_in[idx] = gmem_ugrad_interactions[idx];
+  }
+  __syncwarp();
+  // Form the 2D ugrad matrix.
+  if (lane_id < num_rows_after_padding) {
+    uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1);
+    uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride;
+    for (uint row = 0; row < num_rows; row++) {
+      half ugrad_val = __float2half(0.0f);
+      if (row < lane_id && lane_id < num_rows) {
+        ugrad_val = smem_in[ugrad_flat_index + row];
+        smem_temp[ugrad_offset_1 + row] = ugrad_val;
+      }
+      if (row <= lane_id && lane_id < num_rows_after_padding) {
+        smem_temp[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val;
+      }
+    }
+    for (uint row = num_rows; row < num_rows_after_padding; row++) {
+      smem_temp[row * interaction_ugrad_2D_stride + lane_id] = __float2half(0.0f);
+    }
+  }
+  __syncwarp();
+
+  // Input -> Shared Memory
+
+  for (uint row = 0; row < num_rows; row++) {
+    half *smem_row_ptr = &smem_in[row * input_stride];
+    const half *gmem_row_ptr = &gmem_input[row * num_cols];
+    for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+      smem_row_ptr[idx] = gmem_row_ptr[idx];
+    }
+    uint idx = lane_id + num_cols;
+    if (idx < num_cols_after_padding) {
+      smem_row_ptr[idx] = __float2half(0);
+    }
+  }
+
+#pragma unroll 2
+  for (uint row = num_rows; row < num_rows_after_padding; row++) {
+    half *smem_row_ptr = &smem_in[row * input_stride];
+    for (uint idx = lane_id; idx < num_cols_after_padding; idx += WARP_SIZE) {
+      smem_row_ptr[idx] = __float2half(0);
+    }
+  }
+  __syncwarp();
+
+  wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[ROW_TILES_PER_STEP]
+                                                                                       [ROW_TILES_PER_STEP];
+  for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+    for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+      const half *tile_ptr = smem_temp + ((i * interaction_ugrad_2D_stride + j) << TILE_DIM_LOG_2);
+      wmma::load_matrix_sync(a[i][j], tile_ptr, interaction_ugrad_2D_stride);
+    }
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[ROW_TILES_PER_STEP];
+  wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> b[ROW_TILES_PER_STEP];
+  for (int col_step = 0; col_step < num_col_steps; col_step++) {
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      const half *tile_ptr = smem_in + ((i * input_stride + col_step) << TILE_DIM_LOG_2);
+      wmma::fill_fragment(acc[i], 0);
+      wmma::load_matrix_sync(b[i], tile_ptr, input_stride);
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+        wmma::mma_sync(acc[i], a[i][j], b[j], acc[i]);
+      }
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      float *tile_ptr = smem_out + i * TILE_DIM * TILE_DIM;
+      wmma::store_matrix_sync(tile_ptr, acc[i], TILE_DIM, wmma::mem_row_major);
+    }
+    __syncwarp();
+    uint gmem_grad_col = (col_step << TILE_DIM_LOG_2) + lane_id;
+    if (gmem_grad_col < num_cols) {
+      for (uint i = 0; i < num_rows; i++) {
+        gmem_grad[i * num_cols + gmem_grad_col] = __float2half(smem_out[(i << TILE_DIM_LOG_2) + lane_id]);
+      }
+    }
+  }
+
+  for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+    gmem_mlp_grad[idx] = gmem_ugrad[idx];
+  }
+}
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint FRAG_A_ROWS,
+          uint FRAG_B_COLS,
+          uint TILE_LENGTH,
+          uint TILE_LENGTH_LOG_2,
+          uint TILE_WIDTH,
+          uint TILE_WIDTH_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__
+    void dotBasedInteractTF32BwdKernel(const float *__restrict input,
+                                       const float *__restrict upstream_grad,
+                                       float *__restrict grad,
+                                       float *__restrict bottom_mlp_grad,
+                                       uint batch_size,
+                                       uint num_rows,
+                                       uint num_cols,
+                                       uint num_rows_after_padding,
+                                       uint num_cols_after_padding,
+                                       uint sample_size,
+                                       uint interaction_ugrad_size,
+                                       uint interaction_ugrad_size_with_padding,
+                                       uint interaction_ugrad_2D_size_elems,
+                                       uint interaction_ugrad_2D_stride,
+                                       uint input_size_elems,
+                                       uint input_stride,
+                                       uint shared_mem_per_warp_size_elems,
+                                       uint num_k_steps,
+                                       uint num_n_steps) {
+  // The only support sizes for TF32.
+  const uint kWmmaM = 16;
+  const uint kWmmaN = 16;
+  const uint kWmmaK = 8;
+
+  extern __shared__ float shared_mem_float[];
+  uint warp_id = threadIdx.x >> WARP_SIZE_LOG_2;
+  uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  uint lane_id = threadIdx.x & (WARP_SIZE - 1);
+  uint smem_warp_offset = warp_id * shared_mem_per_warp_size_elems;
+
+  float *smem_in = &shared_mem_float[smem_warp_offset];
+  float *smem_ugrad = &shared_mem_float[smem_warp_offset + input_size_elems];
+  float *smem_out = &shared_mem_float[smem_warp_offset + input_size_elems + interaction_ugrad_2D_size_elems];
+
+  // Global memory pointers for the current sample
+  // Input
+  uint gmem_input_sample_offset = sample_id * sample_size;
+  const float *gmem_input = &input[gmem_input_sample_offset];
+
+  // Interaction Gradient
+  const uint &gmem_grad_sample_offset = gmem_input_sample_offset;
+  float *gmem_grad = &grad[gmem_grad_sample_offset];
+
+  // Bottom MLP gradient
+  float *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols];
+
+  // Upstream gradient vector
+  uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding);
+  const float *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset];
+
+  // Upstream gradient vector for interactions
+  const float *gmem_ugrad_interactions = &gmem_ugrad[num_cols];
+
+  // upstream grad -> shared memory (place in input section temporarily)
+#pragma unroll
+  for (uint idx = lane_id; idx < (interaction_ugrad_size >> 2); idx += WARP_SIZE) {
+    float4 tmp = ((float4 *)gmem_ugrad_interactions)[idx];
+    tmp.x = wmma::__float_to_tf32(tmp.x);
+    tmp.y = wmma::__float_to_tf32(tmp.y);
+    tmp.z = wmma::__float_to_tf32(tmp.z);
+    tmp.w = wmma::__float_to_tf32(tmp.w);
+    ((float4 *)smem_in)[idx] = tmp;
+  }
+  uint offset = (interaction_ugrad_size >> 2) << 2;
+  for (uint idx = lane_id + offset; idx < interaction_ugrad_size; idx += WARP_SIZE) {
+    smem_in[idx] = wmma::__float_to_tf32(gmem_ugrad_interactions[idx]);
+  }
+  __syncwarp();
+
+  float zero = wmma::__float_to_tf32(0.0f);
+  float4 zero4;
+  zero4.x = zero;
+  zero4.y = zero;
+  zero4.z = zero;
+  zero4.w = zero;
+  // Form the 2D ugrad matrix.
+  if (lane_id < num_rows_after_padding) {
+    uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1);
+    uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride;
+    for (uint row = 0; row < num_rows; row++) {
+      float ugrad_val = zero;
+      if (row < lane_id && lane_id < num_rows) {
+        ugrad_val = smem_in[ugrad_flat_index + row];
+        smem_ugrad[ugrad_offset_1 + row] = ugrad_val;
+      }
+      if (row <= lane_id && lane_id < num_rows_after_padding) {
+        smem_ugrad[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val;
+      }
+    }
+    for (uint row = num_rows; row < num_rows_after_padding; row++) {
+      smem_ugrad[row * interaction_ugrad_2D_stride + lane_id] = zero;
+    }
+  }
+  __syncwarp();
+
+  // Input -> Shared Memory
+
+  if (lane_id < (num_cols >> 2)) {
+    for (uint row = 0; row < num_rows; row++) {
+      float *smem_row_ptr = &smem_in[row * input_stride];
+      const float *gmem_row_ptr = &gmem_input[row * num_cols];
+      float4 tmp = ((float4 *)gmem_row_ptr)[lane_id];
+      tmp.x = wmma::__float_to_tf32(tmp.x);
+      tmp.y = wmma::__float_to_tf32(tmp.y);
+      tmp.z = wmma::__float_to_tf32(tmp.z);
+      tmp.w = wmma::__float_to_tf32(tmp.w);
+      ((float4 *)smem_row_ptr)[lane_id] = tmp;
+    }
+  }
+
+  uint idx = lane_id + num_cols;
+  if (idx < num_cols_after_padding) {
+    for (uint row = 0; row < num_rows; row++) {
+      float *smem_row_ptr = &smem_in[row * input_stride];
+      smem_row_ptr[idx] = zero;
+    }
+  }
+
+  if (lane_id < (num_cols_after_padding >> 2)) {
+#pragma unroll 2
+    for (uint row = num_rows; row < num_rows_after_padding; row++) {
+      float *smem_row_ptr = &smem_in[row * input_stride];
+      ((float4 *)smem_row_ptr)[lane_id] = zero4;
+    }
+  }
+  __syncwarp();
+
+  wmma::fragment<wmma::matrix_a, kWmmaM, kWmmaN, kWmmaK, wmma::precision::tf32, wmma::row_major> a[FRAG_A_ROWS];
+  wmma::fragment<wmma::matrix_b, kWmmaM, kWmmaN, kWmmaK, wmma::precision::tf32, wmma::row_major> b[FRAG_B_COLS];
+  wmma::fragment<wmma::accumulator, kWmmaM, kWmmaN, kWmmaK, float> acc[FRAG_A_ROWS][FRAG_B_COLS];
+  for (uint n = 0; n < num_n_steps; n++) {
+    for (uint i = 0; i < FRAG_A_ROWS; i++) {
+      for (uint j = 0; j < FRAG_B_COLS; j++) {
+        wmma::fill_fragment(acc[i][j], zero);
+      }
+    }
+    for (uint k = 0; k < num_k_steps; k++) {
+      for (uint i = 0; i < FRAG_A_ROWS; i++) {
+        const float *mat_a_tile_ptr =
+            smem_ugrad + (i << TILE_LENGTH_LOG_2) * interaction_ugrad_2D_stride + (k << TILE_WIDTH_LOG_2);
+        wmma::load_matrix_sync(a[i], mat_a_tile_ptr, interaction_ugrad_2D_stride);
+      }
+      for (uint j = 0; j < FRAG_B_COLS; j++) {
+        const float *mat_b_tile_ptr =
+            smem_in + (k << TILE_WIDTH_LOG_2) * input_stride + ((2 * n + j) << TILE_LENGTH_LOG_2);
+        wmma::load_matrix_sync(b[j], mat_b_tile_ptr, input_stride);
+      }
+      for (uint i = 0; i < FRAG_A_ROWS; i++) {
+        for (uint j = 0; j < FRAG_B_COLS; j++) {
+          wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]);
+        }
+      }
+    }
+    // __syncwarp(); ?
+    uint out_stride = FRAG_B_COLS << TILE_LENGTH_LOG_2;
+    for (uint i = 0; i < FRAG_A_ROWS; i++) {
+      for (uint j = 0; j < FRAG_B_COLS; j++) {
+        float *out_tile_ptr = smem_out + (i << TILE_LENGTH_LOG_2) * out_stride + (j << TILE_LENGTH_LOG_2);
+        wmma::store_matrix_sync(out_tile_ptr, acc[i][j], out_stride, wmma::mem_row_major);
+      }
+    }
+    uint gmem_grad_col = n * (FRAG_B_COLS << TILE_LENGTH_LOG_2) + lane_id;
+    for (uint i = 0; i < num_rows; i++) {
+      gmem_grad[i * num_cols + gmem_grad_col] = smem_out[i * out_stride + lane_id];
+    }
+  }
+
+  if (lane_id < (num_cols >> 2)) {
+    ((float4 *)gmem_mlp_grad)[lane_id] = ((float4 *)gmem_ugrad)[lane_id];
+  }
+}
+
+inline void dotBasedInteractTF32Fwd(
+    const void *input, const void *bottom_mlp_output, void *output, uint batch_size, uint num_rows, uint num_cols) {
+  const uint kWarpSize = 32;
+  const uint kWarpSizeLog2 = Log2<kWarpSize>::value;
+  const uint kTileLength = 16;
+  const uint kTileLengthLog2 = Log2<kTileLength>::value;
+  const uint kTileWidth = 8;
+  const uint kTileWidthLog2 = Log2<kTileWidth>::value;
+  const uint kWarpsPerBlock = 2;
+  const uint kThreadBlockSize = kWarpsPerBlock * kWarpSize;
+  const uint kPaddingSize = 1;
+  const uint kRowTilesPerStep = 2;
+  const uint kColTilesPerStep = 1;
+  const uint kSkewFloat = 4;  // Ensures we are 16 byte align as required by nvcuda::wmma::load_matrix_sync
+
+  // num tiles
+  uint mat_a_num_row_tiles = (num_rows + kTileLength - 1) >> kTileLengthLog2;
+  uint mat_a_num_col_tiles = (num_cols + kTileWidth - 1) >> kTileWidthLog2;
+
+  const uint &mat_b_num_row_tiles = mat_a_num_col_tiles;
+  const uint &mat_b_num_col_tiles = mat_a_num_row_tiles;
+
+  // number of rows and columns after padding
+  uint num_rows_after_padding = mat_a_num_row_tiles << kTileLengthLog2;
+  uint num_cols_after_padding = mat_a_num_col_tiles << kTileWidthLog2;
+
+  uint num_row_steps = mat_a_num_row_tiles / kRowTilesPerStep;
+  uint num_col_steps = mat_a_num_col_tiles / kColTilesPerStep;
+
+  const uint smem_stride = num_cols_after_padding + kSkewFloat;
+  const uint smem_elems_per_warp_mat = num_rows_after_padding * smem_stride;
+
+  const uint smem_stride_acc = num_rows_after_padding + kSkewFloat;
+  const uint smem_elems_per_warp_acc = num_rows_after_padding * smem_stride_acc;
+
+  const uint smem_elems_per_warp =
+      smem_elems_per_warp_mat > smem_elems_per_warp_acc ? smem_elems_per_warp_mat : smem_elems_per_warp_acc;
+
+  uint output_size = num_cols + ((num_rows * (num_rows - 1)) >> 1) + kPaddingSize;
+  bool float4_predicate = !((num_cols & 7) || (output_size & 7));
+
+  // TODO: MTMD - Clean Up
+  // std::cout << "mat_a_num_row_tiles    " << mat_a_num_row_tiles << std::endl;
+  // std::cout << "mat_a_num_col_tiles    " << mat_a_num_col_tiles << std::endl;
+  // std::cout << "mat_b_num_row_tiles    " << mat_b_num_row_tiles << std::endl;
+  // std::cout << "mat_b_num_col_tiles    " << mat_b_num_col_tiles << std::endl;
+  // std::cout << "num_rows_after_padding " << num_rows_after_padding << std::endl;
+  // std::cout << "num_cols_after_padding " << num_cols_after_padding << std::endl;
+  // std::cout << "num_row_steps          " << num_row_steps << std::endl;
+  // std::cout << "num_col_steps          " << num_col_steps << std::endl;
+  // std::cout << "smem_stride            " << smem_stride << std::endl;
+  // std::cout << "smem_elems_per_warp_mat" << smem_elems_per_warp_mat << std::endl;
+  // std::cout << "smem_stride_acc        " << smem_stride_acc << std::endl;
+  // std::cout << "smem_elems_per_warp_acc" << smem_elems_per_warp_acc << std::endl;
+  // std::cout << "===================================================================" << std::endl;
+
+  if (float4_predicate) {
+    dotBasedInteractTF32FwdKernel<kWarpsPerBlock,
+                                  kThreadBlockSize,
+                                  kWarpSize,
+                                  kWarpSizeLog2,
+                                  kTileLength,
+                                  kTileLengthLog2,
+                                  kTileWidth,
+                                  kTileWidthLog2,
+                                  kRowTilesPerStep>
+        <<<(batch_size + kWarpsPerBlock - 1) / kWarpsPerBlock,
+           kThreadBlockSize,
+           kWarpsPerBlock * smem_elems_per_warp * sizeof(float)>>>((const float *)input,
+                                                                   (float *)output,
+                                                                   batch_size,
+                                                                   num_rows,
+                                                                   num_cols,
+                                                                   num_rows_after_padding,
+                                                                   num_cols_after_padding,
+                                                                   smem_elems_per_warp,
+                                                                   output_size,
+                                                                   num_row_steps,
+                                                                   num_col_steps,
+                                                                   smem_stride,
+                                                                   smem_stride_acc);
+  } else {
+    std::cout << "GENERIC VERSION IS UNFINISHED." << std::endl;
+#ifdef GENERIC_IS_DONE
+    dotBasedInteractFwdKernelNonAligned<warps_per_threadblock,
+                                        threadblock_size,
+                                        M_BLOCKS,
+                                        K_BLOCKS,
+                                        SMEM_STRIDE,
+                                        SMEM_STRIDE_ACC,
+                                        kWarpSize,
+                                        kWarpSizeLog2,
+                                        kTileDim,
+                                        kTileDimLog2>
+        <<<(batch_size + warps_per_threadblock - 1) / warps_per_threadblock,
+           threadblock_size,
+           warps_per_threadblock * smem_elems_per_warp * sizeof(__half)>>>((const __half *)input,
+                                                                           (half *)output,
+                                                                           batch_size,
+                                                                           num_rows,
+                                                                           num_cols,
+                                                                           num_rows_after_padding,
+                                                                           num_cols_after_padding,
+                                                                           smem_elems_per_warp,
+                                                                           smem_rows_per_warp,
+                                                                           output_size,
+                                                                           num_row_steps,
+                                                                           num_col_steps);
+#endif
+  }
+}
+
+inline void dotBasedInteractTF32Bwd(void *input,
+                                    void *upstream_grad,
+                                    void *grad,
+                                    void *bottom_mlp_grad,
+                                    uint batch_size,
+                                    uint num_rows,
+                                    uint num_cols) {
+  // Fragment Settings
+  const uint kFragARows = 2;
+  const uint kFragBCols = 2;
+  const uint kTileLength = 16;
+  const uint kTileLengthLog2 = Log2<kTileLength>::value;
+  const uint kTileWidth = 8;
+  const uint kTileWidthLog2 = Log2<kTileWidth>::value;
+
+  const uint kWarpSize = 32;
+  const uint kWarpSizeLog2 = Log2<kWarpSize>::value;
+  const uint kSkewFloat = 4;
+  const uint kPaddingSize = 1;
+  const uint kWarpsPerBlock = 1;
+  const uint kWarpsPerBlockLog2 = Log2<kWarpsPerBlock>::value;
+  const uint kNumThreads = kWarpsPerBlock * kWarpSize;
+
+  // num tiles
+  uint mat_a_num_row_tiles = (num_rows + kTileLength - 1) >> kTileLengthLog2;
+  uint mat_a_num_col_tiles = (num_rows + kTileWidth - 1) >> kTileWidthLog2;
+
+  const uint &mat_b_num_row_tiles = mat_a_num_col_tiles;
+  uint mat_b_num_col_tiles = (num_cols + kTileLength - 1) >> kTileLengthLog2;
+
+  // number of rows and columns after padding
+  uint num_rows_after_padding = mat_a_num_row_tiles << kTileLengthLog2;
+  uint num_cols_after_padding = mat_b_num_col_tiles << kTileLengthLog2;
+
+  // 2D ugrad size and stride
+  uint interaction_ugrad_2D_stride = num_rows_after_padding + kSkewFloat;
+  uint interaction_ugrad_2D_size_elems = num_rows_after_padding * interaction_ugrad_2D_stride;
+
+  // 1D ugrad size
+  uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1;
+  uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize;
+
+  // in_out place size and stride
+  uint input_stride = num_cols_after_padding + kSkewFloat;
+  uint input_size_elems = num_rows_after_padding * input_stride;
+
+  // sample size
+  uint sample_size = num_rows * num_cols;
+
+  // output size
+  uint output_size_elems = kTileLength * kTileLength * kFragARows * kFragBCols;
+
+  // Shared memory size
+  uint shared_mem_per_warp_size_elems = interaction_ugrad_2D_size_elems + input_size_elems + output_size_elems;
+  uint shared_mem_size_elems = kWarpsPerBlock * shared_mem_per_warp_size_elems;
+  uint shared_mem_size_bytes = shared_mem_size_elems * sizeof(float);
+
+  uint num_blocks = (batch_size + kWarpsPerBlock - 1) >> kWarpsPerBlockLog2;
+  uint num_k_steps = mat_a_num_col_tiles;
+  uint num_n_steps = mat_b_num_col_tiles / kFragBCols;
+
+  bool float4_predicate = !((interaction_ugrad_size_with_padding & 7) || (num_cols & 7));
+  if (float4_predicate) {
+    dotBasedInteractTF32BwdKernel<kWarpsPerBlock,
+                                  kNumThreads,
+                                  kWarpSize,
+                                  kWarpSizeLog2,
+                                  kFragARows,
+                                  kFragBCols,
+                                  kTileLength,
+                                  kTileLengthLog2,
+                                  kTileWidth,
+                                  kTileWidthLog2>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const float *)input,
+                                                             (const float *)upstream_grad,
+                                                             (float *)grad,
+                                                             (float *)bottom_mlp_grad,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             num_rows_after_padding,
+                                                             num_cols_after_padding,
+                                                             sample_size,
+                                                             interaction_ugrad_size,
+                                                             interaction_ugrad_size_with_padding,
+                                                             interaction_ugrad_2D_size_elems,
+                                                             interaction_ugrad_2D_stride,
+                                                             input_size_elems,
+                                                             input_stride,
+                                                             shared_mem_per_warp_size_elems,
+                                                             num_k_steps,
+                                                             num_n_steps);
+  } else {
+    std::cout << "GENERIC VERSION IS UNFINISHED." << std::endl;
+#ifdef GENERIC_IS_DONE
+    dotBasedInteractBwdKernelNonAligned<kWarpsPerBlock,
+                                        kNumThreads,
+                                        kRowTilesPerStep,
+                                        kColTilesPerStep,
+                                        kWarpSize,
+                                        kWarpSizeLog2,
+                                        kTileDim,
+                                        kTileDimLog2>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const half *)input,
+                                                             (const half *)upstream_grad,
+                                                             (half *)grad,
+                                                             (half *)bottom_mlp_grad,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             num_rows_after_padding,
+                                                             num_cols_after_padding,
+                                                             sample_size,
+                                                             interaction_ugrad_size,
+                                                             interaction_ugrad_size_with_padding,
+                                                             interaction_ugrad_2D_size_elems,
+                                                             interaction_ugrad_2D_stride,
+                                                             input_size_elems,
+                                                             input_stride,
+                                                             num_row_steps,
+                                                             num_col_steps,
+                                                             row_tiles_per_step,
+                                                             shared_mem_per_warp_size_byte);
+#endif
+  }
+}

+ 13 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/pytorch_ops.cpp

@@ -0,0 +1,13 @@
+#include <torch/extension.h>
+
+torch::Tensor dotBasedInteractFwdTorch(torch::Tensor input,
+                                       torch::Tensor bottom_mlp_output);
+std::vector<torch::Tensor> dotBasedInteractBwdTorch(torch::Tensor input,
+                                                    torch::Tensor upstreamGrad);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("dotBasedInteractFwd", &dotBasedInteractFwdTorch, "", py::arg("input"),
+        py::arg("bottom_mlp_output"));
+  m.def("dotBasedInteractBwd", &dotBasedInteractBwdTorch, "", py::arg("input"),
+        py::arg("upstreamGrad"));
+}

+ 22 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_ampere/shared_utils.cuh

@@ -0,0 +1,22 @@
+#pragma once
+
+#include <math.h>
+
+#define CHK_CUDA(expression)                                                                                        \
+  {                                                                                                                 \
+    cudaError_t status = (expression);                                                                              \
+    if (status != cudaSuccess) {                                                                                    \
+      std::cerr << "Error in file: " << __FILE__ << ", on line: " << __LINE__ << ": " << cudaGetErrorString(status) \
+                << std::endl;                                                                                       \
+      std::exit(EXIT_FAILURE);                                                                                      \
+    }                                                                                                               \
+  }
+
+template <uint x>
+struct Log2 {
+  static constexpr uint value = 1 + Log2<x / 2>::value;
+};
+template <>
+struct Log2<1> {
+  static constexpr uint value = 0;
+};

+ 1137 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_volta/dot_based_interact.cu

@@ -0,0 +1,1137 @@
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime_api.h>
+#include <device_launch_parameters.h>
+#include <mma.h>
+#include <cuda_fp16.hpp>
+
+#include <math.h>
+#include <fstream>
+#include <iomanip>
+#include <iostream>
+#include <vector>
+
+using namespace nvcuda;
+
+#define CHK_CUDA(expression)                                                                                        \
+  {                                                                                                                 \
+    cudaError_t status = (expression);                                                                              \
+    if (status != cudaSuccess) {                                                                                    \
+      std::cerr << "Error in file: " << __FILE__ << ", on line: " << __LINE__ << ": " << cudaGetErrorString(status) \
+                << std::endl;                                                                                       \
+      std::exit(EXIT_FAILURE);                                                                                      \
+    }                                                                                                               \
+  }
+
+template <uint x>
+struct Log2 {
+  static constexpr uint value = 1 + Log2<x / 2>::value;
+};
+template <>
+struct Log2<1> {
+  static constexpr uint value = 0;
+};
+
+struct __align__(8) half4 {
+  half2 vals[2];
+};
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint M_BLOCKS,
+          uint K_BLOCKS,
+          uint SMEM_STRIDE,
+          uint SMEM_STRIDE_ACC,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractFwdKernelNonAligned(const __half *__restrict input,
+                                                                                        __half *__restrict output,
+                                                                                        uint batch_size,
+                                                                                        uint num_rows,
+                                                                                        uint num_cols,
+                                                                                        uint num_rows_after_padding,
+                                                                                        uint num_cols_after_padding,
+                                                                                        uint smem_elems_per_warp,
+                                                                                        uint smem_rows_per_warp,
+                                                                                        uint output_size,
+                                                                                        uint num_row_steps,
+                                                                                        uint num_col_steps) {
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  int sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  int lane_id = threadIdx.x & (WARP_SIZE - 1);
+
+  extern __shared__ half shmem_dynamic[];
+  half *shmem = shmem_dynamic + (warp_id * smem_elems_per_warp);
+
+  const half *sample_input = input + num_rows * num_cols * sample_id;
+  for (uint i = 0; i < num_rows; ++i, sample_input += num_cols) {
+    for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+      (shmem + i * SMEM_STRIDE)[idx] = sample_input[idx];
+    }
+  }
+
+  uint idx = lane_id + num_cols;
+  if (idx < num_cols_after_padding) {
+    for (int i = 0; i < num_rows; ++i) {
+      (shmem + i * SMEM_STRIDE)[idx] = __float2half(0);
+    }
+  }
+
+  half4 zeros;
+  zeros.vals[0].x = __float2half(0);
+  zeros.vals[0].y = __float2half(0);
+  zeros.vals[1].x = __float2half(0);
+  zeros.vals[1].y = __float2half(0);
+  if (lane_id < (num_cols_after_padding >> 2)) {
+    for (int i = num_rows; i < num_rows_after_padding; i++) {
+      ((half4 *)(shmem + i * SMEM_STRIDE))[lane_id] = zeros;
+    }
+  }
+  __syncwarp();
+  half *gmem_output = output + output_size * sample_id;
+
+  for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+    gmem_output[idx] = shmem[idx];
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[M_BLOCKS][M_BLOCKS];
+
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      wmma::fill_fragment(acc[i][j], 0);
+    }
+  }
+
+  for (int k_step = 0; k_step < num_col_steps; k_step++) {
+    wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[M_BLOCKS];
+    wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::col_major> b[M_BLOCKS];
+    for (int j = 0; j < M_BLOCKS; j++) {
+      int base_row = (j < M_BLOCKS - 1) ? j * 16 : smem_rows_per_warp - 16;
+      const half *tile_ptr = shmem + (base_row * SMEM_STRIDE + k_step * 16);
+      wmma::load_matrix_sync(a[j], tile_ptr, SMEM_STRIDE);
+      wmma::load_matrix_sync(b[j], tile_ptr, SMEM_STRIDE);
+    }
+    for (int i = 0; i < M_BLOCKS; i++) {
+      for (int j = 0; j < M_BLOCKS; j++) {
+        wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]);
+      }
+    }
+  }
+  float *shmem_store = reinterpret_cast<float *>(shmem);
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      float *tile_ptr = shmem_store + (i * 16 * SMEM_STRIDE_ACC + j * 16);
+      wmma::store_matrix_sync(tile_ptr, acc[i][j], SMEM_STRIDE_ACC, wmma::mem_row_major);
+    }
+  }
+
+  half *gmem_interact_output = gmem_output + num_cols;
+  int lastRowBlockOffset = M_BLOCKS * 16 - smem_rows_per_warp;
+  int srcLine = 0;
+  for (int i = 0; i < num_rows; ++i, ++srcLine) {
+    if (i == ((M_BLOCKS - 1) * 16)) {
+      srcLine += lastRowBlockOffset;
+    }
+    if (lane_id < i) {
+      uint offset = (i * (i - 1)) >> 1;
+      gmem_interact_output[offset + lane_id] = __float2half(shmem_store[srcLine * SMEM_STRIDE_ACC + lane_id]);
+    }
+  }
+  // Padding
+  if (lane_id == 0) {
+    gmem_output[output_size - 1] = __float2half(0);
+  }
+}
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint M_BLOCKS,
+          uint K_BLOCKS,
+          uint SMEM_STRIDE,
+          uint SMEM_STRIDE_ACC,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractFwdKernel(const __half *__restrict input,
+                                                                              __half *__restrict output,
+                                                                              uint batch_size,
+                                                                              uint num_rows,
+                                                                              uint num_cols,
+                                                                              uint num_rows_after_padding,
+                                                                              uint num_cols_after_padding,
+                                                                              uint smem_elems_per_warp,
+                                                                              uint smem_rows_per_warp,
+                                                                              uint output_size,
+                                                                              uint num_row_steps,
+                                                                              uint num_col_steps) {
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  int sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  int lane_id = threadIdx.x & (WARP_SIZE - 1);
+
+  extern __shared__ half shmem_dynamic[];
+  half *shmem = shmem_dynamic + (warp_id * smem_elems_per_warp);
+
+  const half *sample_input = input + num_rows * num_cols * sample_id;
+  if (lane_id < (num_cols >> 2)) {
+    for (int i = 0; i < num_rows; ++i, sample_input += num_cols) {
+      ((float2 *)(shmem + i * SMEM_STRIDE))[lane_id] = ((float2 *)sample_input)[lane_id];
+    }
+  }
+
+  uint idx = lane_id + num_cols;
+  if (idx < num_cols_after_padding) {
+    for (int i = 0; i < num_rows; ++i) {
+      (shmem + i * SMEM_STRIDE)[idx] = __float2half(0);
+    }
+  }
+
+  half4 zeros;
+  zeros.vals[0].x = __float2half(0);
+  zeros.vals[0].y = __float2half(0);
+  zeros.vals[1].x = __float2half(0);
+  zeros.vals[1].y = __float2half(0);
+  if (lane_id < (num_cols_after_padding >> 2)) {
+    for (int i = num_rows; i < num_rows_after_padding; i++) {
+      ((half4 *)(shmem + i * SMEM_STRIDE))[lane_id] = zeros;
+    }
+  }
+  __syncwarp();
+  half *gmem_output = output + output_size * sample_id;
+  if (lane_id < (num_cols >> 2)) {
+    ((float2 *)gmem_output)[lane_id] = ((float2 *)shmem)[lane_id];
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[M_BLOCKS][M_BLOCKS];
+
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      wmma::fill_fragment(acc[i][j], 0);
+    }
+  }
+
+  for (int k_step = 0; k_step < num_col_steps; k_step++) {
+    wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[M_BLOCKS];
+    wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::col_major> b[M_BLOCKS];
+    for (int j = 0; j < M_BLOCKS; j++) {
+      int base_row = (j < M_BLOCKS - 1) ? j * 16 : smem_rows_per_warp - 16;
+      const half *tile_ptr = shmem + (base_row * SMEM_STRIDE + k_step * 16);
+      wmma::load_matrix_sync(a[j], tile_ptr, SMEM_STRIDE);
+      wmma::load_matrix_sync(b[j], tile_ptr, SMEM_STRIDE);
+    }
+    for (int i = 0; i < M_BLOCKS; i++) {
+      for (int j = 0; j < M_BLOCKS; j++) {
+        wmma::mma_sync(acc[i][j], a[i], b[j], acc[i][j]);
+      }
+    }
+  }
+  float *shmem_store = reinterpret_cast<float *>(shmem);
+  for (int i = 0; i < M_BLOCKS; i++) {
+    for (int j = 0; j < M_BLOCKS; j++) {
+      float *tile_ptr = shmem_store + (i * 16 * SMEM_STRIDE_ACC + j * 16);
+      wmma::store_matrix_sync(tile_ptr, acc[i][j], SMEM_STRIDE_ACC, wmma::mem_row_major);
+    }
+  }
+
+  half *gmem_interact_output = gmem_output + num_cols;
+  int lastRowBlockOffset = M_BLOCKS * 16 - smem_rows_per_warp;
+  int srcLine = 0;
+  for (int i = 0; i < num_rows; ++i, ++srcLine) {
+    if (i == ((M_BLOCKS - 1) * 16)) {
+      srcLine += lastRowBlockOffset;
+    }
+    if (lane_id < i) {
+      uint offset = (i * (i - 1)) >> 1;
+      gmem_interact_output[offset + lane_id] = __float2half(shmem_store[srcLine * SMEM_STRIDE_ACC + lane_id]);
+    }
+  }
+  // Padding
+  if (lane_id == 0) {
+    gmem_output[output_size - 1] = __float2half(0);
+  }
+}
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint ROW_TILES_PER_STEP,
+          uint COL_TILES_PER_STEP,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__
+    void dotBasedInteractBwdKernelNonAligned(const __half *__restrict input,
+                                             const __half *__restrict upstream_grad,
+                                             half __restrict *grad,
+                                             half __restrict *bottom_mlp_grad,
+                                             uint batch_size,
+                                             uint num_rows,
+                                             uint num_cols,
+                                             uint num_rows_after_padding,
+                                             uint num_cols_after_padding,
+                                             uint sample_size,
+                                             uint interaction_ugrad_size,
+                                             uint interaction_ugrad_size_with_padding,
+                                             uint interaction_ugrad_2D_size_elems,
+                                             uint interaction_ugrad_2D_stride,
+                                             uint input_size_elems,
+                                             uint input_stride,
+                                             uint num_row_steps,
+                                             uint num_col_steps,
+                                             uint row_tiles_per_step,
+                                             uint shared_mem_per_warp_size_byte) {
+  extern __shared__ half shared_mem[];
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  uint lane_id = threadIdx.x & (WARP_SIZE - 1);
+  // ">> 1" to convert to half pointer
+  uint smem_warp_offset = warp_id * (shared_mem_per_warp_size_byte >> 1);
+
+  half *smem_in = &shared_mem[smem_warp_offset];
+  half *smem_temp = &shared_mem[smem_warp_offset + input_size_elems];
+  float *smem_out = reinterpret_cast<float *>(smem_temp);
+
+  // Global memory pointers for the current sample
+  // Input
+  uint gmem_input_sample_offset = sample_id * sample_size;
+  const half *gmem_input = &input[gmem_input_sample_offset];
+
+  // Interaction Gradient
+  const uint &gmem_grad_sample_offset = gmem_input_sample_offset;
+  half *gmem_grad = &grad[gmem_grad_sample_offset];
+
+  // Bottom MLP gradient
+  half *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols];
+
+  // Upstream gradient vector
+  uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding);
+  const half *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset];
+
+  // Upstream gradient vector for interactions
+  const half *gmem_ugrad_interactions = &gmem_ugrad[num_cols];
+
+  // upstream grad -> shared memory (place in input section temporarily)
+#pragma unroll
+  for (uint idx = lane_id; idx < interaction_ugrad_size; idx += WARP_SIZE) {
+    smem_in[idx] = gmem_ugrad_interactions[idx];
+  }
+  __syncwarp();
+  // Form the 2D ugrad matrix.
+  if (lane_id < num_rows_after_padding) {
+    uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1);
+    uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride;
+    for (uint row = 0; row < num_rows; row++) {
+      half ugrad_val = __float2half(0.0f);
+      if (row < lane_id && lane_id < num_rows) {
+        ugrad_val = smem_in[ugrad_flat_index + row];
+        smem_temp[ugrad_offset_1 + row] = ugrad_val;
+      }
+      if (row <= lane_id && lane_id < num_rows_after_padding) {
+        smem_temp[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val;
+      }
+    }
+    for (uint row = num_rows; row < num_rows_after_padding; row++) {
+      smem_temp[row * interaction_ugrad_2D_stride + lane_id] = __float2half(0.0f);
+    }
+  }
+  __syncwarp();
+
+  // Input -> Shared Memory
+
+  for (uint row = 0; row < num_rows; row++) {
+    half *smem_row_ptr = &smem_in[row * input_stride];
+    const half *gmem_row_ptr = &gmem_input[row * num_cols];
+    for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+      smem_row_ptr[idx] = gmem_row_ptr[idx];
+    }
+    uint idx = lane_id + num_cols;
+    if (idx < num_cols_after_padding) {
+      smem_row_ptr[idx] = __float2half(0);
+    }
+  }
+
+#pragma unroll 2
+  for (uint row = num_rows; row < num_rows_after_padding; row++) {
+    half *smem_row_ptr = &smem_in[row * input_stride];
+    for (uint idx = lane_id; idx < num_cols_after_padding; idx += WARP_SIZE) {
+      smem_row_ptr[idx] = __float2half(0);
+    }
+  }
+  __syncwarp();
+
+  wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[ROW_TILES_PER_STEP]
+                                                                                       [ROW_TILES_PER_STEP];
+  for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+    for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+      const half *tile_ptr = smem_temp + ((i * interaction_ugrad_2D_stride + j) << TILE_DIM_LOG_2);
+      wmma::load_matrix_sync(a[i][j], tile_ptr, interaction_ugrad_2D_stride);
+    }
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[ROW_TILES_PER_STEP];
+  wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> b[ROW_TILES_PER_STEP];
+  for (int col_step = 0; col_step < num_col_steps; col_step++) {
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      const half *tile_ptr = smem_in + ((i * input_stride + col_step) << TILE_DIM_LOG_2);
+      wmma::fill_fragment(acc[i], 0);
+      wmma::load_matrix_sync(b[i], tile_ptr, input_stride);
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+        wmma::mma_sync(acc[i], a[i][j], b[j], acc[i]);
+      }
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      float *tile_ptr = smem_out + i * TILE_DIM * TILE_DIM;
+      wmma::store_matrix_sync(tile_ptr, acc[i], TILE_DIM, wmma::mem_row_major);
+    }
+    __syncwarp();
+    uint gmem_grad_col = (col_step << TILE_DIM_LOG_2) + lane_id;
+    if (gmem_grad_col < num_cols) {
+      for (uint i = 0; i < num_rows; i++) {
+        gmem_grad[i * num_cols + gmem_grad_col] = __float2half(smem_out[(i << TILE_DIM_LOG_2) + lane_id]);
+      }
+    }
+  }
+
+  for (uint idx = lane_id; idx < num_cols; idx += WARP_SIZE) {
+    gmem_mlp_grad[idx] = gmem_ugrad[idx];
+  }
+}
+
+template <uint WARPS_PER_BLOCK,
+          uint THREADBLOCK_SIZE,
+          uint ROW_TILES_PER_STEP,
+          uint COL_TILES_PER_STEP,
+          uint WARP_SIZE,
+          uint WARP_SIZE_LOG_2,
+          uint TILE_DIM,
+          uint TILE_DIM_LOG_2>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractBwdKernel(const __half *__restrict input,
+                                                                              const __half *__restrict upstream_grad,
+                                                                              half __restrict *grad,
+                                                                              half __restrict *bottom_mlp_grad,
+                                                                              uint batch_size,
+                                                                              uint num_rows,
+                                                                              uint num_cols,
+                                                                              uint num_rows_after_padding,
+                                                                              uint num_cols_after_padding,
+                                                                              uint sample_size,
+                                                                              uint interaction_ugrad_size,
+                                                                              uint interaction_ugrad_size_with_padding,
+                                                                              uint interaction_ugrad_2D_size_elems,
+                                                                              uint interaction_ugrad_2D_stride,
+                                                                              uint input_size_elems,
+                                                                              uint input_stride,
+                                                                              uint num_row_steps,
+                                                                              uint num_col_steps,
+                                                                              uint row_tiles_per_step,
+                                                                              uint shared_mem_per_warp_size_byte) {
+  extern __shared__ half shared_mem[];
+  uint warp_id = (threadIdx.x >> WARP_SIZE_LOG_2);
+  uint sample_id = blockIdx.x * WARPS_PER_BLOCK + warp_id;
+  if (sample_id >= batch_size) {
+    return;
+  }
+  uint lane_id = threadIdx.x & (WARP_SIZE - 1);
+  // ">> 1" to convert to half pointer
+  uint smem_warp_offset = warp_id * (shared_mem_per_warp_size_byte >> 1);
+
+  half *smem_in = &shared_mem[smem_warp_offset];
+  half *smem_temp = &shared_mem[smem_warp_offset + input_size_elems];
+  float *smem_out = reinterpret_cast<float *>(smem_temp);
+
+  // Global memory pointers for the current sample
+  // Input
+  uint gmem_input_sample_offset = sample_id * sample_size;
+  const half *gmem_input = &input[gmem_input_sample_offset];
+
+  // Interaction Gradient
+  const uint &gmem_grad_sample_offset = gmem_input_sample_offset;
+  half *gmem_grad = &grad[gmem_grad_sample_offset];
+
+  // Bottom MLP gradient
+  half *gmem_mlp_grad = &bottom_mlp_grad[sample_id * num_cols];
+
+  // Upstream gradient vector
+  uint gmem_ugrad_sample_offset = sample_id * (num_cols + interaction_ugrad_size_with_padding);
+  const half *gmem_ugrad = &upstream_grad[gmem_ugrad_sample_offset];
+
+  // Upstream gradient vector for interactions
+  const half *gmem_ugrad_interactions = &gmem_ugrad[num_cols];
+
+  // upstream grad -> shared memory (place in input section temporarily)
+#pragma unroll
+  for (uint idx = lane_id; idx < (interaction_ugrad_size >> 3); idx += WARP_SIZE) {
+    ((float4 *)smem_in)[idx] = ((float4 *)gmem_ugrad_interactions)[idx];
+  }
+  uint offset = (interaction_ugrad_size >> 3) << 3;
+  for (uint idx = lane_id + offset; idx < interaction_ugrad_size; idx += WARP_SIZE) {
+    smem_in[idx] = gmem_ugrad_interactions[idx];
+  }
+  __syncwarp();
+  // Form the 2D ugrad matrix.
+  if (lane_id < num_rows_after_padding) {
+    uint ugrad_flat_index = ((lane_id * (lane_id - 1)) >> 1);
+    uint ugrad_offset_1 = lane_id * interaction_ugrad_2D_stride;
+    for (uint row = 0; row < num_rows; row++) {
+      half ugrad_val = __float2half(0.0f);
+      if (row < lane_id && lane_id < num_rows) {
+        ugrad_val = smem_in[ugrad_flat_index + row];
+        smem_temp[ugrad_offset_1 + row] = ugrad_val;
+      }
+      if (row <= lane_id && lane_id < num_rows_after_padding) {
+        smem_temp[row * interaction_ugrad_2D_stride + lane_id] = ugrad_val;
+      }
+    }
+    for (uint row = num_rows; row < num_rows_after_padding; row++) {
+      smem_temp[row * interaction_ugrad_2D_stride + lane_id] = __float2half(0.0f);
+    }
+  }
+  __syncwarp();
+
+  // Input -> Shared Memory
+
+  if (lane_id < (num_cols >> 2)) {
+    for (uint row = 0; row < num_rows; row++) {
+      half *smem_row_ptr = &smem_in[row * input_stride];
+      const half *gmem_row_ptr = &gmem_input[row * num_cols];
+      ((float2 *)smem_row_ptr)[lane_id] = ((float2 *)gmem_row_ptr)[lane_id];
+    }
+  }
+
+  uint idx = lane_id + num_cols;
+  if (idx < num_cols_after_padding) {
+    for (uint row = 0; row < num_rows; row++) {
+      half *smem_row_ptr = &smem_in[row * input_stride];
+      smem_row_ptr[idx] = __float2half(0);
+    }
+  }
+
+  half4 zeros;
+  zeros.vals[0].x = __float2half(0);
+  zeros.vals[0].y = __float2half(0);
+  zeros.vals[1].x = __float2half(0);
+  zeros.vals[1].y = __float2half(0);
+  if (lane_id < (num_cols_after_padding >> 2)) {
+#pragma unroll 2
+    for (uint row = num_rows; row < num_rows_after_padding; row++) {
+      half *smem_row_ptr = &smem_in[row * input_stride];
+      ((half4 *)smem_row_ptr)[lane_id] = zeros;
+    }
+  }
+  __syncwarp();
+
+  wmma::fragment<wmma::matrix_a, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> a[ROW_TILES_PER_STEP]
+                                                                                       [ROW_TILES_PER_STEP];
+  for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+    for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+      const half *tile_ptr = smem_temp + ((i * interaction_ugrad_2D_stride + j) << TILE_DIM_LOG_2);
+      wmma::load_matrix_sync(a[i][j], tile_ptr, interaction_ugrad_2D_stride);
+    }
+  }
+
+  wmma::fragment<wmma::accumulator, TILE_DIM, TILE_DIM, TILE_DIM, float> acc[ROW_TILES_PER_STEP];
+  wmma::fragment<wmma::matrix_b, TILE_DIM, TILE_DIM, TILE_DIM, half, wmma::row_major> b[ROW_TILES_PER_STEP];
+  for (int col_step = 0; col_step < num_col_steps; col_step++) {
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      const half *tile_ptr = smem_in + ((i * input_stride + col_step) << TILE_DIM_LOG_2);
+      wmma::fill_fragment(acc[i], 0);
+      wmma::load_matrix_sync(b[i], tile_ptr, input_stride);
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      for (uint j = 0; j < ROW_TILES_PER_STEP; j++) {
+        wmma::mma_sync(acc[i], a[i][j], b[j], acc[i]);
+      }
+    }
+    for (uint i = 0; i < ROW_TILES_PER_STEP; i++) {
+      float *tile_ptr = smem_out + i * TILE_DIM * TILE_DIM;
+      wmma::store_matrix_sync(tile_ptr, acc[i], TILE_DIM, wmma::mem_row_major);
+    }
+    __syncwarp();
+    uint gmem_grad_col = (col_step << TILE_DIM_LOG_2) + lane_id;
+    if (gmem_grad_col < num_cols) {
+      for (uint i = 0; i < num_rows; i++) {
+        gmem_grad[i * num_cols + gmem_grad_col] = __float2half(smem_out[(i << TILE_DIM_LOG_2) + lane_id]);
+      }
+    }
+  }
+  if (lane_id < (num_cols >> 2)) {
+    ((float2 *)gmem_mlp_grad)[lane_id] = ((float2 *)gmem_ugrad)[lane_id];
+  }
+}
+
+inline void dotBasedInteractFwd(
+    const void *input, const void *bottom_mlp_output, void *output, uint batch_size, uint num_rows, uint num_cols) {
+  const uint kWarpSize = 32;
+  const uint kWarpSizeLog2 = Log2<kWarpSize>::value;
+  const uint kTileDim = 16;
+  const uint kTileDimLog2 = Log2<kTileDim>::value;
+  const uint warps_per_threadblock = 4;
+  const uint threadblock_size = warps_per_threadblock * 32;
+  const uint kPaddingSize = 1;
+  const uint kRowTilesPerStep = 2;
+  const uint kColTilesPerStep = 1;
+
+  // num tiles
+  uint num_row_tiles = (num_rows + kTileDim - 1) >> kTileDimLog2;
+  uint num_col_tiles = (num_cols + kTileDim - 1) >> kTileDimLog2;
+
+  // number of rows and columns after padding
+  uint num_rows_after_padding = kTileDim << 1;
+  uint num_cols_after_padding = num_col_tiles << kTileDimLog2;
+
+  uint num_row_steps = num_row_tiles / kRowTilesPerStep;
+  uint num_col_steps = num_col_tiles / kColTilesPerStep;
+
+  const uint K_BLOCKS = 8;
+  const uint M_BLOCKS = 2;
+  const uint SKEW_HALF = ((K_BLOCKS % 2) == 0) ? 8 : 0;
+  const uint SMEM_STRIDE = (K_BLOCKS * 16 + SKEW_HALF);
+  // multiple of 2 to guarantee 256-bit alignment for start of the row, at least 16 to safeload a tile
+  const uint smem_rows_per_warp = M_BLOCKS << 4;
+  const uint smem_elems_per_warp_mat = smem_rows_per_warp * SMEM_STRIDE;
+  const uint SKEW_HALF_ACC = ((M_BLOCKS % 2) == 0) ? 8 : 0;
+  const uint SMEM_STRIDE_ACC = (M_BLOCKS * 16 + SKEW_HALF_ACC);
+  const uint smem_elems_per_warp_acc = M_BLOCKS * 16 * SMEM_STRIDE_ACC * 2;  // output in FP32
+  const uint smem_elems_per_warp =
+      (smem_elems_per_warp_mat > smem_elems_per_warp_acc) ? smem_elems_per_warp_mat : smem_elems_per_warp_acc;
+  uint output_size = num_cols + ((num_rows * (num_rows - 1)) >> 1) + kPaddingSize;
+
+  bool float4_predicate = !((num_cols & 7) || (output_size & 7));
+
+  if (float4_predicate) {
+    dotBasedInteractFwdKernel<warps_per_threadblock,
+                              threadblock_size,
+                              M_BLOCKS,
+                              K_BLOCKS,
+                              SMEM_STRIDE,
+                              SMEM_STRIDE_ACC,
+                              kWarpSize,
+                              kWarpSizeLog2,
+                              kTileDim,
+                              kTileDimLog2>
+        <<<(batch_size + warps_per_threadblock - 1) / warps_per_threadblock,
+           threadblock_size,
+           warps_per_threadblock * smem_elems_per_warp * sizeof(__half)>>>((const __half *)input,
+                                                                           (half *)output,
+                                                                           batch_size,
+                                                                           num_rows,
+                                                                           num_cols,
+                                                                           num_rows_after_padding,
+                                                                           num_cols_after_padding,
+                                                                           smem_elems_per_warp,
+                                                                           smem_rows_per_warp,
+                                                                           output_size,
+                                                                           num_row_steps,
+                                                                           num_col_steps);
+  } else {
+    dotBasedInteractFwdKernelNonAligned<warps_per_threadblock,
+                                        threadblock_size,
+                                        M_BLOCKS,
+                                        K_BLOCKS,
+                                        SMEM_STRIDE,
+                                        SMEM_STRIDE_ACC,
+                                        kWarpSize,
+                                        kWarpSizeLog2,
+                                        kTileDim,
+                                        kTileDimLog2>
+        <<<(batch_size + warps_per_threadblock - 1) / warps_per_threadblock,
+           threadblock_size,
+           warps_per_threadblock * smem_elems_per_warp * sizeof(__half)>>>((const __half *)input,
+                                                                           (half *)output,
+                                                                           batch_size,
+                                                                           num_rows,
+                                                                           num_cols,
+                                                                           num_rows_after_padding,
+                                                                           num_cols_after_padding,
+                                                                           smem_elems_per_warp,
+                                                                           smem_rows_per_warp,
+                                                                           output_size,
+                                                                           num_row_steps,
+                                                                           num_col_steps);
+  }
+}
+
+inline void dotBasedInteractBwd(void *input,
+                                void *upstream_grad,
+                                void *grad,
+                                void *bottom_mlp_grad,
+                                uint batch_size,
+                                uint num_rows,
+                                uint num_cols) {
+  const uint kWarpSize = 32;
+  const uint kWarpSizeLog2 = Log2<kWarpSize>::value;
+  const uint kTileDim = 16;
+  const uint kTileDimLog2 = Log2<kTileDim>::value;
+  const uint mem_skew_size = 8;
+  const uint kPaddingSize = 1;
+  const uint kWarpsPerBlock = 4;
+  const uint kWarpsPerBlockLog2 = Log2<kWarpsPerBlock>::value;
+  const uint kNumThreads = kWarpsPerBlock * kWarpSize;
+  const uint kRowTilesPerStep = 2;
+  const uint kColTilesPerStep = 1;
+
+  uint row_tiles_per_step = num_rows > kTileDim ? kRowTilesPerStep : 1;
+
+  // num tiles
+  uint num_row_tiles = (num_rows + kTileDim - 1) >> kTileDimLog2;
+  uint num_col_tiles = (num_cols + kTileDim - 1) >> kTileDimLog2;
+
+  // number of rows and columns after padding
+  uint num_rows_after_padding = kTileDim << 1;
+  uint num_cols_after_padding = num_col_tiles << kTileDimLog2;
+
+  // 2D ugrad size and stride
+  uint interaction_ugrad_2D_stride = num_rows_after_padding + mem_skew_size;
+  uint interaction_ugrad_2D_size_elems = num_rows_after_padding * interaction_ugrad_2D_stride;
+  uint interaction_ugrad_2D_size_bytes = interaction_ugrad_2D_size_elems * sizeof(half);
+
+  // 1D ugrad size
+  uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1;
+  uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize;
+
+  // in_out place size and stride
+  uint input_stride = num_cols_after_padding + mem_skew_size;
+  uint input_size_elems = num_rows_after_padding * input_stride;
+  uint input_size_bytes = input_size_elems * sizeof(half);
+
+  // sample size
+  uint sample_size = num_rows * num_cols;
+
+  // output size
+  uint output_size_elems = kTileDim * kTileDim * kRowTilesPerStep * kColTilesPerStep;
+  uint output_size_bytes = output_size_elems * sizeof(float);
+
+  // staging area size
+  uint staging_area_size_bytes =
+      output_size_bytes > interaction_ugrad_2D_size_bytes ? output_size_bytes : interaction_ugrad_2D_size_bytes;
+
+  // Shared memory size
+  uint shared_mem_per_warp_size_byte = input_size_bytes + staging_area_size_bytes;
+  uint shared_mem_size_bytes = kWarpsPerBlock * shared_mem_per_warp_size_byte;
+
+  uint num_blocks = (batch_size + kWarpsPerBlock - 1) >> kWarpsPerBlockLog2;
+  uint num_row_steps = num_row_tiles / row_tiles_per_step;
+  uint num_col_steps = num_col_tiles / kColTilesPerStep;
+
+  bool float4_predicate = !((interaction_ugrad_size_with_padding & 7) || (num_cols & 7));
+  if (float4_predicate) {
+    dotBasedInteractBwdKernel<kWarpsPerBlock,
+                              kNumThreads,
+                              kRowTilesPerStep,
+                              kColTilesPerStep,
+                              kWarpSize,
+                              kWarpSizeLog2,
+                              kTileDim,
+                              kTileDimLog2>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const half *)input,
+                                                             (const half *)upstream_grad,
+                                                             (half *)grad,
+                                                             (half *)bottom_mlp_grad,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             num_rows_after_padding,
+                                                             num_cols_after_padding,
+                                                             sample_size,
+                                                             interaction_ugrad_size,
+                                                             interaction_ugrad_size_with_padding,
+                                                             interaction_ugrad_2D_size_elems,
+                                                             interaction_ugrad_2D_stride,
+                                                             input_size_elems,
+                                                             input_stride,
+                                                             num_row_steps,
+                                                             num_col_steps,
+                                                             row_tiles_per_step,
+                                                             shared_mem_per_warp_size_byte);
+  } else {
+    dotBasedInteractBwdKernelNonAligned<kWarpsPerBlock,
+                                        kNumThreads,
+                                        kRowTilesPerStep,
+                                        kColTilesPerStep,
+                                        kWarpSize,
+                                        kWarpSizeLog2,
+                                        kTileDim,
+                                        kTileDimLog2>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const half *)input,
+                                                             (const half *)upstream_grad,
+                                                             (half *)grad,
+                                                             (half *)bottom_mlp_grad,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             num_rows_after_padding,
+                                                             num_cols_after_padding,
+                                                             sample_size,
+                                                             interaction_ugrad_size,
+                                                             interaction_ugrad_size_with_padding,
+                                                             interaction_ugrad_2D_size_elems,
+                                                             interaction_ugrad_2D_stride,
+                                                             input_size_elems,
+                                                             input_stride,
+                                                             num_row_steps,
+                                                             num_col_steps,
+                                                             row_tiles_per_step,
+                                                             shared_mem_per_warp_size_byte);
+  }
+}
+
+template <uint THREADBLOCK_SIZE>
+__launch_bounds__(THREADBLOCK_SIZE) __global__
+    void dotBasedInteractF32FwdKernelNonAligned(const float *__restrict input,
+                                                float *__restrict output,
+                                                uint batch_size,
+                                                uint num_rows,
+                                                uint num_cols,
+                                                uint input_size,
+                                                uint output_size,
+                                                uint interaction_output_size) {
+  extern __shared__ float smem_f32_fwd[];
+  float *smem_in = &smem_f32_fwd[0];
+
+  uint input_batch_offset = blockIdx.x * input_size;
+  const float *gmem_in = &input[input_batch_offset];
+
+  uint output_batch_offset = blockIdx.x * output_size;
+  float *gmem_out_bottom_mlp = &output[output_batch_offset];
+  float *gmem_out_interaction = &output[output_batch_offset + num_cols];
+
+  // Load the input - one sample per block
+  for (uint idx = threadIdx.x; idx < input_size; idx += blockDim.x) {
+    smem_in[idx] = gmem_in[idx];
+  }
+  __syncthreads();
+
+  // Copy bottom MLP output to output
+  for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) {
+    gmem_out_bottom_mlp[idx] = smem_in[idx];
+  }
+
+  for (uint idx = threadIdx.x; idx < (interaction_output_size); idx += blockDim.x) {
+    uint elems_per_row = 1;
+    uint index = idx;
+    while (index >= elems_per_row) {
+      index -= elems_per_row;
+      elems_per_row++;
+    }
+    uint target_row = elems_per_row;
+    uint target_col = index;
+
+    float sum = 0;
+    for (uint i = 0; i < num_cols; i++) {
+      float tmp1 = smem_in[target_row * num_cols + i];
+      float tmp2 = smem_in[target_col * num_cols + i];
+      sum = fmaf(tmp1, tmp2, sum);
+    }
+
+    gmem_out_interaction[idx] = sum;
+  }
+
+  gmem_out_interaction[interaction_output_size] = 0;
+}
+
+template <uint THREADBLOCK_SIZE>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractF32FwdKernel(const float *__restrict input,
+                                                                                 float *__restrict output,
+                                                                                 uint batch_size,
+                                                                                 uint num_rows,
+                                                                                 uint num_cols,
+                                                                                 uint input_size,
+                                                                                 uint output_size,
+                                                                                 uint interaction_output_size) {
+  extern __shared__ float smem_f32_fwd[];
+  float *smem_in = &smem_f32_fwd[0];
+
+  uint input_batch_offset = blockIdx.x * input_size;
+  const float *gmem_in = &input[input_batch_offset];
+
+  uint output_batch_offset = blockIdx.x * output_size;
+  float *gmem_out_bottom_mlp = &output[output_batch_offset];
+  float *gmem_out_interaction = &output[output_batch_offset + num_cols];
+
+  // Load the input - one sample per block
+  uint input_size_float4 = input_size >> 2;
+  for (uint idx = threadIdx.x; idx < input_size_float4; idx += blockDim.x) {
+    ((float4 *)smem_in)[idx] = ((float4 *)gmem_in)[idx];
+  }
+  __syncthreads();
+
+  // Copy bottom MLP output to output
+  uint btm_mlp_out_size_float4 = num_cols >> 2;
+  for (uint idx = threadIdx.x; idx < btm_mlp_out_size_float4; idx += blockDim.x) {
+    ((float4 *)gmem_out_bottom_mlp)[idx] = ((float4 *)smem_in)[idx];
+  }
+
+  for (uint idx = threadIdx.x; idx < (interaction_output_size); idx += blockDim.x) {
+    uint elems_per_row = 1;
+    uint index = idx;
+    while (index >= elems_per_row) {
+      index -= elems_per_row;
+      elems_per_row++;
+    }
+    uint target_row = elems_per_row;
+    uint target_col = index;
+
+    float4 sum;
+    sum.x = 0;
+    sum.y = 0;
+    sum.z = 0;
+    sum.w = 0;
+    uint num_cols_float4 = num_cols >> 2;
+    for (uint i = 0; i < num_cols_float4; i++) {
+      float4 tmp1 = ((float4 *)smem_in)[target_row * num_cols_float4 + i];
+      float4 tmp2 = ((float4 *)smem_in)[target_col * num_cols_float4 + i];
+      sum.x = fmaf(tmp1.x, tmp2.x, sum.x);
+      sum.y = fmaf(tmp1.y, tmp2.y, sum.y);
+      sum.z = fmaf(tmp1.z, tmp2.z, sum.z);
+      sum.w = fmaf(tmp1.w, tmp2.w, sum.w);
+    }
+
+    gmem_out_interaction[idx] = sum.x + sum.y + sum.z + sum.w;
+  }
+
+  gmem_out_interaction[interaction_output_size] = 0;
+}
+
+inline void dotBasedInteractF32Fwd(const void *input,
+                                   const void *bottom_mlp_output,
+                                   const void *output,
+                                   uint batch_size,
+                                   uint num_rows,
+                                   uint num_cols) {
+  const uint kPaddingSize = 1;
+  const uint kNumThreads = 128;
+  uint num_blocks = batch_size;
+
+  // Output
+  uint interaction_output_size = (num_rows * (num_rows - 1)) >> 1;
+  uint output_size = num_cols + interaction_output_size + kPaddingSize;
+
+  // Input
+  uint input_size = num_rows * num_cols;
+
+  uint shared_mem_size_elems = input_size;
+  uint shared_mem_size_bytes = shared_mem_size_elems << 2;  // F32 Kernel
+
+  bool float4_predicate = !((num_cols & 3) || (output_size & 3));
+
+  if (float4_predicate) {
+    dotBasedInteractF32FwdKernel<kNumThreads>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const float *)input,
+                                                             (float *)output,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             input_size,
+                                                             output_size,
+                                                             interaction_output_size);
+  } else {
+    dotBasedInteractF32FwdKernelNonAligned<kNumThreads>
+        <<<num_blocks, kNumThreads, shared_mem_size_bytes>>>((const float *)input,
+                                                             (float *)output,
+                                                             batch_size,
+                                                             num_rows,
+                                                             num_cols,
+                                                             input_size,
+                                                             output_size,
+                                                             interaction_output_size);
+  }
+}
+
+template <uint THREADBLOCK_SIZE>
+__launch_bounds__(THREADBLOCK_SIZE) __global__
+    void dotBasedInteractF32BwdKernelNonAligned(const float *__restrict input,
+                                                const float *__restrict upstream_grad,
+                                                float *__restrict grad,
+                                                float *__restrict bottom_mlp_grad,
+                                                uint batch_size,
+                                                uint num_rows,
+                                                uint num_cols,
+                                                uint input_size,
+                                                uint ugrad_size,
+                                                uint interaction_ugrad_size) {
+  extern __shared__ float smem_f32_bwd[];
+  float *smem_in = &smem_f32_bwd[0];
+  float *smem_interaction_ugrad = &smem_f32_bwd[input_size];
+
+  // Input
+  uint input_batch_offset = blockIdx.x * input_size;
+  const float *gmem_in = &input[input_batch_offset];
+
+  // Gradient
+  const uint &grad_batch_offset = input_batch_offset;
+  float *gmem_mlp_grad = &bottom_mlp_grad[blockIdx.x * num_cols];
+  float *gmem_interaction_grad = &grad[grad_batch_offset];
+
+  // Upstream Gradient
+  uint upstream_grad_batch_offset = blockIdx.x * ugrad_size;
+  const float *gmem_mlp_ugrad = &upstream_grad[upstream_grad_batch_offset];
+  const float *gmem_interaction_ugrad = &upstream_grad[upstream_grad_batch_offset + num_cols];
+
+  // input -> shared memory
+  for (uint idx = threadIdx.x; idx < input_size; idx += blockDim.x) {
+    smem_in[idx] = gmem_in[idx];
+  }
+
+  // Interaction Upstream Grad -> Shared Memory
+  for (uint idx = threadIdx.x; idx < interaction_ugrad_size; idx += blockDim.x) {
+    smem_interaction_ugrad[idx] = gmem_interaction_ugrad[idx];
+  }
+  __syncthreads();
+
+  // Copy the upstream gradient w.r.t to mlp to it's corresponding memory location.
+  for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) {
+    gmem_mlp_grad[idx] = gmem_mlp_ugrad[idx];
+  }
+
+  for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) {
+    size_t grad_idx = idx;
+    for (uint row_idx = 0; row_idx < num_rows; row_idx++) {
+      float sum = 0;
+      size_t upstream_grad_offset = (row_idx * (row_idx - 1)) >> 1;
+      for (int k = 0; k < row_idx; k++) {
+        sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + k], sum);
+      }
+      for (int k = row_idx + 1; k < num_rows; k++) {
+        upstream_grad_offset = (k * (k - 1)) >> 1;  // TODO: this can become a sum
+        sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + row_idx], sum);
+      }
+      gmem_interaction_grad[grad_idx] = sum;
+      grad_idx += num_cols;
+    }
+  }
+}
+
+template <uint THREADBLOCK_SIZE>
+__launch_bounds__(THREADBLOCK_SIZE) __global__ void dotBasedInteractF32BwdKernel(const float *__restrict input,
+                                                                                 const float *__restrict upstream_grad,
+                                                                                 float *__restrict grad,
+                                                                                 float *__restrict bottom_mlp_grad,
+                                                                                 uint batch_size,
+                                                                                 uint num_rows,
+                                                                                 uint num_cols,
+                                                                                 uint input_size,
+                                                                                 uint ugrad_size,
+                                                                                 uint interaction_ugrad_size) {
+  extern __shared__ float smem_f32_bwd[];
+  float *smem_in = &smem_f32_bwd[0];
+  float *smem_interaction_ugrad = &smem_f32_bwd[input_size];
+
+  // Input
+  uint input_batch_offset = blockIdx.x * input_size;
+  const float *gmem_in = &input[input_batch_offset];
+
+  // Gradient
+  const uint &grad_batch_offset = input_batch_offset;
+  float *gmem_mlp_grad = &bottom_mlp_grad[blockIdx.x * num_cols];
+  float *gmem_interaction_grad = &grad[grad_batch_offset];
+
+  // Upstream Gradient
+  uint upstream_grad_batch_offset = blockIdx.x * ugrad_size;
+  const float *gmem_mlp_ugrad = &upstream_grad[upstream_grad_batch_offset];
+  const float *gmem_interaction_ugrad = &upstream_grad[upstream_grad_batch_offset + num_cols];
+
+  // input -> shared memory
+  uint input_size_float4 = input_size >> 2;
+  for (uint idx = threadIdx.x; idx < input_size_float4; idx += blockDim.x) {
+    ((float4 *)smem_in)[idx] = ((float4 *)gmem_in)[idx];
+  }
+
+  // Interaction Upstream Grad -> Shared Memory
+  uint upstream_grad_size_float4 = interaction_ugrad_size >> 2;
+  for (uint idx = threadIdx.x; idx < upstream_grad_size_float4; idx += blockDim.x) {
+    ((float4 *)smem_interaction_ugrad)[idx] = ((float4 *)gmem_interaction_ugrad)[idx];
+  }
+
+  uint vectorized_load_offset = (upstream_grad_size_float4 << 2);
+  for (uint idx = vectorized_load_offset + threadIdx.x; idx < interaction_ugrad_size; idx += blockDim.x) {
+    smem_interaction_ugrad[idx] = gmem_interaction_ugrad[idx];
+  }
+  __syncthreads();
+
+  // Copy the upstream gradient w.r.t to mlp to it's corresponding memory location.
+  for (uint idx = threadIdx.x; idx < (num_cols >> 2); idx += blockDim.x) {
+    ((float4 *)gmem_mlp_grad)[idx] = ((float4 *)gmem_mlp_ugrad)[idx];
+  }
+
+  for (uint idx = threadIdx.x; idx < num_cols; idx += blockDim.x) {
+    size_t grad_idx = idx;
+    for (uint row_idx = 0; row_idx < num_rows; row_idx++) {
+      float sum = 0;
+      size_t upstream_grad_offset = (row_idx * (row_idx - 1)) >> 1;
+      for (int k = 0; k < row_idx; k++) {
+        sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + k], sum);
+      }
+      for (int k = row_idx + 1; k < num_rows; k++) {
+        upstream_grad_offset = (k * (k - 1)) >> 1;  // TODO: this can become a sum
+        sum = fmaf(smem_in[k * num_cols + idx], smem_interaction_ugrad[upstream_grad_offset + row_idx], sum);
+      }
+      gmem_interaction_grad[grad_idx] = sum;
+      grad_idx += num_cols;
+    }
+  }
+}
+
+inline void dotBasedInteractF32Bwd(const void *input,
+                                   const void *upstream_grad,
+                                   void *grad,
+                                   void *bottom_mlp_grad,
+                                   uint batch_size,
+                                   uint num_rows,
+                                   uint num_cols) {
+  const uint kPaddingSize = 1;
+  const uint kNumThreads = 128;
+
+  uint num_blocks = batch_size;
+
+  uint input_size = num_rows * num_cols;
+
+  // 1D ugrad size
+  uint interaction_ugrad_size = num_rows * (num_rows - 1) >> 1;
+  uint interaction_ugrad_size_with_padding = interaction_ugrad_size + kPaddingSize;
+  uint ugrad_size = num_cols + interaction_ugrad_size_with_padding;
+
+  // input space + upstream grad space
+  uint smem_size_elems = input_size + interaction_ugrad_size;
+  uint smem_size_bytes = smem_size_elems << 2;  // F32 Kernel
+
+  bool float4_predicate = !((interaction_ugrad_size_with_padding & 3) || (num_cols & 3));
+  if (float4_predicate) {
+    dotBasedInteractF32BwdKernel<kNumThreads>
+        <<<num_blocks, kNumThreads, smem_size_bytes>>>((const float *)input,
+                                                       (const float *)upstream_grad,
+                                                       (float *)grad,
+                                                       (float *)bottom_mlp_grad,
+                                                       batch_size,
+                                                       num_rows,
+                                                       num_cols,
+                                                       input_size,
+                                                       ugrad_size,
+                                                       interaction_ugrad_size);
+  } else {
+    dotBasedInteractF32BwdKernelNonAligned<kNumThreads>
+        <<<num_blocks, kNumThreads, smem_size_bytes>>>((const float *)input,
+                                                       (const float *)upstream_grad,
+                                                       (float *)grad,
+                                                       (float *)bottom_mlp_grad,
+                                                       batch_size,
+                                                       num_rows,
+                                                       num_cols,
+                                                       input_size,
+                                                       ugrad_size,
+                                                       interaction_ugrad_size);
+  }
+}
+

+ 68 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_volta/dot_based_interact_pytorch_types.cu

@@ -0,0 +1,68 @@
+#include <torch/extension.h>
+#include <torch/types.h>
+#include <stdexcept>
+#include "dot_based_interact.cu"
+
+torch::Tensor dotBasedInteractFwdTorch(torch::Tensor input, torch::Tensor bottom_mlp_output) {
+  const uint kPaddingSize = 1;
+  auto size = input.sizes();
+  auto batch_size = size[0];
+  auto num_rows = size[1];
+  auto num_cols = size[2];
+  uint output_size = ((num_rows * (num_rows - 1)) >> 1) + num_cols + kPaddingSize;
+
+  int64_t outputShape[2] = {batch_size, output_size};
+  auto output = torch::empty(c10::IntArrayRef(outputShape), input.options());
+  if (input.scalar_type() == torch::ScalarType::Half && bottom_mlp_output.scalar_type() == torch::ScalarType::Half) {
+    dotBasedInteractFwd(input.contiguous().data_ptr<at::Half>(),
+                        bottom_mlp_output.contiguous().data_ptr<at::Half>(),
+                        output.contiguous().data_ptr<at::Half>(),
+                        batch_size,
+                        num_rows,
+                        num_cols);
+  } else if (input.scalar_type() == torch::ScalarType::Float &&
+             bottom_mlp_output.scalar_type() == torch::ScalarType::Float) {
+    dotBasedInteractF32Fwd(input.contiguous().data_ptr<float>(),
+                           bottom_mlp_output.contiguous().data_ptr<float>(),
+                           output.contiguous().data_ptr<float>(),
+                           batch_size,
+                           num_rows,
+                           num_cols);
+  } else {
+    throw std::invalid_argument("Invalid input type.");
+  }
+  return output;
+}
+
+std::vector<torch::Tensor> dotBasedInteractBwdTorch(torch::Tensor input, torch::Tensor upstreamGrad) {
+  auto size = input.sizes();
+  auto batch_size = size[0];
+  auto num_rows = size[1];
+  auto num_cols = size[2];
+
+  auto outputGrad = torch::empty_like(input);
+  int64_t outputShape[2] = {batch_size, num_cols};
+  auto mlp_grad = torch::empty(c10::IntArrayRef(outputShape), input.options());
+
+  if (input.scalar_type() == torch::ScalarType::Half && upstreamGrad.scalar_type() == torch::ScalarType::Half) {
+    dotBasedInteractBwd(input.contiguous().data_ptr<at::Half>(),
+                        upstreamGrad.contiguous().data_ptr<at::Half>(),
+                        outputGrad.contiguous().data_ptr<at::Half>(),
+                        mlp_grad.contiguous().data_ptr<at::Half>(),
+                        batch_size,
+                        num_rows,
+                        num_cols);
+  } else if (input.scalar_type() == torch::ScalarType::Float &&
+             upstreamGrad.scalar_type() == torch::ScalarType::Float) {
+    dotBasedInteractF32Bwd(input.contiguous().data_ptr<float>(),
+                           upstreamGrad.contiguous().data_ptr<float>(),
+                           outputGrad.contiguous().data_ptr<float>(),
+                           mlp_grad.contiguous().data_ptr<float>(),
+                           batch_size,
+                           num_rows,
+                           num_cols);
+  } else {
+    throw std::invalid_argument("Invalid input type.");
+  }
+  return {outputGrad, mlp_grad};
+}

+ 13 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/dot_based_interact_volta/pytorch_ops.cpp

@@ -0,0 +1,13 @@
+#include <torch/extension.h>
+
+torch::Tensor dotBasedInteractFwdTorch(torch::Tensor input,
+                                       torch::Tensor bottom_mlp_output);
+std::vector<torch::Tensor> dotBasedInteractBwdTorch(torch::Tensor input,
+                                                    torch::Tensor upstreamGrad);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("dotBasedInteractFwd", &dotBasedInteractFwdTorch, "", py::arg("input"),
+        py::arg("bottom_mlp_output"));
+  m.def("dotBasedInteractBwd", &dotBasedInteractBwdTorch, "", py::arg("input"),
+        py::arg("upstreamGrad"));
+}

+ 293 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/gather_gpu_fused.cu

@@ -0,0 +1,293 @@
+#include <iostream>
+#include <cuda_runtime_api.h>
+#include <c10/cuda/CUDAStream.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#define CHK_CUDA(expression)                                                                                        \
+  {                                                                                                                 \
+    cudaError_t status = (expression);                                                                              \
+    if (status != cudaSuccess) {                                                                                    \
+      std::cerr << "Error in file: " << __FILE__ << ", on line: " << __LINE__ << ": " << cudaGetErrorString(status) \
+                << std::endl;                                                                                       \
+      std::exit(EXIT_FAILURE);                                                                                      \
+    }                                                                                                               \
+  }
+
+// only 4 element vectorized types are implemented - can be done for other types
+// load/store by "mask" vars
+// assignments by "val" vars
+template <class DTYPE>
+struct VecType4{};
+
+template <>
+struct VecType4<__half> {
+  typedef float2 Type;
+  typedef struct __align__(8) {
+    __half x;
+    __half y;
+    __half z;
+    __half w;
+  } half4;
+  union Data {
+    half4 val;
+    Type mask;
+  } data;
+
+  __device__ VecType4() {
+    data.mask = make_float2(0.0f, 0.0f);
+  }
+
+  __device__ VecType4& operator=(float4 &in) {
+    data.val.x = __float2half(in.x);
+    data.val.y = __float2half(in.y);
+    data.val.z = __float2half(in.z);
+    data.val.w = __float2half(in.w);
+
+    return *this;
+  }
+
+  __device__ VecType4& operator=(half4 &in) {
+    data.val = in;
+    return *this;
+  }
+};
+
+template <>
+struct VecType4<float> {
+  typedef float4 Type;
+  union Data {
+    Type val;
+    Type mask;
+  } data;
+
+  __device__ VecType4() {
+    data.val.x = 0.0f;
+    data.val.y = 0.0f;
+    data.val.z = 0.0f;
+    data.val.w = 0.0f;
+  }
+
+  __device__ VecType4& operator=(VecType4<__half>::half4 &in) {
+    data.val.x = __half2float(in.x);
+    data.val.y = __half2float(in.y);
+    data.val.z = __half2float(in.z);
+    data.val.w = __half2float(in.w);
+
+    return *this;
+  }
+  __device__ VecType4& operator=(float4 &in) {
+    data.val = in;
+    return *this;
+  }
+};
+
+//  -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__
+// above default build params to Torch extensions requires this extensive juggling around
+template <typename ITYPE, typename OTYPE, typename std::enable_if<(std::is_same<ITYPE, float>::value &&
+                                                                    std::is_same<OTYPE, __half>::value),
+                                                                    ITYPE>::type * = nullptr>
+__device__ __host__ __forceinline__  OTYPE fp_type_cast(ITYPE input) {
+  return __float2half(input);
+}
+
+template <typename ITYPE, typename OTYPE, typename std::enable_if<(std::is_same<ITYPE, __half>::value &&
+                                                                    std::is_same<OTYPE, float>::value),
+                                                                    ITYPE>::type * = nullptr>
+__device__ __host__ __forceinline__  OTYPE fp_type_cast(ITYPE input) {
+  return __half2float(input);
+}
+
+template <typename ITYPE, typename OTYPE, typename std::enable_if<std::is_same<ITYPE, OTYPE>::value,
+                                                                    ITYPE>::type * = nullptr>
+__device__ __host__ __forceinline__  OTYPE fp_type_cast(ITYPE input) {
+  return input;
+}
+
+// this kernel assumes embedding vector_width of 128
+template <typename ITYPE, typename OTYPE>
+__global__ void lookupEmbeddings(ITYPE *embeddingTable, int64_t *offsets,
+                                    int64_t *indices, OTYPE *outLookup, int batch_size) {
+
+  typedef typename VecType4<ITYPE>::Type invec4;
+  typedef typename VecType4<OTYPE>::Type outvec4;
+
+  int vector_width = 128;
+  const int fea_count = 26;
+
+  int lane_id = threadIdx.x % warpSize;
+  int warp_id = threadIdx.x / warpSize;
+  int num_warps = blockDim.x / warpSize;
+  int start_idx = warp_id * fea_count + lane_id + blockIdx.x * (num_warps * fea_count);
+
+  int64_t lane_offset = 0;
+  if (lane_id < fea_count)
+    lane_offset = offsets[lane_id];
+
+  while (1) {
+    int64_t lookup_idx = -1;
+    if (lane_id < fea_count && start_idx < (batch_size * fea_count)) {
+      lookup_idx = indices[start_idx] + lane_offset;
+  }
+
+  if (__all_sync(0xffffffff, lookup_idx == -1))
+    break;
+
+  for (int i = 0; i < fea_count; i++) {
+    int64_t table_idx = __shfl_sync(0xffffffff, lookup_idx, i);
+
+    if (table_idx != -1) {
+      invec4 *vec_embedding_table = reinterpret_cast<invec4*>(embeddingTable);
+      outvec4 *vec_embedding_out = reinterpret_cast<outvec4*>(outLookup);
+
+      int64_t out_idx = start_idx - lane_id + i;
+      out_idx *= vector_width;
+
+      int vector_inst_width = 4;    // 128 bit loads, 4-floats
+      int64_t vec_in_idx = ((table_idx * vector_width) + (lane_id * vector_inst_width)) >> 2;
+      int64_t vec_out_idx = (out_idx + (lane_id * vector_inst_width)) >> 2;
+
+      VecType4<ITYPE> input_elements;
+      input_elements.data.mask = vec_embedding_table[vec_in_idx];
+      VecType4<OTYPE> output_elements;
+      output_elements = input_elements.data.val;
+      vec_embedding_out[vec_out_idx] = output_elements.data.mask;
+    }
+  }
+
+  start_idx += (gridDim.x * num_warps * fea_count);
+  }
+}
+
+__global__ void indices_offset_addition(int64_t *indices, int64_t *offsets, int64_t *output_indices,
+                                          int batch_size) {
+  const int fea_count = 26;
+  __shared__ int64_t smem_offsets[fea_count];
+
+  if (threadIdx.x < fea_count) {
+    smem_offsets[threadIdx.x] = offsets[threadIdx.x];
+  }
+  __syncthreads();
+
+  int start_idx = threadIdx.x + blockIdx.x * blockDim.x;
+  for (int i = start_idx; i < (batch_size * fea_count); i+=(gridDim.x * blockDim.x)) {
+    output_indices[i] = indices[i] + smem_offsets[i % fea_count];
+  }
+}
+
+template <typename ITYPE, typename OTYPE>
+__global__ void gradient_copy_kernel(ITYPE *input_gradient, OTYPE *output_gradient, int64_t num_elements) {
+  typedef typename VecType4<ITYPE>::Type invec4;
+  typedef typename VecType4<OTYPE>::Type outvec4;
+
+  invec4 *vec_input_gradient = reinterpret_cast<invec4*>(input_gradient);
+  outvec4 *vec_output_gradient = reinterpret_cast<outvec4*>(output_gradient);
+
+  int64_t start_idx = threadIdx.x + blockIdx.x * blockDim.x;
+  for (int64_t i = start_idx; i < num_elements / 4; i+= (gridDim.x * blockDim.x)) {
+    VecType4<ITYPE> input_elements;
+    input_elements.data.mask = vec_input_gradient[i];
+    VecType4<OTYPE> output_elements;
+    output_elements = input_elements.data.val;
+    vec_output_gradient[i] = output_elements.data.mask;
+  }
+  int elements_left = num_elements % 4;
+
+  if (threadIdx.x == 0 && elements_left != 0) {
+    while(elements_left) {
+      int64_t idx = num_elements - elements_left;
+      output_gradient[idx] = fp_type_cast<ITYPE, OTYPE>(input_gradient[idx]);
+      elements_left--;
+    }
+  }
+}
+
+// kernels are fully instantiation type compatible float<->float , float<->Half, half<->half
+// but their runner functions are not instantiated for all types
+template <typename ITYPE, typename OTYPE>
+void gather_gpu_fused_fwd(ITYPE *embeddingTablePtr, int64_t *indices_offset, int64_t *lookup_indices,
+                            OTYPE *outputPtr, int batch_size) {};
+
+template <>
+void gather_gpu_fused_fwd(float *embeddingTablePtr, int64_t *indices_offset, int64_t *lookup_indices,
+                            c10::Half *outputPtr, int batch_size) {
+
+  auto deviceProp = at::cuda::getCurrentDeviceProperties();
+  dim3 block(deviceProp->maxThreadsPerBlock, 1, 1);
+  dim3 grid((deviceProp->multiProcessorCount * deviceProp->maxThreadsPerMultiProcessor) / deviceProp->maxThreadsPerBlock,
+              1, 1);
+
+  cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
+
+  lookupEmbeddings<float, __half><<<grid, block, 0, stream>>>(embeddingTablePtr, indices_offset, lookup_indices, (__half*)outputPtr, batch_size);
+  CHK_CUDA(cudaGetLastError());
+}
+
+template <>
+void gather_gpu_fused_fwd(float *embeddingTablePtr, int64_t *indices_offset, int64_t *lookup_indices,
+                            float *outputPtr, int batch_size) {
+
+  auto deviceProp = at::cuda::getCurrentDeviceProperties();
+  dim3 block(deviceProp->maxThreadsPerBlock, 1, 1);
+  dim3 grid((deviceProp->multiProcessorCount * deviceProp->maxThreadsPerMultiProcessor) / deviceProp->maxThreadsPerBlock,
+              1, 1);
+
+  cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
+
+  lookupEmbeddings<float, float><<<grid, block, 0, stream>>>(embeddingTablePtr, indices_offset, lookup_indices, outputPtr, batch_size);
+  CHK_CUDA(cudaGetLastError());
+}
+
+template <>
+void gather_gpu_fused_fwd(c10::Half *embeddingTablePtr, int64_t *indices_offset, int64_t *lookup_indices,
+                            c10::Half *outputPtr, int batch_size) {
+
+  auto deviceProp = at::cuda::getCurrentDeviceProperties();
+  dim3 block(deviceProp->maxThreadsPerBlock, 1, 1);
+  dim3 grid((deviceProp->multiProcessorCount * deviceProp->maxThreadsPerMultiProcessor) / deviceProp->maxThreadsPerBlock,
+              1, 1);
+
+  cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
+
+  lookupEmbeddings<__half, __half><<<grid, block, 0, stream>>>((__half*)embeddingTablePtr, indices_offset, lookup_indices, (__half*)outputPtr, batch_size);
+  CHK_CUDA(cudaGetLastError());
+}
+
+template <typename ITYPE, typename OTYPE>
+void gather_gpu_fused_bwd(ITYPE *input_gradient, int64_t *lookup_indices, int64_t *offsets, OTYPE *out_gradient,
+                            int64_t *out_indices, int batch_size, int num_features, int embed_vector_dim) {};
+
+template <>
+void gather_gpu_fused_bwd(c10::Half *input_gradient, int64_t *lookup_indices, int64_t *offsets, float *out_gradient,
+                            int64_t *out_indices, int batch_size, int num_features, int embed_vector_dim) {
+  // offset addition to indices
+  auto deviceProp = at::cuda::getCurrentDeviceProperties();
+  dim3 block(deviceProp->maxThreadsPerBlock, 1, 1);
+  dim3 grid((deviceProp->multiProcessorCount * deviceProp->maxThreadsPerMultiProcessor) / deviceProp->maxThreadsPerBlock,
+              1, 1);
+  cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
+
+  // indices - offset addition kernel
+  indices_offset_addition<<<grid, block, 0, stream>>>(lookup_indices, offsets, out_indices, batch_size);
+  CHK_CUDA(cudaGetLastError());
+
+  gradient_copy_kernel<__half, float><<<grid, block, 0, stream>>>((__half *)input_gradient, out_gradient, (int64_t)batch_size * num_features * embed_vector_dim );
+  CHK_CUDA(cudaGetLastError());
+}
+
+template <>
+void gather_gpu_fused_bwd(float *input_gradient, int64_t *lookup_indices, int64_t *offsets, float *out_gradient,
+                            int64_t *out_indices, int batch_size, int num_features, int embed_vector_dim) {
+  // offset addition to indices
+  auto deviceProp = at::cuda::getCurrentDeviceProperties();
+  dim3 block(deviceProp->maxThreadsPerBlock, 1, 1);
+  dim3 grid((deviceProp->multiProcessorCount * deviceProp->maxThreadsPerMultiProcessor) / deviceProp->maxThreadsPerBlock,
+              1, 1);
+  cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
+
+  // indices - offset addition kernel
+  indices_offset_addition<<<grid, block, 0, stream>>>(lookup_indices, offsets, out_indices, batch_size);
+  CHK_CUDA(cudaGetLastError());
+
+  gradient_copy_kernel<float, float><<<grid, block, 0, stream>>>(input_gradient, out_gradient, (int64_t)batch_size * num_features * embed_vector_dim );
+  CHK_CUDA(cudaGetLastError());
+}

+ 100 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/gather_gpu_fused_pytorch_impl.cu

@@ -0,0 +1,100 @@
+#include <torch/extension.h>
+#include <torch/types.h>
+#include <stdexcept>
+#include "gather_gpu_fused.cu"
+
+// plugin functions instantiated to do only mixed-precision execution
+torch::Tensor gatherGPUFusedFwdTorch(torch::Tensor embedding, torch::Tensor indices, torch::Tensor offsets,
+                                        bool amp_train) {
+  auto size = indices.sizes();
+  auto batch_size = size[0];
+  auto num_features = size[1];
+
+  size = embedding.sizes();
+  auto embedding_vector_dim = size[1];
+  auto embedding_table_rows = size[0];    // not really need this
+
+//   if (embedding.scalar_type() != torch::ScalarType::Float) {
+//     throw std::invalid_argument("Invalid input type.");
+//   }
+
+  int64_t outputShape[3] = {batch_size, num_features, embedding_vector_dim};
+  torch::Tensor output;
+
+  if (embedding.scalar_type() == torch::ScalarType::Float) {
+    if (amp_train) {
+        output = torch::empty(c10::IntArrayRef(outputShape), embedding.options().dtype(torch::ScalarType::Half));
+        gather_gpu_fused_fwd(embedding.contiguous().data_ptr<float>(),
+                                offsets.contiguous().data_ptr<int64_t>(),
+                                indices.contiguous().data_ptr<int64_t>(),
+                                output.contiguous().data_ptr<at::Half>(),
+                                batch_size);
+    }
+    else {
+        output = torch::empty(c10::IntArrayRef(outputShape), embedding.options().dtype(torch::ScalarType::Float));
+        gather_gpu_fused_fwd(embedding.contiguous().data_ptr<float>(),
+                                offsets.contiguous().data_ptr<int64_t>(),
+                                indices.contiguous().data_ptr<int64_t>(),
+                                output.contiguous().data_ptr<float>(),
+                                batch_size);
+    }
+  }
+  else {
+    output = torch::empty(c10::IntArrayRef(outputShape), embedding.options().dtype(torch::ScalarType::Half));
+    gather_gpu_fused_fwd(embedding.contiguous().data_ptr<at::Half>(),
+                            offsets.contiguous().data_ptr<int64_t>(),
+                            indices.contiguous().data_ptr<int64_t>(),
+                            output.contiguous().data_ptr<at::Half>(),
+                            batch_size);
+  }
+  return output;
+}
+
+torch::Tensor gatherGPUFusedBwdTorch(torch::Tensor embedding, torch::Tensor indices,
+                                          torch::Tensor offsets, torch::Tensor upstreamGrad) {
+  if (embedding.scalar_type() != torch::ScalarType::Float) {
+    throw std::invalid_argument("Invalid input type.");
+  }
+
+  auto size = upstreamGrad.sizes();
+  auto batch_size = size[0];
+  auto num_features = size[1];
+  auto embedding_vector_dim = size[2];
+
+  size = indices.sizes();
+  auto sparse_tensor_indices_dim = size[0] * size[1];
+  int64_t indices_outputShape[2] = {1, sparse_tensor_indices_dim};
+
+  auto sparse_tensor_values_0 = batch_size * num_features;
+  auto sparse_tensor_values_1 = embedding_vector_dim;
+  int64_t values_outputShape[2] = {sparse_tensor_values_0, sparse_tensor_values_1};
+
+  auto sparse_grad_indices_tensor = torch::empty(c10::IntArrayRef(indices_outputShape), indices.options());
+
+  auto sparse_grad_values_tensor = torch::empty(c10::IntArrayRef(values_outputShape),
+                                                  upstreamGrad.options().dtype(torch::ScalarType::Float));
+
+  // this is the shape of output gradient vector
+  int64_t sparse_tensor_shape[2] = {embedding.sizes()[0], embedding_vector_dim};
+
+  if (upstreamGrad.scalar_type() == torch::ScalarType::Half) {
+    gather_gpu_fused_bwd(upstreamGrad.contiguous().data_ptr<at::Half>(),
+                            indices.contiguous().data_ptr<int64_t>(),
+                            offsets.contiguous().data_ptr<int64_t>(),
+                            sparse_grad_values_tensor.contiguous().data_ptr<float>(),
+                            sparse_grad_indices_tensor.contiguous().data_ptr<int64_t>(),
+                        (int)batch_size, (int)num_features, (int)embedding_vector_dim);
+  }
+  else {
+    gather_gpu_fused_bwd(upstreamGrad.contiguous().data_ptr<float>(),
+                            indices.contiguous().data_ptr<int64_t>(),
+                            offsets.contiguous().data_ptr<int64_t>(),
+                            sparse_grad_values_tensor.contiguous().data_ptr<float>(),
+                            sparse_grad_indices_tensor.contiguous().data_ptr<int64_t>(),
+                        (int)batch_size, (int)num_features, (int)embedding_vector_dim);
+  }
+
+  return torch::_sparse_coo_tensor_with_dims_and_tensors(1, 1, c10::IntArrayRef(sparse_tensor_shape),
+                                                    sparse_grad_indices_tensor, sparse_grad_values_tensor,
+                                                    sparse_grad_values_tensor.options().layout(c10::Layout::Sparse));
+}

+ 22 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/pytorch_embedding_ops.cpp

@@ -0,0 +1,22 @@
+#include <torch/extension.h>
+
+torch::Tensor gatherGPUFusedFwdTorch(torch::Tensor embedding,
+                                       torch::Tensor indices,
+                                       torch::Tensor offsets,
+                                       bool amp_train);
+
+torch::Tensor gatherGPUFusedBwdTorch(torch::Tensor embedding,
+                                       torch::Tensor indices,
+                                       torch::Tensor offsets,
+                                       torch::Tensor upstreamGrad);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("gather_gpu_fused_fwd", &gatherGPUFusedFwdTorch, "", py::arg("embedding"),
+                                                              py::arg("indices"),
+                                                              py::arg("offsets"),
+                                                              py::arg("amp_train"));
+  m.def("gather_gpu_fused_bwd", &gatherGPUFusedBwdTorch, "", py::arg("embedding"),
+                                                              py::arg("indices"),
+                                                              py::arg("offsets"),
+                                                              py::arg("upstreamGrad"));
+}

+ 21 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/sparse_gather/common.h

@@ -0,0 +1,21 @@
+// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+
+#ifndef COMMON_H_
+#define COMMON_H_
+
+using ULLInt = unsigned long long int;
+
+// Use to compute things like number of blocks
+#define CEIL_DIV_INT(a, b) ((a + b - 1) / b)
+
+#define CUDA_CHECK(cmd)                                                                     \
+  do {                                                                                      \
+    cudaError_t e = cmd;                                                                    \
+    if (e != cudaSuccess) {                                                                 \
+      printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
+      exit(EXIT_FAILURE);                                                                   \
+    }                                                                                       \
+  } while (0)
+
+
+#endif  // COMMON_H_

+ 171 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/sparse_gather/gather_gpu.cu

@@ -0,0 +1,171 @@
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+#include <math.h>
+
+#include <cassert>
+#include <iostream>
+
+#include <ATen/cuda/CUDAContext.h>
+#include <torch/extension.h>
+
+// For simplicity reason, boundry checks are removed
+// All the  kernels MUST be launched with grid size = batch size and block size = embedding size
+
+__global__ void GatherKernel(const float* params,
+                             int64_t num_features,
+                             int embed_size,
+                             int batch_size,
+                             int query_nnz,
+                             const int64_t* indices,
+                             float* ret) {
+  int tid = threadIdx.x, bid = blockIdx.x;
+
+  extern __shared__ int shmem_indices[];
+
+  // each CTA load one row of indices in the mini batch into shared memory
+  for (int i = tid; i < query_nnz; i += blockDim.x) {
+    shmem_indices[i] = indices[query_nnz * bid + i];
+  }
+  __syncthreads();
+
+#pragma unroll
+  for (int i = 0; i < query_nnz; ++i) {
+    // printf("%d, %d, %d\n", bid, i, shmem_indices[i]);
+    ret[(bid * query_nnz + i) * embed_size + tid] =
+        params[(int64_t)shmem_indices[i] * embed_size + tid];
+  }
+}
+
+__global__ void OneHotKernel(const float* params,
+                             int64_t num_features,
+                             int embed_size,
+                             int batch_size,
+                             const int64_t* indices,
+                             float* ret) {
+  int tid = threadIdx.x, bid = blockIdx.x;
+
+  ret[bid * embed_size + tid] = params[(int64_t)indices[bid] * embed_size + tid];
+}
+
+// grads is used to update params directly by atomic instead of forming wgrad
+// Only SGD without momentum and without weight decay is supported
+__global__ void GatherBackwardFuseSgdKernel(const float* grads,
+                                            int64_t num_features,
+                                            int embed_size,
+                                            int batch_size,
+                                            int query_nnz,
+                                            const int64_t* indices,
+                                            float lr,
+                                            float* params) {
+  int tid = threadIdx.x, bid = blockIdx.x;
+
+  extern __shared__ int shmem_indices[];
+
+  for (int i = tid; i < query_nnz; i += blockDim.x) {
+    shmem_indices[i] = indices[query_nnz * bid + i];
+  }
+  __syncthreads();
+
+#pragma unroll
+  for (int i = 0; i < query_nnz; ++i) {
+    atomicAdd(&params[(int64_t)shmem_indices[i] * embed_size + tid],
+              -lr * grads[(bid * query_nnz + i) * embed_size + tid]);
+  }
+}
+
+// Keep the interface and argument name as torch.embedding()
+// input is indices, and weight is embedding table
+torch::Tensor gather_gpu_fwd(const torch::Tensor weight, const torch::Tensor indices) {
+  AT_ASSERT(indices.is_cuda());
+  AT_ASSERT(weight.is_cuda());
+  AT_ASSERT(indices.scalar_type() == torch::ScalarType::Long);
+  AT_ASSERT(weight.scalar_type() == torch::ScalarType::Float);
+  AT_ASSERT(weight.is_contiguous());
+
+  int batch_size = indices.size(0);
+  int query_nnz = 1;
+  if (indices.dim() > 1) {
+    query_nnz = indices.size(1);
+  }
+
+  // Shared memory size limit. Larger nnz can also be supported by skipping shared memory if necessary
+  TORCH_CHECK(query_nnz <= 12288, "Embedding width must be smaller than 48k");
+
+  int num_features = weight.size(0);
+  int embed_size = weight.size(1);
+
+  // Block dimension limit. Large than 1024 width can be easily supported by letting each block read
+  // from different strides if necessary.
+  TORCH_CHECK(embed_size <= 1024, "Embedding width must be smaller than 1024");
+
+  auto outputs =
+      torch::empty(batch_size * query_nnz * embed_size, at::device(at::kCUDA).dtype(at::kFloat));
+
+  if (query_nnz != 1) {
+    GatherKernel<<<batch_size,
+                   embed_size,
+                   query_nnz * sizeof(int),
+                   at::cuda::getCurrentCUDAStream()>>>(weight.data_ptr<float>(),
+                                                       num_features,
+                                                       embed_size,
+                                                       batch_size,
+                                                       query_nnz,
+                                                       indices.contiguous().data_ptr<int64_t>(),
+                                                       outputs.data_ptr<float>());
+  } else {
+    OneHotKernel<<<batch_size, embed_size, 0, at::cuda::getCurrentCUDAStream()>>>(
+        weight.data_ptr<float>(),
+        num_features,
+        embed_size,
+        batch_size,
+        indices.contiguous().data_ptr<int64_t>(),
+        outputs.data_ptr<float>());
+  }
+
+  return outputs.reshape({batch_size, query_nnz, embed_size});
+}
+
+// Because complication of handling sparse tensor, use the native backward function is still faster
+// TODO(haow): Figure out a way to write out sparse tensor directly to avoid addintional copy which makes
+// customized implementation slower than Pytorch's own desipte kernels are more efficient
+torch::Tensor gather_gpu_bwd(const torch::Tensor grad,
+                             const torch::Tensor indices,
+                             const int num_features) {
+  return at::embedding_sparse_backward(grad, indices, num_features, /*padding_idx=*/-1, /*scale_grad_by_freq=*/false);
+}
+
+// Backward gather with fused plain SGD (no weight decay nor momentum)
+void gather_gpu_bwd_fuse_sgd(const torch::Tensor grad,
+                             const torch::Tensor indices,
+                             float lr,
+                             torch::Tensor weight) {
+  AT_ASSERT(grad.is_cuda());
+  AT_ASSERT(indices.is_cuda());
+  AT_ASSERT(weight.is_cuda());
+  AT_ASSERT(grad.scalar_type() == torch::ScalarType::Float);
+  AT_ASSERT(indices.scalar_type() == torch::ScalarType::Long);
+  AT_ASSERT(weight.scalar_type() == torch::ScalarType::Float);
+  AT_ASSERT(weight.is_contiguous());
+
+  int batch_size = indices.size(0);
+  int query_nnz = 1;
+  if (indices.dim() > 1) {
+    query_nnz = indices.size(1);
+  }
+  int num_features = weight.size(0);
+  int embed_size = weight.size(1);
+
+  GatherBackwardFuseSgdKernel<<<batch_size,
+                                embed_size,
+                                query_nnz * sizeof(int),
+                                at::cuda::getCurrentCUDAStream()>>>(
+      grad.contiguous().data_ptr<float>(),
+      num_features,
+      embed_size,
+      batch_size,
+      query_nnz,
+      indices.contiguous().data_ptr<int64_t>(),
+      lr,
+      weight.data_ptr<float>());
+}

+ 14 - 0
PyTorch/Recommendation/DLRM/dlrm/cuda_src/sparse_gather/sparse_pytorch_ops.cpp

@@ -0,0 +1,14 @@
+#include <torch/extension.h>
+
+torch::Tensor gather_gpu_fwd(torch::Tensor input, torch::Tensor weight);
+void gather_gpu_bwd_fuse_sgd(const torch::Tensor grad, const torch::Tensor indices, float lr, torch::Tensor weight);
+torch::Tensor gather_gpu_bwd(const torch::Tensor grad, const torch::Tensor indices, const int num_features);
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("gather_gpu_fwd", &gather_gpu_fwd, "Embedding gather", py::arg("indices"), py::arg("weight"));
+  m.def("gather_gpu_bwd_fuse_sgd", &gather_gpu_bwd_fuse_sgd, "Embedding gather backward with fused plain SGD",
+        py::arg("grad"), py::arg("indices"), py::arg("lr"), py::arg("weight"));
+  m.def("gather_gpu_bwd", &gather_gpu_bwd, "Embedding gather backward",
+        py::arg("grad"), py::arg("indices"), py::arg("num_features"));
+}

+ 14 - 54
PyTorch/Recommendation/DLRM/dlrm/data/data_loader.py

@@ -13,66 +13,26 @@
 # limitations under the License.
 
 
-import math
-import os
-import time
-import numpy as np
 import argparse
+import time
+from typing import Tuple, Optional
 
-import torch
-from torch.utils.data import Dataset
-
-class CriteoBinDataset(Dataset):
-    """Simple dataloader for a recommender system. Designed to work with a single binary file."""
-
-    def __init__(self, data_file, batch_size=1, subset=None,
-                 numerical_features=13, categorical_features=26,
-                 data_type='int32', online_shuffle=True):
-        self.data_type = np.__dict__[data_type]
-        bytes_per_feature = self.data_type().nbytes
-
-        self.tad_fea = 1 + numerical_features
-        self.tot_fea = 1 + numerical_features + categorical_features
-
-        self.batch_size = batch_size
-        self.bytes_per_entry = (bytes_per_feature * self.tot_fea * batch_size)
-
-        self.num_entries = math.ceil(os.path.getsize(data_file) / self.bytes_per_entry)
-
-        if subset is not None:
-            if subset <= 0 or subset > 1:
-                raise ValueError('Subset parameter must be in (0,1) range')
-            self.num_entries = self.num_entries * subset
-
-        print('data file:', data_file, 'number of batches:', self.num_entries)
-        self.file = open(data_file, 'rb')
-        self.online_shuffle=online_shuffle
-
-    def __len__(self):
-        return self.num_entries
-
-    def __getitem__(self, idx):
-        if idx == 0:
-            self.file.seek(0, 0)
-
-        if self.online_shuffle:
-            self.file.seek(idx * self.bytes_per_entry, 0)
-
-        raw_data = self.file.read(self.bytes_per_entry)
-        array = np.frombuffer(raw_data, dtype=self.data_type).reshape(-1, self.tot_fea)
+from torch.utils.data import DataLoader
 
-        # numerical features are encoded as float32
-        numerical_features = array[:, 1:self.tad_fea].view(dtype=np.float32)
-        numerical_features = torch.from_numpy(numerical_features)
+from dlrm.data.datasets import CriteoBinDataset
+from dlrm.data.factories import create_dataset_factory
 
 
-        categorical_features = torch.from_numpy(array[:, self.tad_fea:])
-        labels = torch.from_numpy(array[:, 0])
+def get_data_loaders(flags, device_mapping: Optional[dict] = None) -> Tuple[DataLoader, DataLoader]:
+    dataset_factory = create_dataset_factory(flags, device_mapping=device_mapping)
 
-        return numerical_features, categorical_features, labels
+    dataset_train, dataset_test = dataset_factory.create_datasets()
+    train_sampler = dataset_factory.create_sampler(dataset_train) if flags.shuffle_batch_order else None
+    collate_fn = dataset_factory.create_collate_fn()
 
-    def __del__(self):
-        self.file.close()
+    data_loader_train = dataset_factory.create_data_loader(dataset_train, collate_fn=collate_fn, sampler=train_sampler)
+    data_loader_test = dataset_factory.create_data_loader(dataset_test, collate_fn=collate_fn)
+    return data_loader_train, data_loader_test
 
 
 if __name__ == '__main__':
@@ -90,7 +50,7 @@ if __name__ == '__main__':
     for i in range(args.steps):
         _ = dataset[i]
     end = time.time()
-    
+
     step_time = (end - begin) / args.steps
     throughput = args.batch_size / step_time
 

+ 242 - 0
PyTorch/Recommendation/DLRM/dlrm/data/datasets.py

@@ -0,0 +1,242 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import concurrent
+import math
+import os
+import queue
+
+import torch
+
+import numpy as np
+from torch.utils.data import Dataset
+from typing import Optional, Sequence, Tuple, Any, Dict
+
+from dlrm.data.utils import get_categorical_feature_type
+from dlrm.utils.distributed import get_rank
+
+
+class SyntheticDataset(Dataset):
+    """Synthetic dataset version of criteo dataset."""
+
+    def __init__(
+        self,
+        num_entries: int,
+        device: str = 'cuda',
+        batch_size: int = 1,
+        numerical_features: Optional[int] = None,
+        categorical_feature_sizes: Optional[Sequence[int]] = None,
+        device_mapping: Optional[Dict[str, Any]] = None
+    ):
+        if device_mapping:
+            # distributed setting
+            rank = get_rank()
+            numerical_features = numerical_features if device_mapping["bottom_mlp"] == rank else None
+            categorical_feature_sizes = device_mapping["embedding"][rank]
+
+        self.cat_features_count = len(categorical_feature_sizes) if categorical_feature_sizes is not None else 0
+        self.num_features_count = numerical_features if numerical_features is not None else 0
+
+        self.tot_fea = 1 + self.num_features_count + self.cat_features_count
+        self.batch_size = batch_size
+        self.batches_per_epoch = math.ceil(num_entries / batch_size)
+        self.categorical_feature_sizes = categorical_feature_sizes
+        self.device = device
+
+        self.tensor = torch.randint(low=0, high=2, size=(self.batch_size, self.tot_fea), device=self.device)
+        self.tensor = self.tensor.float()
+
+    def __len__(self):
+        return self.batches_per_epoch
+
+    def __getitem__(self, idx: int):
+        if idx >= self.batches_per_epoch:
+            raise IndexError()
+
+        numerical_features = (self.tensor[:, 1: 1 + self.num_features_count].to(torch.float32)
+                              if self.num_features_count > 0 else None)
+        categorical_features = (self.tensor[:, 1 + self.num_features_count:].to(torch.long)
+                                if self.cat_features_count > 0 else None)
+        target = self.tensor[:, 0].to(torch.float32)
+
+        return numerical_features, categorical_features, target
+
+
+class CriteoBinDataset(Dataset):
+    """Simple dataloader for a recommender system. Designed to work with a single binary file."""
+
+    def __init__(
+        self,
+        data_file: str,
+        batch_size: int = 1,
+        subset: float = None,
+        numerical_features: int = 13,
+        categorical_features: int = 26,
+        data_type: str = 'int32'
+    ):
+        self.data_type = np.__dict__[data_type]
+        bytes_per_feature = self.data_type().nbytes
+
+        self.tad_fea = 1 + numerical_features
+        self.tot_fea = 1 + numerical_features + categorical_features
+
+        self.batch_size = batch_size
+        self.bytes_per_entry = (bytes_per_feature * self.tot_fea * batch_size)
+        self.num_entries = math.ceil(os.path.getsize(data_file) / self.bytes_per_entry)
+
+        if subset is not None:
+            if subset <= 0 or subset > 1:
+                raise ValueError('Subset parameter must be in (0,1) range')
+            self.num_entries = math.ceil(self.num_entries * subset)
+
+        self.file = open(data_file, 'rb')
+        self._last_read_idx = -1
+
+    def __len__(self):
+        return self.num_entries
+
+    def __getitem__(self, idx):
+        if idx >= self.num_entries:
+            raise IndexError()
+
+        if idx == 0:
+            self.file.seek(0, 0)
+        elif self._last_read_idx != (idx - 1):
+            self.file.seek(idx * self.bytes_per_entry, 0)
+
+        raw_data = self.file.read(self.bytes_per_entry)
+        self._last_read_idx = idx
+
+        array = np.frombuffer(raw_data, dtype=self.data_type).reshape(-1, self.tot_fea)
+        return array
+
+    def __del__(self):
+        self.file.close()
+
+
+class SplitCriteoDataset(Dataset):
+    """Split version of Criteo dataset
+
+    Args:
+        data_path (str): Full path to split binary file of dataset. It must contain numerical.bin, label.bin and
+            cat_0 ~ cat_25.bin
+        batch_size (int):
+        numerical_features(boolean): If True, load numerical features for bottom_mlp. Default False
+        categorical_features (list or None): categorical features used by the rank
+        prefetch_depth (int): How many samples to prefetch. Default 10.
+    """
+    def __init__(
+        self,
+        data_path: str,
+        batch_size: int = 1,
+        numerical_features: bool = False,
+        categorical_features: Optional[Sequence[int]] = None,
+        categorical_feature_sizes: Optional[Sequence[int]] = None,
+        prefetch_depth: int = 10
+    ):
+        self._label_bytes_per_batch = np.dtype(np.bool).itemsize * batch_size
+        self._numerical_bytes_per_batch = 13 * np.dtype(np.float16).itemsize * batch_size if numerical_features else 0
+        self._categorical_feature_types = [
+            get_categorical_feature_type(size) for size in categorical_feature_sizes
+        ] if categorical_feature_sizes else []
+        self._categorical_bytes_per_batch = [
+            np.dtype(cat_type).itemsize * batch_size for cat_type in self._categorical_feature_types
+        ]
+        self._categorical_features = categorical_features
+        self._batch_size = batch_size
+        self._label_file = os.open(os.path.join(data_path, F"label.bin"), os.O_RDONLY)
+        self._num_entries = int(math.ceil(os.fstat(self._label_file).st_size / self._label_bytes_per_batch))
+
+        if numerical_features:
+            self._numerical_features_file = os.open(os.path.join(data_path, "numerical.bin"), os.O_RDONLY)
+            if math.ceil(os.fstat(self._numerical_features_file).st_size /
+                         self._numerical_bytes_per_batch) != self._num_entries:
+                raise ValueError("Size miss match in data files")
+        else:
+            self._numerical_features_file = None
+
+        if categorical_features:
+            self._categorical_features_files = []
+            for cat_id in categorical_features:
+                cat_file = os.open(os.path.join(data_path, F"cat_{cat_id}.bin"), os.O_RDONLY)
+                cat_bytes = self._categorical_bytes_per_batch[cat_id]
+                if math.ceil(
+                        os.fstat(cat_file).st_size / cat_bytes) != self._num_entries:
+                    raise ValueError("Size miss match in data files")
+                self._categorical_features_files.append(cat_file)
+        else:
+            self._categorical_features_files = None
+
+        self._prefetch_depth = min(prefetch_depth, self._num_entries)
+        self._prefetch_queue = queue.Queue()
+        self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
+
+    def __len__(self):
+        return self._num_entries
+
+    def __getitem__(self, idx: int):
+        if idx >= self._num_entries:
+            raise IndexError()
+
+        if self._prefetch_depth <= 1:
+            return self._get_item(idx)
+
+        if idx == 0:
+            for i in range(self._prefetch_depth):
+                self._prefetch_queue.put(self._executor.submit(self._get_item, (i)))
+        if idx < self._num_entries - self._prefetch_depth:
+            self._prefetch_queue.put(self._executor.submit(self._get_item, (idx + self._prefetch_depth)))
+        return self._prefetch_queue.get().result()
+
+    def _get_item(self, idx: int) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+        click = self._get_label(idx)
+        numerical_features = self._get_numerical_features(idx)
+        categorical_features = self._get_categorical_features(idx)
+        return numerical_features, categorical_features, click
+
+    def _get_label(self, idx: int) -> torch.Tensor:
+        raw_label_data = os.pread(self._label_file, self._label_bytes_per_batch,
+                                  idx * self._label_bytes_per_batch)
+        array = np.frombuffer(raw_label_data, dtype=np.bool)
+        return torch.from_numpy(array).to(torch.float32)
+
+    def _get_numerical_features(self, idx: int) -> Optional[torch.Tensor]:
+        if self._numerical_features_file is None:
+            return None
+
+        raw_numerical_data = os.pread(self._numerical_features_file, self._numerical_bytes_per_batch,
+                                      idx * self._numerical_bytes_per_batch)
+        array = np.frombuffer(raw_numerical_data, dtype=np.float16)
+        return torch.from_numpy(array).view(-1, 13)
+
+    def _get_categorical_features(self, idx: int) -> Optional[torch.Tensor]:
+        if self._categorical_features_files is None:
+            return None
+
+        categorical_features = []
+        for cat_id, cat_file in zip(self._categorical_features, self._categorical_features_files):
+            cat_bytes = self._categorical_bytes_per_batch[cat_id]
+            cat_type = self._categorical_feature_types[cat_id]
+            raw_cat_data = os.pread(cat_file, cat_bytes, idx * cat_bytes)
+            array = np.frombuffer(raw_cat_data, dtype=cat_type)
+            tensor = torch.from_numpy(array).unsqueeze(1).to(torch.long)
+            categorical_features.append(tensor)
+        return torch.cat(categorical_features, dim=1)
+
+    def __del__(self):
+        data_files = [self._label_file, self._numerical_features_file] + self._categorical_features_files
+        for data_file in data_files:
+            if data_file is not None:
+                os.close(data_file)

+ 219 - 0
PyTorch/Recommendation/DLRM/dlrm/data/factories.py

@@ -0,0 +1,219 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import os
+from typing import Tuple, Optional, Callable, Dict, Sequence
+
+import torch
+from torch.utils.data import Dataset, Sampler, RandomSampler
+
+from dlrm.data.datasets import CriteoBinDataset, SyntheticDataset, SplitCriteoDataset
+from dlrm.data.samplers import RandomDistributedSampler
+from dlrm.data.utils import collate_array, write_dataset_to_disk, get_categorical_feature_sizes, collate_split_tensors
+from dlrm.utils.distributed import is_distributed, is_main_process, get_rank
+
+
+def create_synthetic_datasets(flags, device_mapping: Optional[Dict] = None):
+    dataset_train = SyntheticDataset(num_entries=flags.synthetic_dataset_num_entries,
+                                     batch_size=flags.batch_size,
+                                     numerical_features=flags.num_numerical_features,
+                                     categorical_feature_sizes=get_categorical_feature_sizes(flags),
+                                     device_mapping=device_mapping)
+
+    dataset_test = SyntheticDataset(num_entries=flags.synthetic_dataset_num_entries,
+                                    batch_size=flags.test_batch_size,
+                                    numerical_features=flags.num_numerical_features,
+                                    categorical_feature_sizes=get_categorical_feature_sizes(flags),
+                                    device_mapping=device_mapping)
+    return dataset_train, dataset_test
+
+
+def create_real_datasets(flags, path, dataset_class: type = CriteoBinDataset):
+    train_dataset = os.path.join(path, "train_data.bin")
+    test_dataset = os.path.join(path, "test_data.bin")
+    categorical_sizes = get_categorical_feature_sizes(flags)
+
+    dataset_train = dataset_class(
+        data_file=train_dataset,
+        batch_size=flags.batch_size,
+        subset=flags.dataset_subset,
+        numerical_features=flags.num_numerical_features,
+        categorical_features=len(categorical_sizes),
+    )
+
+    dataset_test = dataset_class(
+        data_file=test_dataset,
+        batch_size=flags.test_batch_size,
+        numerical_features=flags.num_numerical_features,
+        categorical_features=len(categorical_sizes),
+    )
+
+    return dataset_train, dataset_test
+
+
+class DatasetFactory:
+
+    def __init__(self, flags, device_mapping: Optional[Dict] = None):
+        self._flags = flags
+        self._device_mapping = device_mapping
+
+    def create_collate_fn(self) -> Optional[Callable]:
+        if self._device_mapping is not None:
+            # selection of categorical features assigned to this device
+            device_cat_features = torch.tensor(
+                self._device_mapping["embedding"][get_rank()], device=self._flags.base_device, dtype=torch.long)
+        else:
+            device_cat_features = None
+
+        orig_stream = torch.cuda.current_stream() if self._flags.base_device == 'cuda' else None
+        return functools.partial(
+            collate_array,
+            device=self._flags.base_device,
+            orig_stream=orig_stream,
+            num_numerical_features=self._flags.num_numerical_features,
+            selected_categorical_features=device_cat_features
+        )
+
+    def create_sampler(self, dataset: Dataset) -> Optional[Sampler]:
+        return RandomDistributedSampler(dataset) if is_distributed() else RandomSampler(dataset)
+
+    def create_datasets(self) -> Tuple[Dataset, Dataset]:
+        raise NotImplementedError()
+
+    def create_data_loader(self, dataset, collate_fn: Optional[Callable] = None, sampler: Optional[Sampler] = None):
+        return torch.utils.data.DataLoader(
+            dataset, collate_fn=collate_fn, sampler=sampler, batch_size=None,
+            num_workers=0, pin_memory=False
+        )
+
+
+class SyntheticDiskDatasetFactory(DatasetFactory):
+
+    def create_sampler(self, dataset: Dataset) -> Optional[Sampler]:
+        return None
+
+    def create_datasets(self) -> Tuple[Dataset, Dataset]:
+        synthetic_train, synthetic_test = create_synthetic_datasets(self._flags)
+
+        if is_distributed():
+            self._synchronized_write(synthetic_train, synthetic_test)
+        else:
+            self._write(synthetic_train, synthetic_test)
+
+        return create_real_datasets(self._flags, self._flags.synthetic_dataset_dir)
+
+    def _synchronized_write(self, train_dataset: Dataset, test_dataset: Dataset):
+        if is_main_process():
+            self._write(train_dataset, test_dataset)
+        torch.distributed.barrier()
+
+    def _write(self, train_dataset: Dataset, test_dataset: Dataset):
+        write_dataset_to_disk(self._flags.synthetic_dataset_dir, train_dataset, test_dataset,
+                              self._flags.synthetic_dataset_table_sizes)
+
+
+class SyntheticGpuDatasetFactory(DatasetFactory):
+
+    def create_collate_fn(self) -> Optional[Callable]:
+        return None
+
+    def create_sampler(self, dataset) -> Optional[Sampler]:
+        return None
+
+    def create_datasets(self) -> Tuple[Dataset, Dataset]:
+        return create_synthetic_datasets(self._flags, self._device_mapping)
+
+
+class BinaryDatasetFactory(DatasetFactory):
+
+    def create_datasets(self) -> Tuple[Dataset, Dataset]:
+        return create_real_datasets(self._flags, self._flags.dataset)
+
+
+class SplitBinaryDatasetFactory(DatasetFactory):
+
+    def __init__(self, flags, numerical_features: bool, categorical_features: Sequence[int]):
+        super().__init__(flags)
+        self._numerical_features = numerical_features
+        self._categorical_features = categorical_features
+
+    def create_collate_fn(self):
+        orig_stream = torch.cuda.current_stream() if self._flags.base_device == 'cuda' else None
+        return functools.partial(
+            collate_split_tensors,
+            device=self._flags.base_device,
+            orig_stream=orig_stream,
+            numerical_type=torch.float16 if self._flags.amp else torch.float32
+        )
+
+    def create_datasets(self) -> Tuple[Dataset, Dataset]:
+        train_dataset_path = os.path.join(self._flags.dataset, "train")
+        test_dataset_path = os.path.join(self._flags.dataset, "test")
+        categorical_sizes = get_categorical_feature_sizes(self._flags)
+
+        dataset_train = SplitCriteoDataset(
+            data_path=train_dataset_path,
+            batch_size=self._flags.batch_size,
+            numerical_features=self._numerical_features,
+            categorical_features=self._categorical_features,
+            categorical_feature_sizes=categorical_sizes
+        )
+        dataset_test = SplitCriteoDataset(
+            data_path=test_dataset_path,
+            batch_size=self._flags.test_batch_size,
+            numerical_features=self._numerical_features,
+            categorical_features=self._categorical_features,
+            categorical_feature_sizes=categorical_sizes
+        )
+        return dataset_train, dataset_test
+
+
+def create_dataset_factory(flags, device_mapping: Optional[dict] = None) -> DatasetFactory:
+    """
+    By default each dataset can be used in single GPU or distributed setting - please keep that in mind when adding
+    new datasets. Distributed case requires selection of categorical features provided in `device_mapping`
+    (see `DatasetFactory#create_collate_fn`).
+
+    :param flags:
+    :param device_mapping: dict, information about model bottom mlp and embeddings devices assignment
+    :return:
+    """
+    dataset_type = flags.dataset_type
+
+    if dataset_type == "binary":
+        return BinaryDatasetFactory(flags, device_mapping)
+
+    if dataset_type == "split":
+        if is_distributed():
+            assert device_mapping is not None, "Distributed dataset requires information about model device mapping."
+            rank = get_rank()
+            return SplitBinaryDatasetFactory(
+                flags=flags,
+                numerical_features=device_mapping["bottom_mlp"] == rank,
+                categorical_features=device_mapping["embedding"][rank]
+            )
+        return SplitBinaryDatasetFactory(
+            flags=flags,
+            numerical_features=True,
+            categorical_features=range(len(get_categorical_feature_sizes(flags)))
+        )
+
+    if dataset_type == "synthetic_gpu":
+        return SyntheticGpuDatasetFactory(flags, device_mapping)
+
+    if dataset_type == "synthetic_disk":
+        return SyntheticDiskDatasetFactory(flags, device_mapping)
+
+    raise NotImplementedError(f"unknown dataset type: {dataset_type}")

+ 37 - 0
PyTorch/Recommendation/DLRM/dlrm/data/samplers.py

@@ -0,0 +1,37 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+
+from torch.utils.data import RandomSampler
+
+from dlrm.utils.distributed import get_local_rank
+
+
+class RandomDistributedSampler(RandomSampler):
+
+    _SAMPLE_FILE = "/tmp/dlrm_training_sample.npy"
+
+    def __iter__(self):
+        """
+        To guarantee all ranks have the same same permutation, generating it from rank 0 and sync
+        to other rank by writing to disk
+        """
+        if get_local_rank() == 0:
+            np.save(self._SAMPLE_FILE, np.array(super().__iter__()))
+        torch.distributed.barrier()
+
+        sample = np.load(self._SAMPLE_FILE)
+        return iter(sample)

+ 0 - 42
PyTorch/Recommendation/DLRM/dlrm/data/synthetic_dataset.py

@@ -1,42 +0,0 @@
-# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#       http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-import math
-from torch.utils.data import Dataset
-
-
-class SyntheticDataset(Dataset):
-    """Synthetic dataset version of criteo dataset."""
-
-    def __init__(self,  num_entries, device='cuda', batch_size=1, dense_features=13,
-                 categorical_feature_sizes=None):
-        # dataset. single target, 13 dense features, 26 sparse features
-        self.sparse_features = len(categorical_feature_sizes)
-        self.dense_features = dense_features
-
-        self.tot_fea = 1 + dense_features + self.sparse_features
-        self.batch_size = batch_size
-        self.batches_per_epoch = math.ceil(num_entries / batch_size)
-        self.categorical_feature_sizes = categorical_feature_sizes
-        self.device = device
-
-        self.tensor = torch.randint(low=0, high=2, size=(self.batch_size, self.tot_fea), device=self.device)
-        self.tensor = self.tensor.float()
-
-    def __len__(self):
-        return self.batches_per_epoch
-
-    def __getitem__(self, idx):
-        return self.tensor[:, 1:14], self.tensor[:, 14:], self.tensor[:, 0]

+ 164 - 0
PyTorch/Recommendation/DLRM/dlrm/data/utils.py

@@ -0,0 +1,164 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+
+import numpy as np
+import pandas as pd
+import torch
+import tqdm
+from torch import Tensor
+from torch.cuda import Stream
+from typing import Tuple, Optional
+
+
+def collate_split_tensors(
+        tensors: Tuple[Tensor, Tensor, Tensor],
+        device: str,
+        orig_stream: Stream,
+        numerical_type: torch.dtype = torch.float32
+):
+    tensors = [tensor.to(device, non_blocking=True) if tensor is not None else None for tensor in tensors]
+    if device == 'cuda':
+        for tensor in tensors:
+            if tensor is not None:
+                tensor.record_stream(orig_stream)
+
+    numerical_features, categorical_features, click = tensors
+
+    if numerical_features is not None:
+        numerical_features = numerical_features.to(numerical_type)
+
+    return numerical_features, categorical_features, click
+
+
+def collate_array(
+        array: np.array,
+        device: str,
+        orig_stream: Stream,
+        num_numerical_features: int,
+        selected_categorical_features: Optional[Tensor] = None
+):
+    # numerical features are encoded as float32
+    numerical_features = array[:, 1:1 + num_numerical_features].view(dtype=np.float32)
+    numerical_features = torch.from_numpy(numerical_features)
+
+    categorical_features = torch.from_numpy(array[:, 1 + num_numerical_features:])
+    click = torch.from_numpy(array[:, 0])
+
+    categorical_features = categorical_features.to(device, non_blocking=True).to(torch.long)
+    numerical_features = numerical_features.to(device, non_blocking=True)
+    click = click.to(torch.float32).to(device, non_blocking=True)
+
+    if selected_categorical_features is not None:
+        categorical_features = categorical_features[:, selected_categorical_features]
+
+    if device == 'cuda':
+        numerical_features.record_stream(orig_stream)
+        categorical_features.record_stream(orig_stream)
+        click.record_stream(orig_stream)
+
+    return numerical_features, categorical_features, click
+
+
+def write_dataset_to_disk(destination, dataset_train, dataset_test, table_sizes):
+    for filename, dataset in zip(('train_data.bin', 'test_data.bin'),
+                                 (dataset_train, dataset_test)):
+
+        os.makedirs(destination, exist_ok=True)
+        dst_file = os.path.join(destination, filename)
+        if os.path.exists(dst_file):
+            print(f'File {dst_file} already exists, skipping')
+            continue
+
+        with open(dst_file, 'wb') as dst_fd:
+            for numeric, categorical, label in tqdm.tqdm(dataset):
+                # numeric, categorical, label = collate(batch, device='cpu',
+                #                                       orig_stream=None,
+                #                                       num_numerical_features=13)
+
+                categorical = categorical.to(torch.int32)
+                label = label.to(torch.int32)
+
+                l = pd.DataFrame(label.cpu().numpy())
+                l.columns = ['label']
+                n = pd.DataFrame(numeric.cpu().numpy())
+                n.columns = ['n' + str(i) for i in range(len(n.columns))]
+
+                c = pd.DataFrame(categorical.cpu().numpy())
+                c.columns = ['c' + str(i) for i in range(len(c.columns))]
+                df = pd.concat([l, n, c], axis=1)
+
+                records = df.to_records(index=False)
+                raw_data = records.tobytes()
+
+                dst_fd.write(raw_data)
+
+    model_size_dict = {'_c' + str(i): size for i, size in zip(range(14, 40), table_sizes)}
+    with open(os.path.join(destination, 'model_size.json'), 'w') as f:
+        json.dump(model_size_dict, f, indent=4, sort_keys=True)
+
+
+def prefetcher(load_iterator, prefetch_stream):
+    def _prefetch():
+        with torch.cuda.stream(prefetch_stream):
+            try:
+                data_batch = next(load_iterator)
+            except StopIteration:
+                return None
+
+        return data_batch
+
+    next_data_batch = _prefetch()
+
+    while next_data_batch is not None:
+        torch.cuda.current_stream().wait_stream(prefetch_stream)
+        data_batch = next_data_batch
+        next_data_batch = _prefetch()
+        yield data_batch
+
+
+def get_categorical_feature_sizes(FLAGS):
+    if FLAGS.dataset_type in ['synthetic_disk', 'synthetic_gpu']:
+        feature_sizes = [int(s) for s in FLAGS.synthetic_dataset_table_sizes]
+        print('feature sizes: ', feature_sizes)
+        return feature_sizes
+
+    categorical_sizes_file = os.path.join(FLAGS.dataset, "model_size.json")
+    with open(categorical_sizes_file) as f:
+        categorical_sizes = json.load(f).values()
+
+    categorical_sizes = list(categorical_sizes)
+
+    # need to add 1 because the JSON file contains the max value not the count
+    categorical_sizes = [s + 1 for s in categorical_sizes]
+
+    print('feature sizes: ', categorical_sizes)
+
+    if FLAGS.max_table_size is None:
+        return categorical_sizes
+
+    clipped_sizes = [min(s, FLAGS.max_table_size) for s in categorical_sizes]
+    return clipped_sizes
+
+
+def get_categorical_feature_type(size: int):
+    types = (np.int8, np.int16, np.int32)
+
+    for numpy_type in types:
+        if size < np.iinfo(numpy_type).max:
+            return numpy_type
+
+    raise RuntimeError(f"Categorical feature of size {size} is too big for defined types")

+ 0 - 224
PyTorch/Recommendation/DLRM/dlrm/model.py

@@ -1,224 +0,0 @@
-# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#       http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import copy
-import json
-import math
-
-from absl import logging
-
-import torch
-from torch import nn
-from typing import List
-
-
-class Dlrm(nn.Module):
-    """Reimplement Facebook's DLRM model
-
-    Original implementation is from https://github.com/facebookresearch/dlrm.
-
-    """
-
-    def __init__(self, num_numerical_features, categorical_feature_sizes, bottom_mlp_sizes, top_mlp_sizes,
-                     embedding_dim=32, interaction_op="dot", self_interaction=False, hash_indices=False,
-                     base_device="cuda", sigmoid=False):
-
-        # Running everything on gpu by default
-        self._base_device = base_device
-        self._embedding_device_map = [base_device for _ in range(len(categorical_feature_sizes))]
-
-        super(Dlrm, self).__init__()
-
-        if embedding_dim != bottom_mlp_sizes[-1]:
-            raise TypeError("The last bottom MLP layer must have same size as embedding.")
-
-        self._embedding_dim = embedding_dim
-        self._interaction_op = interaction_op
-        self._self_interaction = self_interaction
-        self._hash_indices = hash_indices
-        self._categorical_feature_sizes = copy.copy(categorical_feature_sizes)
-
-        # Interactions are among outputs of all the embedding tables and bottom MLP, total number of
-        # (num_embedding_tables + 1) vectors with size embdding_dim. ``dot`` product interaction computes dot product
-        # between any 2 vectors. ``cat`` interaction concatenate all the vectors together.
-        # Output of interaction will have shape [num_interactions, embdding_dim].
-        self._num_interaction_inputs = len(categorical_feature_sizes) + 1
-        if interaction_op == "dot":
-            if self_interaction:
-                raise NotImplementedError
-            num_interactions = (self._num_interaction_inputs * (self._num_interaction_inputs - 1)) // 2 + embedding_dim
-        elif interaction_op == "cat":
-            num_interactions = self._num_interaction_inputs * embedding_dim
-        else:
-            raise TypeError(F"Unknown interaction {interaction_op}.")
-
-        self.embeddings = nn.ModuleList()
-        self._create_embeddings(self.embeddings, embedding_dim, categorical_feature_sizes)
-
-        # Create bottom MLP
-        bottom_mlp_layers = []
-        input_dims = num_numerical_features
-        for output_dims in bottom_mlp_sizes:
-            bottom_mlp_layers.append(
-                nn.Linear(input_dims, output_dims))
-            bottom_mlp_layers.append(nn.ReLU(inplace=True))
-            input_dims = output_dims
-        self.bottom_mlp = nn.Sequential(*bottom_mlp_layers)
-
-        # Create Top MLP
-        top_mlp_layers = []
-
-        input_dims = num_interactions
-        if self._interaction_op == 'dot':
-            input_dims += 1  # pad 1 to be multiple of 8
-
-        for output_dims in top_mlp_sizes[:-1]:
-            top_mlp_layers.append(nn.Linear(input_dims, output_dims))
-            top_mlp_layers.append(nn.ReLU(inplace=True))
-            input_dims = output_dims
-        # last Linear layer uses sigmoid
-        top_mlp_layers.append(nn.Linear(input_dims, top_mlp_sizes[-1]))
-        if sigmoid:
-            top_mlp_layers.append(nn.Sigmoid())
-        self.top_mlp = nn.Sequential(*top_mlp_layers)
-
-        self._initialize_mlp_weights()
-        self._interaction_padding = torch.zeros(1, 1, dtype=torch.float32, device=base_device)
-        self.tril_indices = torch.tensor([[i for i in range(len(self.embeddings) + 1) 
-                                             for j in range(i + int(self_interaction))],
-                                          [j for i in range(len(self.embeddings) + 1) 
-                                             for j in range(i + int(self_interaction))]])
-
-    def _interaction(self, 
-            bottom_mlp_output: torch.Tensor, 
-            embedding_outputs: List[torch.Tensor], 
-            batch_size: int) -> torch.Tensor:
-        """Interaction
-
-        "dot" interaction is a bit tricky to implement and test. Break it out from forward so that it can be tested
-        independently.
-
-        Args:
-            bottom_mlp_output (Tensor):
-            embedding_outputs (list): Sequence of tensors
-            batch_size (int):
-        """
-        if self._interaction_padding is None:
-            self._interaction_padding = torch.zeros(
-                batch_size, 1, dtype=bottom_mlp_output.dtype, device=bottom_mlp_output.device)
-        concat = torch.cat([bottom_mlp_output] + embedding_outputs, dim=1)
-        if self._interaction_op == "dot" and not self._self_interaction:
-            concat = concat.view((-1, self._num_interaction_inputs, self._embedding_dim))
-            interaction = torch.bmm(concat, torch.transpose(concat, 1, 2))
-            interaction_flat = interaction[:, self.tril_indices[0], self.tril_indices[1]]
-            # concatenate dense features and interactions
-            interaction_padding = self._interaction_padding.expand(batch_size, 1).to(dtype=bottom_mlp_output.dtype)
-            interaction_output = torch.cat(
-                (bottom_mlp_output, interaction_flat, interaction_padding), dim=1)
-        elif self._interaction_op == "cat":
-            interaction_output = concat
-        else:
-            raise NotImplementedError
-
-        return interaction_output
-
-    def _initialize_mlp_weights(self):
-        """Initializing weights same as original DLRM"""
-        for module in self.modules():
-            if isinstance(module, nn.Linear):
-                nn.init.normal_(module.weight.data, 0., math.sqrt(2. / (module.in_features + module.out_features)))
-                nn.init.normal_(module.bias.data, 0., math.sqrt(1. /  module.out_features))
-
-        # Explicitly set weight corresponding to zero padded interaction output. They will
-        # stay 0 throughout the entire training. An assert can be added to the end of the training
-        # to prove it doesn't increase model capacity but just 0 paddings.
-        nn.init.zeros_(self.top_mlp[0].weight[:, -1].data)
-
-    @property
-    def num_categorical_features(self):
-        return len(self._categorical_feature_sizes)
-
-    def extra_repr(self):
-        s = (F"interaction_op={self._interaction_op}, self_interaction={self._self_interaction}, "
-             F"hash_indices={self._hash_indices}")
-        return s
-    # pylint:enable=missing-docstring
-
-    @classmethod
-    def from_dict(cls, obj_dict, **kwargs):
-        """Create from json str"""
-        return cls(**obj_dict, **kwargs)
-
-    def _create_embeddings(self, embeddings, embedding_dim, categorical_feature_sizes):
-        # Each embedding table has size [num_features, embedding_dim]
-        for i, num_features in enumerate(categorical_feature_sizes):
-            # Allocate directly on GPU is much faster than allocating on CPU then copying over
-            embedding_weight = torch.empty((num_features, embedding_dim), device=self._embedding_device_map[i])
-            embedding = nn.Embedding.from_pretrained(embedding_weight, freeze=False, sparse=True)
-
-            # Initializing embedding same as original DLRM
-            nn.init.uniform_(
-                embedding.weight.data,
-                -math.sqrt(1. / embedding.num_embeddings),
-                math.sqrt(1. / embedding.num_embeddings))
-
-            embeddings.append(embedding)
-
-    def set_devices(self, base_device):
-        """Set devices to run the model
-
-        Args:
-            base_device (string);
-        """
-        self._base_device = base_device
-        self.bottom_mlp.to(base_device)
-        self.top_mlp.to(base_device)
-        self._interaction_padding = self._interaction_padding.to(base_device)
-        self._embedding_device_map = [base_device for _ in range(self.num_categorical_features)]
-
-        for embedding_id, device in enumerate(self._embedding_device_map):
-            logging.info("Place embedding %d on device %s", embedding_id, device)
-            self.embeddings[embedding_id].to(device)
-
-    def forward(self, numerical_input, categorical_inputs):
-        """
-
-        Args:
-            numerical_input (Tensor): with shape [batch_size, num_numerical_features]
-            categorical_inputs (Tensor): with shape [batch_size, num_categorical_features]
-        """
-        batch_size = numerical_input.size()[0]
-
-        # Put indices on the same device as corresponding embedding
-        device_indices = []
-        for embedding_id, _ in enumerate(self.embeddings):
-            device_indices.append(categorical_inputs[:, embedding_id].to(self._embedding_device_map[embedding_id]))
-
-        bottom_mlp_output = self.bottom_mlp(numerical_input)
-
-        # embedding_outputs will be a list of (26 in the case of Criteo) fetched embeddings with shape
-        # [batch_size, embedding_size]
-        embedding_outputs = []
-        for embedding_id, embedding in enumerate(self.embeddings):
-            if self._hash_indices:
-                device_indices[embedding_id] = device_indices[embedding_id] % embedding.num_embeddings
-
-            embedding_outputs.append(embedding(device_indices[embedding_id]).to(self._base_device))
-
-        interaction_output = self._interaction(bottom_mlp_output, embedding_outputs, batch_size)
-
-        top_mlp_output = self.top_mlp(interaction_output)
-
-        return top_mlp_output

+ 0 - 0
PyTorch/Recommendation/DLRM/dlrm/model/__init__.py


+ 156 - 0
PyTorch/Recommendation/DLRM/dlrm/model/distributed.py

@@ -0,0 +1,156 @@
+from typing import Sequence, Optional
+
+import torch
+from torch import nn
+
+from dlrm.nn.factories import create_interaction
+from dlrm.nn.parts import DlrmBottom, DlrmTop
+from dlrm.utils import distributed as dist
+
+
+class BottomToTop(torch.autograd.Function):
+    """Switch from model parallel to data parallel
+
+    Wrap the communication of doing from bottom model in model parallel fashion to top model in data parallel
+    """
+
+    @staticmethod
+    def forward(
+        ctx,
+        local_bottom_outputs: torch.Tensor,
+        batch_sizes_per_gpu: Sequence[int],
+        vector_dim: int,
+        vectors_per_gpu: Sequence[int],
+        feature_order: Optional[torch.Tensor] = None,
+        device_feature_order: Optional[torch.Tensor] = None
+    ):
+        """
+        Args:
+            ctx : Pytorch convention
+            local_bottom_outputs (Tensor): Concatenated output of bottom model
+            batch_sizes_per_gpu (Sequence[int]):
+            vector_dim (int):
+            vectors_per_gpu (Sequence[int]): Note, bottom MLP is considered as 1 vector
+            device_feature_order:
+            feature_order:
+
+        Returns:
+            slice_embedding_outputs (Tensor): Patial output from bottom model to feed into data parallel top model
+        """
+        rank = dist.get_rank()
+
+        ctx.world_size = torch.distributed.get_world_size()
+        ctx.batch_sizes_per_gpu = batch_sizes_per_gpu
+        ctx.vector_dim = vector_dim
+        ctx.vectors_per_gpu = vectors_per_gpu
+        ctx.feature_order = feature_order
+        ctx.device_feature_order = device_feature_order
+
+        # Buffer shouldn't need to be zero out. If not zero out buffer affecting accuracy, there must be a bug.
+        bottom_output_buffer = [torch.empty(
+            batch_sizes_per_gpu[rank], n * vector_dim,
+            device=local_bottom_outputs.device, dtype=local_bottom_outputs.dtype) for n in vectors_per_gpu]
+
+        torch.distributed.all_to_all(bottom_output_buffer, list(local_bottom_outputs.split(batch_sizes_per_gpu, dim=0)))
+        slice_bottom_outputs = torch.cat(bottom_output_buffer, dim=1).view(batch_sizes_per_gpu[rank], -1, vector_dim)
+
+        # feature reordering is just for consistency across different device mapping configurations
+        if feature_order is not None and device_feature_order is not None:
+            return slice_bottom_outputs[:, feature_order, :]
+
+        return slice_bottom_outputs
+
+    @staticmethod
+    def backward(ctx, grad_slice_bottom_outputs):
+        rank = dist.get_rank()
+
+        if ctx.feature_order is not None and ctx.device_feature_order is not None:
+            grad_slice_bottom_outputs = grad_slice_bottom_outputs[:, ctx.device_feature_order, :]
+
+        grad_local_bottom_outputs = torch.empty(
+            sum(ctx.batch_sizes_per_gpu), ctx.vectors_per_gpu[rank] * ctx.vector_dim,
+            device=grad_slice_bottom_outputs.device,
+            dtype=grad_slice_bottom_outputs.dtype)
+        # All to all only takes list while split() returns tuple
+
+        grad_local_bottom_outputs_split = list(grad_local_bottom_outputs.split(ctx.batch_sizes_per_gpu, dim=0))
+
+        split_grads = [t.contiguous() for t in (grad_slice_bottom_outputs.view(ctx.batch_sizes_per_gpu[rank], -1).split(
+            [ctx.vector_dim * n for n in ctx.vectors_per_gpu], dim=1))]
+
+        torch.distributed.all_to_all(grad_local_bottom_outputs_split, split_grads)
+
+        return (grad_local_bottom_outputs.view(grad_local_bottom_outputs.shape[0], -1, ctx.vector_dim), None, None,
+                None, None, None)
+
+
+bottom_to_top = BottomToTop.apply
+
+
+class DistributedDlrm(nn.Module):
+
+    def __init__(
+        self,
+        vectors_per_gpu: Sequence[int],
+        embedding_device_mapping: Sequence[Sequence[int]],
+        world_num_categorical_features: int,
+        num_numerical_features: int,
+        categorical_feature_sizes: Sequence[int],
+        bottom_mlp_sizes: Sequence[int],
+        top_mlp_sizes: Sequence[int],
+        embedding_type: str = "multi_table",
+        embedding_dim: int = 128,
+        interaction_op: str = "dot",
+        hash_indices: bool = False,
+        use_cpp_mlp: bool = False,
+        fp16: bool = False,
+        bottom_features_ordered: bool = False,
+        device: str = "cuda"
+    ):
+        super().__init__()
+
+        self._vectors_per_gpu = vectors_per_gpu
+        self._embedding_dim = embedding_dim
+        self._interaction_op = interaction_op
+        self._hash_indices = hash_indices
+
+        # TODO: take bottom_mlp GPU from device mapping, do not assume it's always first
+        self._device_feature_order = torch.tensor(
+            [-1] + [i for bucket in embedding_device_mapping for i in bucket], dtype=torch.long, device=device
+        ) + 1 if bottom_features_ordered else None
+        self._feature_order = self._device_feature_order.argsort() if bottom_features_ordered else None
+
+        interaction = create_interaction(interaction_op, world_num_categorical_features, embedding_dim)
+
+        self.bottom_model = DlrmBottom(
+            num_numerical_features, categorical_feature_sizes, bottom_mlp_sizes,
+            embedding_type, embedding_dim, hash_indices=hash_indices, use_cpp_mlp=use_cpp_mlp, fp16=fp16, device=device
+        )
+        self.top_model = DlrmTop(top_mlp_sizes, interaction, use_cpp_mlp=use_cpp_mlp).to(device)
+
+    def extra_repr(self):
+        return f"interaction_op={self._interaction_op}, hash_indices={self._hash_indices}"
+
+    # pylint:enable=missing-docstring
+
+    @classmethod
+    def from_dict(cls, obj_dict, **kwargs):
+        """Create from json str"""
+        return cls(**obj_dict, **kwargs)
+
+    def forward(self, numerical_input, categorical_inputs, batch_sizes_per_gpu: Sequence[int]):
+        """
+        Args:
+            numerical_input (Tensor): with shape [batch_size, num_numerical_features]
+            categorical_inputs (Tensor): with shape [batch_size, num_categorical_features]
+            batch_sizes_per_gpu (Sequence[int]):
+        """
+        # bottom mlp output may be not present before all to all communication
+        bottom_output, _ = self.bottom_model(numerical_input, categorical_inputs)
+
+        from_bottom = bottom_to_top(bottom_output, batch_sizes_per_gpu, self._embedding_dim, self._vectors_per_gpu,
+                                    self._feature_order, self._device_feature_order)
+
+        # TODO: take bottom_mlp GPU from device mapping, do not assume it's always first
+        bottom_mlp_output = from_bottom[:, 0, :]
+        return self.top_model(from_bottom, bottom_mlp_output)

+ 81 - 0
PyTorch/Recommendation/DLRM/dlrm/model/single.py

@@ -0,0 +1,81 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Sequence
+
+from torch import nn
+
+from dlrm.nn.factories import create_interaction
+from dlrm.nn.parts import DlrmBottom, DlrmTop
+
+
+class Dlrm(nn.Module):
+    """Reimplement Facebook's DLRM model
+
+    Original implementation is from https://github.com/facebookresearch/dlrm.
+
+    """
+    def __init__(
+        self,
+        num_numerical_features: int,
+        categorical_feature_sizes: Sequence[int],
+        bottom_mlp_sizes: Sequence[int],
+        top_mlp_sizes: Sequence[int],
+        embedding_type: str = "multi_table",
+        embedding_dim: int = 32,
+        interaction_op: str = "dot",
+        hash_indices: bool = False,
+        use_cpp_mlp: bool = False,
+        fp16: bool = False,
+        base_device: str = "cuda",
+    ):
+        super().__init__()
+        assert embedding_dim == bottom_mlp_sizes[-1], "The last bottom MLP layer must have same size as embedding."
+
+        interaction = create_interaction(interaction_op, len(categorical_feature_sizes), embedding_dim)
+
+        self._interaction_op = interaction_op
+        self._hash_indices = hash_indices
+
+        self.bottom_model = DlrmBottom(
+            num_numerical_features=num_numerical_features,
+            categorical_feature_sizes=categorical_feature_sizes,
+            bottom_mlp_sizes=bottom_mlp_sizes,
+            embedding_type=embedding_type,
+            embedding_dim=embedding_dim,
+            hash_indices=hash_indices,
+            use_cpp_mlp=use_cpp_mlp,
+            fp16=fp16,
+            device=base_device
+        )
+        self.top_model = DlrmTop(top_mlp_sizes, interaction, use_cpp_mlp=use_cpp_mlp).to(base_device)
+
+    def extra_repr(self):
+        return f"interaction_op={self._interaction_op}, hash_indices={self._hash_indices}"
+
+    # pylint:enable=missing-docstring
+    @classmethod
+    def from_dict(cls, obj_dict, **kwargs):
+        """Create from json str"""
+        return cls(**obj_dict, **kwargs)
+
+    def forward(self, numerical_input, categorical_inputs):
+        """
+
+        Args:
+            numerical_input (Tensor): with shape [batch_size, num_numerical_features]
+            categorical_inputs (Tensor): with shape [batch_size, num_categorical_features]
+        """
+        bottom_output, bottom_mlp_output = self.bottom_model(numerical_input, categorical_inputs)
+        return self.top_model(bottom_output, bottom_mlp_output)

+ 0 - 0
PyTorch/Recommendation/DLRM/dlrm/nn/__init__.py


+ 248 - 0
PyTorch/Recommendation/DLRM/dlrm/nn/embeddings.py

@@ -0,0 +1,248 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from typing import Sequence, List, Iterable
+
+import torch
+from absl import logging
+from torch import nn
+
+from dlrm import cuda_ext
+from dlrm.cuda_ext.fused_gather_embedding import BuckleEmbeddingFusedGatherFunction
+
+
+class Embeddings(nn.Module):
+
+    def forward(self, categorical_inputs) -> List[torch.Tensor]:
+        raise NotImplementedError()
+
+    @property
+    def weights(self) -> List[torch.Tensor]:
+        """
+        Note: output list size should match number of handled categorical features
+        """
+        raise NotImplementedError()
+
+    def load_weights(self, weights: Iterable[torch.Tensor]):
+        raise NotImplementedError()
+
+
+class MultiTableEmbeddings(Embeddings):
+
+    def __init__(
+        self,
+        categorical_feature_sizes: Sequence[int],
+        embedding_dim: int,
+        hash_indices: bool = False,
+        device: str = "cuda"
+    ):
+        super().__init__()
+        self._categorical_feature_sizes = copy.copy(categorical_feature_sizes)
+        self._base_device = device
+        self._embedding_device_map = [device for _ in range(len(categorical_feature_sizes))]
+
+        embeddings = []
+        # Each embedding table has size [num_features, embedding_dim]
+        for i, num_features in enumerate(categorical_feature_sizes):
+            # Allocate directly on GPU is much faster than allocating on CPU then copying over
+            embedding_weight = torch.empty((num_features, embedding_dim), device=self._embedding_device_map[i])
+            embedding = nn.Embedding.from_pretrained(embedding_weight, freeze=False, sparse=True)
+            embeddings.append(embedding)
+
+        self.embeddings = nn.ModuleList(embeddings)
+        self.hash_indices = hash_indices
+        self.embedding_dim = embedding_dim
+
+    def forward(self, categorical_inputs) -> List[torch.Tensor]:
+        """
+        Args:
+            categorical_inputs (Tensor): with shape [batch_size, num_categorical_features]
+
+        Returns:
+            Tensor: embedding outputs in shape [batch, embedding_num, embedding_dim]
+        """
+        # Put indices on the same device as corresponding embedding
+        device_indices = []
+        for embedding_id, _ in enumerate(self.embeddings):
+            device_indices.append(categorical_inputs[:, embedding_id].to(self._embedding_device_map[embedding_id]))
+
+        # embedding_outputs will be a list of (26 in the case of Criteo) fetched embeddings with shape
+        # [batch_size, embedding_size]
+        embedding_outputs = []
+        for embedding_id, embedding in enumerate(self.embeddings):
+            if self.hash_indices:
+                device_indices[embedding_id] %= embedding.num_embeddings
+
+            embedding_outputs.append(embedding(device_indices[embedding_id]).to(self._base_device).unsqueeze(1))
+
+        return embedding_outputs
+
+    @property
+    def weights(self):
+        return [embedding.weight.data for embedding in self.embeddings]
+
+    def load_weights(self, weights: Iterable[torch.Tensor]):
+        for embedding, weight in zip(self.embeddings, weights):
+            embedding.weight.data = weight
+            embedding.weight.data.requires_grad_()
+
+
+class JointEmbedding(Embeddings):
+    """Buckle multiple one hot embedding together
+
+    Multiple one hot embedding can be done as one embedding (indexing). Use nn.Embedding to deal with sparse wgrad
+    before I fully customizing it.
+
+    Args:
+        categorical_feature_sizes (list): A list of integer indicating number of features of each embedding table
+        embedding_dim (int): the size of each embedding vector
+        device (torch.device): where to create the embedding. Default "cuda"
+    """
+    def __init__(
+        self,
+        categorical_feature_sizes: Sequence[int],
+        embedding_dim: int,
+        device: str = "cuda",
+        hash_indices: bool = False
+    ):
+        super().__init__()
+        self._categorical_feature_sizes = copy.copy(categorical_feature_sizes)
+
+        self.register_buffer("offsets", torch.tensor([0] + list(categorical_feature_sizes), device=device).cumsum(0))
+
+        embedding_weight = torch.empty((self.offsets[-1].item(), embedding_dim), device=device)
+        self.embedding = nn.Embedding.from_pretrained(embedding_weight, freeze=False, sparse=True)
+        self.hash_indices = hash_indices
+
+    # pylint:disable=missing-docstring
+    def forward(self, categorical_inputs) -> List[torch.Tensor]:
+        if self.hash_indices:
+            for cat, size in enumerate(self._categorical_feature_sizes):
+                categorical_inputs[:, cat] %= size
+                logging.log_first_n(logging.WARNING, F"Hashed indices out of range.", 1)
+
+        return [self.embedding(categorical_inputs + self.offsets[:-1])]
+
+    def extra_repr(self):
+        s = F"offsets={self.offsets.cpu().numpy()}"
+        return s
+    # pylint:enable=missing-docstring
+
+    @property
+    def weights(self):
+        return [self.embedding.weight.data[self.offsets[cat]:self.offsets[cat + 1]]
+                for cat in range(len(self._categorical_feature_sizes))]
+
+    def load_weights(self, weights: Iterable[torch.Tensor]):
+        data = self.embedding.weight.data
+        offsets = self.offsets
+
+        for cat, weight in zip(range(len(self._categorical_feature_sizes)), weights):
+            data[offsets[cat]:offsets[cat + 1]] = weight
+
+
+class FusedJointEmbedding(Embeddings):
+    """
+    Buckle multiple one hot embedding together
+
+    Multiple one hot embedding can be done as one embedding (indexing).
+    Args:
+    categorical_feature_sizes (list): A list of integer indicating number of features of each embedding table
+    embedding_dim (int): the size of each embedding vector
+    device (torch.device): where to create the embedding. Default "cuda"
+    """
+
+    def __init__(
+        self,
+        categorical_feature_sizes: Sequence[int],
+        embedding_dim: int,
+        device: str = "cuda",
+        hash_indices: bool = False,
+        amp_train: bool = False
+    ):
+        super().__init__()
+        self._categorical_feature_sizes = copy.copy(categorical_feature_sizes)
+
+        self.embedding_dim = embedding_dim
+        self.amp_train = amp_train
+        self.hash_indices = hash_indices
+
+        self.register_buffer("offsets", torch.tensor([0] + categorical_feature_sizes).cumsum(0).to(device))
+
+        self.register_parameter("weight", torch.nn.Parameter(
+            torch.empty((self.offsets[-1].item(), embedding_dim), device=device), requires_grad=True))
+
+    def forward(self, categorical_inputs) -> List[torch.Tensor]:
+        # Check input has the right shape
+        if self.hash_indices:
+            for cat, size in enumerate(self._categorical_feature_sizes):
+                categorical_inputs[:, cat] %= size
+                logging.log_first_n(logging.WARNING, F"Hashed indices out of range.", 1)
+
+        return [BuckleEmbeddingFusedGatherFunction.apply(self.weight, categorical_inputs, self.offsets, self.amp_train)]
+
+    def extra_repr(self):
+        return 'embedding_dim={}, categorical_feature_sizes={}, offsets={}'.format(
+            self.embedding_dim, self._categorical_feature_sizes, self.offsets)
+
+    @property
+    def weights(self) -> List[torch.Tensor]:
+        return [self.weight.data[self.offsets[cat]:self.offsets[cat + 1]]
+                for cat in range(len(self._categorical_feature_sizes))]
+
+    def load_weights(self, weights: Iterable[torch.Tensor]):
+        data = self.weight.data
+        offsets = self.offsets
+
+        for cat, weight in zip(range(len(self._categorical_feature_sizes)), weights):
+            data[offsets[cat]:offsets[cat + 1]] = weight
+
+
+class JointSparseEmbedding(Embeddings):
+
+    def __init__(
+        self,
+        categorical_feature_sizes: List[int],
+        embedding_dim: int,
+        device: str = "cuda",
+        hash_indices: bool = False
+    ):
+        super().__init__()
+        self._categorical_feature_sizes = categorical_feature_sizes
+        self.embedding = cuda_ext.JointSparseEmbedding(categorical_feature_sizes, embedding_dim, device)
+        self.hash_indices = hash_indices
+
+    def forward(self, categorical_inputs) -> List[torch.Tensor]:
+        if self.hash_indices:
+            for cat, size in enumerate(self._categorical_feature_sizes):
+                categorical_inputs[:, cat] %= size
+                logging.log_first_n(logging.WARNING, F"Hashed indices out of range.", 1)
+
+        return [
+            self.embedding(categorical_inputs)
+        ]
+
+    @property
+    def weights(self):
+        data = self.embedding.weights.data
+        offsets = self.embedding.offsets
+        return [data[offsets[cat]:offsets[cat + 1]] for cat in range(len(self._categorical_feature_sizes))]
+
+    def load_weights(self, weights: Iterable[torch.Tensor]):
+        data = self.embedding.weights.data
+        offsets = self.embedding.offsets
+
+        for cat, weight in zip(range(len(self._categorical_feature_sizes)), weights):
+            data[offsets[cat]:offsets[cat + 1]] = weight

+ 64 - 0
PyTorch/Recommendation/DLRM/dlrm/nn/factories.py

@@ -0,0 +1,64 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Sequence
+
+from dlrm.nn.embeddings import (
+    JointEmbedding, MultiTableEmbeddings, FusedJointEmbedding, JointSparseEmbedding,
+    Embeddings
+)
+from dlrm.nn.interactions import Interaction, CudaDotInteraction, DotInteraction, CatInteraction
+from dlrm.nn.mlps import AbstractMlp, CppMlp, TorchMlp
+from dlrm.utils.distributed import is_distributed
+
+
+def create_mlp(input_dim: int, sizes: Sequence[int], use_cpp_mlp: bool) -> AbstractMlp:
+    return CppMlp(input_dim, sizes) if use_cpp_mlp else TorchMlp(input_dim, sizes)
+
+
+def create_embeddings(
+        embedding_type: str,
+        categorical_feature_sizes: Sequence[int],
+        embedding_dim: int,
+        device: str = "cuda",
+        hash_indices: bool = False,
+        fp16: bool = False
+) -> Embeddings:
+    if embedding_type == "joint":
+        return JointEmbedding(categorical_feature_sizes, embedding_dim, device=device, hash_indices=hash_indices)
+    elif embedding_type == "joint_fused":
+        assert not is_distributed(), "Joint fused embedding is not supported in the distributed mode. " \
+                                     "You may want to use 'joint_sparse' option instead."
+        return FusedJointEmbedding(categorical_feature_sizes, embedding_dim, device=device, hash_indices=hash_indices,
+                                   amp_train=fp16)
+    elif embedding_type == "joint_sparse":
+        return JointSparseEmbedding(categorical_feature_sizes, embedding_dim, device=device, hash_indices=hash_indices)
+    elif embedding_type == "multi_table":
+        return MultiTableEmbeddings(categorical_feature_sizes, embedding_dim,
+                                    hash_indices=hash_indices, device=device)
+    else:
+        raise NotImplementedError(f"unknown embedding type: {embedding_type}")
+
+
+def create_interaction(interaction_op: str, embedding_num: int, embedding_dim: int) -> Interaction:
+    if interaction_op == "dot":
+        return DotInteraction(embedding_num, embedding_dim)
+    elif interaction_op == "cuda_dot":
+        return CudaDotInteraction(
+            DotInteraction(embedding_num, embedding_dim)
+        )
+    elif interaction_op == "cat":
+        return CatInteraction(embedding_num, embedding_dim)
+    else:
+        raise NotImplementedError(f"unknown interaction op: {interaction_op}")

+ 113 - 0
PyTorch/Recommendation/DLRM/dlrm/nn/interactions.py

@@ -0,0 +1,113 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from dlrm.cuda_ext import dotBasedInteract
+
+
+class Interaction:
+
+    @property
+    def num_interactions(self) -> int:
+        raise NotImplementedError()
+
+    def interact(self, bottom_output, bottom_mlp_output):
+        """
+        :param bottom_output: [batch_size, 1 + #embeddings, embedding_dim]
+        :param bottom_mlp_output
+        :return:
+        """
+        raise NotImplementedError()
+
+
+class DotInteraction(Interaction):
+
+    def __init__(self, embedding_num: int, embedding_dim: int):
+        """
+        Interactions are among outputs of all the embedding tables and bottom MLP, total number of
+        (num_embedding_tables + 1) vectors with size embedding_dim. ``dot`` product interaction computes dot product
+        between any 2 vectors. Output of interaction will have shape [num_interactions, embedding_dim].
+        """
+        self._num_interaction_inputs = embedding_num + 1
+        self._embedding_dim = embedding_dim
+        self._tril_indices = torch.tensor([[i for i in range(self._num_interaction_inputs)
+                                            for _ in range(i)],
+                                           [j for i in range(self._num_interaction_inputs)
+                                            for j in range(i)]])
+
+    @property
+    def num_interactions(self) -> int:
+        n = (self._num_interaction_inputs * (self._num_interaction_inputs - 1)) // 2 + self._embedding_dim
+        return n + 1  # pad 1 to be multiple of 8
+
+    def interact(self, bottom_output, bottom_mlp_output):
+        """
+        :param bottom_output: [batch_size, 1 + #embeddings, embedding_dim]
+        :param bottom_mlp_output
+        :return:
+        """
+        batch_size = bottom_output.size()[0]
+
+        interaction = torch.bmm(bottom_output, torch.transpose(bottom_output, 1, 2))
+        interaction_flat = interaction[:, self._tril_indices[0], self._tril_indices[1]]
+
+        # concatenate dense features and interactions
+        zeros_padding = torch.zeros(batch_size, 1, dtype=bottom_output.dtype, device=bottom_output.device)
+        interaction_output = torch.cat(
+            (bottom_mlp_output, interaction_flat, zeros_padding), dim=1)
+
+        return interaction_output
+
+
+class CudaDotInteraction(Interaction):
+
+    def __init__(self, dot_interaction: DotInteraction):
+        self._dot_interaction = dot_interaction
+
+    @property
+    def num_interactions(self):
+        return self._dot_interaction.num_interactions
+
+    def interact(self, bottom_output, bottom_mlp_output):
+        """
+        :param bottom_output: [batch_size, 1 + #embeddings, embedding_dim]
+        :param bottom_mlp_output
+        :return:
+        """
+        return dotBasedInteract(bottom_output, bottom_mlp_output)
+
+
+class CatInteraction(Interaction):
+
+    def __init__(self, embedding_num: int, embedding_dim: int):
+        """
+        Interactions are among outputs of all the embedding tables and bottom MLP, total number of
+        (num_embedding_tables + 1) vectors with size embdding_dim. ``cat`` interaction concatenate all the vectors
+        together. Output of interaction will have shape [num_interactions, embedding_dim].
+        """
+        self._num_interaction_inputs = embedding_num + 1
+        self._embedding_dim = embedding_dim
+
+    @property
+    def num_interactions(self) -> int:
+        return self._num_interaction_inputs * self._embedding_dim
+
+    def interact(self, bottom_output, bottom_mlp_output):
+        """
+        :param bottom_output: [batch_size, 1 + #embeddings, embedding_dim]
+        :param bottom_mlp_output
+        :return:
+        """
+        return bottom_output.view(-1, self.num_interactions)

+ 117 - 0
PyTorch/Recommendation/DLRM/dlrm/nn/mlps.py

@@ -0,0 +1,117 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Sequence, List, Iterable
+
+import apex.mlp
+import torch
+from torch import nn
+
+
+class AbstractMlp(nn.Module):
+    """
+    MLP interface used for configuration-agnostic checkpointing (`dlrm.utils.checkpointing`)
+    and easily swappable MLP implementation
+    """
+
+    @property
+    def weights(self) -> List[torch.Tensor]:
+        """
+        Getter for all MLP layers weights (without biases)
+        """
+        raise NotImplementedError()
+
+    @property
+    def biases(self) -> List[torch.Tensor]:
+        """
+        Getter for all MLP layers biases
+        """
+        raise NotImplementedError()
+
+    def forward(self, mlp_input: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError()
+
+    def load_state(self, weights: Iterable[torch.Tensor], biases: Iterable[torch.Tensor]):
+        for new_weight, weight, new_bias, bias in zip(weights, self.weights, biases, self.biases):
+            weight.data = new_weight.data
+            weight.data.requires_grad_()
+
+            bias.data = new_bias.data
+            bias.data.requires_grad_()
+
+
+class TorchMlp(AbstractMlp):
+    def __init__(self, input_dim: int, sizes: Sequence[int]):
+        super().__init__()
+
+        layers = []
+        for output_dims in sizes:
+            layers.append(nn.Linear(input_dim, output_dims))
+            layers.append(nn.ReLU(inplace=True))
+            input_dim = output_dims
+
+        self.layers = nn.Sequential(*layers)
+
+        self._initialize_weights()
+
+    def _initialize_weights(self):
+        for module in self.modules():
+            if isinstance(module, nn.Linear):
+                nn.init.normal_(module.weight.data, 0., math.sqrt(2. / (module.in_features + module.out_features)))
+                nn.init.normal_(module.bias.data, 0., math.sqrt(1. / module.out_features))
+
+    @property
+    def weights(self):
+        return [layer.weight for layer in self.layers if isinstance(layer, nn.Linear)]
+
+    @property
+    def biases(self):
+        return [layer.bias for layer in self.layers if isinstance(layer, nn.Linear)]
+
+    def forward(self, mlp_input: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            mlp_input (Tensor): with shape [batch_size, num_features]
+
+        Returns:
+            Tensor: Mlp output in shape [batch_size, num_output_features]
+        """
+        return self.layers(mlp_input)
+
+
+class CppMlp(AbstractMlp):
+
+    def __init__(self, input_dim: int, sizes: Sequence[int]):
+        super().__init__()
+
+        self.mlp = apex.mlp.MLP([input_dim] + list(sizes))
+
+    @property
+    def weights(self):
+        return self.mlp.weights
+
+    @property
+    def biases(self):
+        return self.mlp.biases
+
+    def forward(self, mlp_input: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            mlp_input (Tensor): with shape [batch_size, num_features]
+
+        Returns:
+            Tensor: Mlp output in shape [batch_size, num_output_features]
+        """
+        return self.mlp(mlp_input)

+ 135 - 0
PyTorch/Recommendation/DLRM/dlrm/nn/parts.py

@@ -0,0 +1,135 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import math
+from typing import Sequence, Optional, Tuple
+
+import torch
+from torch import nn
+
+from dlrm.nn.embeddings import Embeddings
+from dlrm.nn.factories import create_embeddings, create_mlp
+from dlrm.nn.interactions import Interaction
+
+
+class DlrmBottom(nn.Module):
+
+    def __init__(
+        self,
+        num_numerical_features: int,
+        categorical_feature_sizes: Sequence[int],
+        bottom_mlp_sizes: Optional[Sequence[int]] = None,
+        embedding_type: str = "multi_table",
+        embedding_dim: int = 128,
+        hash_indices: bool = False,
+        use_cpp_mlp: bool = False,
+        fp16: bool = False,
+        device: str = "cuda"
+    ):
+        super().__init__()
+        assert bottom_mlp_sizes is None or embedding_dim == bottom_mlp_sizes[-1], "The last bottom MLP layer must" \
+                                                                                  " have same size as embedding."
+        self._embedding_dim = embedding_dim
+        self._categorical_feature_sizes = copy.copy(categorical_feature_sizes)
+        self._fp16 = fp16
+
+        self.embeddings = create_embeddings(
+            embedding_type,
+            categorical_feature_sizes,
+            embedding_dim,
+            device,
+            hash_indices,
+            fp16
+        )
+        self.mlp = (create_mlp(num_numerical_features, bottom_mlp_sizes, use_cpp_mlp).to(device)
+                    if bottom_mlp_sizes else torch.nn.ModuleList())
+
+        self._initialize_embeddings_weights(self.embeddings, categorical_feature_sizes)
+
+    def _initialize_embeddings_weights(self, embeddings: Embeddings, categorical_feature_sizes: Sequence[int]):
+        assert len(embeddings.weights) == len(categorical_feature_sizes)
+
+        for size, weight in zip(categorical_feature_sizes, embeddings.weights):
+            nn.init.uniform_(
+                weight,
+                -math.sqrt(1. / size),
+                math.sqrt(1. / size)
+            )
+
+    @property
+    def num_categorical_features(self) -> int:
+        return len(self._categorical_feature_sizes)
+
+    @property
+    def num_feature_vectors(self) -> int:
+        return self.num_categorical_features + int(self.mlp is not None)
+
+    def forward(self, numerical_input, categorical_inputs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """
+        Args:
+            numerical_input (Tensor): with shape [batch_size, num_numerical_features]
+            categorical_inputs (Tensor): with shape [batch_size, num_categorical_features]
+
+        Returns:
+            Tensor: Concatenated bottom mlp and embedding output in shape [batch, 1 + #embedding, embedding_dim]
+        """
+        batch_size = categorical_inputs.size()[0]
+        bottom_output = []
+        bottom_mlp_output = None
+
+        if self.mlp:
+            bottom_mlp_output = self.mlp(numerical_input)
+            if self._fp16:
+                bottom_mlp_output = bottom_mlp_output.half()
+
+            # reshape bottom mlp to concatenate with embeddings
+            bottom_output.append(bottom_mlp_output.view(batch_size, 1, -1))
+
+        bottom_output += self.embeddings(categorical_inputs)
+
+        if self._fp16:
+            bottom_output = [x.half() if x.dtype != torch.half else x for x in bottom_output]
+
+        if len(bottom_output) == 1:
+            return bottom_output[0], bottom_mlp_output
+
+        return torch.cat(bottom_output, dim=1), bottom_mlp_output
+
+
+class DlrmTop(nn.Module):
+
+    def __init__(self, top_mlp_sizes: Sequence[int], interaction: Interaction, use_cpp_mlp: bool = False):
+        super().__init__()
+
+        self.interaction = interaction
+        self.mlp = create_mlp(interaction.num_interactions, top_mlp_sizes[:-1], use_cpp_mlp)
+        self.out = nn.Linear(top_mlp_sizes[-2], top_mlp_sizes[-1])
+
+        self._initialize_weights()
+
+    def _initialize_weights(self):
+        # Explicitly set weight corresponding to zero padded interaction output. They will
+        # stay 0 throughout the entire training. An assert can be added to the end of the training
+        # to prove it doesn't increase model capacity but just 0 paddings.
+        nn.init.zeros_(self.mlp.weights[0][:, -1].data)
+
+    def forward(self, bottom_output, bottom_mlp_output):
+        """
+        Args:
+            bottom_output (Tensor): with shape [batch_size, 1 + #embeddings, embedding_dim]
+            bottom_mlp_output (Tensor): with shape [batch_size, embedding_dim]
+        """
+        interaction_output = self.interaction.interact(bottom_output, bottom_mlp_output)
+        return self.out(self.mlp(interaction_output))

+ 401 - 0
PyTorch/Recommendation/DLRM/dlrm/scripts/dist_main.py

@@ -0,0 +1,401 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import datetime
+import itertools
+import sys
+from pprint import pprint
+from time import time
+
+import dllogger
+import numpy as np
+import torch
+from absl import app, flags, logging
+from apex import amp, parallel, optimizers as apex_optim
+
+import dlrm.scripts.utils as utils
+from dlrm.data.data_loader import get_data_loaders
+from dlrm.data.utils import prefetcher
+from dlrm.model.distributed import DistributedDlrm
+from dlrm.scripts.main import FLAGS, get_categorical_feature_sizes
+from dlrm.utils import distributed as dist
+from dlrm.utils.checkpointing.distributed import make_distributed_checkpoint_writer, make_distributed_checkpoint_loader
+from dlrm.utils.distributed import get_gpu_batch_sizes, get_criteo_device_mapping, is_main_process, is_distributed
+
+# Training schedule flags
+FLAGS.set_default("batch_size", 65536)
+FLAGS.set_default("test_batch_size", 131072)
+FLAGS.set_default("lr", 24.0)
+FLAGS.set_default("warmup_factor", 0)
+FLAGS.set_default("warmup_steps", 8000)
+FLAGS.set_default("decay_steps", 24000)
+FLAGS.set_default("decay_start_step", 48000)
+FLAGS.set_default("decay_power", 2)
+FLAGS.set_default("decay_end_lr", 0)
+FLAGS.set_default("embedding_type", "joint_sparse")
+
+flags.DEFINE_string("backend", "nccl", "Backend to use for distributed training. Default nccl")
+flags.DEFINE_boolean("bottom_features_ordered", False, "Sort features from the bottom model, useful when using saved "
+                                                       "checkpoint in different device configurations")
+
+
+def main(argv):
+    torch.manual_seed(FLAGS.seed)
+
+    utils.init_logging(log_path=FLAGS.log_path)
+
+    use_gpu = "cpu" not in FLAGS.base_device.lower()
+    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu)
+    device = FLAGS.base_device
+
+    if not is_distributed():
+        raise NotImplementedError("This file is only for distributed training.")
+
+    if is_main_process():
+        dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')
+
+        print("Command line flags:")
+        pprint(FLAGS.flag_values_dict())
+
+    print("Creating data loaders")
+
+    FLAGS.set_default("test_batch_size", FLAGS.test_batch_size // world_size * world_size)
+
+    categorical_feature_sizes = get_categorical_feature_sizes(FLAGS)
+    world_categorical_feature_sizes = np.asarray(categorical_feature_sizes)
+    device_mapping = get_criteo_device_mapping(world_size)
+
+    batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size, num_gpus=world_size)
+    batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu)))
+
+    # sizes of embeddings for each GPU
+    categorical_feature_sizes = world_categorical_feature_sizes[device_mapping['embedding'][rank]].tolist()
+
+    bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping['bottom_mlp'] else None
+
+    data_loader_train, data_loader_test = get_data_loaders(FLAGS, device_mapping=device_mapping)
+
+    model = DistributedDlrm(
+        vectors_per_gpu=device_mapping['vectors_per_gpu'],
+        embedding_device_mapping=device_mapping['embedding'],
+        embedding_type=FLAGS.embedding_type,
+        embedding_dim=FLAGS.embedding_dim,
+        world_num_categorical_features=len(world_categorical_feature_sizes),
+        categorical_feature_sizes=categorical_feature_sizes,
+        num_numerical_features=FLAGS.num_numerical_features,
+        hash_indices=FLAGS.hash_indices,
+        bottom_mlp_sizes=bottom_mlp_sizes,
+        top_mlp_sizes=FLAGS.top_mlp_sizes,
+        interaction_op=FLAGS.interaction_op,
+        fp16=FLAGS.amp,
+        use_cpp_mlp=FLAGS.optimized_mlp,
+        bottom_features_ordered=FLAGS.bottom_features_ordered,
+        device=device
+    )
+    print(model)
+    print(device_mapping)
+    print(f"Batch sizes per gpu: {batch_sizes_per_gpu}")
+
+    dist.setup_distributed_print(is_main_process())
+
+    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
+    # Compensate it with further scaling lr
+    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr
+    scaled_lrs = [scaled_lr / world_size, scaled_lr]
+
+    embedding_optimizer = torch.optim.SGD([
+        {'params': model.bottom_model.embeddings.parameters(), 'lr': scaled_lrs[0]},
+    ])
+    mlp_optimizer = apex_optim.FusedSGD([
+        {'params': model.bottom_model.mlp.parameters(), 'lr': scaled_lrs[0]},
+        {'params': model.top_model.parameters(), 'lr': scaled_lrs[1]}
+    ])
+
+    checkpoint_writer = make_distributed_checkpoint_writer(
+        device_mapping=device_mapping,
+        rank=rank,
+        is_main_process=is_main_process(),
+        config=FLAGS.flag_values_dict()
+    )
+
+    checkpoint_loader = make_distributed_checkpoint_loader(device_mapping=device_mapping, rank=rank)
+
+    if FLAGS.load_checkpoint_path:
+        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
+        model.to(device)
+
+    if FLAGS.amp:
+        (model.top_model, model.bottom_model.mlp), mlp_optimizer = amp.initialize(
+            [model.top_model, model.bottom_model.mlp], mlp_optimizer, opt_level="O2", loss_scale=1)
+
+    if use_gpu:
+        model.top_model = parallel.DistributedDataParallel(model.top_model)
+    else:  # Use other backend for CPU
+        model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model)
+
+    if FLAGS.mode == 'test':
+        auc = dist_evaluate(model, data_loader_test)
+
+        results = {'auc': auc}
+        dllogger.log(data=results, step=tuple())
+
+        if auc is not None:
+            print(F"Finished testing. Test auc {auc:.4f}")
+        return
+
+    if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process():
+        logging.warning("Saving checkpoint without --bottom_features_ordered flag will result in "
+                        "a device-order dependent model. Consider using --bottom_features_ordered "
+                        "if you plan to load the checkpoint in different device configurations.")
+
+    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")
+
+    # Print per 16384 * 2000 samples by default
+    default_print_freq = 16384 * 2000 // FLAGS.batch_size
+    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq
+
+    steps_per_epoch = len(data_loader_train)
+    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1
+
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
+    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}'))
+    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+
+    # Accumulating loss on GPU to avoid memcpyD2H every step
+    moving_loss = torch.zeros(1, device=device)
+    moving_loss_stream = torch.cuda.Stream()
+
+    lr_scheduler = utils.LearningRateScheduler(optimizers=[mlp_optimizer, embedding_optimizer],
+                                               base_lrs=[scaled_lrs, [scaled_lrs[0]]],
+                                               warmup_steps=FLAGS.warmup_steps,
+                                               warmup_factor=FLAGS.warmup_factor,
+                                               decay_start_step=FLAGS.decay_start_step,
+                                               decay_steps=FLAGS.decay_steps,
+                                               decay_power=FLAGS.decay_power,
+                                               end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)
+
+    data_stream = torch.cuda.Stream()
+    timer = utils.StepTimer()
+
+    best_auc = 0
+    best_epoch = 0
+    start_time = time()
+    stop_time = time()
+
+    for epoch in range(FLAGS.epochs):
+        epoch_start_time = time()
+
+        batch_iter = prefetcher(iter(data_loader_train), data_stream)
+
+        for step in range(len(data_loader_train)):
+            timer.click()
+
+            numerical_features, categorical_features, click = next(batch_iter)
+            torch.cuda.synchronize()
+
+            global_step = steps_per_epoch * epoch + step
+
+            if FLAGS.max_steps and global_step > FLAGS.max_steps:
+                print(F"Reached max global steps of {FLAGS.max_steps}. Stopping.")
+                break
+
+            lr_scheduler.step()
+
+            if click.shape[0] != FLAGS.batch_size:  # last batch
+                logging.error("The last batch with size %s is not supported", click.shape[0])
+            else:
+                output = model(numerical_features, categorical_features, batch_sizes_per_gpu).squeeze()
+
+                loss = loss_fn(output, click[batch_indices[rank]: batch_indices[rank + 1]])
+
+                # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
+                for param_group in itertools.chain(embedding_optimizer.param_groups, mlp_optimizer.param_groups):
+                    for param in param_group['params']:
+                        param.grad = None
+
+                if FLAGS.amp:
+                    loss *= FLAGS.loss_scale
+                    with amp.scale_loss(loss, mlp_optimizer) as scaled_loss:
+                        scaled_loss.backward()
+                else:
+                    loss.backward()
+
+                mlp_optimizer.step()
+                embedding_optimizer.step()
+
+                moving_loss_stream.wait_stream(torch.cuda.current_stream())
+                with torch.cuda.stream(moving_loss_stream):
+                    moving_loss += loss
+
+            if timer.measured is None:
+                # first iteration, no step time etc. to print
+                continue
+
+            if step == 0:
+                print(F"Started epoch {epoch}...")
+            elif step % print_freq == 0:
+                torch.cuda.current_stream().wait_stream(moving_loss_stream)
+                # Averaging cross a print_freq period to reduce the error.
+                # An accurate timing needs synchronize which would slow things down.
+
+                if global_step < FLAGS.benchmark_warmup_steps:
+                    metric_logger.update(
+                        loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.amp else 1),
+                        lr=mlp_optimizer.param_groups[1]["lr"] * (FLAGS.loss_scale if FLAGS.amp else 1))
+                else:
+                    metric_logger.update(
+                        step_time=timer.measured,
+                        loss=moving_loss.item() / print_freq / (FLAGS.loss_scale if FLAGS.amp else 1),
+                        lr=mlp_optimizer.param_groups[1]["lr"] * (FLAGS.loss_scale if FLAGS.amp else 1))
+                stop_time = time()
+
+                eta_str = datetime.timedelta(seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step)))
+                metric_logger.print(
+                    header=F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}")
+
+                with torch.cuda.stream(moving_loss_stream):
+                    moving_loss = 0.
+
+            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
+                auc = dist_evaluate(model, data_loader_test)
+
+                if auc is None:
+                    continue
+
+                print(F"Epoch {epoch} step {step}. auc {auc:.6f}")
+                stop_time = time()
+
+                if auc > best_auc:
+                    best_auc = auc
+                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
+
+                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
+                    run_time_s = int(stop_time - start_time)
+                    print(F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
+                          F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
+                          F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s.")
+                    sys.exit()
+
+        epoch_stop_time = time()
+        epoch_time_s = epoch_stop_time - epoch_start_time
+        print(F"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. "
+              F"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s.")
+
+    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg
+
+    if FLAGS.save_checkpoint_path:
+        checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path, epoch, step)
+
+    results = {'best_auc': best_auc,
+               'best_epoch': best_epoch,
+               'average_train_throughput': avg_throughput}
+
+    dllogger.log(data=results, step=tuple())
+
+
+def dist_evaluate(model, data_loader):
+    """Test distributed DLRM model
+
+    Args:
+        model (DistDLRM):
+        data_loader (torch.utils.data.DataLoader):
+    """
+    model.eval()
+
+    device = FLAGS.base_device
+    world_size = dist.get_world_size()
+
+    batch_sizes_per_gpu = [FLAGS.test_batch_size // world_size for _ in range(world_size)]
+    test_batch_size = sum(batch_sizes_per_gpu)
+
+    if FLAGS.test_batch_size != test_batch_size:
+        print(f"Rounded test_batch_size to {test_batch_size}")
+    print(f"Batch sizes per GPU {batch_sizes_per_gpu}")
+
+    # Test bach size could be big, make sure it prints
+    default_print_freq = max(524288 * 100 // test_batch_size, 1)
+    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq
+
+    steps_per_epoch = len(data_loader)
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
+
+    with torch.no_grad():
+        timer = utils.StepTimer()
+
+        # ROC can be computed per batch and then compute AUC globally, but I don't have the code.
+        # So pack all the outputs and labels together to compute AUC. y_true and y_score naming follows sklearn
+        y_true = []
+        y_score = []
+        data_stream = torch.cuda.Stream()
+
+        batch_iter = prefetcher(iter(data_loader), data_stream)
+
+        timer.click()
+
+        for step in range(len(data_loader)):
+            numerical_features, categorical_features, click = next(batch_iter)
+            torch.cuda.synchronize()
+
+            last_batch_size = None
+            if click.shape[0] != test_batch_size:  # last batch
+                last_batch_size = click.shape[0]
+                logging.warning("Pad the last test batch of size %d to %d", last_batch_size, test_batch_size)
+                padding_size = test_batch_size - last_batch_size
+
+                if numerical_features is not None:
+                    padding_numerical = torch.empty(
+                        padding_size, numerical_features.shape[1],
+                        device=numerical_features.device, dtype=numerical_features.dtype)
+                    numerical_features = torch.cat((numerical_features, padding_numerical), dim=0)
+
+                if categorical_features is not None:
+                    padding_categorical = torch.ones(
+                        padding_size, categorical_features.shape[1],
+                        device=categorical_features.device, dtype=categorical_features.dtype)
+                    categorical_features = torch.cat((categorical_features, padding_categorical), dim=0)
+
+            output = model(numerical_features, categorical_features, batch_sizes_per_gpu).squeeze()
+
+            output_receive_buffer = torch.empty(test_batch_size, device=device)
+            torch.distributed.all_gather(list(output_receive_buffer.split(batch_sizes_per_gpu)), output)
+            if last_batch_size is not None:
+                output_receive_buffer = output_receive_buffer[:last_batch_size]
+
+            y_true.append(click)
+            y_score.append(output_receive_buffer)
+
+            timer.click()
+
+            if timer.measured is not None:
+                metric_logger.update(step_time=timer.measured)
+                if step % print_freq == 0 and step > 0:
+                    metric_logger.print(header=F"Test: [{step}/{steps_per_epoch}]")
+
+        if is_main_process():
+            auc = utils.roc_auc_score(torch.cat(y_true), torch.sigmoid(torch.cat(y_score).float()))
+        else:
+            auc = None
+
+        torch.distributed.barrier()
+
+    model.train()
+
+    return auc
+
+
+if __name__ == '__main__':
+    app.run(main)

+ 135 - 161
PyTorch/Recommendation/DLRM/dlrm/scripts/main.py

@@ -11,29 +11,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
 import datetime
-import os
-import numpy as np
-import json
-from pprint import pprint
 from time import time
-from sklearn.metrics import roc_auc_score
-
-from absl import app
-from absl import flags
 
 import dllogger
-
+import numpy as np
 import torch
+from absl import app, flags
 from apex import amp
 
-from dlrm.data import data_loader
-from dlrm.data.synthetic_dataset import SyntheticDataset
-from dlrm.model import Dlrm
-
 import dlrm.scripts.utils as utils
+from dlrm.data.data_loader import get_data_loaders
+from dlrm.data.utils import get_categorical_feature_sizes, prefetcher
+from dlrm.model.single import Dlrm
+from dlrm.utils.checkpointing.serial import SerialCheckpointWriter, make_serial_checkpoint_writer, \
+    make_serial_checkpoint_loader
 
 FLAGS = flags.FLAGS
 
@@ -55,24 +47,32 @@ flags.DEFINE_integer("warmup_steps", 6400, "Number of warmup optimization steps"
 flags.DEFINE_integer("decay_steps", 80000, "Polynomial learning rate decay steps. If equal to 0 will not do any decaying")
 flags.DEFINE_integer("decay_start_step", 64000,
     "Optimization step after which to start decaying the learning rate, if None will start decaying right after the warmup phase is completed")
+flags.DEFINE_integer("decay_power", 2, "Polynomial learning rate decay power")
+flags.DEFINE_float("decay_end_lr", 0, "LR after the decay ends")
 
 # Model configuration
+flags.DEFINE_enum("embedding_type", "joint_fused", ["joint", "joint_fused", "joint_sparse", "multi_table"],
+                  help="The type of the embedding operation to use")
 flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of embedding space for categorical features")
 flags.DEFINE_list("top_mlp_sizes", [1024, 1024, 512, 256, 1], "Linear layer sizes for the top MLP")
 flags.DEFINE_list("bottom_mlp_sizes", [512, 256, 128], "Linear layer sizes for the bottom MLP")
 
-flags.DEFINE_string("interaction_op", "dot",
-                    "Type of interaction operation to perform. Supported choices: 'dot' or 'cat'")
-flags.DEFINE_boolean("self_interaction", False, "Set to True to use self-interaction")
+flags.DEFINE_enum("interaction_op", default="cuda_dot", enum_values=["cuda_dot", "dot", "cat"],
+                  help="Type of interaction operation to perform.")
 
 flags.DEFINE_string(
     "dataset", None,
     "Full path to binary dataset. Must include files such as: train_data.bin, test_data.bin")
+flags.DEFINE_enum("dataset_type", default="split", enum_values=['binary', 'split', 'synthetic_gpu', 'synthetic_disk'],
+                  help='The type of the dataset to use')
 
-flags.DEFINE_boolean("synthetic_dataset", False, "Use synthetic instead of real data for benchmarking purposes")
+flags.DEFINE_string("synthetic_dataset_dir", "/tmp/dlrm_sythetic_dataset", "Default synthetic disk dataset directory")
 flags.DEFINE_list("synthetic_dataset_table_sizes", default=','.join(26 * [str(10**5)]),
                   help="Embedding table sizes to use with the synthetic dataset")
 
+flags.DEFINE_integer("synthetic_dataset_num_entries", default=int(2**15 * 1024), # 1024 batches by default
+                     help="Number of samples per epoch for the synthetic dataset")
+
 flags.DEFINE_boolean("shuffle_batch_order", False, "Read batch in train dataset by random order", short_name="shuffle")
 
 flags.DEFINE_integer("num_numerical_features", 13,
@@ -101,8 +101,8 @@ flags.DEFINE_integer("benchmark_warmup_steps", 0, "Number of initial iterations
 
 # Machine setting flags
 flags.DEFINE_string("base_device", "cuda", "Device to run the majority of the model operations")
-flags.DEFINE_boolean("fp16", True, "If True (default) the script will use Automatic Mixed Precision")
-flags.DEFINE_float("loss_scale", 8192, "Static loss scale for Mixed Precision Training")
+flags.DEFINE_boolean("amp", False, "If True the script will use Automatic Mixed Precision")
+flags.DEFINE_float("loss_scale", 1024, "Static loss scale for Mixed Precision Training")
 
 # inference benchmark
 flags.DEFINE_list("inference_benchmark_batch_sizes", default=[1, 64, 4096],
@@ -111,91 +111,30 @@ flags.DEFINE_integer("inference_benchmark_steps", 200,
                      "Number of steps for measuring inference latency and throughput")
 
 flags.DEFINE_float("auc_threshold", None, "Stop the training after achieving this AUC")
+flags.DEFINE_boolean("optimized_mlp", True, "Use an optimized implementation of MLP from apex")
 
 
 def validate_flags():
     if FLAGS.max_table_size is not None and not FLAGS.hash_indices:
        raise ValueError('Hash indices must be True when setting a max_table_size')
 
+    if FLAGS.base_device == 'cpu':
+        if FLAGS.embedding_type in ('joint_fused', 'joint_sparse'):
+            print('WARNING: CUDA joint embeddings are not supported on CPU')
+            FLAGS.embedding_type = 'joint'
 
-def create_synthetic_datasets(train_batch_size, test_batch_size):
-    categorical_sizes = get_categorical_feature_sizes()
+        if FLAGS.amp:
+            print('WARNING: Automatic mixed precision not supported on CPU')
+            FLAGS.amp = False
 
-    dataset_train = SyntheticDataset(num_entries=4 * 10**9,
-                                     batch_size=train_batch_size,
-                                     dense_features=FLAGS.num_numerical_features,
-                                     categorical_feature_sizes=categorical_sizes)
+        if FLAGS.optimized_mlp:
+            print('WARNING: Optimized MLP is not supported on CPU')
+            FLAGS.optimized_mlp = False
 
-    dataset_test = SyntheticDataset(num_entries=100 * 10**6,
-                                    batch_size=test_batch_size,
-                                    dense_features=FLAGS.num_numerical_features,
-                                    categorical_feature_sizes=categorical_sizes)
-
-    return dataset_train, dataset_test
-
-
-def create_real_datasets(train_batch_size, test_batch_size, online_shuffle=True):
-    train_dataset = os.path.join(FLAGS.dataset, "train_data.bin")
-    test_dataset = os.path.join(FLAGS.dataset, "test_data.bin")
-    categorical_sizes = get_categorical_feature_sizes()
-
-    dataset_train = data_loader.CriteoBinDataset(
-        data_file=train_dataset,
-        batch_size=train_batch_size, subset=FLAGS.dataset_subset,
-        numerical_features=FLAGS.num_numerical_features,
-        categorical_features=len(categorical_sizes),
-        online_shuffle=online_shuffle
-    )
-
-    dataset_test = data_loader.CriteoBinDataset(
-        data_file=test_dataset, batch_size=test_batch_size,
-        numerical_features=FLAGS.num_numerical_features,
-        categorical_features=len(categorical_sizes),
-        online_shuffle = False
-    )
 
-    return dataset_train, dataset_test
+def is_data_prefetching_enabled() -> bool:
+    return FLAGS.base_device == 'cuda'
 
-def get_dataloaders(train_batch_size, test_batch_size):
-    print("Creating data loaders")
-    if FLAGS.synthetic_dataset:
-        dataset_train, dataset_test = create_synthetic_datasets(train_batch_size, test_batch_size)
-    else:
-        dataset_train, dataset_test = create_real_datasets(train_batch_size,
-                                                           test_batch_size,
-                                                           online_shuffle=FLAGS.shuffle_batch_order)
-
-    if FLAGS.shuffle_batch_order and not FLAGS.synthetic_dataset:
-        train_sampler = torch.utils.data.RandomSampler(dataset_train)
-    else:
-        train_sampler = None
-    data_loader_train = torch.utils.data.DataLoader(
-        dataset_train, batch_size=None, num_workers=0, pin_memory=False, sampler=train_sampler)
-    data_loader_test = torch.utils.data.DataLoader(
-        dataset_test, batch_size=None, num_workers=0, pin_memory=False)
-
-    return data_loader_train, data_loader_test
-
-
-def get_categorical_feature_sizes():
-    if FLAGS.synthetic_dataset:
-        feature_sizes = [int(s) for s in FLAGS.synthetic_dataset_table_sizes]
-        return feature_sizes
-
-    categorical_sizes_file = os.path.join(FLAGS.dataset, "model_size.json")
-    with open(categorical_sizes_file) as f:
-        categorical_sizes = json.load(f).values()
-
-    categorical_sizes = list(categorical_sizes)
-
-    # need to add 1 because the JSON file contains the max value not the count
-    categorical_sizes = [s + 1 for s in categorical_sizes]
-
-    if FLAGS.max_table_size is None:
-        return categorical_sizes
-
-    clipped_sizes = [min(s, FLAGS.max_table_size) for s in categorical_sizes]
-    return clipped_sizes
 
 def create_model():
     print("Creating model")
@@ -203,23 +142,30 @@ def create_model():
     model_config = {
         'top_mlp_sizes': FLAGS.top_mlp_sizes,
         'bottom_mlp_sizes': FLAGS.bottom_mlp_sizes,
+        'embedding_type': FLAGS.embedding_type,
         'embedding_dim': FLAGS.embedding_dim,
         'interaction_op': FLAGS.interaction_op,
-        'self_interaction': FLAGS.self_interaction,
-        'categorical_feature_sizes': get_categorical_feature_sizes(),
+        'categorical_feature_sizes': get_categorical_feature_sizes(FLAGS),
         'num_numerical_features': FLAGS.num_numerical_features,
         'hash_indices': FLAGS.hash_indices,
+        'use_cpp_mlp': FLAGS.optimized_mlp,
+        'fp16': FLAGS.amp,
         'base_device': FLAGS.base_device,
     }
 
     model = Dlrm.from_dict(model_config)
     print(model)
 
-    if FLAGS.load_checkpoint_path is not None:
-        model.load_state_dict(torch.load(FLAGS.load_checkpoint_path, map_location="cpu"))
-
     model.to(FLAGS.base_device)
 
+    if FLAGS.load_checkpoint_path is not None:
+        checkpoint_loader = make_serial_checkpoint_loader(
+            embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))),
+            device="cpu"
+        )
+        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
+        model.to(FLAGS.base_device)
+
     return model
 
 
@@ -230,25 +176,21 @@ def main(argv):
     utils.init_logging(log_path=FLAGS.log_path)
     dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')
 
-    data_loader_train, data_loader_test = get_dataloaders(train_batch_size=FLAGS.batch_size,
-                                                          test_batch_size=FLAGS.test_batch_size)
+    data_loader_train, data_loader_test = get_data_loaders(FLAGS)
 
-    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.fp16 else FLAGS.lr
+    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr
 
     model = create_model()
 
     optimizer = torch.optim.SGD(model.parameters(), lr=scaled_lr)
 
-    if FLAGS.fp16 and FLAGS.mode == 'train':
-        (model.top_mlp, model.bottom_mlp), optimizer = amp.initialize([model.top_mlp, model.bottom_mlp],
-                                                                      optimizer, opt_level="O2",
-                                                                      loss_scale=1)
-    elif FLAGS.fp16:
+    if FLAGS.amp and FLAGS.mode == 'train':
+        (model.top_model, model.bottom_model.mlp), optimizer = amp.initialize([model.top_model, model.bottom_model.mlp],
+                                                                              optimizer, opt_level="O2", loss_scale=1)
+    elif FLAGS.amp:
         model = model.half()
 
     loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")
-    loss_fn = torch.jit.trace(loss_fn.forward, (torch.rand(FLAGS.batch_size, 1).cuda(),
-                                                torch.rand(FLAGS.batch_size, 1).cuda()))
 
     if FLAGS.mode == 'test':
         loss, auc, test_step_time = evaluate(model, loss_fn, data_loader_test)
@@ -265,14 +207,15 @@ def main(argv):
     if FLAGS.mode == 'inference_benchmark':
         results = {}
 
-        if FLAGS.fp16:
+        if FLAGS.amp:
             # can use pure FP16 for inference
             model = model.half()
 
         for batch_size in FLAGS.inference_benchmark_batch_sizes:
             batch_size = int(batch_size)
-            _, benchmark_data_loader = get_dataloaders(train_batch_size=batch_size,
-                                                       test_batch_size=batch_size)
+            FLAGS.test_batch_size = batch_size
+
+            _, benchmark_data_loader = get_data_loaders(FLAGS)
 
             latencies = inference_benchmark(model=model, data_loader=benchmark_data_loader,
                                             num_batches=FLAGS.inference_benchmark_steps)
@@ -293,12 +236,14 @@ def main(argv):
         train(model, loss_fn, optimizer, data_loader_train, data_loader_test, scaled_lr)
 
 
-def maybe_save_checkpoint(model, path):
+def maybe_save_checkpoint(checkpoint_writer: SerialCheckpointWriter, model, path):
     if path is None:
         return
 
+    print(f'Saving a checkpoint to {path}')
+
     begin = time()
-    torch.save(model.state_dict(), path)
+    checkpoint_writer.save_checkpoint(model, path)
     end = time()
     print(f'Checkpoint saving took {end-begin:,.2f} [s]')
 
@@ -314,35 +259,43 @@ def train(model, loss_fn, optimizer, data_loader_train, data_loader_test, scaled
         data_loader_test (torch.utils.data.DataLoader):
     """
     model.train()
+    prefetching_enabled = is_data_prefetching_enabled()
     base_device = FLAGS.base_device
     print_freq = FLAGS.print_freq
     steps_per_epoch = len(data_loader_train)
 
-    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch
+    checkpoint_writer = make_serial_checkpoint_writer(
+        embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))),
+        config=FLAGS.flag_values_dict()
+    )
+
+    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1
 
     metric_logger = utils.MetricLogger(delimiter="  ")
-    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
-    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.6f}'))
+    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
     metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
 
+    if prefetching_enabled:
+        data_stream = torch.cuda.Stream()
+
     timer = utils.StepTimer()
 
     best_auc = 0
     best_epoch = 0
     start_time = time()
-    for epoch in range(FLAGS.epochs):
 
-        batch_iter = iter(data_loader_train)
-        for step in range(len(data_loader_train)):
-            timer.click()
+    timer.click()
 
-            global_step = steps_per_epoch * epoch + step
+    for epoch in range(FLAGS.epochs):
+        input_pipeline = iter(data_loader_train)
 
-            numerical_features, categorical_features, click = next(batch_iter)
+        if prefetching_enabled:
+            input_pipeline = prefetcher(input_pipeline, data_stream)
 
-            categorical_features = categorical_features.to(base_device).to(torch.long)
-            numerical_features = numerical_features.to(base_device)
-            click = click.to(base_device).to(torch.float32)
+        for step, batch in enumerate(input_pipeline):
+            global_step = steps_per_epoch * epoch + step
+            numerical_features, categorical_features, click = batch
 
             utils.lr_step(optimizer, num_warmup_iter=FLAGS.warmup_steps, current_step=global_step + 1,
                           base_lr=scaled_lr, warmup_factor=FLAGS.warmup_factor,
@@ -352,37 +305,43 @@ def train(model, loss_fn, optimizer, data_loader_train, data_loader_test, scaled
                 print(F"Reached max global steps of {FLAGS.max_steps}. Stopping.")
                 break
 
+            if prefetching_enabled:
+                torch.cuda.synchronize()
+
             output = model(numerical_features, categorical_features).squeeze().float()
 
             loss = loss_fn(output, click.squeeze())
 
-            optimizer.zero_grad()
-            if FLAGS.fp16:
+            # Setting grad to None is faster than zero_grad()
+            for param_group in optimizer.param_groups:
+                for param in param_group['params']:
+                    param.grad = None
+
+            if FLAGS.amp:
                 loss *= FLAGS.loss_scale
                 with amp.scale_loss(loss, optimizer) as scaled_loss:
                     scaled_loss.backward()
             else:
                 loss.backward()
-            optimizer.step()
 
-            loss_value = loss.item()
+            optimizer.step()
 
-            if timer.measured is None:
-                # first iteration, no step time etc. to print
-                continue
+            if step % print_freq == 0 and step > 0:
+                loss_value = loss.item()
 
+                timer.click()
 
-            if global_step < FLAGS.benchmark_warmup_steps:
-                metric_logger.update(
-                    loss=loss_value, lr=optimizer.param_groups[0]["lr"])
-            else:
-                unscale_factor = FLAGS.loss_scale if FLAGS.fp16 else 1
-                metric_logger.update(
-                     loss=loss_value / unscale_factor, step_time=timer.measured,
-                     lr=optimizer.param_groups[0]["lr"] * unscale_factor
-                )
+                if global_step < FLAGS.benchmark_warmup_steps:
+                    metric_logger.update(
+                        loss=loss_value, lr=optimizer.param_groups[0]["lr"])
+                else:
+                    unscale_factor = FLAGS.loss_scale if FLAGS.amp else 1
+                    metric_logger.update(
+                        loss=loss_value / unscale_factor,
+                        step_time=timer.measured / FLAGS.print_freq,
+                        lr=optimizer.param_groups[0]["lr"] * unscale_factor
+                    )
 
-            if step % print_freq == 0 and step > 0:
                 if global_step < FLAGS.benchmark_warmup_steps:
                     print(F'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]')
                     continue
@@ -391,14 +350,15 @@ def train(model, loss_fn, optimizer, data_loader_train, data_loader_test, scaled
                 metric_logger.print(
                     header=F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}")
 
-            if (global_step + 1) % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
+            if (global_step % test_freq == 0 and global_step > 0 and
+                    global_step / steps_per_epoch >= FLAGS.test_after):
                 loss, auc, test_step_time = evaluate(model, loss_fn, data_loader_test)
                 print(F"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}")
 
                 if auc > best_auc:
                     best_auc = auc
                     best_epoch = epoch + ((step + 1) / steps_per_epoch)
-                    maybe_save_checkpoint(model, FLAGS.save_checkpoint_path)
+                    maybe_save_checkpoint(checkpoint_writer, model, FLAGS.save_checkpoint_path)
 
                 if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                     stop_time = time()
@@ -408,6 +368,12 @@ def train(model, loss_fn, optimizer, data_loader_train, data_loader_test, scaled
                           F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s.")
                     return
 
+    stop_time = time()
+    run_time_s = int(stop_time - start_time)
+
+    print(F"Finished training in {run_time_s}s. "
+          F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s.")
+
     avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg
 
     results = {'best_auc' : best_auc,
@@ -430,31 +396,36 @@ def evaluate(model, loss_fn, data_loader):
         data_loader (torch.utils.data.DataLoader):
     """
     model.eval()
-    base_device = FLAGS.base_device
     print_freq = FLAGS.print_freq
+    prefetching_enabled = is_data_prefetching_enabled()
 
     steps_per_epoch = len(data_loader)
     metric_logger = utils.MetricLogger(delimiter="  ")
-    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
-    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
+    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
+    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
+
+    if prefetching_enabled:
+        data_stream = torch.cuda.Stream()
+
     with torch.no_grad():
         y_true = []
         y_score = []
 
         timer = utils.StepTimer()
-        batch_iter = iter(data_loader)
-
         timer.click()
-        for step in range(len(data_loader)):
-            numerical_features, categorical_features, click = next(batch_iter)
 
-            categorical_features = categorical_features.to(base_device).to(torch.long)
-            numerical_features = numerical_features.to(base_device)
-            click = click.to(torch.float32).to(base_device)
+        input_pipeline = iter(data_loader)
 
-            if FLAGS.fp16:
+        if prefetching_enabled:
+            input_pipeline = prefetcher(input_pipeline, data_stream)
+
+        for step, (numerical_features, categorical_features, click) in enumerate(input_pipeline):
+            if FLAGS.amp:
                 numerical_features = numerical_features.half()
 
+            if prefetching_enabled:
+                torch.cuda.synchronize()
+
             output = model(numerical_features, categorical_features).squeeze()
 
             loss = loss_fn(output, click)
@@ -469,9 +440,12 @@ def evaluate(model, loss_fn, data_loader):
                 if step % print_freq == 0 and step > 0:
                     metric_logger.print(header=F"Test: [{step}/{steps_per_epoch}]")
 
-        y_true = torch.cat(y_true).cpu().numpy()
-        y_score = torch.cat(y_score).cpu().numpy()
-        auc = roc_auc_score(y_true=y_true, y_score=y_score)
+        y_true = torch.cat(y_true)
+        y_score = torch.cat(y_score)
+
+        before_auc_timestamp = time()
+        auc = utils.roc_auc_score(y_true=y_true, y_score=y_score)
+        print(f'AUC computation took: {time() - before_auc_timestamp:.2f} [s]')
 
     model.train()
 
@@ -491,7 +465,7 @@ def inference_benchmark(model, data_loader, num_batches=100):
             step_start_time = time()
 
             numerical_features = numerical_features.to(base_device)
-            if FLAGS.fp16:
+            if FLAGS.amp:
                 numerical_features = numerical_features.half()
 
             categorical_features = categorical_features.to(device=base_device, dtype=torch.int64)

+ 26 - 0
PyTorch/Recommendation/DLRM/dlrm/scripts/prepare_synthetic_dataset.py

@@ -0,0 +1,26 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dlrm.data.factories import SyntheticDiskDatasetFactory
+from dlrm.scripts.main import FLAGS
+from absl import app
+
+
+def main(argv):
+    dataset_factory = SyntheticDiskDatasetFactory(FLAGS)
+    dataset_factory.create_datasets()
+
+
+if __name__ == '__main__':
+    app.run(main)

+ 105 - 78
PyTorch/Recommendation/DLRM/dlrm/scripts/utils.py

@@ -13,16 +13,16 @@
 # limitations under the License.
 
 
-from collections import defaultdict, deque
-import datetime
-import time
-import torch
-import torch.distributed as dist
-
 import errno
 import os
+import time
+from collections import defaultdict, deque
 
 import dllogger
+import torch
+import torch.distributed as dist
+
+from dlrm.utils.distributed import is_dist_avail_and_initialized
 
 
 class SmoothedValue(object):
@@ -172,7 +172,6 @@ def lr_step(optim, num_warmup_iter, current_step, base_lr, warmup_factor, decay_
         param_group['lr'] = new_lr
 
 
-
 def mkdir(path):
     try:
         os.makedirs(path)
@@ -181,45 +180,6 @@ def mkdir(path):
             raise
 
 
-def setup_for_distributed(is_master):
-    """
-    This function disables printing when not in master process
-    """
-    import builtins as __builtin__
-    builtin_print = __builtin__.print
-
-    def print(*args, **kwargs):
-        force = kwargs.pop('force', False)
-        if is_master or force:
-            builtin_print(*args, **kwargs)
-
-    __builtin__.print = print
-
-
-def is_dist_avail_and_initialized():
-    if not dist.is_available():
-        return False
-    if not dist.is_initialized():
-        return False
-    return True
-
-
-def get_world_size():
-    if not is_dist_avail_and_initialized():
-        return 1
-    return dist.get_world_size()
-
-
-def get_rank():
-    if not is_dist_avail_and_initialized():
-        return 0
-    return dist.get_rank()
-
-
-def is_main_process():
-    return get_rank() == 0
-
-
 def init_logging(log_path):
     json_backend = dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                               filename=log_path)
@@ -233,37 +193,6 @@ def init_logging(log_path):
     dllogger.init(backends=[json_backend, stdout_backend])
 
 
-def save_on_master(*args, **kwargs):
-    if is_main_process():
-        torch.save(*args, **kwargs)
-
-
-def init_distributed_mode(args):
-    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
-        args.rank = int(os.environ["RANK"])
-        args.world_size = int(os.environ['WORLD_SIZE'])
-        args.gpu = int(os.environ['LOCAL_RANK'])
-    elif 'SLURM_PROCID' in os.environ:
-        args.rank = int(os.environ['SLURM_PROCID'])
-        args.gpu = args.rank % torch.cuda.device_count()
-    elif hasattr(args, "rank"):
-        pass
-    else:
-        print('Not using distributed mode')
-        args.distributed = False
-        return
-
-    args.distributed = True
-
-    torch.cuda.set_device(args.gpu)
-    args.dist_backend = 'nccl'
-    print('| distributed init (rank {}): {}'.format(
-        args.rank, args.dist_url), flush=True)
-    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
-                                         world_size=args.world_size, rank=args.rank)
-    setup_for_distributed(args.rank == 0)
-
-
 class StepTimer():
     def __init__(self):
         self._previous = None
@@ -275,4 +204,102 @@ class StepTimer():
         self._new = time.time()
 
         if self._previous is not None:
-            self.measured = self._new - self._previous
+            self.measured = self._new - self._previous
+
+
+class LearningRateScheduler:
+    """Polynomial learning rate decay for multiple optimizers and multiple param groups
+
+    Args:
+        optimizers (list): optimizers for which to apply the learning rate changes
+        base_lrs (list): a nested list of base_lrs to use for each param_group of each optimizer
+        warmup_steps (int): number of linear warmup steps to perform at the beginning of training
+        warmup_factor (int)
+        decay_steps (int): number of steps over which to apply poly LR decay from base_lr to 0
+        decay_start_step (int): the optimization step at which to start decaying the learning rate
+            if None will start the decay immediately after
+        decay_power (float): polynomial learning rate decay power
+        end_lr_factor (float): for each optimizer and param group:
+            lr = max(current_lr_factor, end_lr_factor) * base_lr
+
+    Example:
+        lr_scheduler = LearningRateScheduler(optimizers=[optimizer], base_lrs=[[lr]],
+                                             warmup_steps=100, warmup_factor=0,
+                                             decay_start_step=1000, decay_steps=2000,
+                                             decay_power=2, end_lr_factor=1e-6)
+
+        for batch in data_loader:
+            lr_scheduler.step()
+            # foward, backward, weight update
+    """
+    def __init__(self, optimizers, base_lrs, warmup_steps, warmup_factor,
+                 decay_steps, decay_start_step, decay_power=2, end_lr_factor=0):
+        self.current_step = 0
+        self.optimizers = optimizers
+        self.base_lrs = base_lrs
+        self.warmup_steps = warmup_steps
+        self.warmup_factor = warmup_factor
+        self.decay_steps = decay_steps
+        self.decay_start_step = decay_start_step
+        self.decay_power = decay_power
+        self.end_lr_factor = end_lr_factor
+        self.decay_end_step = self.decay_start_step + self.decay_steps
+
+        if self.decay_start_step < self.warmup_steps:
+            raise ValueError('Learning rate warmup must finish before decay starts')
+
+    def _compute_lr_factor(self):
+        lr_factor = 1
+
+        if self.current_step <= self.warmup_steps:
+            warmup_step = 1 / (self.warmup_steps * (2 ** self.warmup_factor))
+            lr_factor = 1 - (self.warmup_steps - self.current_step) * warmup_step
+        elif self.decay_start_step < self.current_step <= self.decay_end_step:
+            lr_factor = ((self.decay_end_step - self.current_step) / self.decay_steps) ** self.decay_power
+            lr_factor = max(lr_factor, self.end_lr_factor)
+        elif self.current_step > self.decay_end_step:
+            lr_factor = self.end_lr_factor
+
+        return lr_factor
+
+    def step(self):
+        self.current_step += 1
+        lr_factor = self._compute_lr_factor()
+
+        for optim, base_lrs in zip(self.optimizers, self.base_lrs):
+            for group_id, base_lr in enumerate(base_lrs):
+                optim.param_groups[group_id]['lr'] = base_lr * lr_factor
+
+
+def roc_auc_score(y_true, y_score):
+    """ROC AUC score in PyTorch
+
+    Args:
+        y_true (Tensor):
+        y_score (Tensor):
+    """
+    device = y_true.device
+    y_true.squeeze_()
+    y_score.squeeze_()
+    if y_true.shape != y_score.shape:
+        raise TypeError(F"Shape of y_true and y_score must match. Got {y_true.shape()} and {y_score.shape()}.")
+
+    desc_score_indices = torch.argsort(y_score, descending=True)
+    y_score = y_score[desc_score_indices]
+    y_true = y_true[desc_score_indices]
+
+    distinct_value_indices = torch.nonzero(y_score[1:] - y_score[:-1]).squeeze()
+    threshold_idxs = torch.cat([distinct_value_indices, torch.tensor([y_true.numel() - 1], device=device)])
+
+    tps = torch.cumsum(y_true, dim=0)[threshold_idxs]
+    fps = 1 + threshold_idxs - tps
+
+    tps = torch.cat([torch.zeros(1, device=device), tps])
+    fps = torch.cat([torch.zeros(1, device=device), fps])
+
+    fpr = fps / fps[-1]
+    tpr = tps / tps[-1]
+
+    area = torch.trapz(tpr, fpr).item()
+
+    return area

+ 0 - 0
PyTorch/Recommendation/DLRM/dlrm/utils/__init__.py


+ 139 - 0
PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing.py

@@ -0,0 +1,139 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import torch
+from os.path import join
+from typing import Dict, Any, Optional, Sequence
+
+
+class DlrmCheckpointNavigator:
+
+    @property
+    def bottom_mlp_path(self) -> str:
+        return "bottom_model.mlp.pt"
+
+    @property
+    def top_model_path(self) -> str:
+        return "top_model.pt"
+
+    @property
+    def metadata_path(self) -> str:
+        return "metadata.pt"
+
+    def embedding_path(self, embedding_index: int) -> str:
+        return f"bottom_model.embeddings.{embedding_index}.pt"
+
+
+class DistributedCheckpointWriter:
+
+    def __init__(
+        self,
+        device_mapping: Dict[str, Any],
+        config: Dict[str, Any],
+        rank: int,
+        main_process: bool
+    ):
+        self._device_mapping = device_mapping
+        self._config = config
+        self._main_process = main_process
+        self._has_bottom_mlp = rank == device_mapping["bottom_mlp"]
+        self._embedding_indices = device_mapping["embedding"][rank]
+        self._navigator = DlrmCheckpointNavigator()
+
+    def save_checkpoint(
+        self,
+        model,
+        checkpoint_path: str,
+        epoch: Optional[int] = None,
+        step: Optional[int] = None
+    ):
+        os.makedirs(checkpoint_path, exist_ok=True)
+
+        self._save_embeddings_weights(checkpoint_path, model)
+
+        if self._has_bottom_mlp:
+            torch.save(model.bottom_model.mlp.state_dict(), join(checkpoint_path, self._navigator.bottom_mlp_path))
+
+        if self._main_process:
+            torch.save(model.top_model.state_dict(), join(checkpoint_path, self._navigator.top_model_path))
+            self._save_metadata(checkpoint_path, epoch, step)
+
+        torch.distributed.barrier()
+
+    def _save_embeddings_weights(self, checkpoint_path: str, model):
+        for embedding_index, weight in zip(self._embedding_indices, model.bottom_model.embeddings.weights):
+            torch.save({"weight": weight}, join(checkpoint_path, self._navigator.embedding_path(embedding_index)))
+
+    def _save_metadata(self, checkpoint_path, epoch, step):
+        torch.save({
+            "config": self._config,
+            "device_mapping": self._device_mapping,
+            "epoch": epoch,
+            "step": step
+        }, join(checkpoint_path, self._navigator.metadata_path))
+
+
+class DistributedCheckpointLoader:
+
+    def __init__(self, device_mapping: Dict[str, Any], rank: int):
+        self._device_mapping = device_mapping
+        self._has_bottom_mlp = rank == device_mapping["bottom_mlp"]
+        self._embedding_indices = device_mapping["embedding"][rank]
+        self._navigator = DlrmCheckpointNavigator()
+
+    def load_checkpoint(self, model, checkpoint_path: str):
+        top_model_state = self._load(checkpoint_path, self._navigator.top_model_path)
+        model.top_model.load_state_dict(top_model_state)
+
+        if self._has_bottom_mlp:
+            bottom_mlp_state = self._load(checkpoint_path, self._navigator.bottom_mlp_path)
+            model.bottom_model.mlp.load_state_dict(bottom_mlp_state)
+
+        embedding_weights = (self._load(checkpoint_path, self._navigator.embedding_path(index))["weight"]
+                             for index in self._embedding_indices)
+        model.bottom_model.embeddings.load_weights(embedding_weights)
+
+        torch.distributed.barrier()
+
+    def _load(self, checkpoint_path: str, state_path: str):
+        return torch.load(join(checkpoint_path, state_path), map_location="cpu")  # loading to CUDA causes OOM errors
+
+
+class CpuCheckpointLoader:
+
+    def __init__(self, embedding_indices: Sequence[int]):
+        self._embedding_indices = embedding_indices
+        self._navigator = DlrmCheckpointNavigator()
+
+    def load_checkpoint(self, model, checkpoint_path: str):
+        top_model_state = self._load(checkpoint_path, self._navigator.top_model_path)
+        model.top_model.load_state_dict(top_model_state)
+
+        bottom_mlp_state = self._load(checkpoint_path, self._navigator.bottom_mlp_path)
+        model.bottom_model.mlp.load_state_dict(bottom_mlp_state)
+
+        embedding_weights = (self._load(checkpoint_path, self._navigator.embedding_path(index))["weight"]
+                             for index in self._embedding_indices)
+        model.bottom_model.embeddings.load_weights(embedding_weights)
+
+    def _load(self, checkpoint_path: str, state_path: str):
+        data = torch.load(join(checkpoint_path, state_path), map_location="cpu")
+        return {self._strip_key(key): value for key, value in data.items()}
+
+    def _strip_key(self, key: str):
+        prefix = "module."
+        if key.startswith(prefix):
+            return key[len(prefix):]
+        return key

+ 0 - 0
PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing/__init__.py


+ 105 - 0
PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing/distributed.py

@@ -0,0 +1,105 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Any, Optional
+
+import torch
+
+from dlrm.utils.checkpointing.model import DlrmCheckpointWriter, DlrmCheckpointLoader
+
+
+class DistributedCheckpointWriter:
+
+    def __init__(
+        self,
+        writer: DlrmCheckpointWriter,
+        device_mapping: Dict[str, Any],
+        rank: int,
+        main_process: bool
+    ):
+        self._device_mapping = device_mapping
+        self._main_process = main_process
+        self._has_bottom_mlp = rank == device_mapping["bottom_mlp"]
+        self._writer = writer
+
+    def save_checkpoint(
+        self,
+        model,
+        checkpoint_path: str,
+        epoch: Optional[int] = None,
+        step: Optional[int] = None
+    ):
+        self._writer.save_embeddings(checkpoint_path, model)
+
+        if self._has_bottom_mlp:
+            self._writer.save_bottom_mlp(checkpoint_path, model)
+
+        if self._main_process:
+            self._writer.save_top_model(checkpoint_path, model)
+            self._save_metadata(checkpoint_path, epoch, step)
+
+        torch.distributed.barrier()
+
+    def _save_metadata(self, checkpoint_path, epoch, step):
+        self._writer.save_metadata(checkpoint_path, {
+            "device_mapping": self._device_mapping,
+            "epoch": epoch,
+            "step": step
+        })
+
+
+class DistributedCheckpointLoader:
+
+    def __init__(self, loader: DlrmCheckpointLoader, device_mapping: Dict[str, Any], rank: int):
+        self._has_bottom_mlp = rank == device_mapping["bottom_mlp"]
+        self._loader = loader
+
+    def load_checkpoint(self, model, checkpoint_path: str):
+        self._loader.load_top_model(checkpoint_path, model)
+
+        if self._has_bottom_mlp:
+            self._loader.load_bottom_mlp(checkpoint_path, model)
+
+        self._loader.load_embeddings(checkpoint_path, model)
+        torch.distributed.barrier()
+
+
+def make_distributed_checkpoint_loader(device_mapping, rank: int, device: str = "cpu") -> DistributedCheckpointLoader:
+    embedding_indices = device_mapping["embedding"][rank]
+    return DistributedCheckpointLoader(
+        loader=DlrmCheckpointLoader(
+            embedding_indices=embedding_indices,
+            device=device,
+        ),
+        device_mapping=device_mapping,
+        rank=rank
+    )
+
+
+def make_distributed_checkpoint_writer(
+        device_mapping,
+        rank: int,
+        is_main_process: bool,
+        config: Dict[str, Any],
+) -> DistributedCheckpointWriter:
+    embedding_indices = device_mapping["embedding"][rank]
+    return DistributedCheckpointWriter(
+        writer=DlrmCheckpointWriter(
+            embedding_indices=embedding_indices,
+            config=config
+        ),
+        device_mapping=device_mapping,
+        rank=rank,
+        main_process=is_main_process
+    )

+ 132 - 0
PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing/model.py

@@ -0,0 +1,132 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+from os.path import join
+from typing import Sequence, Any, Dict
+
+import torch
+
+_BOTTOM_MLP_FILE = "bottom_model.mlp.pt"
+_TOP_MLP_FILE = "top_model.mlp.pt"
+_TOP_OUT_FILE = "top_model.out.pt"
+_EMBEDDING_METADATA_FILE = "embeddings.metadata.pt"
+_METADATA_FILE = "metadata.pt"
+
+
+def _get_embedding_file(embedding_index: int) -> str:
+    return f"bottom_model.embeddings.{embedding_index}.bin"
+
+
+def _get_embedding_meta_file(embedding_index: int) -> str:
+    return f"embeddings.{embedding_index}.meta.pt"
+
+
+class DlrmCheckpointWriter:
+    """
+    Class responsible for saving checkpoints of DLRM model parts.
+
+    Depends on `dlrm.nn.embeddings.Embeddings` and `dlrm.nn.mlps.AbstractMlp` interfaces
+    (for handling multiple model configurations)
+    """
+    def __init__(self, embedding_indices: Sequence[int], config: Dict[str, Any]):
+        self._embedding_indices = embedding_indices
+        self._config = config
+
+    def save_embeddings(self, checkpoint_path: str, model):
+        self._ensure_directory(checkpoint_path)
+        for embedding_index, weight in zip(self._embedding_indices, model.bottom_model.embeddings.weights):
+            self._save_as_bytes(weight.data, join(checkpoint_path, _get_embedding_file(embedding_index)))
+            torch.save({"shape": weight.shape}, join(checkpoint_path, _get_embedding_meta_file(embedding_index)))
+
+    def save_bottom_mlp(self, checkpoint_path: str, model):
+        self._ensure_directory(checkpoint_path)
+        torch.save(self._mlp_state(model.bottom_model.mlp), join(checkpoint_path, _BOTTOM_MLP_FILE))
+
+    def save_top_model(self, checkpoint_path: str, model):
+        self._ensure_directory(checkpoint_path)
+        # DistributedDataParallel wraps top_model under "module" attribute
+        top_model = model.top_model.module if hasattr(model.top_model, 'module') else model.top_model
+
+        torch.save(self._mlp_state(top_model.mlp), join(checkpoint_path, _TOP_MLP_FILE))
+        torch.save(top_model.out.state_dict(), join(checkpoint_path, _TOP_OUT_FILE))
+
+    def save_metadata(self, checkpoint_path: str, data: Dict[str, Any]):
+        self._ensure_directory(checkpoint_path)
+        torch.save({"data": data, "config": self._config}, join(checkpoint_path, _METADATA_FILE))
+
+    def _ensure_directory(self, checkpoint_path: str):
+        os.makedirs(checkpoint_path, exist_ok=True)
+
+    def _mlp_state(self, mlp):
+        return {
+            "weights": [x.to(torch.float32) for x in mlp.weights],
+            "biases": [x.to(torch.float32) for x in mlp.biases]
+        }
+
+    def _save_as_bytes(self, tensor: torch.Tensor, path: str):
+        with open(path, "wb+") as file:
+            file.write(tensor.cpu().numpy().astype(np.float32).tobytes())
+
+
+class DlrmCheckpointLoader:
+    """
+    Class responsible for loading checkpoints of DLRM model parts.
+
+    Depends on `dlrm.nn.embeddings.Embeddings` and `dlrm.nn.mlps.AbstractMlp` interfaces
+    (for handling multiple model configurations)
+    """
+    def __init__(self, embedding_indices: Sequence[int], device: str = "cpu"):
+        self._embedding_indices = embedding_indices
+        self._device = device
+
+    def load_embeddings(self, checkpoint_path: str, model):
+        embedding_weights = (self._load_from_bytes(join(checkpoint_path, _get_embedding_file(index)),
+                                                   self._get_embedding_shape(checkpoint_path, index))
+                             for index in self._embedding_indices)
+        model.bottom_model.embeddings.load_weights(embedding_weights)
+
+    def load_bottom_mlp(self, checkpoint_path: str, model):
+        bottom_mlp_state = self._load(checkpoint_path, _BOTTOM_MLP_FILE)
+        model.bottom_model.mlp.load_state(bottom_mlp_state["weights"], bottom_mlp_state["biases"])
+
+    def load_top_model(self, checkpoint_path: str, model):
+        # DistributedDataParallel wraps top_model under "module" attribute
+        top_model = model.top_model.module if hasattr(model.top_model, 'module') else model.top_model
+        top_mlp_state = self._load(checkpoint_path, _TOP_MLP_FILE)
+        top_model.mlp.load_state(top_mlp_state["weights"], top_mlp_state["biases"])
+
+        top_out_state = self._load(checkpoint_path, _TOP_OUT_FILE)
+        top_model.out.load_state_dict(top_out_state)
+
+    def _load(self, checkpoint_path: str, state_path: str):
+        data = torch.load(join(checkpoint_path, state_path), map_location=self._device)
+        return {self._strip_key(key): value for key, value in data.items()}
+
+    def _strip_key(self, key: str):
+        # DistributedDataParallel wraps top_model under "module" attribute
+        prefix = "module."
+        if key.startswith(prefix):
+            return key[len(prefix):]
+        return key
+
+    def _load_from_bytes(self, path: str, shape) -> torch.Tensor:
+        with open(path, "rb") as file:
+            array = np.frombuffer(file.read(), dtype=np.float32).reshape(*shape)
+            return torch.from_numpy(array).to(self._device)
+
+    def _get_embedding_shape(self, checkpoint_path: str, index: int) -> tuple:
+        embedding_meta = torch.load(join(checkpoint_path, _get_embedding_meta_file(index)))
+        return embedding_meta["shape"]

+ 66 - 0
PyTorch/Recommendation/DLRM/dlrm/utils/checkpointing/serial.py

@@ -0,0 +1,66 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Sequence, Dict, Any
+
+from dlrm.utils.checkpointing.model import DlrmCheckpointWriter, DlrmCheckpointLoader
+
+
+class SerialCheckpointWriter:
+
+    def __init__(self, writer: DlrmCheckpointWriter):
+        self._writer = writer
+
+    def save_checkpoint(
+        self,
+        model,
+        checkpoint_path: str,
+        epoch: Optional[int] = None,
+        step: Optional[int] = None
+    ):
+        self._writer.save_embeddings(checkpoint_path, model)
+        self._writer.save_bottom_mlp(checkpoint_path, model)
+        self._writer.save_top_model(checkpoint_path, model)
+        self._writer.save_metadata(checkpoint_path, {
+            "epoch": epoch,
+            "step": step
+        })
+
+
+class SerialCheckpointLoader:
+
+    def __init__(self, loader: DlrmCheckpointLoader):
+        self._loader = loader
+
+    def load_checkpoint(self, model, checkpoint_path: str):
+        self._loader.load_top_model(checkpoint_path, model)
+        self._loader.load_bottom_mlp(checkpoint_path, model)
+        self._loader.load_embeddings(checkpoint_path, model)
+
+
+def make_serial_checkpoint_loader(embedding_indices: Sequence[int], device: str) -> SerialCheckpointLoader:
+    return SerialCheckpointLoader(DlrmCheckpointLoader(
+        embedding_indices=embedding_indices,
+        device=device,
+    ))
+
+
+def make_serial_checkpoint_writer(
+        embedding_indices: Sequence[int],
+        config: Dict[str, Any],
+) -> SerialCheckpointWriter:
+    return SerialCheckpointWriter(DlrmCheckpointWriter(
+        embedding_indices=embedding_indices,
+        config=config
+    ))

+ 153 - 0
PyTorch/Recommendation/DLRM/dlrm/utils/distributed.py

@@ -0,0 +1,153 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from functools import reduce
+from itertools import combinations_with_replacement
+from typing import MutableSequence, Any, Sequence, List
+
+import torch
+import torch.distributed as dist
+
+
+def setup_distributed_print(enable):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop('force', False)
+        if enable or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def is_distributed() -> bool:
+    return get_world_size() > 1
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def get_local_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return int(os.environ['LOCAL_RANK'])
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def init_distributed_mode(backend="nccl", use_gpu=True):
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        rank = int(os.environ["RANK"])
+        world_size = int(os.environ['WORLD_SIZE'])
+        gpu = int(os.environ['LOCAL_RANK'])
+    elif 'OMPI_COMM_WORLD_RANK' in os.environ and 'OMPI_COMM_WORLD_SIZE' in os.environ:
+        rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+        world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+        gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29500'
+    else:
+        print('Not using distributed mode')
+        return None, 1, None
+
+    if use_gpu:
+        torch.cuda.set_device(gpu)
+
+    print('| distributed init (rank {})'.format(rank), flush=True)
+    torch.distributed.init_process_group(backend=backend, world_size=world_size, rank=rank, init_method='env://')
+
+    return rank, world_size, gpu
+
+
+def get_gpu_batch_sizes(global_batch_size: int, num_gpus: int = 4, batch_std: int = 64, divisible_by: int = 64):
+    batch_avg = global_batch_size // num_gpus
+    start, end = batch_avg - batch_std, batch_avg + batch_std
+    sizes_range = (x for x in range(start, end + 1) if x % divisible_by == 0)
+    solutions = [
+        sizes for sizes in combinations_with_replacement(sizes_range, num_gpus) if sum(sizes) == global_batch_size
+    ]
+
+    if not solutions:
+        raise RuntimeError("Could not find GPU batch sizes for a given configuration. "
+                           "Please adjust global batch size or number of used GPUs.")
+
+    return max(solutions, key=lambda sizes: reduce(lambda x, y: x * y, sizes))
+
+
+def distribute_to_buckets(elements: MutableSequence[Any], buckets: Sequence[List[Any]], start_bucket: int = 0):
+    current_bucket = start_bucket % len(buckets)
+    while elements:
+        element = elements.pop()
+        buckets[current_bucket].append(element)
+        current_bucket = (current_bucket + 1) % len(buckets)
+    return current_bucket
+
+
+def get_criteo_device_mapping(num_gpus: int = 4, num_embeddings: int = 26, heavy_components=(0, 9, 19, 21, 20)):
+    """Get device mappings for hybrid parallelism
+
+    Bottom MLP running on device 0. 26 embeddings will be distributed across among all the devices. 0, 9, 19, 20, 21
+    are the large ones, 20GB each.
+
+    Args:
+        num_gpus (int): Default 4.
+        num_embeddings (int):
+        heavy_components (tuple):
+
+    Returns:
+        device_mapping (dict):
+    """
+    bottom_mlp_index = -1
+    heavy_components = list(heavy_components)
+    regular_components = [x for x in range(num_embeddings) if x not in heavy_components]
+
+    gpu_buckets = [[] for _ in range(num_gpus)]
+    gpu_buckets[0].append(bottom_mlp_index)
+
+    next_bucket = distribute_to_buckets(heavy_components, gpu_buckets, start_bucket=1)
+    distribute_to_buckets(regular_components, gpu_buckets, start_bucket=next_bucket)
+
+    vectors_per_gpu = [len(bucket) for bucket in gpu_buckets]
+
+    gpu_buckets[0].pop(0)  # pop bottom mlp
+
+    return {
+        'bottom_mlp': 0,
+        'embedding': gpu_buckets,
+        'vectors_per_gpu': vectors_per_gpu,
+    }

+ 12 - 4
PyTorch/Recommendation/DLRM/preproc/prepare_dataset.sh

@@ -39,9 +39,9 @@ conversion_intermediate_dir=${conversion_intermediate_dir:-'/data/dlrm/intermedi
 final_output_dir=${final_output_dir:-'/data/dlrm/binary_dataset'}
 
 
-if [ -f ${final_output_dir}/train_data.bin ] \
-   && [ -f ${final_output_dir}/val_data.bin ] \
-   && [ -f ${final_output_dir}/test_data.bin ] \
+if [ -d ${final_output_dir}/train ] \
+   && [ -d ${final_output_dir}/val ] \
+   && [ -d ${final_output_dir}/test ] \
    && [ -f ${final_output_dir}/model_sizes.json ]; then
 
     echo "Final conversion already done"
@@ -52,8 +52,16 @@ else
                                 --dst_dir ${final_output_dir}
 
     cp "${spark_output_path}/model_size.json" "${final_output_dir}/model_size.json"
+
+    python split_dataset.py --dataset "${final_output_dir}" --output "${final_output_dir}/split"
+    rm ${final_output_dir}/train_data.bin
+    rm ${final_output_dir}/val_data.bin
+    rm ${final_output_dir}/test_data.bin
+
+    mv ${final_output_dir}/split/* ${final_output_dir}
+    rm -rf ${final_output_dir}/split
 fi
 
 echo "Done preprocessing the Criteo Kaggle Dataset"
 echo "You can now start the training with: "
-echo "python -m dlrm.scripts.main --mode train --dataset  /data/dlrm/binary_dataset/ --model_config dlrm/config/default.json"
+echo "python -m dlrm.scripts.main --mode train --dataset  ${final_output_dir}"

+ 119 - 0
PyTorch/Recommendation/DLRM/preproc/split_dataset.py

@@ -0,0 +1,119 @@
+# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import json
+import os
+import math
+from shutil import copyfile
+
+from tqdm import tqdm
+import numpy as np
+from typing import Sequence
+
+from dlrm.data.utils import get_categorical_feature_type
+
+
+def split_binary_file(
+    binary_file_path: str,
+    output_dir: str,
+    categorical_feature_sizes: Sequence[int],
+    num_numerical_features: int,
+    batch_size: int,
+    source_data_type: str = 'int32',
+):
+    record_width = 1 + num_numerical_features + len(categorical_feature_sizes)  # label + numerical + categorical
+    bytes_per_feature = np.__dict__[source_data_type]().nbytes
+    bytes_per_entry = record_width * bytes_per_feature
+
+    total_size = os.path.getsize(binary_file_path)
+    batches_num = int(math.ceil((total_size // bytes_per_entry) / batch_size))
+
+    cat_feature_types = [get_categorical_feature_type(cat_size) for cat_size in categorical_feature_sizes]
+
+    file_streams = []
+    try:
+        input_data_f = open(binary_file_path, "rb")
+        file_streams.append(input_data_f)
+
+        numerical_f = open(os.path.join(output_dir, "numerical.bin"), "wb+")
+        file_streams.append(numerical_f)
+
+        label_f = open(os.path.join(output_dir, 'label.bin'), 'wb+')
+        file_streams.append(label_f)
+
+        categorical_fs = []
+        for i in range(len(categorical_feature_sizes)):
+            fs = open(os.path.join(output_dir, F'cat_{i}.bin'), 'wb+')
+            categorical_fs.append(fs)
+            file_streams.append(fs)
+
+        for _ in tqdm(range(batches_num)):
+            raw_data = np.frombuffer(input_data_f.read(bytes_per_entry * batch_size), dtype=np.int32)
+            batch_data = raw_data.reshape(-1, record_width)
+
+            numerical_features = batch_data[:, 1:1 + num_numerical_features].view(dtype=np.float32)
+            numerical_f.write(numerical_features.astype(np.float16).tobytes())
+
+            label = batch_data[:, 0]
+            label_f.write(label.astype(np.bool).tobytes())
+
+            cat_offset = num_numerical_features + 1
+            for cat_idx, cat_feature_type in enumerate(cat_feature_types):
+                cat_data = batch_data[:, (cat_idx + cat_offset):(cat_idx + cat_offset + 1)].astype(cat_feature_type)
+                categorical_fs[cat_idx].write(cat_data.tobytes())
+    finally:
+        for stream in file_streams:
+            stream.close()
+
+
+def split_dataset(dataset_dir: str, output_dir: str, batch_size: int, numerical_features: int):
+    categorical_sizes_file = os.path.join(dataset_dir, "model_size.json")
+    with open(categorical_sizes_file) as f:
+        categorical_sizes = list(json.load(f).values())
+
+    train_file = os.path.join(dataset_dir, "train_data.bin")
+    test_file = os.path.join(dataset_dir, "test_data.bin")
+    val_file = os.path.join(dataset_dir, "val_data.bin")
+
+    target_train = os.path.join(output_dir, "train")
+    target_test = os.path.join(output_dir, "test")
+    target_val = os.path.join(output_dir, "val")
+
+    os.makedirs(output_dir, exist_ok=True)
+    os.makedirs(target_train, exist_ok=True)
+    os.makedirs(target_test, exist_ok=True)
+    os.makedirs(target_val, exist_ok=True)
+
+    copyfile(categorical_sizes_file, os.path.join(output_dir, "model_size.json"))
+    split_binary_file(test_file, target_test, categorical_sizes, numerical_features, batch_size)
+    split_binary_file(train_file, target_train, categorical_sizes, numerical_features, batch_size)
+    split_binary_file(val_file, target_val, categorical_sizes, numerical_features, batch_size)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--dataset', type=str, required=True)
+    parser.add_argument('--output', type=str, required=True)
+    parser.add_argument('--batch_size', type=int, default=32768)
+    parser.add_argument('--numerical_features', type=int, default=13)
+    args = parser.parse_args()
+
+    split_dataset(
+        dataset_dir=args.dataset,
+        output_dir=args.output,
+        batch_size=args.batch_size,
+        numerical_features=args.numerical_features
+    )
+

+ 52 - 2
PyTorch/Recommendation/DLRM/setup.py

@@ -14,9 +14,9 @@
 
 
 import os
-import subprocess
+
 from setuptools import setup, find_packages
-from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
 
 abspath = os.path.dirname(os.path.realpath(__file__))
 
@@ -28,4 +28,54 @@ setup(name="dlrm",
       description="Reimplementation of Facebook's DLRM",
       packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
       zip_safe=False,
+      ext_modules=[
+          CUDAExtension(name="dlrm.cuda_ext.fused_embedding",
+                        sources=[
+                            os.path.join(abspath, "dlrm/cuda_src/pytorch_embedding_ops.cpp"),
+                            os.path.join(abspath, "dlrm/cuda_src/gather_gpu_fused_pytorch_impl.cu")
+                        ],
+                        extra_compile_args={
+                            'cxx': [],
+                            'nvcc': ["-arch=sm_70",
+                                     '-gencode', 'arch=compute_80,code=sm_80']
+                        }),
+          CUDAExtension(name="dlrm.cuda_ext.interaction_volta",
+                        sources=[
+                            os.path.join(abspath, "dlrm/cuda_src/dot_based_interact_volta/pytorch_ops.cpp"),
+                            os.path.join(abspath, "dlrm/cuda_src/dot_based_interact_volta/dot_based_interact_pytorch_types.cu")
+                        ],
+                        extra_compile_args={
+                            'cxx': [],
+                            'nvcc': [
+                                '-DCUDA_HAS_FP16=1',
+                                '-D__CUDA_NO_HALF_OPERATORS__',
+                                '-D__CUDA_NO_HALF_CONVERSIONS__',
+                                '-D__CUDA_NO_HALF2_OPERATORS__',
+                                '-gencode', 'arch=compute_70,code=sm_70']
+                        }),
+          CUDAExtension(name="dlrm.cuda_ext.interaction_ampere",
+                        sources=[
+                            os.path.join(abspath, "dlrm/cuda_src/dot_based_interact_ampere/pytorch_ops.cpp"),
+                            os.path.join(abspath, "dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact_pytorch_types.cu")
+                        ],
+                        extra_compile_args={
+                            'cxx': [],
+                            'nvcc': [
+                                '-DCUDA_HAS_FP16=1',
+                                '-D__CUDA_NO_HALF_OPERATORS__',
+                                '-D__CUDA_NO_HALF_CONVERSIONS__',
+                                '-D__CUDA_NO_HALF2_OPERATORS__',
+                                '-gencode', 'arch=compute_80,code=sm_80']
+                        }),
+          CUDAExtension(name="dlrm.cuda_ext.sparse_gather",
+                        sources=[
+                            os.path.join(abspath, "dlrm/cuda_src/sparse_gather/sparse_pytorch_ops.cpp"),
+                            os.path.join(abspath, "dlrm/cuda_src/sparse_gather/gather_gpu.cu")
+                        ],
+                        extra_compile_args={
+                            'cxx': [],
+                            'nvcc': ["-arch=sm_70",
+                                     '-gencode', 'arch=compute_80,code=sm_80']
+                        })
+      ],
       cmdclass={"build_ext": BuildExtension})

+ 9 - 8
PyTorch/Recommendation/DLRM/triton/Dockerfile

@@ -12,20 +12,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.03-py3
-FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk as trt
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.06-py3
+ARG TRITON_BASE_IMAGE=nvcr.io/nvidia/tritonserver:20.06-py3-clientsdk
+FROM ${TRITON_BASE_IMAGE} as trt
 FROM ${FROM_IMAGE_NAME}
 
 ADD requirements.txt .
 RUN pip install -r requirements.txt
 RUN pip install onnxruntime
 
-COPY --from=trt /workspace/install /workspace/install/
+COPY --from=trt /workspace/v2.0.0.clients.tar.gz ./v2.0.0.clients.tar.gz
+RUN tar -xzf v2.0.0.clients.tar.gz \
+    && pip install ./python/tritonclientutils-2.0.0-py3-none-any.whl \
+    && pip install ./python/tritonhttpclient-2.0.0-py3-none-any.whl \
+    && pip install ./python/tritongrpcclient-2.0.0-py3-none-any.whl
 
-ENV LD_LIBRARY_PATH /workspace/install/lib:${LD_LIBRARY_PATH}
-RUN ls /workspace/install/python
-RUN pip install /workspace/install/python/tensorrtserver-1.12.0-py3-none-linux_x86_64.whl
-
-ENV PYTHONPATH /workspace/dlrm
 WORKDIR /workspace/dlrm
 COPY . .
+RUN pip install --no-cache-dir -e .

+ 88 - 94
PyTorch/Recommendation/DLRM/triton/README.md

@@ -1,6 +1,6 @@
 # Deploying the DLRM model using Triton Inference Server
 
-The [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) provides a datacenter and cloud inferencing solution optimized for NVIDIA GPUs. The server provides an inference service via an HTTP or gRPC endpoint, allowing remote clients to request inferencing for any number of GPU or CPU models being managed by the server. 
+The [NVIDIA Triton Inference Server](https://github.com/NVIDIA/trtis-inference-server) provides a datacenter and cloud inferencing solution optimized for NVIDIA GPUs. The server provides an inference service via an HTTP or gRPC endpoint, allowing remote clients to request inferencing for any number of GPU or CPU models being managed by the server.
 
 This folder contains instructions for deploment and exemplary client application to run inference on
 Triton Inference Server as well as detailed performance analysis.
@@ -28,13 +28,13 @@ container:
 
 `docker run -it --rm --gpus device=0 --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --net=host -v <PATH_TO_MODEL_REPOSITORY>:/repository dlrm-inference bash`
 
-Here `device=0,1,2,3` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates location where
+Here `--gpus '"device=0,1,2,3"'` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates location where
 deployed models were stored.
 
 ### Deploying the model
 
 To deploy model into Triton compatible format, `deployer.py` script can by used. This script is
-meant to be run from inside deployment docker container. 
+meant to be run from inside deployment docker container.
 
 ```
 usage: deployer.py [-h] (--ts-script | --ts-trace | --onnx) [--triton-no-cuda]
@@ -43,7 +43,7 @@ usage: deployer.py [-h] (--ts-script | --ts-trace | --onnx) [--triton-no-cuda]
                    [--triton-max-batch-size TRITON_MAX_BATCH_SIZE]
                    [--triton-dyn-batching-delay TRITON_DYN_BATCHING_DELAY]
                    [--triton-engine-count TRITON_ENGINE_COUNT]
-                   [--save-dir SAVE_DIR]
+                   [--save-dir SAVE_DIR] [--deploy_cpu]
                    ...
 
 optional arguments:
@@ -51,6 +51,7 @@ optional arguments:
   --ts-script           convert to torchscript using torch.jit.script
   --ts-trace            convert to torchscript using torch.jit.trace
   --onnx                convert to onnx using torch.onnx.export
+  --deploy_cpu
 
 triton related flags:
   --triton-no-cuda      Use the CPU for tracing.
@@ -75,14 +76,11 @@ triton related flags:
 other flags:
   model_arguments       arguments that will be ignored by deployer lib and
                         will be forwarded to your deployer script
-
 ```
 
 Following model specific arguments have to be specified for model deployment:
-  
+
 ```
-  --num_numerical_features NUM_FEATURES
-                        Number of numerical features at network input.
   --embedding_dim EMBEDDING_DIM
                         Embedding dimensionality.
   --top_mlp_sizes TOP_MLP_SIZES [TOP_MLP_SIZES ...]
@@ -91,10 +89,6 @@ Following model specific arguments have to be specified for model deployment:
                         Units in layers of bottom MLP (default: 512 256 128).
   --interaction_op {cat,dot}
                         Interaction operator to use.
-  --self_interaction
-                        Enables self interaction.
-  --hash_indices
-                        Hash indices for categorical features.
   --dataset DATASET
                         Path to dataset directory contaning model_size.json file
                         describing input sizes for each embedding layer.
@@ -113,7 +107,7 @@ Following model specific arguments have to be specified for model deployment:
 For example, to deploy model into onnx format, using half precision and max batch size 4096 called
 `dlrm-onnx-16` execute:
 
-`python triton/deployer.py --onnx --triton-model-name dlrm-onnx-16 --triton-max-batch-size 4096 --save-dir /repository -- --model_checkpoint /results/checkpoint --fp16 --batch_size 4096 --num_numerical_features 13 --embedding_dim 128 --top_mlp_sizes 1024 1024 512 256 1 --bottom_mlp_sizes 512 256 128 --interaction_op dot --hash_indices --dataset /data`
+`python -m triton.deployer --ts-trace --triton-model-name dlrm-ts-trace-16 --triton-max-batch-size 4096 --save-dir /repository -- --model_checkpoint /results/checkpoint --fp16 --batch_size 4096 --num_numerical_features 13 --embedding_dim 128 --top_mlp_sizes 1024 1024 512 256 1 --bottom_mlp_sizes 512 256 128 --interaction_op dot --dataset /data`
 
 Where `model_checkpoint` is a checkpoint for a trained model with the same configuration as used during export and dataset (or at least dataset configuration)
 is mounted under `/data`
@@ -121,10 +115,10 @@ is mounted under `/data`
 ### Running the Triton server
 **NOTE: This step is executed outside inference container**
 
-1. `docker pull nvcr.io/nvidia/tritonserver:20.03-py3`
-2. `docker run -d --rm --gpus device=0 --ipc=host --network=host [--cpuset-cpus=0-15] -p 8000:8000 -p 8001:8001 -p 8002:8002 -v <PATH_TO_MODEL_REPOSITORY>:/models nvcr.io/nvidia/tritonserver:20.03-py3 trtserver --model-store=/models --log-verbose=1 --model-control-mode=explicit`
+1. `docker pull nvcr.io/nvidia/tritonserver:20.06-py3`
+2. `docker run -d --rm --gpus device=0 --ipc=host --network=host [--cpuset-cpus=0-15] -p 8000:8000 -p 8001:8001 -p 8002:8002 -v <PATH_TO_MODEL_REPOSITORY>:/models nvcr.io/nvidia/tritonserver:20.06-py3 tritonserver --model-repository=/models --log-verbose=1 --model-control-mode=explicit`
 
-Here `device=0,1,2,3` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates location where
+Here `--gpus '"device=0,1,2,3"'` selects GPUs indexed by ordinals `0,1,2` and `3`, respectively. The server will see only these GPUs. If you write `device=all`, then the server will see all the available GPUs. `PATH_TO_MODEL_REPOSITORY` indicates location where
 deployed models were stored. Additional `--model-controle-mode` option allows to manually load and
 unload models. This is especially useful when dealing with numerous large models like DLRM.
 
@@ -133,14 +127,13 @@ For models exported to onnx format and hosted inside onnx runtime it might be re
 ### Running client
 
 Exemplary client `client.py` allows to check model performance against synthetic or real validation
-data. Client connects to Triton server and perform inference. 
+data. Client connects to Triton server and perform inference.
 
 ```
 usage: client.py [-h] --triton-server-url TRITON_SERVER_URL
                  --triton-model-name TRITON_MODEL_NAME
                  [--triton-model-version TRITON_MODEL_VERSION]
-                 [--protocol PROTOCOL] [-v] [-H HTTP_HEADER]
-                 [--num_numerical_features NUM_NUMERICAL_FEATURES]
+                 [-v] [-H HTTP_HEADER]
                  --dataset_config DATASET_CONFIG
                  [--inference_data INFERENCE_DATA] [--batch_size BATCH_SIZE]
                  [--fp16]
@@ -153,12 +146,9 @@ optional arguments:
                         Triton deployed model name
   --triton-model-version TRITON_MODEL_VERSION
                         Triton model version
-  --protocol PROTOCOL   Communication protocol (HTTP/GRPC)
   -v, --verbose         Verbose mode.
   -H HTTP_HEADER        HTTP headers to add to inference server requests.
                         Format is -H"Header:Value".
-  --num_numerical_features NUM_NUMERICAL_FEATURES
-                        Number of numerical features as an input.
   --dataset_config DATASET_CONFIG
                         Configuration file describing categorical features
   --inference_data INFERENCE_DATA
@@ -169,110 +159,114 @@ optional arguments:
 ```
 
 To run inference on model exported in previous steps, using data located under
-`/data/test_data.bin` execute:
-
-`python triton/client.py --triton-server-url localhost:8000 --protocol HTTP --triton-model-name dlrm-onnx-16 --num_numerical_features 13 --dataset_config /data/model_size.json --inference_data /data/test_data.bin --batch_size 4096 --fp16`
-
-or
+`/data/test` execute:
 
-`python triton/client.py --triton-server-url localhost:8001 --protocol GRPC --triton-model-name dlrm-onnx-16 --num_numerical_features 13 --dataset_config /data/model_size.json --inference_data /data/test_data.bin --batch_size 4096 --fp16`
+`python -m triton.client --triton-server-url localhost:8000 --triton-model-name dlrm-ts-trace-16 --dataset_config /data/model_size.json --inference_data /data/test --batch_size 4096 --fp16`
 
 
 ### Gathering performance data
 Performance data can be gathered using `perf_client` tool. To use this tool, performance data needs
 to be dumped during deployment. To do that, use `--dump_perf_data` option for the deployer:
 
-`python triton/deployer.py --onnx --triton-model-name dlrm-onnx-16 --triton-max-batch-size 4096 --save-dir /repository -- --model_checkpoint /results/checkpoint --fp16 --batch_size 4096 --num_numerical_features 13 --embedding_dim 128 --top_mlp_sizes 1024 1024 512 256 1 --bottom_mlp_sizes 512 256 128 --interaction_op dot --hash_indices --dataset /data --dump_perf_data /location/for/perfdata`
+`python -m triton.deployer --ts-trace --triton-model-name dlrm-ts-trace-16 --triton-max-batch-size 4096 --save-dir /repository -- --model_checkpoint /results/checkpoint --fp16 --batch_size 4096 --num_numerical_features 13 --embedding_dim 128 --top_mlp_sizes 1024 1024 512 256 1 --bottom_mlp_sizes 512 256 128 --interaction_op dot --dataset /data --dump_perf_data /location/for/perfdata`
 
 When perf data are dumped, `perf_client` can be used with following command:
 
-`/workspace/install/bin/perf_client --max-threads 10 -m dlrm-onnx-16 -x 1 -p 5000 -v -i gRPC -u localhost:8001 -b 4096 -l 5000 --concurrency-range 1 --input-data /location/for/perfdata -f result.csv`
+`/workspace/bin/perf_client --max-threads 10 -m dlrm-onnx-16 -x 1 -p 5000 -v -i gRPC -u localhost:8001 -b 4096 -l 5000 --concurrency-range 1 --input-data /location/for/perfdata -f result.csv`
 
 For more information about `perf_client` please refer to [official documentation](https://docs.nvidia.com/deeplearning/sdk/triton-inference-server-master-branch-guide/docs/optimization.html#perf-client).
 
 ## Throughput/Latency results
 
-Throughput is measured in recommendations/second, and latency in milliseconds. 
+Throughput is measured in recommendations/second, and latency in milliseconds.
 
-**ONNX FP16 inference (V100-32G)**
-
-| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
-|----------------|----------------|-----------------|-----------------|-----------------|
-| 1	             | 432.4 rec/s    | 2.31 ms         | 2.42 ms         | 2.51 ms         |
-| 8	             | 3214.4 rec/s   |	2.48 ms         |	2.64 ms         |	2.72 ms         |
-| 64	           | 26924.8 rec/s  |	2.37 ms         | 2.50 ms	        | 2.57 ms         |
-| 512	           | 190413 rec/s   |	2.68 ms         | 2.85 ms         | 2.94 ms         |
-| 4096	         | 891290 rec/s   | 4.58 ms         |	4.82 ms         |	4.96 ms         |
-| 32768	         | 1218970 rec/s  |	26.85 ms        |	27.43 ms        |	28.81 ms        |
-| 65536	         | 1245180 rec/s  |	52.55	ms        | 53.46	ms        | 55.83 ms        |
-| 131072	       | 1140330 rec/s  |	115.24 ms       |	117.56 ms       |	120.32 ms       |
 
 **TorchScript FP16 inference (V100-32G)**
 
-| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
-|----------------|----------------|-----------------|-----------------|-----------------|
-| 1	             | 399.6 rec/s    |	2.50 ms         | 2.56 ms         | 2.70 ms         |
-| 8	             | 3563.2 rec/s   |	2.24 ms         | 2.29 ms         | 2.42 ms         |
-| 64             | 28288.2 rec/s  | 2.26 ms         | 2.33 ms         | 2.41 ms         |
-| 512            | 220774 rec/s   | 2.31 ms         | 2.38 ms         | 2.44 ms         |
-| 4096           | 1104280 rec/s  | 3.70 ms         | 3.78 ms         | 3.86 ms         |
-| 32768          | 1428680 rec/s  | 22.97 ms        | 23.29 ms        | 24.05 ms        |
-| 65536          | 1402470 rec/s  | 46.80 ms        | 48.12 ms        | 52.88 ms        |
-| 131072         | 1546650 rec/s  | 85.27 ms        | 86.17 ms        | 87.05 ms        |
+|   Batch  Size|   Throughput [samples / s]  |   Median Latency [ms]|   95% latency [ms]|   99% latency [ms]|
+|--------:|--------------------:|--------------:|--------------:|---------------:|
+|       1 |      1019         |         0.966 |         1.027 |          1.082 |
+|       2 |      2119         |         0.989 |         1.047 |          1.086 |
+|       4 |      3340         |         1.199 |         1.277 |          1.290 |
+|       8 |      6641         |         1.197 |         1.284 |          1.314 |
+|      16 |     12.5k         |         1.082 |         1.196 |          1.214 |
+|      32 |     28k         |         1.133 |         1.271 |          1.291 |
+|      64 |     45k         |         1.413  |         1.489 |          1.551 |
+|     128 |    105k           |         1.223 |         1.270 |          1.290 |
+|     256 |    193.6k           |         1.320 |         1.471 |          1.518 |
+|     512 |    376k           |         1.367 |         1.449 |          1.486 |
+|    1024 |    740k           |         1.379 |         1.463 |          1.536 |
+|    2048 |         1.105M |         1.817 |         2.106 |          2.195 |
+|    4096 |         1.488M |         2.730 |         2.851 |          3.266 |
+|    8192 |         1.676M |         4.851 |         5.174 |          5.486 |
+|   16384 |         1.831M |        8.926 |        9.127 |         9.415 |
+|   32768 |         1.756M |        18.543 |        19.625 |         20.223   |
+|   65536 |         1.678M |        38.950 |        41.112 |         41.985 |
+|  131072 |         1.547M |        86.258 |        90.772 |         92.511 |
+
 
 **TorchScript FP32 inference (V100-32G)**
 
-| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
-|----------------|----------------|-----------------|-----------------|-----------------|
-| 1              | 333.7 rec/s    | 2.99 ms         | 3.17 ms         | 3.32 ms         |
-| 8              | 3092.8 rec/s   | 2.58 ms         | 2.79 ms         | 2.91 ms         |
-| 64             | 24435.2 rec/s  | 2.61 ms         | 2.78 ms         | 2.89 ms         |
-| 512            | 169216 rec/s   | 3.02 ms         | 3.14 ms         | 3.19 ms         |
-| 4096           | 718438 rec/s   | 5.69 ms         | 5.93 ms         | 6.08 ms         |
-| 32768          | 842138 rec/s   | 38.96 ms        | 39.68 ms        | 41.02 ms        |
-| 65536          | 892138 rec/s   | 73.53 ms        | 74.56 ms        | 74.99 ms        |
-| 131072         | 904397 rec/s   | 146.11 ms       | 149.88 ms       | 151.43 ms       |
-
-**ONNX FP32 inference CPU (2x E5-2698 v4 @ 2.20GHz)**
-
-| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
-|----------------|----------------|-----------------|-----------------|-----------------|
-| 1              | 402.5 rec/s    | 2.48 ms         | 2.34 ms         | 3.16 ms         |
-| 8              | 2316 rec/s     | 3.39 ms         | 2.89 ms         | 6.93 ms         |
-| 64             | 9248 rec/s     | 6.91 ms         | 6.73 ms         | 13.14 ms        |
-| 512            | 14643.3 rec/s  | 35.00 ms        | 42.77 ms        | 69.24 ms        |
-| 4096           | 13926.4 rec/s  | 291.28 ms       | 321.90 ms       | 490.06 ms       |
-| 32768          | 13107.2 rec/s  | 2387.24 ms      | 2395.80 ms      | 2395.80 ms      |
-| 65536          | 14417.9 rec/s  | 5008.26 ms      | 5311.47 ms      | 5311.47 ms      |
-| 131072         | 13107.2 rec/s  | 10033.19 ms     | 10416.43 ms     | 10416.43 ms     |
+|   Batch  Size|   Throughput [samples / s]  |   Median Latency [ms]|   95% latency [ms]|   99% latency [ms]|
+|--------:|--------------------:|--------------:|--------------:|---------------:|
+|       1 |       1153         |         0.855 |         0.909 |          0.929 |
+|       2 |      2084         |         0.950  |         1.042   |          1.199 |
+|       4 |      4105        |         0.955 |         1.033 |          1.177 |
+|       8 |      8356         |         0.943 |         1.029 |          1.179 |
+|      16 |     16.8k           |       0.942 |         1.009 |          1.158 |
+|      32 |     28.3k         |         1.134 |         1.274 |          1.336 |
+|      64 |     54.7k         |         1.214 |         1.307  |          1.353 |
+|     128 |    118k          |         1.036 |         1.255 |          1.303 |
+|     256 |    202k          |         1.275 |         1.404 |          1.449 |
+|     512 |    401k           |         1.286 |         1.365 |          1.397 |
+|    1024 |    707k           |         1.448 |         1.518 |          1.550 |
+|    2048 |    833k           |         2.450 |         2.547 |          2.610 |
+|    4096 |    1.013M           |        3.996 |         4.361 |          4.683 |
+|    8192 |    1.091M           |         7.333 |        7.951 |         8.115 |
+|   16384 |    1.180M          |        13.8  |        14.462 |         15.053 |
+|   32768 |    1.173M           |        27.927 |        28.655 |         28.841 |
+|   65536 |    1.140M          |        57.046 |        58.627 |         58.861 |
+|  131072 |         1.074M |       120.982 |       122.193 |        122.337 |
+
 
 **TorchScript FP32 inference CPU (2x E5-2698 v4 @ 2.20GHz)**
 
-| **Batch Size** | **Throughput** | **Avg Latency** | **95% Latency** | **99% Latency** |
-|----------------|----------------|-----------------|-----------------|-----------------|
-| 1              | 116.3 rec/s    | 8.60 ms         | 9.83 ms         | 14.60 ms        |
-| 8              | 3723.2 rec/s   | 2.14 ms         | 2.55 ms         | 2.78 ms         |
-| 64             | 3014.4 rec/s   | 21.22 ms        | 31.34 ms        | 41.28 ms        |
-| 512            | 6451.2 rec/s   | 79.69 ms        | 106.00 ms       | 296.39 ms       |
-| 4096           | 41984 rec/s    | 97.71 ms        | 118.70 ms       | 123.37 ms       |
-| 32768          | 79735.5 rec/s  | 407.98 ms       | 426.64 ms       | 430.66 ms       |
-| 65536          | 79021.8 rec/s  | 852.90 ms       | 902.39 ms       | 911.46 ms       |
-| 131072         | 81264.6 rec/s  | 1601.28 ms      | 1694.64 ms      | 1711.57 ms      |
+|   Batch  Size|   Throughput [samples / s]  |   Avg Latency [ms]|   95% latency [ms]|   99% latency [ms]|
+|--------:|--------------------:|--------------:|--------------:|---------------:|
+|       1 |               923.2 |         1.051 |         1.195 |          1.225 |
+|       2 |              1660.8 |         1.204 |         1.486 |          1.597 |
+|       4 |              3553.6 |         1.095 |         1.456 |          1.65  |
+|       8 |              6692.8 |         1.112 |         1.56  |          1.787 |
+|      16 |             11.8k |         1.317 |         1.545 |          1.698 |
+|      32 |             19.2k |         1.636 |         1.851 |          2.261 |
+|      64 |             28.6k |         2.203 |         2.403 |          2.615 |
+|     128 |             37.3k |         3.333 |         3.968 |          4.143 |
+|     256 |             46.5k |         5.286 |         6.538 |          7.102 |
+|     512 |             63.5k   |         7.962 |         8.256 |         10.052 |
+|    1024 |             85.8k |        10.777 |        16.563 |         17.917 |
+|    2048 |            117k   |        17.169 |        19.441 |         26.955 |
+|    4096 |             95.8k |        41.988 |        45.996 |         50.483 |
+|    8192 |             85.1k |        92.251 |       131.333 |        133.578 |
+|   16384 |             88.5k |       187.677 |       204.676 |        231.393 |
+|   32768 |             78.6k |       408.815 |       428.574 |        429.58  |
+|   65536 |             91.8k |       804.059 |       810.328 |        810.328 |
+|  131072 |             78.6k|      1606.89  |      1635.36  |       1635.36  |
+
 
 ![Latency vs Throughput](./img/lat_vs_thr.png)
 
-The plot above shows, that the GPU is saturated with batch size 4096. However, running inference with larger batches 
-might be faster, than running several inference requests. Therefore, we choose 65536 as the optimal batch size. 
+The plot above shows, that the GPU is saturated with batch size 4096. However, running inference with larger batches
+might be faster, than running several inference requests. Therefore, we choose 65536 as the optimal batch size.
 
 
 ## Dynamic batching support
-The Triton server has a dynamic batching mechanism built in, that can be enabled. When it is enabled, then the server creates 
-inference batches from the received requests. Since the output of the model is a single probability, the batch size of a 
-single request may be large. Here it is assumed to be 4096. With dynamic batching enabled, the server will concatenate requests of this size into 
-an inference batch. The upper bound of the size of the inference batch is set to 65536. All these parameters are configurable. 
+The Triton server has a dynamic batching mechanism built in, that can be enabled. When it is enabled, then the server creates
+inference batches from the received requests. Since the output of the model is a single probability, the batch size of a
+single request may be large. Here it is assumed to be 4096. With dynamic batching enabled, the server will concatenate requests of this size into
+an inference batch. The upper bound of the size of the inference batch is set to 65536. All these parameters are configurable.
 Performance results on a single V100-32G (ONNX FP16 model) for various numbers of simultaneous requests are shown in the figure below.
 
 ![Dynamic batching](./img/dyn_batch_concurrency.png)
 
-The plot above shows, that if we have a 20ms upper bound on latency, then a single GPU can handle up to 8 concurrent requests. 
-This leads to total throughput of 1.776.030 recommendations/sec. This means 35520 recommendations within 20ms, on a single GPU. 
+The plot above shows, that if we have a 20ms upper bound on latency, then a single GPU can handle up to 8 concurrent requests.
+This leads to total throughput of 1.776.030 recommendations/sec. This means 35520 recommendations within 20ms, on a single GPU.

+ 113 - 85
PyTorch/Recommendation/DLRM/triton/client.py

@@ -1,45 +1,63 @@
-# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
+#!/usr/bin/env python
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
 #
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#  * Redistributions of source code must retain the above copyright
+#    notice, this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright
+#    notice, this list of conditions and the following disclaimer in the
+#    documentation and/or other materials provided with the distribution.
+#  * Neither the name of NVIDIA CORPORATION nor the names of its
+#    contributors may be used to endorse or promote products derived
+#    from this software without specific prior written permission.
 #
-#       http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 import argparse
 import json
+import sys
 
+import numpy as np
 import torch
-
-from dlrm.data import data_loader
-from dlrm.data.synthetic_dataset import SyntheticDataset
-
+import tritonhttpclient
+from sklearn.metrics import roc_auc_score
 from tqdm import tqdm
-from tensorrtserver.api import *
 
-from sklearn.metrics import roc_auc_score
-from functools import partial
+from dlrm.data.datasets import SyntheticDataset, SplitCriteoDataset
 
-def get_data_loader(batch_size, *, data_file, model_config):
+
+def get_data_loader(batch_size, *, data_path, model_config):
     with open(model_config.dataset_config) as f:
         categorical_sizes = list(json.load(f).values())
-    if data_file:
-        data = data_loader.CriteoBinDataset(data_file=data_file,
-                batch_size=batch_size, subset=None,
-                numerical_features=model_config.num_numerical_features,
-                categorical_features=len(categorical_sizes),
-                online_shuffle=False)
+    if data_path:
+        data = SplitCriteoDataset(
+            data_path=data_path,
+            batch_size=batch_size,
+            numerical_features=True,
+            categorical_features=range(len(categorical_sizes)),
+            categorical_feature_sizes=categorical_sizes,
+            prefetch_depth=1
+        )
     else:
-        data = SyntheticDataset(num_entries=batch_size * 1024, batch_size=batch_size,
-                dense_features=model_config.num_numerical_features,
-                categorical_feature_sizes=categorical_sizes,
-                device="cpu")
+        data = SyntheticDataset(
+            num_entries=batch_size * 1024,
+            batch_size=batch_size,
+            numerical_features=model_config.num_numerical_features,
+            categorical_feature_sizes=categorical_sizes,
+            device="cpu"
+        )
 
     return torch.utils.data.DataLoader(data,
                                        batch_size=None,
@@ -47,76 +65,87 @@ def get_data_loader(batch_size, *, data_file, model_config):
                                        pin_memory=False)
 
 
-if __name__ == "__main__":
+def run_infer(model_name, model_version, numerical_features, categorical_features, headers=None):
+    inputs = []
+    outputs = []
+    num_type = "FP16" if numerical_features.dtype == np.float16 else "FP32"
+    inputs.append(tritonhttpclient.InferInput('input__0', numerical_features.shape, num_type))
+    inputs.append(tritonhttpclient.InferInput('input__1', categorical_features.shape, "INT64"))
+
+    # Initialize the data
+    inputs[0].set_data_from_numpy(numerical_features, binary_data=True)
+    inputs[1].set_data_from_numpy(categorical_features, binary_data=False)
+
+    outputs.append(tritonhttpclient.InferRequestedOutput('output__0', binary_data=True))
+    results = triton_client.infer(model_name,
+                                  inputs,
+                                  model_version=str(model_version) if model_version != -1 else '',
+                                  outputs=outputs,
+                                  headers=headers)
+    return results
+
+
+if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument("--triton-server-url", type=str, required=True, 
-                        help="URL adress of trtion server (with port)")
-    parser.add_argument("--triton-model-name", type=str, required=True,
-                        help="Triton deployed model name")
-    parser.add_argument("--triton-model-version", type=int, default=-1,
-                        help="Triton model version")
-    parser.add_argument("--protocol", type=str, default="HTTP",
-                        help="Communication protocol (HTTP/GRPC)")
-    parser.add_argument("-v", "--verbose", action="store_true", default=False,
-                        help="Verbose mode.")
+    parser.add_argument('--triton-server-url',
+                        type=str,
+                        required=True,
+                        help='URL adress of triton server (with port)')
+    parser.add_argument('--triton-model-name', type=str, required=True,
+                        help='Triton deployed model name')
+    parser.add_argument('--triton-model-version', type=int, default=-1,
+                        help='Triton model version')
+    parser.add_argument('-v',
+                        '--verbose',
+                        action="store_true",
+                        required=False,
+                        default=False,
+                        help='Enable verbose output')
     parser.add_argument('-H', dest='http_headers', metavar="HTTP_HEADER",
                         required=False, action='append',
                         help='HTTP headers to add to inference server requests. ' +
                         'Format is -H"Header:Value".')
 
-    parser.add_argument("--num_numerical_features", type=int, default=13)
     parser.add_argument("--dataset_config", type=str, required=True)
-    parser.add_argument("--inference_data", type=str, 
+    parser.add_argument("--inference_data", type=str,
                         help="Path to file with inference data.")
     parser.add_argument("--batch_size", type=int, default=1,
                         help="Inference request batch size")
     parser.add_argument("--fp16", action="store_true", default=False,
                         help="Use 16bit for numerical input")
+
     FLAGS = parser.parse_args()
+    try:
+        triton_client = tritonhttpclient.InferenceServerClient(url=FLAGS.triton_server_url, verbose=FLAGS.verbose)
+    except Exception as e:
+        print("channel creation failed: " + str(e))
+        sys.exit(1)
+
+    if FLAGS.http_headers is not None:
+        headers_dict = {l.split(':')[0]: l.split(':')[1]
+                        for l in FLAGS.http_headers}
+    else:
+        headers_dict = None
 
-    FLAGS.protocol = ProtocolType.from_str(FLAGS.protocol)
-    
-    # Create a health context, get the ready and live state of server.
-    health_ctx = ServerHealthContext(FLAGS.triton_server_url, FLAGS.protocol, 
-                                     http_headers=FLAGS.http_headers, verbose=FLAGS.verbose)
-    print("Health for model {}".format(FLAGS.triton_model_name))
-    print("Live: {}".format(health_ctx.is_live()))
-    print("Ready: {}".format(health_ctx.is_ready()))
-    
-    with ModelControlContext(FLAGS.triton_server_url, FLAGS.protocol) as ctx:
-        ctx.load(FLAGS.triton_model_name)
-
-    # Create a status context and get server status
-    status_ctx = ServerStatusContext(FLAGS.triton_server_url, FLAGS.protocol, FLAGS.triton_model_name, 
-                                     http_headers=FLAGS.http_headers, verbose=FLAGS.verbose)
-    print("Status for model {}".format(FLAGS.triton_model_name))
-    print(status_ctx.get_server_status())
-    
-    # Create the inference context for the model.
-    infer_ctx = InferContext(FLAGS.triton_server_url, FLAGS.protocol, FLAGS.triton_model_name, 
-                             FLAGS.triton_model_version, 
-                             http_headers=FLAGS.http_headers, verbose=FLAGS.verbose)
-
-    dataloader = get_data_loader(FLAGS.batch_size, 
-                                 data_file=FLAGS.inference_data,
-                                 model_config=FLAGS)
+    triton_client.load_model(FLAGS.triton_model_name)
+    if not triton_client.is_model_ready(FLAGS.triton_model_name):
+        sys.exit(1)
 
+    dataloader = get_data_loader(FLAGS.batch_size,
+                                 data_path=FLAGS.inference_data,
+                                 model_config=FLAGS)
     results = []
     tgt_list = []
 
-    for num, cat, target in tqdm(dataloader):
-        num = num.cpu().numpy()
-        if FLAGS.fp16:
-            num = num.astype(np.float16)
-        cat = cat.long().cpu().numpy()
+    for numerical_features, categorical_features, target in tqdm(dataloader):
+        numerical_features = numerical_features.cpu().numpy()
+        numerical_features = numerical_features.astype(np.float16 if FLAGS.fp16 else np.float32)
+        categorical_features = categorical_features.long().cpu().numpy()
 
-        input_dict = {"input__0": tuple(num[i] for i in range(len(num))),
-                      "input__1": tuple(cat[i] for i in range(len(cat)))}
-        output_keys = ["output__0"]
-        output_dict = {x: InferContext.ResultFormat.RAW for x in output_keys}
+        output = run_infer(FLAGS.triton_model_name, FLAGS.triton_model_version,
+                           numerical_features, categorical_features, headers_dict)
 
-        result = infer_ctx.run(input_dict, output_dict, len(num))
-        results.append(result["output__0"])
+        results.append(output.as_numpy('output__0'))
         tgt_list.append(target.cpu().numpy())
 
     results = np.concatenate(results).squeeze()
@@ -125,9 +154,8 @@ if __name__ == "__main__":
     score = roc_auc_score(tgt_list, results)
     print(F"Model score: {score}")
 
-    with ModelControlContext(FLAGS.triton_server_url, FLAGS.protocol) as ctx:
-        ctx.unload(FLAGS.triton_model_name)
-
-
-
-
+    statistics = triton_client.get_inference_statistics(model_name=FLAGS.triton_model_name, headers=headers_dict)
+    print(statistics)
+    if len(statistics['model_stats']) != 1:
+        print("FAILED: Inference Statistics")
+        sys.exit(1)

+ 45 - 36
PyTorch/Recommendation/DLRM/triton/deployer.py

@@ -1,30 +1,34 @@
 #!/usr/bin/python
 
-# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License. 
+# you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, 
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
-# limitations under the License. 
+# limitations under the License.
 
-import os
-import torch
 import argparse
-import deployer_lib
 import json
-# 
+import os
+#
 import sys
+
+import torch
+
+from dlrm.data.datasets import SyntheticDataset
+from dlrm.model.single import Dlrm
+from dlrm.utils.checkpointing.serial import make_serial_checkpoint_loader
+from triton import deployer_lib
+
 sys.path.append('../')
 
-from dlrm.model import Dlrm
-from dlrm.data.synthetic_dataset import SyntheticDataset
 
 def get_model_args(model_args):
     parser = argparse.ArgumentParser()
@@ -35,52 +39,54 @@ def get_model_args(model_args):
 
     parser.add_argument("--num_numerical_features", type=int, default=13)
     parser.add_argument("--embedding_dim", type=int, default=128)
+    parser.add_argument("--embedding_type", type=str, default="joint", choices=["joint", "multi_table"])
     parser.add_argument("--top_mlp_sizes", type=int, nargs="+",
                         default=[1024, 1024, 512, 256, 1])
     parser.add_argument("--bottom_mlp_sizes", type=int, nargs="+",
                         default=[512, 256, 128])
     parser.add_argument("--interaction_op", type=str, default="dot",
                         choices=["dot", "cat"])
-    parser.add_argument("--self_interaction", default=False, 
-                        action="store_true")
-    parser.add_argument("--hash_indices", default=False, 
-                        action="store_true")
     parser.add_argument("--cpu", default=False, action="store_true")
     parser.add_argument("--dataset", type=str, required=True)
-    
+
     return parser.parse_args(model_args)
 
+
 def initialize_model(args, categorical_sizes):
     ''' return model, ready to trace '''
     base_device = "cuda:0" if not args.cpu else "cpu"
     model_config = {
-        "top_mlp_sizes": args.top_mlp_sizes,
-        "bottom_mlp_sizes": args.bottom_mlp_sizes,
-        "embedding_dim": args.embedding_dim,
-        "interaction_op": args.interaction_op,
-        "self_interaction": args.self_interaction,
-        "categorical_feature_sizes": categorical_sizes,
-        "num_numerical_features": args.num_numerical_features,
-        "hash_indices": args.hash_indices,
-        "base_device": base_device
+        'top_mlp_sizes': args.top_mlp_sizes,
+        'bottom_mlp_sizes': args.bottom_mlp_sizes,
+        'embedding_dim': args.embedding_dim,
+        'interaction_op': args.interaction_op,
+        'categorical_feature_sizes': categorical_sizes,
+        'num_numerical_features': args.num_numerical_features,
+        'embedding_type': args.embedding_type,
+        'hash_indices': False,
+        'use_cpp_mlp': False,
+        'fp16': args.fp16,
+        'base_device': base_device,
     }
-        
-    model = Dlrm.from_dict(model_config, sigmoid=True)
+
+    model = Dlrm.from_dict(model_config)
     model.to(base_device)
 
     if args.model_checkpoint:
-        model.load_state_dict(torch.load(args.model_checkpoint,  
-                                         map_location="cpu"))
+        checkpoint_loader = make_serial_checkpoint_loader(range(len(categorical_sizes)), device="cpu")
+        checkpoint_loader.load_checkpoint(model, args.model_checkpoint)
+        model.to(base_device)
 
     if args.fp16:
         model = model.half()
 
     return model
 
+
 def get_dataloader(args, categorical_sizes):
     dataset_test = SyntheticDataset(num_entries=2000,
                                     batch_size=args.batch_size,
-                                    dense_features=args.num_numerical_features,
+                                    numerical_features=args.num_numerical_features,
                                     categorical_feature_sizes=categorical_sizes,
                                     device="cpu" if args.cpu else "cuda:0")
     class RemoveOutput:
@@ -98,19 +104,22 @@ def get_dataloader(args, categorical_sizes):
         def __len__(self):
             return len(self.dataset)
 
-    test_loader = torch.utils.data.DataLoader(RemoveOutput(dataset_test), 
-                                              batch_size=None, 
-                                              num_workers=0, 
+    test_loader = torch.utils.data.DataLoader(RemoveOutput(dataset_test),
+                                              batch_size=None,
+                                              num_workers=0,
                                               pin_memory=False)
 
     return test_loader
 
 
 if __name__=='__main__':
-    deployer, model_args = deployer_lib.create_deployer(sys.argv[1:], 
-            get_model_args) # deployer and returns removed deployer arguments
+    # deploys and returns removed deployer arguments
+    deployer, model_args = deployer_lib.create_deployer(sys.argv[1:],
+                                                        get_model_args)
+
     with open(os.path.join(model_args.dataset, "model_size.json")) as f:
         categorical_sizes = list(json.load(f).values())
+        categorical_sizes = [s + 1 for s in categorical_sizes]
 
     model = initialize_model(model_args, categorical_sizes)
     dataloader = get_dataloader(model_args, categorical_sizes)
@@ -123,5 +132,5 @@ if __name__=='__main__':
         os.makedirs(model_args.dump_perf_data, exist_ok=True)
         input_0.detach().cpu().numpy()[0].tofile(os.path.join(model_args.dump_perf_data, "input__0"))
         input_1.detach().cpu().numpy()[0].tofile(os.path.join(model_args.dump_perf_data, "input__1"))
-        
+
     deployer.deploy(dataloader, model)

+ 15 - 21
PyTorch/Recommendation/DLRM/triton/deployer_lib.py

@@ -54,7 +54,7 @@ output [
 instance_group [
     {{
         count: {engine_count}
-        kind: KIND_GPU
+        kind: {kind}
         gpus: [ {gpu_list} ]
     }}
 ]
@@ -149,6 +149,9 @@ def create_deployer(argv, model_args_parser):
                            type=str,
                            default='./triton_models',
                            help='Saved model directory')
+
+    parser.add_argument("--deploy_cpu", default=False, action="store_true")
+
     # other args
     arguments = parser.add_argument_group('other flags')
 
@@ -369,26 +372,17 @@ dynamic_batching {{
             accelerator_str = accelerator_template.format_map({})
 
         config_values = {
-            "model_name":
-            self.args.triton_model_name,
-            "platform":
-            self.platform,
-            "max_batch_size":
-            max_batch_size,
-            "spec_inputs":
-            spec_inputs,
-            "spec_outputs":
-            spec_outputs,
-            "dynamic_batching":
-            batching_str,
-            "model_parameters":
-            parameters_str,
-            "model_optimizations":
-            accelerator_str,
-            "gpu_list":
-            ", ".join([str(x) for x in range(torch.cuda.device_count())]),
-            "engine_count":
-            self.args.triton_engine_count
+            "model_name": self.args.triton_model_name,
+            "platform": self.platform,
+            "max_batch_size": max_batch_size,
+            "spec_inputs": spec_inputs,
+            "spec_outputs": spec_outputs,
+            "dynamic_batching": batching_str,
+            "model_parameters": parameters_str,
+            "model_optimizations": accelerator_str,
+            "gpu_list": "" if self.args.deploy_cpu else ", ".join([str(x) for x in range(torch.cuda.device_count())]),
+            "engine_count": self.args.triton_engine_count,
+            "kind": "KIND_CPU" if self.args.deploy_cpu else "KIND_GPU"
         }
 
         # write config

BIN
PyTorch/Recommendation/DLRM/triton/img/lat_vs_thr.png