# ***************************************************************************** # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of the NVIDIA CORPORATION nor the # names of its contributors may be used to endorse or promote products # derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # ***************************************************************************** import argparse import models import time import tqdm import sys import warnings from pathlib import Path import torch import numpy as np from scipy.stats import norm from scipy.io.wavfile import write from torch.nn.utils.rnn import pad_sequence import dllogger as DLLogger from apex import amp from dllogger import StdOutBackend, JSONStreamBackend, Verbosity from common import utils from common.text import text_to_sequence from waveglow import model as glow from waveglow.denoiser import Denoiser sys.modules['glow'] = glow def parse_args(parser): """ Parse commandline arguments. """ parser.add_argument('-i', '--input', type=str, required=True, help='Full path to the input text (phareses separated by newlines)') parser.add_argument('-o', '--output', default=None, help='Output folder to save audio (file per phrase)') parser.add_argument('--log-file', type=str, default='nvlog.json', help='Filename for logging') parser.add_argument('--cuda', action='store_true', help='Run inference on a GPU using CUDA') parser.add_argument('--fastpitch', type=str, help='Full path to the generator checkpoint file (skip to use ground truth mels)') parser.add_argument('--waveglow', type=str, help='Full path to the WaveGlow model checkpoint file (skip to only generate mels)') parser.add_argument('-s', '--sigma-infer', default=0.9, type=float, help='WaveGlow sigma') parser.add_argument('-d', '--denoising-strength', default=0.01, type=float, help='WaveGlow denoising') parser.add_argument('-sr', '--sampling-rate', default=22050, type=int, help='Sampling rate') parser.add_argument('--stft-hop-length', type=int, default=256, help='STFT hop length for estimating audio length from mel size') parser.add_argument('--amp-run', action='store_true', help='Inference with AMP') parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--include-warmup', action='store_true', help='Include warmup') parser.add_argument('--repeats', type=int, default=1, help='Repeat inference for benchmarking') parser.add_argument('--torchscript', action='store_true', help='Apply TorchScript') parser.add_argument('--ema', action='store_true', help='Use EMA averaged model (if saved in checkpoints)') parser.add_argument('--dataset-path', type=str, help='Path to dataset (for loading extra data fields)') transform = parser.add_argument_group('transform') transform.add_argument('--fade-out', type=int, default=5, help='Number of fadeout frames at the end') transform.add_argument('--pace', type=float, default=1.0, help='Adjust the pace of speech') transform.add_argument('--pitch-transform-flatten', action='store_true', help='Flatten the pitch') transform.add_argument('--pitch-transform-invert', action='store_true', help='Invert the pitch wrt mean value') transform.add_argument('--pitch-transform-amplify', action='store_true', help='Amplify the pitch variability') transform.add_argument('--pitch-transform-shift', type=float, default=0.0, help='Raise/lower the pitch by ') return parser def load_and_setup_model(model_name, parser, checkpoint, amp_run, device, unk_args=[], forward_is_infer=False, ema=True, jitable=False): model_parser = models.parse_model_args(model_name, parser, add_help=False) model_args, model_unk_args = model_parser.parse_known_args() unk_args[:] = list(set(unk_args) & set(model_unk_args)) model_config = models.get_model_config(model_name, model_args) model = models.get_model(model_name, model_config, device, forward_is_infer=forward_is_infer, jitable=jitable) if checkpoint is not None: checkpoint_data = torch.load(checkpoint) status = '' if 'state_dict' in checkpoint_data: sd = checkpoint_data['state_dict'] if ema and 'ema_state_dict' in checkpoint_data: sd = checkpoint_data['ema_state_dict'] status += ' (EMA)' elif ema and not 'ema_state_dict' in checkpoint_data: print(f'WARNING: EMA weights missing for {model_name}') if any(key.startswith('module.') for key in sd): sd = {k.replace('module.', ''): v for k,v in sd.items()} status += ' ' + str(model.load_state_dict(sd, strict=False)) else: model = checkpoint_data['model'] print(f'Loaded {model_name}{status}') if model_name == "WaveGlow": model = model.remove_weightnorm(model) if amp_run: model.half() model.eval() return model.to(device) def load_fields(fpath): lines = [l.strip() for l in open(fpath, encoding='utf-8')] if fpath.endswith('.tsv'): columns = lines[0].split('\t') fields = list(zip(*[t.split('\t') for t in lines[1:]])) else: columns = ['text'] fields = [lines] return {c:f for c, f in zip(columns, fields)} def prepare_input_sequence(fields, device, batch_size=128, dataset=None, load_mels=False, load_pitch=False): fields['text'] = [torch.LongTensor(text_to_sequence(t, ['english_cleaners'])) for t in fields['text']] order = np.argsort([-t.size(0) for t in fields['text']]) fields['text'] = [fields['text'][i] for i in order] fields['text_lens'] = torch.LongTensor([t.size(0) for t in fields['text']]) if load_mels: assert 'mel' in fields fields['mel'] = [ torch.load(Path(dataset, fields['mel'][i])).t() for i in order] fields['mel_lens'] = torch.LongTensor([t.size(0) for t in fields['mel']]) if load_pitch: assert 'pitch' in fields fields['pitch'] = [ torch.load(Path(dataset, fields['pitch'][i])) for i in order] fields['pitch_lens'] = torch.LongTensor([t.size(0) for t in fields['pitch']]) if 'output' in fields: fields['output'] = [fields['output'][i] for i in order] # cut into batches & pad batches = [] for b in range(0, len(order), batch_size): batch = {f: values[b:b+batch_size] for f, values in fields.items()} for f in batch: if f == 'text': batch[f] = pad_sequence(batch[f], batch_first=True) elif f == 'mel' and load_mels: batch[f] = pad_sequence(batch[f], batch_first=True).permute(0, 2, 1) elif f == 'pitch' and load_pitch: batch[f] = pad_sequence(batch[f], batch_first=True) if type(batch[f]) is torch.Tensor: batch[f] = batch[f].to(device) batches.append(batch) return batches def build_pitch_transformation(args): fun = 'pitch' if args.pitch_transform_flatten: fun = f'({fun}) * 0.0' if args.pitch_transform_invert: fun = f'({fun}) * -1.0' if args.pitch_transform_amplify: fun = f'({fun}) * 2.0' if args.pitch_transform_shift != 0.0: hz = args.pitch_transform_shift fun = f'({fun}) + {hz} / std' return eval(f'lambda pitch, mean, std: {fun}') class MeasureTime(list): def __enter__(self): torch.cuda.synchronize() self.t0 = time.perf_counter() def __exit__(self, exc_type, exc_value, exc_traceback): torch.cuda.synchronize() self.append(time.perf_counter() - self.t0) def __add__(self, other): assert len(self) == len(other) return MeasureTime(sum(ab) for ab in zip(self, other)) def main(): """ Launches text to speech (inference). Inference is executed on a single GPU. """ parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference', allow_abbrev=False) parser = parse_args(parser) args, unk_args = parser.parse_known_args() DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, args.log_file), StdOutBackend(Verbosity.VERBOSE)]) for k,v in vars(args).items(): DLLogger.log(step="PARAMETER", data={k:v}) DLLogger.log(step="PARAMETER", data={'model_name': 'FastPitch_PyT'}) if args.output is not None: Path(args.output).mkdir(parents=False, exist_ok=True) device = torch.device('cuda' if args.cuda else 'cpu') if args.fastpitch is not None: generator = load_and_setup_model( 'FastPitch', parser, args.fastpitch, args.amp_run, device, unk_args=unk_args, forward_is_infer=True, ema=args.ema, jitable=args.torchscript) if args.torchscript: generator = torch.jit.script(generator) else: generator = None if args.waveglow is not None: with warnings.catch_warnings(): warnings.simplefilter("ignore") waveglow = load_and_setup_model( 'WaveGlow', parser, args.waveglow, args.amp_run, device, unk_args=unk_args, forward_is_infer=True, ema=args.ema) denoiser = Denoiser(waveglow).to(device) waveglow = getattr(waveglow, 'infer', waveglow) else: waveglow = None if len(unk_args) > 0: raise ValueError(f'Invalid options {unk_args}') fields = load_fields(args.input) batches = prepare_input_sequence( fields, device, args.batch_size, args.dataset_path, load_mels=(generator is None)) if args.include_warmup: # Use real data rather than synthetic - FastPitch predicts len for i in range(3): with torch.no_grad(): if generator is not None: b = batches[0] mel, *_ = generator(b['text'], b['text_lens']) if waveglow is not None: audios = waveglow(mel, sigma=args.sigma_infer).float() _ = denoiser(audios, strength=args.denoising_strength) gen_measures = MeasureTime() waveglow_measures = MeasureTime() gen_kw = {'pace': args.pace, 'pitch_tgt': None, 'pitch_transform': build_pitch_transformation(args)} if args.torchscript: gen_kw.pop('pitch_transform') all_utterances = 0 all_samples = 0 all_letters = 0 all_frames = 0 reps = args.repeats log_enabled = reps == 1 log = lambda s, d: DLLogger.log(step=s, data=d) if log_enabled else None for repeat in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)): for b in batches: if generator is None: log(0, {'Synthesizing from ground truth mels'}) mel, mel_lens = b['mel'], b['mel_lens'] else: with torch.no_grad(), gen_measures: mel, mel_lens, *_ = generator( b['text'], b['text_lens'], **gen_kw) gen_infer_perf = mel.size(0) * mel.size(2) / gen_measures[-1] all_letters += b['text_lens'].sum().item() all_frames += mel.size(0) * mel.size(2) log(0, {"generator_frames_per_sec": gen_infer_perf}) log(0, {"generator_latency": gen_measures[-1]}) if waveglow is not None: with torch.no_grad(), waveglow_measures: audios = waveglow(mel, sigma=args.sigma_infer) audios = denoiser(audios.float(), strength=args.denoising_strength ).squeeze(1) all_utterances += len(audios) all_samples += sum(audio.size(0) for audio in audios) waveglow_infer_perf = ( audios.size(0) * audios.size(1) / waveglow_measures[-1]) log(0, {"waveglow_samples_per_sec": waveglow_infer_perf}) log(0, {"waveglow_latency": waveglow_measures[-1]}) if args.output is not None and reps == 1: for i, audio in enumerate(audios): audio = audio[:mel_lens[i].item() * args.stft_hop_length] if args.fade_out: fade_len = args.fade_out * args.stft_hop_length fade_w = torch.linspace(1.0, 0.0, fade_len) audio[-fade_len:] *= fade_w.to(audio.device) audio = audio/torch.max(torch.abs(audio)) fname = b['output'][i] if 'output' in b else f'audio_{i}.wav' audio_path = Path(args.output, fname) write(audio_path, args.sampling_rate, audio.cpu().numpy()) if generator is not None and waveglow is not None: log(0, {"latency": (gen_measures[-1] + waveglow_measures[-1])}) log_enabled = True if generator is not None: gm = np.sort(np.asarray(gen_measures)) log('avg', {"generator letters/s": all_letters / gm.sum()}) log('avg', {"generator_frames/s": all_frames / gm.sum()}) log('avg', {"generator_latency": gm.mean()}) log('90%', {"generator_latency": gm.mean() + norm.ppf((1.0 + 0.90) / 2) * gm.std()}) log('95%', {"generator_latency": gm.mean() + norm.ppf((1.0 + 0.95) / 2) * gm.std()}) log('99%', {"generator_latency": gm.mean() + norm.ppf((1.0 + 0.99) / 2) * gm.std()}) if waveglow is not None: wm = np.sort(np.asarray(waveglow_measures)) log('avg', {"waveglow_samples/s": all_samples / wm.sum()}) log('avg', {"waveglow_latency": wm.mean()}) log('90%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.90) / 2) * wm.std()}) log('95%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.95) / 2) * wm.std()}) log('99%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.99) / 2) * wm.std()}) if generator is not None and waveglow is not None: m = gm + wm rtf = all_samples / (len(batches) * all_utterances * m.mean() * args.sampling_rate) log('avg', {"samples/s": all_samples / m.sum()}) log('avg', {"letters/s": all_letters / m.sum()}) log('avg', {"latency": m.mean()}) log('avg', {"RTF": rtf}) log('90%', {"latency": m.mean() + norm.ppf((1.0 + 0.90) / 2) * m.std()}) log('95%', {"latency": m.mean() + norm.ppf((1.0 + 0.95) / 2) * m.std()}) log('99%', {"latency": m.mean() + norm.ppf((1.0 + 0.99) / 2) * m.std()}) DLLogger.flush() if __name__ == '__main__': main()