model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # *****************************************************************************
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import torch
  28. from torch import nn as nn
  29. from torch.nn.utils.rnn import pad_sequence
  30. from common.layers import ConvReLUNorm
  31. from common.utils import mask_from_lens
  32. from fastpitch.transformer import FFTransformer
  33. def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
  34. """If target=None, then predicted durations are applied"""
  35. reps = torch.round(durations.float() / pace).long()
  36. dec_lens = reps.sum(dim=1)
  37. enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
  38. for o, r in zip(enc_out, reps)],
  39. batch_first=True)
  40. if mel_max_len:
  41. enc_rep = enc_rep[:, :mel_max_len]
  42. dec_lens = torch.clamp_max(dec_lens, mel_max_len)
  43. return enc_rep, dec_lens
  44. class TemporalPredictor(nn.Module):
  45. """Predicts a single float per each temporal location"""
  46. def __init__(self, input_size, filter_size, kernel_size, dropout,
  47. n_layers=2):
  48. super(TemporalPredictor, self).__init__()
  49. self.layers = nn.Sequential(*[
  50. ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
  51. kernel_size=kernel_size, dropout=dropout)
  52. for i in range(n_layers)]
  53. )
  54. self.fc = nn.Linear(filter_size, 1, bias=True)
  55. def forward(self, enc_out, enc_out_mask):
  56. out = enc_out * enc_out_mask
  57. out = self.layers(out.transpose(1, 2)).transpose(1, 2)
  58. out = self.fc(out) * enc_out_mask
  59. return out.squeeze(-1)
  60. class FastPitch(nn.Module):
  61. def __init__(self, n_mel_channels, max_seq_len, n_symbols,
  62. symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
  63. in_fft_d_head,
  64. in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
  65. in_fft_output_size,
  66. p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
  67. out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
  68. out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
  69. out_fft_output_size,
  70. p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
  71. dur_predictor_kernel_size, dur_predictor_filter_size,
  72. p_dur_predictor_dropout, dur_predictor_n_layers,
  73. pitch_predictor_kernel_size, pitch_predictor_filter_size,
  74. p_pitch_predictor_dropout, pitch_predictor_n_layers):
  75. super(FastPitch, self).__init__()
  76. del max_seq_len # unused
  77. del n_symbols
  78. self.encoder = FFTransformer(
  79. n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
  80. d_model=symbols_embedding_dim,
  81. d_head=in_fft_d_head,
  82. d_inner=in_fft_conv1d_filter_size,
  83. kernel_size=in_fft_conv1d_kernel_size,
  84. dropout=p_in_fft_dropout,
  85. dropatt=p_in_fft_dropatt,
  86. dropemb=p_in_fft_dropemb,
  87. d_embed=symbols_embedding_dim,
  88. embed_input=True)
  89. self.duration_predictor = TemporalPredictor(
  90. in_fft_output_size,
  91. filter_size=dur_predictor_filter_size,
  92. kernel_size=dur_predictor_kernel_size,
  93. dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
  94. )
  95. self.decoder = FFTransformer(
  96. n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
  97. d_model=symbols_embedding_dim,
  98. d_head=out_fft_d_head,
  99. d_inner=out_fft_conv1d_filter_size,
  100. kernel_size=out_fft_conv1d_kernel_size,
  101. dropout=p_out_fft_dropout,
  102. dropatt=p_out_fft_dropatt,
  103. dropemb=p_out_fft_dropemb,
  104. d_embed=symbols_embedding_dim,
  105. embed_input=False)
  106. self.pitch_predictor = TemporalPredictor(
  107. in_fft_output_size,
  108. filter_size=pitch_predictor_filter_size,
  109. kernel_size=pitch_predictor_kernel_size,
  110. dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
  111. )
  112. self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3,
  113. padding=1)
  114. # Store values precomputed for training data within the model
  115. self.register_buffer('pitch_mean', torch.zeros(1))
  116. self.register_buffer('pitch_std', torch.zeros(1))
  117. self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
  118. def forward(self, inputs, use_gt_durations=True, use_gt_pitch=True,
  119. pace=1.0, max_duration=75):
  120. inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs
  121. mel_max_len = mel_tgt.size(2)
  122. # Input FFT
  123. enc_out, enc_mask = self.encoder(inputs)
  124. # Embedded for predictors
  125. pred_enc_out, pred_enc_mask = enc_out, enc_mask
  126. # Predict durations
  127. log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
  128. dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
  129. # Predict pitch
  130. pitch_pred = self.pitch_predictor(enc_out, enc_mask)
  131. if use_gt_pitch and pitch_tgt is not None:
  132. pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
  133. else:
  134. pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
  135. enc_out = enc_out + pitch_emb.transpose(1, 2)
  136. len_regulated, dec_lens = regulate_len(
  137. dur_tgt if use_gt_durations else dur_pred,
  138. enc_out, pace, mel_max_len)
  139. # Output FFT
  140. dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
  141. mel_out = self.proj(dec_out)
  142. return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
  143. def infer(self, inputs, input_lens, pace=1.0, dur_tgt=None, pitch_tgt=None,
  144. pitch_transform=None, max_duration=75):
  145. del input_lens # unused
  146. # Input FFT
  147. enc_out, enc_mask = self.encoder(inputs)
  148. # Embedded for predictors
  149. pred_enc_out, pred_enc_mask = enc_out, enc_mask
  150. # Predict durations
  151. log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
  152. dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
  153. # Pitch over chars
  154. pitch_pred = self.pitch_predictor(enc_out, enc_mask)
  155. if pitch_transform is not None:
  156. if self.pitch_std[0] == 0.0:
  157. # XXX LJSpeech-1.1 defaults
  158. mean, std = 218.14, 67.24
  159. else:
  160. mean, std = self.pitch_mean[0], self.pitch_std[0]
  161. pitch_pred = pitch_transform(pitch_pred, mean, std)
  162. if pitch_tgt is None:
  163. pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
  164. else:
  165. pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
  166. enc_out = enc_out + pitch_emb
  167. len_regulated, dec_lens = regulate_len(
  168. dur_pred if dur_tgt is None else dur_tgt,
  169. enc_out, pace, mel_max_len=None)
  170. dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
  171. mel_out = self.proj(dec_out)
  172. # mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
  173. mel_out = mel_out.permute(0, 2, 1) # For inference.py
  174. return mel_out, dec_lens, dur_pred, pitch_pred