entrypoints.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # *****************************************************************************
  2. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import urllib.request
  28. import torch
  29. import os
  30. import sys
  31. #from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
  32. def checkpoint_from_distributed(state_dict):
  33. """
  34. Checks whether checkpoint was generated by DistributedDataParallel. DDP
  35. wraps model in additional "module.", it needs to be unwrapped for single
  36. GPU inference.
  37. :param state_dict: model's state dict
  38. """
  39. ret = False
  40. for key, _ in state_dict.items():
  41. if key.find('module.') != -1:
  42. ret = True
  43. break
  44. return ret
  45. # from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
  46. def unwrap_distributed(state_dict):
  47. """
  48. Unwraps model from DistributedDataParallel.
  49. DDP wraps model in additional "module.", it needs to be removed for single
  50. GPU inference.
  51. :param state_dict: model's state dict
  52. """
  53. new_state_dict = {}
  54. for key, value in state_dict.items():
  55. new_key = key.replace('module.1.', '')
  56. new_key = new_key.replace('module.', '')
  57. new_state_dict[new_key] = value
  58. return new_state_dict
  59. def _download_checkpoint(checkpoint, force_reload):
  60. model_dir = os.path.join(torch.hub._get_torch_home(), 'checkpoints')
  61. if not os.path.exists(model_dir):
  62. os.makedirs(model_dir)
  63. ckpt_file = os.path.join(model_dir, os.path.basename(checkpoint))
  64. if not os.path.exists(ckpt_file) or force_reload:
  65. sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint))
  66. urllib.request.urlretrieve(checkpoint, ckpt_file)
  67. return ckpt_file
  68. def nvidia_hifigan(pretrained=True, **kwargs):
  69. """TODO
  70. """
  71. from hifigan import models as vocoder
  72. force_reload = "force_reload" in kwargs and kwargs["force_reload"]
  73. fp16 = "model_math" in kwargs and kwargs["model_math"] == "fp16"
  74. if pretrained:
  75. checkpoint = 'https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_mode-finetune_ds-ljs22khz/versions/21.08.0_amp/files/hifigan_gen_checkpoint_10000_ft.pt'
  76. ckpt_file = _download_checkpoint(checkpoint, force_reload)
  77. ckpt = torch.load(ckpt_file)
  78. state_dict = ckpt['generator']
  79. if checkpoint_from_distributed(state_dict):
  80. state_dict = unwrap_distributed(state_dict)
  81. config = ckpt['config']
  82. train_setup = ckpt.get('train_setup', {})
  83. else:
  84. config = {'upsample_rates': [8, 8, 2, 2], 'upsample_kernel_sizes': [16, 16, 4, 4],
  85. 'upsample_initial_channel': 512, 'resblock': '1', 'resblock_kernel_sizes': [3, 7, 11],
  86. 'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]]}
  87. for k,v in kwargs.items():
  88. if k in config.keys():
  89. config[k] = v
  90. train_setup = {}
  91. hifigan = vocoder.Generator(config)
  92. denoiser = None
  93. if pretrained:
  94. hifigan.load_state_dict(state_dict)
  95. hifigan.remove_weight_norm()
  96. denoiser = vocoder.Denoiser(hifigan, win_length=1024)
  97. if fp16:
  98. hifigan.half()
  99. denoiser.half()
  100. return hifigan, train_setup, denoiser