Kaynağa Gözat

[TFT/PyTorch] Move to nvFuser

Izzy Putterman 2 yıl önce
ebeveyn
işleme
777d174008

+ 2 - 1
PyTorch/Forecasting/TFT/Dockerfile

@@ -12,7 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.12-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.11-py3
+
 FROM ${FROM_IMAGE_NAME}
 
 # Set workdir and python path

+ 1 - 1
PyTorch/Forecasting/TFT/Dockerfile-triton

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.12-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.11-py3
 FROM ${FROM_IMAGE_NAME}
 
 # Ensure apt-get won't prompt for selecting options

+ 43 - 41
PyTorch/Forecasting/TFT/README.md

@@ -123,9 +123,6 @@ For information about:
   Training of Deep Neural
   Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/)
   blog.
-* APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in
-  PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/)
-  .
 
 
 #### Enabling mixed precision
@@ -169,7 +166,7 @@ The following section lists the requirements that you need to meet in order to s
 
 This repository contains Dockerfile, which extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
 -   [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
--   [PyTorch 21.12 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch)
+-   [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch)
 -   Supported GPUs:
 - [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
 - [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
@@ -371,7 +368,7 @@ The [NVIDIA Triton Inference Server](https://github.com/triton-inference-server/
 
 ### Benchmarking
 
-The following section shows how to run benchmarks measuring the model performance in training and inference modes.
+The following section shows how to run benchmarks measuring the model performance in training and inference modes. Note that the first 3 steps of each epoch are not used in the throughput or latency calculation.  This is due to the fact that the nvFuser performs the optimizations on the 3rd step of the first epoch causing a multi-second pause.
 
 #### Training performance benchmark
 
@@ -390,24 +387,24 @@ We conducted an extensive hyperparameter search along with stability tests. The
 
 ##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB)
 
-Our results were obtained by running the `train.sh` training script in the [PyTorch 21.06 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA A100 (8x A100 80GB) GPUs.
+Our results were obtained by running the `train.sh` training script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA A100 (8x A100 80GB) GPUs.
 
 | Dataset | GPUs | Batch size / GPU    | Accuracy - TF32  | Accuracy - mixed precision  |   Time to train - TF32  |  Time to train - mixed precision | Time to train speedup (TF32 to mixed precision)     
 |-------------|---|------|-----------------------|-----------------------|-------|-------|-------
-| Electricity | 8 | 1024 | 0.027 / 0.057 / 0.029 | 0.028 / 0.057 / 0.029 | 216s  | 176s  | 1.227x
-| Traffic     | 8 | 1024 | 0.043 / 0.108 / 0.079 | 0.042 / 0.107 / 0.078 | 151s  | 126s  | 1.198x
+| Electricity | 8 | 1024 | 0.026 / 0.056 / 0.029 | 0.028 / 0.058 / 0.029 | 200s  | 176s  | 1.136x
+| Traffic     | 8 | 1024 | 0.044 / 0.108 / 0.078 | 0.044 / 0.109 / 0.079 | 140s  | 129s  | 1.085x
 
 
 
 
 ##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
 
-Our results were obtained by running the `train.sh` training script in the [PyTorch 21.06 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
+Our results were obtained by running the `train.sh` training script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
 
 | Dataset | GPUs    | Batch size / GPU    | Accuracy - FP32  | Accuracy - mixed precision  |   Time to train - FP32  |  Time to train - mixed precision | Time to train speedup (FP32 to mixed precision)        
 |-------------|---|------|-----------------------|-----------------------|-------|-------|-----------
-| Electricity | 8 | 1024 | 0.028 / 0.057 / 0.029 | 0.027 / 0.057 / 0.029 | 381s  | 261s  | 1.460x   
-| Traffic     | 8 | 1024 | 0.042 / 0.106 / 0.076 | 0.040 / 0.103 / 0.074 | 256s  | 176s  | 1.455x
+| Electricity | 8 | 1024 | 0.028 / 0.057 / 0.028 | 0.027 / 0.059 / 0.030 | 371s  | 269s  | 1.379x   
+| Traffic     | 8 | 1024 | 0.042 / 0.110 / 0.080 | 0.043 / 0.109 / 0.080 | 251s  | 191s  | 1.314x
 
 
 
@@ -417,22 +414,22 @@ In order to get a greater picture of the model’s accuracy, we performed a hype
 
 | Dataset     | #GPU | Hidden size | #Heads | Local BS | LR   | Gradient clipping | Dropout | Mean q-risk | Std q-risk | Min q-risk | Max q-risk
 |-------------|------|-------------|--------|----------|------|-------------------|---------|-------------|------------| -----------|------ 
-| Electricity | 8    | 128         | 4      | 1024     | 1e-3 | 0.0               | 0.1     | 0.1131      | 0.0025     | 0.1080     | 0.1200
-| Traffic     | 8    | 128         | 4      | 1024     | 1e-3 | 0.0               | 0.3     | 0.2180      | 0.0049     | 0.2069     | 0.2336
+| Electricity | 8    | 128         | 4      | 1024     | 1e-3 | 0.0               | 0.1     | 0.1129      | 0.0025     | 0.1074     | 0.1244
+| Traffic     | 8    | 128         | 4      | 1024     | 1e-3 | 0.0               | 0.3     | 0.2262      | 0.0027     | 0.2207     | 0.2331
 
 
 #### Training performance results
 
 ##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
 
-Our results were obtained by running the `train.sh` training script in the [PyTorch 21.06 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA A100 (8x A100 80GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
+Our results were obtained by running the `train.sh` training script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA A100 (8x A100 80GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
 
 | Dataset | GPUs   | Batch size / GPU   | Throughput - TF32    | Throughput - mixed precision    | Throughput speedup (TF32 - mixed precision)   | Weak scaling - TF32    | Weak scaling - mixed precision        
 |-------------|---|------|--------|--------|-------|-------|-----
-| Electricity | 1 | 1024 | 10173  | 13703  | 1.35x | 1     | 1
-| Electricity | 8 | 1024 | 80596  | 107761 | 1.34x | 7.92x | 7.86x
-| Traffic     | 1 | 1024 | 10197  | 13779  | 1.35x | 1     | 1
-| Traffic     | 8 | 1024 | 80692  | 107979 | 1.34x | 7.91x | 7.84x
+| Electricity | 1 | 1024 | 12435  | 17608  | 1.42x | 1     | 1
+| Electricity | 8 | 1024 | 94389  | 130769 | 1.39x | 7.59x | 7.42x
+| Traffic     | 1 | 1024 | 12509  | 17591  | 1.40x | 1     | 1
+| Traffic     | 8 | 1024 | 94476  | 130992 | 1.39x | 7.55x | 7.45x
 
 
 To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
@@ -442,14 +439,14 @@ The performance metrics used were items per second.
 
 ##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
 
-Our results were obtained by running the `train.sh` training script in the [PyTorch 21.06 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 with (8x V100 16GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
+Our results were obtained by running the `train.sh` training script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 with (8x V100 16GB) GPUs. Performance numbers (in items/images per second) were averaged over an entire training epoch.
 
 | Dataset | GPUs   | Batch size / GPU   | Throughput - FP32    | Throughput - mixed precision    | Throughput speedup (FP32 - mixed precision)   | Weak scaling - FP32    | Weak scaling - mixed precision        
 |-------------|---|------|-------|-------|-------|------|----
-| Electricity | 1 | 1024 | 5580  | 9148  | 1.64x | 1     | 1
-| Electricity | 8 | 1024 | 43351 | 69855 | 1.61x | 7.77x | 7.64x
-| Traffic     | 1 | 1024 | 5593  | 9194  | 1.64x | 1     | 1
-| Traffic     | 8 | 1024 | 43426 | 69983 | 1.61x | 7.76x | 7.61x
+| Electricity | 1 | 1024 | 5932  | 10163 | 1.71x | 1     | 1
+| Electricity | 8 | 1024 | 45566 | 75660 | 1.66x | 7.68x | 7.44x
+| Traffic     | 1 | 1024 | 5971  | 10166 | 1.70x | 1     | 1
+| Traffic     | 8 | 1024 | 45925 | 75640 | 1.64x | 7.69x | 7.44x
 
 
 
@@ -463,39 +460,44 @@ The performance metrics used were items per second.
 
 ##### Inference Performance: NVIDIA DGX A100
 
-Our results were obtained by running the `inference.py` script in the [PyTorch 21.12 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX A100.  Throughput is measured in items per second and latency is measured in milliseconds.
+Our results were obtained by running the `inference.py` script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX A100.  Throughput is measured in items per second and latency is measured in milliseconds.
 To benchmark the inference performance on a specific batch size and dataset, run the `inference.py` script.
 | Dataset | GPUs   | Batch size / GPU   | Throughput - mixed precision (item/s)    | Average Latency (ms) | Latency p90 (ms) | Latency p95 (ms) | Latency p99 (ms)
 |-------------|--------|-----|---------------------------------|-----------------|-------------|-------------|------------
-| Electricity | 1      | 1   | 144.37   | 6.93 | 7.00 | 7.04 | 7.25
-| Electricity | 1      | 2   | 277.53   | 7.21 | 7.25 | 7.27 | 7.48
-| Electricity | 1      | 4   | 564.37   | 7.09 | 7.13 | 7.15 | 7.64
-| Electricity | 1      | 8   | 1399.25  | 5.72 | 5.71 | 5.77 | 7.51
-| Traffic     | 1      | 1   | 145.26   | 6.88 | 6.91 | 6.95 | 7.60
-| Traffic     | 1      | 2   | 277.97   | 7.19 | 7.28 | 7.30 | 7.46
-| Traffic     | 1      | 4   | 563.05   | 7.10 | 7.14 | 7.16 | 7.42
-| Traffic     | 1      | 8   | 1411.62  | 5.67 | 5.69 | 5.79 | 6.21
+| Electricity | 1      | 1   | 272.43   | 3.67 | 3.70 | 3.87 | 4.18
+| Electricity | 1      | 2   | 518.13   | 3.86 | 3.88 | 3.93 | 4.19
+| Electricity | 1      | 4   | 1039.31  | 3.85 | 3.89 | 3.97 | 4.15
+| Electricity | 1      | 8   | 2039.54  | 3.92 | 3.93 | 3.95 | 4.32
+| Traffic     | 1      | 1   | 269.59   | 3.71 | 3.74 | 3.79 | 4.30
+| Traffic     | 1      | 2   | 518.73   | 3.86 | 3.78 | 3.91 | 4.66
+| Traffic     | 1      | 4   | 1021.49  | 3.92 | 3.94 | 3.95 | 4.25
+| Traffic     | 1      | 8   | 2005.54  | 3.99 | 4.01 | 4.03 | 4.39
 
 
 ##### Inference Performance: NVIDIA DGX-1 V100
 
-Our results were obtained by running the `inference.py` script in the [PyTorch 21.12 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 V100.  Throughput is measured in items per second and latency is measured in milliseconds.
+Our results were obtained by running the `inference.py` script in the [PyTorch 22.11 NGC container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) on NVIDIA DGX-1 V100.  Throughput is measured in items per second and latency is measured in milliseconds.
 To benchmark the inference performance on a specific batch size and dataset, run the `inference.py` script.
 | Dataset | GPUs   | Batch size / GPU   | Throughput - mixed precision (item/s)    | Average Latency (ms) | Latency p90 (ms) | Latency p95 (ms) | Latency p99 (ms)
 |-------------|--------|-----|---------------------------------|-----------------|-------------|-------------|------------
-| Electricity | 1      | 1   | 95.65  | 10.45 | 11.30 | 11.95 | 12.13 
-| Electricity | 1      | 2   | 193.15  | 10.35 | 10.80 | 11.46 | 12.16 
-| Electricity | 1      | 4   | 381.09  | 10.49 | 10.75 | 12.29 | 12.41
-| Electricity | 1      | 8   | 805.49 | 9.93 | 10.41 | 10.48 | 10.91
-| Traffic     | 1      | 1   | 96.72  | 10.34 | 10.53 | 11.99 | 12.13
-| Traffic     | 1      | 2   | 192.93  | 10.37 | 10.80 | 11.97 | 12.12
-| Traffic     | 1      | 4   | 379.00  | 10.55 | 10.88 | 11.09 | 11.96
-| Traffic     | 1      | 8   | 859.69 | 9.30 | 10.58 | 10.65 | 11.28
+| Electricity | 1      | 1   | 171.68  | 5.82 | 5.99 | 6.17 | 7.00 
+| Electricity | 1      | 2   | 318.92  | 6.27 | 6.43 | 6.60 | 7.51 
+| Electricity | 1      | 4   | 684.79  | 5.84 | 6.02 | 6.08 | 6.47
+| Electricity | 1      | 8   | 1275.54 | 6.27 | 7.31 | 7.36 | 7.51
+| Traffic     | 1      | 1   | 183.39  | 5.45 | 5.64 | 5.86 | 6.73
+| Traffic     | 1      | 2   | 340.73  | 5.87 | 6.07 | 6.77 | 7.25
+| Traffic     | 1      | 4   | 647.33  | 6.18 | 6.35 | 7.99 | 8.07
+| Traffic     | 1      | 8   | 1364.39 | 5.86 | 6.07 | 6.40 | 7.31
 ## Release notes
 The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to https://developer.nvidia.com/deep-learning-performance-training-inference.
 
 ### Changelog
 
+March 2023
+- 23.01 Container Update
+- Switch from NVIDIA Apex AMP and NVIDIA Apex FusedLayerNorm to Native PyTorch AMP and Native PyTorch LayerNorm
+- Acceleration using NvFuser
+
 February 2022
 - 21.12 Container Update
 - Triton Inference Performance Numbers

+ 1 - 1
PyTorch/Forecasting/TFT/configuration.py

@@ -124,5 +124,5 @@ class TrafficConfig():
 
 
 CONFIGS = {'electricity':  ElectricityConfig,
-           'traffic':      TrafficConfig, 
+           'traffic':      TrafficConfig,
            }

+ 9 - 0
PyTorch/Forecasting/TFT/criterions.py

@@ -15,6 +15,7 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+import numpy as np
 
 class QuantileLoss(nn.Module):
     def __init__(self, config):
@@ -26,3 +27,11 @@ class QuantileLoss(nn.Module):
         ql = (1-self.q)*F.relu(diff) + self.q*F.relu(-diff)
         losses = ql.view(-1, ql.shape[-1]).mean(0)
         return losses
+
+def qrisk(pred, tgt, quantiles):
+    diff = pred - tgt
+    ql = (1-quantiles)*np.clip(diff,0, float('inf')) + quantiles*np.clip(-diff,0, float('inf'))
+    losses = ql.reshape(-1, ql.shape[-1])
+    normalizer = np.abs(tgt).mean()
+    risk = 2 * losses / normalizer
+    return risk.mean(0)

+ 47 - 2
PyTorch/Forecasting/TFT/data_utils.py

@@ -41,7 +41,8 @@ import numpy as np
 from bisect import bisect
 
 import torch
-from torch.utils.data import Dataset,IterableDataset,DataLoader
+from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler, RandomSampler
+from torch.utils.data.dataloader import default_collate
 
 class DataTypes(enum.IntEnum):
     """Defines numerical types of each column."""
@@ -401,6 +402,51 @@ def sample_data(dataset, num_samples):
     else:
         return torch.utils.data.Subset(dataset, np.random.choice(np.arange(len(dataset)), size=num_samples, replace=False))
 
+def load_dataset(args, config, collate_fn=default_collate):
+    from utils import print_once
+    train_split = TFTBinaryDataset(os.path.join(args.data_path, 'train.bin'), config)
+    train_split = sample_data(train_split, args.sample_data[0])
+    if args.distributed_world_size > 1:
+        data_sampler = DistributedSampler(train_split, args.distributed_world_size, args.distributed_rank, seed=args.seed + args.distributed_rank, drop_last=True)
+    else:
+        data_sampler = RandomSampler(train_split)
+    train_loader = DataLoader(train_split,
+                              batch_size=args.batch_size,
+                              num_workers=4,
+                              sampler=data_sampler, 
+                              collate_fn=collate_fn,
+                              pin_memory=True)
+
+    valid_split = TFTBinaryDataset(os.path.join(args.data_path, 'valid.bin'), config)
+    valid_split = sample_data(valid_split, args.sample_data[1])
+    if args.distributed_world_size > 1:
+        data_sampler = DistributedSampler(valid_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
+    else:
+        data_sampler = None
+    valid_loader = DataLoader(valid_split, 
+                              batch_size=args.batch_size, 
+                              sampler=data_sampler, 
+                              num_workers=4, 
+                              collate_fn=collate_fn,
+                              pin_memory=True)
+
+    test_split = TFTBinaryDataset(os.path.join(args.data_path, 'test.bin'), config)
+    if args.distributed_world_size > 1:
+        data_sampler = DistributedSampler(test_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
+    else:
+        data_sampler = None
+    test_loader = DataLoader(test_split,
+                             batch_size=args.batch_size, 
+                             sampler=data_sampler, 
+                             num_workers=4, 
+                             collate_fn=collate_fn,
+                             pin_memory=True)
+
+    print_once(f'Train split length: {len(train_split)}')
+    print_once(f'Valid split length: {len(valid_split)}')
+    print_once(f'Test split length: {len(test_split)}')
+
+    return train_loader, valid_loader, test_loader
 
 def standarize_electricity(path):
     """Code taken from https://github.com/google-research/google-research/blob/master/tft/script_download_data.py"""
@@ -574,4 +620,3 @@ def standarize_traffic(path):
   
     flat_df.to_csv(os.path.join(path, 'standarized.csv'))
 
-

+ 32 - 32
PyTorch/Forecasting/TFT/inference.py

@@ -26,12 +26,12 @@ from modeling import TemporalFusionTransformer
 from configuration import ElectricityConfig
 from data_utils import TFTDataset
 from utils import PerformanceMeter
-from criterions import QuantileLoss
+from criterions import qrisk
 import dllogger
 from log_helper import setup_logger
+from torch.cuda import amp
 
 def _unscale_per_id(config, values, ids, scalers):
-    values = values.cpu().numpy()
     num_horizons = config.example_length - config.encoder_length + 1
     flat_values = pd.DataFrame(
             values,
@@ -51,11 +51,9 @@ def _unscale_per_id(config, values, ids, scalers):
     flat_values = pd.concat(df_list, axis=0)
 
     flat_values = flat_values[[col for col in flat_values if not 'id' in col]]
-    flat_tensor = torch.from_numpy(flat_values.values)
-    return flat_tensor
+    return flat_values.values
 
 def _unscale(config, values, scaler):
-    values = values.cpu().numpy()
     num_horizons = config.example_length - config.encoder_length + 1
     flat_values = pd.DataFrame(
             values,
@@ -68,8 +66,7 @@ def _unscale(config, values, scaler):
             flat_values[col] = _t_col
 
     flat_values = flat_values[[col for col in flat_values if not 'id' in col]]
-    flat_tensor = torch.from_numpy(flat_values.values)
-    return flat_tensor
+    return flat_values.values
 
 def predict(args, config, model, data_loader, scalers, cat_encodings, extend_targets=False):
     model.eval()
@@ -78,36 +75,37 @@ def predict(args, config, model, data_loader, scalers, cat_encodings, extend_tar
     ids = []
     perf_meter = PerformanceMeter(benchmark_mode=not args.disable_benchmark)
     n_workers = args.distributed_world_size if hasattr(args, 'distributed_world_size') else 1
-
-    for step, batch in enumerate(data_loader):
-        perf_meter.reset_current_lap()
-        with torch.no_grad():
-            batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in batch.items()}
-            ids.append(batch['id'][:,0,:])
-            targets.append(batch['target'])
-            predictions.append(model(batch).float())
-
-        perf_meter.update(args.batch_size * n_workers,
-            exclude_from_total=step in [0, len(data_loader)-1])
-
-    targets = torch.cat(targets, dim=0)
+    
+    with torch.jit.fuser("fuser2"):
+        for step, batch in enumerate(data_loader):
+            perf_meter.reset_current_lap()
+            with torch.no_grad():
+                batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in batch.items()}
+                ids.append(batch['id'][:,0,:])
+                targets.append(batch['target'])
+                predictions.append(model(batch).float())
+
+            perf_meter.update(args.batch_size * n_workers,
+                exclude_from_total=step in [0, 1, 2, len(data_loader)-1])
+
+    targets = torch.cat(targets, dim=0).cpu().numpy()
     if not extend_targets:
         targets = targets[:,config.encoder_length:,:] 
-    predictions = torch.cat(predictions, dim=0)
+    predictions = torch.cat(predictions, dim=0).cpu().numpy()
     
     if config.scale_per_id:
         ids = torch.cat(ids, dim=0).cpu().numpy()
 
-        unscaled_predictions = torch.stack(
+        unscaled_predictions = np.stack(
                 [_unscale_per_id(config, predictions[:,:,i], ids, scalers) for i in range(len(config.quantiles))], 
-                dim=-1)
-        unscaled_targets = _unscale_per_id(config, targets[:,:,0], ids, scalers).unsqueeze(-1)
+                axis=-1)
+        unscaled_targets = np.expand_dims(_unscale_per_id(config, targets[:,:,0], ids, scalers), axis=-1)
     else:
         ids = None
-        unscaled_predictions = torch.stack(
+        unscaled_predictions = np.stack(
                 [_unscale(config, predictions[:,:,i], scalers['']) for i in range(len(config.quantiles))], 
-                dim=-1)
-        unscaled_targets = _unscale(config, targets[:,:,0], scalers['']).unsqueeze(-1)
+                axis=-1)
+        unscaled_targets = np.expand_dims(_unscale(config, targets[:,:,0], scalers['']), axis=-1)
 
     return unscaled_predictions, unscaled_targets, ids, perf_meter
 
@@ -173,9 +171,11 @@ def inference(args, config, model, data_loader, scalers, cat_encodings):
                     os.makedirs(os.path.join(args.results, 'predictions', str(key)), exist_ok=True)
                     df.to_csv(os.path.join(args.results, 'predictions', str(key), q+'.csv'))
 
-    losses = QuantileLoss(config)(unscaled_predictions, unscaled_targets)
-    normalizer = unscaled_targets.abs().mean()
-    q_risk = 2 * losses / normalizer
+    #losses = QuantileLoss(config)(torch.from_numpy(unscaled_predictions).contiguous(),
+    #        torch.from_numpy(unscaled_targets).contiguous()).numpy()
+    #normalizer = np.mean(np.abs(unscaled_targets))
+    #q_risk = 2 * losses / normalizer
+    risk = qrisk(unscaled_predictions, unscaled_targets, np.array(config.quantiles))
 
     perf_dict = {
                 'throughput': perf_meter.avg,
@@ -186,7 +186,7 @@ def inference(args, config, model, data_loader, scalers, cat_encodings):
                 'total_infernece_time': perf_meter.total_time,
                 }
 
-    return q_risk, perf_dict
+    return risk, perf_dict
 
 
 def main(args):
@@ -215,7 +215,7 @@ def main(args):
     quantiles = {'test_p10': quantiles[0].item(), 'test_p50': quantiles[1].item(), 'test_p90': quantiles[2].item(), 'sum':sum(quantiles).item()}
     finish_log = {**quantiles, **perf_dict}
     dllogger.log(step=(), data=finish_log, verbosity=1)
-    print('Test q-risk: P10 {} | P50 {} | P90 {}'.format(*quantiles))
+    print('Test q-risk: P10 {test_p10} | P50 {test_p50} | P90 {test_p90}'.format(**quantiles))
     print('Latency:\n\tAverage {:.3f}s\n\tp90 {:.3f}s\n\tp95 {:.3f}s\n\tp99 {:.3f}s'.format(
         perf_dict['latency_avg'], perf_dict['latency_p90'], perf_dict['latency_p95'], perf_dict['latency_p99']))
 

+ 166 - 65
PyTorch/Forecasting/TFT/modeling.py

@@ -17,12 +17,11 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 from torch import Tensor
+from torch.nn.parameter import UninitializedParameter
 from typing import Dict, Tuple, Optional, List
 
-if os.environ.get("TFT_SCRIPTING", False):
-    from torch.nn import LayerNorm
-else:
-    from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
+MAKE_CONVERT_COMPATIBLE = os.environ.get("TFT_SCRIPTING", None) is not None
+from torch.nn import LayerNorm
 
 class MaybeLayerNorm(nn.Module):
     def __init__(self, output_size, hidden_size, eps):
@@ -46,21 +45,20 @@ class GLU(nn.Module):
         x = F.glu(x)
         return x
 
-
 class GRN(nn.Module):
     def __init__(self,
                  input_size,
-                 hidden_size, 
+                 hidden_size,
                  output_size=None,
                  context_hidden_size=None,
-                 dropout=0):
+                 dropout=0.0,):
         super().__init__()
-
-        
         self.layer_norm = MaybeLayerNorm(output_size, hidden_size, eps=1e-3)
         self.lin_a = nn.Linear(input_size, hidden_size)
         if context_hidden_size is not None:
             self.lin_c = nn.Linear(context_hidden_size, hidden_size, bias=False)
+        else:
+            self.lin_c = nn.Identity()
         self.lin_i = nn.Linear(hidden_size, hidden_size)
         self.glu = GLU(hidden_size, output_size if output_size else hidden_size)
         self.dropout = nn.Dropout(dropout)
@@ -74,13 +72,28 @@ class GRN(nn.Module):
         x = self.lin_i(x)
         x = self.dropout(x)
         x = self.glu(x)
-        y = a if not self.out_proj else self.out_proj(a)
+        y = a if self.out_proj is None else self.out_proj(a)
         x = x + y
-        x = self.layer_norm(x)
-        return x 
+        return self.layer_norm(x) 
+
+
+# @torch.jit.script #Currently broken with autocast
+def fused_pointwise_linear_v1(x, a, b):
+    out = torch.mul(x.unsqueeze(-1), a)
+    out = out + b
+    return out
+
[email protected]
+def fused_pointwise_linear_v2(x, a, b):
+    out = x.unsqueeze(3) * a
+    out = out + b
+    return out
+
 
 class TFTEmbedding(nn.Module):
-    def __init__(self, config):
+    def __init__(self, config, initialize_cont_params=True):
+        # initialize_cont_params=False prevents form initializing parameters inside this class
+        # so they can be lazily initialized in LazyEmbedding module
         super().__init__()
         self.s_cat_inp_lens    = config.static_categorical_inp_lens
         self.t_cat_k_inp_lens  = config.temporal_known_categorical_inp_lens
@@ -108,23 +121,43 @@ class TFTEmbedding(nn.Module):
         self.t_cat_o_embed = nn.ModuleList([
             nn.Embedding(n, self.hidden_size) for n in self.t_cat_o_inp_lens]) if self.t_cat_o_inp_lens else None
 
-        self.s_cont_embedding_vectors = nn.Parameter(torch.Tensor(self.s_cont_inp_size, self.hidden_size)) if self.s_cont_inp_size else None
-        self.t_cont_k_embedding_vectors = nn.Parameter(torch.Tensor(self.t_cont_k_inp_size, self.hidden_size)) if self.t_cont_k_inp_size else None
-        self.t_cont_o_embedding_vectors = nn.Parameter(torch.Tensor(self.t_cont_o_inp_size, self.hidden_size)) if self.t_cont_o_inp_size else None
-        self.t_tgt_embedding_vectors = nn.Parameter(torch.Tensor(self.t_tgt_size, self.hidden_size))
+        if initialize_cont_params:
+            self.s_cont_embedding_vectors = nn.Parameter(torch.Tensor(self.s_cont_inp_size, self.hidden_size)) if self.s_cont_inp_size else None
+            self.t_cont_k_embedding_vectors = nn.Parameter(torch.Tensor(self.t_cont_k_inp_size, self.hidden_size)) if self.t_cont_k_inp_size else None
+            self.t_cont_o_embedding_vectors = nn.Parameter(torch.Tensor(self.t_cont_o_inp_size, self.hidden_size)) if self.t_cont_o_inp_size else None
+            self.t_tgt_embedding_vectors = nn.Parameter(torch.Tensor(self.t_tgt_size, self.hidden_size))
 
-        self.s_cont_embedding_bias = nn.Parameter(torch.zeros(self.s_cont_inp_size, self.hidden_size)) if self.s_cont_inp_size else None
-        self.t_cont_k_embedding_bias = nn.Parameter(torch.zeros(self.t_cont_k_inp_size, self.hidden_size)) if self.t_cont_k_inp_size else None
-        self.t_cont_o_embedding_bias = nn.Parameter(torch.zeros(self.t_cont_o_inp_size, self.hidden_size)) if self.t_cont_o_inp_size else None
-        self.t_tgt_embedding_bias = nn.Parameter(torch.zeros(self.t_tgt_size, self.hidden_size))
+            self.s_cont_embedding_bias = nn.Parameter(torch.zeros(self.s_cont_inp_size, self.hidden_size)) if self.s_cont_inp_size else None
+            self.t_cont_k_embedding_bias = nn.Parameter(torch.zeros(self.t_cont_k_inp_size, self.hidden_size)) if self.t_cont_k_inp_size else None
+            self.t_cont_o_embedding_bias = nn.Parameter(torch.zeros(self.t_cont_o_inp_size, self.hidden_size)) if self.t_cont_o_inp_size else None
+            self.t_tgt_embedding_bias = nn.Parameter(torch.zeros(self.t_tgt_size, self.hidden_size))
 
+            self.reset_parameters()
+
+
+    def reset_parameters(self):
         if self.s_cont_embedding_vectors is not None:
             torch.nn.init.xavier_normal_(self.s_cont_embedding_vectors)
+            torch.nn.init.zeros_(self.s_cont_embedding_bias)
         if self.t_cont_k_embedding_vectors is not None:
             torch.nn.init.xavier_normal_(self.t_cont_k_embedding_vectors)
+            torch.nn.init.zeros_(self.t_cont_k_embedding_bias)
         if self.t_cont_o_embedding_vectors is not None:
             torch.nn.init.xavier_normal_(self.t_cont_o_embedding_vectors)
-        torch.nn.init.xavier_normal_(self.t_tgt_embedding_vectors)
+            torch.nn.init.zeros_(self.t_cont_o_embedding_bias)
+        if self.t_tgt_embedding_vectors is not None:
+            torch.nn.init.xavier_normal_(self.t_tgt_embedding_vectors)
+            torch.nn.init.zeros_(self.t_tgt_embedding_bias)
+        if self.s_cat_embed is not None:
+            for module in self.s_cat_embed:
+                module.reset_parameters()
+        if self.t_cat_k_embed is not None:
+            for module in self.t_cat_k_embed:
+                module.reset_parameters()
+        if self.t_cat_o_embed is not None:
+            for module in self.t_cat_o_embed:
+                module.reset_parameters()
+
 
     def _apply_embedding(self,
             cat: Optional[Tensor],
@@ -138,8 +171,11 @@ class TFTEmbedding(nn.Module):
             #the line below is equivalent to following einsums
             #e_cont = torch.einsum('btf,fh->bthf', cont, cont_emb)
             #e_cont = torch.einsum('bf,fh->bhf', cont, cont_emb)
-            e_cont = torch.mul(cont.unsqueeze(-1), cont_emb)
-            e_cont = e_cont + cont_bias
+            if MAKE_CONVERT_COMPATIBLE:
+                e_cont = torch.mul(cont.unsqueeze(-1), cont_emb)
+                e_cont = e_cont + cont_bias
+            else:
+                e_cont = fused_pointwise_linear_v1(cont, cont_emb, cont_bias)
         else:
             e_cont = None
 
@@ -185,11 +221,68 @@ class TFTEmbedding(nn.Module):
 
         # Temporal observed targets
         # t_observed_tgt = torch.einsum('btf,fh->btfh', t_tgt_obs, self.t_tgt_embedding_vectors)
-        t_observed_tgt = torch.matmul(t_tgt_obs.unsqueeze(3).unsqueeze(4), self.t_tgt_embedding_vectors.unsqueeze(1)).squeeze(3)
-        t_observed_tgt = t_observed_tgt + self.t_tgt_embedding_bias
+        if MAKE_CONVERT_COMPATIBLE:
+            t_observed_tgt = torch.matmul(t_tgt_obs.unsqueeze(3).unsqueeze(4), self.t_tgt_embedding_vectors.unsqueeze(1)).squeeze(3)
+            t_observed_tgt = t_observed_tgt + self.t_tgt_embedding_bias
+        else:
+            t_observed_tgt = fused_pointwise_linear_v2(t_tgt_obs, self.t_tgt_embedding_vectors, self.t_tgt_embedding_bias)
 
         return s_inp, t_known_inp, t_observed_inp, t_observed_tgt
 
+class LazyEmbedding(nn.modules.lazy.LazyModuleMixin, TFTEmbedding):
+    cls_to_become = TFTEmbedding
+
+    def __init__(self, config):
+        super().__init__(config, initialize_cont_params=False)
+
+        if config.static_continuous_inp_size:
+            self.s_cont_embedding_vectors = UninitializedParameter()
+            self.s_cont_embedding_bias = UninitializedParameter()
+        else:
+            self.s_cont_embedding_vectors = None
+            self.s_cont_embedding_bias = None
+
+        if config.temporal_known_continuous_inp_size:
+            self.t_cont_k_embedding_vectors = UninitializedParameter()
+            self.t_cont_k_embedding_bias = UninitializedParameter()
+        else:
+            self.t_cont_k_embedding_vectors = None
+            self.t_cont_k_embedding_bias = None
+
+        if config.temporal_observed_continuous_inp_size:
+            self.t_cont_o_embedding_vectors = UninitializedParameter()
+            self.t_cont_o_embedding_bias = UninitializedParameter()
+        else:
+            self.t_cont_o_embedding_vectors = None
+            self.t_cont_o_embedding_bias = None
+
+        self.t_tgt_embedding_vectors = UninitializedParameter()
+        self.t_tgt_embedding_bias = UninitializedParameter()
+
+    def initialize_parameters(self, x):
+        if self.has_uninitialized_params():
+            s_cont_inp = x.get('s_cont', None)
+            t_cont_k_inp = x.get('k_cont', None)
+            t_cont_o_inp = x.get('o_cont', None)
+            t_tgt_obs = x['target'] # Has to be present
+
+            if s_cont_inp is not None:
+                self.s_cont_embedding_vectors.materialize((s_cont_inp.shape[-1], self.hidden_size))
+                self.s_cont_embedding_bias.materialize((s_cont_inp.shape[-1], self.hidden_size))
+
+            if t_cont_k_inp is not None:
+                self.t_cont_k_embedding_vectors.materialize((t_cont_k_inp.shape[-1], self.hidden_size))
+                self.t_cont_k_embedding_bias.materialize((t_cont_k_inp.shape[-1], self.hidden_size))
+
+            if t_cont_o_inp is not None:
+                self.t_cont_o_embedding_vectors.materialize((t_cont_o_inp.shape[-1], self.hidden_size))
+                self.t_cont_o_embedding_bias.materialize((t_cont_o_inp.shape[-1], self.hidden_size))
+
+            self.t_tgt_embedding_vectors.materialize((t_tgt_obs.shape[-1], self.hidden_size))
+            self.t_tgt_embedding_bias.materialize((t_tgt_obs.shape[-1], self.hidden_size))
+
+            self.reset_parameters()
+
 class VariableSelectionNetwork(nn.Module):
     def __init__(self, config, num_inputs):
         super().__init__()
@@ -197,7 +290,7 @@ class VariableSelectionNetwork(nn.Module):
         self.var_grns = nn.ModuleList([GRN(config.hidden_size, config.hidden_size, dropout=config.dropout) for _ in range(num_inputs)])
 
     def forward(self, x: Tensor, context: Optional[Tensor] = None):
-        Xi = x.reshape(*x.shape[:-2], -1)
+        Xi = torch.flatten(x, start_dim=-2)
         grn_outputs = self.joint_grn(Xi, c=context)
         sparse_weights = F.softmax(grn_outputs, dim=-1)
         transformed_embed_list = [m(x[...,i,:]) for i, m in enumerate(self.var_grns)]
@@ -223,7 +316,7 @@ class StaticCovariateEncoder(nn.Module):
         # enrichment context
         # state_c context
         # state_h context
-        cs, ce, ch, cc = tuple(m(variable_ctx) for m in self.context_grns)
+        cs, ce, ch, cc = [m(variable_ctx) for m in self.context_grns]
 
         return cs, ce, ch, cc
 
@@ -241,7 +334,7 @@ class InterpretableMultiHeadAttention(nn.Module):
         self.scale = self.d_head**-0.5
         self.register_buffer("_mask", torch.triu(torch.full((config.example_length, config.example_length), float('-inf')), 1).unsqueeze(0))
 
-    def forward(self, x: Tensor, mask_future_timesteps: bool = True) -> Tuple[Tensor, Tensor]:
+    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
         bs, t, h_size = x.shape
         qkv = self.qkv_linears(x)
         q, k, v = qkv.split((self.n_head * self.d_head, self.n_head * self.d_head, self.d_head), dim=-1)
@@ -253,8 +346,7 @@ class InterpretableMultiHeadAttention(nn.Module):
         attn_score = torch.matmul(q.permute((0, 2, 1, 3)), k.permute((0, 2, 3, 1)))
         attn_score.mul_(self.scale)
 
-        if mask_future_timesteps:
-            attn_score = attn_score + self._mask
+        attn_score = attn_score + self._mask
 
         attn_prob = F.softmax(attn_score, dim=3)
         attn_prob = self.attn_dropout(attn_prob)
@@ -267,24 +359,12 @@ class InterpretableMultiHeadAttention(nn.Module):
 
         return out, attn_prob
 
-
-
-class TemporalFusionTransformer(nn.Module):
-    """ 
-    Implementation of https://arxiv.org/abs/1912.09363 
-    """
+class TFTBack(nn.Module):
     def __init__(self, config):
         super().__init__()
 
-        if hasattr(config, 'model'):
-            config = config.model
-
-        self.encoder_length = config.encoder_length #this determines from how distant past we want to use data from
-
-        self.embedding = TFTEmbedding(config)
-        self.static_encoder = StaticCovariateEncoder(config)
-
-        self.history_vsn = VariableSelectionNetwork(config, config.num_historic_vars) 
+        self.encoder_length = config.encoder_length
+        self.history_vsn = VariableSelectionNetwork(config, config.num_historic_vars)
         self.history_encoder = nn.LSTM(config.hidden_size, config.hidden_size, batch_first=True)
         self.future_vsn = VariableSelectionNetwork(config, config.num_future_vars)
         self.future_encoder = nn.LSTM(config.hidden_size, config.hidden_size, batch_first=True)
@@ -309,28 +389,13 @@ class TemporalFusionTransformer(nn.Module):
         self.decoder_ln = LayerNorm(config.hidden_size, eps=1e-3)
 
         self.quantile_proj = nn.Linear(config.hidden_size, len(config.quantiles))
-
-    def forward(self, x: Dict[str, Tensor]) -> Tensor:
-        s_inp, t_known_inp, t_observed_inp, t_observed_tgt = self.embedding(x)
-
-        # Static context
-        cs, ce, ch, cc = self.static_encoder(s_inp)
-        ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) #lstm initial states
-
-        # Temporal input
-        _historical_inputs = [t_known_inp[:,:self.encoder_length,:], t_observed_tgt[:,:self.encoder_length,:]]
-        if t_observed_inp is not None:
-            _historical_inputs.insert(0,t_observed_inp[:,:self.encoder_length,:])
-
-        historical_inputs = torch.cat(_historical_inputs, dim=-2)
-        future_inputs = t_known_inp[:, self.encoder_length:]
-
-        # Encoders
+        
+    def forward(self, historical_inputs, cs, ch, cc, ce, future_inputs):
         historical_features, _ = self.history_vsn(historical_inputs, cs)
         history, state = self.history_encoder(historical_features, (ch, cc))
         future_features, _ = self.future_vsn(future_inputs, cs)
         future, _ = self.future_encoder(future_features, state)
-        torch.cuda.synchronize() # this call gives perf boost for unknown reasons
+        torch.cuda.synchronize()
 
         # skip connection
         input_embedding = torch.cat([historical_features, future_features], dim=1)
@@ -343,7 +408,7 @@ class TemporalFusionTransformer(nn.Module):
         enriched = self.enrichment_grn(temporal_features, c=ce)
 
         # Temporal self attention
-        x, _ = self.attention(enriched, mask_future_timesteps=True)
+        x, _ = self.attention(enriched)
 
         # Don't compute hictorical quantiles
         x = x[:, self.encoder_length:, :]
@@ -365,3 +430,39 @@ class TemporalFusionTransformer(nn.Module):
         out = self.quantile_proj(x)
 
         return out
+
+
+class TemporalFusionTransformer(nn.Module):
+    """ 
+    Implementation of https://arxiv.org/abs/1912.09363 
+    """
+    def __init__(self, config):
+        super().__init__()
+
+        if hasattr(config, 'model'):
+            config = config.model
+
+        self.encoder_length = config.encoder_length #this determines from how distant past we want to use data from
+
+        self.embedding = LazyEmbedding(config)
+        self.static_encoder = StaticCovariateEncoder(config)
+        if MAKE_CONVERT_COMPATIBLE:
+            self.TFTpart2 = TFTBack(config)
+        else:
+            self.TFTpart2 = torch.jit.script(TFTBack(config))
+
+    def forward(self, x: Dict[str, Tensor]) -> Tensor:
+        s_inp, t_known_inp, t_observed_inp, t_observed_tgt = self.embedding(x)
+
+        # Static context
+        cs, ce, ch, cc = self.static_encoder(s_inp)
+        ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) #lstm initial states
+
+        # Temporal input
+        _historical_inputs = [t_known_inp[:,:self.encoder_length,:], t_observed_tgt[:,:self.encoder_length,:]]
+        if t_observed_inp is not None:
+            _historical_inputs.insert(0,t_observed_inp[:,:self.encoder_length,:])
+
+        historical_inputs = torch.cat(_historical_inputs, dim=-2)
+        future_inputs = t_known_inp[:, self.encoder_length:]
+        return self.TFTpart2(historical_inputs, cs, ch, cc, ce, future_inputs)

+ 37 - 55
PyTorch/Forecasting/TFT/train.py

@@ -23,10 +23,9 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torch.distributed as dist
 from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
-from apex import amp
 from apex.optimizers import FusedAdam
-#from torch.nn.parallel import DistributedDataParallel as DDP
-from apex.parallel import DistributedDataParallel as DDP
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.cuda import amp
 
 import numpy as np
 
@@ -34,48 +33,14 @@ import dllogger
 
 from modeling import TemporalFusionTransformer
 from configuration import CONFIGS
-from data_utils import TFTBinaryDataset, sample_data
+from data_utils import load_dataset
 from log_helper import setup_logger
 from criterions import QuantileLoss
 from inference import predict
-from utils import PerformanceMeter
+from utils import PerformanceMeter, print_once
 import gpu_affinity
 from ema import ModelEma
 
-def load_dataset(args, config):
-    train_split = TFTBinaryDataset(os.path.join(args.data_path, 'train.bin'), config)
-    train_split = sample_data(train_split, args.sample_data[0])
-    if args.distributed_world_size > 1:
-        data_sampler = DistributedSampler(train_split, args.distributed_world_size, args.distributed_rank, seed=args.seed + args.distributed_rank, drop_last=True)
-    else:
-        data_sampler = RandomSampler(train_split)
-    train_loader = DataLoader(train_split, batch_size=args.batch_size, num_workers=4, sampler=data_sampler, pin_memory=True)
-
-    valid_split = TFTBinaryDataset(os.path.join(args.data_path, 'valid.bin'), config)
-    valid_split = sample_data(valid_split, args.sample_data[1])
-    if args.distributed_world_size > 1:
-        data_sampler = DistributedSampler(valid_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
-    else:
-        data_sampler = None
-    valid_loader = DataLoader(valid_split, batch_size=args.batch_size, sampler=data_sampler, num_workers=4, pin_memory=True)
-
-    test_split = TFTBinaryDataset(os.path.join(args.data_path, 'test.bin'), config)
-    if args.distributed_world_size > 1:
-        data_sampler = DistributedSampler(test_split, args.distributed_world_size, args.distributed_rank, shuffle=False, drop_last=False)
-    else:
-        data_sampler = None
-    test_loader = DataLoader(test_split, batch_size=args.batch_size, sampler=data_sampler, num_workers=4, pin_memory=True)
-
-    print_once(f'Train split length: {len(train_split)}')
-    print_once(f'Valid split length: {len(valid_split)}')
-    print_once(f'Test split length: {len(test_split)}')
-
-    return train_loader, valid_loader, test_loader
-
-def print_once(*args, **kwargs):
-    if not dist.is_initialized() or dist.get_rank() == 0:
-        print(*args, **kwargs)
-
 
 def main(args):
     ### INIT DISTRIBUTED
@@ -113,23 +78,28 @@ def main(args):
 
     dllogger.log(step='HPARAMS', data={**vars(args), **vars(config)}, verbosity=1)
 
+    train_loader, valid_loader, test_loader = load_dataset(args, config)
+
     model = TemporalFusionTransformer(config).cuda()
     if args.ema_decay:
         model_ema = ModelEma(model, decay=args.ema_decay)
 
-    print_once('Model params: {}'.format(sum(p.numel() for p in model.parameters())))
+    # Run dummy iteration to initialize lazy modules
+    dummy_batch = next(iter(train_loader))
+    dummy_batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in dummy_batch.items()}
+    model(dummy_batch)
+
     criterion = QuantileLoss(config).cuda()
     optimizer = FusedAdam(model.parameters(), lr=args.lr)
-    if args.use_amp:
-        model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic")
     if args.distributed_world_size > 1:
-        #model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
-        model = DDP(model)
+        model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
 
-    train_loader, valid_loader, test_loader = load_dataset(args, config)
 
+    print_once('Model params: {}'.format(sum(p.numel() for p in model.parameters())))
     global_step = 0
     perf_meter = PerformanceMeter(benchmark_mode=not args.disable_benchmark)
+    if args.use_amp:
+        scaler = amp.GradScaler(init_scale=32768.0)
 
     for epoch in range(args.epochs):
         start = time.time()
@@ -139,20 +109,28 @@ def main(args):
         for local_step, batch in enumerate(train_loader):
             perf_meter.reset_current_lap()
             batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in batch.items()}
-            predictions = model(batch)
-            targets = batch['target'][:,config.encoder_length:,:]
-            p_losses = criterion(predictions, targets)
-            loss = p_losses.sum()
-
+            with torch.jit.fuser("fuser2"), amp.autocast(enabled=args.use_amp):
+                predictions = model(batch)
+                targets = batch['target'][:,config.encoder_length:,:]
+                p_losses = criterion(predictions, targets)
+                loss = p_losses.sum()
+            if global_step == 0 and args.ema_decay:
+                model_ema(batch)
             if args.use_amp:
-                with amp.scale_loss(loss, optimizer) as scaled_loss:
-                    scaled_loss.backward()
+                scaler.scale(loss).backward()
+
             else:
                 loss.backward()
             if not args.grad_accumulation or (global_step+1) % args.grad_accumulation == 0:
+                if args.use_amp:
+                    scaler.unscale_(optimizer)
                 if args.clip_grad:
                     torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
-                optimizer.step()
+                if args.use_amp:
+                    scaler.step(optimizer)
+                    scaler.update()
+                else:
+                    optimizer.step()
                 optimizer.zero_grad()
                 if args.ema_decay:
                     model_ema.update(model)
@@ -164,7 +142,7 @@ def main(args):
 
             torch.cuda.synchronize()
             ips = perf_meter.update(args.batch_size * args.distributed_world_size,
-                    exclude_from_total=local_step in [0, len(train_loader)-1])
+                    exclude_from_total=local_step in [0, 1, 2, len(train_loader)-1])
 
             log_dict = {'P10':p_losses[0].item(), 'P50':p_losses[1].item(), 'P90':p_losses[2].item(), 'loss': loss.item(), 'items/s':ips}
             dllogger.log(step=global_step, data=log_dict, verbosity=1)
@@ -188,6 +166,10 @@ def main(args):
     cat_encodings = pickle.load(open(os.path.join(args.data_path,'cat_encodings.bin'), 'rb'))
 
     unscaled_predictions, unscaled_targets, _, _ = predict(args, config, model, test_loader, tgt_scalers, cat_encodings)
+
+    unscaled_predictions = torch.from_numpy(unscaled_predictions).contiguous()
+    unscaled_targets = torch.from_numpy(unscaled_targets).contiguous()
+
     losses = QuantileLoss(config)(unscaled_predictions, unscaled_targets)
     normalizer = unscaled_targets.abs().mean()
     quantiles = 2 * losses / normalizer
@@ -212,7 +194,7 @@ def validate(args, config, model, criterion, dataloader, global_step):
     torch.cuda.synchronize()
     validation_start = time.time()
     for batch in dataloader:
-        with torch.no_grad():
+        with torch.jit.fuser("fuser2"), amp.autocast(enabled=args.use_amp), torch.no_grad():
             batch = {key: tensor.cuda() if tensor.numel() else None for key, tensor in batch.items()}
             predictions = model(batch)
             targets = batch['target'][:,config.encoder_length:,:]

+ 3 - 0
PyTorch/Forecasting/TFT/triton/deployment_toolkit/bermuda/pyt.py

@@ -161,6 +161,8 @@ class PyTorchModelLoader(BaseLoader):
     def _trace(self, model: Model, dataloader_fn) -> Model:
         device = get_model_device(model.handle)
         dummy_input = get_sample_input(dataloader_fn(), device)
+        # Run dummy forward to initialize lazy modules
+        model.handle(*dummy_input)
         traced_model = torch.jit.trace_module(model.handle, {"forward": dummy_input})
         return Model(traced_model, precision=model.precision, inputs=model.inputs, outputs=model.outputs)
 
@@ -213,6 +215,7 @@ class PYT2ONNXSaver(BaseSaver):
 
         device = get_model_device(model.handle)
         dummy_input = get_sample_input(dataloader_fn(), device)
+        model.handle(*dummy_input)
         with torch.no_grad():
             torch.onnx.export(
                 model.handle,

+ 2 - 1
PyTorch/Forecasting/TFT/triton/requirements.txt

@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-model_navigator[pyt] @ git+https://github.com/triton-inference-server/[email protected].5#egg=model_navigator
+model_navigator[pyt] @ git+https://github.com/triton-inference-server/[email protected].7#egg=model_navigator
 natsort>=7.0.0
 networkx==2.5
 numpy
@@ -21,3 +21,4 @@ pycuda>=2019.1.2
 PyYAML>=5.2
 tabulate>=0.8.7
 tqdm>=4.44.1
+triton-model-analyzer==1.22.0

+ 1 - 1
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-A30.yaml

@@ -112,7 +112,7 @@ configurations:
   triton_gpu_engine_count: 2
   triton_max_queue_delay: 1
   triton_preferred_batch_sizes: 512 1024
-container_version: '21.12'
+container_version: '22.11'
 datasets:
 - name: electricity_bin
 - name: traffic_bin

+ 1 - 1
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-DGX-1-(1x-V100-32GB).yaml

@@ -112,7 +112,7 @@ configurations:
   triton_gpu_engine_count: 2
   triton_max_queue_delay: 1
   triton_preferred_batch_sizes: 512 1024
-container_version: '21.12'
+container_version: '22.11'
 datasets:
 - name: electricity_bin
 - name: traffic_bin

+ 1 - 1
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-DGX-A100-(1x-A100-80GB).yaml

@@ -112,7 +112,7 @@ configurations:
   triton_gpu_engine_count: 2
   triton_max_queue_delay: 1
   triton_preferred_batch_sizes: 512 1024
-container_version: '21.12'
+container_version: '22.11'
 datasets:
 - name: electricity_bin
 - name: traffic_bin

+ 1 - 1
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-T4.yaml

@@ -112,7 +112,7 @@ configurations:
   triton_gpu_engine_count: 2
   triton_max_queue_delay: 1
   triton_preferred_batch_sizes: 512 1024
-container_version: '21.12'
+container_version: '22.11'
 datasets:
 - name: electricity_bin
 - name: traffic_bin

+ 1 - 1
PyTorch/Forecasting/TFT/triton/scripts/docker/triton_inference_server.sh

@@ -41,7 +41,7 @@ docker run --rm -d \
   --ulimit memlock=-1 \
   --ulimit stack=67108864 \
   --ipc=host \
-  nvcr.io/nvidia/tritonserver:21.12-py3 tritonserver \
+  nvcr.io/nvidia/tritonserver:22.11-py3 tritonserver \
   --model-store=${MODEL_REPOSITORY_PATH} \
   --strict-model-config=false \
   --exit-on-error=true \