Переглянути джерело

[SIM/TF2] Release new version of SIM model with prebatching support

Jakub Tomsia 3 роки тому
батько
коміт
bf00fe1dbe

+ 0 - 1
TensorFlow2/Recommendation/SIM/.gitignore

@@ -15,4 +15,3 @@
 .ipynb_checkpoints/
 .ipynb_checkpoints/
 .idea/
 .idea/
 __pycache__
 __pycache__
-results/

+ 239 - 329
TensorFlow2/Recommendation/SIM/README.md

@@ -28,6 +28,7 @@ This repository provides a script and recipe to train the SIM model to achieve s
     * [Command-line options](#command-line-options)
     * [Command-line options](#command-line-options)
     * [Getting the data](#getting-the-data)
     * [Getting the data](#getting-the-data)
         * [Dataset guidelines](#dataset-guidelines)
         * [Dataset guidelines](#dataset-guidelines)
+        * [Prebatching](#prebatching)
         * [BYO dataset](#byo-dataset)
         * [BYO dataset](#byo-dataset)
             * [Channel definitions and requirements](#channel-definitions-and-requirements)
             * [Channel definitions and requirements](#channel-definitions-and-requirements)
     * [Training process](#training-process)
     * [Training process](#training-process)
@@ -78,7 +79,7 @@ In the author’s SIM implementation, the internals of submodels differs slightl
 List of implementation differences between original SIM code and DIN/DIEN/SIM papers
 List of implementation differences between original SIM code and DIN/DIEN/SIM papers
 </b></summary>
 </b></summary>
 
 
-- Batch normalization before NLP is not included in papers.
+- Batch normalization before MLP is not included in papers.
 - Batch normalization in code used `trainable=False` during the training phase.
 - Batch normalization in code used `trainable=False` during the training phase.
 - ItemItemInteraction in DIN`s attention module in SIM implementation didn't correspond to activation unit inside DIN paper.
 - ItemItemInteraction in DIN`s attention module in SIM implementation didn't correspond to activation unit inside DIN paper.
   - Element-wise subtraction and multiplications are fed to MLP, skipping outer product operation.
   - Element-wise subtraction and multiplications are fed to MLP, skipping outer product operation.
@@ -375,7 +376,7 @@ The following section lists the requirements that you need to meet in order to s
 
 
 This repository contains a Dockerfile that extends the TensorFflow2 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
 This repository contains a Dockerfile that extends the TensorFflow2 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
 - [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
 - [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
-- [TensorFlow2 21.10-py3](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow/tags) NGC container
+- [TensorFlow2 22.01-py3](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow/tags) NGC container
 - Supported GPUs:
 - Supported GPUs:
   - [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
   - [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
   - [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
   - [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
@@ -417,9 +418,6 @@ To train your model using mixed or TF32 precision with Tensor Cores or using FP3
 5. Start preprocessing.
 5. Start preprocessing.
 
 
     For details of the required file format and certain preprocessing parameters refer to [BYO dataset](#byo-dataset).
     For details of the required file format and certain preprocessing parameters refer to [BYO dataset](#byo-dataset).
-    
-    
-    `${NUMBER_OF_USER_FEATURES}` defines how many user specific features are present in dataset. If using default Amazon Books dataset and `sim_preprocessing` script (as shown below), this parameter should be set to <b>1</b> (in this case, the only user specific features is <b>user_id</b>. Other features are item specific).
 
 
    ```bash
    ```bash
    python preprocessing/sim_preprocessing.py \
    python preprocessing/sim_preprocessing.py \
@@ -428,8 +426,7 @@ To train your model using mixed or TF32 precision with Tensor Cores or using FP3
 
 
    python preprocessing/parquet_to_tfrecord.py \
    python preprocessing/parquet_to_tfrecord.py \
     --amazon_dataset_path ${PARQUET_PATH} \
     --amazon_dataset_path ${PARQUET_PATH} \
-    --tfrecord_output_dir ${TF_RECORD_PATH} \
-    --number_of_user_features ${NUMBER_OF_USER_FEATURES}
+    --tfrecord_output_dir ${TF_RECORD_PATH}
    ```
    ```
 
 
 6. Start training (`${GPU}` is an arbitrary number of GPUs to be used).
 6. Start training (`${GPU}` is an arbitrary number of GPUs to be used).
@@ -496,10 +493,11 @@ The `main.py` script parameters are detailed in the following table.
 | training        | drop_remainder            | Drop remainder batch for training set (flag)                            | False                     |
 | training        | drop_remainder            | Drop remainder batch for training set (flag)                            | False                     |
 | training        | disable_cache             | Disable dataset caching after the first time it is iterated over (flag)        | False                     |
 | training        | disable_cache             | Disable dataset caching after the first time it is iterated over (flag)        | False                     |
 | training        | repeat_count              | Repeat training dataset this number of times                            | 0                         |
 | training        | repeat_count              | Repeat training dataset this number of times                            | 0                         |
-| training | prefetch_train_size |  Number of batches to prefetch in training. | -1 |
-| training | prefetch_test_size |  Number of batches to prefetch in evaluation. | -1 |
-| training | train_dataset_size |  Number of samples in training dataset (used to determine prefetch_train_size when --prefetch_train_size < 0) | 11796480 |
+| training | prefetch_train_size |  Number of batches to prefetch in training. | 10 |
+| training | prefetch_test_size |  Number of batches to prefetch in evaluation. | 2 |
 | training | long_seq_length | Determines the long history - short history split of history features | 90 |
 | training | long_seq_length | Determines the long history - short history split of history features | 90 |
+| training | prebatch_train_size | Batch size of batching applied during preprocessing to train dataset. | 0 |
+| training | prebatch_test_size | Batch size of batching applied during preprocessing to test dataset. | 0 |
 | results         | results_dir               | Path to the model result files storage                                  | /tmp/sim                  |
 | results         | results_dir               | Path to the model result files storage                                  | /tmp/sim                  |
 | results         | log_filename              | Name of the file to store logger output                                 | log.json                  |
 | results         | log_filename              | Name of the file to store logger output                                 | log.json                  |
 | results         | save_checkpoint_path      | Directory to save model checkpoints                                     | ""                        |
 | results         | save_checkpoint_path      | Directory to save model checkpoints                                     | ""                        |
@@ -511,8 +509,10 @@ The `main.py` script parameters are detailed in the following table.
 | run mode        | affinity                  | Type of CPU affinity                                                    | socket_unique_interleaved |
 | run mode        | affinity                  | Type of CPU affinity                                                    | socket_unique_interleaved |
 | run mode        | inter_op_parallelism      | Number of inter op threads                                              | 0                         |
 | run mode        | inter_op_parallelism      | Number of inter op threads                                              | 0                         |
 | run mode        | intra_op_parallelism      | Number of intra op threads                                              | 0                         |
 | run mode        | intra_op_parallelism      | Number of intra op threads                                              | 0                         |
+| run mode        | num_parallel_calls        | Parallelism level for tf.data API. If None, heuristic based on number of CPUs and number of GPUs will be used   |   None  |
 | reproducibility | seed                      | Random seed                                                             | -1                        |
 | reproducibility | seed                      | Random seed                                                             | -1                        |
 
 
+
 ### Command-line options
 ### Command-line options
 
 
 To view the full list of available options and their descriptions, use the `--help` command-line option, for example:
 To view the full list of available options and their descriptions, use the `--help` command-line option, for example:
@@ -534,6 +534,56 @@ The preprocessing steps applied to the raw data include:
 - Determining embedding table sizes for categorical features needed to construct a model
 - Determining embedding table sizes for categorical features needed to construct a model
 - Filter users for training split based on their number of interactions (discard users with less than 20 interactions)
 - Filter users for training split based on their number of interactions (discard users with less than 20 interactions)
 
 
+#### Prebatching
+
+Preprocessing scripts allow to apply batching prior to the model`s dataloader. This reduces the size of produced TFrecord files and speeds up dataloading.
+To do so, specify `--prebatch_train_size` and `--prebatch_test_size` while converting data using `scripts/parquet_to_tfrecord.py`. Later, while using the `main.py` script, pass the information about applied prebatch size via the same parameters.
+
+Example
+
+Start preprocessing from step 5. from [Quick Start Guide](#quick-start-guide):
+
+```bash
+python preprocessing/sim_preprocessing.py \
+--amazon_dataset_path ${RAW_DATASET_PATH} \
+--output_path ${PARQUET_PATH}
+
+python preprocessing/parquet_to_tfrecord.py \
+--amazon_dataset_path ${PARQUET_PATH} \
+--tfrecord_output_dir ${TF_RECORD_PATH} \
+--prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+--prebatch_train_size ${PREBATCH_TEST_SIZE}
+```
+
+And then train the model (step 6.):
+
+```bash
+mpiexec --allow-run-as-root --bind-to socket -np ${GPU} python main.py \
+--dataset_dir ${TF_RECORD_PATH} \
+--mode train \
+--model_type sim \
+--embedding_dim 16 \
+--drop_remainder \
+--optimizer adam \
+--lr 0.01 \
+--epochs 3 \
+--global_batch_size 131072 \
+--amp \
+--prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+--prebatch_train_size ${PREBATCH_TEST_SIZE}
+```
+
+<details>
+<summary><b>Prebatching details</b></summary>
+
+- The last batch for each split will pe saved to the separate file `remainder.tfrecord` unless there are enough samples to form a full batch.
+- Final batch size used in main script can be a multiple of prebatch size.
+- Final batch size used in main script can be a divider of prebatch size. In this case, when using multi GPU training, the number of batches received by each worker can be greater than 1 thus resulting in error during allgather operation. Dataset size, batch size and prebatch size have to be chosen with that limitation in mind.
+- For the orignal Amazon Books Dataset, parameters were set to PREBATCH_TRAIN_SIZE = PREBATCH_TEST_SIZE = 4096 for performance benchmarking purposes.
+</details>
+
+&nbsp;
+
 #### BYO dataset 
 #### BYO dataset 
 
 
 This implementation supports using other datasets thanks to BYO dataset functionality. 
 This implementation supports using other datasets thanks to BYO dataset functionality. 
@@ -676,7 +726,7 @@ source_spec:
     type: tfrecord
     type: tfrecord
 ```
 ```
 
 
-`dimensions` should contain the length of the history to which the entries will be padded.
+`dimensions` should contain the length of the sequencial features.
 
 
 Note that corresponsive features in `negative_history`, `positive_history`, `target_item_features` need to be listed in the same order in channel spec in each channel since they share embedding tables in the model. (for example `item_id` needs to be first and `cat_id` second). 
 Note that corresponsive features in `negative_history`, `positive_history`, `target_item_features` need to be listed in the same order in channel spec in each channel since they share embedding tables in the model. (for example `item_id` needs to be first and `cat_id` second). 
 
 
@@ -705,7 +755,7 @@ For performance reasons, the only supported dataset type is tfrecord.
 
 
 ### Training process
 ### Training process
 
 
-Training can be run using `main.py` script by specifying the `--mode train` parameter. The speed of training is measured by throughput, that is, the number of samples processed per second. Evaluation is based on the [Area under ROC Curve (ROC AUC)](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) metric. Model checkpoints may be stored using Checkpoint manager as specified via (...). Training and inference logs are saved to a directory specified via the `--results_dir` parameter. Mixed precision training is supported via the `--amp` flag. Multi-GPU training is performed using mpiexec and Horovod libraries.
+Training can be run using `main.py` script by specifying the `--mode train` parameter. The speed of training is measured by throughput, that is, the number of samples processed per second. Evaluation is based on the [Area under ROC Curve (ROC AUC)](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) metric. Model checkpoints may be stored using Checkpoint manager via the `--save_checkpoint_path` and `--load_checkpoint_path` parameters. Training and inference logs are saved to a directory specified via the `--results_dir` parameter. Mixed precision training is supported via the `--amp` flag. Multi-GPU training is performed using mpiexec and Horovod libraries.
 
 
 ### Inference process
 ### Inference process
 
 
@@ -778,7 +828,9 @@ mpiexec --allow-run-as-root --bind-to socket -np ${GPU} python main.py \
   --global_batch_size 131072 \
   --global_batch_size 131072 \
   --drop_remainder \
   --drop_remainder \
   --amp \
   --amp \
-  --benchmark
+  --benchmark \
+  --prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+  --prebatch_test_size ${PREBATCH_TEST_SIZE}
 ```
 ```
 
 
 Equivalent:
 Equivalent:
@@ -787,7 +839,9 @@ scripts/run_model.sh \
   --data_path ${TF_RECORD_PATH} \
   --data_path ${TF_RECORD_PATH} \
   --gpus ${GPU} \
   --gpus ${GPU} \
   --amp 1 \
   --amp 1 \
-  --benchmark 1 
+  --benchmark 1 \
+  --prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+  --prebatch_test_size ${PREBATCH_TEST_SIZE}
 ```
 ```
 
 
 #### Inference performance benchmark
 #### Inference performance benchmark
@@ -801,7 +855,9 @@ mpiexec --allow-run-as-root --bind-to socket -np ${GPU} python main.py \
   --model_type sim \
   --model_type sim \
   --global_batch_size 131072 \
   --global_batch_size 131072 \
   --amp \
   --amp \
-  --benchmark
+  --benchmark \
+  --prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+  --prebatch_test_size ${PREBATCH_TEST_SIZE}
 ```
 ```
 
 
 Equivalent:
 Equivalent:
@@ -811,7 +867,8 @@ scripts/run_model.sh \
   --gpus ${GPU} \
   --gpus ${GPU} \
   --amp 1 \
   --amp 1 \
   --benchmark 1 \
   --benchmark 1 \
-  --mode inference
+  --prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+  --prebatch_test_size ${PREBATCH_TEST_SIZE}
 ```
 ```
 
 
 ### Results
 ### Results
@@ -820,7 +877,7 @@ The following sections provide details on how we achieved our performance and ac
 
 
 #### Training accuracy results
 #### Training accuracy results
 
 
-Our results were obtained by running the `run_model.sh` bash script in the TensorFlow2 21.10-py3 NGC container. Experiments were run on 1 and 8 GPUs, with FP32/TF32 Precision and AMP and with XLA-OFF/XLA-ON. Other parameters were set to defaults.
+Our results were obtained by running the `run_model.sh` bash script in the TensorFlow2 21.10-py3 NGC container. Experiments were run on 1 and 8 GPUs, with FP32/TF32 Precision and AMP and with XLA-OFF/XLA-ON. Dataset was prebatched with the size of 16384. Other parameters were set to defaults.
 
 
 There were 10 runs for each configuration. In the `Training accuracy` sections, average values are reported. In the `Training stability` sections, values from all runs are included in plots.
 There were 10 runs for each configuration. In the `Training accuracy` sections, average values are reported. In the `Training stability` sections, values from all runs are included in plots.
 
 
@@ -962,7 +1019,7 @@ Figure 8. ROC curve for different configurations of Ampere/Volta, 1/8 GPUs, doub
 
 
 #### Training performance results
 #### Training performance results
 
 
-Our results were obtained by running the `scripts/run_model.sh` script in the TensorFlow2 21.10-py3 NGC container. 
+Our results were obtained by running the `scripts/run_model.sh` script in the TensorFlow2 21.10-py3 NGC container. Dataset was prebatched with the size of 16384.
 
 
 Numbers were averaged over 10 separate runs for each configuration.
 Numbers were averaged over 10 separate runs for each configuration.
 
 
@@ -974,12 +1031,12 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
 
 
 ##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
 ##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
 
 
-|GPUs |XLA  |Throughput - TF32 (samples/s)  |Throughput - mixed precision (samples/s) |Throughput speedup (mixed precision / TF32)  | Strong scaling - TF32 | Strong scaling - mixed precision |
-|-----|-----|--------------------|------------------------------|---------------------------------------------|-----------|-------------|
-|1    |OFF  |381211.31           |484360.65                     |1.27    | 1.00 | 1.00 |
-|1    |ON   |462012.86           |571727.91                     |1.24    | 1.00 | 1.00 |
-|8    |OFF  |2304284.08          |2475445.94                    |1.07   | 6.04 | 5.11 |
-|8    |ON   |2679300.61          |3006370.96                    |1.12   | 5.80 | 5.26 |
+|   GPUs |   XLA |   Throughput - TF32 (samples/s) |   Throughput - mixed precision (samples/s) |   Throughput speedup (mixed precision / TF32) |   Strong scaling - TF32 |   Strong scaling - mixed precision |
+|-------:|------:|--------------------------------:|-------------------------------------------:|----------------------------------------------:|------------------------:|-----------------------------------:|
+|      1 |     OFF |                       377254.65 |                                  479921.54 |                                          1.27 |                    1.00 |                               1.00 |
+|      1 |     ON |                       455724.01 |                                  565221.04 |                                          1.24 |                    1.00 |                               1.00 |
+|      8 |     OFF |                      2161681.55 |                                 2603489.60 |                                          1.20 |                    5.73 |                               5.42 |
+|      8 |     ON |                      2662368.18 |                                 2979441.80 |                                          1.12 |                    5.84 |                               5.27 |
 
 
 <details>
 <details>
 <summary><b>
 <summary><b>
@@ -990,24 +1047,24 @@ For each configuration of parameters present in the table, the `Speedup` column
 
 
 |GPUs |Precision      |Speedup |
 |GPUs |Precision      |Speedup |
 |-----|---------------|--------|
 |-----|---------------|--------|
-|1    |TF32           |1.212   |
-|1    |AMP            |1.180   |
-|8    |TF32           |1.163   |
-|8    |AMP            |1.214   |
+|1    |TF32           |1.208   |
+|1    |AMP            |1.178   |
+|8    |TF32           |1.232   |
+|8    |AMP            |1.119   |
 </details>
 </details>
 
 
 &nbsp;
 &nbsp;
 
 
 ##### Training performance: NVIDIA DGX-2 (16x V100 32GB)
 ##### Training performance: NVIDIA DGX-2 (16x V100 32GB)
 
 
-|GPUs |XLA  |Throughput - FP32 (samples/s)  |Throughput - mixed precision (samples/s) |Throughput speedup (mixed precision / FP32) | Strong scaling - FP32 | Strong scaling - mixed precision |
-|-----|-----|--------------------|------------------------------|---------------------------------------------|----------|-------------|
-|1    |OFF  |210772.27           |312580.01                     |1.48                                         | 1.00 | 1.00 |
-|1    |ON   |248514.27           |358305.52                     |1.44                                         | 1.00 | 1.00 |
-|8    |OFF  |1357463.39          |1785361.62                    |1.32                                         | 6.44 | 5.71 |
-|8    |ON   |1584757.09          |2091403.04                    |1.32                                         | 7.52 | 6.69 |
-|16   |OFF  |2319719.76          |2837309.15                    |1.22                                         | 11.00 | 9.08 |
-|16   |ON   |2681789.69          |3168488.89                    |1.18                                         | 12.73 | 10.14 |
+|   GPUs |   XLA |   Throughput - FP32 (samples/s) |   Throughput - mixed precision (samples/s) |   Throughput speedup (mixed precision / FP32) |   Strong scaling - FP32 |   Strong scaling - mixed precision |
+|-------:|------:|--------------------------------:|-------------------------------------------:|----------------------------------------------:|------------------------:|-----------------------------------:|
+|      1 |     OFF |                       209376.38 |                                  309752.48 |                                          1.48 |                    1.00 |                               1.00 |
+|      1 |     ON |                       245414.62 |                                  348945.59 |                                          1.42 |                    1.00 |                               1.00 |
+|      8 |     OFF |                      1310239.01 |                                 1689602.79 |                                          1.29 |                    6.26 |                               5.45 |
+|      8 |     ON |                      1483120.32 |                                 1962226.32 |                                          1.32 |                    6.04 |                               5.62 |
+|     16 |     OFF |                      2127221.65 |                                 2555926.79 |                                          1.20 |                   10.16 |                               8.25 |
+|     16 |     ON |                      2450499.40 |                                 2788997.07 |                                          1.14 |                    9.99 |                               7.99 |
 
 
 <details>
 <details>
 <summary><b>
 <summary><b>
@@ -1018,12 +1075,12 @@ For each configuration of parameters present in the table, the `Speedup` column
 
 
 |GPUs |AMP                 |Speedup        |
 |GPUs |AMP                 |Speedup        |
 |-----|--------------------|---------------|
 |-----|--------------------|---------------|
-|1    |FP32                |1.179          |
-|1    |AMP                 |1.146          |
-|8    |FP32                |1.167          |
-|8    |AMP                 |1.171          |
-|16   |FP32                |1.156          |
-|16   |AMP                 |1.117          |
+|1    |FP32                |1.172          |
+|1    |AMP                 |1.127          |
+|8    |FP32                |1.132          |
+|8    |AMP                 |1.161          |
+|16   |FP32                |1.152          |
+|16   |AMP                 |1.091          |
 </details>
 </details>
 
 
 &nbsp;
 &nbsp;
@@ -1033,16 +1090,17 @@ For each configuration of parameters present in the table, the `Speedup` column
 NVIDIA DGX A100 / DGX-2 (Ampere / Volta) training speedup
 NVIDIA DGX A100 / DGX-2 (Ampere / Volta) training speedup
 </b></summary>
 </b></summary>
 
 
-|GPUs |XLA    |Precision       |Speedup|
-|-----|-------|---------------|-------|
-|1    |OFF    |TF32/FP32      |1.809  |
-|1    |OFF    |AMP            |1.550  |
-|1    |ON     |TF32/FP32      |1.860  |
-|1    |ON     |AMP            |1.596  |
-|8    |OFF    |TF32/FP32      |1.697  |
-|8    |OFF    |AMP            |1.387  |
-|8    |ON     |TF32/FP32      |1.691  |
-|8    |ON     |AMP            |1.437  |
+
+|   GPUs |   XLA | Precision   |   Speedup |
+|-------:|------:|:------------|----------:|
+|      1 |     OFF | TF32/FP32   |     1.802 |
+|      1 |     OFF | AMP         |     1.549 |
+|      1 |     ON | TF32/FP32   |     1.857 |
+|      1 |     ON | AMP         |     1.620 |
+|      8 |     OFF | TF32/FP32   |     1.650 |
+|      8 |     OFF | AMP         |     1.541 |
+|      8 |     ON | TF32/FP32   |     1.795 |
+|      8 |     ON | AMP         |     1.518 |
 
 
 </details>
 </details>
 
 
@@ -1060,74 +1118,44 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
 
 
 ##### Inference performance: NVIDIA DGX A100 (8x A100 80GB)
 ##### Inference performance: NVIDIA DGX A100 (8x A100 80GB)
 
 
-|GPUs |Global batch size|XLA  |Throughput - TF32 (samples/s)|Throughput - mixed precision (samples/s)|Throughput speedup (mixed precision / TF32)  | Strong scaling - TF32 | Strong scaling - mixed precision |
-|-----|----------|-----|---------------|----------------------------|---------------------------------------------|----------------|---------|
-|1    |4096      |ON   |561967.1       |535674.63                   |0.95                                         | 1.00 | 1.00 |
-|1    |8192      |ON   |670885.47      |758801.43                   |1.13                                         | 1.00 | 1.00 |
-|1    |16384     |ON   |788890.79      |920695.88                   |1.17                                         | 1.00 | 1.00 |
-|1    |32768     |ON   |855056.39      |1035530.23                  |1.21                                         | 1.00 | 1.00 |
-|1    |65536     |ON   |918649.98      |1081408.05                  |1.18                                         | 1.00 | 1.00 |
-|1    |131072    |ON   |918555.37      |771119.78                   |0.84                                         | 1.00 | 1.00 |
-|8    |4096      |ON   |1130031.99     |935848.52                   |0.83                                         | 2.01 | 1.75 |
-|8    |8192      |ON   |2246441.94     |1885511.32                  |0.84                                         | 3.64 | 2.48 |
-|8    |16384     |ON   |4000071.31     |3303417.5                   |0.83                                         | 5.07 | 3.59 |
-|8    |32768     |ON   |5479754.01     |5762298.42                  |1.05                                         | 6.41 | 5.56 |
-|8    |65536     |ON   |6736333.91     |7869825.77                  |1.17                                         | 7.33 | 7.28 |
-|8    |131072    |ON   |7598665.72     |9002545.49                  |1.18                                         | 8.27 | 11.67 |
+|   Batch Size |   XLA |   Throughput - TF32 (samples/s) |   Throughput - mixed precision (samples/s) |   Throughput speedup (mixed precision / TF32) |
+|--------------------:|------:|--------------------------------:|-------------------------------------------:|----------------------------------------------:|
+|                4096 |     ON |                       618547.45 |                                  669640.65 |                                          1.08 |
+|                8192 |     ON |                       722801.14 |                                  849101.88 |                                          1.17 |
+|               16384 |     ON |                       859418.77 |                                 1051361.67 |                                          1.22 |
+|               32768 |     ON |                       976771.70 |                                 1269000.97 |                                          1.30 |
+|               65536 |     ON |                      1082688.51 |                                 1444729.52 |                                          1.33 |
+|              131072 |     ON |                      1094733.64 |                                 1483542.86 |                                          1.36 |
 
 
 <details>
 <details>
 <summary><b> Complete table of DGX A100 inference performance results </b></summary>
 <summary><b> Complete table of DGX A100 inference performance results </b></summary>
 
 
-|GPUSs|Global Batch Size   |XLA    |Precision      |Throughput  (samples/s)           |
-|-----|--------------------|-------|---------------|-----------------------|
-|1    |4096                |OFF    |TF32           |585246.51 ± 10513.06   |
-|1    |8192                |OFF    |TF32           |750729.14 ± 17029.41   |
-|1    |16384               |OFF    |TF32           |803593.59 ± 11207.58   |
-|1    |32768               |OFF    |TF32           |822162.85 ± 5071.85    |
-|1    |65536               |OFF    |TF32           |775748.42 ± 36821.04   |
-|1    |131072              |OFF    |TF32           |644740.49 ± 31148.79   |
-|1    |4096                |OFF    |AMP            |516164.09 ± 9916.80    |
-|1    |8192                |OFF    |AMP            |778740.41 ± 19384.36   |
-|1    |16384               |OFF    |AMP            |932211.18 ± 20331.07   |
-|1    |32768               |OFF    |AMP            |990696.89 ± 11554.34   |
-|1    |65536               |OFF    |AMP            |715678.16 ± 30944.63   |
-|1    |131072              |OFF    |AMP            |611740.50 ± 21392.81   |
-|1    |4096                |ON     |TF32           |561967.10 ± 18100.55   |
-|1    |8192                |ON     |TF32           |670885.47 ± 11149.51   |
-|1    |16384               |ON     |TF32           |788890.79 ± 10058.99   |
-|1    |32768               |ON     |TF32           |855056.39 ± 14349.13   |
-|1    |65536               |ON     |TF32           |918649.98 ± 7571.32    |
-|1    |131072              |ON     |TF32           |918555.37 ± 15036.89   |
-|1    |4096                |ON     |AMP            |535674.63 ± 14003.35   |
-|1    |8192                |ON     |AMP            |758801.43 ± 15225.76   |
-|1    |16384               |ON     |AMP            |920695.88 ± 15325.29   |
-|1    |32768               |ON     |AMP            |1035530.23 ± 16055.40  |
-|1    |65536               |ON     |AMP            |1081408.05 ± 41906.29  |
-|1    |131072              |ON     |AMP            |771119.78 ± 79589.50   |
-|8    |4096                |OFF    |TF32           |765154.17 ± 30582.87   |
-|8    |8192                |OFF    |TF32           |1396414.24 ± 99987.01  |
-|8    |16384               |OFF    |TF32           |2281597.86 ± 77483.79  |
-|8    |32768               |OFF    |TF32           |3555014.42 ± 145944.33 |
-|8    |65536               |OFF    |TF32           |4792413.60 ± 203285.21 |
-|8    |131072              |OFF    |TF32           |5941195.01 ± 182519.72 |
-|8    |4096                |OFF    |AMP            |642706.11 ± 28063.45   |
-|8    |8192                |OFF    |AMP            |1197789.38 ± 47262.95  |
-|8    |16384               |OFF    |AMP            |1961353.19 ± 49818.70  |
-|8    |32768               |OFF    |AMP            |3267263.60 ± 130680.70 |
-|8    |65536               |OFF    |AMP            |4847783.16 ± 257991.99 |
-|8    |131072              |OFF    |AMP            |6413842.15 ± 289543.64 |
-|8    |4096                |ON     |TF32           |1130031.99 ± 75271.24  |
-|8    |8192                |ON     |TF32           |2246441.94 ± 26132.90  |
-|8    |16384               |ON     |TF32           |4000071.31 ± 48054.68  |
-|8    |32768               |ON     |TF32           |5479754.01 ± 170421.20 |
-|8    |65536               |ON     |TF32           |6736333.91 ± 153745.68 |
-|8    |131072              |ON     |TF32           |7598665.72 ± 174188.78 |
-|8    |4096                |ON     |AMP            |935848.52 ± 14583.48   |
-|8    |8192                |ON     |AMP            |1885511.32 ± 22206.00  |
-|8    |16384               |ON     |AMP            |3303417.50 ± 210306.61 |
-|8    |32768               |ON     |AMP            |5762298.42 ± 140412.56 |
-|8    |65536               |ON     |AMP            |7869825.77 ± 305838.69 |
-|8    |131072              |ON     |AMP            |9002545.49 ± 438204.32 |
+|   Batch Size | XLA   | Precision   | Throughput  (samples/s)   |
+|-------------:|:------|:------------|:--------------------------|
+|         4096 | OFF   | TF32        | 708349.73 ± 14161.58      |
+|         8192 | OFF   | TF32        | 873335.82 ± 8539.56       |
+|        16384 | OFF   | TF32        | 937987.79 ± 12114.34      |
+|        32768 | OFF   | TF32        | 943313.07 ± 8631.81       |
+|        65536 | OFF   | TF32        | 960794.46 ± 7388.45       |
+|       131072 | OFF   | TF32        | 966245.27 ± 8637.82       |
+|         4096 | OFF   | AMP         | 645394.94 ± 14844.27      |
+|         8192 | OFF   | AMP         | 919410.07 ± 11355.28      |
+|        16384 | OFF   | AMP         | 1136346.66 ± 14529.91     |
+|        32768 | OFF   | AMP         | 1216810.45 ± 21013.12     |
+|        65536 | OFF   | AMP         | 1287305.05 ± 19373.18     |
+|       131072 | OFF   | AMP         | 1298478.97 ± 10733.67     |
+|         4096 | ON    | TF32        | 618547.45 ± 6569.97       |
+|         8192 | ON    | TF32        | 722801.14 ± 9448.19       |
+|        16384 | ON    | TF32        | 859418.77 ± 10012.61      |
+|        32768 | ON    | TF32        | 976771.70 ± 13377.36      |
+|        65536 | ON    | TF32        | 1082688.51 ± 8523.55      |
+|       131072 | ON    | TF32        | 1094733.64 ± 11157.18     |
+|         4096 | ON    | AMP         | 669640.65 ± 9319.68       |
+|         8192 | ON    | AMP         | 849101.88 ± 14068.04      |
+|        16384 | ON    | AMP         | 1051361.67 ± 15310.42     |
+|        32768 | ON    | AMP         | 1269000.97 ± 23971.56     |
+|        65536 | ON    | AMP         | 1444729.52 ± 18011.54     |
+|       131072 | ON    | AMP         | 1483542.86 ± 6751.29      |
 
 
 </details>
 </details>
 
 
@@ -1138,32 +1166,20 @@ DGX A100 XLA-ON / XLA-OFF inference Speedup
 
 
 For each configuration of parameters present in the table, the `Speedup` column shows the speedup achieved by turning on XLA.
 For each configuration of parameters present in the table, the `Speedup` column shows the speedup achieved by turning on XLA.
 
 
-|GPUs |Global Batch Size   |Precision      |Speedup |
-|-----|--------------------|---------------|--------|
-|1    |4096                |TF32           |0.960   |
-|1    |8192                |TF32           |0.894   |
-|1    |16384               |TF32           |0.982   |
-|1    |32768               |TF32           |1.040   |
-|1    |65536               |TF32           |1.184   |
-|1    |131072              |TF32           |1.425   |
-|1    |4096                |AMP            |1.038   |
-|1    |8192                |AMP            |0.974   |
-|1    |16384               |AMP            |0.988   |
-|1    |32768               |AMP            |1.045   |
-|1    |65536               |AMP            |1.511   |
-|1    |131072              |AMP            |1.261   |
-|8    |4096                |TF32           |1.477   |
-|8    |8192                |TF32           |1.609   |
-|8    |16384               |TF32           |1.753   |
-|8    |32768               |TF32           |1.541   |
-|8    |65536               |TF32           |1.406   |
-|8    |131072              |TF32           |1.279   |
-|8    |4096                |AMP            |1.456   |
-|8    |8192                |AMP            |1.574   |
-|8    |16384               |AMP            |1.684   |
-|8    |32768               |AMP            |1.764   |
-|8    |65536               |AMP            |1.623   |
-|8    |131072              |AMP            |1.404   |
+|Batch Size   |Precision      |Speedup |
+|--------------------|---------------|--------|
+|4096                |TF32           |0.873   |
+|8192                |TF32           |0.828   |
+|16384               |TF32           |0.916   |
+|32768               |TF32           |1.035   |
+|65536               |TF32           |1.127   |
+|131072              |TF32           |1.133   |
+|4096                |AMP            |1.038   |
+|8192                |AMP            |0.924   |
+|16384               |AMP            |0.925   |
+|32768               |AMP            |1.043   |
+|65536               |AMP            |1.187   |
+|131072              |AMP            |1.143   |
 
 
 </details>
 </details>
 
 
@@ -1171,153 +1187,69 @@ For each configuration of parameters present in the table, the `Speedup` column
 
 
 ##### Inference performance: NVIDIA DGX-2 (16x V100 32GB)
 ##### Inference performance: NVIDIA DGX-2 (16x V100 32GB)
 
 
-|GPUs |Global batch size|XLA  |Throughput - FP32 (samples/s)|Throughput - mixed precision (samples/s)|Throughput speedup (mixed precision / FP32)  | Strong scaling - FP32 | Strong scaling - mixed precision |
-|-----|----------|-----|---------------|----------------------------|---------------------------------------------|--------|--------|
-|1    |4096      |ON   |403479.95      |479051.62                   |1.19                                         | 1.00 | 1.00 |
-|1    |8192      |ON   |480491.12      |600002.95                   |1.25                                         | 1.00 | 1.00 |
-|1    |16384     |ON   |538737.44      |713203.59                   |1.32                                         | 1.00 | 1.00 |
-|1    |32768     |ON   |580958.93      |790782.1                    |1.36                                         | 1.00 | 1.00 |
-|1    |65536     |ON   |586275.07      |818038.44                   |1.40                                         | 1.00 | 1.00 |
-|1    |131072    |ON   |613524.11      |734034.26                   |1.20                                         | 1.00 | 1.00 |
-|8    |4096      |ON   |1059775.22     |909719.3                    |0.86                                         | 2.63 | 1.90 |
-|8    |8192      |ON   |1845819.99     |1752510.62                  |0.95                                         | 3.84 | 2.92 |
-|8    |16384     |ON   |2801114.77     |2898423.08                  |1.03                                         | 5.20 | 4.06 |
-|8    |32768     |ON   |3396766.27     |4102026.01                  |1.21                                         | 5.85 | 5.19 |
-|8    |65536     |ON   |3911994.39     |4725023.23                  |1.21                                         | 6.67 | 5.78 |
-|8    |131072    |ON   |4197603.74     |5413542.58                  |1.29                                         | 6.84 | 7.38 |
-|16   |4096      |ON   |1142272.86     |924525.38                   |0.81                                         | 2.83 | 1.93 |
-|16   |8192      |ON   |2068920.7      |1917814.81                  |0.93                                         | 4.31 | 3.20 |
-|16   |16384     |ON   |3091676.83     |3496153.45                  |1.13                                         | 5.74 | 4.90 |
-|16   |32768     |ON   |5132772.75     |5063615.77                  |0.99                                         | 8.84 | 6.40 |
-|16   |65536     |ON   |6553882.87     |8247475.75                  |1.26                                         | 11.18 | 10.08 |
-|16   |131072    |ON   |7555906.17     |9571965.84                  |1.27                                         | 12.32 | 13.04 |
+|   Batch Size |   XLA |   Throughput - FP32 (samples/s) |   Throughput - mixed precision (samples/s) |   Throughput speedup (mixed precision / FP32) |
+|--------------------:|------:|--------------------------------:|-------------------------------------------:|----------------------------------------------:|
+|                4096 |     ON |                       444532.22 |                                  541975.24 |                                          1.22 |
+|                8192 |     ON |                       505047.64 |                                  642784.48 |                                          1.27 |
+|               16384 |     ON |                       549325.54 |                                  727077.63 |                                          1.32 |
+|               32768 |     ON |                       587452.73 |                                  788606.35 |                                          1.34 |
+|               65536 |     ON |                       605187.67 |                                  832651.59 |                                          1.38 |
+|              131072 |     ON |                       599557.03 |                                  840602.90 |                                          1.40 |
 
 
 <details>
 <details>
 <summary><b>
 <summary><b>
-Complete table of DGX2 inference performance results
+Complete table of DGX-2 inference performance results
 </b></summary>
 </b></summary>
 
 
-|GPUs |Global Batch Size   |XLA    |Precision      |Throughput (samples/s)           |
-|-----|--------------------|-------|---------------|-----------------------|
-|1    |4096                |OFF    |FP32           |459149.07 ± 20971.34   |
-|1    |8192                |OFF    |FP32           |488763.98 ± 15037.09   |
-|1    |16384               |OFF    |FP32           |516804.05 ± 8355.49    |
-|1    |32768               |OFF    |FP32           |534387.97 ± 4763.49    |
-|1    |65536               |OFF    |FP32           |536215.89 ± 5794.77    |
-|1    |131072              |OFF    |FP32           |538646.76 ± 6359.47    |
-|1    |4096                |OFF    |AMP            |488475.14 ± 6226.30    |
-|1    |8192                |OFF    |AMP            |632098.48 ± 27370.49   |
-|1    |16384               |OFF    |AMP            |705878.12 ± 7852.19    |
-|1    |32768               |OFF    |AMP            |739740.73 ± 6866.73    |
-|1    |65536               |OFF    |AMP            |618291.18 ± 26749.52   |
-|1    |131072              |OFF    |AMP            |544071.41 ± 19200.23   |
-|1    |4096                |ON     |FP32           |403479.95 ± 4079.19    |
-|1    |8192                |ON     |FP32           |480491.12 ± 6828.93    |
-|1    |16384               |ON     |FP32           |538737.44 ± 10932.49   |
-|1    |32768               |ON     |FP32           |580958.93 ± 9544.37    |
-|1    |65536               |ON     |FP32           |586275.07 ± 7640.59    |
-|1    |131072              |ON     |FP32           |613524.11 ± 7931.04    |
-|1    |4096                |ON     |AMP            |479051.62 ± 6076.26    |
-|1    |8192                |ON     |AMP            |600002.95 ± 16380.88   |
-|1    |16384               |ON     |AMP            |713203.59 ± 9515.25    |
-|1    |32768               |ON     |AMP            |790782.10 ± 10788.69   |
-|1    |65536               |ON     |AMP            |818038.44 ± 14132.80   |
-|1    |131072              |ON     |AMP            |734034.26 ± 34664.74   |
-|8    |4096                |OFF    |FP32           |502947.25 ± 105758.96  |
-|8    |8192                |OFF    |FP32           |809285.58 ± 112765.45  |
-|8    |16384               |OFF    |FP32           |1974085.95 ± 476616.90 |
-|8    |32768               |OFF    |FP32           |2990517.14 ± 645490.89 |
-|8    |65536               |OFF    |FP32           |3662830.22 ± 191010.11 |
-|8    |131072              |OFF    |FP32           |3978985.17 ± 142801.19 |
-|8    |4096                |OFF    |AMP            |596945.98 ± 92977.56   |
-|8    |8192                |OFF    |AMP            |730694.36 ± 67972.28   |
-|8    |16384               |OFF    |AMP            |1758189.25 ± 340547.41 |
-|8    |32768               |OFF    |AMP            |3873856.45 ± 528746.35 |
-|8    |65536               |OFF    |AMP            |4863371.50 ± 297299.34 |
-|8    |131072              |OFF    |AMP            |5134261.52 ± 473726.31 |
-|8    |4096                |ON     |FP32           |1059775.22 ± 24386.54  |
-|8    |8192                |ON     |FP32           |1845819.99 ± 250767.40 |
-|8    |16384               |ON     |FP32           |2801114.77 ± 210397.18 |
-|8    |32768               |ON     |FP32           |3396766.27 ± 221795.61 |
-|8    |65536               |ON     |FP32           |3911994.39 ± 239259.17 |
-|8    |131072              |ON     |FP32           |4197603.74 ± 158110.80 |
-|8    |4096                |ON     |AMP            |909719.30 ± 135634.13  |
-|8    |8192                |ON     |AMP            |1752510.62 ± 87042.91  |
-|8    |16384               |ON     |AMP            |2898423.08 ± 231659.28 |
-|8    |32768               |ON     |AMP            |4102026.01 ± 254242.94 |
-|8    |65536               |ON     |AMP            |4725023.23 ± 322597.53 |
-|8    |131072              |ON     |AMP            |5413542.58 ± 364633.26 |
-|16   |4096                |OFF    |FP32           |865109.29 ± 40032.58   |
-|16   |8192                |OFF    |FP32           |1565843.18 ± 305582.99 |
-|16   |16384               |OFF    |FP32           |3109303.21 ± 240314.57 |
-|16   |32768               |OFF    |FP32           |5750753.42 ± 898435.09 |
-|16   |65536               |OFF    |FP32           |6456324.48 ± 730326.61 |
-|16   |131072              |OFF    |FP32           |7415730.04 ± 434928.14 |
-|16   |4096                |OFF    |AMP            |742890.53 ± 27541.80   |
-|16   |8192                |OFF    |AMP            |1468615.49 ± 67548.46  |
-|16   |16384               |OFF    |AMP            |2591245.05 ± 394504.75 |
-|16   |32768               |OFF    |AMP            |4671719.91 ± 721705.81 |
-|16   |65536               |OFF    |AMP            |7982733.55 ± 1242742.25|
-|16   |131072              |OFF    |AMP            |9867894.78 ± 679119.71 |
-|16   |4096                |ON     |FP32           |1142272.86 ± 43154.49  |
-|16   |8192                |ON     |FP32           |2068920.70 ± 130214.35 |
-|16   |16384               |ON     |FP32           |3091676.83 ± 991449.61 |
-|16   |32768               |ON     |FP32           |5132772.75 ± 525201.10 |
-|16   |65536               |ON     |FP32           |6553882.87 ± 400638.86 |
-|16   |131072              |ON     |FP32           |7555906.17 ± 626110.02 |
-|16   |4096                |ON     |AMP            |924525.38 ± 163488.57  |
-|16   |8192                |ON     |AMP            |1917814.81 ± 59114.71  |
-|16   |16384               |ON     |AMP            |3496153.45 ± 190771.71 |
-|16   |32768               |ON     |AMP            |5063615.77 ± 1281699.58|
-|16   |65536               |ON     |AMP            |8247475.75 ± 539827.60 |
-|16   |131072              |ON     |AMP            |9571965.84 ± 764075.50 |
+|   Batch Size | XLA   | Precision   | Throughput  (samples/s)   |
+|-------------:|:------|:------------|:--------------------------|
+|         4096 | OFF   | FP32        | 459175.30 ± 23184.33      |
+|         8192 | OFF   | FP32        | 499179.20 ± 15967.26      |
+|        16384 | OFF   | FP32        | 525180.72 ± 2521.56       |
+|        32768 | OFF   | FP32        | 532042.10 ± 4020.44       |
+|        65536 | OFF   | FP32        | 534307.20 ± 7276.26       |
+|       131072 | OFF   | FP32        | 532311.44 ± 6195.16       |
+|         4096 | OFF   | AMP         | 581771.66 ± 6163.50       |
+|         8192 | OFF   | AMP         | 665048.04 ± 4607.95       |
+|        16384 | OFF   | AMP         | 716355.19 ± 7174.98       |
+|        32768 | OFF   | AMP         | 741642.61 ± 4981.04       |
+|        65536 | OFF   | AMP         | 755141.25 ± 6175.05       |
+|       131072 | OFF   | AMP         | 744459.46 ± 8183.17       |
+|         4096 | ON    | FP32        | 444532.22 ± 6239.01       |
+|         8192 | ON    | FP32        | 505047.64 ± 6543.06       |
+|        16384 | ON    | FP32        | 549325.54 ± 2841.21       |
+|        32768 | ON    | FP32        | 587452.73 ± 2366.43       |
+|        65536 | ON    | FP32        | 605187.67 ± 3740.07       |
+|       131072 | ON    | FP32        | 599557.03 ± 11811.28      |
+|         4096 | ON    | AMP         | 541975.24 ± 4441.93       |
+|         8192 | ON    | AMP         | 642784.48 ± 4721.08       |
+|        16384 | ON    | AMP         | 727077.63 ± 5332.80       |
+|        32768 | ON    | AMP         | 788606.35 ± 11705.36      |
+|        65536 | ON    | AMP         | 832651.59 ± 10401.17      |
+|       131072 | ON    | AMP         | 840602.90 ± 16358.73      |
 </details>
 </details>
 
 
 <details>
 <details>
 <summary><b>
 <summary><b>
-DGX A100 XLA-ON / XLA-OFF inference speedup
+DGX-2 XLA-ON / XLA-OFF inference speedup
 </b></summary>
 </b></summary>
 
 
 For each configuration of parameters present in the table, the `Speedup` column shows the speedup achieved by turning on XLA.
 For each configuration of parameters present in the table, the `Speedup` column shows the speedup achieved by turning on XLA.
 
 
-|GPUs |Global Batch Size   |Precision      |Speedup |
-|-----|--------------------|---------------|--------|
-|1    |4096                |FP32           |0.879   |
-|1    |8192                |FP32           |0.983   |
-|1    |16384               |FP32           |1.042   |
-|1    |32768               |FP32           |1.087   |
-|1    |65536               |FP32           |1.093   |
-|1    |131072              |FP32           |1.139   |
-|1    |4096                |AMP            |0.981   |
-|1    |8192                |AMP            |0.949   |
-|1    |16384               |AMP            |1.010   |
-|1    |32768               |AMP            |1.069   |
-|1    |65536               |AMP            |1.323   |
-|1    |131072              |AMP            |1.349   |
-|8    |4096                |FP32           |2.107   |
-|8    |8192                |FP32           |2.281   |
-|8    |16384               |FP32           |1.419   |
-|8    |32768               |FP32           |1.136   |
-|8    |65536               |FP32           |1.068   |
-|8    |131072              |FP32           |1.055   |
-|8    |4096                |AMP            |1.524   |
-|8    |8192                |AMP            |2.398   |
-|8    |16384               |AMP            |1.649   |
-|8    |32768               |AMP            |1.059   |
-|8    |65536               |AMP            |0.972   |
-|8    |131072              |AMP            |1.054   |
-|16   |4096                |FP32           |1.320   |
-|16   |8192                |FP32           |1.321   |
-|16   |16384               |FP32           |0.994   |
-|16   |32768               |FP32           |0.893   |
-|16   |65536               |FP32           |1.015   |
-|16   |131072              |FP32           |1.019   |
-|16   |4096                |AMP            |1.244   |
-|16   |8192                |AMP            |1.306   |
-|16   |16384               |AMP            |1.349   |
-|16   |32768               |AMP            |1.084   |
-|16   |65536               |AMP            |1.033   |
-|16   |131072              |AMP            |0.970   |
+|Batch Size   |Precision      |Speedup |
+|--------------------|---------------|--------|
+|4096                |TF32           |0.968   |
+|8192                |TF32           |1.012   |
+|16384               |TF32           |1.046   |
+|32768               |TF32           |1.104   |
+|65536               |TF32           |1.133   |
+|131072              |TF32           |1.126   |
+|4096                |AMP            |0.932  |
+|8192                |AMP            |0.967   |
+|16384               |AMP            |1.384   |
+|32768               |AMP            |1.063   |
+|65536               |AMP            |1.103   |
+|131072              |AMP            |1.129   |
 </details>
 </details>
 
 
 &nbsp;
 &nbsp;
@@ -1327,56 +1259,32 @@ For each configuration of parameters present in the table, the `Speedup` column
 NVIDIA A100 / DGX-2 (Ampere / Volta) inference speedup
 NVIDIA A100 / DGX-2 (Ampere / Volta) inference speedup
 </b></summary>
 </b></summary>
 
 
-|GPUs |Global Batch Size   |XLA    |Precision      |Speedup |
-|-----|--------------------|-------|---------------|--------|
-|1    |4096                |OFF    |TF32/FP32      |1.275   |
-|1    |8192                |OFF    |TF32/FP32      |1.536   |
-|1    |16384               |OFF    |TF32/FP32      |1.555   |
-|1    |32768               |OFF    |TF32/FP32      |1.539   |
-|1    |65536               |OFF    |TF32/FP32      |1.447   |
-|1    |131072              |OFF    |TF32/FP32      |1.197   |
-|1    |4096                |OFF    |AMP            |1.057   |
-|1    |8192                |OFF    |AMP            |1.232   |
-|1    |16384               |OFF    |AMP            |1.321   |
-|1    |32768               |OFF    |AMP            |1.339   |
-|1    |65536               |OFF    |AMP            |1.158   |
-|1    |131072              |OFF    |AMP            |1.124   |
-|1    |4096                |ON     |TF32/FP32      |1.393   |
-|1    |8192                |ON     |TF32/FP32      |1.396   |
-|1    |16384               |ON     |TF32/FP32      |1.464   |
-|1    |32768               |ON     |TF32/FP32      |1.472   |
-|1    |65536               |ON     |TF32/FP32      |1.567   |
-|1    |131072              |ON     |TF32/FP32      |1.497   |
-|1    |4096                |ON     |AMP            |1.118   |
-|1    |8192                |ON     |AMP            |1.265   |
-|1    |16384               |ON     |AMP            |1.291   |
-|1    |32768               |ON     |AMP            |1.310   |
-|1    |65536               |ON     |AMP            |1.322   |
-|1    |131072              |ON     |AMP            |1.051   |
-|8    |4096                |OFF    |TF32/FP32      |1.521   |
-|8    |8192                |OFF    |TF32/FP32      |1.725   |
-|8    |16384               |OFF    |TF32/FP32      |1.156   |
-|8    |32768               |OFF    |TF32/FP32      |1.189   |
-|8    |65536               |OFF    |TF32/FP32      |1.308   |
-|8    |131072              |OFF    |TF32/FP32      |1.493   |
-|8    |4096                |OFF    |AMP            |1.077   |
-|8    |8192                |OFF    |AMP            |1.639   |
-|8    |16384               |OFF    |AMP            |1.116   |
-|8    |32768               |OFF    |AMP            |0.843   |
-|8    |65536               |OFF    |AMP            |0.997   |
-|8    |131072              |OFF    |AMP            |1.249   |
-|8    |4096                |ON     |TF32/FP32      |1.066   |
-|8    |8192                |ON     |TF32/FP32      |1.217   |
-|8    |16384               |ON     |TF32/FP32      |1.428   |
-|8    |32768               |ON     |TF32/FP32      |1.613   |
-|8    |65536               |ON     |TF32/FP32      |1.722   |
-|8    |131072              |ON     |TF32/FP32      |1.810   |
-|8    |4096                |ON     |AMP            |1.029   |
-|8    |8192                |ON     |AMP            |1.076   |
-|8    |16384               |ON     |AMP            |1.140   |
-|8    |32768               |ON     |AMP            |1.405   |
-|8    |65536               |ON     |AMP            |1.666   |
-|8    |131072              |ON     |AMP            |1.663   |
+|   Batch Size | XLA   | Precision   |   Speedup |
+|-------------:|:------|:------------|----------:|
+|         4096 | OFF   | TF32/FP32   |      1.54 |
+|         8192 | OFF   | TF32/FP32   |      1.75 |
+|        16384 | OFF   | TF32/FP32   |      1.79 |
+|        32768 | OFF   | TF32/FP32   |      1.77 |
+|        65536 | OFF   | TF32/FP32   |      1.80 |
+|       131072 | OFF   | TF32/FP32   |      1.81 |
+|         4096 | OFF   | AMP         |      1.11 |
+|         8192 | OFF   | AMP         |      1.38 |
+|        16384 | OFF   | AMP         |      1.59 |
+|        32768 | OFF   | AMP         |      1.64 |
+|        65536 | OFF   | AMP         |      1.71 |
+|       131072 | OFF   | AMP         |      1.74 |
+|         4096 | ON    | TF32/FP32   |      1.39 |
+|         8192 | ON    | TF32/FP32   |      1.43 |
+|        16384 | ON    | TF32/FP32   |      1.56 |
+|        32768 | ON    | TF32/FP32   |      1.66 |
+|        65536 | ON    | TF32/FP32   |      1.79 |
+|       131072 | ON    | TF32/FP32   |      1.83 |
+|         4096 | ON    | AMP         |      1.24 |
+|         8192 | ON    | AMP         |      1.32 |
+|        16384 | ON    | AMP         |      1.45 |
+|        32768 | ON    | AMP         |      1.61 |
+|        65536 | ON    | AMP         |      1.74 |
+|       131072 | ON    | AMP         |      1.76 |
 </details>
 </details>
 
 
 &nbsp;
 &nbsp;
@@ -1388,10 +1296,12 @@ NVIDIA A100 / DGX-2 (Ampere / Volta) inference speedup
 May 2022
 May 2022
 - Initial release
 - Initial release
 
 
-### Known issues
+November 2022
+- Moved batching and padding operations to preprocessing
+- Added support for prebatched samples during dataloading
+- Reduced throughput variance (previously appearing mainly during inference)
 
 
-- While benchmarking inference on a single GPU, sometimes throughput drops drastically in the middle of the epoch and remains low until the end of the epoch.
-- On a multi-GPU setup, the summary of throughput (in the last line of the logfile) is lower than it would result from each step`s throughput (sample/s). It is probably the case when a single GPU is slower than the one on the logging node. In this case, the overhead for synchronization before the final throughput calculation is higher than usual.
+### Known issues
 - The SIM model results are non-deterministic, even using the same random seed. The reason for this non-determinism is the [tf.math.unsorted_segment_sum](https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum) operation called within an optimization step. Its influence depends on categorical data distribution within a batch, and this issue is more severe for momentum-based optimizers. A potential solution is to use a deterministic version of this op which allows perfect reproduction, but is up to six times slower training.
 - The SIM model results are non-deterministic, even using the same random seed. The reason for this non-determinism is the [tf.math.unsorted_segment_sum](https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum) operation called within an optimization step. Its influence depends on categorical data distribution within a batch, and this issue is more severe for momentum-based optimizers. A potential solution is to use a deterministic version of this op which allows perfect reproduction, but is up to six times slower training.
 
 
 
 

+ 43 - 20
TensorFlow2/Recommendation/SIM/main.py

@@ -93,8 +93,8 @@ def init_logger(results_dir, filename):
 
 
 
 
 # In the future, select one of available dataloaders there (tfrecord, csv, etc...)
 # In the future, select one of available dataloaders there (tfrecord, csv, etc...)
-def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length, prefetch_size, repeat_count=0,
-                      drop_remainder=False, amp=False, disable_cache=False):
+def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length, prefetch_size, num_parallel_calls=None, repeat_count=0,
+                      drop_remainder=False, amp=False, disable_cache=False, prebatch_size=0):
     return get_dataloader_tfrecord(
     return get_dataloader_tfrecord(
         paths,
         paths,
         feature_spec=feature_spec,
         feature_spec=feature_spec,
@@ -105,7 +105,9 @@ def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length
         drop_remainder=drop_remainder,
         drop_remainder=drop_remainder,
         repeat_count=repeat_count,
         repeat_count=repeat_count,
         disable_cache=disable_cache,
         disable_cache=disable_cache,
-        prefetch_buffer_size=prefetch_size
+        prefetch_buffer_size=prefetch_size,
+        num_parallel_calls=num_parallel_calls,
+        prebatch_size=prebatch_size
     )
     )
 
 
 
 
@@ -243,10 +245,24 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
         local_targets.append(targets)
         local_targets.append(targets)
         local_total_losses.append(loss_dict["total_loss"])
         local_total_losses.append(loss_dict["total_loss"])
 
 
-    # concat all local variables into a single tensor
-    logits = tf.concat(local_logits, 0)
-    targets = tf.concat(local_targets, 0)
-    total_losses = tf.concat(local_total_losses, 0)
+    locals = [local_logits, local_targets, local_total_losses]
+    for i, local in enumerate(locals):
+
+        # wrap empty lists in tensor to allow tf.concat
+        if len(local) == 0:
+            local = tf.constant(local)
+
+        # concat all local variables into a single tensor
+        local = tf.concat(local, 0)
+
+        # for single element lists, tf.concat will produce shape=() instead of shape=(1,).
+        # reshape it for hvd.allgather to work
+        if len(local.shape) == 0:
+            local = tf.reshape(local, -1)
+
+        locals[i] = local
+    
+    logits, targets, total_losses = locals
 
 
     if distributed:
     if distributed:
         # gather from all nodes
         # gather from all nodes
@@ -455,6 +471,9 @@ def inference(model, data_iterator, benchmark, performance_calculator):
 @click.option(
 @click.option(
     "--global_batch_size", default=131072, help="Batch size used to train/eval the model.", type=int
     "--global_batch_size", default=131072, help="Batch size used to train/eval the model.", type=int
 )
 )
[email protected](
+    "--num_parallel_calls", default=None, help="Parallelism level for tf.data API. If None, heuristic based on number of CPUs and number of GPUs will be used."
+)
 @click.option(
 @click.option(
     "--epochs", default=3, help="Train for the following number of epochs.", type=int
     "--epochs", default=3, help="Train for the following number of epochs.", type=int
 )
 )
@@ -521,10 +540,8 @@ def inference(model, data_iterator, benchmark, performance_calculator):
 )
 )
 @click.option(
 @click.option(
     "--prefetch_train_size",
     "--prefetch_train_size",
-    default=-1,
+    default=10,
     help="Number of batches to prefetch in training. "
     help="Number of batches to prefetch in training. "
-    "If == 0: No prefetching is done. "
-    "If < 0: Prefetch size is set to train_dataset_size // global_batch_size. ",
 )
 )
 @click.option(
 @click.option(
     "--prefetch_test_size",
     "--prefetch_test_size",
@@ -532,9 +549,14 @@ def inference(model, data_iterator, benchmark, performance_calculator):
     help="Number of batches to prefetch in testing"
     help="Number of batches to prefetch in testing"
 )
 )
 @click.option(
 @click.option(
-    "--train_dataset_size",
-    default=11796480,
-    help="Number of train samples. Used to set prefetching size (see --prefetch_train_size for more information."
+    "--prebatch_train_size",
+    default=0,
+    help="Information about batch size applied during preprocessing to train dataset"
+)
[email protected](
+    "--prebatch_test_size",
+    default=0,
+    help="Information about batch size applied during preprocessing to test dataset"
 )
 )
 def main(
 def main(
         mode: str,
         mode: str,
@@ -554,6 +576,7 @@ def main(
         weight_decay: float,
         weight_decay: float,
         embedding_dim: int,
         embedding_dim: int,
         global_batch_size: int,
         global_batch_size: int,
+        num_parallel_calls: int,
         epochs: int,
         epochs: int,
         disable_cache: bool,
         disable_cache: bool,
         drop_remainder: bool,
         drop_remainder: bool,
@@ -570,7 +593,8 @@ def main(
         intra_op_parallelism: int,
         intra_op_parallelism: int,
         prefetch_train_size: int,
         prefetch_train_size: int,
         prefetch_test_size: int,
         prefetch_test_size: int,
-        train_dataset_size: int
+        prebatch_train_size: int,
+        prebatch_test_size: int
 ):
 ):
     hvd.init()
     hvd.init()
 
 
@@ -636,20 +660,19 @@ def main(
     # since each tfrecord file must include all of the features, it is enough to read first chunk for each split. 
     # since each tfrecord file must include all of the features, it is enough to read first chunk for each split. 
     train_files = [dataset_dir / file for file in feature_spec.source_spec[TRAIN_MAPPING][0][FILES_SELECTOR]]
     train_files = [dataset_dir / file for file in feature_spec.source_spec[TRAIN_MAPPING][0][FILES_SELECTOR]]
 
 
-    if prefetch_train_size < 0:
-        prefetch_train_size = train_dataset_size // global_batch_size
-
     data_iterator_train = get_data_iterator(
     data_iterator_train = get_data_iterator(
         train_files, feature_spec, batch_size, num_gpus, long_seq_length,
         train_files, feature_spec, batch_size, num_gpus, long_seq_length,
         repeat_count=repeat_count, drop_remainder=drop_remainder,
         repeat_count=repeat_count, drop_remainder=drop_remainder,
-        amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_train_size
+        amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_train_size,
+        num_parallel_calls=num_parallel_calls, prebatch_size=prebatch_train_size
     )
     )
 
 
     if mode == "train":
     if mode == "train":
         test_files = [dataset_dir / file for file in feature_spec.source_spec[TEST_MAPPING][0][FILES_SELECTOR]]
         test_files = [dataset_dir / file for file in feature_spec.source_spec[TEST_MAPPING][0][FILES_SELECTOR]]
         data_iterator_test = get_data_iterator(
         data_iterator_test = get_data_iterator(
             test_files, feature_spec, batch_size, num_gpus, long_seq_length,
             test_files, feature_spec, batch_size, num_gpus, long_seq_length,
-            amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_test_size
+            amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_test_size, num_parallel_calls=num_parallel_calls,
+            prebatch_size=prebatch_test_size
         )
         )
     else:
     else:
         data_iterator_test = []  # otherwise not used
         data_iterator_test = []  # otherwise not used
@@ -689,4 +712,4 @@ def main(
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    main()
+    main()

+ 33 - 0
TensorFlow2/Recommendation/SIM/preprocessing/ops.py

@@ -73,6 +73,39 @@ def _preserve_data(offsets, values, new_values):
             new_values[i] = values[rowid]
             new_values[i] = values[rowid]
 
 
 
 
[email protected]
+def _slice_rjust(max_elements, offsets, elements, new_offsets, new_elements):
+    rowid = numba.cuda.grid(1)
+    if rowid < new_offsets.size - 1:
+        row_size = min(offsets[rowid + 1] - offsets[rowid], max_elements)
+        offset = offsets[rowid + 1] - row_size
+        new_start = new_offsets[rowid + 1] - row_size
+
+        for i in range(row_size):
+            new_elements[new_start + i] = elements[offset + i]
+
+
+def slice_and_pad_left(seq_col, max_elements, pad_value=0):
+    c = seq_col._column
+    offsets = c.offsets.values
+    elements = c.elements.values
+
+    threads = THREADS
+    blocks = (offsets.size + threads - 1) // threads
+
+    new_offsets = cupy.arange(offsets.size, dtype=offsets.dtype) * max_elements
+
+    new_elements = cupy.full(
+        new_offsets[-1].item(), fill_value=pad_value, dtype=elements.dtype
+    )
+    _slice_rjust[blocks, threads](
+        max_elements, offsets, elements, new_offsets, new_elements
+    )
+
+    new_col = nvt_build_list_column(new_elements, new_offsets)
+    return new_col
+
+
 class ExplodeSequence:
 class ExplodeSequence:
     """
     """
     For each row create a new one with a subsequence of the original list columns.
     For each row create a new one with a subsequence of the original list columns.

+ 108 - 68
TensorFlow2/Recommendation/SIM/preprocessing/parquet_to_tfrecord.py

@@ -21,9 +21,11 @@ from functools import partial
 
 
 import click
 import click
 import pandas as pd
 import pandas as pd
+import numpy as np
 import tensorflow as tf
 import tensorflow as tf
 
 
 from sim.data.feature_spec import FeatureSpec
 from sim.data.feature_spec import FeatureSpec
+from sim.data.defaults import TRAIN_MAPPING, TEST_MAPPING, REMAINDER_FILENAME, FILES_SELECTOR
 
 
 # Docker image sets it to "python" for NVTabular purposes (bugfix), which slows down the script 20x
 # Docker image sets it to "python" for NVTabular purposes (bugfix), which slows down the script 20x
 os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
 os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
@@ -34,46 +36,31 @@ logging.basicConfig(
     format="[%(asctime)s] %(levelname)s: %(message)s",
     format="[%(asctime)s] %(levelname)s: %(message)s",
 )
 )
 
 
-
-def _int64_feature(value, islist=False):
-    """Returns an int64_list from a bool / enum / int / uint."""
-    if not islist:
-        value = [value]
-    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
-
-
-def process_chunk(df, sequential_data_start):
-    feature_values_lists = [df.iloc[:, i].values for i in range(sequential_data_start)]
-
-    for i in range(sequential_data_start, df.shape[1]):
-        values = df.iloc[:, i].values.tolist()
-        feature_values_lists.append(values)
-
-    return zip(*feature_values_lists)
-
-
-def prepare_record(sample, all_feature_names, sequential_data_start):
-
+def prepare_record(sample, all_feature_names, sequential_data_start, prebatch):
     feature = {}
     feature = {}
-    for idx, (f_name, data) in enumerate(zip(all_feature_names, sample)):
-        islist = idx >= sequential_data_start
-        feature[f_name] = _int64_feature(data, islist)
+    for idx, (f_name, data) in enumerate(zip(all_feature_names, sample.values())):
+        if idx >= sequential_data_start:
+            if prebatch:
+                data = np.array(data).flatten()
+        else:
+            if not prebatch:
+                data = [data]
 
 
-    record_bytes = tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()
-    return record_bytes
+        feature[f_name] = tf.train.Feature(int64_list=tf.train.Int64List(value=data))
 
 
+    return tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()
 
 
-def create_default_feature_spec(user_features_cardinalities, item_features_cardinalities,
-                                max_seq_len, tfrecord_output_dir, train_output_file, test_output_file):
+def save_records(output_path, records, base_output_path, feature_spec, mapping):
 
 
-    train_output = tfrecord_output_dir / train_output_file
-    test_output = tfrecord_output_dir / test_output_file
+    with tf.io.TFRecordWriter(str(output_path)) as file_writer:
+        for record_bytes in records:
+            file_writer.write(record_bytes)
 
 
-    f_spec = FeatureSpec.get_default_feature_spec(user_features_cardinalities, item_features_cardinalities,
-                                                  max_seq_len, train_output, test_output)
+    feature_spec.source_spec[mapping][0][FILES_SELECTOR].append(
+        str(output_path.relative_to(base_output_path))
+    )
 
 
-    save_path = tfrecord_output_dir / 'feature_spec.yaml'
-    f_spec.to_yaml(save_path)
+    logging.info(f'Created: {output_path}')
 
 
 
 
 @click.command()
 @click.command()
@@ -91,14 +78,15 @@ def create_default_feature_spec(user_features_cardinalities, item_features_cardi
 )
 )
 @click.option(
 @click.option(
     "--number_of_user_features",
     "--number_of_user_features",
-    required=True,
-    help="number of user specific features.",
+    default=1,
+    help="number of user specific features. Default is 1 for amazon books dataset (user_id).",
     type=int
     type=int
 )
 )
 @click.option(
 @click.option(
     "--max_seq_len",
     "--max_seq_len",
     default=100,
     default=100,
-    help="maximum possible length of history. (Entries will be padded to that length later)."
+    help="maximum possible length of history. (Entries will be padded to that length later).",
+    type=int
 )
 )
 @click.option(
 @click.option(
     "--n_proc",
     "--n_proc",
@@ -109,30 +97,57 @@ def create_default_feature_spec(user_features_cardinalities, item_features_cardi
 @click.option(
 @click.option(
     "--train_split_dir",
     "--train_split_dir",
     default='train',
     default='train',
-    help="name of directory within amazon dataset directory containing train data."
+    help="Name of directory within amazon dataset directory containing train data.",
+    type=str
 )
 )
 @click.option(
 @click.option(
     "--test_split_dir",
     "--test_split_dir",
     default='test',
     default='test',
-    help="name of directory within amazon dataset directory containing test data."
+    help="Name of directory within amazon dataset directory containing test data.",
+    type=str,
 )
 )
 @click.option(
 @click.option(
     "--metadata_file",
     "--metadata_file",
     default='metadata.json',
     default='metadata.json',
-    help="name of metadata file within amazon dataset directory (containing feature cardinalities)."
+    help="Name of metadata file within amazon dataset directory (containing feature cardinalities).",
+    type=str
 )
 )
 @click.option(
 @click.option(
-    "--train_output_file",
-    default='train.tfrecord',
-    help='name of train file within output directory.',
+    "--train_output_dir",
+    default='train',
+    help="Name of train directory within output directory.",
     type=str
     type=str
 )
 )
 @click.option(
 @click.option(
-    "--test_output_file",
-    default='test.tfrecord',
-    help='name of test file within output directory.',
+    "--test_output_dir",
+    default='test',
+    help='Name of test directory within output directory.',
     type=str
     type=str
 )
 )
[email protected](
+    "--train_parts",
+    default=8,
+    help="Number of output train files.",
+    type=int
+)
[email protected](
+    "--test_parts",
+    default=4,
+    help="Number of output test files.",
+    type=int
+)
[email protected](
+    "--prebatch_train_size",
+    default=0,
+    help='Apply batching to data in preprocessing. If prebatch_size == 0, no prebatching is done.',
+    type=int
+)
[email protected](
+    "--prebatch_test_size",
+    default=0,
+    help='Apply batching to data in preprocessing. If prebatch_size == 0, no prebatching is done.',
+    type=int
+)
 def main(
 def main(
         amazon_dataset_path: str,
         amazon_dataset_path: str,
         tfrecord_output_dir: str,
         tfrecord_output_dir: str,
@@ -142,8 +157,12 @@ def main(
         train_split_dir: str,
         train_split_dir: str,
         test_split_dir: str,
         test_split_dir: str,
         metadata_file: str,
         metadata_file: str,
-        train_output_file: str,
-        test_output_file: str
+        train_output_dir: str,
+        test_output_dir: str,
+        train_parts: int,
+        test_parts: int,
+        prebatch_train_size: int,
+        prebatch_test_size: int
 ):
 ):
     """
     """
     read_parquet()
     read_parquet()
@@ -160,11 +179,12 @@ def main(
         amazon_dataset_path / test_split_dir
         amazon_dataset_path / test_split_dir
     ]
     ]
 
 
-    os.makedirs(tfrecord_output_dir, exist_ok=True)
     output_splits = [
     output_splits = [
-        tfrecord_output_dir / train_output_file,
-        tfrecord_output_dir / test_output_file
+        tfrecord_output_dir / train_output_dir,
+        tfrecord_output_dir / test_output_dir
     ]
     ]
+    for split_dir in output_splits:
+        os.makedirs(split_dir, exist_ok=True)
 
 
     with open(amazon_dataset_path / metadata_file, 'r') as file:
     with open(amazon_dataset_path / metadata_file, 'r') as file:
         metadata = json.load(file)
         metadata = json.load(file)
@@ -176,35 +196,55 @@ def main(
     user_features_cardinalities = feature_cardinalities[:number_of_user_features]
     user_features_cardinalities = feature_cardinalities[:number_of_user_features]
     item_features_cardinalities = feature_cardinalities[number_of_user_features:]
     item_features_cardinalities = feature_cardinalities[number_of_user_features:]
 
 
-    create_default_feature_spec(user_features_cardinalities, item_features_cardinalities, max_seq_len,
-                                tfrecord_output_dir, train_output_file, test_output_file)
+    feature_spec = FeatureSpec.get_default_feature_spec(user_features_cardinalities, item_features_cardinalities, max_seq_len)
 
 
     number_of_item_features = len(item_features_cardinalities)
     number_of_item_features = len(item_features_cardinalities)
     sequential_data_start = 1 + number_of_user_features + number_of_item_features
     sequential_data_start = 1 + number_of_user_features + number_of_item_features
     all_feature_names = FeatureSpec.get_default_features_names(number_of_user_features, number_of_item_features)
     all_feature_names = FeatureSpec.get_default_features_names(number_of_user_features, number_of_item_features)
-    prepare_record_function = partial(prepare_record, all_feature_names=all_feature_names,
-                                      sequential_data_start=sequential_data_start)
+    
+    prebatch_per_split = [prebatch_train_size, prebatch_test_size]
+    parts_per_split = [train_parts, test_parts]
+    mappings = [TRAIN_MAPPING, TEST_MAPPING]
+
+    for mapping, input_dir, output_dir, output_parts, prebatch_size in zip(mappings, input_splits, output_splits, parts_per_split, prebatch_per_split):
+
+        prebatch = prebatch_size > 0
+        prepare_record_function = partial(prepare_record, all_feature_names=all_feature_names,
+                                        sequential_data_start=sequential_data_start, prebatch=prebatch)
+        save_records_function = partial(save_records, base_output_path=tfrecord_output_dir, feature_spec=feature_spec, mapping=mapping)
+
+        logging.info(f"Started conversion, will output to {output_dir}")
+
+        df = pd.read_parquet(input_dir, engine='pyarrow')
+
+        logging.info("Parquet loaded")
+
+        if prebatch:
+            df['batch_index'] = df.index // prebatch_size
+            df = df.groupby('batch_index').agg(list)
+            if len(df.iloc[-1, 0]) < prebatch_size:
+                remainder = df[-1:].to_dict('records')[0]
+                remainder = prepare_record_function(remainder)
 
 
-    for input_dir, output_file in zip(input_splits, output_splits):
+                df = df[:-1]
 
 
-        files = input_dir.glob("part.*.parquet")
-        def num_order(p): return int(p.name.split(".")[1])
-        paths = sorted(files, key=num_order)
+            logging.info("Prebatching applied")
 
 
-        logging.info(f"Started conversion, will output to {output_file}")
+        df = df.to_dict('records')
+        with multiprocessing.Pool(n_proc) as pool:
+            records = pool.map(prepare_record_function, df)
 
 
-        with tf.io.TFRecordWriter(str(output_file)) as file_writer:
-            with multiprocessing.Pool(n_proc) as pool:
-                for path in paths:
-                    df = pd.read_parquet(path)
+        logging.info("Records created")
 
 
-                    zipped_data = process_chunk(df, sequential_data_start)
+        records = np.array_split(records, output_parts)
+        for i, records_part in enumerate(records):
+            if len(records_part) > 0:
+                save_records_function(output_dir / f'part_{i}.tfrecord', records_part)
 
 
-                    records = pool.map(prepare_record_function, zipped_data)
-                    for record_bytes in records:
-                        file_writer.write(record_bytes)
+        if prebatch:
+            save_records_function(output_dir / REMAINDER_FILENAME, [remainder])
 
 
-                    logging.info(f"Processed {path}")
+    feature_spec.to_yaml(tfrecord_output_dir / 'feature_spec.yaml')
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":

+ 9 - 7
TensorFlow2/Recommendation/SIM/preprocessing/sim_preprocessing.py

@@ -25,7 +25,7 @@ import dask_cudf
 import rmm
 import rmm
 
 
 from preprocessing.io import load_metadata, load_review_data, save_metadata
 from preprocessing.io import load_metadata, load_review_data, save_metadata
-from preprocessing.ops import ExplodeSequence, add_negative_sequence, list_slice
+from preprocessing.ops import ExplodeSequence, add_negative_sequence, list_slice, slice_and_pad_left
 
 
 DASK_TRAIN_DATASET_CHUNKSIZE = 15_000
 DASK_TRAIN_DATASET_CHUNKSIZE = 15_000
 TRAIN_DATA_DIR = "train"
 TRAIN_DATA_DIR = "train"
@@ -179,11 +179,12 @@ def add_negative_sampling(df: cudf.DataFrame, sampling_df: cudf.DataFrame) -> cu
     return df
     return df
 
 
 
 
-def slice_sequences(df: cudf.DataFrame, max_elements: int) -> cudf.DataFrame:
-    df["item_sequence"] = list_slice(df["item_sequence"], -max_elements)
-    df["cat_sequence"] = list_slice(df["cat_sequence"], -max_elements)
-    df["neg_item_sequence"] = list_slice(df["neg_item_sequence"], -max_elements)
-    df["neg_cat_sequence"] = list_slice(df["neg_cat_sequence"], -max_elements)
+def pad_with_zeros(df: cudf.DataFrame, max_elements: int) -> cudf.DataFrame:
+    df["item_sequence"] = slice_and_pad_left(df["item_sequence"], max_elements)
+    df["cat_sequence"] = slice_and_pad_left(df["cat_sequence"], max_elements)
+    df["neg_item_sequence"] = slice_and_pad_left(df["neg_item_sequence"], max_elements)
+    df["neg_cat_sequence"] = slice_and_pad_left(df["neg_cat_sequence"], max_elements)
+
     return df
     return df
 
 
 
 
@@ -202,6 +203,7 @@ def create_train_dataset(
 
 
         df = explode_sequence(df, min_elements, max_elements)
         df = explode_sequence(df, min_elements, max_elements)
         df = add_negative_sampling(df, sampling_df)
         df = add_negative_sampling(df, sampling_df)
+        df = pad_with_zeros(df, max_elements)
         df = df.sort_values(by=["uid"])
         df = df.sort_values(by=["uid"])
         df.reset_index(drop=True, inplace=True)
         df.reset_index(drop=True, inplace=True)
         df = df[list(OUTPUT_META)]
         df = df[list(OUTPUT_META)]
@@ -222,7 +224,7 @@ def create_test_dataset(
     output_path: str,
     output_path: str,
 ) -> None:
 ) -> None:
     df = add_negative_sampling(df, sampling_df)
     df = add_negative_sampling(df, sampling_df)
-    df = slice_sequences(df, max_elements)
+    df = pad_with_zeros(df, max_elements)
     df = df.sort_values(by=["uid"])
     df = df.sort_values(by=["uid"])
     df.reset_index(drop=True, inplace=True)
     df.reset_index(drop=True, inplace=True)
     df = df[list(OUTPUT_META)]
     df = df[list(OUTPUT_META)]

+ 5 - 1
TensorFlow2/Recommendation/SIM/scripts/run_model.sh

@@ -30,6 +30,8 @@ Usage: bash scripts/run_model.sh
 --log_filename          Name of output log file within results_dir. Default: log.json.
 --log_filename          Name of output log file within results_dir. Default: log.json.
 --save_checkpoint_path  Path to output checkpoint after training.
 --save_checkpoint_path  Path to output checkpoint after training.
 --load_checkpoint_path  Path from which to restore checkpoint for inference or suspend/resume training.
 --load_checkpoint_path  Path from which to restore checkpoint for inference or suspend/resume training.
+--prebatch_train_size
+--prebatch_test_size
 EOF
 EOF
 }
 }
 
 
@@ -82,10 +84,12 @@ results_dir_option=$(get_option_or_use_default --results_dir $results_dir)
 log_filename_option=$(get_option_or_use_default --log_filename $log_filename)
 log_filename_option=$(get_option_or_use_default --log_filename $log_filename)
 save_checkpoint_path_option=$(get_option_or_use_default --save_checkpoint_path $save_checkpoint_path)
 save_checkpoint_path_option=$(get_option_or_use_default --save_checkpoint_path $save_checkpoint_path)
 load_checkpoint_path_option=$(get_option_or_use_default --load_checkpoint_path $load_checkpoint_path)
 load_checkpoint_path_option=$(get_option_or_use_default --load_checkpoint_path $load_checkpoint_path)
+prebatch_train_size_option=$(get_option_or_use_default --prebatch_train_size $prebatch_train_size)
+prebatch_test_size_option=$(get_option_or_use_default --prebatch_test_size $prebatch_test_size)
 
 
 command="mpiexec --allow-run-as-root --bind-to socket -np ${gpus} python main.py --dataset_dir ${data_path} --drop_remainder ${epochs_option} 
 command="mpiexec --allow-run-as-root --bind-to socket -np ${gpus} python main.py --dataset_dir ${data_path} --drop_remainder ${epochs_option} 
 ${xla_arg} ${amp_arg} ${benchmark_arg} ${mode_option} ${benchmark_steps_option} ${batch_size_option} ${results_dir_option} ${log_filename_option}
 ${xla_arg} ${amp_arg} ${benchmark_arg} ${mode_option} ${benchmark_steps_option} ${batch_size_option} ${results_dir_option} ${log_filename_option}
-${save_checkpoint_path_option} ${load_checkpoint_path_option}"
+${save_checkpoint_path_option} ${load_checkpoint_path_option} ${prebatch_train_size_option} ${prebatch_test_size_option}"
 
 
 printf "[INFO] Running:\n%s\n" "${command}"
 printf "[INFO] Running:\n%s\n" "${command}"
 # run
 # run

+ 62 - 21
TensorFlow2/Recommendation/SIM/sim/data/dataloader.py

@@ -18,12 +18,7 @@ from functools import partial
 import tensorflow as tf
 import tensorflow as tf
 
 
 from sim.data.defaults import (DIMENSIONS_SELECTOR, LABEL_CHANNEL, NEGATIVE_HISTORY_CHANNEL, POSITIVE_HISTORY_CHANNEL,
 from sim.data.defaults import (DIMENSIONS_SELECTOR, LABEL_CHANNEL, NEGATIVE_HISTORY_CHANNEL, POSITIVE_HISTORY_CHANNEL,
-                               TARGET_ITEM_FEATURES_CHANNEL, USER_FEATURES_CHANNEL)
-
-
-def _pad_ragged_infront(x, pad_length):
-    x = tf.reverse(x, axis=[1])
-    return tf.reverse(x.to_tensor(shape=(None, pad_length)), axis=[1])
+                               TARGET_ITEM_FEATURES_CHANNEL, USER_FEATURES_CHANNEL, REMAINDER_FILENAME)
 
 
 
 
 def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
 def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
@@ -32,20 +27,20 @@ def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
     features = feature_spec.feature_spec
     features = feature_spec.feature_spec
 
 
     user_features = {
     user_features = {
-        f_name: sample[f_name] for f_name in channel_spec[USER_FEATURES_CHANNEL]
+        f_name: tf.reshape(sample[f_name], [-1]) for f_name in channel_spec[USER_FEATURES_CHANNEL]
     }
     }
 
 
     target_item_features = {
     target_item_features = {
-        f_name: sample[f_name] for f_name in channel_spec[TARGET_ITEM_FEATURES_CHANNEL]
+        f_name: tf.reshape(sample[f_name], [-1]) for f_name in channel_spec[TARGET_ITEM_FEATURES_CHANNEL]
     }
     }
 
 
     padded_positive = {
     padded_positive = {
-        f_name: _pad_ragged_infront(sample[f_name], features[f_name][DIMENSIONS_SELECTOR][0])
+        f_name: tf.reshape(sample[f_name], [-1, features[f_name][DIMENSIONS_SELECTOR][0]]) 
         for f_name in channel_spec[POSITIVE_HISTORY_CHANNEL]
         for f_name in channel_spec[POSITIVE_HISTORY_CHANNEL]
     }
     }
 
 
     padded_negative = {
     padded_negative = {
-        f_name: _pad_ragged_infront(sample[f_name], features[f_name][DIMENSIONS_SELECTOR][0])
+        f_name: tf.reshape(sample[f_name], [-1, features[f_name][DIMENSIONS_SELECTOR][0]]) 
         for f_name in channel_spec[NEGATIVE_HISTORY_CHANNEL]
         for f_name in channel_spec[NEGATIVE_HISTORY_CHANNEL]
     }
     }
 
 
@@ -70,7 +65,7 @@ def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
     short_sequence_mask = history_mask[:, long_seq_length:]
     short_sequence_mask = history_mask[:, long_seq_length:]
 
 
     label_name = channel_spec[LABEL_CHANNEL][0]
     label_name = channel_spec[LABEL_CHANNEL][0]
-    target = sample[label_name]
+    target = tf.reshape(sample[label_name], [-1])
 
 
     return {
     return {
         "user_features": user_features,
         "user_features": user_features,
@@ -84,6 +79,14 @@ def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
     }, target
     }, target
 
 
 
 
+def split_prebatch(sample, split_into):
+    res = {}
+    for f_name, val in sample.items():
+        res[f_name] = tf.reshape(val, [split_into, -1])
+
+    return tf.data.Dataset.from_tensor_slices(res)
+
+
 def get_dataloader_tfrecord(
 def get_dataloader_tfrecord(
     file_paths,
     file_paths,
     feature_spec,
     feature_spec,
@@ -94,36 +97,74 @@ def get_dataloader_tfrecord(
     drop_remainder=False,
     drop_remainder=False,
     repeat_count=0,
     repeat_count=0,
     prefetch_buffer_size=90,
     prefetch_buffer_size=90,
-    disable_cache=False):
+    num_parallel_calls=None,
+    disable_cache=False,
+    prebatch_size=0
+    ):
 
 
     features = feature_spec.feature_spec
     features = feature_spec.feature_spec
+    prebatched = prebatch_size > 0
+
+    remainder_file = None
+    if file_paths[-1].name == REMAINDER_FILENAME:
+        remainder_file = file_paths[-1:]
+        file_paths = file_paths[:-1]
 
 
     tf_feature_spec = {}
     tf_feature_spec = {}
     for name, feature in features.items():
     for name, feature in features.items():
         dimensions = feature.get(DIMENSIONS_SELECTOR)
         dimensions = feature.get(DIMENSIONS_SELECTOR)
         if dimensions is None:
         if dimensions is None:
-            tf_feature_spec[name] = tf.io.FixedLenFeature([], tf.int64)
-        else:
-            tf_feature_spec[name] = tf.io.RaggedFeature(tf.int64)
+            dimensions = [1] if prebatched else []
+
+        if prebatched:
+            dimensions = dimensions.copy()
+            dimensions[0] *= prebatch_size
 
 
-    num_cpus = multiprocessing.cpu_count()
+        tf_feature_spec[name] = tf.io.FixedLenFeature(dimensions, tf.int64)
 
 
-    dataset = tf.data.TFRecordDataset(file_paths)
+    if num_parallel_calls is None:
+        num_cpus = multiprocessing.cpu_count()
+        num_parallel_calls = 4 * num_cpus // num_gpus
+
+    dataset = tf.data.TFRecordDataset(file_paths, num_parallel_reads=num_parallel_calls)
 
 
     dataset = dataset.shard(num_gpus, id)
     dataset = dataset.shard(num_gpus, id)
 
 
-    dataset = dataset.apply(
-        tf.data.experimental.dense_to_ragged_batch(batch_size, drop_remainder=drop_remainder)
+    splitting_function = None
+    if prebatched:
+        if batch_size >= prebatch_size:
+            batch_size = batch_size // prebatch_size
+        else:
+            split_into = prebatch_size // batch_size
+            splitting_function = partial(split_prebatch, split_into=split_into)
+            batch_size = 1
+
+    dataset = dataset.batch(
+        batch_size, drop_remainder=drop_remainder, num_parallel_calls=num_parallel_calls
     )
     )
 
 
     dataset = dataset.map(
     dataset = dataset.map(
         map_func=partial(tf.io.parse_example, features=tf_feature_spec),
         map_func=partial(tf.io.parse_example, features=tf_feature_spec),
-        num_parallel_calls=num_cpus//num_gpus
+        num_parallel_calls=num_parallel_calls
     )
     )
 
 
+    if splitting_function is not None:
+        dataset = dataset.flat_map(splitting_function)
+
+    if not drop_remainder and id == 0 and remainder_file is not None:
+        tf_feature_spec_remainder = {
+            name: tf.io.RaggedFeature(tf.int64) for name in tf_feature_spec
+        }
+        remainder = tf.data.TFRecordDataset(remainder_file)
+        remainder = remainder.map(
+            map_func=partial(tf.io.parse_example, features=tf_feature_spec_remainder)
+        )
+
+        dataset = dataset.concatenate(remainder)
+
     dataset = dataset.map(
     dataset = dataset.map(
         map_func=partial(_remap_column_values_tfrecord, feature_spec=feature_spec, long_seq_length=long_seq_length),
         map_func=partial(_remap_column_values_tfrecord, feature_spec=feature_spec, long_seq_length=long_seq_length),
-        num_parallel_calls=num_cpus//num_gpus
+        num_parallel_calls=num_parallel_calls
     )
     )
 
 
     if repeat_count > 0:
     if repeat_count > 0:

+ 2 - 0
TensorFlow2/Recommendation/SIM/sim/data/defaults.py

@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
+REMAINDER_FILENAME = 'remainder.tfrecord'
+
 USER_FEATURES_CHANNEL = 'user_features'
 USER_FEATURES_CHANNEL = 'user_features'
 TARGET_ITEM_FEATURES_CHANNEL = 'target_item_features'
 TARGET_ITEM_FEATURES_CHANNEL = 'target_item_features'
 POSITIVE_HISTORY_CHANNEL = 'positive_history'
 POSITIVE_HISTORY_CHANNEL = 'positive_history'

+ 3 - 4
TensorFlow2/Recommendation/SIM/sim/data/feature_spec.py

@@ -72,8 +72,7 @@ class FeatureSpec:
         return [label_feature_name] + user_features_names + item_features_names
         return [label_feature_name] + user_features_names + item_features_names
 
 
     @staticmethod
     @staticmethod
-    def get_default_feature_spec(user_features_cardinalities, item_features_cardinalities,
-                                 max_seq_len, train_output, test_output):
+    def get_default_feature_spec(user_features_cardinalities, item_features_cardinalities, max_seq_len):
 
 
         number_of_user_features = len(user_features_cardinalities)
         number_of_user_features = len(user_features_cardinalities)
         number_of_item_features = len(item_features_cardinalities)
         number_of_item_features = len(item_features_cardinalities)
@@ -127,9 +126,9 @@ class FeatureSpec:
                 {
                 {
                     'type': 'tfrecord',
                     'type': 'tfrecord',
                     'features': all_features_names,
                     'features': all_features_names,
-                    'files': [filepath.name]
+                    'files': []
                 }
                 }
-            ] for split, filepath in zip([TRAIN_MAPPING, TEST_MAPPING], [train_output, test_output])
+            ] for split in [TRAIN_MAPPING, TEST_MAPPING]
         }
         }
 
 
         return FeatureSpec(feature_spec=feature_spec, channel_spec=channel_spec, source_spec=source_spec)
         return FeatureSpec(feature_spec=feature_spec, channel_spec=channel_spec, source_spec=source_spec)