inference.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import itertools
  16. import sys
  17. import time
  18. import warnings
  19. from pathlib import Path
  20. from tqdm import tqdm
  21. import torch
  22. import numpy as np
  23. from scipy.stats import norm
  24. from scipy.io.wavfile import write
  25. from torch.nn.functional import l1_loss
  26. from torch.nn.utils.rnn import pad_sequence
  27. import dllogger as DLLogger
  28. from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
  29. import models
  30. from common import gpu_affinity
  31. from common.tb_dllogger import (init_inference_metadata, stdout_metric_format,
  32. unique_log_fpath)
  33. from common.text import cmudict
  34. from common.text.text_processing import get_text_processing
  35. from common.utils import l2_promote
  36. from fastpitch.pitch_transform import pitch_transform_custom
  37. from hifigan.data_function import MAX_WAV_VALUE, mel_spectrogram
  38. from hifigan.models import Denoiser
  39. from waveglow import model as glow
  40. CHECKPOINT_SPECIFIC_ARGS = [
  41. 'sampling_rate', 'hop_length', 'win_length', 'p_arpabet', 'text_cleaners',
  42. 'symbol_set', 'max_wav_value', 'prepend_space_to_text',
  43. 'append_space_to_text']
  44. def parse_args(parser):
  45. """
  46. Parse commandline arguments.
  47. """
  48. parser.add_argument('-i', '--input', type=str, required=True,
  49. help='Full path to the input text (phareses separated by newlines)')
  50. parser.add_argument('-o', '--output', default=None,
  51. help='Output folder to save audio (file per phrase)')
  52. parser.add_argument('--log-file', type=str, default=None,
  53. help='Path to a DLLogger log file')
  54. parser.add_argument('--save-mels', action='store_true',
  55. help='Save generator outputs to disk')
  56. parser.add_argument('--cuda', action='store_true',
  57. help='Run inference on a GPU using CUDA')
  58. parser.add_argument('--cudnn-benchmark', action='store_true',
  59. help='Enable cudnn benchmark mode')
  60. parser.add_argument('--l2-promote', action='store_true',
  61. help='Increase max fetch granularity of GPU L2 cache')
  62. parser.add_argument('--fastpitch', type=str, default=None, required=False,
  63. help='Full path to the spectrogram generator .pt file '
  64. '(skip to synthesize from ground truth mels)')
  65. parser.add_argument('--waveglow', type=str, default=None, required=False,
  66. help='Full path to a WaveGlow model .pt file')
  67. parser.add_argument('-s', '--waveglow-sigma-infer', default=0.9, type=float,
  68. help='WaveGlow sigma')
  69. parser.add_argument('--hifigan', type=str, default=None, required=False,
  70. help='Full path to a HiFi-GAN model .pt file')
  71. parser.add_argument('-d', '--denoising-strength', default=0.0, type=float,
  72. help='Capture and subtract model bias to enhance audio')
  73. parser.add_argument('--hop-length', type=int, default=256,
  74. help='STFT hop length for estimating audio length from mel size')
  75. parser.add_argument('--win-length', type=int, default=1024,
  76. help='STFT win length for denoiser and mel loss')
  77. parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
  78. choices=[22050, 44100], help='Sampling rate')
  79. parser.add_argument('--max_wav_value', default=32768.0, type=float,
  80. help='Maximum audiowave value')
  81. parser.add_argument('--amp', action='store_true',
  82. help='Inference with AMP')
  83. parser.add_argument('-bs', '--batch-size', type=int, default=64)
  84. parser.add_argument('--warmup-steps', type=int, default=0,
  85. help='Warmup iterations before measuring performance')
  86. parser.add_argument('--repeats', type=int, default=1,
  87. help='Repeat inference for benchmarking')
  88. parser.add_argument('--torchscript', action='store_true',
  89. help='Run inference with TorchScript model (convert to TS if needed)')
  90. parser.add_argument('--checkpoint-format', type=str,
  91. choices=['pyt', 'ts'], default='pyt',
  92. help='Input checkpoint format (PyT or TorchScript)')
  93. parser.add_argument('--torch-tensorrt', action='store_true',
  94. help='Run inference with Torch-TensorRT model (compile beforehand)')
  95. parser.add_argument('--report-mel-loss', action='store_true',
  96. help='Report mel loss in metrics')
  97. parser.add_argument('--ema', action='store_true',
  98. help='Use EMA averaged model (if saved in checkpoints)')
  99. parser.add_argument('--dataset-path', type=str,
  100. help='Path to dataset (for loading extra data fields)')
  101. parser.add_argument('--speaker', type=int, default=0,
  102. help='Speaker ID for a multi-speaker model')
  103. parser.add_argument('--affinity', type=str, default='single',
  104. choices=['socket', 'single', 'single_unique',
  105. 'socket_unique_interleaved',
  106. 'socket_unique_continuous',
  107. 'disabled'],
  108. help='type of CPU affinity')
  109. transf = parser.add_argument_group('transform')
  110. transf.add_argument('--fade-out', type=int, default=10,
  111. help='Number of fadeout frames at the end')
  112. transf.add_argument('--pace', type=float, default=1.0,
  113. help='Adjust the pace of speech')
  114. transf.add_argument('--pitch-transform-flatten', action='store_true',
  115. help='Flatten the pitch')
  116. transf.add_argument('--pitch-transform-invert', action='store_true',
  117. help='Invert the pitch wrt mean value')
  118. transf.add_argument('--pitch-transform-amplify', type=float, default=1.0,
  119. help='Multiplicative amplification of pitch variability. '
  120. 'Typical values are in the range (1.0, 3.0).')
  121. transf.add_argument('--pitch-transform-shift', type=float, default=0.0,
  122. help='Raise/lower the pitch by <hz>')
  123. transf.add_argument('--pitch-transform-custom', action='store_true',
  124. help='Apply the transform from pitch_transform.py')
  125. txt = parser.add_argument_group('Text processing parameters')
  126. txt.add_argument('--text-cleaners', type=str, nargs='*',
  127. default=['english_cleaners_v2'],
  128. help='Type of text cleaners for input text')
  129. txt.add_argument('--symbol-set', type=str, default='english_basic',
  130. help='Define symbol set for input text')
  131. txt.add_argument('--p-arpabet', type=float, default=0.0, help='')
  132. txt.add_argument('--heteronyms-path', type=str,
  133. default='cmudict/heteronyms', help='')
  134. txt.add_argument('--cmudict-path', type=str,
  135. default='cmudict/cmudict-0.7b', help='')
  136. return parser
  137. def load_fields(fpath):
  138. lines = [l.strip() for l in open(fpath, encoding='utf-8')]
  139. if fpath.endswith('.tsv'):
  140. columns = lines[0].split('\t')
  141. fields = list(zip(*[t.split('\t') for t in lines[1:]]))
  142. else:
  143. columns = ['text']
  144. fields = [lines]
  145. return {c: f for c, f in zip(columns, fields)}
  146. def prepare_input_sequence(fields, device, symbol_set, text_cleaners,
  147. batch_size=128, dataset=None, load_mels=False,
  148. load_pitch=False, p_arpabet=0.0):
  149. tp = get_text_processing(symbol_set, text_cleaners, p_arpabet)
  150. fields['text'] = [torch.LongTensor(tp.encode_text(text))
  151. for text in fields['text']]
  152. order = np.argsort([-t.size(0) for t in fields['text']])
  153. fields['text'] = [fields['text'][i] for i in order]
  154. fields['text_lens'] = torch.LongTensor([t.size(0) for t in fields['text']])
  155. for t in fields['text']:
  156. print(tp.sequence_to_text(t.numpy()))
  157. if load_mels:
  158. assert 'mel' in fields
  159. assert dataset is not None
  160. fields['mel'] = [
  161. torch.load(Path(dataset, fields['mel'][i])).t() for i in order]
  162. fields['mel_lens'] = torch.LongTensor([t.size(0) for t in fields['mel']])
  163. if load_pitch:
  164. assert 'pitch' in fields
  165. fields['pitch'] = [
  166. torch.load(Path(dataset, fields['pitch'][i])) for i in order]
  167. fields['pitch_lens'] = torch.LongTensor([t.size(0) for t in fields['pitch']])
  168. if 'output' in fields:
  169. fields['output'] = [fields['output'][i] for i in order]
  170. # cut into batches & pad
  171. batches = []
  172. for b in range(0, len(order), batch_size):
  173. batch = {f: values[b:b+batch_size] for f, values in fields.items()}
  174. for f in batch:
  175. if f == 'text':
  176. batch[f] = pad_sequence(batch[f], batch_first=True)
  177. elif f == 'mel' and load_mels:
  178. batch[f] = pad_sequence(batch[f], batch_first=True).permute(0, 2, 1)
  179. elif f == 'pitch' and load_pitch:
  180. batch[f] = pad_sequence(batch[f], batch_first=True)
  181. if type(batch[f]) is torch.Tensor:
  182. batch[f] = batch[f].to(device)
  183. batches.append(batch)
  184. return batches
  185. def build_pitch_transformation(args):
  186. if args.pitch_transform_custom:
  187. def custom_(pitch, pitch_lens, mean, std):
  188. return (pitch_transform_custom(pitch * std + mean, pitch_lens)
  189. - mean) / std
  190. return custom_
  191. fun = 'pitch'
  192. if args.pitch_transform_flatten:
  193. fun = f'({fun}) * 0.0'
  194. if args.pitch_transform_invert:
  195. fun = f'({fun}) * -1.0'
  196. if args.pitch_transform_amplify != 1.0:
  197. ampl = args.pitch_transform_amplify
  198. fun = f'({fun}) * {ampl}'
  199. if args.pitch_transform_shift != 0.0:
  200. hz = args.pitch_transform_shift
  201. fun = f'({fun}) + {hz} / std'
  202. if fun == 'pitch':
  203. return None
  204. return eval(f'lambda pitch, pitch_lens, mean, std: {fun}')
  205. def setup_mel_loss_reporting(args, voc_train_setup):
  206. if args.denoising_strength > 0.0:
  207. print('WARNING: denoising will be included in vocoder mel loss')
  208. num_mels = voc_train_setup.get('num_mels', 80)
  209. fmin = voc_train_setup.get('mel_fmin', 0)
  210. fmax = voc_train_setup.get('mel_fmax', 8000) # not mel_fmax_loss
  211. def compute_audio_mel_loss(gen_audios, gt_mels, mel_lens):
  212. gen_audios /= MAX_WAV_VALUE
  213. total_loss = 0
  214. for gen_audio, gt_mel, mel_len in zip(gen_audios, gt_mels, mel_lens):
  215. mel_len = mel_len.item()
  216. gen_audio = gen_audio[None, :mel_len * args.hop_length]
  217. gen_mel = mel_spectrogram(gen_audio, args.win_length, num_mels,
  218. args.sampling_rate, args.hop_length,
  219. args.win_length, fmin, fmax)[0]
  220. total_loss += l1_loss(gen_mel, gt_mel[:, :mel_len])
  221. return total_loss.item()
  222. return compute_audio_mel_loss
  223. def compute_mel_loss(mels, lens, gt_mels, gt_lens):
  224. total_loss = 0
  225. for mel, len_, gt_mel, gt_len in zip(mels, lens, gt_mels, gt_lens):
  226. min_len = min(len_, gt_len)
  227. total_loss += l1_loss(gt_mel[:, :min_len], mel[:, :min_len])
  228. return total_loss.item()
  229. class MeasureTime(list):
  230. def __init__(self, *args, cuda=True, **kwargs):
  231. super(MeasureTime, self).__init__(*args, **kwargs)
  232. self.cuda = cuda
  233. def __enter__(self):
  234. if self.cuda:
  235. torch.cuda.synchronize()
  236. self.t0 = time.time()
  237. def __exit__(self, exc_type, exc_value, exc_traceback):
  238. if self.cuda:
  239. torch.cuda.synchronize()
  240. self.append(time.time() - self.t0)
  241. def __add__(self, other):
  242. assert len(self) == len(other)
  243. return MeasureTime((sum(ab) for ab in zip(self, other)), cuda=self.cuda)
  244. def main():
  245. """
  246. Launches text-to-speech inference on a single GPU.
  247. """
  248. parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference',
  249. allow_abbrev=False)
  250. parser = parse_args(parser)
  251. args, unk_args = parser.parse_known_args()
  252. if args.affinity != 'disabled':
  253. nproc_per_node = torch.cuda.device_count()
  254. # print(nproc_per_node)
  255. affinity = gpu_affinity.set_affinity(
  256. 0,
  257. nproc_per_node,
  258. args.affinity
  259. )
  260. print(f'Thread affinity: {affinity}')
  261. if args.l2_promote:
  262. l2_promote()
  263. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  264. if args.output is not None:
  265. Path(args.output).mkdir(parents=False, exist_ok=True)
  266. log_fpath = args.log_file or str(Path(args.output, 'nvlog_infer.json'))
  267. DLLogger.init(backends=[
  268. JSONStreamBackend(Verbosity.DEFAULT, log_fpath, append=True),
  269. JSONStreamBackend(Verbosity.DEFAULT, unique_log_fpath(log_fpath)),
  270. StdOutBackend(Verbosity.VERBOSE, metric_format=stdout_metric_format)
  271. ])
  272. init_inference_metadata(args.batch_size)
  273. [DLLogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]
  274. device = torch.device('cuda' if args.cuda else 'cpu')
  275. gen_train_setup = {}
  276. voc_train_setup = {}
  277. generator = None
  278. vocoder = None
  279. denoiser = None
  280. is_ts_based_infer = args.torch_tensorrt or args.torchscript
  281. assert args.checkpoint_format == 'pyt' or is_ts_based_infer, \
  282. 'TorchScript checkpoint can be used only for TS or Torch-TRT' \
  283. ' inference. Please set --torchscript or --torch-tensorrt flag.'
  284. assert args.waveglow is None or args.hifigan is None, \
  285. "Specify a single vocoder model"
  286. def _load_pyt_or_ts_model(model_name, ckpt_path):
  287. if args.checkpoint_format == 'ts':
  288. model = models.load_and_setup_ts_model(model_name, ckpt_path,
  289. args.amp, device)
  290. model_train_setup = {}
  291. return model, model_train_setup
  292. model, _, model_train_setup = models.load_and_setup_model(
  293. model_name, parser, ckpt_path, args.amp, device,
  294. unk_args=unk_args, forward_is_infer=True, jitable=is_ts_based_infer)
  295. if is_ts_based_infer:
  296. model = torch.jit.script(model)
  297. return model, model_train_setup
  298. if args.fastpitch is not None:
  299. gen_name = 'fastpitch'
  300. generator, gen_train_setup = _load_pyt_or_ts_model('FastPitch',
  301. args.fastpitch)
  302. if args.waveglow is not None:
  303. voc_name = 'waveglow'
  304. with warnings.catch_warnings():
  305. warnings.simplefilter("ignore")
  306. vocoder, _, voc_train_setup = models.load_and_setup_model(
  307. 'WaveGlow', parser, args.waveglow, args.amp, device,
  308. unk_args=unk_args, forward_is_infer=True, jitable=False)
  309. if args.denoising_strength > 0.0:
  310. denoiser = Denoiser(vocoder, sigma=0.0,
  311. win_length=args.win_length).to(device)
  312. # if args.torchscript:
  313. # vocoder = torch.jit.script(vocoder)
  314. def generate_audio(mel):
  315. audios = vocoder(mel, sigma=args.waveglow_sigma_infer)
  316. if denoiser is not None:
  317. audios = denoiser(audios.float(), args.denoising_strength).squeeze(1)
  318. return audios
  319. elif args.hifigan is not None:
  320. voc_name = 'hifigan'
  321. vocoder, voc_train_setup = _load_pyt_or_ts_model('HiFi-GAN',
  322. args.hifigan)
  323. if args.denoising_strength > 0.0:
  324. denoiser = Denoiser(vocoder, win_length=args.win_length).to(device)
  325. if args.torch_tensorrt:
  326. vocoder = models.convert_ts_to_trt('HiFi-GAN', vocoder, parser,
  327. args.amp, unk_args)
  328. def generate_audio(mel):
  329. audios = vocoder(mel).float()
  330. if denoiser is not None:
  331. audios = denoiser(audios.squeeze(1), args.denoising_strength)
  332. return audios.squeeze(1) * args.max_wav_value
  333. if len(unk_args) > 0:
  334. raise ValueError(f'Invalid options {unk_args}')
  335. for k in CHECKPOINT_SPECIFIC_ARGS:
  336. v1 = gen_train_setup.get(k, None)
  337. v2 = voc_train_setup.get(k, None)
  338. assert v1 is None or v2 is None or v1 == v2, \
  339. f'{k} mismatch in spectrogram generator and vocoder'
  340. val = v1 or v2
  341. if val and getattr(args, k) != val:
  342. src = 'generator' if v2 is None else 'vocoder'
  343. print(f'Overwriting args.{k}={getattr(args, k)} with {val} '
  344. f'from {src} checkpoint.')
  345. setattr(args, k, val)
  346. gen_kw = {'pace': args.pace,
  347. 'speaker': args.speaker,
  348. 'pitch_tgt': None,
  349. 'pitch_transform': build_pitch_transformation(args)}
  350. if is_ts_based_infer and generator is not None:
  351. gen_kw.pop('pitch_transform')
  352. print('Note: --pitch-transform-* args are disabled with TorchScript. '
  353. 'To condition on pitch, pass pitch_tgt as input.')
  354. if args.p_arpabet > 0.0:
  355. cmudict.initialize(args.cmudict_path, args.heteronyms_path)
  356. if args.report_mel_loss:
  357. mel_loss_fn = setup_mel_loss_reporting(args, voc_train_setup)
  358. fields = load_fields(args.input)
  359. batches = prepare_input_sequence(
  360. fields, device, args.symbol_set, args.text_cleaners, args.batch_size,
  361. args.dataset_path, load_mels=(generator is None or args.report_mel_loss),
  362. p_arpabet=args.p_arpabet)
  363. cycle = itertools.cycle(batches)
  364. # Use real data rather than synthetic - FastPitch predicts len
  365. for _ in tqdm(range(args.warmup_steps), 'Warmup'):
  366. with torch.no_grad():
  367. b = next(cycle)
  368. if generator is not None:
  369. mel, *_ = generator(b['text'])
  370. else:
  371. mel, mel_lens = b['mel'], b['mel_lens']
  372. if args.amp:
  373. mel = mel.half()
  374. if vocoder is not None:
  375. audios = generate_audio(mel)
  376. gen_measures = MeasureTime(cuda=args.cuda)
  377. vocoder_measures = MeasureTime(cuda=args.cuda)
  378. all_utterances = 0
  379. all_samples = 0
  380. all_batches = 0
  381. all_letters = 0
  382. all_frames = 0
  383. gen_mel_loss_sum = 0
  384. voc_mel_loss_sum = 0
  385. reps = args.repeats
  386. log_enabled = reps == 1
  387. log = lambda s, d: DLLogger.log(step=s, data=d) if log_enabled else None
  388. for rep in (tqdm(range(reps), 'Inference') if reps > 1 else range(reps)):
  389. for b in batches:
  390. if generator is None:
  391. mel, mel_lens = b['mel'], b['mel_lens']
  392. if args.amp:
  393. mel = mel.half()
  394. else:
  395. with torch.no_grad(), gen_measures:
  396. mel, mel_lens, *_ = generator(b['text'], **gen_kw)
  397. if args.report_mel_loss:
  398. gen_mel_loss_sum += compute_mel_loss(
  399. mel, mel_lens, b['mel'], b['mel_lens'])
  400. gen_infer_perf = mel.size(0) * mel.size(2) / gen_measures[-1]
  401. all_letters += b['text_lens'].sum().item()
  402. all_frames += mel.size(0) * mel.size(2)
  403. log(rep, {f"{gen_name}_frames/s": gen_infer_perf})
  404. log(rep, {f"{gen_name}_latency": gen_measures[-1]})
  405. if args.save_mels:
  406. for i, mel_ in enumerate(mel):
  407. m = mel_[:, :mel_lens[i].item()].permute(1, 0)
  408. fname = b['output'][i] if 'output' in b else f'mel_{i}.npy'
  409. mel_path = Path(args.output, Path(fname).stem + '.npy')
  410. np.save(mel_path, m.cpu().numpy())
  411. if vocoder is not None:
  412. with torch.no_grad(), vocoder_measures:
  413. audios = generate_audio(mel)
  414. vocoder_infer_perf = (
  415. audios.size(0) * audios.size(1) / vocoder_measures[-1])
  416. log(rep, {f"{voc_name}_samples/s": vocoder_infer_perf})
  417. log(rep, {f"{voc_name}_latency": vocoder_measures[-1]})
  418. if args.report_mel_loss:
  419. voc_mel_loss_sum += mel_loss_fn(audios, mel, mel_lens)
  420. if args.output is not None and reps == 1:
  421. for i, audio in enumerate(audios):
  422. audio = audio[:mel_lens[i].item() * args.hop_length]
  423. if args.fade_out:
  424. fade_len = args.fade_out * args.hop_length
  425. fade_w = torch.linspace(1.0, 0.0, fade_len)
  426. audio[-fade_len:] *= fade_w.to(audio.device)
  427. audio = audio / torch.max(torch.abs(audio))
  428. fname = b['output'][i] if 'output' in b else f'audio_{all_utterances + i}.wav'
  429. audio_path = Path(args.output, fname)
  430. write(audio_path, args.sampling_rate, audio.cpu().numpy())
  431. if generator is not None:
  432. log(rep, {"latency": (gen_measures[-1] + vocoder_measures[-1])})
  433. all_utterances += mel.size(0)
  434. all_samples += mel_lens.sum().item() * args.hop_length
  435. all_batches += 1
  436. log_enabled = True
  437. if generator is not None:
  438. gm = np.sort(np.asarray(gen_measures))
  439. rtf = all_samples / (all_utterances * gm.mean() * args.sampling_rate)
  440. rtf_at = all_samples / (all_batches * gm.mean() * args.sampling_rate)
  441. log((), {f"avg_{gen_name}_tokens/s": all_letters / gm.sum()})
  442. log((), {f"avg_{gen_name}_frames/s": all_frames / gm.sum()})
  443. log((), {f"avg_{gen_name}_latency": gm.mean()})
  444. log((), {f"avg_{gen_name}_RTF": rtf})
  445. log((), {f"avg_{gen_name}_RTF@{args.batch_size}": rtf_at})
  446. log((), {f"90%_{gen_name}_latency": gm.mean() + norm.ppf((1.0 + 0.90) / 2) * gm.std()})
  447. log((), {f"95%_{gen_name}_latency": gm.mean() + norm.ppf((1.0 + 0.95) / 2) * gm.std()})
  448. log((), {f"99%_{gen_name}_latency": gm.mean() + norm.ppf((1.0 + 0.99) / 2) * gm.std()})
  449. if args.report_mel_loss:
  450. log((), {f"avg_{gen_name}_mel-loss": gen_mel_loss_sum / all_utterances})
  451. if vocoder is not None:
  452. vm = np.sort(np.asarray(vocoder_measures))
  453. rtf = all_samples / (all_utterances * vm.mean() * args.sampling_rate)
  454. rtf_at = all_samples / (all_batches * vm.mean() * args.sampling_rate)
  455. log((), {f"avg_{voc_name}_samples/s": all_samples / vm.sum()})
  456. log((), {f"avg_{voc_name}_latency": vm.mean()})
  457. log((), {f"avg_{voc_name}_RTF": rtf})
  458. log((), {f"avg_{voc_name}_RTF@{args.batch_size}": rtf_at})
  459. log((), {f"90%_{voc_name}_latency": vm.mean() + norm.ppf((1.0 + 0.90) / 2) * vm.std()})
  460. log((), {f"95%_{voc_name}_latency": vm.mean() + norm.ppf((1.0 + 0.95) / 2) * vm.std()})
  461. log((), {f"99%_{voc_name}_latency": vm.mean() + norm.ppf((1.0 + 0.99) / 2) * vm.std()})
  462. if args.report_mel_loss:
  463. log((), {f"avg_{voc_name}_mel-loss": voc_mel_loss_sum / all_utterances})
  464. if generator is not None and vocoder is not None:
  465. m = gm + vm
  466. rtf = all_samples / (all_utterances * m.mean() * args.sampling_rate)
  467. rtf_at = all_samples / (all_batches * m.mean() * args.sampling_rate)
  468. log((), {"avg_samples/s": all_samples / m.sum()})
  469. log((), {"avg_letters/s": all_letters / m.sum()})
  470. log((), {"avg_latency": m.mean()})
  471. log((), {"avg_RTF": rtf})
  472. log((), {f"avg_RTF@{args.batch_size}": rtf_at})
  473. log((), {"90%_latency": m.mean() + norm.ppf((1.0 + 0.90) / 2) * m.std()})
  474. log((), {"95%_latency": m.mean() + norm.ppf((1.0 + 0.95) / 2) * m.std()})
  475. log((), {"99%_latency": m.mean() + norm.ppf((1.0 + 0.99) / 2) * m.std()})
  476. DLLogger.flush()
  477. if __name__ == '__main__':
  478. main()