features.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import math
  2. import random
  3. import librosa
  4. import torch
  5. import torch.nn as nn
  6. class BaseFeatures(nn.Module):
  7. """Base class for GPU accelerated audio preprocessing."""
  8. __constants__ = ["pad_align", "pad_to_max_duration", "max_len"]
  9. def __init__(self, pad_align, pad_to_max_duration, max_duration,
  10. sample_rate, window_size, window_stride, spec_augment=None,
  11. cutout_augment=None):
  12. super(BaseFeatures, self).__init__()
  13. self.pad_align = pad_align
  14. self.pad_to_max_duration = pad_to_max_duration
  15. self.win_length = int(sample_rate * window_size) # frame size
  16. self.hop_length = int(sample_rate * window_stride)
  17. # Calculate maximum sequence length (# frames)
  18. if pad_to_max_duration:
  19. self.max_len = 1 + math.ceil(
  20. (max_duration * sample_rate - self.win_length) / self.hop_length
  21. )
  22. if spec_augment is not None:
  23. self.spec_augment = SpecAugment(**spec_augment)
  24. else:
  25. self.spec_augment = None
  26. if cutout_augment is not None:
  27. self.cutout_augment = CutoutAugment(**cutout_augment)
  28. else:
  29. self.cutout_augment = None
  30. @torch.no_grad()
  31. def calculate_features(self, audio, audio_lens):
  32. return audio, audio_lens
  33. def __call__(self, audio, audio_lens):
  34. dtype = audio.dtype
  35. audio = audio.float()
  36. feat, feat_lens = self.calculate_features(audio, audio_lens)
  37. feat = self.apply_padding(feat)
  38. if self.cutout_augment is not None:
  39. feat = self.cutout_augment(feat)
  40. if self.spec_augment is not None:
  41. feat = self.spec_augment(feat)
  42. feat = feat.to(dtype)
  43. return feat, feat_lens
  44. def apply_padding(self, x):
  45. if self.pad_to_max_duration:
  46. x_size = max(x.size(-1), self.max_len)
  47. else:
  48. x_size = x.size(-1)
  49. if self.pad_align > 0:
  50. pad_amt = x_size % self.pad_align
  51. else:
  52. pad_amt = 0
  53. padded_len = x_size + (self.pad_align - pad_amt if pad_amt > 0 else 0)
  54. return nn.functional.pad(x, (0, padded_len - x.size(-1)))
  55. class SpecAugment(nn.Module):
  56. """Spec augment. refer to https://arxiv.org/abs/1904.08779
  57. """
  58. def __init__(self, freq_masks=0, min_freq=0, max_freq=10, time_masks=0,
  59. min_time=0, max_time=10):
  60. super(SpecAugment, self).__init__()
  61. assert 0 <= min_freq <= max_freq
  62. assert 0 <= min_time <= max_time
  63. self.freq_masks = freq_masks
  64. self.min_freq = min_freq
  65. self.max_freq = max_freq
  66. self.time_masks = time_masks
  67. self.min_time = min_time
  68. self.max_time = max_time
  69. @torch.no_grad()
  70. def forward(self, x):
  71. sh = x.shape
  72. mask = torch.zeros(x.shape, dtype=torch.bool, device=x.device)
  73. for idx in range(sh[0]):
  74. for _ in range(self.freq_masks):
  75. w = torch.randint(self.min_freq, self.max_freq + 1, size=(1,)).item()
  76. f0 = torch.randint(0, max(1, sh[1] - w), size=(1,))
  77. mask[idx, f0:f0+w] = 1
  78. for _ in range(self.time_masks):
  79. w = torch.randint(self.min_time, self.max_time + 1, size=(1,)).item()
  80. t0 = torch.randint(0, max(1, sh[2] - w), size=(1,))
  81. mask[idx, :, t0:t0+w] = 1
  82. return x.masked_fill(mask, 0)
  83. class CutoutAugment(nn.Module):
  84. """Cutout. refer to https://arxiv.org/pdf/1708.04552.pdf
  85. """
  86. def __init__(self, masks=0, min_freq=20, max_freq=20, min_time=5, max_time=5):
  87. super(CutoutAugment, self).__init__()
  88. assert 0 <= min_freq <= max_freq
  89. assert 0 <= min_time <= max_time
  90. self.masks = masks
  91. self.min_freq = min_freq
  92. self.max_freq = max_freq
  93. self.min_time = min_time
  94. self.max_time = max_time
  95. @torch.no_grad()
  96. def forward(self, x):
  97. sh = x.shape
  98. mask = torch.zeros(x.shape, dtype=torch.bool, device=x.device)
  99. for idx in range(sh[0]):
  100. for i in range(self.masks):
  101. w = torch.randint(self.min_freq, self.max_freq + 1, size=(1,)).item()
  102. h = torch.randint(self.min_time, self.max_time + 1, size=(1,)).item()
  103. f0 = int(random.uniform(0, sh[1] - w))
  104. t0 = int(random.uniform(0, sh[2] - h))
  105. mask[idx, f0:f0+w, t0:t0+h] = 1
  106. return x.masked_fill(mask, 0)
  107. @torch.jit.script
  108. def normalize_batch(x, seq_len, normalize_type: str):
  109. # print ("normalize_batch: x, seq_len, shapes: ", x.shape, seq_len, seq_len.shape)
  110. if normalize_type == "per_feature":
  111. x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype,
  112. device=x.device)
  113. x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype,
  114. device=x.device)
  115. for i in range(x.shape[0]):
  116. x_mean[i, :] = x[i, :, :seq_len[i]].mean(dim=1)
  117. x_std[i, :] = x[i, :, :seq_len[i]].std(dim=1)
  118. # make sure x_std is not zero
  119. x_std += 1e-5
  120. return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
  121. elif normalize_type == "all_features":
  122. x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
  123. x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
  124. for i in range(x.shape[0]):
  125. x_mean[i] = x[i, :, :int(seq_len[i])].mean()
  126. x_std[i] = x[i, :, :int(seq_len[i])].std()
  127. # make sure x_std is not zero
  128. x_std += 1e-5
  129. return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1)
  130. else:
  131. return x
  132. @torch.jit.script
  133. def stack_subsample_frames(x, x_lens, stacking: int = 1, subsampling: int = 1):
  134. """ Stacks frames together across feature dim, and then subsamples
  135. input is batch_size, feature_dim, num_frames
  136. output is batch_size, feature_dim * stacking, num_frames / subsampling
  137. """
  138. seq = [x]
  139. for n in range(1, stacking):
  140. tmp = torch.zeros_like(x)
  141. tmp[:, :, :-n] = x[:, :, n:]
  142. seq.append(tmp)
  143. x = torch.cat(seq, dim=1)[:, :, ::subsampling]
  144. if subsampling > 1:
  145. x_lens = torch.ceil(x_lens.float() / subsampling).int()
  146. if x.size(2) > x_lens.max().item():
  147. assert abs(x.size(2) - x_lens.max().item()) <= 1
  148. x = x[:,:,:x_lens.max().item()]
  149. return x, x_lens
  150. class FilterbankFeatures(BaseFeatures):
  151. # For JIT, https://pytorch.org/docs/stable/jit.html#python-defined-constants
  152. __constants__ = ["dither", "preemph", "n_fft", "hop_length", "win_length",
  153. "log", "frame_splicing", "normalize"]
  154. # torchscript: "center" removed due to a bug
  155. def __init__(self, spec_augment=None, cutout_augment=None,
  156. sample_rate=8000, window_size=0.02, window_stride=0.01,
  157. window="hamming", normalize="per_feature", n_fft=None,
  158. preemph=0.97, n_filt=64, lowfreq=0, highfreq=None, log=True,
  159. dither=1e-5, pad_align=8, pad_to_max_duration=False,
  160. max_duration=float('inf'), frame_splicing=1):
  161. super(FilterbankFeatures, self).__init__(
  162. pad_align=pad_align, pad_to_max_duration=pad_to_max_duration,
  163. max_duration=max_duration, sample_rate=sample_rate,
  164. window_size=window_size, window_stride=window_stride,
  165. spec_augment=spec_augment, cutout_augment=cutout_augment)
  166. torch_windows = {
  167. 'hann': torch.hann_window,
  168. 'hamming': torch.hamming_window,
  169. 'blackman': torch.blackman_window,
  170. 'bartlett': torch.bartlett_window,
  171. 'none': None,
  172. }
  173. self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
  174. self.normalize = normalize
  175. self.log = log
  176. #TORCHSCRIPT: Check whether or not we need this
  177. self.dither = dither
  178. self.frame_splicing = frame_splicing
  179. self.n_filt = n_filt
  180. self.preemph = preemph
  181. highfreq = highfreq or sample_rate / 2
  182. window_fn = torch_windows.get(window, None)
  183. window_tensor = window_fn(self.win_length,
  184. periodic=False) if window_fn else None
  185. filterbanks = torch.tensor(
  186. librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=n_filt,
  187. fmin=lowfreq, fmax=highfreq),
  188. dtype=torch.float).unsqueeze(0)
  189. # torchscript
  190. self.register_buffer("fb", filterbanks)
  191. self.register_buffer("window", window_tensor)
  192. def get_seq_len(self, seq_len):
  193. return torch.ceil(seq_len.to(dtype=torch.float) / self.hop_length).to(
  194. dtype=torch.int)
  195. # TORCHSCRIPT: center removed due to bug
  196. def stft(self, x):
  197. spec = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
  198. win_length=self.win_length,
  199. window=self.window.to(dtype=torch.float),
  200. return_complex=True)
  201. return torch.view_as_real(spec)
  202. @torch.no_grad()
  203. def calculate_features(self, x, seq_len):
  204. dtype = x.dtype
  205. seq_len = self.get_seq_len(seq_len)
  206. # dither
  207. if self.dither > 0:
  208. x += self.dither * torch.randn_like(x)
  209. # do preemphasis
  210. if self.preemph is not None:
  211. x = torch.cat(
  212. (x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
  213. x = self.stft(x)
  214. # get power spectrum
  215. x = x.pow(2).sum(-1)
  216. # dot with filterbank energies
  217. x = torch.matmul(self.fb.to(x.dtype), x)
  218. # log features if required
  219. if self.log:
  220. x = torch.log(x + 1e-20)
  221. # frame splicing if required
  222. if self.frame_splicing > 1:
  223. raise ValueError('Frame splicing not supported')
  224. # normalize if required
  225. x = normalize_batch(x, seq_len, normalize_type=self.normalize)
  226. # mask to zero any values beyond seq_len in batch,
  227. # pad to multiple of `pad_align` (for efficiency)
  228. max_len = x.size(-1)
  229. mask = torch.arange(max_len, dtype=seq_len.dtype, device=x.device)
  230. mask = mask.expand(x.size(0), max_len) >= seq_len.unsqueeze(1)
  231. x = x.masked_fill(mask.unsqueeze(1), 0)
  232. # TORCHSCRIPT: Is this del important? It breaks scripting
  233. # del mask
  234. return x.to(dtype), seq_len