models.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import json
  16. import re
  17. import sys
  18. import torch
  19. from common.text.symbols import get_symbols, get_pad_idx
  20. from common.utils import DefaultAttrDict, AttrDict
  21. from fastpitch.model import FastPitch
  22. from fastpitch.model_jit import FastPitchJIT
  23. from hifigan.models import Generator
  24. try:
  25. from waveglow.model import WaveGlow
  26. from waveglow import model as glow
  27. from waveglow.denoiser import Denoiser
  28. sys.modules['glow'] = glow
  29. except ImportError:
  30. print("WARNING: Couldn't import WaveGlow")
  31. def parse_model_args(model_name, parser, add_help=False):
  32. if model_name == 'FastPitch':
  33. from fastpitch import arg_parser
  34. return arg_parser.parse_fastpitch_args(parser, add_help)
  35. elif model_name == 'HiFi-GAN':
  36. from hifigan import arg_parser
  37. return arg_parser.parse_hifigan_args(parser, add_help)
  38. elif model_name == 'WaveGlow':
  39. from waveglow.arg_parser import parse_waveglow_args
  40. return parse_waveglow_args(parser, add_help)
  41. else:
  42. raise NotImplementedError(model_name)
  43. def get_model(model_name, model_config, device, bn_uniform_init=False,
  44. forward_is_infer=False, jitable=False):
  45. """Chooses a model based on name"""
  46. del bn_uniform_init # unused (old name: uniform_initialize_bn_weight)
  47. if model_name == 'FastPitch':
  48. if jitable:
  49. model = FastPitchJIT(**model_config)
  50. else:
  51. model = FastPitch(**model_config)
  52. elif model_name == 'HiFi-GAN':
  53. model = Generator(model_config)
  54. elif model_name == 'WaveGlow':
  55. model = WaveGlow(**model_config)
  56. else:
  57. raise NotImplementedError(model_name)
  58. if forward_is_infer and hasattr(model, 'infer'):
  59. model.forward = model.infer
  60. return model.to(device)
  61. def get_model_config(model_name, args, ckpt_config=None):
  62. """ Get config needed to instantiate the model """
  63. # Mark keys missing in `args` with an object (None is ambiguous)
  64. _missing = object()
  65. args = DefaultAttrDict(lambda: _missing, vars(args))
  66. # `ckpt_config` is loaded from the checkpoint and has the priority
  67. # `model_config` is based on args and fills empty slots in `ckpt_config`
  68. if model_name == 'FastPitch':
  69. model_config = dict(
  70. # io
  71. n_mel_channels=args.n_mel_channels,
  72. # symbols
  73. n_symbols=(len(get_symbols(args.symbol_set))
  74. if args.symbol_set is not _missing else _missing),
  75. padding_idx=(get_pad_idx(args.symbol_set)
  76. if args.symbol_set is not _missing else _missing),
  77. symbols_embedding_dim=args.symbols_embedding_dim,
  78. # input FFT
  79. in_fft_n_layers=args.in_fft_n_layers,
  80. in_fft_n_heads=args.in_fft_n_heads,
  81. in_fft_d_head=args.in_fft_d_head,
  82. in_fft_conv1d_kernel_size=args.in_fft_conv1d_kernel_size,
  83. in_fft_conv1d_filter_size=args.in_fft_conv1d_filter_size,
  84. in_fft_output_size=args.in_fft_output_size,
  85. p_in_fft_dropout=args.p_in_fft_dropout,
  86. p_in_fft_dropatt=args.p_in_fft_dropatt,
  87. p_in_fft_dropemb=args.p_in_fft_dropemb,
  88. # output FFT
  89. out_fft_n_layers=args.out_fft_n_layers,
  90. out_fft_n_heads=args.out_fft_n_heads,
  91. out_fft_d_head=args.out_fft_d_head,
  92. out_fft_conv1d_kernel_size=args.out_fft_conv1d_kernel_size,
  93. out_fft_conv1d_filter_size=args.out_fft_conv1d_filter_size,
  94. out_fft_output_size=args.out_fft_output_size,
  95. p_out_fft_dropout=args.p_out_fft_dropout,
  96. p_out_fft_dropatt=args.p_out_fft_dropatt,
  97. p_out_fft_dropemb=args.p_out_fft_dropemb,
  98. # duration predictor
  99. dur_predictor_kernel_size=args.dur_predictor_kernel_size,
  100. dur_predictor_filter_size=args.dur_predictor_filter_size,
  101. p_dur_predictor_dropout=args.p_dur_predictor_dropout,
  102. dur_predictor_n_layers=args.dur_predictor_n_layers,
  103. # pitch predictor
  104. pitch_predictor_kernel_size=args.pitch_predictor_kernel_size,
  105. pitch_predictor_filter_size=args.pitch_predictor_filter_size,
  106. p_pitch_predictor_dropout=args.p_pitch_predictor_dropout,
  107. pitch_predictor_n_layers=args.pitch_predictor_n_layers,
  108. # pitch conditioning
  109. pitch_embedding_kernel_size=args.pitch_embedding_kernel_size,
  110. # speakers parameters
  111. n_speakers=args.n_speakers,
  112. speaker_emb_weight=args.speaker_emb_weight,
  113. # energy predictor
  114. energy_predictor_kernel_size=args.energy_predictor_kernel_size,
  115. energy_predictor_filter_size=args.energy_predictor_filter_size,
  116. p_energy_predictor_dropout=args.p_energy_predictor_dropout,
  117. energy_predictor_n_layers=args.energy_predictor_n_layers,
  118. # energy conditioning
  119. energy_conditioning=args.energy_conditioning,
  120. energy_embedding_kernel_size=args.energy_embedding_kernel_size,
  121. )
  122. elif model_name == 'HiFi-GAN':
  123. if args.hifigan_config is not None:
  124. assert ckpt_config is None, (
  125. "Supplied --hifigan-config, but the checkpoint has a config. "
  126. "Drop the flag or remove the config from the checkpoint file.")
  127. print(f'HiFi-GAN: Reading model config from {args.hifigan_config}')
  128. with open(args.hifigan_config) as f:
  129. args = AttrDict(json.load(f))
  130. model_config = dict(
  131. # generator architecture
  132. upsample_rates=args.upsample_rates,
  133. upsample_kernel_sizes=args.upsample_kernel_sizes,
  134. upsample_initial_channel=args.upsample_initial_channel,
  135. resblock=args.resblock,
  136. resblock_kernel_sizes=args.resblock_kernel_sizes,
  137. resblock_dilation_sizes=args.resblock_dilation_sizes,
  138. )
  139. elif model_name == 'WaveGlow':
  140. model_config = dict(
  141. n_mel_channels=args.n_mel_channels,
  142. n_flows=args.flows,
  143. n_group=args.groups,
  144. n_early_every=args.early_every,
  145. n_early_size=args.early_size,
  146. WN_config=dict(
  147. n_layers=args.wn_layers,
  148. kernel_size=args.wn_kernel_size,
  149. n_channels=args.wn_channels
  150. )
  151. )
  152. else:
  153. raise NotImplementedError(model_name)
  154. # Start with ckpt_config, and fill missing keys from model_config
  155. final_config = {} if ckpt_config is None else ckpt_config.copy()
  156. missing_keys = set(model_config.keys()) - set(final_config.keys())
  157. final_config.update({k: model_config[k] for k in missing_keys})
  158. # If there was a ckpt_config, it should have had all args
  159. if ckpt_config is not None and len(missing_keys) > 0:
  160. print(f'WARNING: Keys {missing_keys} missing from the loaded config; '
  161. 'using args instead.')
  162. assert all(v is not _missing for v in final_config.values())
  163. return final_config
  164. def get_model_train_setup(model_name, args):
  165. """ Dump train setup for documentation purposes """
  166. if model_name == 'FastPitch':
  167. return dict()
  168. elif model_name == 'HiFi-GAN':
  169. return dict(
  170. # audio
  171. segment_size=args.segment_size,
  172. filter_length=args.filter_length,
  173. num_mels=args.num_mels,
  174. hop_length=args.hop_length,
  175. win_length=args.win_length,
  176. sampling_rate=args.sampling_rate,
  177. mel_fmin=args.mel_fmin,
  178. mel_fmax=args.mel_fmax,
  179. mel_fmax_loss=args.mel_fmax_loss,
  180. max_wav_value=args.max_wav_value,
  181. # other
  182. seed=args.seed,
  183. # optimization
  184. base_lr=args.learning_rate,
  185. lr_decay=args.lr_decay,
  186. epochs_all=args.epochs,
  187. )
  188. elif model_name == 'WaveGlow':
  189. return dict()
  190. else:
  191. raise NotImplementedError(model_name)
  192. def load_model_from_ckpt(checkpoint_data, model, key='state_dict'):
  193. if key is None:
  194. return checkpoint_data['model'], None
  195. sd = checkpoint_data[key]
  196. sd = {re.sub('^module\.', '', k): v for k, v in sd.items()}
  197. status = model.load_state_dict(sd, strict=False)
  198. return model, status
  199. def load_and_setup_model(model_name, parser, checkpoint, amp, device,
  200. unk_args=[], forward_is_infer=False, jitable=False):
  201. if checkpoint is not None:
  202. ckpt_data = torch.load(checkpoint)
  203. print(f'{model_name}: Loading {checkpoint}...')
  204. ckpt_config = ckpt_data.get('config')
  205. if ckpt_config is None:
  206. print(f'{model_name}: No model config in the checkpoint; using args.')
  207. else:
  208. print(f'{model_name}: Found model config saved in the checkpoint.')
  209. else:
  210. ckpt_config = None
  211. ckpt_data = {}
  212. model_parser = parse_model_args(model_name, parser, add_help=False)
  213. model_args, model_unk_args = model_parser.parse_known_args()
  214. unk_args[:] = list(set(unk_args) & set(model_unk_args))
  215. model_config = get_model_config(model_name, model_args, ckpt_config)
  216. model = get_model(model_name, model_config, device,
  217. forward_is_infer=forward_is_infer,
  218. jitable=jitable)
  219. if checkpoint is not None:
  220. key = 'generator' if model_name == 'HiFi-GAN' else 'state_dict'
  221. model, status = load_model_from_ckpt(ckpt_data, model, key)
  222. missing = [] if status is None else status.missing_keys
  223. unexpected = [] if status is None else status.unexpected_keys
  224. # Attention is only used during training, we won't miss it
  225. if model_name == 'FastPitch':
  226. missing = [k for k in missing if not k.startswith('attention.')]
  227. unexpected = [k for k in unexpected if not k.startswith('attention.')]
  228. assert len(missing) == 0 and len(unexpected) == 0, (
  229. f'Mismatched keys when loading parameters. Missing: {missing}, '
  230. f'unexpected: {unexpected}.')
  231. if model_name == "WaveGlow":
  232. for k, m in model.named_modules():
  233. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
  234. model = model.remove_weightnorm(model)
  235. elif model_name == 'HiFi-GAN':
  236. assert model_args.hifigan_config is not None or ckpt_config is not None, (
  237. 'Use a HiFi-GAN checkpoint from NVIDIA DeepLearningExamples with '
  238. 'saved config or supply --hifigan-config <json_file>.')
  239. model.remove_weight_norm()
  240. if amp:
  241. model.half()
  242. model.eval()
  243. return model.to(device), model_config, ckpt_data.get('train_setup', {})
  244. def load_and_setup_ts_model(model_name, checkpoint, amp, device=None):
  245. print(f'{model_name}: Loading TorchScript checkpoint {checkpoint}...')
  246. model = torch.jit.load(checkpoint).eval()
  247. if device is not None:
  248. model = model.to(device)
  249. if amp:
  250. model.half()
  251. elif next(model.parameters()).dtype == torch.float16:
  252. raise ValueError('Trying to load FP32 model,'
  253. 'TS checkpoint is in FP16 precision.')
  254. return model
  255. def convert_ts_to_trt(model_name, ts_model, parser, amp, unk_args=[]):
  256. trt_parser = _parse_trt_compilation_args(model_name, parser, add_help=False)
  257. trt_args, trt_unk_args = trt_parser.parse_known_args()
  258. unk_args[:] = list(set(unk_args) & set(trt_unk_args))
  259. if model_name == 'HiFi-GAN':
  260. return _convert_ts_to_trt_hifigan(
  261. ts_model, amp, trt_args.trt_min_opt_max_batch,
  262. trt_args.trt_min_opt_max_hifigan_length)
  263. else:
  264. raise NotImplementedError
  265. def _parse_trt_compilation_args(model_name, parent, add_help=False):
  266. """
  267. Parse model and inference specific commandline arguments.
  268. """
  269. parser = argparse.ArgumentParser(parents=[parent], add_help=add_help,
  270. allow_abbrev=False)
  271. trt = parser.add_argument_group(f'{model_name} Torch-TensorRT compilation parameters')
  272. trt.add_argument('--trt-min-opt-max-batch', nargs=3, type=int,
  273. default=(1, 8, 16),
  274. help='Torch-TensorRT min, optimal and max batch size')
  275. if model_name == 'HiFi-GAN':
  276. trt.add_argument('--trt-min-opt-max-hifigan-length', nargs=3, type=int,
  277. default=(100, 800, 1200),
  278. help='Torch-TensorRT min, optimal and max audio length (in frames)')
  279. return parser
  280. def _convert_ts_to_trt_hifigan(ts_model, amp, trt_min_opt_max_batch,
  281. trt_min_opt_max_hifigan_length, num_mels=80):
  282. import torch_tensorrt
  283. trt_dtype = torch.half if amp else torch.float
  284. print(f'Torch TensorRT: compiling HiFi-GAN for dtype {trt_dtype}.')
  285. min_shp, opt_shp, max_shp = zip(trt_min_opt_max_batch,
  286. (num_mels,) * 3,
  287. trt_min_opt_max_hifigan_length)
  288. compile_settings = {
  289. "inputs": [torch_tensorrt.Input(
  290. min_shape=min_shp,
  291. opt_shape=opt_shp,
  292. max_shape=max_shp,
  293. dtype=trt_dtype,
  294. )],
  295. "enabled_precisions": {trt_dtype},
  296. "require_full_compilation": True,
  297. }
  298. trt_model = torch_tensorrt.compile(ts_model, **compile_settings)
  299. print('Torch TensorRT: compilation successful.')
  300. return trt_model