fastspeech.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. # Redistribution and use in source and binary forms, with or without
  3. # modification, are permitted provided that the following conditions are met:
  4. # * Redistributions of source code must retain the above copyright
  5. # notice, this list of conditions and the following disclaimer.
  6. # * Redistributions in binary form must reproduce the above copyright
  7. # notice, this list of conditions and the following disclaimer in the
  8. # documentation and/or other materials provided with the distribution.
  9. # * Neither the name of the NVIDIA CORPORATION nor the
  10. # names of its contributors may be used to endorse or promote products
  11. # derived from this software without specific prior written permission.
  12. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  14. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  15. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  16. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  17. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  18. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  19. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  20. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  21. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  22. from collections import OrderedDict
  23. import numpy as np
  24. import torch
  25. from torch import nn as nn
  26. from fastspeech.model.module import FFTBlocks, LengthRegulator
  27. from fastspeech.utils.pytorch import to_device_async
  28. from fastspeech.utils.nvtx import Nvtx
  29. from torch.nn import functional as F
  30. from fastspeech.utils.logging import tprint
  31. from fastspeech.text_norm.symbols import symbols
  32. class Fastspeech(nn.Module):
  33. """ FastSpeech """
  34. def __init__(self,
  35. max_seq_len,
  36. d_model,
  37. phoneme_side_n_layer,
  38. phoneme_side_head,
  39. phoneme_side_conv1d_filter_size,
  40. phoneme_side_output_size,
  41. mel_side_n_layer,
  42. mel_side_head,
  43. mel_side_conv1d_filter_size,
  44. mel_side_output_size,
  45. fft_conv1d_kernel,
  46. fft_conv1d_padding,
  47. duration_predictor_filter_size,
  48. duration_predictor_kernel_size,
  49. dropout,
  50. n_mels,
  51. fused_layernorm=False):
  52. super(Fastspeech, self).__init__()
  53. self.max_seq_len = max_seq_len
  54. self.d_model = d_model
  55. self.phoneme_side_n_layer = phoneme_side_n_layer
  56. self.phoneme_side_head = phoneme_side_head
  57. self.phoneme_side_conv1d_filter_size = phoneme_side_conv1d_filter_size
  58. self.phoneme_side_output_size = phoneme_side_output_size
  59. self.mel_side_n_layer = mel_side_n_layer
  60. self.mel_side_head = mel_side_head
  61. self.mel_side_conv1d_filter_size = mel_side_conv1d_filter_size
  62. self.mel_side_output_size = mel_side_output_size
  63. self.fft_conv1d_kernel = fft_conv1d_kernel
  64. self.fft_conv1d_padding = fft_conv1d_padding
  65. self.duration_predictor_filter_size = duration_predictor_filter_size
  66. self.duration_predictor_kernel_size = duration_predictor_kernel_size
  67. self.dropout = dropout
  68. self.n_mels = n_mels
  69. self.fused_layernorm = fused_layernorm
  70. self.n_phns = len(symbols)+1
  71. self.word_emb = nn.Embedding(
  72. self.n_phns,
  73. d_model,
  74. padding_idx=0)
  75. self.phoneme_side = FFTBlocks(
  76. max_seq_len=max_seq_len,
  77. n_layers=phoneme_side_n_layer,
  78. n_head=phoneme_side_head,
  79. d_k=64,
  80. d_v=64,
  81. d_model=d_model,
  82. d_inner=phoneme_side_conv1d_filter_size,
  83. fft_conv1d_kernel=fft_conv1d_kernel,
  84. fft_conv1d_padding=fft_conv1d_padding,
  85. dropout=dropout,
  86. name="phoneme_side",
  87. fused_layernorm=fused_layernorm
  88. )
  89. self.length_regulator = LengthRegulator(
  90. input_size=phoneme_side_output_size,
  91. duration_predictor_filter_size=duration_predictor_filter_size,
  92. duration_predictor_kernel_size=duration_predictor_kernel_size,
  93. dropout=dropout,
  94. fused_layernorm=fused_layernorm
  95. )
  96. self.mel_side = FFTBlocks(
  97. max_seq_len=max_seq_len,
  98. n_layers=mel_side_n_layer,
  99. n_head=mel_side_head,
  100. d_k=64,
  101. d_v=64,
  102. d_model=d_model,
  103. d_inner=mel_side_conv1d_filter_size,
  104. fft_conv1d_kernel=fft_conv1d_kernel,
  105. fft_conv1d_padding=fft_conv1d_padding,
  106. dropout=dropout,
  107. name="mel_side",
  108. fused_layernorm=fused_layernorm
  109. )
  110. self.mel_linear = nn.Linear(mel_side_output_size, n_mels, bias=True)
  111. def forward(self, seq, pos, duration_target=None, alpha=1.0, seq_output_len=None, use_fp16=False, acts=None):
  112. # Phoneme Embedding
  113. output = self.word_emb(seq)
  114. if acts is not None:
  115. acts["act.emb"] = output
  116. if use_fp16:
  117. output = output.half()
  118. # Phoneme Side FFT Blocks
  119. output, output_mask = self.phoneme_side(output, pos, acts=acts)
  120. if acts is not None:
  121. acts["act.phoneme_side.seq"] = output
  122. # Length Regulator
  123. output, pos, duration = self.length_regulator(
  124. output,
  125. output_mask,
  126. target=duration_target,
  127. alpha=alpha)
  128. if seq_output_len:
  129. output = F.pad(output, pad=(0, 0, 0, seq_output_len - output.size(1)))
  130. pos = F.pad(pos, pad=(0, seq_output_len - pos.size(1)))
  131. # length of output mel shouldn't exceed max_seq_len
  132. output = output[:, :self.max_seq_len]
  133. pos = pos[:, :self.max_seq_len]
  134. if acts is not None:
  135. acts["act.length_regulator.seq"] = output
  136. acts["act.length_regulator.dur"] = torch.round(duration)
  137. if self.training or output.bool().any():
  138. # Mel Side FFT Blocks
  139. output, output_mask = self.mel_side(output, pos, acts=acts)
  140. if acts is not None:
  141. acts["act.mel_side.seq"] = output
  142. # Linear Layer
  143. output = self.mel_linear(output)
  144. if acts is not None:
  145. acts["out.seq_mask"] = output_mask
  146. acts["out.seq"] = output
  147. else:
  148. # seq length could be zero, in case duration predictor outputs all zeros.
  149. # In this case, skip feed-forwarding.
  150. tprint("Duration Predictor outputs all zeros. Output will be zero length.")
  151. output_shape = (output.size(0), 0, output_mask.size(2))
  152. output = torch.zeros(size=(output_shape))
  153. output_mask = torch.ones(size=(output_shape))
  154. if torch.cuda.device_count() > 1:
  155. # In a multi-gpu setting, all output mels from devices must have the same length.
  156. # otherwise, an error occurs in process of gathering output.
  157. if not seq_output_len:
  158. seq_output_len = self.max_seq_len
  159. padding = (0, 0, 0, seq_output_len - output.size(1))
  160. output = F.pad(output, padding)
  161. output = output[:, :seq_output_len, :]
  162. output_mask = F.pad(output_mask, padding)
  163. output_mask = output_mask[:, :seq_output_len, :]
  164. return output, output_mask, duration