Просмотр исходного кода

Merge pull request #482 from maggiezha/cpu-run

Cpu run
GrzegorzKarchNV 5 лет назад
Родитель
Сommit
67a7d9c4eb

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/Dockerfile

@@ -1,4 +1,4 @@
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.01-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.03-py3
 FROM ${FROM_IMAGE_NAME}
 
 ADD . /workspace/tacotron2

+ 30 - 4
PyTorch/SpeechSynthesis/Tacotron2/README.md

@@ -231,7 +231,7 @@ and encapsulates some dependencies. Aside from these dependencies, ensure you
 have the following components:
 
 * [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
-* [PyTorch 20.01-py3+ NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
+* [PyTorch 20.03-py3+ NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
 or newer
 * [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) based GPU
 
@@ -320,12 +320,16 @@ Ensure your loss values are comparable to those listed in the table in the
 7. Start inference.
 After you have trained the Tacotron 2 and WaveGlow models, you can perform
 inference using the respective checkpoints that are passed as `--tacotron2`
-and `--waveglow` arguments.
+and `--waveglow` arguments. Tacotron2 and WaveGlow checkpoints can also be downloaded from NGC:
+
+   https://ngc.nvidia.com/catalog/models/nvidia:tacotron2pyt_fp16/files?version=3
+   
+   https://ngc.nvidia.com/catalog/models/nvidia:waveglow256pyt_fp16/files?version=2
 
    To run inference issue:
 
    ```bash
-   python inference.py --tacotron2 <Tacotron2_checkpoint> --waveglow <WaveGlow_checkpoint> -o output/ -i phrases/phrase.txt --amp-run
+   python inference.py --tacotron2 <Tacotron2_checkpoint> --waveglow <WaveGlow_checkpoint> --wn-channels 256 -o output/ -i phrases/phrase.txt --amp-run
    ```
 
    The speech is generated from lines of text in the file that is passed with
@@ -333,6 +337,14 @@ and `--waveglow` arguments.
    inference in mixed precision, use the `--amp-run` flag. The output audio will
    be stored in the path specified by the `-o` argument.
 
+   You can also run inference on CPU with TorchScript by adding flag --cpu-run:
+   ```bash
+   export CUDA_VISIBLE_DEVICES=
+   ```
+   ```bash
+   python inference.py --tacotron2 <Tacotron2_checkpoint> --waveglow <WaveGlow_checkpoint> --wn-channels 256 --cpu-run -o output/ -i phrases/phrase.txt
+   ```    
+
 ## Advanced
 
 The following sections provide greater details of the dataset, running
@@ -372,6 +384,7 @@ WaveGlow models.
 * `--learning-rate` - learning rate (Tacotron 2: 1e-3, WaveGlow: 1e-4)
 * `--batch-size` - batch size (Tacotron 2 FP16/FP32: 104/48, WaveGlow FP16/FP32: 10/4)
 * `--amp-run` - use mixed precision training
+* `--cpu-run` - use CPU with TorchScript for inference
 
 #### Shared audio/STFT parameters
 
@@ -469,7 +482,7 @@ models and input text as a text file, with one phrase per line.
 
 To run inference, issue:
 ```bash
-python inference.py --tacotron2 <Tacotron2_checkpoint> --waveglow <WaveGlow_checkpoint> -o output/ --include-warmup -i phrases/phrase.txt --amp-run
+python inference.py --tacotron2 <Tacotron2_checkpoint> --waveglow <WaveGlow_checkpoint> --wn-channels 256 -o output/ --include-warmup -i phrases/phrase.txt --amp-run
 ```
 Here, `Tacotron2_checkpoint` and `WaveGlow_checkpoint` are pre-trained
 checkpoints for the respective models, and `phrases/phrase.txt` contains input 
@@ -480,6 +493,14 @@ mixed precision and FP32 training, respectively.
 
 You can find all the available options by calling `python inference.py --help`.
 
+You can also run inference on CPU with TorchScript by adding flag --cpu-run:
+```bash
+export CUDA_VISIBLE_DEVICES=
+```
+```bash
+python inference.py --tacotron2 <Tacotron2_checkpoint> --waveglow <WaveGlow_checkpoint> --wn-channels 256 --cpu-run -o output/ -i phrases/phrase.txt
+```    
+
 ## Performance
 
 ### Benchmarking
@@ -557,6 +578,7 @@ The output log files will contain performance numbers for Tacotron 2 model
 and for WaveGlow (number of output samples per second, reported as `waveglow_items_per_sec`).
 The `inference.py` script will run a few warmup iterations before running the benchmark.
 
+
 ### Results
 
 The following sections provide details on how we achieved our performance
@@ -672,6 +694,10 @@ the PyTorch-19.09-py3 NGC container. Please note that to reproduce the results,
 you need to provide pretrained checkpoints for Tacotron 2 and WaveGlow. Please
 edit the script to provide your checkpoint filenames.
 
+
+To compare with inference performance on CPU with TorchScript, benchmark inference on CPU using `./run_latency_tests_cpu.sh` script and get the performance numbers for batch size 1 and 4.
+
+
 ## Release notes
 
 ### Changelog

+ 35 - 19
PyTorch/SpeechSynthesis/Tacotron2/inference.py

@@ -68,7 +68,8 @@ def parse_args(parser):
                         help='Include warmup')
     parser.add_argument('--stft-hop-length', type=int, default=256,
                         help='STFT hop length for estimating audio length from mel size')
-
+    parser.add_argument('--cpu-run', action='store_true', 
+                        help='Run inference on CPU')
 
     return parser
 
@@ -102,16 +103,18 @@ def unwrap_distributed(state_dict):
     return new_state_dict
 
 
-def load_and_setup_model(model_name, parser, checkpoint, amp_run, forward_is_infer=False):
+def load_and_setup_model(model_name, parser, checkpoint, amp_run, cpu_run, forward_is_infer=False):
     model_parser = models.parse_model_args(model_name, parser, add_help=False)
     model_args, _ = model_parser.parse_known_args()
-
     model_config = models.get_model_config(model_name, model_args)
-    model = models.get_model(model_name, model_config, to_cuda=True,
-                             forward_is_infer=forward_is_infer)
-
+    model = models.get_model(model_name, model_config, cpu_run, forward_is_infer=forward_is_infer)
+    
     if checkpoint is not None:
-        state_dict = torch.load(checkpoint)['state_dict']
+        if cpu_run:
+            state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict']
+        else:
+            state_dict = torch.load(checkpoint)['state_dict']
+            
         if checkpoint_from_distributed(state_dict):
             state_dict = unwrap_distributed(state_dict)
 
@@ -164,23 +167,26 @@ def prepare_input_sequence(texts):
 
 
 class MeasureTime():
-    def __init__(self, measurements, key):
+    def __init__(self, measurements, key, cpu_run):
         self.measurements = measurements
         self.key = key
+        self.cpu_run = cpu_run
 
     def __enter__(self):
-        torch.cuda.synchronize()
+        if self.cpu_run == False:
+            torch.cuda.synchronize()
         self.t0 = time.perf_counter()
 
     def __exit__(self, exc_type, exc_value, exc_traceback):
-        torch.cuda.synchronize()
+        if self.cpu_run == False:
+            torch.cuda.synchronize()
         self.measurements[self.key] = time.perf_counter() - self.t0
 
 
 def main():
     """
     Launches text to speech (inference).
-    Inference is executed on a single GPU.
+    Inference is executed on a single GPU or CPU.
     """
     parser = argparse.ArgumentParser(
         description='PyTorch Tacotron 2 Inference')
@@ -195,10 +201,14 @@ def main():
     DLLogger.log(step="PARAMETER", data={'model_name':'Tacotron2_PyT'})
 
     tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2,
-                                     args.amp_run, forward_is_infer=True)
+                                     args.amp_run, args.cpu_run, forward_is_infer=True)
     waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow,
-                                    args.amp_run, forward_is_infer=True)
-    denoiser = Denoiser(waveglow).cuda()
+                                    args.amp_run, args.cpu_run, forward_is_infer=True)
+    
+    if args.cpu_run:
+        denoiser = Denoiser(waveglow, args.cpu_run)
+    else:
+         denoiser = Denoiser(waveglow, args.cpu_run).cuda()
 
     jitted_tacotron2 = torch.jit.script(tacotron2)
 
@@ -211,9 +221,14 @@ def main():
         sys.exit(1)
 
     if args.include_warmup:
-        sequence = torch.randint(low=0, high=148, size=(1,50),
+        if args.cpu_run:
+            sequence = torch.randint(low=0, high=148, size=(1,50),
+                                 dtype=torch.long)
+            input_lengths = torch.IntTensor([sequence.size(1)]).long()
+        else:
+            sequence = torch.randint(low=0, high=148, size=(1,50),
                                  dtype=torch.long).cuda()
-        input_lengths = torch.IntTensor([sequence.size(1)]).cuda().long()
+            input_lengths = torch.IntTensor([sequence.size(1)]).cuda().long()
         for i in range(3):
             with torch.no_grad():
                 mel, mel_lengths, _ = jitted_tacotron2(sequence, input_lengths)
@@ -223,16 +238,17 @@ def main():
 
     sequences_padded, input_lengths = prepare_input_sequence(texts)
 
-    with torch.no_grad(), MeasureTime(measurements, "tacotron2_time"):
+    with torch.no_grad(), MeasureTime(measurements, "tacotron2_time", args.cpu_run):
         mel, mel_lengths, alignments = jitted_tacotron2(sequences_padded, input_lengths)
 
-    with torch.no_grad(), MeasureTime(measurements, "waveglow_time"):
+    with torch.no_grad(), MeasureTime(measurements, "waveglow_time", args.cpu_run):
         audios = waveglow(mel, sigma=args.sigma_infer)
         audios = audios.float()
         audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)
 
     print("Stopping after",mel.size(2),"decoder steps")
-    tacotron2_infer_perf = mel.size(0)*mel.size(2)/measurements['tacotron2_time']
+
+    tacotron2_infer_perf = mel.size(0)*mel.size(2)/measurements['tacotron2_time']   
     waveglow_infer_perf = audios.size(0)*audios.size(1)/measurements['waveglow_time']
 
     DLLogger.log(step=0, data={"tacotron2_items_per_sec": tacotron2_infer_perf})

+ 2 - 2
PyTorch/SpeechSynthesis/Tacotron2/models.py

@@ -62,7 +62,7 @@ def init_bn(module):
         init_bn(child)
 
 
-def get_model(model_name, model_config, to_cuda,
+def get_model(model_name, model_config, cpu_run,
               uniform_initialize_bn_weight=False, forward_is_infer=False):
     """ Code chooses a model based on name"""
     model = None
@@ -88,7 +88,7 @@ def get_model(model_name, model_config, to_cuda,
     if uniform_initialize_bn_weight:
         init_bn(model)
 
-    if to_cuda:
+    if cpu_run==False:
         model = model.cuda()
     return model
 

+ 4 - 4
PyTorch/SpeechSynthesis/Tacotron2/run_latency_tests.sh

@@ -1,4 +1,4 @@
-bash test_infer.sh -bs 1 -il 128 -p amp --num-iters 1003 --tacotron2 ./checkpoints/checkpoint_Tacotron2_amp --waveglow ./checkpoints/checkpoint_WaveGlow_amp
-bash test_infer.sh -bs 4 -il 128 -p amp --num-iters 1003 --tacotron2 ./checkpoints/checkpoint_Tacotron2_amp --waveglow ./checkpoints/checkpoint_WaveGlow_amp
-bash test_infer.sh -bs 1 -il 128 -p fp32 --num-iters 1003 --tacotron2 ./checkpoints/checkpoint_Tacotron2_fp32 --waveglow ./checkpoints/checkpoint_WaveGlow_fp32
-bash test_infer.sh -bs 4 -il 128 -p fp32 --num-iters 1003 --tacotron2 ./checkpoints/checkpoint_Tacotron2_fp32 --waveglow ./checkpoints/checkpoint_WaveGlow_fp32
+bash test_infer.sh -bs 1 -il 128 -p amp --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256
+bash test_infer.sh -bs 4 -il 128 -p amp --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256
+bash test_infer.sh -bs 1 -il 128 -p fp32 --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256
+bash test_infer.sh -bs 4 -il 128 -p fp32 --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256

+ 4 - 0
PyTorch/SpeechSynthesis/Tacotron2/run_latency_tests_cpu.sh

@@ -0,0 +1,4 @@
+export CUDA_VISIBLE_DEVICES=
+
+bash test_infer.sh -bs 1 -il 128 -p fp32 --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256 --cpu-run
+bash test_infer.sh -bs 4 -il 128 -p fp32 --num-iters 1003 --tacotron2 tacotron2_1032590_6000_amp --waveglow waveglow_1076430_14000_amp --wn-channels 256 --cpu-run

+ 3 - 4
PyTorch/SpeechSynthesis/Tacotron2/tacotron2/model.py

@@ -535,10 +535,9 @@ class Decoder(nn.Module):
          attention_weights_cum,
          attention_context,
          processed_memory) = self.initialize_decoder_states(memory)
-
-        mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32).cuda()
-        not_finished = torch.ones([memory.size(0)], dtype=torch.int32).cuda()
-
+        mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32, device=memory.device)
+        not_finished = torch.ones([memory.size(0)], dtype=torch.int32, device=memory.device)
+        
         mel_outputs, gate_outputs, alignments = (
             torch.zeros(1), torch.zeros(1), torch.zeros(1))
         first_iter = True

+ 31 - 18
PyTorch/SpeechSynthesis/Tacotron2/test_infer.py

@@ -42,6 +42,8 @@ from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
 
 from apex import amp
 
+from waveglow.denoiser import Denoiser
+
 def parse_args(parser):
     """
     Parse commandline arguments.
@@ -51,6 +53,7 @@ def parse_args(parser):
     parser.add_argument('--waveglow', type=str,
                         help='full path to the WaveGlow model checkpoint file')
     parser.add_argument('-s', '--sigma-infer', default=0.6, type=float)
+    parser.add_argument('-d', '--denoising-strength', default=0.01, type=float)
     parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
                         help='Sampling rate')
     parser.add_argument('--amp-run', action='store_true',
@@ -65,23 +68,24 @@ def parse_args(parser):
                         help='Input length')
     parser.add_argument('-bs', '--batch-size', type=int, default=1,
                         help='Batch size')
-
-
+    parser.add_argument('--cpu-run', action='store_true', 
+                        help='Run inference on CPU')
     return parser
 
 
-def load_and_setup_model(model_name, parser, checkpoint, amp_run, to_cuda=True):
+def load_and_setup_model(model_name, parser, checkpoint, amp_run, cpu_run, forward_is_infer=False):
     model_parser = models.parse_model_args(model_name, parser, add_help=False)
     model_args, _ = model_parser.parse_known_args()
 
     model_config = models.get_model_config(model_name, model_args)
-    model = models.get_model(model_name, model_config, to_cuda=to_cuda)
+    model = models.get_model(model_name, model_config, cpu_run, forward_is_infer=forward_is_infer)
 
     if checkpoint is not None:
-        if to_cuda:
-            state_dict = torch.load(checkpoint)['state_dict']
+        if cpu_run:
+            state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))['state_dict']
         else:
-            state_dict = torch.load(checkpoint,map_location='cpu')['state_dict']
+            state_dict = torch.load(checkpoint)['state_dict']
+
         if checkpoint_from_distributed(state_dict):
             state_dict = unwrap_distributed(state_dict)
 
@@ -141,7 +145,7 @@ def print_stats(measurements_all):
 def main():
     """
     Launches text to speech (inference).
-    Inference is executed on a single GPU.
+    Inference is executed on a single GPU or CPU.
     """
     parser = argparse.ArgumentParser(
         description='PyTorch Tacotron 2 Inference')
@@ -168,8 +172,15 @@ def main():
 
     print("args:", args, unknown_args)
 
-    tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.amp_run)
-    waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.amp_run)
+    tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.amp_run, args.cpu_run, forward_is_infer=True)
+    waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.amp_run, args.cpu_run)
+
+    if args.cpu_run:
+        denoiser = Denoiser(waveglow, args.cpu_run)
+    else:
+        denoiser = Denoiser(waveglow, args.cpu_run).cuda()
+
+    jitted_tacotron2 = torch.jit.script(tacotron2)
 
     texts = ["The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."]
     texts = [texts[0][:args.input_length]]
@@ -181,27 +192,29 @@ def main():
 
         measurements = {}
 
-        with MeasureTime(measurements, "pre_processing"):
+        with MeasureTime(measurements, "pre_processing", args.cpu_run):
             sequences_padded, input_lengths = prepare_input_sequence(texts)
 
         with torch.no_grad():
-            with MeasureTime(measurements, "latency"):
-                with MeasureTime(measurements, "tacotron2_latency"):
-                    mel, mel_lengths, _ = tacotron2.infer(sequences_padded, input_lengths)
+            with MeasureTime(measurements, "latency", args.cpu_run):
+                with MeasureTime(measurements, "tacotron2_latency", args.cpu_run):
+                    mel, mel_lengths, _ = jitted_tacotron2(sequences_padded, input_lengths)
 
-                with MeasureTime(measurements, "waveglow_latency"):
+                with MeasureTime(measurements, "waveglow_latency", args.cpu_run):
                     audios = waveglow.infer(mel, sigma=args.sigma_infer)
+                    audios = audios.float()
+                    audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)
 
         num_mels = mel.size(0)*mel.size(2)
         num_samples = audios.size(0)*audios.size(1)
 
-        with MeasureTime(measurements, "type_conversion"):
+        with MeasureTime(measurements, "type_conversion", args.cpu_run):
             audios = audios.float()
 
-        with MeasureTime(measurements, "data_transfer"):
+        with MeasureTime(measurements, "data_transfer", args.cpu_run):
             audios = audios.cpu()
 
-        with MeasureTime(measurements, "storage"):
+        with MeasureTime(measurements, "storage", args.cpu_run):
             audios = audios.numpy()
             for i, audio in enumerate(audios):
                 audio_path = "audio_"+str(i)+".wav"

+ 9 - 3
PyTorch/SpeechSynthesis/Tacotron2/test_infer.sh

@@ -4,11 +4,12 @@ BATCH_SIZE=1
 INPUT_LENGTH=128
 PRECISION="fp32"
 NUM_ITERS=1003 # extra 3 iterations for warmup
-TACOTRON2_CKPT="checkpoint_Tacotron2_1500_fp32"
-WAVEGLOW_CKPT="checkpoint_WaveGlow_1000_fp32"
+TACOTRON2_CKPT="tacotron2_1032590_6000_amp"
+WAVEGLOW_CKPT="waveglow_1076430_14000_amp"
 AMP_RUN=""
 TEST_PROGRAM="test_infer.py"
-WN_CHANNELS=512
+WN_CHANNELS=256
+CPU_RUN=""
 
 while [ -n "$1" ]
 do
@@ -57,6 +58,10 @@ do
 	    WN_CHANNELS="$2"
 	    shift
 	    ;;
+	--cpu-run)
+	    CPU_RUN="--cpu-run"
+	    shift
+	    ;;
 	*)
 	    echo "Option $1 not recognized"
     esac
@@ -93,6 +98,7 @@ python $TEST_PROGRAM \
        --log-file $NVLOG_FILE \
        --num-iters $NUM_ITERS \
        --wn-channels $WN_CHANNELS \
+       $CPU_RUN \
        |& tee $TMP_LOGFILE
 set +x
 

+ 8 - 3
PyTorch/SpeechSynthesis/Tacotron2/waveglow/denoiser.py

@@ -34,10 +34,15 @@ from common.layers import STFT
 class Denoiser(torch.nn.Module):
     """ Removes model bias from audio produced with waveglow """
 
-    def __init__(self, waveglow, filter_length=1024, n_overlap=4,
+    def __init__(self, waveglow, cpu_run, filter_length=1024, n_overlap=4,
                  win_length=1024, mode='zeros'):
         super(Denoiser, self).__init__()
-        self.stft = STFT(filter_length=filter_length,
+        if cpu_run:
+            self.stft = STFT(filter_length=filter_length,
+                         hop_length=int(filter_length/n_overlap),
+                         win_length=win_length)
+        else:
+            self.stft = STFT(filter_length=filter_length,
                          hop_length=int(filter_length/n_overlap),
                          win_length=win_length).cuda()
         if mode == 'zeros':
@@ -60,7 +65,7 @@ class Denoiser(torch.nn.Module):
         self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
 
     def forward(self, audio, strength=0.1):
-        audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
+        audio_spec, audio_angles = self.stft.transform(audio.float())
         audio_spec_denoised = audio_spec - self.bias_spec * strength
         audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
         audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)