Kaynağa Gözat

[TFT/PyT] Update Checkpoints

Izzy Putterman 2 yıl önce
ebeveyn
işleme
a6c678ef03

+ 2 - 4
PyTorch/Forecasting/TFT/tft_torchhub.py

@@ -19,10 +19,8 @@ from zipfile import ZipFile
 import torch
 from torch.utils.data import DataLoader
 NGC_CHECKPOINT_URLS = {}
-NGC_CHECKPOINT_URLS["electricity"] = "https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip"
-NGC_CHECKPOINT_URLS["traffic"] = "https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip"
-
-
+NGC_CHECKPOINT_URLS["electricity"] = "https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip"
+NGC_CHECKPOINT_URLS["traffic"] = "https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip"
 def _download_checkpoint(checkpoint, force_reload):
     model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
     if not os.path.exists(model_dir):

+ 5 - 2
PyTorch/Forecasting/TFT/triton/README.md

@@ -146,6 +146,9 @@ NVIDIA DGX A100 (1x A100 80GB): bash ./triton/runner/start_NVIDIA-DGX-A100-\(1x-
 NVIDIA T4: bash ./triton/runner/start_NVIDIA-T4.sh
 ```
 
+If one encounters an error like `the provided PTX was compiled with an unsupported toolchain`, follow the steps in
+[Step by step deployment process](#step-by-step-deployment-process).
+
 ## Performance
 The performance measurements in this document were conducted at the time of publication and may not reflect
 the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to
@@ -2077,7 +2080,7 @@ Please use the data download from the [Main QSG](https://github.com/NVIDIA/DeepL
 #### Prepare Checkpoint
 Please place a `checkpoint.pt` from TFT trained on electricity in `runner_workspace/checkpoints/electricity_bin/`.  Note that the `electricity_bin` 
 subdirectory may not be created yet.  In addition one can download a zip archive of a trained checkpoint 
-[here](https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip)
+[here](https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip)
 
 #### Setup Container
 Build and run a container that extends the NGC PyTorch container with the Triton Inference Server client libraries and dependencies.
@@ -2242,7 +2245,7 @@ mkdir -p ${SHARED_DIR}/input_data
 python triton/prepare_input_data.py \
     --input-data-dir ${SHARED_DIR}/input_data/ \
     --dataset ${DATASETS_DIR}/${DATASET} \
-    --checkpoint ${CHECKPOINT_DIR}/ \
+    --checkpoint ${CHECKPOINT_DIR}/
 ```
 
 </details>

+ 2 - 2
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-A30.yaml

@@ -1,8 +1,8 @@
 checkpoints:
 - name: electricity_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip
 - name: traffic_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip
 configurations:
 - accelerator: none
   batch_size:

+ 2 - 2
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-DGX-1-(1x-V100-32GB).yaml

@@ -1,8 +1,8 @@
 checkpoints:
 - name: electricity_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip
 - name: traffic_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip
 configurations:
 - accelerator: none
   batch_size:

+ 2 - 2
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-DGX-A100-(1x-A100-80GB).yaml

@@ -1,8 +1,8 @@
 checkpoints:
 - name: electricity_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip
 - name: traffic_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip
 configurations:
 - accelerator: none
   batch_size:

+ 2 - 2
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-T4.yaml

@@ -1,8 +1,8 @@
 checkpoints:
 - name: electricity_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip
 - name: traffic_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip
 configurations:
 - accelerator: none
   batch_size: