|
|
@@ -42,6 +42,8 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
|
|
|
|
|
|
from apex import amp
|
|
|
|
|
|
+from waveglow.denoiser import Denoiser
|
|
|
+
|
|
|
def parse_args(parser):
|
|
|
"""
|
|
|
Parse commandline arguments.
|
|
|
@@ -51,6 +53,7 @@ def parse_args(parser):
|
|
|
parser.add_argument('--waveglow', type=str,
|
|
|
help='full path to the WaveGlow model checkpoint file')
|
|
|
parser.add_argument('-s', '--sigma-infer', default=0.6, type=float)
|
|
|
+ parser.add_argument('-d', '--denoising-strength', default=0.01, type=float)
|
|
|
parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
|
|
|
help='Sampling rate')
|
|
|
parser.add_argument('--amp-run', action='store_true',
|
|
|
@@ -65,23 +68,24 @@ def parse_args(parser):
|
|
|
help='Input length')
|
|
|
parser.add_argument('-bs', '--batch-size', type=int, default=1,
|
|
|
help='Batch size')
|
|
|
-
|
|
|
-
|
|
|
+ parser.add_argument('--cpu-run', action='store_true',
|
|
|
+ help='Run inference on CPU')
|
|
|
return parser
|
|
|
|
|
|
|
|
|
-def load_and_setup_model(model_name, parser, checkpoint, amp_run, to_cuda=True):
|
|
|
+def load_and_setup_model(model_name, parser, checkpoint, amp_run, cpu_run, forward_is_infer=False):
|
|
|
model_parser = models.parse_model_args(model_name, parser, add_help=False)
|
|
|
model_args, _ = model_parser.parse_known_args()
|
|
|
|
|
|
model_config = models.get_model_config(model_name, model_args)
|
|
|
- model = models.get_model(model_name, model_config, to_cuda=to_cuda)
|
|
|
+ model = models.get_model(model_name, model_config, cpu_run, forward_is_infer=forward_is_infer)
|
|
|
|
|
|
if checkpoint is not None:
|
|
|
- if to_cuda:
|
|
|
- state_dict = torch.load(checkpoint)['state_dict']
|
|
|
+ if cpu_run:
|
|
|
+ state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict']
|
|
|
else:
|
|
|
- state_dict = torch.load(checkpoint,map_location='cpu')['state_dict']
|
|
|
+ state_dict = torch.load(checkpoint)['state_dict']
|
|
|
+
|
|
|
if checkpoint_from_distributed(state_dict):
|
|
|
state_dict = unwrap_distributed(state_dict)
|
|
|
|
|
|
@@ -141,7 +145,7 @@ def print_stats(measurements_all):
|
|
|
def main():
|
|
|
"""
|
|
|
Launches text to speech (inference).
|
|
|
- Inference is executed on a single GPU.
|
|
|
+ Inference is executed on a single GPU or CPU.
|
|
|
"""
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description='PyTorch Tacotron 2 Inference')
|
|
|
@@ -168,8 +172,15 @@ def main():
|
|
|
|
|
|
print("args:", args, unknown_args)
|
|
|
|
|
|
- tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.amp_run)
|
|
|
- waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.amp_run)
|
|
|
+ tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.amp_run, args.cpu_run, forward_is_infer=True)
|
|
|
+ waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.amp_run, args.cpu_run)
|
|
|
+
|
|
|
+ if args.cpu_run:
|
|
|
+ denoiser = Denoiser(waveglow, args.cpu_run)
|
|
|
+ else:
|
|
|
+ denoiser = Denoiser(waveglow, args.cpu_run).cuda()
|
|
|
+
|
|
|
+ jitted_tacotron2 = torch.jit.script(tacotron2)
|
|
|
|
|
|
texts = ["The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."]
|
|
|
texts = [texts[0][:args.input_length]]
|
|
|
@@ -181,27 +192,29 @@ def main():
|
|
|
|
|
|
measurements = {}
|
|
|
|
|
|
- with MeasureTime(measurements, "pre_processing"):
|
|
|
+ with MeasureTime(measurements, "pre_processing", args.cpu_run):
|
|
|
sequences_padded, input_lengths = prepare_input_sequence(texts)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- with MeasureTime(measurements, "latency"):
|
|
|
- with MeasureTime(measurements, "tacotron2_latency"):
|
|
|
- mel, mel_lengths, _ = tacotron2.infer(sequences_padded, input_lengths)
|
|
|
+ with MeasureTime(measurements, "latency", args.cpu_run):
|
|
|
+ with MeasureTime(measurements, "tacotron2_latency", args.cpu_run):
|
|
|
+ mel, mel_lengths, _ = jitted_tacotron2(sequences_padded, input_lengths)
|
|
|
|
|
|
- with MeasureTime(measurements, "waveglow_latency"):
|
|
|
+ with MeasureTime(measurements, "waveglow_latency", args.cpu_run):
|
|
|
audios = waveglow.infer(mel, sigma=args.sigma_infer)
|
|
|
+ audios = audios.float()
|
|
|
+ audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)
|
|
|
|
|
|
num_mels = mel.size(0)*mel.size(2)
|
|
|
num_samples = audios.size(0)*audios.size(1)
|
|
|
|
|
|
- with MeasureTime(measurements, "type_conversion"):
|
|
|
+ with MeasureTime(measurements, "type_conversion", args.cpu_run):
|
|
|
audios = audios.float()
|
|
|
|
|
|
- with MeasureTime(measurements, "data_transfer"):
|
|
|
+ with MeasureTime(measurements, "data_transfer", args.cpu_run):
|
|
|
audios = audios.cpu()
|
|
|
|
|
|
- with MeasureTime(measurements, "storage"):
|
|
|
+ with MeasureTime(measurements, "storage", args.cpu_run):
|
|
|
audios = audios.numpy()
|
|
|
for i, audio in enumerate(audios):
|
|
|
audio_path = "audio_"+str(i)+".wav"
|