train.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. # Redistribution and use in source and binary forms, with or without
  3. # modification, are permitted provided that the following conditions are met:
  4. # * Redistributions of source code must retain the above copyright
  5. # notice, this list of conditions and the following disclaimer.
  6. # * Redistributions in binary form must reproduce the above copyright
  7. # notice, this list of conditions and the following disclaimer in the
  8. # documentation and/or other materials provided with the distribution.
  9. # * Neither the name of the NVIDIA CORPORATION nor the
  10. # names of its contributors may be used to endorse or promote products
  11. # derived from this software without specific prior written permission.
  12. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  14. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  15. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  16. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  17. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  18. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  19. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  20. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  21. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  22. import pprint
  23. import fire
  24. import torch
  25. from torch.optim.lr_scheduler import LambdaLR
  26. from fastspeech import DEFAULT_DEVICE
  27. from fastspeech import hparam as hp
  28. from fastspeech.data_load import PadDataLoader
  29. from fastspeech.dataset.ljspeech_dataset import LJSpeechDataset
  30. from fastspeech.model.fastspeech import Fastspeech
  31. from fastspeech.trainer.fastspeech_trainer import FastspeechTrainer
  32. from fastspeech.utils.logging import tprint
  33. try:
  34. import apex
  35. except ImportError:
  36. ImportError('Required to install apex.')
  37. # import multiprocessing
  38. # multiprocessing.set_start_method('spawn', True)
  39. pp = pprint.PrettyPrinter(indent=4, width=1000)
  40. def train(hparam="train.yaml",
  41. device=DEFAULT_DEVICE,
  42. **kwargs):
  43. """ The FastSpeech model training script.
  44. By default, this script assumes to load parameters in the default config file, fastspeech/hparams/train.yaml.
  45. Besides the flags, you can also set parameters in the config file via the command-line. For examples,
  46. --dataset_path=DATASET_PATH
  47. Path to dataset directory.
  48. --tacotron2_path=TACOTRON2_PATH
  49. Path to tacotron2 checkpoint file.
  50. --mels_path=MELS_PATH
  51. Path to preprocessed mels directory.
  52. --aligns_path=ALIGNS_PATH
  53. Path to preprocessed alignments directory.
  54. --log_path=LOG_PATH
  55. Path to log directory.
  56. --checkpoint_path=CHECKPOINT_PATH
  57. Path to checkpoint directory. The latest checkpoint will be loaded.
  58. --batch_size=BATCH_SIZE
  59. Batch size to use. Defaults to 16.
  60. Refer to fastspeech/hparams/train.yaml to see more parameters.
  61. Args:
  62. hparam (str, optional): Path to default config file. Defaults to "train.yaml".
  63. device (str, optional): Device to use. Defaults to "cuda" if avaiable, or "cpu".
  64. """
  65. hp.set_hparam(hparam, kwargs)
  66. tprint("Hparams:\n{}".format(pp.pformat(hp)))
  67. tprint("Device count: {}".format(torch.cuda.device_count()))
  68. # model
  69. model = Fastspeech(
  70. max_seq_len=hp.max_seq_len,
  71. d_model=hp.d_model,
  72. phoneme_side_n_layer=hp.phoneme_side_n_layer,
  73. phoneme_side_head=hp.phoneme_side_head,
  74. phoneme_side_conv1d_filter_size=hp.phoneme_side_conv1d_filter_size,
  75. phoneme_side_output_size=hp.phoneme_side_output_size,
  76. mel_side_n_layer=hp.mel_side_n_layer,
  77. mel_side_head=hp.mel_side_head,
  78. mel_side_conv1d_filter_size=hp.mel_side_conv1d_filter_size,
  79. mel_side_output_size=hp.mel_side_output_size,
  80. duration_predictor_filter_size=hp.duration_predictor_filter_size,
  81. duration_predictor_kernel_size=hp.duration_predictor_kernel_size,
  82. fft_conv1d_kernel=hp.fft_conv1d_kernel,
  83. fft_conv1d_padding=hp.fft_conv1d_padding,
  84. dropout=hp.dropout,
  85. n_mels=hp.num_mels,
  86. fused_layernorm=hp.fused_layernorm
  87. )
  88. # dataset
  89. dataset = LJSpeechDataset(root_path=hp.dataset_path,
  90. meta_file=hp.meta_file,
  91. mels_path=hp.mels_path,
  92. aligns_path=hp.aligns_path,
  93. sr=hp.sr,
  94. n_fft=hp.n_fft,
  95. win_len=hp.win_len,
  96. hop_len=hp.hop_len,
  97. n_mels=hp.num_mels,
  98. mel_fmin=hp.mel_fmin,
  99. mel_fmax=hp.mel_fmax,
  100. )
  101. tprint("Dataset size: {}".format(len(dataset)))
  102. # data loader
  103. data_loader = PadDataLoader(dataset,
  104. batch_size=hp.batch_size,
  105. num_workers=hp.n_workers,
  106. drop_last=True,
  107. )
  108. # optimizer
  109. def get_optimizer(model):
  110. optimizer = torch.optim.Adam(
  111. model.parameters(),
  112. lr=hp.learning_rate,
  113. betas=(0.9, 0.98),
  114. eps=1e-9)
  115. return optimizer
  116. def get_warmup_lr_scheduler(optimizer):
  117. d_model = hp.d_model
  118. warmup_steps = hp.warmup_steps
  119. lr = lambda step: d_model ** -0.5 * min((step + 1) ** -0.5,
  120. (step + 1) * warmup_steps ** -1.5) / hp.learning_rate
  121. scheduler = LambdaLR(optimizer, lr_lambda=[lr])
  122. return scheduler
  123. # trainer
  124. trainer = FastspeechTrainer(data_loader,
  125. 'fastspeech',
  126. model,
  127. optimizer_fn=get_optimizer,
  128. final_steps=hp.final_steps,
  129. log_steps=hp.log_step,
  130. ckpt_path=hp.checkpoint_path,
  131. save_steps=hp.save_step,
  132. log_path=hp.log_path,
  133. lr_scheduler_fn=get_warmup_lr_scheduler,
  134. pre_aligns=True if hp.aligns_path else False,
  135. device=device,
  136. use_amp=hp.use_amp,
  137. nvprof_iter_start=hp.nvprof_iter_start,
  138. nvprof_iter_end=hp.nvprof_iter_end,
  139. pyprof_enabled=hp.pyprof_enabled,
  140. )
  141. trainer.train()
  142. if __name__ == '__main__':
  143. torch.backends.cudnn.enabled = True
  144. torch.backends.cudnn.benchmark = False
  145. fire.Fire(train)