Bladeren bron

[SIM/TF2] Release new version of SIM model with prebatching support

Jakub Tomsia 3 jaren geleden
bovenliggende
commit
bf00fe1dbe

+ 0 - 1
TensorFlow2/Recommendation/SIM/.gitignore

@@ -15,4 +15,3 @@
 .ipynb_checkpoints/
 .idea/
 __pycache__
-results/

+ 239 - 329
TensorFlow2/Recommendation/SIM/README.md

@@ -28,6 +28,7 @@ This repository provides a script and recipe to train the SIM model to achieve s
     * [Command-line options](#command-line-options)
     * [Getting the data](#getting-the-data)
         * [Dataset guidelines](#dataset-guidelines)
+        * [Prebatching](#prebatching)
         * [BYO dataset](#byo-dataset)
             * [Channel definitions and requirements](#channel-definitions-and-requirements)
     * [Training process](#training-process)
@@ -78,7 +79,7 @@ In the author’s SIM implementation, the internals of submodels differs slightl
 List of implementation differences between original SIM code and DIN/DIEN/SIM papers
 </b></summary>
 
-- Batch normalization before NLP is not included in papers.
+- Batch normalization before MLP is not included in papers.
 - Batch normalization in code used `trainable=False` during the training phase.
 - ItemItemInteraction in DIN`s attention module in SIM implementation didn't correspond to activation unit inside DIN paper.
   - Element-wise subtraction and multiplications are fed to MLP, skipping outer product operation.
@@ -375,7 +376,7 @@ The following section lists the requirements that you need to meet in order to s
 
 This repository contains a Dockerfile that extends the TensorFflow2 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
 - [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
-- [TensorFlow2 21.10-py3](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow/tags) NGC container
+- [TensorFlow2 22.01-py3](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow/tags) NGC container
 - Supported GPUs:
   - [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
   - [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
@@ -417,9 +418,6 @@ To train your model using mixed or TF32 precision with Tensor Cores or using FP3
 5. Start preprocessing.
 
     For details of the required file format and certain preprocessing parameters refer to [BYO dataset](#byo-dataset).
-    
-    
-    `${NUMBER_OF_USER_FEATURES}` defines how many user specific features are present in dataset. If using default Amazon Books dataset and `sim_preprocessing` script (as shown below), this parameter should be set to <b>1</b> (in this case, the only user specific features is <b>user_id</b>. Other features are item specific).
 
    ```bash
    python preprocessing/sim_preprocessing.py \
@@ -428,8 +426,7 @@ To train your model using mixed or TF32 precision with Tensor Cores or using FP3
 
    python preprocessing/parquet_to_tfrecord.py \
     --amazon_dataset_path ${PARQUET_PATH} \
-    --tfrecord_output_dir ${TF_RECORD_PATH} \
-    --number_of_user_features ${NUMBER_OF_USER_FEATURES}
+    --tfrecord_output_dir ${TF_RECORD_PATH}
    ```
 
 6. Start training (`${GPU}` is an arbitrary number of GPUs to be used).
@@ -496,10 +493,11 @@ The `main.py` script parameters are detailed in the following table.
 | training        | drop_remainder            | Drop remainder batch for training set (flag)                            | False                     |
 | training        | disable_cache             | Disable dataset caching after the first time it is iterated over (flag)        | False                     |
 | training        | repeat_count              | Repeat training dataset this number of times                            | 0                         |
-| training | prefetch_train_size |  Number of batches to prefetch in training. | -1 |
-| training | prefetch_test_size |  Number of batches to prefetch in evaluation. | -1 |
-| training | train_dataset_size |  Number of samples in training dataset (used to determine prefetch_train_size when --prefetch_train_size < 0) | 11796480 |
+| training | prefetch_train_size |  Number of batches to prefetch in training. | 10 |
+| training | prefetch_test_size |  Number of batches to prefetch in evaluation. | 2 |
 | training | long_seq_length | Determines the long history - short history split of history features | 90 |
+| training | prebatch_train_size | Batch size of batching applied during preprocessing to train dataset. | 0 |
+| training | prebatch_test_size | Batch size of batching applied during preprocessing to test dataset. | 0 |
 | results         | results_dir               | Path to the model result files storage                                  | /tmp/sim                  |
 | results         | log_filename              | Name of the file to store logger output                                 | log.json                  |
 | results         | save_checkpoint_path      | Directory to save model checkpoints                                     | ""                        |
@@ -511,8 +509,10 @@ The `main.py` script parameters are detailed in the following table.
 | run mode        | affinity                  | Type of CPU affinity                                                    | socket_unique_interleaved |
 | run mode        | inter_op_parallelism      | Number of inter op threads                                              | 0                         |
 | run mode        | intra_op_parallelism      | Number of intra op threads                                              | 0                         |
+| run mode        | num_parallel_calls        | Parallelism level for tf.data API. If None, heuristic based on number of CPUs and number of GPUs will be used   |   None  |
 | reproducibility | seed                      | Random seed                                                             | -1                        |
 
+
 ### Command-line options
 
 To view the full list of available options and their descriptions, use the `--help` command-line option, for example:
@@ -534,6 +534,56 @@ The preprocessing steps applied to the raw data include:
 - Determining embedding table sizes for categorical features needed to construct a model
 - Filter users for training split based on their number of interactions (discard users with less than 20 interactions)
 
+#### Prebatching
+
+Preprocessing scripts allow to apply batching prior to the model`s dataloader. This reduces the size of produced TFrecord files and speeds up dataloading.
+To do so, specify `--prebatch_train_size` and `--prebatch_test_size` while converting data using `scripts/parquet_to_tfrecord.py`. Later, while using the `main.py` script, pass the information about applied prebatch size via the same parameters.
+
+Example
+
+Start preprocessing from step 5. from [Quick Start Guide](#quick-start-guide):
+
+```bash
+python preprocessing/sim_preprocessing.py \
+--amazon_dataset_path ${RAW_DATASET_PATH} \
+--output_path ${PARQUET_PATH}
+
+python preprocessing/parquet_to_tfrecord.py \
+--amazon_dataset_path ${PARQUET_PATH} \
+--tfrecord_output_dir ${TF_RECORD_PATH} \
+--prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+--prebatch_train_size ${PREBATCH_TEST_SIZE}
+```
+
+And then train the model (step 6.):
+
+```bash
+mpiexec --allow-run-as-root --bind-to socket -np ${GPU} python main.py \
+--dataset_dir ${TF_RECORD_PATH} \
+--mode train \
+--model_type sim \
+--embedding_dim 16 \
+--drop_remainder \
+--optimizer adam \
+--lr 0.01 \
+--epochs 3 \
+--global_batch_size 131072 \
+--amp \
+--prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+--prebatch_train_size ${PREBATCH_TEST_SIZE}
+```
+
+<details>
+<summary><b>Prebatching details</b></summary>
+
+- The last batch for each split will pe saved to the separate file `remainder.tfrecord` unless there are enough samples to form a full batch.
+- Final batch size used in main script can be a multiple of prebatch size.
+- Final batch size used in main script can be a divider of prebatch size. In this case, when using multi GPU training, the number of batches received by each worker can be greater than 1 thus resulting in error during allgather operation. Dataset size, batch size and prebatch size have to be chosen with that limitation in mind.
+- For the orignal Amazon Books Dataset, parameters were set to PREBATCH_TRAIN_SIZE = PREBATCH_TEST_SIZE = 4096 for performance benchmarking purposes.
+</details>
+
+&nbsp;
+
 #### BYO dataset 
 
 This implementation supports using other datasets thanks to BYO dataset functionality. 
@@ -676,7 +726,7 @@ source_spec:
     type: tfrecord
 ```
 
-`dimensions` should contain the length of the history to which the entries will be padded.
+`dimensions` should contain the length of the sequencial features.
 
 Note that corresponsive features in `negative_history`, `positive_history`, `target_item_features` need to be listed in the same order in channel spec in each channel since they share embedding tables in the model. (for example `item_id` needs to be first and `cat_id` second). 
 
@@ -705,7 +755,7 @@ For performance reasons, the only supported dataset type is tfrecord.
 
 ### Training process
 
-Training can be run using `main.py` script by specifying the `--mode train` parameter. The speed of training is measured by throughput, that is, the number of samples processed per second. Evaluation is based on the [Area under ROC Curve (ROC AUC)](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) metric. Model checkpoints may be stored using Checkpoint manager as specified via (...). Training and inference logs are saved to a directory specified via the `--results_dir` parameter. Mixed precision training is supported via the `--amp` flag. Multi-GPU training is performed using mpiexec and Horovod libraries.
+Training can be run using `main.py` script by specifying the `--mode train` parameter. The speed of training is measured by throughput, that is, the number of samples processed per second. Evaluation is based on the [Area under ROC Curve (ROC AUC)](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) metric. Model checkpoints may be stored using Checkpoint manager via the `--save_checkpoint_path` and `--load_checkpoint_path` parameters. Training and inference logs are saved to a directory specified via the `--results_dir` parameter. Mixed precision training is supported via the `--amp` flag. Multi-GPU training is performed using mpiexec and Horovod libraries.
 
 ### Inference process
 
@@ -778,7 +828,9 @@ mpiexec --allow-run-as-root --bind-to socket -np ${GPU} python main.py \
   --global_batch_size 131072 \
   --drop_remainder \
   --amp \
-  --benchmark
+  --benchmark \
+  --prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+  --prebatch_test_size ${PREBATCH_TEST_SIZE}
 ```
 
 Equivalent:
@@ -787,7 +839,9 @@ scripts/run_model.sh \
   --data_path ${TF_RECORD_PATH} \
   --gpus ${GPU} \
   --amp 1 \
-  --benchmark 1 
+  --benchmark 1 \
+  --prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+  --prebatch_test_size ${PREBATCH_TEST_SIZE}
 ```
 
 #### Inference performance benchmark
@@ -801,7 +855,9 @@ mpiexec --allow-run-as-root --bind-to socket -np ${GPU} python main.py \
   --model_type sim \
   --global_batch_size 131072 \
   --amp \
-  --benchmark
+  --benchmark \
+  --prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+  --prebatch_test_size ${PREBATCH_TEST_SIZE}
 ```
 
 Equivalent:
@@ -811,7 +867,8 @@ scripts/run_model.sh \
   --gpus ${GPU} \
   --amp 1 \
   --benchmark 1 \
-  --mode inference
+  --prebatch_train_size ${PREBATCH_TRAIN_SIZE} \
+  --prebatch_test_size ${PREBATCH_TEST_SIZE}
 ```
 
 ### Results
@@ -820,7 +877,7 @@ The following sections provide details on how we achieved our performance and ac
 
 #### Training accuracy results
 
-Our results were obtained by running the `run_model.sh` bash script in the TensorFlow2 21.10-py3 NGC container. Experiments were run on 1 and 8 GPUs, with FP32/TF32 Precision and AMP and with XLA-OFF/XLA-ON. Other parameters were set to defaults.
+Our results were obtained by running the `run_model.sh` bash script in the TensorFlow2 21.10-py3 NGC container. Experiments were run on 1 and 8 GPUs, with FP32/TF32 Precision and AMP and with XLA-OFF/XLA-ON. Dataset was prebatched with the size of 16384. Other parameters were set to defaults.
 
 There were 10 runs for each configuration. In the `Training accuracy` sections, average values are reported. In the `Training stability` sections, values from all runs are included in plots.
 
@@ -962,7 +1019,7 @@ Figure 8. ROC curve for different configurations of Ampere/Volta, 1/8 GPUs, doub
 
 #### Training performance results
 
-Our results were obtained by running the `scripts/run_model.sh` script in the TensorFlow2 21.10-py3 NGC container. 
+Our results were obtained by running the `scripts/run_model.sh` script in the TensorFlow2 21.10-py3 NGC container. Dataset was prebatched with the size of 16384.
 
 Numbers were averaged over 10 separate runs for each configuration.
 
@@ -974,12 +1031,12 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
 
 ##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
 
-|GPUs |XLA  |Throughput - TF32 (samples/s)  |Throughput - mixed precision (samples/s) |Throughput speedup (mixed precision / TF32)  | Strong scaling - TF32 | Strong scaling - mixed precision |
-|-----|-----|--------------------|------------------------------|---------------------------------------------|-----------|-------------|
-|1    |OFF  |381211.31           |484360.65                     |1.27    | 1.00 | 1.00 |
-|1    |ON   |462012.86           |571727.91                     |1.24    | 1.00 | 1.00 |
-|8    |OFF  |2304284.08          |2475445.94                    |1.07   | 6.04 | 5.11 |
-|8    |ON   |2679300.61          |3006370.96                    |1.12   | 5.80 | 5.26 |
+|   GPUs |   XLA |   Throughput - TF32 (samples/s) |   Throughput - mixed precision (samples/s) |   Throughput speedup (mixed precision / TF32) |   Strong scaling - TF32 |   Strong scaling - mixed precision |
+|-------:|------:|--------------------------------:|-------------------------------------------:|----------------------------------------------:|------------------------:|-----------------------------------:|
+|      1 |     OFF |                       377254.65 |                                  479921.54 |                                          1.27 |                    1.00 |                               1.00 |
+|      1 |     ON |                       455724.01 |                                  565221.04 |                                          1.24 |                    1.00 |                               1.00 |
+|      8 |     OFF |                      2161681.55 |                                 2603489.60 |                                          1.20 |                    5.73 |                               5.42 |
+|      8 |     ON |                      2662368.18 |                                 2979441.80 |                                          1.12 |                    5.84 |                               5.27 |
 
 <details>
 <summary><b>
@@ -990,24 +1047,24 @@ For each configuration of parameters present in the table, the `Speedup` column
 
 |GPUs |Precision      |Speedup |
 |-----|---------------|--------|
-|1    |TF32           |1.212   |
-|1    |AMP            |1.180   |
-|8    |TF32           |1.163   |
-|8    |AMP            |1.214   |
+|1    |TF32           |1.208   |
+|1    |AMP            |1.178   |
+|8    |TF32           |1.232   |
+|8    |AMP            |1.119   |
 </details>
 
 &nbsp;
 
 ##### Training performance: NVIDIA DGX-2 (16x V100 32GB)
 
-|GPUs |XLA  |Throughput - FP32 (samples/s)  |Throughput - mixed precision (samples/s) |Throughput speedup (mixed precision / FP32) | Strong scaling - FP32 | Strong scaling - mixed precision |
-|-----|-----|--------------------|------------------------------|---------------------------------------------|----------|-------------|
-|1    |OFF  |210772.27           |312580.01                     |1.48                                         | 1.00 | 1.00 |
-|1    |ON   |248514.27           |358305.52                     |1.44                                         | 1.00 | 1.00 |
-|8    |OFF  |1357463.39          |1785361.62                    |1.32                                         | 6.44 | 5.71 |
-|8    |ON   |1584757.09          |2091403.04                    |1.32                                         | 7.52 | 6.69 |
-|16   |OFF  |2319719.76          |2837309.15                    |1.22                                         | 11.00 | 9.08 |
-|16   |ON   |2681789.69          |3168488.89                    |1.18                                         | 12.73 | 10.14 |
+|   GPUs |   XLA |   Throughput - FP32 (samples/s) |   Throughput - mixed precision (samples/s) |   Throughput speedup (mixed precision / FP32) |   Strong scaling - FP32 |   Strong scaling - mixed precision |
+|-------:|------:|--------------------------------:|-------------------------------------------:|----------------------------------------------:|------------------------:|-----------------------------------:|
+|      1 |     OFF |                       209376.38 |                                  309752.48 |                                          1.48 |                    1.00 |                               1.00 |
+|      1 |     ON |                       245414.62 |                                  348945.59 |                                          1.42 |                    1.00 |                               1.00 |
+|      8 |     OFF |                      1310239.01 |                                 1689602.79 |                                          1.29 |                    6.26 |                               5.45 |
+|      8 |     ON |                      1483120.32 |                                 1962226.32 |                                          1.32 |                    6.04 |                               5.62 |
+|     16 |     OFF |                      2127221.65 |                                 2555926.79 |                                          1.20 |                   10.16 |                               8.25 |
+|     16 |     ON |                      2450499.40 |                                 2788997.07 |                                          1.14 |                    9.99 |                               7.99 |
 
 <details>
 <summary><b>
@@ -1018,12 +1075,12 @@ For each configuration of parameters present in the table, the `Speedup` column
 
 |GPUs |AMP                 |Speedup        |
 |-----|--------------------|---------------|
-|1    |FP32                |1.179          |
-|1    |AMP                 |1.146          |
-|8    |FP32                |1.167          |
-|8    |AMP                 |1.171          |
-|16   |FP32                |1.156          |
-|16   |AMP                 |1.117          |
+|1    |FP32                |1.172          |
+|1    |AMP                 |1.127          |
+|8    |FP32                |1.132          |
+|8    |AMP                 |1.161          |
+|16   |FP32                |1.152          |
+|16   |AMP                 |1.091          |
 </details>
 
 &nbsp;
@@ -1033,16 +1090,17 @@ For each configuration of parameters present in the table, the `Speedup` column
 NVIDIA DGX A100 / DGX-2 (Ampere / Volta) training speedup
 </b></summary>
 
-|GPUs |XLA    |Precision       |Speedup|
-|-----|-------|---------------|-------|
-|1    |OFF    |TF32/FP32      |1.809  |
-|1    |OFF    |AMP            |1.550  |
-|1    |ON     |TF32/FP32      |1.860  |
-|1    |ON     |AMP            |1.596  |
-|8    |OFF    |TF32/FP32      |1.697  |
-|8    |OFF    |AMP            |1.387  |
-|8    |ON     |TF32/FP32      |1.691  |
-|8    |ON     |AMP            |1.437  |
+
+|   GPUs |   XLA | Precision   |   Speedup |
+|-------:|------:|:------------|----------:|
+|      1 |     OFF | TF32/FP32   |     1.802 |
+|      1 |     OFF | AMP         |     1.549 |
+|      1 |     ON | TF32/FP32   |     1.857 |
+|      1 |     ON | AMP         |     1.620 |
+|      8 |     OFF | TF32/FP32   |     1.650 |
+|      8 |     OFF | AMP         |     1.541 |
+|      8 |     ON | TF32/FP32   |     1.795 |
+|      8 |     ON | AMP         |     1.518 |
 
 </details>
 
@@ -1060,74 +1118,44 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
 
 ##### Inference performance: NVIDIA DGX A100 (8x A100 80GB)
 
-|GPUs |Global batch size|XLA  |Throughput - TF32 (samples/s)|Throughput - mixed precision (samples/s)|Throughput speedup (mixed precision / TF32)  | Strong scaling - TF32 | Strong scaling - mixed precision |
-|-----|----------|-----|---------------|----------------------------|---------------------------------------------|----------------|---------|
-|1    |4096      |ON   |561967.1       |535674.63                   |0.95                                         | 1.00 | 1.00 |
-|1    |8192      |ON   |670885.47      |758801.43                   |1.13                                         | 1.00 | 1.00 |
-|1    |16384     |ON   |788890.79      |920695.88                   |1.17                                         | 1.00 | 1.00 |
-|1    |32768     |ON   |855056.39      |1035530.23                  |1.21                                         | 1.00 | 1.00 |
-|1    |65536     |ON   |918649.98      |1081408.05                  |1.18                                         | 1.00 | 1.00 |
-|1    |131072    |ON   |918555.37      |771119.78                   |0.84                                         | 1.00 | 1.00 |
-|8    |4096      |ON   |1130031.99     |935848.52                   |0.83                                         | 2.01 | 1.75 |
-|8    |8192      |ON   |2246441.94     |1885511.32                  |0.84                                         | 3.64 | 2.48 |
-|8    |16384     |ON   |4000071.31     |3303417.5                   |0.83                                         | 5.07 | 3.59 |
-|8    |32768     |ON   |5479754.01     |5762298.42                  |1.05                                         | 6.41 | 5.56 |
-|8    |65536     |ON   |6736333.91     |7869825.77                  |1.17                                         | 7.33 | 7.28 |
-|8    |131072    |ON   |7598665.72     |9002545.49                  |1.18                                         | 8.27 | 11.67 |
+|   Batch Size |   XLA |   Throughput - TF32 (samples/s) |   Throughput - mixed precision (samples/s) |   Throughput speedup (mixed precision / TF32) |
+|--------------------:|------:|--------------------------------:|-------------------------------------------:|----------------------------------------------:|
+|                4096 |     ON |                       618547.45 |                                  669640.65 |                                          1.08 |
+|                8192 |     ON |                       722801.14 |                                  849101.88 |                                          1.17 |
+|               16384 |     ON |                       859418.77 |                                 1051361.67 |                                          1.22 |
+|               32768 |     ON |                       976771.70 |                                 1269000.97 |                                          1.30 |
+|               65536 |     ON |                      1082688.51 |                                 1444729.52 |                                          1.33 |
+|              131072 |     ON |                      1094733.64 |                                 1483542.86 |                                          1.36 |
 
 <details>
 <summary><b> Complete table of DGX A100 inference performance results </b></summary>
 
-|GPUSs|Global Batch Size   |XLA    |Precision      |Throughput  (samples/s)           |
-|-----|--------------------|-------|---------------|-----------------------|
-|1    |4096                |OFF    |TF32           |585246.51 ± 10513.06   |
-|1    |8192                |OFF    |TF32           |750729.14 ± 17029.41   |
-|1    |16384               |OFF    |TF32           |803593.59 ± 11207.58   |
-|1    |32768               |OFF    |TF32           |822162.85 ± 5071.85    |
-|1    |65536               |OFF    |TF32           |775748.42 ± 36821.04   |
-|1    |131072              |OFF    |TF32           |644740.49 ± 31148.79   |
-|1    |4096                |OFF    |AMP            |516164.09 ± 9916.80    |
-|1    |8192                |OFF    |AMP            |778740.41 ± 19384.36   |
-|1    |16384               |OFF    |AMP            |932211.18 ± 20331.07   |
-|1    |32768               |OFF    |AMP            |990696.89 ± 11554.34   |
-|1    |65536               |OFF    |AMP            |715678.16 ± 30944.63   |
-|1    |131072              |OFF    |AMP            |611740.50 ± 21392.81   |
-|1    |4096                |ON     |TF32           |561967.10 ± 18100.55   |
-|1    |8192                |ON     |TF32           |670885.47 ± 11149.51   |
-|1    |16384               |ON     |TF32           |788890.79 ± 10058.99   |
-|1    |32768               |ON     |TF32           |855056.39 ± 14349.13   |
-|1    |65536               |ON     |TF32           |918649.98 ± 7571.32    |
-|1    |131072              |ON     |TF32           |918555.37 ± 15036.89   |
-|1    |4096                |ON     |AMP            |535674.63 ± 14003.35   |
-|1    |8192                |ON     |AMP            |758801.43 ± 15225.76   |
-|1    |16384               |ON     |AMP            |920695.88 ± 15325.29   |
-|1    |32768               |ON     |AMP            |1035530.23 ± 16055.40  |
-|1    |65536               |ON     |AMP            |1081408.05 ± 41906.29  |
-|1    |131072              |ON     |AMP            |771119.78 ± 79589.50   |
-|8    |4096                |OFF    |TF32           |765154.17 ± 30582.87   |
-|8    |8192                |OFF    |TF32           |1396414.24 ± 99987.01  |
-|8    |16384               |OFF    |TF32           |2281597.86 ± 77483.79  |
-|8    |32768               |OFF    |TF32           |3555014.42 ± 145944.33 |
-|8    |65536               |OFF    |TF32           |4792413.60 ± 203285.21 |
-|8    |131072              |OFF    |TF32           |5941195.01 ± 182519.72 |
-|8    |4096                |OFF    |AMP            |642706.11 ± 28063.45   |
-|8    |8192                |OFF    |AMP            |1197789.38 ± 47262.95  |
-|8    |16384               |OFF    |AMP            |1961353.19 ± 49818.70  |
-|8    |32768               |OFF    |AMP            |3267263.60 ± 130680.70 |
-|8    |65536               |OFF    |AMP            |4847783.16 ± 257991.99 |
-|8    |131072              |OFF    |AMP            |6413842.15 ± 289543.64 |
-|8    |4096                |ON     |TF32           |1130031.99 ± 75271.24  |
-|8    |8192                |ON     |TF32           |2246441.94 ± 26132.90  |
-|8    |16384               |ON     |TF32           |4000071.31 ± 48054.68  |
-|8    |32768               |ON     |TF32           |5479754.01 ± 170421.20 |
-|8    |65536               |ON     |TF32           |6736333.91 ± 153745.68 |
-|8    |131072              |ON     |TF32           |7598665.72 ± 174188.78 |
-|8    |4096                |ON     |AMP            |935848.52 ± 14583.48   |
-|8    |8192                |ON     |AMP            |1885511.32 ± 22206.00  |
-|8    |16384               |ON     |AMP            |3303417.50 ± 210306.61 |
-|8    |32768               |ON     |AMP            |5762298.42 ± 140412.56 |
-|8    |65536               |ON     |AMP            |7869825.77 ± 305838.69 |
-|8    |131072              |ON     |AMP            |9002545.49 ± 438204.32 |
+|   Batch Size | XLA   | Precision   | Throughput  (samples/s)   |
+|-------------:|:------|:------------|:--------------------------|
+|         4096 | OFF   | TF32        | 708349.73 ± 14161.58      |
+|         8192 | OFF   | TF32        | 873335.82 ± 8539.56       |
+|        16384 | OFF   | TF32        | 937987.79 ± 12114.34      |
+|        32768 | OFF   | TF32        | 943313.07 ± 8631.81       |
+|        65536 | OFF   | TF32        | 960794.46 ± 7388.45       |
+|       131072 | OFF   | TF32        | 966245.27 ± 8637.82       |
+|         4096 | OFF   | AMP         | 645394.94 ± 14844.27      |
+|         8192 | OFF   | AMP         | 919410.07 ± 11355.28      |
+|        16384 | OFF   | AMP         | 1136346.66 ± 14529.91     |
+|        32768 | OFF   | AMP         | 1216810.45 ± 21013.12     |
+|        65536 | OFF   | AMP         | 1287305.05 ± 19373.18     |
+|       131072 | OFF   | AMP         | 1298478.97 ± 10733.67     |
+|         4096 | ON    | TF32        | 618547.45 ± 6569.97       |
+|         8192 | ON    | TF32        | 722801.14 ± 9448.19       |
+|        16384 | ON    | TF32        | 859418.77 ± 10012.61      |
+|        32768 | ON    | TF32        | 976771.70 ± 13377.36      |
+|        65536 | ON    | TF32        | 1082688.51 ± 8523.55      |
+|       131072 | ON    | TF32        | 1094733.64 ± 11157.18     |
+|         4096 | ON    | AMP         | 669640.65 ± 9319.68       |
+|         8192 | ON    | AMP         | 849101.88 ± 14068.04      |
+|        16384 | ON    | AMP         | 1051361.67 ± 15310.42     |
+|        32768 | ON    | AMP         | 1269000.97 ± 23971.56     |
+|        65536 | ON    | AMP         | 1444729.52 ± 18011.54     |
+|       131072 | ON    | AMP         | 1483542.86 ± 6751.29      |
 
 </details>
 
@@ -1138,32 +1166,20 @@ DGX A100 XLA-ON / XLA-OFF inference Speedup
 
 For each configuration of parameters present in the table, the `Speedup` column shows the speedup achieved by turning on XLA.
 
-|GPUs |Global Batch Size   |Precision      |Speedup |
-|-----|--------------------|---------------|--------|
-|1    |4096                |TF32           |0.960   |
-|1    |8192                |TF32           |0.894   |
-|1    |16384               |TF32           |0.982   |
-|1    |32768               |TF32           |1.040   |
-|1    |65536               |TF32           |1.184   |
-|1    |131072              |TF32           |1.425   |
-|1    |4096                |AMP            |1.038   |
-|1    |8192                |AMP            |0.974   |
-|1    |16384               |AMP            |0.988   |
-|1    |32768               |AMP            |1.045   |
-|1    |65536               |AMP            |1.511   |
-|1    |131072              |AMP            |1.261   |
-|8    |4096                |TF32           |1.477   |
-|8    |8192                |TF32           |1.609   |
-|8    |16384               |TF32           |1.753   |
-|8    |32768               |TF32           |1.541   |
-|8    |65536               |TF32           |1.406   |
-|8    |131072              |TF32           |1.279   |
-|8    |4096                |AMP            |1.456   |
-|8    |8192                |AMP            |1.574   |
-|8    |16384               |AMP            |1.684   |
-|8    |32768               |AMP            |1.764   |
-|8    |65536               |AMP            |1.623   |
-|8    |131072              |AMP            |1.404   |
+|Batch Size   |Precision      |Speedup |
+|--------------------|---------------|--------|
+|4096                |TF32           |0.873   |
+|8192                |TF32           |0.828   |
+|16384               |TF32           |0.916   |
+|32768               |TF32           |1.035   |
+|65536               |TF32           |1.127   |
+|131072              |TF32           |1.133   |
+|4096                |AMP            |1.038   |
+|8192                |AMP            |0.924   |
+|16384               |AMP            |0.925   |
+|32768               |AMP            |1.043   |
+|65536               |AMP            |1.187   |
+|131072              |AMP            |1.143   |
 
 </details>
 
@@ -1171,153 +1187,69 @@ For each configuration of parameters present in the table, the `Speedup` column
 
 ##### Inference performance: NVIDIA DGX-2 (16x V100 32GB)
 
-|GPUs |Global batch size|XLA  |Throughput - FP32 (samples/s)|Throughput - mixed precision (samples/s)|Throughput speedup (mixed precision / FP32)  | Strong scaling - FP32 | Strong scaling - mixed precision |
-|-----|----------|-----|---------------|----------------------------|---------------------------------------------|--------|--------|
-|1    |4096      |ON   |403479.95      |479051.62                   |1.19                                         | 1.00 | 1.00 |
-|1    |8192      |ON   |480491.12      |600002.95                   |1.25                                         | 1.00 | 1.00 |
-|1    |16384     |ON   |538737.44      |713203.59                   |1.32                                         | 1.00 | 1.00 |
-|1    |32768     |ON   |580958.93      |790782.1                    |1.36                                         | 1.00 | 1.00 |
-|1    |65536     |ON   |586275.07      |818038.44                   |1.40                                         | 1.00 | 1.00 |
-|1    |131072    |ON   |613524.11      |734034.26                   |1.20                                         | 1.00 | 1.00 |
-|8    |4096      |ON   |1059775.22     |909719.3                    |0.86                                         | 2.63 | 1.90 |
-|8    |8192      |ON   |1845819.99     |1752510.62                  |0.95                                         | 3.84 | 2.92 |
-|8    |16384     |ON   |2801114.77     |2898423.08                  |1.03                                         | 5.20 | 4.06 |
-|8    |32768     |ON   |3396766.27     |4102026.01                  |1.21                                         | 5.85 | 5.19 |
-|8    |65536     |ON   |3911994.39     |4725023.23                  |1.21                                         | 6.67 | 5.78 |
-|8    |131072    |ON   |4197603.74     |5413542.58                  |1.29                                         | 6.84 | 7.38 |
-|16   |4096      |ON   |1142272.86     |924525.38                   |0.81                                         | 2.83 | 1.93 |
-|16   |8192      |ON   |2068920.7      |1917814.81                  |0.93                                         | 4.31 | 3.20 |
-|16   |16384     |ON   |3091676.83     |3496153.45                  |1.13                                         | 5.74 | 4.90 |
-|16   |32768     |ON   |5132772.75     |5063615.77                  |0.99                                         | 8.84 | 6.40 |
-|16   |65536     |ON   |6553882.87     |8247475.75                  |1.26                                         | 11.18 | 10.08 |
-|16   |131072    |ON   |7555906.17     |9571965.84                  |1.27                                         | 12.32 | 13.04 |
+|   Batch Size |   XLA |   Throughput - FP32 (samples/s) |   Throughput - mixed precision (samples/s) |   Throughput speedup (mixed precision / FP32) |
+|--------------------:|------:|--------------------------------:|-------------------------------------------:|----------------------------------------------:|
+|                4096 |     ON |                       444532.22 |                                  541975.24 |                                          1.22 |
+|                8192 |     ON |                       505047.64 |                                  642784.48 |                                          1.27 |
+|               16384 |     ON |                       549325.54 |                                  727077.63 |                                          1.32 |
+|               32768 |     ON |                       587452.73 |                                  788606.35 |                                          1.34 |
+|               65536 |     ON |                       605187.67 |                                  832651.59 |                                          1.38 |
+|              131072 |     ON |                       599557.03 |                                  840602.90 |                                          1.40 |
 
 <details>
 <summary><b>
-Complete table of DGX2 inference performance results
+Complete table of DGX-2 inference performance results
 </b></summary>
 
-|GPUs |Global Batch Size   |XLA    |Precision      |Throughput (samples/s)           |
-|-----|--------------------|-------|---------------|-----------------------|
-|1    |4096                |OFF    |FP32           |459149.07 ± 20971.34   |
-|1    |8192                |OFF    |FP32           |488763.98 ± 15037.09   |
-|1    |16384               |OFF    |FP32           |516804.05 ± 8355.49    |
-|1    |32768               |OFF    |FP32           |534387.97 ± 4763.49    |
-|1    |65536               |OFF    |FP32           |536215.89 ± 5794.77    |
-|1    |131072              |OFF    |FP32           |538646.76 ± 6359.47    |
-|1    |4096                |OFF    |AMP            |488475.14 ± 6226.30    |
-|1    |8192                |OFF    |AMP            |632098.48 ± 27370.49   |
-|1    |16384               |OFF    |AMP            |705878.12 ± 7852.19    |
-|1    |32768               |OFF    |AMP            |739740.73 ± 6866.73    |
-|1    |65536               |OFF    |AMP            |618291.18 ± 26749.52   |
-|1    |131072              |OFF    |AMP            |544071.41 ± 19200.23   |
-|1    |4096                |ON     |FP32           |403479.95 ± 4079.19    |
-|1    |8192                |ON     |FP32           |480491.12 ± 6828.93    |
-|1    |16384               |ON     |FP32           |538737.44 ± 10932.49   |
-|1    |32768               |ON     |FP32           |580958.93 ± 9544.37    |
-|1    |65536               |ON     |FP32           |586275.07 ± 7640.59    |
-|1    |131072              |ON     |FP32           |613524.11 ± 7931.04    |
-|1    |4096                |ON     |AMP            |479051.62 ± 6076.26    |
-|1    |8192                |ON     |AMP            |600002.95 ± 16380.88   |
-|1    |16384               |ON     |AMP            |713203.59 ± 9515.25    |
-|1    |32768               |ON     |AMP            |790782.10 ± 10788.69   |
-|1    |65536               |ON     |AMP            |818038.44 ± 14132.80   |
-|1    |131072              |ON     |AMP            |734034.26 ± 34664.74   |
-|8    |4096                |OFF    |FP32           |502947.25 ± 105758.96  |
-|8    |8192                |OFF    |FP32           |809285.58 ± 112765.45  |
-|8    |16384               |OFF    |FP32           |1974085.95 ± 476616.90 |
-|8    |32768               |OFF    |FP32           |2990517.14 ± 645490.89 |
-|8    |65536               |OFF    |FP32           |3662830.22 ± 191010.11 |
-|8    |131072              |OFF    |FP32           |3978985.17 ± 142801.19 |
-|8    |4096                |OFF    |AMP            |596945.98 ± 92977.56   |
-|8    |8192                |OFF    |AMP            |730694.36 ± 67972.28   |
-|8    |16384               |OFF    |AMP            |1758189.25 ± 340547.41 |
-|8    |32768               |OFF    |AMP            |3873856.45 ± 528746.35 |
-|8    |65536               |OFF    |AMP            |4863371.50 ± 297299.34 |
-|8    |131072              |OFF    |AMP            |5134261.52 ± 473726.31 |
-|8    |4096                |ON     |FP32           |1059775.22 ± 24386.54  |
-|8    |8192                |ON     |FP32           |1845819.99 ± 250767.40 |
-|8    |16384               |ON     |FP32           |2801114.77 ± 210397.18 |
-|8    |32768               |ON     |FP32           |3396766.27 ± 221795.61 |
-|8    |65536               |ON     |FP32           |3911994.39 ± 239259.17 |
-|8    |131072              |ON     |FP32           |4197603.74 ± 158110.80 |
-|8    |4096                |ON     |AMP            |909719.30 ± 135634.13  |
-|8    |8192                |ON     |AMP            |1752510.62 ± 87042.91  |
-|8    |16384               |ON     |AMP            |2898423.08 ± 231659.28 |
-|8    |32768               |ON     |AMP            |4102026.01 ± 254242.94 |
-|8    |65536               |ON     |AMP            |4725023.23 ± 322597.53 |
-|8    |131072              |ON     |AMP            |5413542.58 ± 364633.26 |
-|16   |4096                |OFF    |FP32           |865109.29 ± 40032.58   |
-|16   |8192                |OFF    |FP32           |1565843.18 ± 305582.99 |
-|16   |16384               |OFF    |FP32           |3109303.21 ± 240314.57 |
-|16   |32768               |OFF    |FP32           |5750753.42 ± 898435.09 |
-|16   |65536               |OFF    |FP32           |6456324.48 ± 730326.61 |
-|16   |131072              |OFF    |FP32           |7415730.04 ± 434928.14 |
-|16   |4096                |OFF    |AMP            |742890.53 ± 27541.80   |
-|16   |8192                |OFF    |AMP            |1468615.49 ± 67548.46  |
-|16   |16384               |OFF    |AMP            |2591245.05 ± 394504.75 |
-|16   |32768               |OFF    |AMP            |4671719.91 ± 721705.81 |
-|16   |65536               |OFF    |AMP            |7982733.55 ± 1242742.25|
-|16   |131072              |OFF    |AMP            |9867894.78 ± 679119.71 |
-|16   |4096                |ON     |FP32           |1142272.86 ± 43154.49  |
-|16   |8192                |ON     |FP32           |2068920.70 ± 130214.35 |
-|16   |16384               |ON     |FP32           |3091676.83 ± 991449.61 |
-|16   |32768               |ON     |FP32           |5132772.75 ± 525201.10 |
-|16   |65536               |ON     |FP32           |6553882.87 ± 400638.86 |
-|16   |131072              |ON     |FP32           |7555906.17 ± 626110.02 |
-|16   |4096                |ON     |AMP            |924525.38 ± 163488.57  |
-|16   |8192                |ON     |AMP            |1917814.81 ± 59114.71  |
-|16   |16384               |ON     |AMP            |3496153.45 ± 190771.71 |
-|16   |32768               |ON     |AMP            |5063615.77 ± 1281699.58|
-|16   |65536               |ON     |AMP            |8247475.75 ± 539827.60 |
-|16   |131072              |ON     |AMP            |9571965.84 ± 764075.50 |
+|   Batch Size | XLA   | Precision   | Throughput  (samples/s)   |
+|-------------:|:------|:------------|:--------------------------|
+|         4096 | OFF   | FP32        | 459175.30 ± 23184.33      |
+|         8192 | OFF   | FP32        | 499179.20 ± 15967.26      |
+|        16384 | OFF   | FP32        | 525180.72 ± 2521.56       |
+|        32768 | OFF   | FP32        | 532042.10 ± 4020.44       |
+|        65536 | OFF   | FP32        | 534307.20 ± 7276.26       |
+|       131072 | OFF   | FP32        | 532311.44 ± 6195.16       |
+|         4096 | OFF   | AMP         | 581771.66 ± 6163.50       |
+|         8192 | OFF   | AMP         | 665048.04 ± 4607.95       |
+|        16384 | OFF   | AMP         | 716355.19 ± 7174.98       |
+|        32768 | OFF   | AMP         | 741642.61 ± 4981.04       |
+|        65536 | OFF   | AMP         | 755141.25 ± 6175.05       |
+|       131072 | OFF   | AMP         | 744459.46 ± 8183.17       |
+|         4096 | ON    | FP32        | 444532.22 ± 6239.01       |
+|         8192 | ON    | FP32        | 505047.64 ± 6543.06       |
+|        16384 | ON    | FP32        | 549325.54 ± 2841.21       |
+|        32768 | ON    | FP32        | 587452.73 ± 2366.43       |
+|        65536 | ON    | FP32        | 605187.67 ± 3740.07       |
+|       131072 | ON    | FP32        | 599557.03 ± 11811.28      |
+|         4096 | ON    | AMP         | 541975.24 ± 4441.93       |
+|         8192 | ON    | AMP         | 642784.48 ± 4721.08       |
+|        16384 | ON    | AMP         | 727077.63 ± 5332.80       |
+|        32768 | ON    | AMP         | 788606.35 ± 11705.36      |
+|        65536 | ON    | AMP         | 832651.59 ± 10401.17      |
+|       131072 | ON    | AMP         | 840602.90 ± 16358.73      |
 </details>
 
 <details>
 <summary><b>
-DGX A100 XLA-ON / XLA-OFF inference speedup
+DGX-2 XLA-ON / XLA-OFF inference speedup
 </b></summary>
 
 For each configuration of parameters present in the table, the `Speedup` column shows the speedup achieved by turning on XLA.
 
-|GPUs |Global Batch Size   |Precision      |Speedup |
-|-----|--------------------|---------------|--------|
-|1    |4096                |FP32           |0.879   |
-|1    |8192                |FP32           |0.983   |
-|1    |16384               |FP32           |1.042   |
-|1    |32768               |FP32           |1.087   |
-|1    |65536               |FP32           |1.093   |
-|1    |131072              |FP32           |1.139   |
-|1    |4096                |AMP            |0.981   |
-|1    |8192                |AMP            |0.949   |
-|1    |16384               |AMP            |1.010   |
-|1    |32768               |AMP            |1.069   |
-|1    |65536               |AMP            |1.323   |
-|1    |131072              |AMP            |1.349   |
-|8    |4096                |FP32           |2.107   |
-|8    |8192                |FP32           |2.281   |
-|8    |16384               |FP32           |1.419   |
-|8    |32768               |FP32           |1.136   |
-|8    |65536               |FP32           |1.068   |
-|8    |131072              |FP32           |1.055   |
-|8    |4096                |AMP            |1.524   |
-|8    |8192                |AMP            |2.398   |
-|8    |16384               |AMP            |1.649   |
-|8    |32768               |AMP            |1.059   |
-|8    |65536               |AMP            |0.972   |
-|8    |131072              |AMP            |1.054   |
-|16   |4096                |FP32           |1.320   |
-|16   |8192                |FP32           |1.321   |
-|16   |16384               |FP32           |0.994   |
-|16   |32768               |FP32           |0.893   |
-|16   |65536               |FP32           |1.015   |
-|16   |131072              |FP32           |1.019   |
-|16   |4096                |AMP            |1.244   |
-|16   |8192                |AMP            |1.306   |
-|16   |16384               |AMP            |1.349   |
-|16   |32768               |AMP            |1.084   |
-|16   |65536               |AMP            |1.033   |
-|16   |131072              |AMP            |0.970   |
+|Batch Size   |Precision      |Speedup |
+|--------------------|---------------|--------|
+|4096                |TF32           |0.968   |
+|8192                |TF32           |1.012   |
+|16384               |TF32           |1.046   |
+|32768               |TF32           |1.104   |
+|65536               |TF32           |1.133   |
+|131072              |TF32           |1.126   |
+|4096                |AMP            |0.932  |
+|8192                |AMP            |0.967   |
+|16384               |AMP            |1.384   |
+|32768               |AMP            |1.063   |
+|65536               |AMP            |1.103   |
+|131072              |AMP            |1.129   |
 </details>
 
 &nbsp;
@@ -1327,56 +1259,32 @@ For each configuration of parameters present in the table, the `Speedup` column
 NVIDIA A100 / DGX-2 (Ampere / Volta) inference speedup
 </b></summary>
 
-|GPUs |Global Batch Size   |XLA    |Precision      |Speedup |
-|-----|--------------------|-------|---------------|--------|
-|1    |4096                |OFF    |TF32/FP32      |1.275   |
-|1    |8192                |OFF    |TF32/FP32      |1.536   |
-|1    |16384               |OFF    |TF32/FP32      |1.555   |
-|1    |32768               |OFF    |TF32/FP32      |1.539   |
-|1    |65536               |OFF    |TF32/FP32      |1.447   |
-|1    |131072              |OFF    |TF32/FP32      |1.197   |
-|1    |4096                |OFF    |AMP            |1.057   |
-|1    |8192                |OFF    |AMP            |1.232   |
-|1    |16384               |OFF    |AMP            |1.321   |
-|1    |32768               |OFF    |AMP            |1.339   |
-|1    |65536               |OFF    |AMP            |1.158   |
-|1    |131072              |OFF    |AMP            |1.124   |
-|1    |4096                |ON     |TF32/FP32      |1.393   |
-|1    |8192                |ON     |TF32/FP32      |1.396   |
-|1    |16384               |ON     |TF32/FP32      |1.464   |
-|1    |32768               |ON     |TF32/FP32      |1.472   |
-|1    |65536               |ON     |TF32/FP32      |1.567   |
-|1    |131072              |ON     |TF32/FP32      |1.497   |
-|1    |4096                |ON     |AMP            |1.118   |
-|1    |8192                |ON     |AMP            |1.265   |
-|1    |16384               |ON     |AMP            |1.291   |
-|1    |32768               |ON     |AMP            |1.310   |
-|1    |65536               |ON     |AMP            |1.322   |
-|1    |131072              |ON     |AMP            |1.051   |
-|8    |4096                |OFF    |TF32/FP32      |1.521   |
-|8    |8192                |OFF    |TF32/FP32      |1.725   |
-|8    |16384               |OFF    |TF32/FP32      |1.156   |
-|8    |32768               |OFF    |TF32/FP32      |1.189   |
-|8    |65536               |OFF    |TF32/FP32      |1.308   |
-|8    |131072              |OFF    |TF32/FP32      |1.493   |
-|8    |4096                |OFF    |AMP            |1.077   |
-|8    |8192                |OFF    |AMP            |1.639   |
-|8    |16384               |OFF    |AMP            |1.116   |
-|8    |32768               |OFF    |AMP            |0.843   |
-|8    |65536               |OFF    |AMP            |0.997   |
-|8    |131072              |OFF    |AMP            |1.249   |
-|8    |4096                |ON     |TF32/FP32      |1.066   |
-|8    |8192                |ON     |TF32/FP32      |1.217   |
-|8    |16384               |ON     |TF32/FP32      |1.428   |
-|8    |32768               |ON     |TF32/FP32      |1.613   |
-|8    |65536               |ON     |TF32/FP32      |1.722   |
-|8    |131072              |ON     |TF32/FP32      |1.810   |
-|8    |4096                |ON     |AMP            |1.029   |
-|8    |8192                |ON     |AMP            |1.076   |
-|8    |16384               |ON     |AMP            |1.140   |
-|8    |32768               |ON     |AMP            |1.405   |
-|8    |65536               |ON     |AMP            |1.666   |
-|8    |131072              |ON     |AMP            |1.663   |
+|   Batch Size | XLA   | Precision   |   Speedup |
+|-------------:|:------|:------------|----------:|
+|         4096 | OFF   | TF32/FP32   |      1.54 |
+|         8192 | OFF   | TF32/FP32   |      1.75 |
+|        16384 | OFF   | TF32/FP32   |      1.79 |
+|        32768 | OFF   | TF32/FP32   |      1.77 |
+|        65536 | OFF   | TF32/FP32   |      1.80 |
+|       131072 | OFF   | TF32/FP32   |      1.81 |
+|         4096 | OFF   | AMP         |      1.11 |
+|         8192 | OFF   | AMP         |      1.38 |
+|        16384 | OFF   | AMP         |      1.59 |
+|        32768 | OFF   | AMP         |      1.64 |
+|        65536 | OFF   | AMP         |      1.71 |
+|       131072 | OFF   | AMP         |      1.74 |
+|         4096 | ON    | TF32/FP32   |      1.39 |
+|         8192 | ON    | TF32/FP32   |      1.43 |
+|        16384 | ON    | TF32/FP32   |      1.56 |
+|        32768 | ON    | TF32/FP32   |      1.66 |
+|        65536 | ON    | TF32/FP32   |      1.79 |
+|       131072 | ON    | TF32/FP32   |      1.83 |
+|         4096 | ON    | AMP         |      1.24 |
+|         8192 | ON    | AMP         |      1.32 |
+|        16384 | ON    | AMP         |      1.45 |
+|        32768 | ON    | AMP         |      1.61 |
+|        65536 | ON    | AMP         |      1.74 |
+|       131072 | ON    | AMP         |      1.76 |
 </details>
 
 &nbsp;
@@ -1388,10 +1296,12 @@ NVIDIA A100 / DGX-2 (Ampere / Volta) inference speedup
 May 2022
 - Initial release
 
-### Known issues
+November 2022
+- Moved batching and padding operations to preprocessing
+- Added support for prebatched samples during dataloading
+- Reduced throughput variance (previously appearing mainly during inference)
 
-- While benchmarking inference on a single GPU, sometimes throughput drops drastically in the middle of the epoch and remains low until the end of the epoch.
-- On a multi-GPU setup, the summary of throughput (in the last line of the logfile) is lower than it would result from each step`s throughput (sample/s). It is probably the case when a single GPU is slower than the one on the logging node. In this case, the overhead for synchronization before the final throughput calculation is higher than usual.
+### Known issues
 - The SIM model results are non-deterministic, even using the same random seed. The reason for this non-determinism is the [tf.math.unsorted_segment_sum](https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum) operation called within an optimization step. Its influence depends on categorical data distribution within a batch, and this issue is more severe for momentum-based optimizers. A potential solution is to use a deterministic version of this op which allows perfect reproduction, but is up to six times slower training.
 
 

+ 43 - 20
TensorFlow2/Recommendation/SIM/main.py

@@ -93,8 +93,8 @@ def init_logger(results_dir, filename):
 
 
 # In the future, select one of available dataloaders there (tfrecord, csv, etc...)
-def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length, prefetch_size, repeat_count=0,
-                      drop_remainder=False, amp=False, disable_cache=False):
+def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length, prefetch_size, num_parallel_calls=None, repeat_count=0,
+                      drop_remainder=False, amp=False, disable_cache=False, prebatch_size=0):
     return get_dataloader_tfrecord(
         paths,
         feature_spec=feature_spec,
@@ -105,7 +105,9 @@ def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length
         drop_remainder=drop_remainder,
         repeat_count=repeat_count,
         disable_cache=disable_cache,
-        prefetch_buffer_size=prefetch_size
+        prefetch_buffer_size=prefetch_size,
+        num_parallel_calls=num_parallel_calls,
+        prebatch_size=prebatch_size
     )
 
 
@@ -243,10 +245,24 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
         local_targets.append(targets)
         local_total_losses.append(loss_dict["total_loss"])
 
-    # concat all local variables into a single tensor
-    logits = tf.concat(local_logits, 0)
-    targets = tf.concat(local_targets, 0)
-    total_losses = tf.concat(local_total_losses, 0)
+    locals = [local_logits, local_targets, local_total_losses]
+    for i, local in enumerate(locals):
+
+        # wrap empty lists in tensor to allow tf.concat
+        if len(local) == 0:
+            local = tf.constant(local)
+
+        # concat all local variables into a single tensor
+        local = tf.concat(local, 0)
+
+        # for single element lists, tf.concat will produce shape=() instead of shape=(1,).
+        # reshape it for hvd.allgather to work
+        if len(local.shape) == 0:
+            local = tf.reshape(local, -1)
+
+        locals[i] = local
+    
+    logits, targets, total_losses = locals
 
     if distributed:
         # gather from all nodes
@@ -455,6 +471,9 @@ def inference(model, data_iterator, benchmark, performance_calculator):
 @click.option(
     "--global_batch_size", default=131072, help="Batch size used to train/eval the model.", type=int
 )
[email protected](
+    "--num_parallel_calls", default=None, help="Parallelism level for tf.data API. If None, heuristic based on number of CPUs and number of GPUs will be used."
+)
 @click.option(
     "--epochs", default=3, help="Train for the following number of epochs.", type=int
 )
@@ -521,10 +540,8 @@ def inference(model, data_iterator, benchmark, performance_calculator):
 )
 @click.option(
     "--prefetch_train_size",
-    default=-1,
+    default=10,
     help="Number of batches to prefetch in training. "
-    "If == 0: No prefetching is done. "
-    "If < 0: Prefetch size is set to train_dataset_size // global_batch_size. ",
 )
 @click.option(
     "--prefetch_test_size",
@@ -532,9 +549,14 @@ def inference(model, data_iterator, benchmark, performance_calculator):
     help="Number of batches to prefetch in testing"
 )
 @click.option(
-    "--train_dataset_size",
-    default=11796480,
-    help="Number of train samples. Used to set prefetching size (see --prefetch_train_size for more information."
+    "--prebatch_train_size",
+    default=0,
+    help="Information about batch size applied during preprocessing to train dataset"
+)
[email protected](
+    "--prebatch_test_size",
+    default=0,
+    help="Information about batch size applied during preprocessing to test dataset"
 )
 def main(
         mode: str,
@@ -554,6 +576,7 @@ def main(
         weight_decay: float,
         embedding_dim: int,
         global_batch_size: int,
+        num_parallel_calls: int,
         epochs: int,
         disable_cache: bool,
         drop_remainder: bool,
@@ -570,7 +593,8 @@ def main(
         intra_op_parallelism: int,
         prefetch_train_size: int,
         prefetch_test_size: int,
-        train_dataset_size: int
+        prebatch_train_size: int,
+        prebatch_test_size: int
 ):
     hvd.init()
 
@@ -636,20 +660,19 @@ def main(
     # since each tfrecord file must include all of the features, it is enough to read first chunk for each split. 
     train_files = [dataset_dir / file for file in feature_spec.source_spec[TRAIN_MAPPING][0][FILES_SELECTOR]]
 
-    if prefetch_train_size < 0:
-        prefetch_train_size = train_dataset_size // global_batch_size
-
     data_iterator_train = get_data_iterator(
         train_files, feature_spec, batch_size, num_gpus, long_seq_length,
         repeat_count=repeat_count, drop_remainder=drop_remainder,
-        amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_train_size
+        amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_train_size,
+        num_parallel_calls=num_parallel_calls, prebatch_size=prebatch_train_size
     )
 
     if mode == "train":
         test_files = [dataset_dir / file for file in feature_spec.source_spec[TEST_MAPPING][0][FILES_SELECTOR]]
         data_iterator_test = get_data_iterator(
             test_files, feature_spec, batch_size, num_gpus, long_seq_length,
-            amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_test_size
+            amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_test_size, num_parallel_calls=num_parallel_calls,
+            prebatch_size=prebatch_test_size
         )
     else:
         data_iterator_test = []  # otherwise not used
@@ -689,4 +712,4 @@ def main(
 
 
 if __name__ == "__main__":
-    main()
+    main()

+ 33 - 0
TensorFlow2/Recommendation/SIM/preprocessing/ops.py

@@ -73,6 +73,39 @@ def _preserve_data(offsets, values, new_values):
             new_values[i] = values[rowid]
 
 
[email protected]
+def _slice_rjust(max_elements, offsets, elements, new_offsets, new_elements):
+    rowid = numba.cuda.grid(1)
+    if rowid < new_offsets.size - 1:
+        row_size = min(offsets[rowid + 1] - offsets[rowid], max_elements)
+        offset = offsets[rowid + 1] - row_size
+        new_start = new_offsets[rowid + 1] - row_size
+
+        for i in range(row_size):
+            new_elements[new_start + i] = elements[offset + i]
+
+
+def slice_and_pad_left(seq_col, max_elements, pad_value=0):
+    c = seq_col._column
+    offsets = c.offsets.values
+    elements = c.elements.values
+
+    threads = THREADS
+    blocks = (offsets.size + threads - 1) // threads
+
+    new_offsets = cupy.arange(offsets.size, dtype=offsets.dtype) * max_elements
+
+    new_elements = cupy.full(
+        new_offsets[-1].item(), fill_value=pad_value, dtype=elements.dtype
+    )
+    _slice_rjust[blocks, threads](
+        max_elements, offsets, elements, new_offsets, new_elements
+    )
+
+    new_col = nvt_build_list_column(new_elements, new_offsets)
+    return new_col
+
+
 class ExplodeSequence:
     """
     For each row create a new one with a subsequence of the original list columns.

+ 108 - 68
TensorFlow2/Recommendation/SIM/preprocessing/parquet_to_tfrecord.py

@@ -21,9 +21,11 @@ from functools import partial
 
 import click
 import pandas as pd
+import numpy as np
 import tensorflow as tf
 
 from sim.data.feature_spec import FeatureSpec
+from sim.data.defaults import TRAIN_MAPPING, TEST_MAPPING, REMAINDER_FILENAME, FILES_SELECTOR
 
 # Docker image sets it to "python" for NVTabular purposes (bugfix), which slows down the script 20x
 os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
@@ -34,46 +36,31 @@ logging.basicConfig(
     format="[%(asctime)s] %(levelname)s: %(message)s",
 )
 
-
-def _int64_feature(value, islist=False):
-    """Returns an int64_list from a bool / enum / int / uint."""
-    if not islist:
-        value = [value]
-    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
-
-
-def process_chunk(df, sequential_data_start):
-    feature_values_lists = [df.iloc[:, i].values for i in range(sequential_data_start)]
-
-    for i in range(sequential_data_start, df.shape[1]):
-        values = df.iloc[:, i].values.tolist()
-        feature_values_lists.append(values)
-
-    return zip(*feature_values_lists)
-
-
-def prepare_record(sample, all_feature_names, sequential_data_start):
-
+def prepare_record(sample, all_feature_names, sequential_data_start, prebatch):
     feature = {}
-    for idx, (f_name, data) in enumerate(zip(all_feature_names, sample)):
-        islist = idx >= sequential_data_start
-        feature[f_name] = _int64_feature(data, islist)
+    for idx, (f_name, data) in enumerate(zip(all_feature_names, sample.values())):
+        if idx >= sequential_data_start:
+            if prebatch:
+                data = np.array(data).flatten()
+        else:
+            if not prebatch:
+                data = [data]
 
-    record_bytes = tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()
-    return record_bytes
+        feature[f_name] = tf.train.Feature(int64_list=tf.train.Int64List(value=data))
 
+    return tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()
 
-def create_default_feature_spec(user_features_cardinalities, item_features_cardinalities,
-                                max_seq_len, tfrecord_output_dir, train_output_file, test_output_file):
+def save_records(output_path, records, base_output_path, feature_spec, mapping):
 
-    train_output = tfrecord_output_dir / train_output_file
-    test_output = tfrecord_output_dir / test_output_file
+    with tf.io.TFRecordWriter(str(output_path)) as file_writer:
+        for record_bytes in records:
+            file_writer.write(record_bytes)
 
-    f_spec = FeatureSpec.get_default_feature_spec(user_features_cardinalities, item_features_cardinalities,
-                                                  max_seq_len, train_output, test_output)
+    feature_spec.source_spec[mapping][0][FILES_SELECTOR].append(
+        str(output_path.relative_to(base_output_path))
+    )
 
-    save_path = tfrecord_output_dir / 'feature_spec.yaml'
-    f_spec.to_yaml(save_path)
+    logging.info(f'Created: {output_path}')
 
 
 @click.command()
@@ -91,14 +78,15 @@ def create_default_feature_spec(user_features_cardinalities, item_features_cardi
 )
 @click.option(
     "--number_of_user_features",
-    required=True,
-    help="number of user specific features.",
+    default=1,
+    help="number of user specific features. Default is 1 for amazon books dataset (user_id).",
     type=int
 )
 @click.option(
     "--max_seq_len",
     default=100,
-    help="maximum possible length of history. (Entries will be padded to that length later)."
+    help="maximum possible length of history. (Entries will be padded to that length later).",
+    type=int
 )
 @click.option(
     "--n_proc",
@@ -109,30 +97,57 @@ def create_default_feature_spec(user_features_cardinalities, item_features_cardi
 @click.option(
     "--train_split_dir",
     default='train',
-    help="name of directory within amazon dataset directory containing train data."
+    help="Name of directory within amazon dataset directory containing train data.",
+    type=str
 )
 @click.option(
     "--test_split_dir",
     default='test',
-    help="name of directory within amazon dataset directory containing test data."
+    help="Name of directory within amazon dataset directory containing test data.",
+    type=str,
 )
 @click.option(
     "--metadata_file",
     default='metadata.json',
-    help="name of metadata file within amazon dataset directory (containing feature cardinalities)."
+    help="Name of metadata file within amazon dataset directory (containing feature cardinalities).",
+    type=str
 )
 @click.option(
-    "--train_output_file",
-    default='train.tfrecord',
-    help='name of train file within output directory.',
+    "--train_output_dir",
+    default='train',
+    help="Name of train directory within output directory.",
     type=str
 )
 @click.option(
-    "--test_output_file",
-    default='test.tfrecord',
-    help='name of test file within output directory.',
+    "--test_output_dir",
+    default='test',
+    help='Name of test directory within output directory.',
     type=str
 )
[email protected](
+    "--train_parts",
+    default=8,
+    help="Number of output train files.",
+    type=int
+)
[email protected](
+    "--test_parts",
+    default=4,
+    help="Number of output test files.",
+    type=int
+)
[email protected](
+    "--prebatch_train_size",
+    default=0,
+    help='Apply batching to data in preprocessing. If prebatch_size == 0, no prebatching is done.',
+    type=int
+)
[email protected](
+    "--prebatch_test_size",
+    default=0,
+    help='Apply batching to data in preprocessing. If prebatch_size == 0, no prebatching is done.',
+    type=int
+)
 def main(
         amazon_dataset_path: str,
         tfrecord_output_dir: str,
@@ -142,8 +157,12 @@ def main(
         train_split_dir: str,
         test_split_dir: str,
         metadata_file: str,
-        train_output_file: str,
-        test_output_file: str
+        train_output_dir: str,
+        test_output_dir: str,
+        train_parts: int,
+        test_parts: int,
+        prebatch_train_size: int,
+        prebatch_test_size: int
 ):
     """
     read_parquet()
@@ -160,11 +179,12 @@ def main(
         amazon_dataset_path / test_split_dir
     ]
 
-    os.makedirs(tfrecord_output_dir, exist_ok=True)
     output_splits = [
-        tfrecord_output_dir / train_output_file,
-        tfrecord_output_dir / test_output_file
+        tfrecord_output_dir / train_output_dir,
+        tfrecord_output_dir / test_output_dir
     ]
+    for split_dir in output_splits:
+        os.makedirs(split_dir, exist_ok=True)
 
     with open(amazon_dataset_path / metadata_file, 'r') as file:
         metadata = json.load(file)
@@ -176,35 +196,55 @@ def main(
     user_features_cardinalities = feature_cardinalities[:number_of_user_features]
     item_features_cardinalities = feature_cardinalities[number_of_user_features:]
 
-    create_default_feature_spec(user_features_cardinalities, item_features_cardinalities, max_seq_len,
-                                tfrecord_output_dir, train_output_file, test_output_file)
+    feature_spec = FeatureSpec.get_default_feature_spec(user_features_cardinalities, item_features_cardinalities, max_seq_len)
 
     number_of_item_features = len(item_features_cardinalities)
     sequential_data_start = 1 + number_of_user_features + number_of_item_features
     all_feature_names = FeatureSpec.get_default_features_names(number_of_user_features, number_of_item_features)
-    prepare_record_function = partial(prepare_record, all_feature_names=all_feature_names,
-                                      sequential_data_start=sequential_data_start)
+    
+    prebatch_per_split = [prebatch_train_size, prebatch_test_size]
+    parts_per_split = [train_parts, test_parts]
+    mappings = [TRAIN_MAPPING, TEST_MAPPING]
+
+    for mapping, input_dir, output_dir, output_parts, prebatch_size in zip(mappings, input_splits, output_splits, parts_per_split, prebatch_per_split):
+
+        prebatch = prebatch_size > 0
+        prepare_record_function = partial(prepare_record, all_feature_names=all_feature_names,
+                                        sequential_data_start=sequential_data_start, prebatch=prebatch)
+        save_records_function = partial(save_records, base_output_path=tfrecord_output_dir, feature_spec=feature_spec, mapping=mapping)
+
+        logging.info(f"Started conversion, will output to {output_dir}")
+
+        df = pd.read_parquet(input_dir, engine='pyarrow')
+
+        logging.info("Parquet loaded")
+
+        if prebatch:
+            df['batch_index'] = df.index // prebatch_size
+            df = df.groupby('batch_index').agg(list)
+            if len(df.iloc[-1, 0]) < prebatch_size:
+                remainder = df[-1:].to_dict('records')[0]
+                remainder = prepare_record_function(remainder)
 
-    for input_dir, output_file in zip(input_splits, output_splits):
+                df = df[:-1]
 
-        files = input_dir.glob("part.*.parquet")
-        def num_order(p): return int(p.name.split(".")[1])
-        paths = sorted(files, key=num_order)
+            logging.info("Prebatching applied")
 
-        logging.info(f"Started conversion, will output to {output_file}")
+        df = df.to_dict('records')
+        with multiprocessing.Pool(n_proc) as pool:
+            records = pool.map(prepare_record_function, df)
 
-        with tf.io.TFRecordWriter(str(output_file)) as file_writer:
-            with multiprocessing.Pool(n_proc) as pool:
-                for path in paths:
-                    df = pd.read_parquet(path)
+        logging.info("Records created")
 
-                    zipped_data = process_chunk(df, sequential_data_start)
+        records = np.array_split(records, output_parts)
+        for i, records_part in enumerate(records):
+            if len(records_part) > 0:
+                save_records_function(output_dir / f'part_{i}.tfrecord', records_part)
 
-                    records = pool.map(prepare_record_function, zipped_data)
-                    for record_bytes in records:
-                        file_writer.write(record_bytes)
+        if prebatch:
+            save_records_function(output_dir / REMAINDER_FILENAME, [remainder])
 
-                    logging.info(f"Processed {path}")
+    feature_spec.to_yaml(tfrecord_output_dir / 'feature_spec.yaml')
 
 
 if __name__ == "__main__":

+ 9 - 7
TensorFlow2/Recommendation/SIM/preprocessing/sim_preprocessing.py

@@ -25,7 +25,7 @@ import dask_cudf
 import rmm
 
 from preprocessing.io import load_metadata, load_review_data, save_metadata
-from preprocessing.ops import ExplodeSequence, add_negative_sequence, list_slice
+from preprocessing.ops import ExplodeSequence, add_negative_sequence, list_slice, slice_and_pad_left
 
 DASK_TRAIN_DATASET_CHUNKSIZE = 15_000
 TRAIN_DATA_DIR = "train"
@@ -179,11 +179,12 @@ def add_negative_sampling(df: cudf.DataFrame, sampling_df: cudf.DataFrame) -> cu
     return df
 
 
-def slice_sequences(df: cudf.DataFrame, max_elements: int) -> cudf.DataFrame:
-    df["item_sequence"] = list_slice(df["item_sequence"], -max_elements)
-    df["cat_sequence"] = list_slice(df["cat_sequence"], -max_elements)
-    df["neg_item_sequence"] = list_slice(df["neg_item_sequence"], -max_elements)
-    df["neg_cat_sequence"] = list_slice(df["neg_cat_sequence"], -max_elements)
+def pad_with_zeros(df: cudf.DataFrame, max_elements: int) -> cudf.DataFrame:
+    df["item_sequence"] = slice_and_pad_left(df["item_sequence"], max_elements)
+    df["cat_sequence"] = slice_and_pad_left(df["cat_sequence"], max_elements)
+    df["neg_item_sequence"] = slice_and_pad_left(df["neg_item_sequence"], max_elements)
+    df["neg_cat_sequence"] = slice_and_pad_left(df["neg_cat_sequence"], max_elements)
+
     return df
 
 
@@ -202,6 +203,7 @@ def create_train_dataset(
 
         df = explode_sequence(df, min_elements, max_elements)
         df = add_negative_sampling(df, sampling_df)
+        df = pad_with_zeros(df, max_elements)
         df = df.sort_values(by=["uid"])
         df.reset_index(drop=True, inplace=True)
         df = df[list(OUTPUT_META)]
@@ -222,7 +224,7 @@ def create_test_dataset(
     output_path: str,
 ) -> None:
     df = add_negative_sampling(df, sampling_df)
-    df = slice_sequences(df, max_elements)
+    df = pad_with_zeros(df, max_elements)
     df = df.sort_values(by=["uid"])
     df.reset_index(drop=True, inplace=True)
     df = df[list(OUTPUT_META)]

+ 5 - 1
TensorFlow2/Recommendation/SIM/scripts/run_model.sh

@@ -30,6 +30,8 @@ Usage: bash scripts/run_model.sh
 --log_filename          Name of output log file within results_dir. Default: log.json.
 --save_checkpoint_path  Path to output checkpoint after training.
 --load_checkpoint_path  Path from which to restore checkpoint for inference or suspend/resume training.
+--prebatch_train_size
+--prebatch_test_size
 EOF
 }
 
@@ -82,10 +84,12 @@ results_dir_option=$(get_option_or_use_default --results_dir $results_dir)
 log_filename_option=$(get_option_or_use_default --log_filename $log_filename)
 save_checkpoint_path_option=$(get_option_or_use_default --save_checkpoint_path $save_checkpoint_path)
 load_checkpoint_path_option=$(get_option_or_use_default --load_checkpoint_path $load_checkpoint_path)
+prebatch_train_size_option=$(get_option_or_use_default --prebatch_train_size $prebatch_train_size)
+prebatch_test_size_option=$(get_option_or_use_default --prebatch_test_size $prebatch_test_size)
 
 command="mpiexec --allow-run-as-root --bind-to socket -np ${gpus} python main.py --dataset_dir ${data_path} --drop_remainder ${epochs_option} 
 ${xla_arg} ${amp_arg} ${benchmark_arg} ${mode_option} ${benchmark_steps_option} ${batch_size_option} ${results_dir_option} ${log_filename_option}
-${save_checkpoint_path_option} ${load_checkpoint_path_option}"
+${save_checkpoint_path_option} ${load_checkpoint_path_option} ${prebatch_train_size_option} ${prebatch_test_size_option}"
 
 printf "[INFO] Running:\n%s\n" "${command}"
 # run

+ 62 - 21
TensorFlow2/Recommendation/SIM/sim/data/dataloader.py

@@ -18,12 +18,7 @@ from functools import partial
 import tensorflow as tf
 
 from sim.data.defaults import (DIMENSIONS_SELECTOR, LABEL_CHANNEL, NEGATIVE_HISTORY_CHANNEL, POSITIVE_HISTORY_CHANNEL,
-                               TARGET_ITEM_FEATURES_CHANNEL, USER_FEATURES_CHANNEL)
-
-
-def _pad_ragged_infront(x, pad_length):
-    x = tf.reverse(x, axis=[1])
-    return tf.reverse(x.to_tensor(shape=(None, pad_length)), axis=[1])
+                               TARGET_ITEM_FEATURES_CHANNEL, USER_FEATURES_CHANNEL, REMAINDER_FILENAME)
 
 
 def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
@@ -32,20 +27,20 @@ def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
     features = feature_spec.feature_spec
 
     user_features = {
-        f_name: sample[f_name] for f_name in channel_spec[USER_FEATURES_CHANNEL]
+        f_name: tf.reshape(sample[f_name], [-1]) for f_name in channel_spec[USER_FEATURES_CHANNEL]
     }
 
     target_item_features = {
-        f_name: sample[f_name] for f_name in channel_spec[TARGET_ITEM_FEATURES_CHANNEL]
+        f_name: tf.reshape(sample[f_name], [-1]) for f_name in channel_spec[TARGET_ITEM_FEATURES_CHANNEL]
     }
 
     padded_positive = {
-        f_name: _pad_ragged_infront(sample[f_name], features[f_name][DIMENSIONS_SELECTOR][0])
+        f_name: tf.reshape(sample[f_name], [-1, features[f_name][DIMENSIONS_SELECTOR][0]]) 
         for f_name in channel_spec[POSITIVE_HISTORY_CHANNEL]
     }
 
     padded_negative = {
-        f_name: _pad_ragged_infront(sample[f_name], features[f_name][DIMENSIONS_SELECTOR][0])
+        f_name: tf.reshape(sample[f_name], [-1, features[f_name][DIMENSIONS_SELECTOR][0]]) 
         for f_name in channel_spec[NEGATIVE_HISTORY_CHANNEL]
     }
 
@@ -70,7 +65,7 @@ def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
     short_sequence_mask = history_mask[:, long_seq_length:]
 
     label_name = channel_spec[LABEL_CHANNEL][0]
-    target = sample[label_name]
+    target = tf.reshape(sample[label_name], [-1])
 
     return {
         "user_features": user_features,
@@ -84,6 +79,14 @@ def _remap_column_values_tfrecord(sample, feature_spec, long_seq_length):
     }, target
 
 
+def split_prebatch(sample, split_into):
+    res = {}
+    for f_name, val in sample.items():
+        res[f_name] = tf.reshape(val, [split_into, -1])
+
+    return tf.data.Dataset.from_tensor_slices(res)
+
+
 def get_dataloader_tfrecord(
     file_paths,
     feature_spec,
@@ -94,36 +97,74 @@ def get_dataloader_tfrecord(
     drop_remainder=False,
     repeat_count=0,
     prefetch_buffer_size=90,
-    disable_cache=False):
+    num_parallel_calls=None,
+    disable_cache=False,
+    prebatch_size=0
+    ):
 
     features = feature_spec.feature_spec
+    prebatched = prebatch_size > 0
+
+    remainder_file = None
+    if file_paths[-1].name == REMAINDER_FILENAME:
+        remainder_file = file_paths[-1:]
+        file_paths = file_paths[:-1]
 
     tf_feature_spec = {}
     for name, feature in features.items():
         dimensions = feature.get(DIMENSIONS_SELECTOR)
         if dimensions is None:
-            tf_feature_spec[name] = tf.io.FixedLenFeature([], tf.int64)
-        else:
-            tf_feature_spec[name] = tf.io.RaggedFeature(tf.int64)
+            dimensions = [1] if prebatched else []
+
+        if prebatched:
+            dimensions = dimensions.copy()
+            dimensions[0] *= prebatch_size
 
-    num_cpus = multiprocessing.cpu_count()
+        tf_feature_spec[name] = tf.io.FixedLenFeature(dimensions, tf.int64)
 
-    dataset = tf.data.TFRecordDataset(file_paths)
+    if num_parallel_calls is None:
+        num_cpus = multiprocessing.cpu_count()
+        num_parallel_calls = 4 * num_cpus // num_gpus
+
+    dataset = tf.data.TFRecordDataset(file_paths, num_parallel_reads=num_parallel_calls)
 
     dataset = dataset.shard(num_gpus, id)
 
-    dataset = dataset.apply(
-        tf.data.experimental.dense_to_ragged_batch(batch_size, drop_remainder=drop_remainder)
+    splitting_function = None
+    if prebatched:
+        if batch_size >= prebatch_size:
+            batch_size = batch_size // prebatch_size
+        else:
+            split_into = prebatch_size // batch_size
+            splitting_function = partial(split_prebatch, split_into=split_into)
+            batch_size = 1
+
+    dataset = dataset.batch(
+        batch_size, drop_remainder=drop_remainder, num_parallel_calls=num_parallel_calls
     )
 
     dataset = dataset.map(
         map_func=partial(tf.io.parse_example, features=tf_feature_spec),
-        num_parallel_calls=num_cpus//num_gpus
+        num_parallel_calls=num_parallel_calls
     )
 
+    if splitting_function is not None:
+        dataset = dataset.flat_map(splitting_function)
+
+    if not drop_remainder and id == 0 and remainder_file is not None:
+        tf_feature_spec_remainder = {
+            name: tf.io.RaggedFeature(tf.int64) for name in tf_feature_spec
+        }
+        remainder = tf.data.TFRecordDataset(remainder_file)
+        remainder = remainder.map(
+            map_func=partial(tf.io.parse_example, features=tf_feature_spec_remainder)
+        )
+
+        dataset = dataset.concatenate(remainder)
+
     dataset = dataset.map(
         map_func=partial(_remap_column_values_tfrecord, feature_spec=feature_spec, long_seq_length=long_seq_length),
-        num_parallel_calls=num_cpus//num_gpus
+        num_parallel_calls=num_parallel_calls
     )
 
     if repeat_count > 0:

+ 2 - 0
TensorFlow2/Recommendation/SIM/sim/data/defaults.py

@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+REMAINDER_FILENAME = 'remainder.tfrecord'
+
 USER_FEATURES_CHANNEL = 'user_features'
 TARGET_ITEM_FEATURES_CHANNEL = 'target_item_features'
 POSITIVE_HISTORY_CHANNEL = 'positive_history'

+ 3 - 4
TensorFlow2/Recommendation/SIM/sim/data/feature_spec.py

@@ -72,8 +72,7 @@ class FeatureSpec:
         return [label_feature_name] + user_features_names + item_features_names
 
     @staticmethod
-    def get_default_feature_spec(user_features_cardinalities, item_features_cardinalities,
-                                 max_seq_len, train_output, test_output):
+    def get_default_feature_spec(user_features_cardinalities, item_features_cardinalities, max_seq_len):
 
         number_of_user_features = len(user_features_cardinalities)
         number_of_item_features = len(item_features_cardinalities)
@@ -127,9 +126,9 @@ class FeatureSpec:
                 {
                     'type': 'tfrecord',
                     'features': all_features_names,
-                    'files': [filepath.name]
+                    'files': []
                 }
-            ] for split, filepath in zip([TRAIN_MAPPING, TEST_MAPPING], [train_output, test_output])
+            ] for split in [TRAIN_MAPPING, TEST_MAPPING]
         }
 
         return FeatureSpec(feature_spec=feature_spec, channel_spec=channel_spec, source_spec=source_spec)