| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- # *****************************************************************************
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- # * Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # * Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- # * Neither the name of the NVIDIA CORPORATION nor the
- # names of its contributors may be used to endorse or promote products
- # derived from this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- #
- # *****************************************************************************
- import torch
- from torch import nn as nn
- from torch.nn.utils.rnn import pad_sequence
- from common.layers import ConvReLUNorm
- from common.utils import mask_from_lens
- from fastpitch.transformer import FFTransformer
- def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
- """If target=None, then predicted durations are applied"""
- reps = torch.round(durations.float() / pace).long()
- dec_lens = reps.sum(dim=1)
- enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
- for o, r in zip(enc_out, reps)],
- batch_first=True)
- if mel_max_len:
- enc_rep = enc_rep[:, :mel_max_len]
- dec_lens = torch.clamp_max(dec_lens, mel_max_len)
- return enc_rep, dec_lens
- class TemporalPredictor(nn.Module):
- """Predicts a single float per each temporal location"""
- def __init__(self, input_size, filter_size, kernel_size, dropout,
- n_layers=2):
- super(TemporalPredictor, self).__init__()
- self.layers = nn.Sequential(*[
- ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
- kernel_size=kernel_size, dropout=dropout)
- for i in range(n_layers)]
- )
- self.fc = nn.Linear(filter_size, 1, bias=True)
- def forward(self, enc_out, enc_out_mask):
- out = enc_out * enc_out_mask
- out = self.layers(out.transpose(1, 2)).transpose(1, 2)
- out = self.fc(out) * enc_out_mask
- return out.squeeze(-1)
- class FastPitch(nn.Module):
- def __init__(self, n_mel_channels, max_seq_len, n_symbols,
- symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
- in_fft_d_head,
- in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
- in_fft_output_size,
- p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
- out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
- out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
- out_fft_output_size,
- p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
- dur_predictor_kernel_size, dur_predictor_filter_size,
- p_dur_predictor_dropout, dur_predictor_n_layers,
- pitch_predictor_kernel_size, pitch_predictor_filter_size,
- p_pitch_predictor_dropout, pitch_predictor_n_layers):
- super(FastPitch, self).__init__()
- del max_seq_len # unused
- del n_symbols
- self.encoder = FFTransformer(
- n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
- d_model=symbols_embedding_dim,
- d_head=in_fft_d_head,
- d_inner=in_fft_conv1d_filter_size,
- kernel_size=in_fft_conv1d_kernel_size,
- dropout=p_in_fft_dropout,
- dropatt=p_in_fft_dropatt,
- dropemb=p_in_fft_dropemb,
- d_embed=symbols_embedding_dim,
- embed_input=True)
- self.duration_predictor = TemporalPredictor(
- in_fft_output_size,
- filter_size=dur_predictor_filter_size,
- kernel_size=dur_predictor_kernel_size,
- dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
- )
- self.decoder = FFTransformer(
- n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
- d_model=symbols_embedding_dim,
- d_head=out_fft_d_head,
- d_inner=out_fft_conv1d_filter_size,
- kernel_size=out_fft_conv1d_kernel_size,
- dropout=p_out_fft_dropout,
- dropatt=p_out_fft_dropatt,
- dropemb=p_out_fft_dropemb,
- d_embed=symbols_embedding_dim,
- embed_input=False)
- self.pitch_predictor = TemporalPredictor(
- in_fft_output_size,
- filter_size=pitch_predictor_filter_size,
- kernel_size=pitch_predictor_kernel_size,
- dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
- )
- self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3,
- padding=1)
- # Store values precomputed for training data within the model
- self.register_buffer('pitch_mean', torch.zeros(1))
- self.register_buffer('pitch_std', torch.zeros(1))
- self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
- def forward(self, inputs, use_gt_durations=True, use_gt_pitch=True,
- pace=1.0, max_duration=75):
- inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs
- mel_max_len = mel_tgt.size(2)
- # Input FFT
- enc_out, enc_mask = self.encoder(inputs)
- # Embedded for predictors
- pred_enc_out, pred_enc_mask = enc_out, enc_mask
- # Predict durations
- log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
- dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
- # Predict pitch
- pitch_pred = self.pitch_predictor(enc_out, enc_mask)
- if use_gt_pitch and pitch_tgt is not None:
- pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
- else:
- pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
- enc_out = enc_out + pitch_emb.transpose(1, 2)
- len_regulated, dec_lens = regulate_len(
- dur_tgt if use_gt_durations else dur_pred,
- enc_out, pace, mel_max_len)
- # Output FFT
- dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
- mel_out = self.proj(dec_out)
- return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
- def infer(self, inputs, input_lens, pace=1.0, dur_tgt=None, pitch_tgt=None,
- pitch_transform=None, max_duration=75):
- del input_lens # unused
- # Input FFT
- enc_out, enc_mask = self.encoder(inputs)
- # Embedded for predictors
- pred_enc_out, pred_enc_mask = enc_out, enc_mask
- # Predict durations
- log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
- dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
- # Pitch over chars
- pitch_pred = self.pitch_predictor(enc_out, enc_mask)
- if pitch_transform is not None:
- if self.pitch_std[0] == 0.0:
- # XXX LJSpeech-1.1 defaults
- mean, std = 218.14, 67.24
- else:
- mean, std = self.pitch_mean[0], self.pitch_std[0]
- pitch_pred = pitch_transform(pitch_pred, mean, std)
- if pitch_tgt is None:
- pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
- else:
- pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
- enc_out = enc_out + pitch_emb
- len_regulated, dec_lens = regulate_len(
- dur_pred if dur_tgt is None else dur_tgt,
- enc_out, pace, mel_max_len=None)
- dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
- mel_out = self.proj(dec_out)
- # mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
- mel_out = mel_out.permute(0, 2, 1) # For inference.py
- return mel_out, dec_lens, dur_pred, pitch_pred
|