| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- # *****************************************************************************
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- # * Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # * Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- # * Neither the name of the NVIDIA CORPORATION nor the
- # names of its contributors may be used to endorse or promote products
- # derived from this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- #
- # *****************************************************************************
- import sys
- from os.path import abspath, dirname
- # enabling modules discovery from global entrypoint
- sys.path.append(abspath(dirname(__file__)+'/'))
- from tacotron2.model import Tacotron2
- from waveglow.model import WaveGlow
- import torch
- def model_parser(model_name, parser, add_help=False):
- if model_name == 'Tacotron2':
- from tacotron2.arg_parser import tacotron2_parser
- return tacotron2_parser(parser, add_help)
- if model_name == 'WaveGlow':
- from waveglow.arg_parser import waveglow_parser
- return waveglow_parser(parser, add_help)
- else:
- raise NotImplementedError(model_name)
- def batchnorm_to_float(module):
- """Converts batch norm to FP32"""
- if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- module.float()
- for child in module.children():
- batchnorm_to_float(child)
- return module
- def init_bn(module):
- if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- if module.affine:
- module.weight.data.uniform_()
- for child in module.children():
- init_bn(child)
- def get_model(model_name, model_config, cpu_run,
- uniform_initialize_bn_weight=False, forward_is_infer=False,
- jittable=False):
- """ Code chooses a model based on name"""
- model = None
- if model_name == 'Tacotron2':
- if forward_is_infer:
- class Tacotron2__forward_is_infer(Tacotron2):
- def forward(self, inputs, input_lengths):
- return self.infer(inputs, input_lengths)
- model = Tacotron2__forward_is_infer(**model_config)
- else:
- model = Tacotron2(**model_config)
- elif model_name == 'WaveGlow':
- model = WaveGlow(**model_config)
- if forward_is_infer:
- model.forward = model.infer
- else:
- raise NotImplementedError(model_name)
- if uniform_initialize_bn_weight:
- init_bn(model)
- if not cpu_run:
- model = model.cuda()
- return model
- def get_model_config(model_name, args):
- """ Code chooses a model based on name"""
- if model_name == 'Tacotron2':
- model_config = dict(
- # optimization
- mask_padding=args.mask_padding,
- # audio
- n_mel_channels=args.n_mel_channels,
- # symbols
- n_symbols=args.n_symbols,
- symbols_embedding_dim=args.symbols_embedding_dim,
- # encoder
- encoder_kernel_size=args.encoder_kernel_size,
- encoder_n_convolutions=args.encoder_n_convolutions,
- encoder_embedding_dim=args.encoder_embedding_dim,
- # attention
- attention_rnn_dim=args.attention_rnn_dim,
- attention_dim=args.attention_dim,
- # attention location
- attention_location_n_filters=args.attention_location_n_filters,
- attention_location_kernel_size=args.attention_location_kernel_size,
- # decoder
- n_frames_per_step=args.n_frames_per_step,
- decoder_rnn_dim=args.decoder_rnn_dim,
- prenet_dim=args.prenet_dim,
- max_decoder_steps=args.max_decoder_steps,
- gate_threshold=args.gate_threshold,
- p_attention_dropout=args.p_attention_dropout,
- p_decoder_dropout=args.p_decoder_dropout,
- # postnet
- postnet_embedding_dim=args.postnet_embedding_dim,
- postnet_kernel_size=args.postnet_kernel_size,
- postnet_n_convolutions=args.postnet_n_convolutions,
- decoder_no_early_stopping=args.decoder_no_early_stopping
- )
- return model_config
- elif model_name == 'WaveGlow':
- model_config = dict(
- n_mel_channels=args.n_mel_channels,
- n_flows=args.flows,
- n_group=args.groups,
- n_early_every=args.early_every,
- n_early_size=args.early_size,
- WN_config=dict(
- n_layers=args.wn_layers,
- kernel_size=args.wn_kernel_size,
- n_channels=args.wn_channels
- )
- )
- return model_config
- else:
- raise NotImplementedError(model_name)
|