model_jit.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # *****************************************************************************
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. from typing import List, Optional
  28. import torch
  29. from torch import nn as nn
  30. from torch.nn.utils.rnn import pad_sequence
  31. from common.layers import ConvReLUNorm
  32. from fastpitch.transformer_jit import FFTransformer
  33. def regulate_len(durations, enc_out, pace: float = 1.0,
  34. mel_max_len: Optional[int] = None):
  35. """If target=None, then predicted durations are applied"""
  36. reps = torch.round(durations.float() / pace).long()
  37. dec_lens = reps.sum(dim=1)
  38. max_len = dec_lens.max()
  39. bsz, _, hid = enc_out.size()
  40. reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1)
  41. pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype,
  42. device=enc_out.device)
  43. enc_rep = torch.cat([enc_out, pad_vec], dim=1)
  44. enc_rep = torch.repeat_interleave(
  45. enc_rep.view(-1, hid), reps_padded.view(-1), dim=0
  46. ).view(bsz, -1, hid)
  47. # enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
  48. # for o, r in zip(enc_out, reps)],
  49. # batch_first=True)
  50. if mel_max_len is not None:
  51. enc_rep = enc_rep[:, :mel_max_len]
  52. dec_lens = torch.clamp_max(dec_lens, mel_max_len)
  53. return enc_rep, dec_lens
  54. class TemporalPredictor(nn.Module):
  55. """Predicts a single float per each temporal location"""
  56. def __init__(self, input_size, filter_size, kernel_size, dropout,
  57. n_layers=2):
  58. super(TemporalPredictor, self).__init__()
  59. self.layers = nn.Sequential(*[
  60. ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
  61. kernel_size=kernel_size, dropout=dropout)
  62. for i in range(n_layers)]
  63. )
  64. self.fc = nn.Linear(filter_size, 1, bias=True)
  65. def forward(self, enc_out, enc_out_mask):
  66. out = enc_out * enc_out_mask
  67. out = self.layers(out.transpose(1, 2)).transpose(1, 2)
  68. out = self.fc(out) * enc_out_mask
  69. return out.squeeze(-1)
  70. class FastPitch(nn.Module):
  71. def __init__(self, n_mel_channels, max_seq_len, n_symbols,
  72. symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
  73. in_fft_d_head,
  74. in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
  75. in_fft_output_size,
  76. p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
  77. out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
  78. out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
  79. out_fft_output_size,
  80. p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
  81. dur_predictor_kernel_size, dur_predictor_filter_size,
  82. p_dur_predictor_dropout, dur_predictor_n_layers,
  83. pitch_predictor_kernel_size, pitch_predictor_filter_size,
  84. p_pitch_predictor_dropout, pitch_predictor_n_layers):
  85. super(FastPitch, self).__init__()
  86. del max_seq_len # unused
  87. del n_symbols
  88. self.encoder = FFTransformer(
  89. n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
  90. d_model=symbols_embedding_dim,
  91. d_head=in_fft_d_head,
  92. d_inner=in_fft_conv1d_filter_size,
  93. kernel_size=in_fft_conv1d_kernel_size,
  94. dropout=p_in_fft_dropout,
  95. dropatt=p_in_fft_dropatt,
  96. dropemb=p_in_fft_dropemb,
  97. d_embed=symbols_embedding_dim,
  98. embed_input=True)
  99. self.duration_predictor = TemporalPredictor(
  100. in_fft_output_size,
  101. filter_size=dur_predictor_filter_size,
  102. kernel_size=dur_predictor_kernel_size,
  103. dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
  104. )
  105. self.decoder = FFTransformer(
  106. n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
  107. d_model=symbols_embedding_dim,
  108. d_head=out_fft_d_head,
  109. d_inner=out_fft_conv1d_filter_size,
  110. kernel_size=out_fft_conv1d_kernel_size,
  111. dropout=p_out_fft_dropout,
  112. dropatt=p_out_fft_dropatt,
  113. dropemb=p_out_fft_dropemb,
  114. d_embed=symbols_embedding_dim,
  115. embed_input=False)
  116. self.pitch_predictor = TemporalPredictor(
  117. in_fft_output_size,
  118. filter_size=pitch_predictor_filter_size,
  119. kernel_size=pitch_predictor_kernel_size,
  120. dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
  121. )
  122. self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3,
  123. padding=1)
  124. # Store values precomputed for training data within the model
  125. self.register_buffer('pitch_mean', torch.zeros(1))
  126. self.register_buffer('pitch_std', torch.zeros(1))
  127. self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
  128. def forward(self, inputs: List[torch.Tensor], use_gt_durations: bool = True,
  129. use_gt_pitch: bool = True, pace: float = 1.0,
  130. max_duration: int = 75):
  131. inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs
  132. mel_max_len = mel_tgt.size(2)
  133. # Input FFT
  134. enc_out, enc_mask = self.encoder(inputs)
  135. # Embedded for predictors
  136. pred_enc_out, pred_enc_mask = enc_out, enc_mask
  137. # Predict durations
  138. log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
  139. dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
  140. # Predict pitch
  141. pitch_pred = self.pitch_predictor(enc_out, enc_mask)
  142. if use_gt_pitch and pitch_tgt is not None:
  143. pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
  144. else:
  145. pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
  146. enc_out = enc_out + pitch_emb.transpose(1, 2)
  147. len_regulated, dec_lens = regulate_len(
  148. dur_tgt if use_gt_durations else dur_pred,
  149. enc_out, pace, mel_max_len)
  150. # Output FFT
  151. dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
  152. mel_out = self.proj(dec_out)
  153. return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
  154. def infer(self, inputs, input_lens, pace: float = 1.0,
  155. dur_tgt: Optional[torch.Tensor] = None,
  156. pitch_tgt: Optional[torch.Tensor] = None,
  157. max_duration: float = 75):
  158. # Input FFT
  159. enc_out, enc_mask = self.encoder(inputs)
  160. # Embedded for predictors
  161. pred_enc_out, pred_enc_mask = enc_out, enc_mask
  162. # Predict durations
  163. log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
  164. dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
  165. # Pitch over chars
  166. pitch_pred = self.pitch_predictor(enc_out, enc_mask)
  167. if pitch_tgt is None:
  168. pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
  169. else:
  170. pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
  171. enc_out = enc_out + pitch_emb
  172. len_regulated, dec_lens = regulate_len(
  173. dur_pred if dur_tgt is None else dur_tgt,
  174. enc_out, pace, mel_max_len=None)
  175. dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
  176. mel_out = self.proj(dec_out)
  177. # mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
  178. mel_out = mel_out.permute(0, 2, 1) # For inference.py
  179. return mel_out, dec_lens, dur_pred, pitch_pred