Procházet zdrojové kódy

[FastPitch/PyT] Fix updated regulate_len

Adrian Lancucki před 4 roky
rodič
revize
8d8c524df6

+ 3 - 2
PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py

@@ -38,8 +38,9 @@ from fastpitch.transformer import FFTransformer
 def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
 def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
     """If target=None, then predicted durations are applied"""
     """If target=None, then predicted durations are applied"""
     dtype = enc_out.dtype
     dtype = enc_out.dtype
-    reps = (durations.float() / pace + 0.5)
-    dec_lens = reps.sum(dim=1).long()
+    reps = durations.float() / pace
+    reps = (reps + 0.5).long()
+    dec_lens = reps.sum(dim=1)
 
 
     max_len = dec_lens.max()
     max_len = dec_lens.max()
     reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
     reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]