|
@@ -1,5 +1,7 @@
|
|
|
import urllib.request
|
|
import urllib.request
|
|
|
import torch
|
|
import torch
|
|
|
|
|
+import os
|
|
|
|
|
+import sys
|
|
|
|
|
|
|
|
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
|
|
# from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
|
|
|
def checkpoint_from_distributed(state_dict):
|
|
def checkpoint_from_distributed(state_dict):
|
|
@@ -54,6 +56,7 @@ def nvidia_ncf(pretrained=True, **kwargs):
|
|
|
from PyTorch.Recommendation.NCF import neumf as ncf
|
|
from PyTorch.Recommendation.NCF import neumf as ncf
|
|
|
|
|
|
|
|
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
|
|
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
|
|
|
|
|
+ force_reload = "force_reload" in kwargs and kwargs["force_reload"]
|
|
|
|
|
|
|
|
config = {'nb_users': None, 'nb_items': None, 'mf_dim': 64, 'mf_reg': 0.,
|
|
config = {'nb_users': None, 'nb_items': None, 'mf_dim': 64, 'mf_reg': 0.,
|
|
|
'mlp_layer_sizes': [256, 256, 128, 64], 'mlp_layer_regs':[0, 0, 0, 0], 'dropout': 0.5}
|
|
'mlp_layer_sizes': [256, 256, 128, 64], 'mlp_layer_regs':[0, 0, 0, 0], 'dropout': 0.5}
|
|
@@ -63,8 +66,10 @@ def nvidia_ncf(pretrained=True, **kwargs):
|
|
|
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
|
|
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225'
|
|
|
else:
|
|
else:
|
|
|
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
|
|
checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225'
|
|
|
- ckpt_file = "ncf_ckpt.pt"
|
|
|
|
|
- urllib.request.urlretrieve(checkpoint, ckpt_file)
|
|
|
|
|
|
|
+ ckpt_file = os.path.basename(checkpoint)
|
|
|
|
|
+ if not os.path.exists(ckpt_file) or force_reload:
|
|
|
|
|
+ sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
|
|
|
|
+ urllib.request.urlretrieve(checkpoint, ckpt_file)
|
|
|
ckpt = torch.load(ckpt_file)
|
|
ckpt = torch.load(ckpt_file)
|
|
|
|
|
|
|
|
if checkpoint_from_distributed(ckpt):
|
|
if checkpoint_from_distributed(ckpt):
|
|
@@ -117,14 +122,17 @@ def nvidia_tacotron2(pretrained=True, **kwargs):
|
|
|
from PyTorch.SpeechSynthesis.Tacotron2.models import lstmcell_to_float, batchnorm_to_float
|
|
from PyTorch.SpeechSynthesis.Tacotron2.models import lstmcell_to_float, batchnorm_to_float
|
|
|
|
|
|
|
|
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
|
|
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
|
|
|
|
|
+ force_reload = "force_reload" in kwargs and kwargs["force_reload"]
|
|
|
|
|
|
|
|
if pretrained:
|
|
if pretrained:
|
|
|
if fp16:
|
|
if fp16:
|
|
|
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
|
|
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306'
|
|
|
else:
|
|
else:
|
|
|
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
|
|
checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306'
|
|
|
- ckpt_file = "tacotron2_ckpt.pt"
|
|
|
|
|
- urllib.request.urlretrieve(checkpoint, ckpt_file)
|
|
|
|
|
|
|
+ ckpt_file = os.path.basename(checkpoint)
|
|
|
|
|
+ if not os.path.exists(ckpt_file) or force_reload:
|
|
|
|
|
+ sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
|
|
|
|
+ urllib.request.urlretrieve(checkpoint, ckpt_file)
|
|
|
ckpt = torch.load(ckpt_file)
|
|
ckpt = torch.load(ckpt_file)
|
|
|
state_dict = ckpt['state_dict']
|
|
state_dict = ckpt['state_dict']
|
|
|
if checkpoint_from_distributed(state_dict):
|
|
if checkpoint_from_distributed(state_dict):
|
|
@@ -172,14 +180,17 @@ def nvidia_waveglow(pretrained=True, **kwargs):
|
|
|
from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float
|
|
from PyTorch.SpeechSynthesis.Tacotron2.models import batchnorm_to_float
|
|
|
|
|
|
|
|
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
|
|
fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
|
|
|
|
|
+ force_reload = "force_reload" in kwargs and kwargs["force_reload"]
|
|
|
|
|
|
|
|
if pretrained:
|
|
if pretrained:
|
|
|
if fp16:
|
|
if fp16:
|
|
|
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
|
|
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306'
|
|
|
else:
|
|
else:
|
|
|
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
|
|
checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306'
|
|
|
- ckpt_file = "waveglow_ckpt.pt"
|
|
|
|
|
- urllib.request.urlretrieve(checkpoint, ckpt_file)
|
|
|
|
|
|
|
+ ckpt_file = os.path.basename(checkpoint)
|
|
|
|
|
+ if not os.path.exists(ckpt_file) or force_reload:
|
|
|
|
|
+ sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
|
|
|
|
|
+ urllib.request.urlretrieve(checkpoint, ckpt_file)
|
|
|
ckpt = torch.load(ckpt_file)
|
|
ckpt = torch.load(ckpt_file)
|
|
|
state_dict = ckpt['state_dict']
|
|
state_dict = ckpt['state_dict']
|
|
|
if checkpoint_from_distributed(state_dict):
|
|
if checkpoint_from_distributed(state_dict):
|