models.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import sys
  28. from os.path import abspath, dirname
  29. # enabling modules discovery from global entrypoint
  30. sys.path.append(abspath(dirname(__file__)+'/'))
  31. from tacotron2.model import Tacotron2
  32. from waveglow.model import WaveGlow
  33. import torch
  34. def model_parser(model_name, parser, add_help=False):
  35. if model_name == 'Tacotron2':
  36. from tacotron2.arg_parser import tacotron2_parser
  37. return tacotron2_parser(parser, add_help)
  38. if model_name == 'WaveGlow':
  39. from waveglow.arg_parser import waveglow_parser
  40. return waveglow_parser(parser, add_help)
  41. else:
  42. raise NotImplementedError(model_name)
  43. def batchnorm_to_float(module):
  44. """Converts batch norm to FP32"""
  45. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  46. module.float()
  47. for child in module.children():
  48. batchnorm_to_float(child)
  49. return module
  50. def init_bn(module):
  51. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  52. if module.affine:
  53. module.weight.data.uniform_()
  54. for child in module.children():
  55. init_bn(child)
  56. def get_model(model_name, model_config, cpu_run,
  57. uniform_initialize_bn_weight=False, forward_is_infer=False,
  58. jittable=False):
  59. """ Code chooses a model based on name"""
  60. model = None
  61. if model_name == 'Tacotron2':
  62. if forward_is_infer:
  63. class Tacotron2__forward_is_infer(Tacotron2):
  64. def forward(self, inputs, input_lengths):
  65. return self.infer(inputs, input_lengths)
  66. model = Tacotron2__forward_is_infer(**model_config)
  67. else:
  68. model = Tacotron2(**model_config)
  69. elif model_name == 'WaveGlow':
  70. model = WaveGlow(**model_config)
  71. if forward_is_infer:
  72. model.forward = model.infer
  73. else:
  74. raise NotImplementedError(model_name)
  75. if uniform_initialize_bn_weight:
  76. init_bn(model)
  77. if not cpu_run:
  78. model = model.cuda()
  79. return model
  80. def get_model_config(model_name, args):
  81. """ Code chooses a model based on name"""
  82. if model_name == 'Tacotron2':
  83. model_config = dict(
  84. # optimization
  85. mask_padding=args.mask_padding,
  86. # audio
  87. n_mel_channels=args.n_mel_channels,
  88. # symbols
  89. n_symbols=args.n_symbols,
  90. symbols_embedding_dim=args.symbols_embedding_dim,
  91. # encoder
  92. encoder_kernel_size=args.encoder_kernel_size,
  93. encoder_n_convolutions=args.encoder_n_convolutions,
  94. encoder_embedding_dim=args.encoder_embedding_dim,
  95. # attention
  96. attention_rnn_dim=args.attention_rnn_dim,
  97. attention_dim=args.attention_dim,
  98. # attention location
  99. attention_location_n_filters=args.attention_location_n_filters,
  100. attention_location_kernel_size=args.attention_location_kernel_size,
  101. # decoder
  102. n_frames_per_step=args.n_frames_per_step,
  103. decoder_rnn_dim=args.decoder_rnn_dim,
  104. prenet_dim=args.prenet_dim,
  105. max_decoder_steps=args.max_decoder_steps,
  106. gate_threshold=args.gate_threshold,
  107. p_attention_dropout=args.p_attention_dropout,
  108. p_decoder_dropout=args.p_decoder_dropout,
  109. # postnet
  110. postnet_embedding_dim=args.postnet_embedding_dim,
  111. postnet_kernel_size=args.postnet_kernel_size,
  112. postnet_n_convolutions=args.postnet_n_convolutions,
  113. decoder_no_early_stopping=args.decoder_no_early_stopping
  114. )
  115. return model_config
  116. elif model_name == 'WaveGlow':
  117. model_config = dict(
  118. n_mel_channels=args.n_mel_channels,
  119. n_flows=args.flows,
  120. n_group=args.groups,
  121. n_early_every=args.early_every,
  122. n_early_size=args.early_size,
  123. WN_config=dict(
  124. n_layers=args.wn_layers,
  125. kernel_size=args.wn_kernel_size,
  126. n_channels=args.wn_channels
  127. )
  128. )
  129. return model_config
  130. else:
  131. raise NotImplementedError(model_name)