Bladeren bron

[Jasper/PyT] Update torch.stft for PyTorch 2.0

Adrian Lancucki 2 jaren geleden
bovenliggende
commit
29aaae3285

+ 4 - 3
PyTorch/SpeechRecognition/Jasper/common/features.py

@@ -244,12 +244,13 @@ class FilterbankFeatures(BaseFeatures):
         return torch.ceil(seq_len.to(dtype=torch.float) / self.hop_length).to(
             dtype=torch.int)
 
-    # do stft
     # TORCHSCRIPT: center removed due to bug
     def stft(self, x):
-        return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
+        spec = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
                           win_length=self.win_length,
-                          window=self.window.to(dtype=torch.float))
+                          window=self.window.to(dtype=torch.float),
+                          return_complex=True)
+        return torch.view_as_real(spec)
 
     @torch.no_grad()
     def calculate_features(self, x, seq_len):

+ 4 - 3
PyTorch/SpeechRecognition/QuartzNet/common/features.py

@@ -248,12 +248,13 @@ class FilterbankFeatures(BaseFeatures):
         return torch.ceil(seq_len.to(dtype=torch.float) / self.hop_length).to(
             dtype=torch.int)
 
-    # do stft
     # TORCHSCRIPT: center removed due to bug
     def stft(self, x):
-        return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
+        spec = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
                           win_length=self.win_length,
-                          window=self.window.to(dtype=torch.float))
+                          window=self.window.to(dtype=torch.float),
+                          return_complex=True)
+        return torch.view_as_real(spec)
 
     @torch.no_grad()
     def calculate_features(self, x, seq_len):

+ 4 - 3
PyTorch/SpeechRecognition/wav2vec2/common/features.py

@@ -261,12 +261,13 @@ class FilterbankFeatures(BaseFeatures):
         return torch.ceil(seq_len.to(dtype=torch.float) / self.hop_length).to(
             dtype=torch.int)
 
-    # do stft
     # TORCHSCRIPT: center removed due to bug
     def stft(self, x):
-        return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
+        spec = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
                           win_length=self.win_length,
-                          window=self.window.to(dtype=torch.float))
+                          window=self.window.to(dtype=torch.float),
+                          return_complex=True)
+        return torch.view_as_real(spec)
 
     @torch.no_grad()
     def calculate_features(self, x, x_lens):