| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389 |
- # *****************************************************************************
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- # * Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # * Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- # * Neither the name of the NVIDIA CORPORATION nor the
- # names of its contributors may be used to endorse or promote products
- # derived from this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- #
- # *****************************************************************************
- import argparse
- import models
- import time
- import tqdm
- import sys
- import warnings
- from pathlib import Path
- import torch
- import numpy as np
- from scipy.stats import norm
- from scipy.io.wavfile import write
- from torch.nn.utils.rnn import pad_sequence
- import dllogger as DLLogger
- from apex import amp
- from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
- from common import utils
- from common.text import text_to_sequence
- from waveglow import model as glow
- from waveglow.denoiser import Denoiser
- sys.modules['glow'] = glow
- def parse_args(parser):
- """
- Parse commandline arguments.
- """
- parser.add_argument('-i', '--input', type=str, required=True,
- help='Full path to the input text (phareses separated by newlines)')
- parser.add_argument('-o', '--output', default=None,
- help='Output folder to save audio (file per phrase)')
- parser.add_argument('--log-file', type=str, default='nvlog.json',
- help='Filename for logging')
- parser.add_argument('--cuda', action='store_true',
- help='Run inference on a GPU using CUDA')
- parser.add_argument('--fastpitch', type=str,
- help='Full path to the generator checkpoint file (skip to use ground truth mels)')
- parser.add_argument('--waveglow', type=str,
- help='Full path to the WaveGlow model checkpoint file (skip to only generate mels)')
- parser.add_argument('-s', '--sigma-infer', default=0.9, type=float,
- help='WaveGlow sigma')
- parser.add_argument('-d', '--denoising-strength', default=0.01, type=float,
- help='WaveGlow denoising')
- parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
- help='Sampling rate')
- parser.add_argument('--stft-hop-length', type=int, default=256,
- help='STFT hop length for estimating audio length from mel size')
- parser.add_argument('--amp-run', action='store_true',
- help='Inference with AMP')
- parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--include-warmup', action='store_true',
- help='Include warmup')
- parser.add_argument('--repeats', type=int, default=1,
- help='Repeat inference for benchmarking')
- parser.add_argument('--torchscript', action='store_true',
- help='Apply TorchScript')
- parser.add_argument('--ema', action='store_true',
- help='Use EMA averaged model (if saved in checkpoints)')
- parser.add_argument('--dataset-path', type=str,
- help='Path to dataset (for loading extra data fields)')
- transform = parser.add_argument_group('transform')
- transform.add_argument('--fade-out', type=int, default=5,
- help='Number of fadeout frames at the end')
- transform.add_argument('--pace', type=float, default=1.0,
- help='Adjust the pace of speech')
- transform.add_argument('--pitch-transform-flatten', action='store_true',
- help='Flatten the pitch')
- transform.add_argument('--pitch-transform-invert', action='store_true',
- help='Invert the pitch wrt mean value')
- transform.add_argument('--pitch-transform-amplify', action='store_true',
- help='Amplify the pitch variability')
- transform.add_argument('--pitch-transform-shift', type=float, default=0.0,
- help='Raise/lower the pitch by <hz>')
- return parser
- def load_and_setup_model(model_name, parser, checkpoint, amp_run, device,
- unk_args=[], forward_is_infer=False, ema=True,
- jitable=False):
- model_parser = models.parse_model_args(model_name, parser, add_help=False)
- model_args, model_unk_args = model_parser.parse_known_args()
- unk_args[:] = list(set(unk_args) & set(model_unk_args))
- model_config = models.get_model_config(model_name, model_args)
- model = models.get_model(model_name, model_config, device,
- forward_is_infer=forward_is_infer,
- jitable=jitable)
- if checkpoint is not None:
- checkpoint_data = torch.load(checkpoint)
- status = ''
- if 'state_dict' in checkpoint_data:
- sd = checkpoint_data['state_dict']
- if ema and 'ema_state_dict' in checkpoint_data:
- sd = checkpoint_data['ema_state_dict']
- status += ' (EMA)'
- elif ema and not 'ema_state_dict' in checkpoint_data:
- print(f'WARNING: EMA weights missing for {model_name}')
- if any(key.startswith('module.') for key in sd):
- sd = {k.replace('module.', ''): v for k,v in sd.items()}
- status += ' ' + str(model.load_state_dict(sd, strict=False))
- else:
- model = checkpoint_data['model']
- print(f'Loaded {model_name}{status}')
- if model_name == "WaveGlow":
- model = model.remove_weightnorm(model)
- if amp_run:
- model.half()
- model.eval()
- return model.to(device)
- def load_fields(fpath):
- lines = [l.strip() for l in open(fpath, encoding='utf-8')]
- if fpath.endswith('.tsv'):
- columns = lines[0].split('\t')
- fields = list(zip(*[t.split('\t') for t in lines[1:]]))
- else:
- columns = ['text']
- fields = [lines]
- return {c:f for c, f in zip(columns, fields)}
- def prepare_input_sequence(fields, device, batch_size=128, dataset=None,
- load_mels=False, load_pitch=False):
- fields['text'] = [torch.LongTensor(text_to_sequence(t, ['english_cleaners']))
- for t in fields['text']]
- order = np.argsort([-t.size(0) for t in fields['text']])
- fields['text'] = [fields['text'][i] for i in order]
- fields['text_lens'] = torch.LongTensor([t.size(0) for t in fields['text']])
- if load_mels:
- assert 'mel' in fields
- fields['mel'] = [
- torch.load(Path(dataset, fields['mel'][i])).t() for i in order]
- fields['mel_lens'] = torch.LongTensor([t.size(0) for t in fields['mel']])
- if load_pitch:
- assert 'pitch' in fields
- fields['pitch'] = [
- torch.load(Path(dataset, fields['pitch'][i])) for i in order]
- fields['pitch_lens'] = torch.LongTensor([t.size(0) for t in fields['pitch']])
- if 'output' in fields:
- fields['output'] = [fields['output'][i] for i in order]
- # cut into batches & pad
- batches = []
- for b in range(0, len(order), batch_size):
- batch = {f: values[b:b+batch_size] for f, values in fields.items()}
- for f in batch:
- if f == 'text':
- batch[f] = pad_sequence(batch[f], batch_first=True)
- elif f == 'mel' and load_mels:
- batch[f] = pad_sequence(batch[f], batch_first=True).permute(0, 2, 1)
- elif f == 'pitch' and load_pitch:
- batch[f] = pad_sequence(batch[f], batch_first=True)
- if type(batch[f]) is torch.Tensor:
- batch[f] = batch[f].to(device)
- batches.append(batch)
- return batches
- def build_pitch_transformation(args):
- fun = 'pitch'
- if args.pitch_transform_flatten:
- fun = f'({fun}) * 0.0'
- if args.pitch_transform_invert:
- fun = f'({fun}) * -1.0'
- if args.pitch_transform_amplify:
- fun = f'({fun}) * 2.0'
- if args.pitch_transform_shift != 0.0:
- hz = args.pitch_transform_shift
- fun = f'({fun}) + {hz} / std'
- return eval(f'lambda pitch, mean, std: {fun}')
- class MeasureTime(list):
- def __enter__(self):
- torch.cuda.synchronize()
- self.t0 = time.perf_counter()
- def __exit__(self, exc_type, exc_value, exc_traceback):
- torch.cuda.synchronize()
- self.append(time.perf_counter() - self.t0)
- def __add__(self, other):
- assert len(self) == len(other)
- return MeasureTime(sum(ab) for ab in zip(self, other))
- def main():
- """
- Launches text to speech (inference).
- Inference is executed on a single GPU.
- """
- parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference',
- allow_abbrev=False)
- parser = parse_args(parser)
- args, unk_args = parser.parse_known_args()
- DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, args.log_file),
- StdOutBackend(Verbosity.VERBOSE)])
- for k,v in vars(args).items():
- DLLogger.log(step="PARAMETER", data={k:v})
- DLLogger.log(step="PARAMETER", data={'model_name': 'FastPitch_PyT'})
- if args.output is not None:
- Path(args.output).mkdir(parents=False, exist_ok=True)
- device = torch.device('cuda' if args.cuda else 'cpu')
- if args.fastpitch is not None:
- generator = load_and_setup_model(
- 'FastPitch', parser, args.fastpitch, args.amp_run, device,
- unk_args=unk_args, forward_is_infer=True, ema=args.ema,
- jitable=args.torchscript)
- if args.torchscript:
- generator = torch.jit.script(generator)
- else:
- generator = None
- if args.waveglow is not None:
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- waveglow = load_and_setup_model(
- 'WaveGlow', parser, args.waveglow, args.amp_run, device,
- unk_args=unk_args, forward_is_infer=True, ema=args.ema)
- denoiser = Denoiser(waveglow).to(device)
- waveglow = getattr(waveglow, 'infer', waveglow)
- else:
- waveglow = None
- if len(unk_args) > 0:
- raise ValueError(f'Invalid options {unk_args}')
- fields = load_fields(args.input)
- batches = prepare_input_sequence(
- fields, device, args.batch_size, args.dataset_path,
- load_mels=(generator is None))
- if args.include_warmup:
- # Use real data rather than synthetic - FastPitch predicts len
- for i in range(3):
- with torch.no_grad():
- if generator is not None:
- b = batches[0]
- mel, *_ = generator(b['text'], b['text_lens'])
- if waveglow is not None:
- audios = waveglow(mel, sigma=args.sigma_infer).float()
- _ = denoiser(audios, strength=args.denoising_strength)
- gen_measures = MeasureTime()
- waveglow_measures = MeasureTime()
- gen_kw = {'pace': args.pace,
- 'pitch_tgt': None,
- 'pitch_transform': build_pitch_transformation(args)}
- if args.torchscript:
- gen_kw.pop('pitch_transform')
- all_utterances = 0
- all_samples = 0
- all_letters = 0
- all_frames = 0
- reps = args.repeats
- log_enabled = reps == 1
- log = lambda s, d: DLLogger.log(step=s, data=d) if log_enabled else None
- for repeat in (tqdm.tqdm(range(reps)) if reps > 1 else range(reps)):
- for b in batches:
- if generator is None:
- log(0, {'Synthesizing from ground truth mels'})
- mel, mel_lens = b['mel'], b['mel_lens']
- else:
- with torch.no_grad(), gen_measures:
- mel, mel_lens, *_ = generator(
- b['text'], b['text_lens'], **gen_kw)
- gen_infer_perf = mel.size(0) * mel.size(2) / gen_measures[-1]
- all_letters += b['text_lens'].sum().item()
- all_frames += mel.size(0) * mel.size(2)
- log(0, {"generator_frames_per_sec": gen_infer_perf})
- log(0, {"generator_latency": gen_measures[-1]})
- if waveglow is not None:
- with torch.no_grad(), waveglow_measures:
- audios = waveglow(mel, sigma=args.sigma_infer)
- audios = denoiser(audios.float(),
- strength=args.denoising_strength
- ).squeeze(1)
- all_utterances += len(audios)
- all_samples += sum(audio.size(0) for audio in audios)
- waveglow_infer_perf = (
- audios.size(0) * audios.size(1) / waveglow_measures[-1])
- log(0, {"waveglow_samples_per_sec": waveglow_infer_perf})
- log(0, {"waveglow_latency": waveglow_measures[-1]})
- if args.output is not None and reps == 1:
- for i, audio in enumerate(audios):
- audio = audio[:mel_lens[i].item() * args.stft_hop_length]
- if args.fade_out:
- fade_len = args.fade_out * args.stft_hop_length
- fade_w = torch.linspace(1.0, 0.0, fade_len)
- audio[-fade_len:] *= fade_w.to(audio.device)
- audio = audio/torch.max(torch.abs(audio))
- fname = b['output'][i] if 'output' in b else f'audio_{i}.wav'
- audio_path = Path(args.output, fname)
- write(audio_path, args.sampling_rate, audio.cpu().numpy())
- if generator is not None and waveglow is not None:
- log(0, {"latency": (gen_measures[-1] + waveglow_measures[-1])})
- log_enabled = True
- if generator is not None:
- gm = np.sort(np.asarray(gen_measures))
- log('avg', {"generator letters/s": all_letters / gm.sum()})
- log('avg', {"generator_frames/s": all_frames / gm.sum()})
- log('avg', {"generator_latency": gm.mean()})
- log('90%', {"generator_latency": gm.mean() + norm.ppf((1.0 + 0.90) / 2) * gm.std()})
- log('95%', {"generator_latency": gm.mean() + norm.ppf((1.0 + 0.95) / 2) * gm.std()})
- log('99%', {"generator_latency": gm.mean() + norm.ppf((1.0 + 0.99) / 2) * gm.std()})
- if waveglow is not None:
- wm = np.sort(np.asarray(waveglow_measures))
- log('avg', {"waveglow_samples/s": all_samples / wm.sum()})
- log('avg', {"waveglow_latency": wm.mean()})
- log('90%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.90) / 2) * wm.std()})
- log('95%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.95) / 2) * wm.std()})
- log('99%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.99) / 2) * wm.std()})
- if generator is not None and waveglow is not None:
- m = gm + wm
- rtf = all_samples / (len(batches) * all_utterances * m.mean() * args.sampling_rate)
- log('avg', {"samples/s": all_samples / m.sum()})
- log('avg', {"letters/s": all_letters / m.sum()})
- log('avg', {"latency": m.mean()})
- log('avg', {"RTF": rtf})
- log('90%', {"latency": m.mean() + norm.ppf((1.0 + 0.90) / 2) * m.std()})
- log('95%', {"latency": m.mean() + norm.ppf((1.0 + 0.95) / 2) * m.std()})
- log('99%', {"latency": m.mean() + norm.ppf((1.0 + 0.99) / 2) * m.std()})
- DLLogger.flush()
- if __name__ == '__main__':
- main()
|