models.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import sys
  28. from typing import Optional
  29. from os.path import abspath, dirname
  30. import torch
  31. # enabling modules discovery from global entrypoint
  32. sys.path.append(abspath(dirname(__file__)+'/'))
  33. from fastpitch.model import FastPitch as _FastPitch
  34. from fastpitch.model_jit import FastPitch as _FastPitchJIT
  35. from tacotron2.model import Tacotron2
  36. from waveglow.model import WaveGlow
  37. def parse_model_args(model_name, parser, add_help=False):
  38. if model_name == 'Tacotron2':
  39. from tacotron2.arg_parser import parse_tacotron2_args
  40. return parse_tacotron2_args(parser, add_help)
  41. if model_name == 'WaveGlow':
  42. from waveglow.arg_parser import parse_waveglow_args
  43. return parse_waveglow_args(parser, add_help)
  44. elif model_name == 'FastPitch':
  45. from fastpitch.arg_parser import parse_fastpitch_args
  46. return parse_fastpitch_args(parser, add_help)
  47. else:
  48. raise NotImplementedError(model_name)
  49. def batchnorm_to_float(module):
  50. """Converts batch norm to FP32"""
  51. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  52. module.float()
  53. for child in module.children():
  54. batchnorm_to_float(child)
  55. return module
  56. def init_bn(module):
  57. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  58. if module.affine:
  59. module.weight.data.uniform_()
  60. for child in module.children():
  61. init_bn(child)
  62. def get_model(model_name, model_config, device,
  63. uniform_initialize_bn_weight=False, forward_is_infer=False,
  64. jitable=False):
  65. """ Code chooses a model based on name"""
  66. model = None
  67. if model_name == 'Tacotron2':
  68. if forward_is_infer:
  69. class Tacotron2__forward_is_infer(Tacotron2):
  70. def forward(self, inputs, input_lengths):
  71. return self.infer(inputs, input_lengths)
  72. model = Tacotron2__forward_is_infer(**model_config)
  73. else:
  74. model = Tacotron2(**model_config)
  75. elif model_name == 'WaveGlow':
  76. if forward_is_infer:
  77. class WaveGlow__forward_is_infer(WaveGlow):
  78. def forward(self, spect, sigma=1.0):
  79. return self.infer(spect, sigma)
  80. model = WaveGlow__forward_is_infer(**model_config)
  81. else:
  82. model = WaveGlow(**model_config)
  83. elif model_name == 'FastPitch':
  84. if forward_is_infer:
  85. if jitable:
  86. class FastPitch__forward_is_infer(_FastPitchJIT):
  87. def forward(self, inputs, input_lengths, pace: float = 1.0,
  88. dur_tgt: Optional[torch.Tensor] = None,
  89. pitch_tgt: Optional[torch.Tensor] = None):
  90. return self.infer(inputs, input_lengths, pace=pace,
  91. dur_tgt=dur_tgt, pitch_tgt=pitch_tgt)
  92. else:
  93. class FastPitch__forward_is_infer(_FastPitch):
  94. def forward(self, inputs, input_lengths, pace: float = 1.0,
  95. dur_tgt: Optional[torch.Tensor] = None,
  96. pitch_tgt: Optional[torch.Tensor] = None,
  97. pitch_transform=None):
  98. return self.infer(inputs, input_lengths, pace=pace,
  99. dur_tgt=dur_tgt, pitch_tgt=pitch_tgt,
  100. pitch_transform=pitch_transform)
  101. model = FastPitch__forward_is_infer(**model_config)
  102. else:
  103. model = _FastPitch(**model_config)
  104. else:
  105. raise NotImplementedError(model_name)
  106. if uniform_initialize_bn_weight:
  107. init_bn(model)
  108. return model.to(device)
  109. def get_model_config(model_name, args):
  110. """ Code chooses a model based on name"""
  111. if model_name == 'Tacotron2':
  112. model_config = dict(
  113. # optimization
  114. mask_padding=args.mask_padding,
  115. # audio
  116. n_mel_channels=args.n_mel_channels,
  117. # symbols
  118. n_symbols=args.n_symbols,
  119. symbols_embedding_dim=args.symbols_embedding_dim,
  120. # encoder
  121. encoder_kernel_size=args.encoder_kernel_size,
  122. encoder_n_convolutions=args.encoder_n_convolutions,
  123. encoder_embedding_dim=args.encoder_embedding_dim,
  124. # attention
  125. attention_rnn_dim=args.attention_rnn_dim,
  126. attention_dim=args.attention_dim,
  127. # attention location
  128. attention_location_n_filters=args.attention_location_n_filters,
  129. attention_location_kernel_size=args.attention_location_kernel_size,
  130. # decoder
  131. n_frames_per_step=args.n_frames_per_step,
  132. decoder_rnn_dim=args.decoder_rnn_dim,
  133. prenet_dim=args.prenet_dim,
  134. max_decoder_steps=args.max_decoder_steps,
  135. gate_threshold=args.gate_threshold,
  136. p_attention_dropout=args.p_attention_dropout,
  137. p_decoder_dropout=args.p_decoder_dropout,
  138. # postnet
  139. postnet_embedding_dim=args.postnet_embedding_dim,
  140. postnet_kernel_size=args.postnet_kernel_size,
  141. postnet_n_convolutions=args.postnet_n_convolutions,
  142. decoder_no_early_stopping=args.decoder_no_early_stopping,
  143. )
  144. return model_config
  145. elif model_name == 'WaveGlow':
  146. model_config = dict(
  147. n_mel_channels=args.n_mel_channels,
  148. n_flows=args.flows,
  149. n_group=args.groups,
  150. n_early_every=args.early_every,
  151. n_early_size=args.early_size,
  152. WN_config=dict(
  153. n_layers=args.wn_layers,
  154. kernel_size=args.wn_kernel_size,
  155. n_channels=args.wn_channels
  156. )
  157. )
  158. return model_config
  159. elif model_name == 'FastPitch':
  160. model_config = dict(
  161. # io
  162. n_mel_channels=args.n_mel_channels,
  163. max_seq_len=args.max_seq_len,
  164. # symbols
  165. n_symbols=args.n_symbols,
  166. symbols_embedding_dim=args.symbols_embedding_dim,
  167. # input FFT
  168. in_fft_n_layers=args.in_fft_n_layers,
  169. in_fft_n_heads=args.in_fft_n_heads,
  170. in_fft_d_head=args.in_fft_d_head,
  171. in_fft_conv1d_kernel_size=args.in_fft_conv1d_kernel_size,
  172. in_fft_conv1d_filter_size=args.in_fft_conv1d_filter_size,
  173. in_fft_output_size=args.in_fft_output_size,
  174. p_in_fft_dropout=args.p_in_fft_dropout,
  175. p_in_fft_dropatt=args.p_in_fft_dropatt,
  176. p_in_fft_dropemb=args.p_in_fft_dropemb,
  177. # output FFT
  178. out_fft_n_layers=args.out_fft_n_layers,
  179. out_fft_n_heads=args.out_fft_n_heads,
  180. out_fft_d_head=args.out_fft_d_head,
  181. out_fft_conv1d_kernel_size=args.out_fft_conv1d_kernel_size,
  182. out_fft_conv1d_filter_size=args.out_fft_conv1d_filter_size,
  183. out_fft_output_size=args.out_fft_output_size,
  184. p_out_fft_dropout=args.p_out_fft_dropout,
  185. p_out_fft_dropatt=args.p_out_fft_dropatt,
  186. p_out_fft_dropemb=args.p_out_fft_dropemb,
  187. # duration predictor
  188. dur_predictor_kernel_size=args.dur_predictor_kernel_size,
  189. dur_predictor_filter_size=args.dur_predictor_filter_size,
  190. p_dur_predictor_dropout=args.p_dur_predictor_dropout,
  191. dur_predictor_n_layers=args.dur_predictor_n_layers,
  192. # pitch predictor
  193. pitch_predictor_kernel_size=args.pitch_predictor_kernel_size,
  194. pitch_predictor_filter_size=args.pitch_predictor_filter_size,
  195. p_pitch_predictor_dropout=args.p_pitch_predictor_dropout,
  196. pitch_predictor_n_layers=args.pitch_predictor_n_layers,
  197. )
  198. return model_config
  199. else:
  200. raise NotImplementedError(model_name)