Explorar el Código

[TFT/PyT] Update Checkpoints

Izzy Putterman hace 2 años
padre
commit
a6c678ef03

+ 2 - 4
PyTorch/Forecasting/TFT/tft_torchhub.py

@@ -19,10 +19,8 @@ from zipfile import ZipFile
 import torch
 import torch
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 NGC_CHECKPOINT_URLS = {}
 NGC_CHECKPOINT_URLS = {}
-NGC_CHECKPOINT_URLS["electricity"] = "https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip"
-NGC_CHECKPOINT_URLS["traffic"] = "https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip"
-
-
+NGC_CHECKPOINT_URLS["electricity"] = "https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip"
+NGC_CHECKPOINT_URLS["traffic"] = "https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip"
 def _download_checkpoint(checkpoint, force_reload):
 def _download_checkpoint(checkpoint, force_reload):
     model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
     model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
     if not os.path.exists(model_dir):
     if not os.path.exists(model_dir):

+ 5 - 2
PyTorch/Forecasting/TFT/triton/README.md

@@ -146,6 +146,9 @@ NVIDIA DGX A100 (1x A100 80GB): bash ./triton/runner/start_NVIDIA-DGX-A100-\(1x-
 NVIDIA T4: bash ./triton/runner/start_NVIDIA-T4.sh
 NVIDIA T4: bash ./triton/runner/start_NVIDIA-T4.sh
 ```
 ```
 
 
+If one encounters an error like `the provided PTX was compiled with an unsupported toolchain`, follow the steps in
+[Step by step deployment process](#step-by-step-deployment-process).
+
 ## Performance
 ## Performance
 The performance measurements in this document were conducted at the time of publication and may not reflect
 The performance measurements in this document were conducted at the time of publication and may not reflect
 the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to
 the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to
@@ -2077,7 +2080,7 @@ Please use the data download from the [Main QSG](https://github.com/NVIDIA/DeepL
 #### Prepare Checkpoint
 #### Prepare Checkpoint
 Please place a `checkpoint.pt` from TFT trained on electricity in `runner_workspace/checkpoints/electricity_bin/`.  Note that the `electricity_bin` 
 Please place a `checkpoint.pt` from TFT trained on electricity in `runner_workspace/checkpoints/electricity_bin/`.  Note that the `electricity_bin` 
 subdirectory may not be created yet.  In addition one can download a zip archive of a trained checkpoint 
 subdirectory may not be created yet.  In addition one can download a zip archive of a trained checkpoint 
-[here](https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip)
+[here](https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip)
 
 
 #### Setup Container
 #### Setup Container
 Build and run a container that extends the NGC PyTorch container with the Triton Inference Server client libraries and dependencies.
 Build and run a container that extends the NGC PyTorch container with the Triton Inference Server client libraries and dependencies.
@@ -2242,7 +2245,7 @@ mkdir -p ${SHARED_DIR}/input_data
 python triton/prepare_input_data.py \
 python triton/prepare_input_data.py \
     --input-data-dir ${SHARED_DIR}/input_data/ \
     --input-data-dir ${SHARED_DIR}/input_data/ \
     --dataset ${DATASETS_DIR}/${DATASET} \
     --dataset ${DATASETS_DIR}/${DATASET} \
-    --checkpoint ${CHECKPOINT_DIR}/ \
+    --checkpoint ${CHECKPOINT_DIR}/
 ```
 ```
 
 
 </details>
 </details>

+ 2 - 2
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-A30.yaml

@@ -1,8 +1,8 @@
 checkpoints:
 checkpoints:
 - name: electricity_bin
 - name: electricity_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip
 - name: traffic_bin
 - name: traffic_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip
 configurations:
 configurations:
 - accelerator: none
 - accelerator: none
   batch_size:
   batch_size:

+ 2 - 2
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-DGX-1-(1x-V100-32GB).yaml

@@ -1,8 +1,8 @@
 checkpoints:
 checkpoints:
 - name: electricity_bin
 - name: electricity_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip
 - name: traffic_bin
 - name: traffic_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip
 configurations:
 configurations:
 - accelerator: none
 - accelerator: none
   batch_size:
   batch_size:

+ 2 - 2
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-DGX-A100-(1x-A100-80GB).yaml

@@ -1,8 +1,8 @@
 checkpoints:
 checkpoints:
 - name: electricity_bin
 - name: electricity_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip
 - name: traffic_bin
 - name: traffic_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip
 configurations:
 configurations:
 - accelerator: none
 - accelerator: none
   batch_size:
   batch_size:

+ 2 - 2
PyTorch/Forecasting/TFT/triton/runner/config_NVIDIA-T4.yaml

@@ -1,8 +1,8 @@
 checkpoints:
 checkpoints:
 - name: electricity_bin
 - name: electricity_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip
 - name: traffic_bin
 - name: traffic_bin
-  url: https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip
+  url: https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip
 configurations:
 configurations:
 - accelerator: none
 - accelerator: none
   batch_size:
   batch_size: