perf_infer_ljspeech.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. # Redistribution and use in source and binary forms, with or without
  3. # modification, are permitted provided that the following conditions are met:
  4. # * Redistributions of source code must retain the above copyright
  5. # notice, this list of conditions and the following disclaimer.
  6. # * Redistributions in binary form must reproduce the above copyright
  7. # notice, this list of conditions and the following disclaimer in the
  8. # documentation and/or other materials provided with the distribution.
  9. # * Neither the name of the NVIDIA CORPORATION nor the
  10. # names of its contributors may be used to endorse or promote products
  11. # derived from this software without specific prior written permission.
  12. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  14. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  15. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  16. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  17. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  18. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  19. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  20. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  21. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  22. import pprint
  23. import sys
  24. import time
  25. import fire
  26. import torch
  27. from tqdm import tqdm
  28. from fastspeech import DEFAULT_DEVICE
  29. from fastspeech import hparam as hp
  30. from fastspeech.data_load import PadDataLoader
  31. from fastspeech.dataset.ljspeech_dataset import LJSpeechDataset
  32. from fastspeech.model.fastspeech import Fastspeech
  33. from fastspeech.utils.logging import tprint
  34. from fastspeech.utils.pytorch import to_cpu_numpy, to_device_async
  35. from fastspeech.infer import get_inferencer
  36. from fastspeech.inferencer.waveglow_inferencer import WaveGlowInferencer
  37. from contextlib import ExitStack
  38. import numpy as np
  39. try:
  40. from apex import amp
  41. except ImportError:
  42. ImportError('Required to install apex.')
  43. pp = pprint.PrettyPrinter(indent=4, width=1000)
  44. WARMUP_ITERS = 3
  45. def perf_inference(hparam="infer.yaml",
  46. with_vocoder=False,
  47. n_iters=None,
  48. device=DEFAULT_DEVICE,
  49. **kwargs):
  50. """The script for estimating inference performance.
  51. By default, this script assumes to load parameters in the default config file, fastspeech/hparams/infer.yaml.
  52. Besides the flags, you can also set parameters in the config file via the command-line. For examples,
  53. --dataset_path=DATASET_PATH
  54. Path to dataset directory.
  55. --checkpoint_path=CHECKPOINT_PATH
  56. Path to checkpoint directory. The latest checkpoint will be loaded.
  57. --batch_size=BATCH_SIZE
  58. Batch size to use. Defaults to 1.
  59. Refer to fastspeech/hparams/infer.yaml to see more parameters.
  60. Args:
  61. hparam (str, optional): Path to default config file. Defaults to "infer.yaml".
  62. with_vocoder (bool, optional): Whether or not to estimate with a vocoder. Defaults to False.
  63. n_iters (int, optional): Number of batches to estimate. Defaults to None (an epoch).
  64. device (str, optional): Device to use. Defaults to "cuda" if avaiable, or "cpu".
  65. """
  66. hp.set_hparam(hparam, kwargs)
  67. tprint("Hparams:\n{}".format(pp.pformat(hp)))
  68. tprint("Device count: {}".format(torch.cuda.device_count()))
  69. model = Fastspeech(
  70. max_seq_len=hp.max_seq_len,
  71. d_model=hp.d_model,
  72. phoneme_side_n_layer=hp.phoneme_side_n_layer,
  73. phoneme_side_head=hp.phoneme_side_head,
  74. phoneme_side_conv1d_filter_size=hp.phoneme_side_conv1d_filter_size,
  75. phoneme_side_output_size=hp.phoneme_side_output_size,
  76. mel_side_n_layer=hp.mel_side_n_layer,
  77. mel_side_head=hp.mel_side_head,
  78. mel_side_conv1d_filter_size=hp.mel_side_conv1d_filter_size,
  79. mel_side_output_size=hp.mel_side_output_size,
  80. duration_predictor_filter_size=hp.duration_predictor_filter_size,
  81. duration_predictor_kernel_size=hp.duration_predictor_kernel_size,
  82. fft_conv1d_kernel=hp.fft_conv1d_kernel,
  83. fft_conv1d_padding=hp.fft_conv1d_padding,
  84. dropout=hp.dropout,
  85. n_mels=hp.num_mels,
  86. fused_layernorm=hp.fused_layernorm
  87. )
  88. dataset = LJSpeechDataset(root_path=hp.dataset_path,
  89. sr=hp.sr,
  90. n_fft=hp.n_fft,
  91. win_len=hp.win_len,
  92. hop_len=hp.hop_len,
  93. n_mels=hp.num_mels,
  94. mel_fmin=hp.mel_fmin,
  95. mel_fmax=hp.mel_fmax,
  96. exclude_mels=True,
  97. sort_by_length=True if hp.use_trt and hp.trt_multi_engine else False
  98. )
  99. tprint("Dataset size: {}".format(len(dataset)))
  100. data_loader = PadDataLoader(dataset,
  101. batch_size=hp.batch_size,
  102. num_workers=hp.n_workers,
  103. shuffle=False if hp.use_trt and hp.trt_multi_engine else True,
  104. drop_last=True,
  105. )
  106. fs_inferencer = get_inferencer(model, data_loader, device)
  107. if with_vocoder:
  108. if hp.use_trt:
  109. from fastspeech.trt.waveglow_trt_inferencer import WaveGlowTRTInferencer
  110. wb_inferencer = WaveGlowTRTInferencer(ckpt_file=hp.waveglow_path, engine_file=hp.waveglow_engine_path, use_fp16=hp.use_fp16)
  111. else:
  112. wb_inferencer = WaveGlowInferencer(ckpt_file=hp.waveglow_path, device=device, use_fp16=hp.use_fp16)
  113. with fs_inferencer, wb_inferencer if with_vocoder else ExitStack():
  114. tprint("Perf started. Batch size={}.".format(hp.batch_size))
  115. latencies = []
  116. throughputs = []
  117. n_iters = min(n_iters, len(data_loader)) if n_iters else len(data_loader)
  118. assert(n_iters > WARMUP_ITERS)
  119. for i in tqdm(range(n_iters)):
  120. start = time.time()
  121. outputs = fs_inferencer.infer()
  122. mels = outputs['mel']
  123. mel_masks = outputs['mel_mask']
  124. assert(mels.is_cuda)
  125. if with_vocoder:
  126. # remove padding
  127. max_len = mel_masks.sum(axis=1).max()
  128. mels = mels[..., :max_len]
  129. mel_masks = mel_masks[..., :max_len]
  130. with torch.no_grad():
  131. wavs = wb_inferencer.infer(mels)
  132. wavs = to_cpu_numpy(wavs)
  133. else:
  134. # include time for DtoH copy
  135. to_cpu_numpy(mels)
  136. to_cpu_numpy(mel_masks)
  137. end = time.time()
  138. if i > WARMUP_ITERS-1:
  139. time_elapsed = end - start
  140. generated_samples = len(mel_masks.nonzero()) * hp.hop_len
  141. throughput = generated_samples / time_elapsed
  142. latencies.append(time_elapsed)
  143. throughputs.append(throughput)
  144. latencies.sort()
  145. avg_latency = np.mean(latencies)
  146. std_latency = np.std(latencies)
  147. latency_90 = max(latencies[:int(len(latencies)*0.90)]) if n_iters > 1 else 0
  148. latency_95 = max(latencies[:int(len(latencies)*0.95)]) if n_iters > 1 else 0
  149. latency_99 = max(latencies[:int(len(latencies)*0.99)]) if n_iters > 1 else 0
  150. throughput = np.mean(throughputs)
  151. rtf = throughput / (hp.sr * hp.batch_size)
  152. tprint("Batch size\tPrecision\tAvg Latency(s)\tStd Latency(s)\tLatency 90%(s)\tLatency 95%(s)\tLatency 99%(s)\tThroughput(samples/s)\tAvg RTF\n\
  153. {}\t{}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{}\t{:.2f}".format(
  154. hp.batch_size,
  155. "FP16" if hp.use_fp16 else "FP32",
  156. avg_latency,
  157. std_latency,
  158. latency_90,
  159. latency_95,
  160. latency_99,
  161. int(throughput),
  162. rtf))
  163. if __name__ == '__main__':
  164. fire.Fire(perf_inference)