|
|
@@ -1,218 +1,246 @@
|
|
|
-# *****************************************************************************
|
|
|
-# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
|
-#
|
|
|
-# Redistribution and use in source and binary forms, with or without
|
|
|
-# modification, are permitted provided that the following conditions are met:
|
|
|
-# * Redistributions of source code must retain the above copyright
|
|
|
-# notice, this list of conditions and the following disclaimer.
|
|
|
-# * Redistributions in binary form must reproduce the above copyright
|
|
|
-# notice, this list of conditions and the following disclaimer in the
|
|
|
-# documentation and/or other materials provided with the distribution.
|
|
|
-# * Neither the name of the NVIDIA CORPORATION nor the
|
|
|
-# names of its contributors may be used to endorse or promote products
|
|
|
-# derived from this software without specific prior written permission.
|
|
|
-#
|
|
|
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
|
-# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
|
-# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
-# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
|
|
-# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
|
-# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
|
-# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
|
|
-# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
|
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
|
-# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
-#
|
|
|
-# *****************************************************************************
|
|
|
-
|
|
|
-from typing import List, Optional
|
|
|
-
|
|
|
-import torch
|
|
|
-from torch import nn as nn
|
|
|
-from torch.nn.utils.rnn import pad_sequence
|
|
|
-
|
|
|
-from common.layers import ConvReLUNorm
|
|
|
-from fastpitch.transformer_jit import FFTransformer
|
|
|
-
|
|
|
-
|
|
|
-def regulate_len(durations, enc_out, pace: float = 1.0,
|
|
|
- mel_max_len: Optional[int] = None):
|
|
|
- """If target=None, then predicted durations are applied"""
|
|
|
- reps = torch.round(durations.float() / pace).long()
|
|
|
- dec_lens = reps.sum(dim=1)
|
|
|
-
|
|
|
- max_len = dec_lens.max()
|
|
|
- bsz, _, hid = enc_out.size()
|
|
|
-
|
|
|
- reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1)
|
|
|
- pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype,
|
|
|
- device=enc_out.device)
|
|
|
-
|
|
|
- enc_rep = torch.cat([enc_out, pad_vec], dim=1)
|
|
|
- enc_rep = torch.repeat_interleave(
|
|
|
- enc_rep.view(-1, hid), reps_padded.view(-1), dim=0
|
|
|
- ).view(bsz, -1, hid)
|
|
|
-
|
|
|
- # enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
|
|
- # for o, r in zip(enc_out, reps)],
|
|
|
- # batch_first=True)
|
|
|
- if mel_max_len is not None:
|
|
|
- enc_rep = enc_rep[:, :mel_max_len]
|
|
|
- dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
|
|
- return enc_rep, dec_lens
|
|
|
-
|
|
|
-
|
|
|
-class TemporalPredictor(nn.Module):
|
|
|
- """Predicts a single float per each temporal location"""
|
|
|
-
|
|
|
- def __init__(self, input_size, filter_size, kernel_size, dropout,
|
|
|
- n_layers=2):
|
|
|
- super(TemporalPredictor, self).__init__()
|
|
|
-
|
|
|
- self.layers = nn.Sequential(*[
|
|
|
- ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
|
|
- kernel_size=kernel_size, dropout=dropout)
|
|
|
- for i in range(n_layers)]
|
|
|
- )
|
|
|
- self.fc = nn.Linear(filter_size, 1, bias=True)
|
|
|
-
|
|
|
- def forward(self, enc_out, enc_out_mask):
|
|
|
- out = enc_out * enc_out_mask
|
|
|
- out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
|
|
- out = self.fc(out) * enc_out_mask
|
|
|
- return out.squeeze(-1)
|
|
|
-
|
|
|
-
|
|
|
-class FastPitch(nn.Module):
|
|
|
- def __init__(self, n_mel_channels, max_seq_len, n_symbols,
|
|
|
- symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
|
|
- in_fft_d_head,
|
|
|
- in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
|
|
- in_fft_output_size,
|
|
|
- p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
|
|
- out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
|
|
- out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
|
|
- out_fft_output_size,
|
|
|
- p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
|
|
- dur_predictor_kernel_size, dur_predictor_filter_size,
|
|
|
- p_dur_predictor_dropout, dur_predictor_n_layers,
|
|
|
- pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
|
|
- p_pitch_predictor_dropout, pitch_predictor_n_layers):
|
|
|
- super(FastPitch, self).__init__()
|
|
|
- del max_seq_len # unused
|
|
|
- del n_symbols
|
|
|
-
|
|
|
- self.encoder = FFTransformer(
|
|
|
- n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
|
|
- d_model=symbols_embedding_dim,
|
|
|
- d_head=in_fft_d_head,
|
|
|
- d_inner=in_fft_conv1d_filter_size,
|
|
|
- kernel_size=in_fft_conv1d_kernel_size,
|
|
|
- dropout=p_in_fft_dropout,
|
|
|
- dropatt=p_in_fft_dropatt,
|
|
|
- dropemb=p_in_fft_dropemb,
|
|
|
- d_embed=symbols_embedding_dim,
|
|
|
- embed_input=True)
|
|
|
-
|
|
|
- self.duration_predictor = TemporalPredictor(
|
|
|
- in_fft_output_size,
|
|
|
- filter_size=dur_predictor_filter_size,
|
|
|
- kernel_size=dur_predictor_kernel_size,
|
|
|
- dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
|
|
- )
|
|
|
-
|
|
|
- self.decoder = FFTransformer(
|
|
|
- n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
|
|
- d_model=symbols_embedding_dim,
|
|
|
- d_head=out_fft_d_head,
|
|
|
- d_inner=out_fft_conv1d_filter_size,
|
|
|
- kernel_size=out_fft_conv1d_kernel_size,
|
|
|
- dropout=p_out_fft_dropout,
|
|
|
- dropatt=p_out_fft_dropatt,
|
|
|
- dropemb=p_out_fft_dropemb,
|
|
|
- d_embed=symbols_embedding_dim,
|
|
|
- embed_input=False)
|
|
|
-
|
|
|
- self.pitch_predictor = TemporalPredictor(
|
|
|
- in_fft_output_size,
|
|
|
- filter_size=pitch_predictor_filter_size,
|
|
|
- kernel_size=pitch_predictor_kernel_size,
|
|
|
- dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
|
|
- )
|
|
|
- self.pitch_emb = nn.Conv1d(1, symbols_embedding_dim, kernel_size=3,
|
|
|
- padding=1)
|
|
|
-
|
|
|
- # Store values precomputed for training data within the model
|
|
|
- self.register_buffer('pitch_mean', torch.zeros(1))
|
|
|
- self.register_buffer('pitch_std', torch.zeros(1))
|
|
|
-
|
|
|
- self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
|
|
-
|
|
|
- def forward(self, inputs: List[torch.Tensor], use_gt_durations: bool = True,
|
|
|
- use_gt_pitch: bool = True, pace: float = 1.0,
|
|
|
- max_duration: int = 75):
|
|
|
- inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt = inputs
|
|
|
- mel_max_len = mel_tgt.size(2)
|
|
|
-
|
|
|
- # Input FFT
|
|
|
- enc_out, enc_mask = self.encoder(inputs)
|
|
|
-
|
|
|
- # Embedded for predictors
|
|
|
- pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
|
|
-
|
|
|
- # Predict durations
|
|
|
- log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
|
|
- dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
|
|
-
|
|
|
- # Predict pitch
|
|
|
- pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
|
|
-
|
|
|
- if use_gt_pitch and pitch_tgt is not None:
|
|
|
- pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
|
|
- else:
|
|
|
- pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
|
|
- enc_out = enc_out + pitch_emb.transpose(1, 2)
|
|
|
-
|
|
|
- len_regulated, dec_lens = regulate_len(
|
|
|
- dur_tgt if use_gt_durations else dur_pred,
|
|
|
- enc_out, pace, mel_max_len)
|
|
|
-
|
|
|
- # Output FFT
|
|
|
- dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
|
|
- mel_out = self.proj(dec_out)
|
|
|
- return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
|
|
-
|
|
|
- def infer(self, inputs, input_lens, pace: float = 1.0,
|
|
|
- dur_tgt: Optional[torch.Tensor] = None,
|
|
|
- pitch_tgt: Optional[torch.Tensor] = None,
|
|
|
- max_duration: float = 75):
|
|
|
-
|
|
|
- # Input FFT
|
|
|
- enc_out, enc_mask = self.encoder(inputs)
|
|
|
-
|
|
|
- # Embedded for predictors
|
|
|
- pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
|
|
-
|
|
|
- # Predict durations
|
|
|
- log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
|
|
- dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
|
|
-
|
|
|
- # Pitch over chars
|
|
|
- pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
|
|
-
|
|
|
- if pitch_tgt is None:
|
|
|
- pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
|
|
- else:
|
|
|
- pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
|
|
-
|
|
|
- enc_out = enc_out + pitch_emb
|
|
|
-
|
|
|
- len_regulated, dec_lens = regulate_len(
|
|
|
- dur_pred if dur_tgt is None else dur_tgt,
|
|
|
- enc_out, pace, mel_max_len=None)
|
|
|
-
|
|
|
- dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
|
|
- mel_out = self.proj(dec_out)
|
|
|
- # mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
|
|
- mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
|
|
- return mel_out, dec_lens, dur_pred, pitch_pred
|
|
|
+# *****************************************************************************
|
|
|
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
|
|
+#
|
|
|
+# Redistribution and use in source and binary forms, with or without
|
|
|
+# modification, are permitted provided that the following conditions are met:
|
|
|
+# * Redistributions of source code must retain the above copyright
|
|
|
+# notice, this list of conditions and the following disclaimer.
|
|
|
+# * Redistributions in binary form must reproduce the above copyright
|
|
|
+# notice, this list of conditions and the following disclaimer in the
|
|
|
+# documentation and/or other materials provided with the distribution.
|
|
|
+# * Neither the name of the NVIDIA CORPORATION nor the
|
|
|
+# names of its contributors may be used to endorse or promote products
|
|
|
+# derived from this software without specific prior written permission.
|
|
|
+#
|
|
|
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
|
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
|
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
+# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
|
|
+# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
|
+# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
|
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
|
|
+# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
|
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
|
+# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
+#
|
|
|
+# *****************************************************************************
|
|
|
+
|
|
|
+from typing import List, Optional
|
|
|
+
|
|
|
+import torch
|
|
|
+from torch import nn as nn
|
|
|
+from torch.nn.utils.rnn import pad_sequence
|
|
|
+
|
|
|
+from common.layers import ConvReLUNorm
|
|
|
+from fastpitch.transformer_jit import FFTransformer
|
|
|
+
|
|
|
+
|
|
|
+def regulate_len(durations, enc_out, pace: float = 1.0,
|
|
|
+ mel_max_len: Optional[int] = None):
|
|
|
+ """If target=None, then predicted durations are applied"""
|
|
|
+ reps = torch.round(durations.float() / pace).long()
|
|
|
+ dec_lens = reps.sum(dim=1)
|
|
|
+
|
|
|
+ max_len = dec_lens.max()
|
|
|
+ bsz, _, hid = enc_out.size()
|
|
|
+
|
|
|
+ reps_padded = torch.cat([reps, (max_len - dec_lens)[:, None]], dim=1)
|
|
|
+ pad_vec = torch.zeros(bsz, 1, hid, dtype=enc_out.dtype,
|
|
|
+ device=enc_out.device)
|
|
|
+
|
|
|
+ enc_rep = torch.cat([enc_out, pad_vec], dim=1)
|
|
|
+ enc_rep = torch.repeat_interleave(
|
|
|
+ enc_rep.view(-1, hid), reps_padded.view(-1), dim=0
|
|
|
+ ).view(bsz, -1, hid)
|
|
|
+
|
|
|
+ # enc_rep = pad_sequence([torch.repeat_interleave(o, r, dim=0)
|
|
|
+ # for o, r in zip(enc_out, reps)],
|
|
|
+ # batch_first=True)
|
|
|
+ if mel_max_len is not None:
|
|
|
+ enc_rep = enc_rep[:, :mel_max_len]
|
|
|
+ dec_lens = torch.clamp_max(dec_lens, mel_max_len)
|
|
|
+ return enc_rep, dec_lens
|
|
|
+
|
|
|
+
|
|
|
+class TemporalPredictor(nn.Module):
|
|
|
+ """Predicts a single float per each temporal location"""
|
|
|
+
|
|
|
+ def __init__(self, input_size, filter_size, kernel_size, dropout,
|
|
|
+ n_layers=2):
|
|
|
+ super(TemporalPredictor, self).__init__()
|
|
|
+
|
|
|
+ self.layers = nn.Sequential(*[
|
|
|
+ ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
|
|
|
+ kernel_size=kernel_size, dropout=dropout)
|
|
|
+ for i in range(n_layers)]
|
|
|
+ )
|
|
|
+ self.fc = nn.Linear(filter_size, 1, bias=True)
|
|
|
+
|
|
|
+ def forward(self, enc_out, enc_out_mask):
|
|
|
+ out = enc_out * enc_out_mask
|
|
|
+ out = self.layers(out.transpose(1, 2)).transpose(1, 2)
|
|
|
+ out = self.fc(out) * enc_out_mask
|
|
|
+ return out.squeeze(-1)
|
|
|
+
|
|
|
+
|
|
|
+class FastPitch(nn.Module):
|
|
|
+ def __init__(self, n_mel_channels, max_seq_len, n_symbols, padding_idx,
|
|
|
+ symbols_embedding_dim, in_fft_n_layers, in_fft_n_heads,
|
|
|
+ in_fft_d_head,
|
|
|
+ in_fft_conv1d_kernel_size, in_fft_conv1d_filter_size,
|
|
|
+ in_fft_output_size,
|
|
|
+ p_in_fft_dropout, p_in_fft_dropatt, p_in_fft_dropemb,
|
|
|
+ out_fft_n_layers, out_fft_n_heads, out_fft_d_head,
|
|
|
+ out_fft_conv1d_kernel_size, out_fft_conv1d_filter_size,
|
|
|
+ out_fft_output_size,
|
|
|
+ p_out_fft_dropout, p_out_fft_dropatt, p_out_fft_dropemb,
|
|
|
+ dur_predictor_kernel_size, dur_predictor_filter_size,
|
|
|
+ p_dur_predictor_dropout, dur_predictor_n_layers,
|
|
|
+ pitch_predictor_kernel_size, pitch_predictor_filter_size,
|
|
|
+ p_pitch_predictor_dropout, pitch_predictor_n_layers,
|
|
|
+ pitch_embedding_kernel_size, n_speakers, speaker_emb_weight):
|
|
|
+ super(FastPitch, self).__init__()
|
|
|
+ del max_seq_len # unused
|
|
|
+
|
|
|
+ self.encoder = FFTransformer(
|
|
|
+ n_layer=in_fft_n_layers, n_head=in_fft_n_heads,
|
|
|
+ d_model=symbols_embedding_dim,
|
|
|
+ d_head=in_fft_d_head,
|
|
|
+ d_inner=in_fft_conv1d_filter_size,
|
|
|
+ kernel_size=in_fft_conv1d_kernel_size,
|
|
|
+ dropout=p_in_fft_dropout,
|
|
|
+ dropatt=p_in_fft_dropatt,
|
|
|
+ dropemb=p_in_fft_dropemb,
|
|
|
+ embed_input=True,
|
|
|
+ d_embed=symbols_embedding_dim,
|
|
|
+ n_embed=n_symbols,
|
|
|
+ padding_idx=padding_idx)
|
|
|
+
|
|
|
+ if n_speakers > 1:
|
|
|
+ self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim)
|
|
|
+ else:
|
|
|
+ self.speaker_emb = None
|
|
|
+ self.speaker_emb_weight = speaker_emb_weight
|
|
|
+
|
|
|
+ self.duration_predictor = TemporalPredictor(
|
|
|
+ in_fft_output_size,
|
|
|
+ filter_size=dur_predictor_filter_size,
|
|
|
+ kernel_size=dur_predictor_kernel_size,
|
|
|
+ dropout=p_dur_predictor_dropout, n_layers=dur_predictor_n_layers
|
|
|
+ )
|
|
|
+
|
|
|
+ self.decoder = FFTransformer(
|
|
|
+ n_layer=out_fft_n_layers, n_head=out_fft_n_heads,
|
|
|
+ d_model=symbols_embedding_dim,
|
|
|
+ d_head=out_fft_d_head,
|
|
|
+ d_inner=out_fft_conv1d_filter_size,
|
|
|
+ kernel_size=out_fft_conv1d_kernel_size,
|
|
|
+ dropout=p_out_fft_dropout,
|
|
|
+ dropatt=p_out_fft_dropatt,
|
|
|
+ dropemb=p_out_fft_dropemb,
|
|
|
+ embed_input=False,
|
|
|
+ d_embed=symbols_embedding_dim
|
|
|
+ )
|
|
|
+
|
|
|
+ self.pitch_predictor = TemporalPredictor(
|
|
|
+ in_fft_output_size,
|
|
|
+ filter_size=pitch_predictor_filter_size,
|
|
|
+ kernel_size=pitch_predictor_kernel_size,
|
|
|
+ dropout=p_pitch_predictor_dropout, n_layers=pitch_predictor_n_layers
|
|
|
+ )
|
|
|
+
|
|
|
+ self.pitch_emb = nn.Conv1d(
|
|
|
+ 1, symbols_embedding_dim,
|
|
|
+ kernel_size=pitch_embedding_kernel_size,
|
|
|
+ padding=int((pitch_embedding_kernel_size - 1) / 2))
|
|
|
+
|
|
|
+ # Store values precomputed for training data within the model
|
|
|
+ self.register_buffer('pitch_mean', torch.zeros(1))
|
|
|
+ self.register_buffer('pitch_std', torch.zeros(1))
|
|
|
+
|
|
|
+ self.proj = nn.Linear(out_fft_output_size, n_mel_channels, bias=True)
|
|
|
+
|
|
|
+ def forward(self, inputs: List[torch.Tensor], use_gt_durations: bool = True,
|
|
|
+ use_gt_pitch: bool = True, pace: float = 1.0,
|
|
|
+ max_duration: int = 75):
|
|
|
+ inputs, _, mel_tgt, _, dur_tgt, _, pitch_tgt, speaker = inputs
|
|
|
+ mel_max_len = mel_tgt.size(2)
|
|
|
+
|
|
|
+ # Calculate speaker embedding
|
|
|
+ if self.speaker_emb is None:
|
|
|
+ spk_emb = 0
|
|
|
+ else:
|
|
|
+ spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
|
|
+ spk_emb.mul_(self.speaker_emb_weight)
|
|
|
+
|
|
|
+ # Input FFT
|
|
|
+ enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
|
|
+
|
|
|
+ # Embedded for predictors
|
|
|
+ pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
|
|
+
|
|
|
+ # Predict durations
|
|
|
+ log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
|
|
+ dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
|
|
+
|
|
|
+ # Predict pitch
|
|
|
+ pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
|
|
+
|
|
|
+ if use_gt_pitch and pitch_tgt is not None:
|
|
|
+ pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1))
|
|
|
+ else:
|
|
|
+ pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1))
|
|
|
+ enc_out = enc_out + pitch_emb.transpose(1, 2)
|
|
|
+
|
|
|
+ len_regulated, dec_lens = regulate_len(
|
|
|
+ dur_tgt if use_gt_durations else dur_pred,
|
|
|
+ enc_out, pace, mel_max_len)
|
|
|
+
|
|
|
+ # Output FFT
|
|
|
+ dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
|
|
+ mel_out = self.proj(dec_out)
|
|
|
+ return mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred
|
|
|
+
|
|
|
+ def infer(self, inputs, input_lens, pace: float = 1.0,
|
|
|
+ dur_tgt: Optional[torch.Tensor] = None,
|
|
|
+ pitch_tgt: Optional[torch.Tensor] = None,
|
|
|
+ max_duration: float = 75,
|
|
|
+ speaker: int = 0):
|
|
|
+ del input_lens # unused
|
|
|
+
|
|
|
+ if self.speaker_emb is None:
|
|
|
+ spk_emb = None
|
|
|
+ else:
|
|
|
+ speaker = torch.ones(inputs.size(0), dtype=torch.long, device=inputs.device).fill_(speaker)
|
|
|
+ spk_emb = self.speaker_emb(speaker).unsqueeze(1)
|
|
|
+ spk_emb.mul_(self.speaker_emb_weight)
|
|
|
+
|
|
|
+ # Input FFT
|
|
|
+ enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
|
|
|
+
|
|
|
+ # Embedded for predictors
|
|
|
+ pred_enc_out, pred_enc_mask = enc_out, enc_mask
|
|
|
+
|
|
|
+ # Predict durations
|
|
|
+ log_dur_pred = self.duration_predictor(pred_enc_out, pred_enc_mask)
|
|
|
+ dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
|
|
|
+
|
|
|
+ # Pitch over chars
|
|
|
+ pitch_pred = self.pitch_predictor(enc_out, enc_mask)
|
|
|
+
|
|
|
+ if pitch_tgt is None:
|
|
|
+ pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
|
|
|
+ else:
|
|
|
+ pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
|
|
|
+
|
|
|
+ enc_out = enc_out + pitch_emb
|
|
|
+
|
|
|
+ len_regulated, dec_lens = regulate_len(
|
|
|
+ dur_pred if dur_tgt is None else dur_tgt,
|
|
|
+ enc_out, pace, mel_max_len=None)
|
|
|
+
|
|
|
+ dec_out, dec_mask = self.decoder(len_regulated, dec_lens)
|
|
|
+ mel_out = self.proj(dec_out)
|
|
|
+ # mel_lens = dec_mask.squeeze(2).sum(axis=1).long()
|
|
|
+ mel_out = mel_out.permute(0, 2, 1) # For inference.py
|
|
|
+ return mel_out, dec_lens, dur_pred, pitch_pred
|