Selaa lähdekoodia

[FastPitch/PyT] Fix updated regulate_len

Adrian Lancucki 4 vuotta sitten
vanhempi
sitoutus
8d8c524df6
1 muutettua tiedostoa jossa 3 lisäystä ja 2 poistoa
  1. 3 2
      PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py

+ 3 - 2
PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py

@@ -38,8 +38,9 @@ from fastpitch.transformer import FFTransformer
 def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
     """If target=None, then predicted durations are applied"""
     dtype = enc_out.dtype
-    reps = (durations.float() / pace + 0.5)
-    dec_lens = reps.sum(dim=1).long()
+    reps = durations.float() / pace
+    reps = (reps + 0.5).long()
+    dec_lens = reps.sum(dim=1)
 
     max_len = dec_lens.max()
     reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]