generate.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. # Redistribution and use in source and binary forms, with or without
  3. # modification, are permitted provided that the following conditions are met:
  4. # * Redistributions of source code must retain the above copyright
  5. # notice, this list of conditions and the following disclaimer.
  6. # * Redistributions in binary form must reproduce the above copyright
  7. # notice, this list of conditions and the following disclaimer in the
  8. # documentation and/or other materials provided with the distribution.
  9. # * Neither the name of the NVIDIA CORPORATION nor the
  10. # names of its contributors may be used to endorse or promote products
  11. # derived from this software without specific prior written permission.
  12. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  14. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  15. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  16. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  17. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  18. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  19. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  20. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  21. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  22. import os
  23. import pathlib
  24. import sys
  25. import time
  26. import fire
  27. import librosa
  28. import torch
  29. from fastspeech.data_load import PadDataLoader
  30. from fastspeech.dataset.text_dataset import TextDataset
  31. from fastspeech.inferencer.fastspeech_inferencer import FastSpeechInferencer
  32. from fastspeech.model.fastspeech import Fastspeech
  33. from fastspeech import hparam as hp, DEFAULT_DEVICE
  34. from fastspeech.utils.logging import tprint
  35. from fastspeech.utils.time import TimeElapsed
  36. from fastspeech.utils.pytorch import to_device_async, to_cpu_numpy
  37. from fastspeech.infer import get_inferencer
  38. from fastspeech.inferencer.waveglow_inferencer import WaveGlowInferencer
  39. MAX_FILESIZE=128
  40. # TODO test with different speeds
  41. def generate(hparam='infer.yaml',
  42. text='test_sentences.txt',
  43. results_path='results',
  44. device=DEFAULT_DEVICE,
  45. **kwargs):
  46. """The script for generating waveforms from texts with a vocoder.
  47. By default, this script assumes to load parameters in the default config file, fastspeech/hparams/infer.yaml.
  48. Besides the flags, you can also set parameters in the config file via the command-line. For examples,
  49. --checkpoint_path=CHECKPOINT_PATH
  50. Path to checkpoint directory. The latest checkpoint will be loaded.
  51. --waveglow_path=WAVEGLOW_PATH
  52. Path to the WaveGlow checkpoint file.
  53. --waveglow_engine_path=WAVEGLOW_ENGINE_PATH
  54. Path to the WaveGlow engine file. It can be only used with --use_trt=True.
  55. --batch_size=BATCH_SIZE
  56. Batch size to use. Defaults to 1.
  57. Refer to fastspeech/hparams/infer.yaml to see more parameters.
  58. Args:
  59. hparam (str, optional): Path to default config file. Defaults to "infer.yaml".
  60. text (str, optional): a sample text or a text file path to generate its waveform. Defaults to 'test_sentences.txt'.
  61. results_path (str, optional): Path to output waveforms directory. Defaults to 'results'.
  62. device (str, optional): Device to use. Defaults to "cuda" if avaiable, or "cpu".
  63. """
  64. hp.set_hparam(hparam, kwargs)
  65. if os.path.isfile(text):
  66. f = open(text, 'r', encoding="utf-8")
  67. texts = f.read().splitlines()
  68. else: # single string
  69. texts = [text]
  70. dataset = TextDataset(texts)
  71. data_loader = PadDataLoader(dataset,
  72. batch_size=hp.batch_size,
  73. num_workers=hp.n_workers,
  74. shuffle=False,
  75. drop_last=False)
  76. # text to mel
  77. model = Fastspeech(
  78. max_seq_len=hp.max_seq_len,
  79. d_model=hp.d_model,
  80. phoneme_side_n_layer=hp.phoneme_side_n_layer,
  81. phoneme_side_head=hp.phoneme_side_head,
  82. phoneme_side_conv1d_filter_size=hp.phoneme_side_conv1d_filter_size,
  83. phoneme_side_output_size=hp.phoneme_side_output_size,
  84. mel_side_n_layer=hp.mel_side_n_layer,
  85. mel_side_head=hp.mel_side_head,
  86. mel_side_conv1d_filter_size=hp.mel_side_conv1d_filter_size,
  87. mel_side_output_size=hp.mel_side_output_size,
  88. duration_predictor_filter_size=hp.duration_predictor_filter_size,
  89. duration_predictor_kernel_size=hp.duration_predictor_kernel_size,
  90. fft_conv1d_kernel=hp.fft_conv1d_kernel,
  91. fft_conv1d_padding=hp.fft_conv1d_padding,
  92. dropout=hp.dropout,
  93. n_mels=hp.num_mels,
  94. fused_layernorm=hp.fused_layernorm
  95. )
  96. fs_inferencer = get_inferencer(model, data_loader, device)
  97. # set up WaveGlow
  98. if hp.use_trt:
  99. from fastspeech.trt.waveglow_trt_inferencer import WaveGlowTRTInferencer
  100. wb_inferencer = WaveGlowTRTInferencer(
  101. ckpt_file=hp.waveglow_path, engine_file=hp.waveglow_engine_path, use_fp16=hp.use_fp16)
  102. else:
  103. wb_inferencer = WaveGlowInferencer(
  104. ckpt_file=hp.waveglow_path, device=device, use_fp16=hp.use_fp16)
  105. tprint("Generating {} sentences.. ".format(len(dataset)))
  106. with fs_inferencer, wb_inferencer:
  107. try:
  108. for i in range(len(data_loader)):
  109. tprint("------------- BATCH # {} -------------".format(i))
  110. with TimeElapsed(name="Inferece Time: E2E", format=":.6f"):
  111. ## Text-to-Mel ##
  112. with TimeElapsed(name="Inferece Time: FastSpeech", device=device, cuda_sync=True, format=":.6f"), torch.no_grad():
  113. outputs = fs_inferencer.infer()
  114. texts = outputs["text"]
  115. mels = outputs["mel"] # (b, n_mels, t)
  116. mel_masks = outputs['mel_mask'] # (b, t)
  117. # assert(mels.is_cuda)
  118. # remove paddings
  119. mel_lens = mel_masks.sum(axis=1)
  120. max_len = mel_lens.max()
  121. mels = mels[..., :max_len]
  122. mel_masks = mel_masks[..., :max_len]
  123. ## Vocoder ##
  124. with TimeElapsed(name="Inferece Time: WaveGlow", device=device, cuda_sync=True, format=":.6f"), torch.no_grad():
  125. wavs = wb_inferencer.infer(mels)
  126. wavs = to_cpu_numpy(wavs)
  127. ## Write wavs ##
  128. pathlib.Path(results_path).mkdir(parents=True, exist_ok=True)
  129. for i, (text, wav) in enumerate(zip(texts, wavs)):
  130. tprint("TEXT #{}: \"{}\"".format(i, text))
  131. # remove paddings in case of batch size > 1
  132. wav_len = mel_lens[i] * hp.hop_len
  133. wav = wav[:wav_len]
  134. path = os.path.join(results_path, text[:MAX_FILESIZE] + ".wav")
  135. librosa.output.write_wav(path, wav, hp.sr)
  136. except StopIteration:
  137. tprint("Generation has been done.")
  138. except KeyboardInterrupt:
  139. tprint("Generation has been canceled.")
  140. if __name__ == '__main__':
  141. fire.Fire(generate)