|
|
@@ -244,12 +244,13 @@ class FilterbankFeatures(BaseFeatures):
|
|
|
return torch.ceil(seq_len.to(dtype=torch.float) / self.hop_length).to(
|
|
|
dtype=torch.int)
|
|
|
|
|
|
- # do stft
|
|
|
# TORCHSCRIPT: center removed due to bug
|
|
|
def stft(self, x):
|
|
|
- return torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
|
|
|
+ spec = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
|
|
|
win_length=self.win_length,
|
|
|
- window=self.window.to(dtype=torch.float))
|
|
|
+ window=self.window.to(dtype=torch.float),
|
|
|
+ return_complex=True)
|
|
|
+ return torch.view_as_real(spec)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def calculate_features(self, x, seq_len):
|