features.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import random
  16. import librosa
  17. import torch
  18. import torch.nn as nn
  19. class BaseFeatures(nn.Module):
  20. """Base class for GPU accelerated audio preprocessing."""
  21. __constants__ = ["pad_align", "pad_to_max_duration", "max_len"]
  22. def __init__(self, pad_align, pad_to_max_duration, max_duration,
  23. sample_rate, window_size, window_stride, spec_augment=None,
  24. cutout_augment=None):
  25. super(BaseFeatures, self).__init__()
  26. self.pad_align = pad_align
  27. self.pad_to_max_duration = pad_to_max_duration
  28. self.win_length = int(sample_rate * window_size) # frame size
  29. self.hop_length = int(sample_rate * window_stride)
  30. # Calculate maximum sequence length (# frames)
  31. if pad_to_max_duration:
  32. self.max_len = 1 + math.ceil(
  33. (max_duration * sample_rate - self.win_length) / self.hop_length
  34. )
  35. if spec_augment is not None:
  36. self.spec_augment = SpecAugment(**spec_augment)
  37. else:
  38. self.spec_augment = None
  39. if cutout_augment is not None:
  40. self.cutout_augment = CutoutAugment(**cutout_augment)
  41. else:
  42. self.cutout_augment = None
  43. @torch.no_grad()
  44. def calculate_features(self, audio, audio_lens):
  45. return audio, audio_lens
  46. def __call__(self, audio, audio_lens):
  47. dtype = audio.dtype
  48. audio = audio.float()
  49. feat, feat_lens = self.calculate_features(audio, audio_lens)
  50. feat = self.apply_padding(feat)
  51. if self.cutout_augment is not None:
  52. feat = self.cutout_augment(feat)
  53. if self.spec_augment is not None:
  54. feat = self.spec_augment(feat)
  55. feat = feat.to(dtype)
  56. return feat, feat_lens
  57. def apply_padding(self, x):
  58. if self.pad_to_max_duration:
  59. x_size = max(x.size(-1), self.max_len)
  60. else:
  61. x_size = x.size(-1)
  62. if self.pad_align > 0:
  63. pad_amt = x_size % self.pad_align
  64. else:
  65. pad_amt = 0
  66. padded_len = x_size + (self.pad_align - pad_amt if pad_amt > 0 else 0)
  67. return nn.functional.pad(x, (0, padded_len - x.size(-1)))
  68. class SpecAugment(nn.Module):
  69. """Spec augment. refer to https://arxiv.org/abs/1904.08779
  70. """
  71. def __init__(self, freq_masks=0, min_freq=0, max_freq=10, time_masks=0,
  72. min_time=0, max_time=10):
  73. super(SpecAugment, self).__init__()
  74. assert 0 <= min_freq <= max_freq
  75. assert 0 <= min_time <= max_time
  76. self.freq_masks = freq_masks
  77. self.min_freq = min_freq
  78. self.max_freq = max_freq
  79. self.time_masks = time_masks
  80. self.min_time = min_time
  81. self.max_time = max_time
  82. @torch.no_grad()
  83. def forward(self, x):
  84. sh = x.shape
  85. mask = torch.zeros(x.shape, dtype=torch.bool, device=x.device)
  86. for idx in range(sh[0]):
  87. for _ in range(self.freq_masks):
  88. w = torch.randint(self.min_freq, self.max_freq + 1, size=(1,)).item()
  89. f0 = torch.randint(0, max(1, sh[1] - w), size=(1,))
  90. mask[idx, f0:f0+w] = 1
  91. for _ in range(self.time_masks):
  92. w = torch.randint(self.min_time, self.max_time + 1, size=(1,)).item()
  93. t0 = torch.randint(0, max(1, sh[2] - w), size=(1,))
  94. mask[idx, :, t0:t0+w] = 1
  95. return x.masked_fill(mask, 0)
  96. class CutoutAugment(nn.Module):
  97. """Cutout. refer to https://arxiv.org/pdf/1708.04552.pdf
  98. """
  99. def __init__(self, masks=0, min_freq=20, max_freq=20, min_time=5, max_time=5):
  100. super(CutoutAugment, self).__init__()
  101. assert 0 <= min_freq <= max_freq
  102. assert 0 <= min_time <= max_time
  103. self.masks = masks
  104. self.min_freq = min_freq
  105. self.max_freq = max_freq
  106. self.min_time = min_time
  107. self.max_time = max_time
  108. @torch.no_grad()
  109. def forward(self, x):
  110. sh = x.shape
  111. mask = torch.zeros(x.shape, dtype=torch.bool, device=x.device)
  112. for idx in range(sh[0]):
  113. for i in range(self.masks):
  114. w = torch.randint(self.min_freq, self.max_freq + 1, size=(1,)).item()
  115. h = torch.randint(self.min_time, self.max_time + 1, size=(1,)).item()
  116. f0 = int(random.uniform(0, sh[1] - w))
  117. t0 = int(random.uniform(0, sh[2] - h))
  118. mask[idx, f0:f0+w, t0:t0+h] = 1
  119. return x.masked_fill(mask, 0)
  120. @torch.jit.script
  121. def normalize_batch(x, seq_len, normalize_type: str):
  122. if normalize_type == "per_feature":
  123. x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype,
  124. device=x.device)
  125. x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype,
  126. device=x.device)
  127. for i in range(x.shape[0]):
  128. x_mean[i, :] = x[i, :, :seq_len[i]].mean(dim=1)
  129. x_std[i, :] = x[i, :, :seq_len[i]].std(dim=1)
  130. # make sure x_std is not zero
  131. x_std += 1e-5
  132. return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
  133. elif normalize_type == "all_features":
  134. x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
  135. x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
  136. for i in range(x.shape[0]):
  137. x_mean[i] = x[i, :, :int(seq_len[i])].mean()
  138. x_std[i] = x[i, :, :int(seq_len[i])].std()
  139. # make sure x_std is not zero
  140. x_std += 1e-5
  141. return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1)
  142. else:
  143. return x
  144. @torch.jit.script
  145. def stack_subsample_frames(x, x_lens, stacking: int = 1, subsampling: int = 1):
  146. """ Stacks frames together across feature dim, and then subsamples
  147. input is batch_size, feature_dim, num_frames
  148. output is batch_size, feature_dim * stacking, num_frames / subsampling
  149. """
  150. seq = [x]
  151. for n in range(1, stacking):
  152. tmp = torch.zeros_like(x)
  153. tmp[:, :, :-n] = x[:, :, n:]
  154. seq.append(tmp)
  155. x = torch.cat(seq, dim=1)[:, :, ::subsampling]
  156. if subsampling > 1:
  157. x_lens = torch.ceil(x_lens.float() / subsampling).int()
  158. if x.size(2) > x_lens.max().item():
  159. assert abs(x.size(2) - x_lens.max().item()) <= 1
  160. x = x[:,:,:x_lens.max().item()]
  161. return x, x_lens
  162. class FilterbankFeatures(BaseFeatures):
  163. # For JIT, https://pytorch.org/docs/stable/jit.html#python-defined-constants
  164. __constants__ = ["dither", "preemph", "n_fft", "hop_length", "win_length",
  165. "log", "frame_stacking", "frame_subsampling", "normalize"]
  166. # torchscript: "center" removed due to a bug
  167. def __init__(self, spec_augment=None, cutout_augment=None,
  168. sample_rate=16000, window_size=0.02, window_stride=0.01,
  169. window="hann", normalize="per_feature", n_fft=512,
  170. preemph=0.97, n_filt=80, lowfreq=0, highfreq=None, log=True,
  171. dither=1e-5, pad_align=16, pad_to_max_duration=False,
  172. max_duration=float('inf'), frame_stacking=1,
  173. frame_subsampling=1):
  174. super(FilterbankFeatures, self).__init__(
  175. pad_align=pad_align, pad_to_max_duration=pad_to_max_duration,
  176. max_duration=max_duration, sample_rate=sample_rate,
  177. window_size=window_size, window_stride=window_stride,
  178. spec_augment=spec_augment, cutout_augment=cutout_augment)
  179. torch_windows = {
  180. 'hann': torch.hann_window,
  181. 'hamming': torch.hamming_window,
  182. 'blackman': torch.blackman_window,
  183. 'bartlett': torch.bartlett_window,
  184. 'none': None,
  185. }
  186. self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
  187. self.normalize = normalize
  188. self.log = log
  189. #TORCHSCRIPT: Check whether or not we need this
  190. self.dither = dither
  191. self.frame_stacking = frame_stacking
  192. self.frame_subsampling = frame_subsampling
  193. self.n_filt = n_filt
  194. self.preemph = preemph
  195. highfreq = highfreq or sample_rate / 2
  196. window_fn = torch_windows.get(window, None)
  197. window_tensor = window_fn(self.win_length,
  198. periodic=False) if window_fn else None
  199. filterbanks = torch.tensor(
  200. librosa.filters.mel(sample_rate, self.n_fft, n_mels=n_filt,
  201. fmin=lowfreq, fmax=highfreq),
  202. dtype=torch.float).unsqueeze(0)
  203. # torchscript
  204. self.register_buffer("fb", filterbanks)
  205. self.register_buffer("window", window_tensor)
  206. def output_dim(self):
  207. return self.n_filt * self.frame_stacking
  208. def get_seq_len(self, seq_len):
  209. return torch.ceil(seq_len.to(dtype=torch.float) / self.hop_length).to(
  210. dtype=torch.int)
  211. # TORCHSCRIPT: center removed due to bug
  212. def stft(self, x):
  213. spec = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
  214. win_length=self.win_length,
  215. window=self.window.to(dtype=torch.float),
  216. return_complex=True)
  217. return torch.view_as_real(spec)
  218. @torch.no_grad()
  219. def calculate_features(self, x, x_lens):
  220. dtype = x.dtype
  221. x_lens = self.get_seq_len(x_lens)
  222. # dither
  223. if self.dither > 0:
  224. x += self.dither * torch.randn_like(x)
  225. # do preemphasis
  226. if self.preemph is not None:
  227. x = torch.cat(
  228. x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1],
  229. dim=1)
  230. x = self.stft(x)
  231. # get power spectrum
  232. x = x.pow(2).sum(-1)
  233. # dot with filterbank energies
  234. x = torch.matmul(self.fb.to(x.dtype), x)
  235. # log features if required
  236. if self.log:
  237. x = torch.log(x + 1e-20)
  238. # normalize if required
  239. x = normalize_batch(x, x_lens, normalize_type=self.normalize)
  240. if self.frame_stacking > 1 or self.frame_subsampling > 1:
  241. x, x_lens = stack_subsample_frames(x, x_lens, self.frame_stacking,
  242. self.frame_subsampling)
  243. # mask to zero any values beyond x_lens in batch,
  244. # pad to multiple of `pad_align` (for efficiency)
  245. max_len = x.size(-1)
  246. mask = torch.arange(max_len, dtype=x_lens.dtype, device=x.device)
  247. mask = mask.expand(x.size(0), max_len) >= x_lens.unsqueeze(1)
  248. x = x.masked_fill(mask.unsqueeze(1), 0)
  249. # TORCHSCRIPT: Is this del important? It breaks scripting
  250. # del mask
  251. return x.to(dtype), x_lens