Просмотр исходного кода

[TFT/PyTorch] Adding TFT to Torchhub

Izzy Putterman 3 лет назад
Родитель
Сommit
a07a807572
3 измененных файлов с 101 добавлено и 1 удалено
  1. 1 1
      PyTorch/Forecasting/TFT/modeling.py
  2. 97 0
      PyTorch/Forecasting/TFT/tft_torchhub.py
  3. 3 0
      hubconf.py

+ 1 - 1
PyTorch/Forecasting/TFT/modeling.py

@@ -265,7 +265,7 @@ class InterpretableMultiHeadAttention(nn.Module):
         out = self.out_proj(m_attn_vec)
         out = self.out_dropout(out)
 
-        return out, attn_vec
+        return out, attn_prob
 
 
 

+ 97 - 0
PyTorch/Forecasting/TFT/tft_torchhub.py

@@ -0,0 +1,97 @@
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import urllib.request
+from zipfile import ZipFile
+import torch
+from torch.utils.data import DataLoader
+NGC_CHECKPOINT_URLS = {}
+NGC_CHECKPOINT_URLS["electricity"] = "https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_eletricity_amp/versions/21.06.0/zip"
+NGC_CHECKPOINT_URLS["traffic"] = "https://api.ngc.nvidia.com/v2/models/nvidia/tft_pyt_ckpt_base_traffic_amp/versions/21.06.0/zip"
+
+
+def _download_checkpoint(checkpoint, force_reload):
+    model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
+    if not os.path.exists(model_dir):
+        os.makedirs(model_dir)
+    ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
+    if not os.path.exists(ckpt_file) or force_reload:
+        sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
+        urllib.request.urlretrieve(checkpoint, ckpt_file)
+    with ZipFile(ckpt_file, "r") as zf:
+        zf.extractall(path=model_dir)
+    return os.path.join(model_dir, "checkpoint.pt")
+
+def nvidia_tft(pretrained=True, **kwargs):
+    from .modeling import TemporalFusionTransformer
+    """Constructs a TFT model.
+    For detailed information on model input and output, training recipies, inference and performance
+    visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com
+    Args (type[, default value]):
+        pretrained (bool, True): If True, returns a pretrained model.
+        dataset (str, 'electricity'): loads selected model type electricity or traffic. Defaults to electricity
+    """
+    ds_type = kwargs.get("dataset", "electricity")
+    ckpt = _download_checkpoint(NGC_CHECKPOINT_URLS[ds_type], True)
+    state_dict = torch.load(ckpt)
+    config = state_dict['config']
+    
+    model = TemporalFusionTransformer(config)
+    if pretrained:
+        model.load_state_dict(state_dict['model'])
+    model.eval()
+    return model
+
+def nvidia_tft_data_utils(**kwargs):
+
+    from .data_utils import TFTDataset
+    from .configuration import ElectricityConfig
+    class Processing:
+        @staticmethod
+        def download_data(path):
+            if not os.path.exists(os.path.join(path, "raw")):
+                os.makedirs(os.path.join(path, "raw"), exist_ok=True)
+            dataset_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip"
+            ckpt_file = os.path.join(path, "raw/electricity.zip")
+            if not os.path.exists(ckpt_file):
+                sys.stderr.write('Downloading checkpoint from {}\n'.format(dataset_url))
+                urllib.request.urlretrieve(dataset_url, ckpt_file)
+            with ZipFile(ckpt_file, "r") as zf:
+                zf.extractall(path=os.path.join(path, "raw/electricity/"))
+
+        @staticmethod
+        def preprocess(path):
+            config = ElectricityConfig()
+            if not os.path.exists(os.path.join(path, "processed")):
+                os.makedirs(os.path.join(path, "processed"), exist_ok=True)
+            from data_utils import standarize_electricity as standarize
+            from data_utils import preprocess
+            standarize(os.path.join(path, "raw/electricity"))
+            preprocess(os.path.join(path, "raw/electricity/standarized.csv"), os.path.join(path, "processed/electricity_bin/"), config)
+
+
+        @staticmethod
+        def get_batch(path):
+            config = ElectricityConfig()
+            test_split = TFTDataset(os.path.join(path, "processed/electricity_bin/", "test.csv"), config)
+            data_loader = DataLoader(test_split, batch_size=16, num_workers=0)
+            for i, batch in enumerate(data_loader):
+                if i == 40:
+                    break
+            return batch
+
+    return Processing()
+

+ 3 - 0
hubconf.py

@@ -24,3 +24,6 @@ from PyTorch.SpeechSynthesis.Tacotron2.tacotron2 import nvidia_tacotron2
 from PyTorch.SpeechSynthesis.Tacotron2.tacotron2 import nvidia_tts_utils
 from PyTorch.SpeechSynthesis.Tacotron2.waveglow import nvidia_waveglow
 sys.path.append(os.path.join(sys.path[0], 'PyTorch/SpeechSynthesis/Tacotron2'))
+
+from PyTorch.Forecasting.TFT.tft_torchhub import nvidia_tft, nvidia_tft_data_utils
+sys.path.append(os.path.join(sys.path[0], 'PyTorch/Forecasting/TFT'))