tft_torchhub.py 4.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import urllib.request
  17. from zipfile import ZipFile
  18. import torch
  19. from torch.utils.data import DataLoader
  20. NGC_CHECKPOINT_URLS = {}
  21. NGC_CHECKPOINT_URLS["electricity"] = "https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-electricity/versions/22.11.0_amp/zip"
  22. NGC_CHECKPOINT_URLS["traffic"] = "https://api.ngc.nvidia.com/v2/models/nvidia/dle/tft_base_pyt_ckpt_ds-traffic/versions/22.11.0_amp/zip"
  23. def _download_checkpoint(checkpoint, force_reload):
  24. model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
  25. if not os.path.exists(model_dir):
  26. os.makedirs(model_dir)
  27. ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
  28. if not os.path.exists(ckpt_file) or force_reload:
  29. sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
  30. urllib.request.urlretrieve(checkpoint, ckpt_file)
  31. with ZipFile(ckpt_file, "r") as zf:
  32. zf.extractall(path=model_dir)
  33. return os.path.join(model_dir, "checkpoint.pt")
  34. def nvidia_tft(pretrained=True, **kwargs):
  35. from .modeling import TemporalFusionTransformer
  36. """Constructs a TFT model.
  37. For detailed information on model input and output, training recipies, inference and performance
  38. visit: github.com/NVIDIA/DeepLearningExamples and/or ngc.nvidia.com
  39. Args (type[, default value]):
  40. pretrained (bool, True): If True, returns a pretrained model.
  41. dataset (str, 'electricity'): loads selected model type electricity or traffic. Defaults to electricity
  42. """
  43. ds_type = kwargs.get("dataset", "electricity")
  44. ckpt = _download_checkpoint(NGC_CHECKPOINT_URLS[ds_type], True)
  45. state_dict = torch.load(ckpt)
  46. config = state_dict['config']
  47. model = TemporalFusionTransformer(config)
  48. if pretrained:
  49. model.load_state_dict(state_dict['model'])
  50. model.eval()
  51. return model
  52. def nvidia_tft_data_utils(**kwargs):
  53. from .data_utils import TFTDataset
  54. from .configuration import ElectricityConfig
  55. class Processing:
  56. @staticmethod
  57. def download_data(path):
  58. if not os.path.exists(os.path.join(path, "raw")):
  59. os.makedirs(os.path.join(path, "raw"), exist_ok=True)
  60. dataset_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip"
  61. ckpt_file = os.path.join(path, "raw/electricity.zip")
  62. if not os.path.exists(ckpt_file):
  63. sys.stderr.write('Downloading checkpoint from {}\n'.format(dataset_url))
  64. urllib.request.urlretrieve(dataset_url, ckpt_file)
  65. with ZipFile(ckpt_file, "r") as zf:
  66. zf.extractall(path=os.path.join(path, "raw/electricity/"))
  67. @staticmethod
  68. def preprocess(path):
  69. config = ElectricityConfig()
  70. if not os.path.exists(os.path.join(path, "processed")):
  71. os.makedirs(os.path.join(path, "processed"), exist_ok=True)
  72. from data_utils import standarize_electricity as standarize
  73. from data_utils import preprocess
  74. standarize(os.path.join(path, "raw/electricity"))
  75. preprocess(os.path.join(path, "raw/electricity/standarized.csv"), os.path.join(path, "processed/electricity_bin/"), config)
  76. @staticmethod
  77. def get_batch(path):
  78. config = ElectricityConfig()
  79. test_split = TFTDataset(os.path.join(path, "processed/electricity_bin/", "test.csv"), config)
  80. data_loader = DataLoader(test_split, batch_size=16, num_workers=0)
  81. for i, batch in enumerate(data_loader):
  82. if i == 40:
  83. break
  84. return batch
  85. return Processing()