소스 검색

[Tacotron2/PyT] Updates: better perf, better trt7 support, new logging, bug fixes

Przemek Strzelczyk 6 년 전
부모
커밋
77a1bb917a
25개의 변경된 파일225개의 추가작업 그리고 526개의 파일을 삭제
  1. 4 0
      PyTorch/SpeechSynthesis/Tacotron2/.gitignore
  2. 3 3
      PyTorch/SpeechSynthesis/Tacotron2/Dockerfile
  3. 1 1
      PyTorch/SpeechSynthesis/Tacotron2/Dockerfile_trtis_client
  4. 20 17
      PyTorch/SpeechSynthesis/Tacotron2/README.md
  5. 1 1
      PyTorch/SpeechSynthesis/Tacotron2/common/stft.py
  6. 2 0
      PyTorch/SpeechSynthesis/Tacotron2/exports/export_tacotron2_ts.py
  7. 47 24
      PyTorch/SpeechSynthesis/Tacotron2/exports/export_waveglow_onnx.py
  8. 10 6
      PyTorch/SpeechSynthesis/Tacotron2/exports/export_waveglow_trt_config.py
  9. 2 1
      PyTorch/SpeechSynthesis/Tacotron2/inference.py
  10. 7 7
      PyTorch/SpeechSynthesis/Tacotron2/notebooks/trtis/README.md
  11. 17 174
      PyTorch/SpeechSynthesis/Tacotron2/notebooks/trtis/notebook.ipynb
  12. 1 1
      PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_AMP_DGX1_16GB_1GPU.sh
  13. 1 1
      PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_AMP_DGX1_16GB_4GPU.sh
  14. 1 1
      PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_AMP_DGX1_16GB_8GPU.sh
  15. 1 1
      PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_FP32_DGX1_16GB_1GPU.sh
  16. 1 1
      PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_FP32_DGX1_16GB_4GPU.sh
  17. 1 1
      PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_FP32_DGX1_16GB_8GPU.sh
  18. 1 0
      PyTorch/SpeechSynthesis/Tacotron2/requirements.txt
  19. 1 1
      PyTorch/SpeechSynthesis/Tacotron2/scripts/train_tacotron2.sh
  20. 54 31
      PyTorch/SpeechSynthesis/Tacotron2/train.py
  21. 7 4
      PyTorch/SpeechSynthesis/Tacotron2/trt/inference_trt.py
  22. 1 6
      PyTorch/SpeechSynthesis/Tacotron2/trt/run_latency_tests_trt.sh
  23. 0 181
      PyTorch/SpeechSynthesis/Tacotron2/trt/test_infer_trt.py
  24. 40 17
      PyTorch/SpeechSynthesis/Tacotron2/trt/trt_utils.py
  25. 1 46
      PyTorch/SpeechSynthesis/Tacotron2/waveglow/model.py

+ 4 - 0
PyTorch/SpeechSynthesis/Tacotron2/.gitignore

@@ -0,0 +1,4 @@
+__pycache__/
+/checkpoints/
+/output/
+nvlog.json

+ 3 - 3
PyTorch/SpeechSynthesis/Tacotron2/Dockerfile

@@ -1,6 +1,6 @@
-FROM nvcr.io/nvidia/pytorch:19.11-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.01-py3
+FROM ${FROM_IMAGE_NAME}
 
 ADD . /workspace/tacotron2
 WORKDIR /workspace/tacotron2
-RUN pip install -r requirements.txt
-RUN pip --no-cache-dir --no-cache install  'git+https://github.com/NVIDIA/dllogger'
+RUN pip install --no-cache-dir -r requirements.txt

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/Dockerfile_trtis_client

@@ -11,7 +11,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-FROM nvcr.io/nvidia/tensorrtserver:19.10-py3-clientsdk AS trt
+FROM nvcr.io/nvidia/tensorrtserver:20.01-py3-clientsdk AS trt
 FROM continuumio/miniconda3
 RUN apt-get update && apt-get install -y pbzip2 pv bzip2 cabextract mc iputils-ping wget
 

+ 20 - 17
PyTorch/SpeechSynthesis/Tacotron2/README.md

@@ -231,7 +231,7 @@ and encapsulates some dependencies. Aside from these dependencies, ensure you
 have the following components:
 
 * [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
-* [PyTorch 19.06-py3+ NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
+* [PyTorch 20.01-py3+ NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
 or newer
 * [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) or [Turing](https://www.nvidia.com/en-us/geforce/turing/) based GPU
 
@@ -370,7 +370,7 @@ WaveGlow models.
 
 * `--epochs` - number of epochs (Tacotron 2: 1501, WaveGlow: 1001)
 * `--learning-rate` - learning rate (Tacotron 2: 1e-3, WaveGlow: 1e-4)
-* `--batch-size` - batch size (Tacotron 2 FP16/FP32: 128/64, WaveGlow FP16/FP32: 10/4)
+* `--batch-size` - batch size (Tacotron 2 FP16/FP32: 104/48, WaveGlow FP16/FP32: 10/4)
 * `--amp-run` - use mixed precision training
 
 #### Shared audio/STFT parameters
@@ -496,21 +496,21 @@ To benchmark the training performance on a specific batch size, run:
 * For 1 GPU
 	* FP16
         ```bash
-        python train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --training-files filelists/ljs_audio_text_train_subset_2500_filelist.txt --dataset-path <dataset-path> --amp-run
+        python train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_subset_2500_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --dataset-path <dataset-path> --amp-run
         ```
 	* FP32
         ```bash
-        python train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --training-files filelists/ljs_audio_text_train_subset_2500_filelist.txt --dataset-path <dataset-path>
+        python train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_subset_2500_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --dataset-path <dataset-path>
         ```
 
 * For multiple GPUs
 	* FP16
         ```bash
-        python -m multiproc train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --training-files filelists/ljs_audio_text_train_subset_2500_filelist.txt --dataset-path <dataset-path> --amp-run
+        python -m multiproc train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_subset_2500_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --dataset-path <dataset-path> --amp-run
         ```
 	* FP32
         ```bash
-        python -m multiproc train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --training-files filelists/ljs_audio_text_train_subset_2500_filelist.txt --dataset-path <dataset-path>
+        python -m multiproc train.py -m Tacotron2 -o <output_dir> -lr 1e-3 --epochs 10 -bs <batch_size> --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_subset_2500_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --dataset-path <dataset-path>
         ```
 
 **WaveGlow**
@@ -579,10 +579,10 @@ All of the results were produced using the `train.py` script as described in the
 | WaveGlow FP16  | -2.2054 | -5.7602 |  -5.901 | -5.9706 | -6.0258 |
 | WaveGlow FP32  | -3.0327 |  -5.858 | -6.0056 | -6.0613 | -6.1087 |
 
-Tacotron 2 FP16 loss - batch size 128 (mean and std over 16 runs)
+Tacotron 2 FP16 loss - batch size 104 (mean and std over 16 runs)
 ![](./img/tacotron2_amp_loss.png "Tacotron 2 FP16 loss")
 
-Tacotron 2 FP32 loss - batch size 64 (mean and std over 16 runs)
+Tacotron 2 FP32 loss - batch size 48 (mean and std over 16 runs)
 ![](./img/tacotron2_fp32_loss.png "Tacotron 2 FP16 loss")
 
 WaveGlow FP16 loss - batch size 10 (mean and std over 16 runs)
@@ -597,7 +597,7 @@ WaveGlow FP32 loss - batch size 4 (mean and std over 16 runs)
 ##### Training performance: NVIDIA DGX-1 (8x V100 16G)
 
 Our results were obtained by running the `./platform/train_{tacotron2,waveglow}_{AMP,FP32}_DGX1_16GB_8GPU.sh`
-training script in the PyTorch-19.06-py3 NGC container on NVIDIA DGX-1 with
+training script in the PyTorch-19.12-py3 NGC container on NVIDIA DGX-1 with
 8x V100 16G GPUs. Performance numbers (in output mel-spectrograms per second for
 Tacotron 2 and output samples per second for WaveGlow) were averaged over
 an entire training epoch.
@@ -606,9 +606,9 @@ This table shows the results for Tacotron 2:
 
 |Number of GPUs|Batch size per GPU|Number of mels used with mixed precision|Number of mels used with FP32|Speed-up with mixed precision|Multi-GPU weak scaling with mixed precision|Multi-GPU weak scaling with FP32|
 |---:|---:|---:|---:|---:|---:|---:|
-|1|128@FP16, 64@FP32 | 20,992  | 12,933 | 1.62 | 1.00 | 1.00 |
-|4|128@FP16, 64@FP32 | 74,989  | 46,115 | 1.63 | 3.57 | 3.57 |
-|8|128@FP16, 64@FP32 | 140,060 | 88,719 | 1.58 | 6.67 | 6.86 |
+|1|104@FP16, 48@FP32 | 15,313 | 9,674 | 1.58 | 1.00 | 1.00 |
+|4|104@FP16, 48@FP32 | 53,661 | 32,778 | 1.64 | 3.50 | 3.39 |
+|8|104@FP16, 48@FP32 | 100,422 | 59,549 | 1.69 | 6.56 | 6.16 |
 
 The following table shows the results for WaveGlow:
 
@@ -626,9 +626,9 @@ The following table shows the expected training time for convergence for Tacotro
 
 |Number of GPUs|Batch size per GPU|Time to train with mixed precision (Hrs)|Time to train with FP32 (Hrs)|Speed-up with mixed precision|
 |---:|---:|---:|---:|---:|
-|1| 128@FP16, 64@FP32 | 153 | 234 | 1.53 |
-|4| 128@FP16, 64@FP32 | 42 | 64 | 1.54 |
-|8| 128@FP16, 64@FP32 | 22 | 33 | 1.52 |
+|1| 104@FP16, 48@FP32 | 193 | 312 | 1.62 |
+|4| 104@FP16, 48@FP32 | 53 | 85 | 1.58 |
+|8| 104@FP16, 48@FP32 | 31 | 45 | 1.47 |
 
 The following table shows the expected training time for convergence for WaveGlow (1001 epochs):
 
@@ -704,8 +704,11 @@ November 2019
 * Implemented training resume from checkpoint
 * Added notebook for running Tacotron 2 and WaveGlow in TRTIS.
 
-December  2019
-* Added `trt` subfolder for running Tacotron 2 and WaveGlow in TensorRT.
+December 2019
+* Added export and inference scripts for TensorRT. See [Tacotron2 TensorRT README](trt/README.md).
+
+January 2020
+* Updated batch sizes and performance results for Tacotron 2.
 
 ### Known issues
 

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/common/stft.py

@@ -58,7 +58,7 @@ class STFT(torch.nn.Module):
 
         forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
         inverse_basis = torch.FloatTensor(
-            np.linalg.pinv(scale * fourier_basis).T[:, None, :])
+            np.linalg.pinv(scale * fourier_basis).T[:, None, :].astype(np.float32))
 
         if window is not None:
             assert(filter_length >= win_length)

+ 2 - 0
PyTorch/SpeechSynthesis/Tacotron2/exports/export_tacotron2_ts.py

@@ -27,6 +27,8 @@
 
 import torch
 import argparse
+import sys
+sys.path.append('./')
 from inference import checkpoint_from_distributed, unwrap_distributed, load_and_setup_model
 
 def parse_args(parser):

+ 47 - 24
PyTorch/SpeechSynthesis/Tacotron2/exports/export_waveglow_onnx.py

@@ -25,6 +25,7 @@
 #
 # *****************************************************************************
 
+import types
 import torch
 import argparse
 
@@ -113,32 +114,52 @@ def convert_1d_to_2d_(glow):
 
     glow.cuda()
 
-def test_inference(waveglow):
 
+def infer_onnx(self, spect, z, sigma=0.9):
 
-    from scipy.io.wavfile import write
+    spect = self.upsample(spect)
+    # trim conv artifacts. maybe pad spec to kernel multiple
+    time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
+    spect = spect[:, :, :-time_cutoff]
 
-    mel = torch.load("mel.pt").cuda()
-    # mel = torch.load("mel_spectrograms/LJ001-0015.wav.pt").cuda()
-    # mel = mel.unsqueeze(0)
-    mel_lengths = [mel.size(2)]
-    stride = 256
-    kernel_size = 1024
-    n_group = 8
-    z_size2 = (mel.size(2)-1)*stride+(kernel_size-1)+1
-    # corresponds to cutoff in infer_onnx
-    z_size2 = z_size2 - (kernel_size-stride)
-    z_size2 = z_size2//n_group
-    z = torch.randn(1, n_group, z_size2, 1).cuda()
-    mel = mel.unsqueeze(3)
+    length_spect_group = spect.size(2)//8
+    mel_dim = 80
+    batch_size = spect.size(0)
 
-    with torch.no_grad():
-        audios = waveglow(mel, z)
+    spect = torch.squeeze(spect, 3)
+    spect = spect.view((batch_size, mel_dim, length_spect_group, self.n_group))
+    spect = spect.permute(0, 2, 1, 3)
+    spect = spect.contiguous()
+    spect = spect.view((batch_size, length_spect_group, self.n_group*mel_dim))
+    spect = spect.permute(0, 2, 1)
+    spect = torch.unsqueeze(spect, 3)
+    spect = spect.contiguous()
+
+    audio = z[:, :self.n_remaining_channels, :, :]
+    z = z[:, self.n_remaining_channels:self.n_group, :, :]
+    audio = sigma*audio
+
+    for k in reversed(range(self.n_flows)):
+        n_half = int(audio.size(1) / 2)
+        audio_0 = audio[:, :n_half, :, :]
+        audio_1 = audio[:, n_half:(n_half+n_half), :, :]
 
-    for i, audio in enumerate(audios):
-        audio = audio[:mel_lengths[i]*256]
-        audio = audio/torch.max(torch.abs(audio))
-        write("audio_pyt.wav", 22050, audio.cpu().numpy())
+        output = self.WN[k]((audio_0, spect))
+        s = output[:, n_half:(n_half+n_half), :, :]
+        b = output[:, :n_half, :, :]
+        audio_1 = (audio_1 - b) / torch.exp(s)
+        audio = torch.cat([audio_0, audio_1], 1)
+
+        audio = self.convinv[k](audio)
+
+        if k % self.n_early_every == 0 and k > 0:
+            audio = torch.cat((z[:, :self.n_early_size, :, :], audio), 1)
+            z = z[:, self.n_early_size:self.n_group, :, :]
+
+    audio = torch.squeeze(audio, 3)
+    audio = audio.permute(0,2,1).contiguous().view(batch_size, (length_spect_group * self.n_group))
+
+    return audio
 
 
 def export_onnx(parser, args):
@@ -166,12 +187,16 @@ def export_onnx(parser, args):
 
         # export to ONNX
         convert_1d_to_2d_(waveglow)
-        waveglow.forward = waveglow.infer_onnx
+
+        fType = types.MethodType
+        waveglow.forward = fType(infer_onnx, waveglow)
+
         if args.amp_run:
             waveglow.half()
         mel = mel.unsqueeze(3)
 
         opset_version = 10
+
         torch.onnx.export(waveglow, (mel, z), args.output+"/"+"waveglow.onnx",
                           opset_version=opset_version,
                           do_constant_folding=True,
@@ -181,8 +206,6 @@ def export_onnx(parser, args):
                                         "z":     {0: "batch_size", 2: "z_seq"},
                                         "audio": {0: "batch_size", 1: "audio_seq"}})
 
-    test_inference(waveglow)
-
 
 def main():
 

+ 10 - 6
PyTorch/SpeechSynthesis/Tacotron2/exports/export_waveglow_trt_config.py

@@ -64,20 +64,24 @@ def main():
     config_template = r"""
 name: "{model_name}"
 platform: "tensorrt_plan"
+default_model_filename: "waveglow_fp16.engine"
+
+max_batch_size: 1
+
 input {{
-  name: "0"
+  name: "mel"
   data_type: {fp_type}
-  dims: [1, 80, 620, 1]
+  dims: [80, -1, 1]
 }}
 input {{
-  name: "1"
+  name: "z"
   data_type: {fp_type}
-  dims: [1, 8, 19840, 1]
+  dims: [8, -1, 1]
 }}
 output {{
-  name: "1991"
+  name: "audio"
   data_type: {fp_type}
-  dims: [1, 158720]
+  dims: [-1]
 }}
 """
     

+ 2 - 1
PyTorch/SpeechSynthesis/Tacotron2/inference.py

@@ -50,6 +50,7 @@ def parse_args(parser):
                         help='full path to the input text (phareses separated by new line)')
     parser.add_argument('-o', '--output', required=True,
                         help='output folder to save audio (file per phrase)')
+    parser.add_argument('--suffix', type=str, default="", help="output filename suffix")
     parser.add_argument('--tacotron2', type=str,
                         help='full path to the Tacotron2 model checkpoint file')
     parser.add_argument('--waveglow', type=str,
@@ -242,7 +243,7 @@ def main():
     for i, audio in enumerate(audios):
         audio = audio[:mel_lengths[i]*args.stft_hop_length]
         audio = audio/torch.max(torch.abs(audio))
-        audio_path = args.output + "audio_"+str(i)+".wav"
+        audio_path = args.output+"audio_"+str(i)+"_"+args.suffix+".wav"
         write(audio_path, args.sampling_rate, audio.cpu().numpy())
 
     DLLogger.flush()

+ 7 - 7
PyTorch/SpeechSynthesis/Tacotron2/notebooks/trtis/README.md

@@ -106,7 +106,7 @@ cd /workspace/onnx-tensorrt/build && cmake .. -DCMAKE_CXX_FLAGS=-isystem\ /usr/l
 In order to export the model into the ONNX intermediate representation, type:
 
 ```bash
-python exports/export_waveglow_onnx.py --waveglow <waveglow_checkpoint> --wn-channels 256 --amp-run
+python exports/export_waveglow_onnx.py --waveglow <waveglow_checkpoint> --wn-channels 256 --amp-run --output ./output
 ```
 
 This will save the model as `waveglow.onnx` (you can change its name with the flag `--output <filename>`).
@@ -114,15 +114,15 @@ This will save the model as `waveglow.onnx` (you can change its name with the fl
 With the model exported to ONNX, type the following to obtain a TRT engine and save it as `trtis_repo/waveglow/1/model.plan`:
 
 ```bash
-onnx2trt <exported_waveglow_onnx> -o trtis_repo/waveglow/1/model.plan -b 1 -w 8589934592
+python trt/export_onnx2trt.py --waveglow  <exported_waveglow_onnx> -o trtis_repo/waveglow/1/ --fp16
 ```
 
 ### Setup the TRTIS server.
 
 Download the TRTIS container by typing:
 ```bash
-docker pull nvcr.io/nvidia/tensorrtserver:19.10-py3
-docker tag nvcr.io/nvidia/tensorrtserver:19.10-py3 tensorrtserver:19.10
+docker pull nvcr.io/nvidia/tensorrtserver:20.01-py3
+docker tag nvcr.io/nvidia/tensorrtserver:20.01-py3 tensorrtserver:20.01
 ```
 
 ### Setup the TRTIS notebook client.
@@ -130,14 +130,14 @@ docker tag nvcr.io/nvidia/tensorrtserver:19.10-py3 tensorrtserver:19.10
 Now go to the root directory of the Tacotron 2 repo, and type: 
 
 ```bash
-docker build -f Dockerfile_trtis_client --network=host -t speech_ai__tts_only:demo .
+docker build -f Dockerfile_trtis_client --network=host -t speech_ai_tts_only:demo .
 ```
 
 ### Run the TRTIS server.
 
 To run the server, type in the root directory of the Tacotron 2 repo:
 ```bash
-NV_GPU=1 nvidia-docker run -ti --ipc=host --network=host --rm -p8000:8000 -p8001:8001 -v $PWD/trtis_repo/:/models tensorrtserver:19.10 trtserver --model-store=/models --log-verbose 1
+NV_GPU=1 nvidia-docker run -ti --ipc=host --network=host --rm -p8000:8000 -p8001:8001 -v $PWD/trtis_repo/:/models tensorrtserver:20.01 trtserver --model-store=/models --log-verbose 1
 ```
 
 The flag `NV_GPU` selects the GPU the server is going to see. If we want it to see all the available GPUs, then run the above command without this flag.
@@ -147,7 +147,7 @@ By default, the model repository will be in `trtis_repo/`.
 
 Leave the server running. In another terminal, type:
 ```bash
-docker run -it --rm --network=host --device /dev/snd:/dev/snd --device /dev/usb:/dev/usb speech_ai__tts_only:demo bash ./run_this.sh
+docker run -it --rm --network=host --device /dev/snd:/dev/snd --device /dev/usb:/dev/usb speech_ai_tts_only:demo bash ./run_this.sh
 ```
 
 Open the URL in a browser, open `notebook.ipynb`, click play, and enjoy.

+ 17 - 174
PyTorch/SpeechSynthesis/Tacotron2/notebooks/trtis/notebook.ipynb

@@ -2,168 +2,9 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "bfd62b9362ec4dbb825e786821358a5e",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "VBox(layout=Layout(height='1in'))"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/markdown": [
-       "**tacotron2 input**"
-      ],
-      "text/plain": [
-       "<IPython.core.display.Markdown object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "fead313d50594d0688a399cff8b6eb86",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Textarea(value='type here', layout=Layout(height='80px', width='550px'), placeholder='')"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/markdown": [
-       "**tacotron2 preprocessing**"
-      ],
-      "text/plain": [
-       "<IPython.core.display.Markdown object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "e610d8bd77ad44a7839d6ede103a71ef",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output(layout=Layout(height='1in', object_fit='fill', object_position='{center} {center}', width='10in'))"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/markdown": [
-       "**tacotron2 output / waveglow input**"
-      ],
-      "text/plain": [
-       "<IPython.core.display.Markdown object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "240214dbb35c4bfd97b84df7cb5cd8cf",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output(layout=Layout(height='2.1in', object_fit='fill', object_position='{center} {center}', width='10in'))"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/markdown": [
-       "**waveglow output**"
-      ],
-      "text/plain": [
-       "<IPython.core.display.Markdown object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "5bc0e4519ea74a89b34c55fa8a040b02",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output(layout=Layout(height='2in', object_fit='fill', object_position='{center} {center}', width='10in'))"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/markdown": [
-       "**play**"
-      ],
-      "text/plain": [
-       "<IPython.core.display.Markdown object>"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "30976ec770c34d34880af8d7f3d6ff08",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output(layout=Layout(height='1in', object_fit='fill', object_position='{center} {center}', width='10in'))"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "bfd62b9362ec4dbb825e786821358a5e",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "VBox(layout=Layout(height='1in'))"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "import os\n",
     "import time\n",
@@ -192,7 +33,7 @@
     "    'protocol': 0,             # 0: http, 1: grpc \n",
     "    'autoplay': True,          # autoplay\n",
     "    'character_limit_min': 4,  # don't touch this\n",
-    "    'character_limit_max': 124 # don't touch this\n",
+    "    'character_limit_max': 340 # don't touch this\n",
     "}\n",
     "\n",
     "\n",
@@ -330,23 +171,25 @@
     "        ::returns:: waveform\n",
     "    '''\n",
     "    # padding/trimming mel to dimension 620\n",
-    "    mel = force_to_shape(mel, 620)\n",
+    "    mel = mel[:,:,None]\n",
     "    # prepare input/output\n",
-    "    mel = mel[None,:,:]\n",
     "    input_dict = {}\n",
-    "    input_dict['0'] = (mel,)\n",
-    "    shape = (8,19840,1)\n",
-    "    shape = (1,*shape)\n",
-    "    input_dict['1'] = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)\n",
-    "    input_dict['1'] = (input_dict['1'],)\n",
+    "    input_dict['mel'] = (mel,)\n",
+    "    stride = 256\n",
+    "    kernel_size = 1024\n",
+    "    n_group = 8\n",
+    "    z_size = (mel.shape[1]-1)*stride + (kernel_size-1) + 1 - (kernel_size-stride)\n",
+    "    z_size = z_size//n_group\n",
+    "    shape = (n_group,z_size,1)\n",
+    "    input_dict['z'] = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)\n",
+    "    input_dict['z'] = (input_dict['z'],)\n",
     "    output_dict = {}\n",
-    "    output_dict['1991'] = InferContext.ResultFormat.RAW\n",
+    "    output_dict['audio'] = InferContext.ResultFormat.RAW\n",
     "    batch_size = 1\n",
     "    # call waveglow\n",
     "    result = infer_ctx_waveglow.run(input_dict, output_dict, batch_size)\n",
     "    # get the results\n",
-    "    signal = result['1991'][0] # take only the first instance in the output batch\n",
-    "    signal = signal[0] # remove this line, when waveglow supports dynamic batch sizes\n",
+    "    signal = result['audio'][0] # take only the first instance in the output batch\n",
     "    # postprocessing of waveglow: trimming signal to its actual size\n",
     "    trimmed_length = mel_lengths[0] * args.stft_hop_length\n",
     "    signal = signal[:trimmed_length] # trim\n",
@@ -432,7 +275,7 @@
     ")\n",
     "\n",
     "# default text\n",
-    "text_area.value = \"I think grown-ups just act like they know what they're doing. \""
+    "text_area.value = \"The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves.\""
    ]
   },
   {
@@ -459,7 +302,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.7.3"
+   "version": "3.7.4"
   }
  },
  "nbformat": 4,

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_AMP_DGX1_16GB_1GPU.sh

@@ -1,2 +1,2 @@
 mkdir -p output
-python train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 128 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3
+python train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 104 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_AMP_DGX1_16GB_4GPU.sh

@@ -1,2 +1,2 @@
 mkdir -p output
-python -m multiproc train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 128 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3
+python -m multiproc train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 104 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_AMP_DGX1_16GB_8GPU.sh

@@ -1,2 +1,2 @@
 mkdir -p output
-python -m multiproc train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 128 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3
+python -m multiproc train.py -m Tacotron2 -o output/ --amp-run -lr 1e-3 --epochs 1501 -bs 104 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.3

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_FP32_DGX1_16GB_1GPU.sh

@@ -1,2 +1,2 @@
 mkdir -p output
-python train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 64 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1
+python train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 48 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_FP32_DGX1_16GB_4GPU.sh

@@ -1,2 +1,2 @@
 mkdir -p output
-python -m multiproc train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 64 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1
+python -m multiproc train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 48 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/platform/train_tacotron2_FP32_DGX1_16GB_8GPU.sh

@@ -1,2 +1,2 @@
 mkdir -p output
-python -m multiproc train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 64 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1
+python -m multiproc train.py -m Tacotron2 -o output/ -lr 1e-3 --epochs 1501 -bs 48 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --load-mel-from-disk --training-files=filelists/ljs_mel_text_train_filelist.txt --validation-files=filelists/ljs_mel_text_val_filelist.txt --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1

+ 1 - 0
PyTorch/SpeechSynthesis/Tacotron2/requirements.txt

@@ -4,3 +4,4 @@ inflect
 librosa
 scipy
 Unidecode
+git+git://github.com/NVIDIA/dllogger.git@26a0f8f1958de2c0c460925ff6102a4d2486d6cc#egg=dllogger

+ 1 - 1
PyTorch/SpeechSynthesis/Tacotron2/scripts/train_tacotron2.sh

@@ -1,2 +1,2 @@
 mkdir -p output
-python -m multiproc train.py -m Tacotron2 -o ./output/ -lr 1e-3 --epochs 1501 -bs 128 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1 --amp-run
+python -m multiproc train.py -m Tacotron2 -o ./output/ -lr 1e-3 --epochs 1501 -bs 104 --weight-decay 1e-6 --grad-clip-thresh 1.0 --cudnn-enabled --log-file nvlog.json --anneal-steps 500 1000 1500 --anneal-factor 0.1 --amp-run

+ 54 - 31
PyTorch/SpeechSynthesis/Tacotron2/train.py

@@ -236,18 +236,38 @@ def validate(model, criterion, valset, epoch, batch_iter, batch_size,
                                 collate_fn=collate_fn)
 
         val_loss = 0.0
+        num_iters = 0
+        val_items_per_sec = 0.0
         for i, batch in enumerate(val_loader):
-            x, y, len_x = batch_to_gpu(batch)
+            torch.cuda.synchronize()
+            iter_start_time = time.perf_counter()
+
+            x, y, num_items = batch_to_gpu(batch)
             y_pred = model(x)
             loss = criterion(y_pred, y)
             if distributed_run:
                 reduced_val_loss = reduce_tensor(loss.data, world_size).item()
+                reduced_num_items = reduce_tensor(num_items.data, 1).item()
             else:
                 reduced_val_loss = loss.item()
+                reduced_num_items = num_items.item()
             val_loss += reduced_val_loss
-        val_loss = val_loss / (i + 1)
 
-        DLLogger.log(step=(epoch, batch_iter, epoch), data={'val_iter_loss': val_loss})
+            torch.cuda.synchronize()
+            iter_stop_time = time.perf_counter()
+            iter_time = iter_stop_time - iter_start_time
+
+            items_per_sec = reduced_num_items/iter_time
+            DLLogger.log(step=(epoch, batch_iter, i), data={'val_items_per_sec': items_per_sec})
+            val_items_per_sec += items_per_sec
+            num_iters += 1
+
+        val_loss = val_loss/(i + 1)
+
+        DLLogger.log(step=(epoch,), data={'val_loss': val_loss})
+        DLLogger.log(step=(epoch,), data={'val_items_per_sec':
+                                         (val_items_per_sec/num_iters if num_iters > 0 else 0.0)})
+
         return val_loss
 
 def adjust_learning_rate(iteration, epoch, optimizer, learning_rate,
@@ -307,8 +327,8 @@ def main():
     if distributed_run:
         init_distributed(args, world_size, local_rank, args.group_name)
 
-    run_start_time = time.time()
-    DLLogger.log(step=tuple(), data={'run_start': run_start_time})
+    torch.cuda.synchronize()
+    run_start_time = time.perf_counter()
 
     model_config = models.get_model_config(model_name, args)
     model = models.get_model(model_name, model_config,
@@ -350,8 +370,14 @@ def main():
         model_name, n_frames_per_step)
     trainset = data_functions.get_data_loader(
         model_name, args.dataset_path, args.training_files, args)
-    train_sampler = DistributedSampler(trainset) if distributed_run else None
-    train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
+    if distributed_run:
+        train_sampler = DistributedSampler(trainset)
+        shuffle = False
+    else:
+        train_sampler = None
+        shuffle = True
+
+    train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
                               sampler=train_sampler,
                               batch_size=args.batch_size, pin_memory=False,
                               drop_last=True, collate_fn=collate_fn)
@@ -362,21 +388,21 @@ def main():
     batch_to_gpu = data_functions.get_batch_to_gpu(model_name)
 
     iteration = 0
-    train_epoch_avg_items_per_sec = 0.0
+    train_epoch_items_per_sec = 0.0
     val_loss = 0.0
     num_iters = 0
 
     model.train()
 
     for epoch in range(start_epoch, args.epochs):
-        epoch_start_time = time.time()
-        DLLogger.log(step=(epoch,) , data={'train_epoch_start': epoch_start_time})
+        torch.cuda.synchronize()
+        epoch_start_time = time.perf_counter()
         # used to calculate avg items/sec over epoch
         reduced_num_items_epoch = 0
 
         # used to calculate avg loss over epoch
         train_epoch_avg_loss = 0.0
-        train_epoch_avg_items_per_sec = 0.0
+        train_epoch_items_per_sec = 0.0
 
         num_iters = 0
 
@@ -387,12 +413,11 @@ def main():
             train_loader.sampler.set_epoch(epoch)
 
         for i, batch in enumerate(train_loader):
-            iter_start_time = time.time()
+            torch.cuda.synchronize()
+            iter_start_time = time.perf_counter()
             DLLogger.log(step=(epoch, i),
                          data={'glob_iter/iters_per_epoch': str(iteration)+"/"+str(len(train_loader))})
-            DLLogger.log(step=(epoch, i), data={'train_iter_start': iter_start_time})
 
-            start = time.perf_counter()
             adjust_learning_rate(iteration, epoch, optimizer, args.learning_rate,
                                  args.anneal_steps, args.anneal_factor, local_rank)
 
@@ -411,7 +436,7 @@ def main():
             if np.isnan(reduced_loss):
                 raise Exception("loss is NaN")
 
-            DLLogger.log(step=(epoch,i), data={'train_iter_loss': reduced_loss})
+            DLLogger.log(step=(epoch,i), data={'train_loss': reduced_loss})
 
             train_epoch_avg_loss += reduced_loss
             num_iters += 1
@@ -431,25 +456,24 @@ def main():
 
             optimizer.step()
 
-            iter_stop_time = time.time()
+            torch.cuda.synchronize()
+            iter_stop_time = time.perf_counter()
             iter_time = iter_stop_time - iter_start_time
             items_per_sec = reduced_num_items/iter_time
-            train_epoch_avg_items_per_sec += items_per_sec
+            train_epoch_items_per_sec += items_per_sec
 
-            DLLogger.log(step=(epoch, i), data={'train_iter_items/sec': items_per_sec})
-            DLLogger.log(step=(epoch, i), data={'train_iter_stop': iter_stop_time})
+            DLLogger.log(step=(epoch, i), data={'train_items_per_sec': items_per_sec})
             DLLogger.log(step=(epoch, i), data={'train_iter_time': iter_time})
             iteration += 1
 
-
-        epoch_stop_time = time.time()
+        torch.cuda.synchronize()
+        epoch_stop_time = time.perf_counter()
         epoch_time = epoch_stop_time - epoch_start_time
 
-        DLLogger.log(step=(epoch,), data={'train_epoch_items/sec': reduced_num_items_epoch/epoch_time})
-        DLLogger.log(step=(epoch,), data={'train_epoch_avg_items/sec':
-                                          (train_epoch_avg_items_per_sec/num_iters if num_iters > 0 else 0.0)})
-        DLLogger.log(step=(epoch,), data={'train_epoch_avg_loss': (train_epoch_avg_loss/num_iters if num_iters > 0 else 0.0)})
-        DLLogger.log(step=(epoch,), data={'epoch_time': epoch_time})
+        DLLogger.log(step=(epoch,), data={'train_items_per_sec':
+                                          (train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)})
+        DLLogger.log(step=(epoch,), data={'train_loss': (train_epoch_avg_loss/num_iters if num_iters > 0 else 0.0)})
+        DLLogger.log(step=(epoch,), data={'train_epoch_time': epoch_time})
 
         val_loss = validate(model, criterion, valset, epoch, i,
                             args.batch_size, world_size, collate_fn,
@@ -463,14 +487,13 @@ def main():
         if local_rank == 0:
             DLLogger.flush()
 
-
-    run_stop_time = time.time()
-    DLLogger.log(step=tuple(), data={'run_stop': run_start_time})
+    torch.cuda.synchronize()
+    run_stop_time = time.perf_counter()
     run_time = run_stop_time - run_start_time
     DLLogger.log(step=tuple(), data={'run_time': run_time})
-    DLLogger.log(step=tuple(), data={'train_items_per_sec':
-                                     (train_epoch_avg_items_per_sec/num_iters if num_iters > 0 else 0.0)})
     DLLogger.log(step=tuple(), data={'val_loss': val_loss})
+    DLLogger.log(step=tuple(), data={'train_items_per_sec':
+                                     (train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)})
 
     if local_rank == 0:
         DLLogger.flush()

+ 7 - 4
PyTorch/SpeechSynthesis/Tacotron2/trt/inference_trt.py

@@ -265,9 +265,13 @@ def infer_waveglow_trt(waveglow, waveglow_context, mel, measurements):
     z_size = (mel_size-1)*stride+(kernel_size-1)+1
     z_size = z_size - (kernel_size-stride)
     z_size = z_size//n_group
-    z = torch.randn(batch_size, n_group, z_size, 1).cuda().float()
+    z = torch.randn(batch_size, n_group, z_size, 1).cuda()
+    audios = torch.zeros(batch_size, mel_size*stride).cuda()
 
-    audios = torch.zeros(batch_size, mel_size*256).cuda()
+    if "HALF" in str(waveglow.get_binding_dtype(waveglow.get_binding_index("mel"))):
+        z = z.half()
+        mel = mel.half()
+        audios = audios.half()
 
     waveglow_tensors = {
         # inputs
@@ -330,7 +334,6 @@ def main():
     measurements = {}
 
     sequences, sequence_lengths = prepare_input_sequence(texts)
-    print("|||sequence_lengths", sequence_lengths)
     sequences = sequences.to(torch.int32)
     sequence_lengths = sequence_lengths.to(torch.int32)
     with MeasureTime(measurements, "latency"):
@@ -342,7 +345,7 @@ def main():
     with encoder_context, decoder_context,  postnet_context, waveglow_context:
         pass
 
-    audios.float()
+    audios = audios.float()
     if args.waveglow_ckpt != "":
         with MeasureTime(measurements, "denoiser"):
             audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)

+ 1 - 6
PyTorch/SpeechSynthesis/Tacotron2/trt/run_latency_tests_trt.sh

@@ -2,10 +2,5 @@
 
 for i in {1..1003}
 do
-    python trt/inference_trt.py -i ./phrases/phrase_1_128.txt --encoder ./output/encoder_fp32.engine --decoder ./output/decoder_iter_fp32.engine --postnet ./output/postnet_fp32.engine  --waveglow ./output/waveglow_fp32.engine -o output_1/ >> tmp_log_bs1_fp32.log 2>&1
-done
-
-for i in {1..1003}
-do
-    python trt/inference_trt.py -i ./phrases/phrase_1_128.txt --encoder ./output/encoder_fp16.engine --decoder ./output/decoder_iter_fp16.engine --postnet ./output/postnet_fp16.engine  --waveglow ./output/waveglow_fp16.engine -o output_1/ >> tmp_log_bs1_fp16.log 2>&1
+    python trt/inference_trt.py -i ./phrases/phrase_1_128.txt --encoder ./output/encoder_fp16.engine --decoder ./output/decoder_iter_fp16.engine --postnet ./output/postnet_fp16.engine  --waveglow ./output/waveglow_fp16.engine -o output/ --fp16 >> tmp_log_bs1_fp16.log 2>&1
 done

+ 0 - 181
PyTorch/SpeechSynthesis/Tacotron2/trt/test_infer_trt.py

@@ -1,181 +0,0 @@
-# *****************************************************************************
-#  Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
-#
-#  Redistribution and use in source and binary forms, with or without
-#  modification, are permitted provided that the following conditions are met:
-#      * Redistributions of source code must retain the above copyright
-#        notice, this list of conditions and the following disclaimer.
-#      * Redistributions in binary form must reproduce the above copyright
-#        notice, this list of conditions and the following disclaimer in the
-#        documentation and/or other materials provided with the distribution.
-#      * Neither the name of the NVIDIA CORPORATION nor the
-#        names of its contributors may be used to endorse or promote products
-#        derived from this software without specific prior written permission.
-#
-#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-#  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-#  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-#  DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
-#  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-#  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-#  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-#  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#
-# *****************************************************************************
-
-import torch
-import argparse
-import numpy as np
-from scipy.io.wavfile import write
-import tensorrt as trt
-import sys
-sys.path.append('./')
-
-import time
-import dllogger as DLLogger
-from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
-
-from apex import amp
-
-from inference import MeasureTime, prepare_input_sequence
-from test_infer import print_stats
-from trt.inference_trt import infer_tacotron2_trt, infer_waveglow_trt
-from trt.trt_utils import load_engine
-import models
-
-from test_infer import print_stats
-
-def parse_args(parser):
-    """
-    Parse commandline arguments.
-    """
-    parser.add_argument('--encoder', type=str, required=True,
-                        help='full path to the Encoder TRT engine')
-    parser.add_argument('--decoder', type=str, required=True,
-                        help='full path to the DecoderIter TRT engine')
-    parser.add_argument('--postnet', type=str, required=True,
-                        help='full path to the Postnet TRT engine')
-    parser.add_argument('--waveglow', type=str, required=True,
-                        help='full path to the WaveGlow TRT engine')
-    parser.add_argument('-s', '--sigma-infer', default=0.6, type=float)
-    parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
-                        help='Sampling rate')
-    parser.add_argument('--fp16', action='store_true',
-                        help='inference ')
-    parser.add_argument('--log-file', type=str, default='nvlog.json',
-                        help='Filename for logging')
-    parser.add_argument('--stft-hop-length', type=int, default=256,
-                        help='STFT hop length for estimating audio length from mel size')
-    parser.add_argument('--num-iters', type=int, default=10,
-                        help='Number of iterations')
-    parser.add_argument('-il', '--input-length', type=int, default=64,
-                        help='Input length')
-    parser.add_argument('-bs', '--batch-size', type=int, default=1,
-                        help='Batch size')
-
-    return parser
-
-
-def main():
-    """
-    Launches text to speech (inference).
-    Inference is executed on a single GPU.
-    """
-    parser = argparse.ArgumentParser(
-        description='PyTorch Tacotron 2 Inference')
-    parser = parse_args(parser)
-    args, unknown_args = parser.parse_known_args()
-
-    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
-    encoder = load_engine(args.encoder, TRT_LOGGER)
-    decoder_iter = load_engine(args.decoder, TRT_LOGGER)
-    postnet = load_engine(args.postnet, TRT_LOGGER)
-    waveglow = load_engine(args.waveglow, TRT_LOGGER)
-
-    # initialize CUDA state
-    torch.cuda.init()
-    # create TRT contexts for each engine
-    encoder_context = encoder.create_execution_context()
-    decoder_context = decoder_iter.create_execution_context()
-    postnet_context = postnet.create_execution_context()
-    waveglow_context = waveglow.create_execution_context()
-
-    DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, args.log_file),
-                            StdOutBackend(Verbosity.VERBOSE)])
-    for k,v in vars(args).items():
-        DLLogger.log(step="PARAMETER", data={k:v})
-    DLLogger.log(step="PARAMETER", data={'model_name':'Tacotron2_PyT'})
-
-    measurements_all = {"pre_processing": [],
-                        "tacotron2_latency": [],
-                        "waveglow_latency": [],
-                        "latency": [],
-                        "type_conversion": [],
-                        "data_transfer": [],
-                        "storage": [],
-                        "tacotron2_items_per_sec": [],
-                        "waveglow_items_per_sec": [],
-                        "num_mels_per_audio": [],
-                        "throughput": []}
-
-    texts = ["The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."]
-    texts = [texts[0][:args.input_length]]
-    texts = texts*args.batch_size
-
-    warmup_iters = 3
-
-    for iter in range(args.num_iters):
-
-        measurements = {}
-
-        with MeasureTime(measurements, "pre_processing"):
-            sequences_padded, input_lengths = prepare_input_sequence(texts)
-
-        with torch.no_grad():
-            with MeasureTime(measurements, "latency"):
-                with MeasureTime(measurements, "tacotron2_latency"):
-                    mel, mel_lengths = infer_tacotron2_trt(encoder, decoder_iter, postnet,
-                                                           encoder_context, decoder_context, postnet_context,
-                                                           sequences_padded, input_lengths, measurements)
-
-                with MeasureTime(measurements, "waveglow_latency"):
-                    audios = infer_waveglow_trt(waveglow, waveglow_context, mel, measurements)
-
-        num_mels = mel.size(0)*mel.size(2)
-        num_samples = audios.size(0)*audios.size(1)
-
-        with MeasureTime(measurements, "type_conversion"):
-            audios = audios.float()
-
-        with MeasureTime(measurements, "data_transfer"):
-            audios = audios.cpu()
-
-        with MeasureTime(measurements, "storage"):
-            audios = audios.numpy()
-            for i, audio in enumerate(audios):
-                audio_path = "audio_"+str(i)+".wav"
-                write(audio_path, args.sampling_rate,
-                      audio[:mel_lengths[i]*args.stft_hop_length])
-
-        measurements['tacotron2_items_per_sec'] = num_mels/measurements['tacotron2_latency']
-        measurements['waveglow_items_per_sec'] = num_samples/measurements['waveglow_latency']
-        measurements['num_mels_per_audio'] = mel.size(2)
-        measurements['throughput'] = num_samples/measurements['latency']
-
-        if iter >= warmup_iters:
-            for k,v in measurements.items():
-                if k in measurements_all.keys():
-                    measurements_all[k].append(v)
-                    DLLogger.log(step=(iter-warmup_iters), data={k: v})
-
-    with encoder_context, decoder_context, postnet_context, waveglow_context:
-        pass
-
-    DLLogger.flush()
-
-    print_stats(measurements_all)
-
-if __name__ == '__main__':
-    main()

+ 40 - 17
PyTorch/SpeechSynthesis/Tacotron2/trt/trt_utils.py

@@ -27,19 +27,6 @@
 
 import tensorrt as trt
 
-def binding_info(engine, context):
-    for i in range(engine.num_bindings):
-        print("|||| binding", i)
-        print("|||| binding_is_input", engine.binding_is_input(i))
-        print("|||| get_binding_dtype", engine.get_binding_dtype(i))
-        print("|||| get_binding_name", engine.get_binding_name(i))
-        print("|||| get_binding_shape", engine.get_binding_shape(i))
-        print("|||| get_binding_vectorized_dim", engine.get_binding_vectorized_dim(i))
-
-    print("|||| all_binding_shapes_specified", context.all_binding_shapes_specified)
-    print("|||| all_shape_inputs_specified", context.all_shape_inputs_specified)
-
-
 def is_dimension_dynamic(dim):
     return dim is None or dim <= 0
 
@@ -60,7 +47,6 @@ def run_trt_engine(context, engine, tensors):
         elif is_shape_dynamic(context.get_binding_shape(idx)):
             context.set_binding_shape(idx, tensor.shape)
 
-    # binding_info(engine, context)
     context.execute_v2(bindings=bindings)
 
 
@@ -70,6 +56,45 @@ def load_engine(engine_filepath, trt_logger):
     return engine
 
 
+def engine_info(engine_filepath):
+
+    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
+    engine = load_engine(engine_filepath, TRT_LOGGER)
+
+    binding_template = r"""
+{btype} {{
+  name: "{bname}"
+  data_type: {dtype}
+  dims: {dims}
+}}"""
+    type_mapping = {"DataType.HALF": "TYPE_FP16",
+    "DataType.FLOAT": "TYPE_FP32",
+    "DataType.INT32": "TYPE_INT32"}
+
+    print("engine name", engine.name)
+    print("has_implicit_batch_dimension", engine.has_implicit_batch_dimension)
+    start_dim = 0 if engine.has_implicit_batch_dimension else 1
+    print("num_optimization_profiles", engine.num_optimization_profiles)
+    print("max_batch_size:", engine.max_batch_size)
+    print("device_memory_size:", engine.device_memory_size)
+    print("max_workspace_size:", engine.max_workspace_size)
+    print("num_layers:", engine.num_layers)
+
+    for i in range(engine.num_bindings):
+        btype = "input" if engine.binding_is_input(i) else "output"
+        bname = engine.get_binding_name(i)
+        dtype = engine.get_binding_dtype(i)
+        bdims = engine.get_binding_shape(i)
+        config_values = {
+            "btype": btype,
+            "bname": bname,
+            "dtype": type_mapping[str(dtype)],
+            "dims": list(bdims[start_dim:])
+        }
+        final_binding_str = binding_template.format_map(config_values)
+        print(final_binding_str)
+
+
 def build_engine(model_file, shapes, max_ws=512*1024*1024, fp16=False):
     TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
     builder = trt.Builder(TRT_LOGGER)
@@ -90,9 +115,7 @@ def build_engine(model_file, shapes, max_ws=512*1024*1024, fp16=False):
         with open(model_file, 'rb') as model:
             parsed = parser.parse(model.read())
             for i in range(parser.num_errors):
-                e = parser.get_error(i)
+                print("TensorRT ONNX parser error:", parser.get_error(i))
             engine = builder.build_engine(network, config=config)
 
             return engine
-
-

+ 1 - 46
PyTorch/SpeechSynthesis/Tacotron2/waveglow/model.py

@@ -78,7 +78,7 @@ class Invertible1x1Conv(torch.nn.Module):
             return z
         else:
             # Forward computation
-            log_det_W = batch_size * n_of_groups * torch.logdet(W.float())
+            log_det_W = batch_size * n_of_groups * torch.logdet(W.unsqueeze(0).float()).squeeze()
             z = self.conv(z)
             return z, log_det_W
 
@@ -273,51 +273,6 @@ class WaveGlow(torch.nn.Module):
         return audio
 
 
-    def infer_onnx(self, spect, z, sigma=0.9):
-
-        spect = self.upsample(spect)
-        # trim conv artifacts. maybe pad spec to kernel multiple
-        time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
-        spect = spect[:, :, :-time_cutoff]
-
-        length_spect_group = spect.size(2)//8
-        mel_dim = 80
-        batch_size = spect.size(0)
-
-        spect = torch.squeeze(spect, 3)
-        spect = spect.view((batch_size, mel_dim, length_spect_group, self.n_group))
-        spect = spect.permute(0, 2, 1, 3)
-        spect = spect.contiguous()
-        spect = spect.view((batch_size, length_spect_group, self.n_group*mel_dim))
-        spect = spect.permute(0, 2, 1)
-        spect = torch.unsqueeze(spect, 3)
-
-        audio = z[:, :self.n_remaining_channels, :, :]
-        z = z[:, self.n_remaining_channels:self.n_group, :, :]
-        audio = sigma*audio
-
-        for k in reversed(range(self.n_flows)):
-            n_half = int(audio.size(1) / 2)
-            audio_0 = audio[:, :n_half, :, :]
-            audio_1 = audio[:, n_half:(n_half+n_half), :, :]
-
-            output = self.WN[k]((audio_0, spect))
-            s = output[:, n_half:(n_half+n_half), :, :]
-            b = output[:, :n_half, :, :]
-            audio_1 = (audio_1 - b) / torch.exp(s)
-            audio = torch.cat([audio_0, audio_1], 1)
-
-            audio = self.convinv[k](audio)
-
-            if k % self.n_early_every == 0 and k > 0:
-                audio = torch.cat((z[:, :self.n_early_size, :, :], audio), 1)
-                z = z[:, self.n_early_size:self.n_group, :, :]
-
-        audio = torch.squeeze(audio, 3)
-        audio = audio.permute(0,2,1).contiguous().view(batch_size, (length_spect_group * self.n_group))
-
-        return audio
-
     @staticmethod
     def remove_weightnorm(model):
         waveglow = model