inference.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import argparse
  28. import models
  29. import time
  30. import tqdm
  31. import sys
  32. import warnings
  33. from pathlib import Path
  34. import torch
  35. import numpy as np
  36. from scipy.stats import norm
  37. from scipy.io.wavfile import write
  38. from torch.nn.utils.rnn import pad_sequence
  39. import dllogger as DLLogger
  40. from apex import amp
  41. from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
  42. from common import utils
  43. from common.text import text_to_sequence
  44. from waveglow import model as glow
  45. from waveglow.denoiser import Denoiser
  46. sys.modules['glow'] = glow
  47. def parse_args(parser):
  48. """
  49. Parse commandline arguments.
  50. """
  51. parser.add_argument('-i', '--input', type=str, required=True,
  52. help='Full path to the input text (phareses separated by newlines)')
  53. parser.add_argument('-o', '--output', default=None,
  54. help='Output folder to save audio (file per phrase)')
  55. parser.add_argument('--log-file', type=str, default='nvlog.json',
  56. help='Filename for logging')
  57. parser.add_argument('--cuda', action='store_true',
  58. help='Run inference on a GPU using CUDA')
  59. parser.add_argument('--fastpitch', type=str,
  60. help='Full path to the generator checkpoint file (skip to use ground truth mels)')
  61. parser.add_argument('--waveglow', type=str,
  62. help='Full path to the WaveGlow model checkpoint file (skip to only generate mels)')
  63. parser.add_argument('-s', '--sigma-infer', default=0.9, type=float,
  64. help='WaveGlow sigma')
  65. parser.add_argument('-d', '--denoising-strength', default=0.01, type=float,
  66. help='WaveGlow denoising')
  67. parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
  68. help='Sampling rate')
  69. parser.add_argument('--stft-hop-length', type=int, default=256,
  70. help='STFT hop length for estimating audio length from mel size')
  71. parser.add_argument('--amp-run', action='store_true',
  72. help='Inference with AMP')
  73. parser.add_argument('--batch-size', type=int, default=64)
  74. parser.add_argument('--include-warmup', action='store_true',
  75. help='Include warmup')
  76. parser.add_argument('--repeats', type=int, default=1,
  77. help='Repeat inference for benchmarking')
  78. parser.add_argument('--torchscript', action='store_true',
  79. help='Apply TorchScript')
  80. parser.add_argument('--ema', action='store_true',
  81. help='Use EMA averaged model (if saved in checkpoints)')
  82. parser.add_argument('--dataset-path', type=str,
  83. help='Path to dataset (for loading extra data fields)')
  84. transform = parser.add_argument_group('transform')
  85. transform.add_argument('--fade-out', type=int, default=5,
  86. help='Number of fadeout frames at the end')
  87. transform.add_argument('--pace', type=float, default=1.0,
  88. help='Adjust the pace of speech')
  89. transform.add_argument('--pitch-transform-flatten', action='store_true',
  90. help='Flatten the pitch')
  91. transform.add_argument('--pitch-transform-invert', action='store_true',
  92. help='Invert the pitch wrt mean value')
  93. transform.add_argument('--pitch-transform-amplify', action='store_true',
  94. help='Amplify the pitch variability')
  95. transform.add_argument('--pitch-transform-shift', type=float, default=0.0,
  96. help='Raise/lower the pitch by <hz>')
  97. return parser
  98. def load_and_setup_model(model_name, parser, checkpoint, amp_run, device,
  99. unk_args=[], forward_is_infer=False, ema=True,
  100. jitable=False):
  101. model_parser = models.parse_model_args(model_name, parser, add_help=False)
  102. model_args, model_unk_args = model_parser.parse_known_args()
  103. unk_args[:] = list(set(unk_args) & set(model_unk_args))
  104. model_config = models.get_model_config(model_name, model_args)
  105. model = models.get_model(model_name, model_config, device,
  106. forward_is_infer=forward_is_infer,
  107. jitable=jitable)
  108. if checkpoint is not None:
  109. checkpoint_data = torch.load(checkpoint)
  110. status = ''
  111. if 'state_dict' in checkpoint_data:
  112. sd = checkpoint_data['state_dict']
  113. if ema and 'ema_state_dict' in checkpoint_data:
  114. sd = checkpoint_data['ema_state_dict']
  115. status += ' (EMA)'
  116. elif ema and not 'ema_state_dict' in checkpoint_data:
  117. print(f'WARNING: EMA weights missing for {model_name}')
  118. if any(key.startswith('module.') for key in sd):
  119. sd = {k.replace('module.', ''): v for k,v in sd.items()}
  120. status += ' ' + str(model.load_state_dict(sd, strict=False))
  121. else:
  122. model = checkpoint_data['model']
  123. print(f'Loaded {model_name}{status}')
  124. if model_name == "WaveGlow":
  125. model = model.remove_weightnorm(model)
  126. if amp_run:
  127. model.half()
  128. model.eval()
  129. return model.to(device)
  130. def load_fields(fpath):
  131. lines = [l.strip() for l in open(fpath, encoding='utf-8')]
  132. if fpath.endswith('.tsv'):
  133. columns = lines[0].split('\t')
  134. fields = list(zip(*[t.split('\t') for t in lines[1:]]))
  135. else:
  136. columns = ['text']
  137. fields = [lines]
  138. return {c:f for c, f in zip(columns, fields)}
  139. def prepare_input_sequence(fields, device, batch_size=128, dataset=None,
  140. load_mels=False, load_pitch=False):
  141. fields['text'] = [torch.LongTensor(text_to_sequence(t, ['english_cleaners']))
  142. for t in fields['text']]
  143. order = np.argsort([-t.size(0) for t in fields['text']])
  144. fields['text'] = [fields['text'][i] for i in order]
  145. fields['text_lens'] = torch.LongTensor([t.size(0) for t in fields['text']])
  146. if load_mels:
  147. assert 'mel' in fields
  148. fields['mel'] = [
  149. torch.load(Path(dataset, fields['mel'][i])).t() for i in order]
  150. fields['mel_lens'] = torch.LongTensor([t.size(0) for t in fields['mel']])
  151. if load_pitch:
  152. assert 'pitch' in fields
  153. fields['pitch'] = [
  154. torch.load(Path(dataset, fields['pitch'][i])) for i in order]
  155. fields['pitch_lens'] = torch.LongTensor([t.size(0) for t in fields['pitch']])
  156. if 'output' in fields:
  157. fields['output'] = [fields['output'][i] for i in order]
  158. # cut into batches & pad
  159. batches = []
  160. for b in range(0, len(order), batch_size):
  161. batch = {f: values[b:b+batch_size] for f, values in fields.items()}
  162. for f in batch:
  163. if f == 'text':
  164. batch[f] = pad_sequence(batch[f], batch_first=True)
  165. elif f == 'mel' and load_mels:
  166. batch[f] = pad_sequence(batch[f], batch_first=True).permute(0, 2, 1)
  167. elif f == 'pitch' and load_pitch:
  168. batch[f] = pad_sequence(batch[f], batch_first=True)
  169. if type(batch[f]) is torch.Tensor:
  170. batch[f] = batch[f].to(device)
  171. batches.append(batch)
  172. return batches
  173. def build_pitch_transformation(args):
  174. fun = 'pitch'
  175. if args.pitch_transform_flatten:
  176. fun = f'({fun}) * 0.0'
  177. if args.pitch_transform_invert:
  178. fun = f'({fun}) * -1.0'
  179. if args.pitch_transform_amplify:
  180. fun = f'({fun}) * 2.0'
  181. if args.pitch_transform_shift != 0.0:
  182. hz = args.pitch_transform_shift
  183. fun = f'({fun}) + {hz} / std'
  184. return eval(f'lambda pitch, mean, std: {fun}')
  185. class MeasureTime(list):
  186. def __enter__(self):
  187. torch.cuda.synchronize()
  188. self.t0 = time.perf_counter()
  189. def __exit__(self, exc_type, exc_value, exc_traceback):
  190. torch.cuda.synchronize()
  191. self.append(time.perf_counter() - self.t0)
  192. def __add__(self, other):
  193. assert len(self) == len(other)
  194. return MeasureTime(sum(ab) for ab in zip(self, other))
  195. def main():
  196. """
  197. Launches text to speech (inference).
  198. Inference is executed on a single GPU.
  199. """
  200. parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference',
  201. allow_abbrev=False)
  202. parser = parse_args(parser)
  203. args, unk_args = parser.parse_known_args()
  204. DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, args.log_file),
  205. StdOutBackend(Verbosity.VERBOSE)])
  206. for k,v in vars(args).items():
  207. DLLogger.log(step="PARAMETER", data={k:v})
  208. DLLogger.log(step="PARAMETER", data={'model_name': 'FastPitch_PyT'})
  209. if args.output is not None:
  210. Path(args.output).mkdir(parents=False, exist_ok=True)
  211. device = torch.device('cuda' if args.cuda else 'cpu')
  212. if args.fastpitch is not None:
  213. generator = load_and_setup_model(
  214. 'FastPitch', parser, args.fastpitch, args.amp_run, device,
  215. unk_args=unk_args, forward_is_infer=True, ema=args.ema,
  216. jitable=args.torchscript)
  217. if args.torchscript:
  218. generator = torch.jit.script(generator)
  219. else:
  220. generator = None
  221. if args.waveglow is not None:
  222. with warnings.catch_warnings():
  223. warnings.simplefilter("ignore")
  224. waveglow = load_and_setup_model(
  225. 'WaveGlow', parser, args.waveglow, args.amp_run, device,
  226. unk_args=unk_args, forward_is_infer=True, ema=args.ema)
  227. denoiser = Denoiser(waveglow).to(device)
  228. waveglow = getattr(waveglow, 'infer', waveglow)
  229. else:
  230. waveglow = None
  231. if len(unk_args) > 0:
  232. raise ValueError(f'Invalid options {unk_args}')
  233. fields = load_fields(args.input)
  234. batches = prepare_input_sequence(
  235. fields, device, args.batch_size, args.dataset_path,
  236. load_mels=(generator is None))
  237. if args.include_warmup:
  238. # Use real data rather than synthetic - FastPitch predicts len
  239. for i in range(3):
  240. with torch.no_grad():
  241. if generator is not None:
  242. b = batches[0]
  243. mel, *_ = generator(b['text'], b['text_lens'])
  244. if waveglow is not None:
  245. audios = waveglow(mel, sigma=args.sigma_infer).float()
  246. _ = denoiser(audios, strength=args.denoising_strength)
  247. gen_measures = MeasureTime()
  248. waveglow_measures = MeasureTime()
  249. gen_kw = {'pace': args.pace,
  250. 'pitch_tgt': None,
  251. 'pitch_transform': build_pitch_transformation(args)}
  252. if args.torchscript:
  253. gen_kw.pop('pitch_transform')
  254. all_utterances = 0
  255. all_samples = 0
  256. all_letters = 0
  257. all_frames = 0
  258. reps = args.repeats
  259. log_enabled = reps == 1
  260. log = lambda s, d: DLLogger.log(step=s, data=d) if log_enabled else None
  261. for repeat in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)):
  262. for b in batches:
  263. if generator is None:
  264. log(0, {'Synthesizing from ground truth mels'})
  265. mel, mel_lens = b['mel'], b['mel_lens']
  266. else:
  267. with torch.no_grad(), gen_measures:
  268. mel, mel_lens, *_ = generator(
  269. b['text'], b['text_lens'], **gen_kw)
  270. gen_infer_perf = mel.size(0) * mel.size(2) / gen_measures[-1]
  271. all_letters += b['text_lens'].sum().item()
  272. all_frames += mel.size(0) * mel.size(2)
  273. log(0, {"generator_frames_per_sec": gen_infer_perf})
  274. log(0, {"generator_latency": gen_measures[-1]})
  275. if waveglow is not None:
  276. with torch.no_grad(), waveglow_measures:
  277. audios = waveglow(mel, sigma=args.sigma_infer)
  278. audios = denoiser(audios.float(),
  279. strength=args.denoising_strength
  280. ).squeeze(1)
  281. all_utterances += len(audios)
  282. all_samples += sum(audio.size(0) for audio in audios)
  283. waveglow_infer_perf = (
  284. audios.size(0) * audios.size(1) / waveglow_measures[-1])
  285. log(0, {"waveglow_samples_per_sec": waveglow_infer_perf})
  286. log(0, {"waveglow_latency": waveglow_measures[-1]})
  287. if args.output is not None and reps == 1:
  288. for i, audio in enumerate(audios):
  289. audio = audio[:mel_lens[i].item() * args.stft_hop_length]
  290. if args.fade_out:
  291. fade_len = args.fade_out * args.stft_hop_length
  292. fade_w = torch.linspace(1.0, 0.0, fade_len)
  293. audio[-fade_len:] *= fade_w.to(audio.device)
  294. audio = audio/torch.max(torch.abs(audio))
  295. fname = b['output'][i] if 'output' in b else f'audio_{i}.wav'
  296. audio_path = Path(args.output, fname)
  297. write(audio_path, args.sampling_rate, audio.cpu().numpy())
  298. if generator is not None and waveglow is not None:
  299. log(0, {"latency": (gen_measures[-1] + waveglow_measures[-1])})
  300. log_enabled = True
  301. if generator is not None:
  302. gm = np.sort(np.asarray(gen_measures))
  303. log('avg', {"generator letters/s": all_letters / gm.sum()})
  304. log('avg', {"generator_frames/s": all_frames / gm.sum()})
  305. log('avg', {"generator_latency": gm.mean()})
  306. log('90%', {"generator_latency": gm.mean() + norm.ppf((1.0 + 0.90) / 2) * gm.std()})
  307. log('95%', {"generator_latency": gm.mean() + norm.ppf((1.0 + 0.95) / 2) * gm.std()})
  308. log('99%', {"generator_latency": gm.mean() + norm.ppf((1.0 + 0.99) / 2) * gm.std()})
  309. if waveglow is not None:
  310. wm = np.sort(np.asarray(waveglow_measures))
  311. log('avg', {"waveglow_samples/s": all_samples / wm.sum()})
  312. log('avg', {"waveglow_latency": wm.mean()})
  313. log('90%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.90) / 2) * wm.std()})
  314. log('95%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.95) / 2) * wm.std()})
  315. log('99%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.99) / 2) * wm.std()})
  316. if generator is not None and waveglow is not None:
  317. m = gm + wm
  318. rtf = all_samples / (len(batches) * all_utterances * m.mean() * args.sampling_rate)
  319. log('avg', {"samples/s": all_samples / m.sum()})
  320. log('avg', {"letters/s": all_letters / m.sum()})
  321. log('avg', {"latency": m.mean()})
  322. log('avg', {"RTF": rtf})
  323. log('90%', {"latency": m.mean() + norm.ppf((1.0 + 0.90) / 2) * m.std()})
  324. log('95%', {"latency": m.mean() + norm.ppf((1.0 + 0.95) / 2) * m.std()})
  325. log('99%', {"latency": m.mean() + norm.ppf((1.0 + 0.99) / 2) * m.std()})
  326. DLLogger.flush()
  327. if __name__ == '__main__':
  328. main()