瀏覽代碼

caching of pre-trained weights added to entrypoints

Krzysztof Kudrynski 6 年之前
父節點
當前提交
161a4ea165
共有 1 個文件被更改,包括 17 次插入6 次删除
  1. 17 6
      hubconf.py

+ 17 - 6
hubconf.py

@@ -1,5 +1,7 @@
 import urllib.request
 import torch
+import os
+import sys
 
 # from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
 def checkpoint_from_distributed(state_dict):
@@ -54,6 +56,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
     from PyTorch.Recommendation.NCF import neumf as ncf
 
     fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
+    force_reload = "force_reload" in kwargs and kwargs["force_reload"]
 
     config = {'nb_users': None, 'nb_items': None, 'mf_dim': 64, 'mf_reg': 0.,
               'mlp_layer_sizes': [256, 256, 128, 64], 'mlp_layer_regs':[0, 0, 0, 0], 'dropout': 0.5}
@@ -63,8 +66,10 @@ def nvidia_ncf(pretrained=True, **kwargs):
             checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
         else:
             checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
-        ckpt_file = "ncf_ckpt.pt"
-        urllib.request.urlretrieve(checkpoint, ckpt_file)
+        ckpt_file = os.path.basename(checkpoint)
+        if not os.path.exists(ckpt_file) or force_reload:
+            sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
+            urllib.request.urlretrieve(checkpoint, ckpt_file)
         ckpt = torch.load(ckpt_file)
 
         if checkpoint_from_distributed(ckpt):
@@ -117,14 +122,17 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
     from PyTorch.SpeechSynthesis.Tacotron2.models import lstmcell_to_float, batchnorm_to_float
 
     fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
+    force_reload = "force_reload" in kwargs and kwargs["force_reload"]
 
     if pretrained:
         if fp16:
             checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
         else:
             checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
-        ckpt_file = "tacotron2_ckpt.pt"
-        urllib.request.urlretrieve(checkpoint, ckpt_file)
+        ckpt_file = os.path.basename(checkpoint)
+        if not os.path.exists(ckpt_file) or force_reload:
+            sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
+            urllib.request.urlretrieve(checkpoint, ckpt_file)
         ckpt = torch.load(ckpt_file)
         state_dict = ckpt['state_dict']
         if checkpoint_from_distributed(state_dict):
@@ -172,14 +180,17 @@ def nvidia_waveglow(pretrained=True, **kwargs):
     from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float
 
     fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
+    force_reload = "force_reload" in kwargs and kwargs["force_reload"]
 
     if pretrained:
         if fp16:
             checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
         else:
             checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
-        ckpt_file = "waveglow_ckpt.pt"
-        urllib.request.urlretrieve(checkpoint, ckpt_file)
+        ckpt_file = os.path.basename(checkpoint)
+        if not os.path.exists(ckpt_file) or force_reload:
+            sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
+            urllib.request.urlretrieve(checkpoint, ckpt_file)
         ckpt = torch.load(ckpt_file)
         state_dict = ckpt['state_dict']
         if checkpoint_from_distributed(state_dict):