infer.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. # Redistribution and use in source and binary forms, with or without
  3. # modification, are permitted provided that the following conditions are met:
  4. # * Redistributions of source code must retain the above copyright
  5. # notice, this list of conditions and the following disclaimer.
  6. # * Redistributions in binary form must reproduce the above copyright
  7. # notice, this list of conditions and the following disclaimer in the
  8. # documentation and/or other materials provided with the distribution.
  9. # * Neither the name of the NVIDIA CORPORATION nor the
  10. # names of its contributors may be used to endorse or promote products
  11. # derived from this software without specific prior written permission.
  12. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  14. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  15. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  16. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  17. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  18. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  19. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  20. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  21. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  22. import fire
  23. from fastspeech import hparam as hp, DEFAULT_DEVICE
  24. from fastspeech.dataset.ljspeech_dataset import LJSpeechDataset
  25. from fastspeech.inferencer.fastspeech_inferencer import FastSpeechInferencer
  26. from fastspeech.model.fastspeech import Fastspeech
  27. from fastspeech.data_load import PadDataLoader
  28. from fastspeech.utils.logging import tprint
  29. import torch
  30. import pprint
  31. from fastspeech.utils.time import TimeElapsed
  32. # import multiprocessing
  33. # multiprocessing.set_start_method('spawn', True)
  34. pp = pprint.PrettyPrinter(indent=4, width=1000)
  35. def infer(hparam="infer.yaml",
  36. device=DEFAULT_DEVICE,
  37. n_iters=1,
  38. **kwargs):
  39. """ The FastSpeech model inference script.
  40. By default, this script assumes to load parameters in the default config file, fastspeech/hparams/infer.yaml.
  41. Besides the flags, you can also set parameters in the config file via the command-line. For examples,
  42. --dataset_path=DATASET_PATH
  43. Path to dataset directory.
  44. --checkpoint_path=CHECKPOINT_PATH
  45. Path to checkpoint directory. The latest checkpoint will be loaded.
  46. --batch_size=BATCH_SIZE
  47. Batch size to use. Defaults to 1.
  48. Refer to fastspeech/hparams/infer.yaml to see more parameters.
  49. Args:
  50. hparam (str, optional): Path to default config file. Defaults to "infer.yaml".
  51. device (str, optional): Device to use. Defaults to "cuda" if avaiable, or "cpu".
  52. n_iters (int, optional): Number of batches to infer. Defaults to 1.
  53. """
  54. hp.set_hparam(hparam, kwargs)
  55. tprint("Hparams:\n{}".format(pp.pformat(hp)))
  56. tprint("Device count: {}".format(torch.cuda.device_count()))
  57. # model
  58. model = Fastspeech(
  59. max_seq_len=hp.max_seq_len,
  60. d_model=hp.d_model,
  61. phoneme_side_n_layer=hp.phoneme_side_n_layer,
  62. phoneme_side_head=hp.phoneme_side_head,
  63. phoneme_side_conv1d_filter_size=hp.phoneme_side_conv1d_filter_size,
  64. phoneme_side_output_size=hp.phoneme_side_output_size,
  65. mel_side_n_layer=hp.mel_side_n_layer,
  66. mel_side_head=hp.mel_side_head,
  67. mel_side_conv1d_filter_size=hp.mel_side_conv1d_filter_size,
  68. mel_side_output_size=hp.mel_side_output_size,
  69. duration_predictor_filter_size=hp.duration_predictor_filter_size,
  70. duration_predictor_kernel_size=hp.duration_predictor_kernel_size,
  71. fft_conv1d_kernel=hp.fft_conv1d_kernel,
  72. fft_conv1d_padding=hp.fft_conv1d_padding,
  73. dropout=hp.dropout,
  74. n_mels=hp.num_mels,
  75. fused_layernorm=hp.fused_layernorm
  76. )
  77. dataset = LJSpeechDataset(root_path=hp.dataset_path,
  78. meta_file=hp.meta_file,
  79. sr=hp.sr,
  80. n_fft=hp.n_fft,
  81. win_len=hp.win_len,
  82. hop_len=hp.hop_len,
  83. n_mels=hp.num_mels,
  84. mel_fmin=hp.mel_fmin,
  85. mel_fmax=hp.mel_fmax,
  86. exclude_mels=True,
  87. sort_by_length=True if hp.use_trt and hp.trt_multi_engine else False
  88. )
  89. tprint("Dataset size: {}".format(len(dataset)))
  90. data_loader = PadDataLoader(dataset,
  91. batch_size=hp.batch_size,
  92. num_workers=hp.n_workers,
  93. shuffle=False if hp.use_trt and hp.trt_multi_engine else True,
  94. drop_last=True,
  95. )
  96. inferencer = get_inferencer(model, data_loader, device)
  97. try:
  98. n_iters = min(len(data_loader), n_iters) if n_iters else len(data_loader)
  99. tprint("Num of iters: {}".format(n_iters))
  100. with inferencer:
  101. for i in range(n_iters):
  102. tprint("------------- INFERENCE : batch #{} -------------".format(i))
  103. with TimeElapsed(name="Inference Time", cuda_sync=True):
  104. out_batch = inferencer.infer()
  105. # tprint("Output:\n{}".format(pp.pformat(out_batch)))
  106. tprint("Inference has been done.")
  107. except KeyboardInterrupt:
  108. tprint("Inference has been canceled.")
  109. def get_inferencer(model, data_loader, device):
  110. if hp.use_trt:
  111. if hp.trt_multi_engine:
  112. from fastspeech.trt.fastspeech_trt_multi_engine_inferencer import FastSpeechTRTMultiEngineInferencer
  113. inferencer = FastSpeechTRTMultiEngineInferencer('fastspeech',
  114. model,
  115. data_loader=data_loader,
  116. ckpt_path=hp.checkpoint_path,
  117. trt_max_ws_size=hp.trt_max_ws_size,
  118. trt_force_build=hp.trt_force_build,
  119. use_fp16=hp.use_fp16,
  120. trt_file_path_list=hp.trt_file_path_list,
  121. trt_max_input_seq_len_list=hp.trt_max_input_seq_len_list,
  122. trt_max_output_seq_len_list=hp.trt_max_output_seq_len_list,
  123. )
  124. else:
  125. from fastspeech.trt.fastspeech_trt_inferencer import FastSpeechTRTInferencer
  126. inferencer = FastSpeechTRTInferencer('fastspeech',
  127. model,
  128. data_loader=data_loader,
  129. ckpt_path=hp.checkpoint_path,
  130. trt_max_ws_size=hp.trt_max_ws_size,
  131. trt_file_path=hp.trt_file_path,
  132. use_fp16=hp.use_fp16,
  133. trt_force_build=hp.trt_force_build,
  134. trt_max_input_seq_len=hp.trt_max_input_seq_len,
  135. trt_max_output_seq_len=hp.trt_max_output_seq_len,
  136. )
  137. else:
  138. inferencer = FastSpeechInferencer(
  139. 'fastspeech',
  140. model,
  141. data_loader=data_loader,
  142. ckpt_path=hp.checkpoint_path,
  143. log_path=hp.log_path,
  144. device=device,
  145. use_fp16=hp.use_fp16)
  146. return inferencer
  147. if __name__ == '__main__':
  148. torch.backends.cudnn.enabled = True
  149. torch.backends.cudnn.benchmark = False
  150. fire.Fire(infer)