perf_infer.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. # Redistribution and use in source and binary forms, with or without
  3. # modification, are permitted provided that the following conditions are met:
  4. # * Redistributions of source code must retain the above copyright
  5. # notice, this list of conditions and the following disclaimer.
  6. # * Redistributions in binary form must reproduce the above copyright
  7. # notice, this list of conditions and the following disclaimer in the
  8. # documentation and/or other materials provided with the distribution.
  9. # * Neither the name of the NVIDIA CORPORATION nor the
  10. # names of its contributors may be used to endorse or promote products
  11. # derived from this software without specific prior written permission.
  12. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  14. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  15. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  16. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  17. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  18. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  19. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  20. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  21. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  22. import pprint
  23. import sys
  24. import time
  25. import fire
  26. import torch
  27. from tqdm import tqdm
  28. from fastspeech import DEFAULT_DEVICE
  29. from fastspeech import hparam as hp
  30. from fastspeech.data_load import PadDataLoader
  31. from fastspeech.dataset.ljspeech_dataset import LJSpeechDataset
  32. from fastspeech.model.fastspeech import Fastspeech
  33. from fastspeech.utils.logging import tprint
  34. from fastspeech.utils.pytorch import to_cpu_numpy, to_device_async
  35. from fastspeech.infer import get_inferencer
  36. from fastspeech.inferencer.waveglow_inferencer import WaveGlowInferencer
  37. from contextlib import ExitStack
  38. from fastspeech.dataset.text_dataset import TextDataset
  39. import numpy as np
  40. try:
  41. from apex import amp
  42. except ImportError:
  43. ImportError('Required to install apex.')
  44. pp = pprint.PrettyPrinter(indent=4, width=1000)
  45. SAMPLE_TEXT = "The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."
  46. INPUT_LEN = 128
  47. INPUT_TEXT = SAMPLE_TEXT[:INPUT_LEN]
  48. WARMUP_ITERS = 3
  49. def perf_inference(hparam="infer.yaml",
  50. with_vocoder=False,
  51. n_iters=None,
  52. device=DEFAULT_DEVICE,
  53. **kwargs):
  54. """The script for estimating inference performance.
  55. By default, this script assumes to load parameters in the default config file, fastspeech/hparams/infer.yaml.
  56. Besides the flags, you can also set parameters in the config file via the command-line. For examples,
  57. --dataset_path=DATASET_PATH
  58. Path to dataset directory.
  59. --checkpoint_path=CHECKPOINT_PATH
  60. Path to checkpoint directory. The latest checkpoint will be loaded.
  61. --batch_size=BATCH_SIZE
  62. Batch size to use. Defaults to 1.
  63. Refer to fastspeech/hparams/infer.yaml to see more parameters.
  64. Args:
  65. hparam (str, optional): Path to default config file. Defaults to "infer.yaml".
  66. with_vocoder (bool, optional): Whether or not to estimate with a vocoder. Defaults to False.
  67. n_iters (int, optional): Number of batches to estimate. Defaults to None (an epoch).
  68. device (str, optional): Device to use. Defaults to "cuda" if avaiable, or "cpu".
  69. """
  70. hp.set_hparam(hparam, kwargs)
  71. tprint("Hparams:\n{}".format(pp.pformat(hp)))
  72. tprint("Device count: {}".format(torch.cuda.device_count()))
  73. model = Fastspeech(
  74. max_seq_len=hp.max_seq_len,
  75. d_model=hp.d_model,
  76. phoneme_side_n_layer=hp.phoneme_side_n_layer,
  77. phoneme_side_head=hp.phoneme_side_head,
  78. phoneme_side_conv1d_filter_size=hp.phoneme_side_conv1d_filter_size,
  79. phoneme_side_output_size=hp.phoneme_side_output_size,
  80. mel_side_n_layer=hp.mel_side_n_layer,
  81. mel_side_head=hp.mel_side_head,
  82. mel_side_conv1d_filter_size=hp.mel_side_conv1d_filter_size,
  83. mel_side_output_size=hp.mel_side_output_size,
  84. duration_predictor_filter_size=hp.duration_predictor_filter_size,
  85. duration_predictor_kernel_size=hp.duration_predictor_kernel_size,
  86. fft_conv1d_kernel=hp.fft_conv1d_kernel,
  87. fft_conv1d_padding=hp.fft_conv1d_padding,
  88. dropout=hp.dropout,
  89. n_mels=hp.num_mels,
  90. fused_layernorm=hp.fused_layernorm
  91. )
  92. dataset_size = hp.batch_size * (n_iters if n_iters else 1)
  93. tprint("Dataset size: {}".format(dataset_size))
  94. dataset = TextDataset([INPUT_TEXT] * (dataset_size + (WARMUP_ITERS * hp.batch_size)))
  95. data_loader = PadDataLoader(dataset,
  96. batch_size=hp.batch_size,
  97. num_workers=hp.n_workers,
  98. shuffle=False if hp.use_trt and hp.trt_multi_engine else True,
  99. drop_last=True,
  100. )
  101. fs_inferencer = get_inferencer(model, data_loader, device)
  102. if with_vocoder:
  103. if hp.use_trt:
  104. from fastspeech.trt.waveglow_trt_inferencer import WaveGlowTRTInferencer
  105. wb_inferencer = WaveGlowTRTInferencer(ckpt_file=hp.waveglow_path, engine_file=hp.waveglow_engine_path, use_fp16=hp.use_fp16)
  106. else:
  107. wb_inferencer = WaveGlowInferencer(ckpt_file=hp.waveglow_path, device=device, use_fp16=hp.use_fp16)
  108. with fs_inferencer, wb_inferencer if with_vocoder else ExitStack():
  109. tprint("Perf started. Batch size={}.".format(hp.batch_size))
  110. latencies = []
  111. throughputs = []
  112. for i in tqdm(range(len(data_loader))):
  113. start = time.time()
  114. outputs = fs_inferencer.infer()
  115. mels = outputs['mel']
  116. mel_masks = outputs['mel_mask']
  117. assert(mels.is_cuda)
  118. if with_vocoder:
  119. # remove padding
  120. max_len = mel_masks.sum(axis=1).max()
  121. mels = mels[..., :max_len]
  122. mel_masks = mel_masks[..., :max_len]
  123. with torch.no_grad():
  124. wavs = wb_inferencer.infer(mels)
  125. wavs = to_cpu_numpy(wavs)
  126. else:
  127. # include time for DtoH copy
  128. to_cpu_numpy(mels)
  129. to_cpu_numpy(mel_masks)
  130. end = time.time()
  131. if i > WARMUP_ITERS-1:
  132. time_elapsed = end - start
  133. generated_samples = len(mel_masks.nonzero()) * hp.hop_len
  134. throughput = generated_samples / time_elapsed
  135. latencies.append(time_elapsed)
  136. throughputs.append(throughput)
  137. latencies.sort()
  138. avg_latency = np.mean(latencies)
  139. std_latency = np.std(latencies)
  140. latency_90 = max(latencies[:int(len(latencies)*0.90)]) if n_iters > 1 else 0
  141. latency_95 = max(latencies[:int(len(latencies)*0.95)]) if n_iters > 1 else 0
  142. latency_99 = max(latencies[:int(len(latencies)*0.99)]) if n_iters > 1 else 0
  143. throughput = np.mean(throughputs)
  144. rtf = throughput / (hp.sr * hp.batch_size)
  145. tprint("Batch size\tPrecision\tAvg Latency(s)\tStd Latency(s)\tLatency 90%(s)\tLatency 95%(s)\tLatency 99%(s)\tThroughput(samples/s)\tAvg RTF\n\
  146. {}\t{}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{}\t{:.2f}".format(
  147. hp.batch_size,
  148. "FP16" if hp.use_fp16 else "FP32",
  149. avg_latency,
  150. std_latency,
  151. latency_90,
  152. latency_95,
  153. latency_99,
  154. int(throughput),
  155. rtf))
  156. if __name__ == '__main__':
  157. fire.Fire(perf_inference)