|
|
@@ -13,7 +13,7 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
from apex import amp
|
|
|
-import torch
|
|
|
+import torch
|
|
|
import torch.nn as nn
|
|
|
from parts.features import FeatureFactory
|
|
|
from helpers import Optimization
|
|
|
@@ -50,7 +50,6 @@ def init_weights(m, mode='xavier_uniform'):
|
|
|
def get_same_padding(kernel_size, stride, dilation):
|
|
|
if stride > 1 and dilation > 1:
|
|
|
raise ValueError("Only stride OR dilation may be greater than 1")
|
|
|
-
|
|
|
return (kernel_size // 2) * dilation
|
|
|
|
|
|
class AudioPreprocessing(nn.Module):
|
|
|
@@ -74,7 +73,7 @@ class AudioPreprocessing(nn.Module):
|
|
|
return processed_signal, processed_length
|
|
|
|
|
|
class SpectrogramAugmentation(nn.Module):
|
|
|
- """Spectrogram augmentation
|
|
|
+ """Spectrogram augmentation
|
|
|
"""
|
|
|
def __init__(self, **kwargs):
|
|
|
nn.Module.__init__(self)
|
|
|
@@ -90,11 +89,8 @@ class SpectrogramAugmentation(nn.Module):
|
|
|
class SpecAugment(nn.Module):
|
|
|
"""Spec augment. refer to https://arxiv.org/abs/1904.08779
|
|
|
"""
|
|
|
- def __init__(self, cfg, rng=None):
|
|
|
+ def __init__(self, cfg):
|
|
|
super(SpecAugment, self).__init__()
|
|
|
-
|
|
|
- self._rng = random.Random() if rng is None else rng
|
|
|
-
|
|
|
self.cutout_x_regions = cfg.get('cutout_x_regions', 0)
|
|
|
self.cutout_y_regions = cfg.get('cutout_y_regions', 0)
|
|
|
|
|
|
@@ -108,12 +104,12 @@ class SpecAugment(nn.Module):
|
|
|
mask = torch.zeros(x.shape).byte()
|
|
|
for idx in range(sh[0]):
|
|
|
for _ in range(self.cutout_x_regions):
|
|
|
- cutout_x_left = int(self._rng.uniform(0, sh[1] - self.cutout_x_width))
|
|
|
+ cutout_x_left = int(random.uniform(0, sh[1] - self.cutout_x_width))
|
|
|
|
|
|
mask[idx, cutout_x_left:cutout_x_left + self.cutout_x_width, :] = 1
|
|
|
|
|
|
for _ in range(self.cutout_y_regions):
|
|
|
- cutout_y_left = int(self._rng.uniform(0, sh[2] - self.cutout_y_width))
|
|
|
+ cutout_y_left = int(random.uniform(0, sh[2] - self.cutout_y_width))
|
|
|
|
|
|
mask[idx, :, cutout_y_left:cutout_y_left + self.cutout_y_width] = 1
|
|
|
|
|
|
@@ -124,11 +120,9 @@ class SpecAugment(nn.Module):
|
|
|
class SpecCutoutRegions(nn.Module):
|
|
|
"""Cutout. refer to https://arxiv.org/pdf/1708.04552.pdf
|
|
|
"""
|
|
|
- def __init__(self, cfg, rng=None):
|
|
|
+ def __init__(self, cfg):
|
|
|
super(SpecCutoutRegions, self).__init__()
|
|
|
|
|
|
- self._rng = random.Random() if rng is None else rng
|
|
|
-
|
|
|
self.cutout_rect_regions = cfg.get('cutout_rect_regions', 0)
|
|
|
self.cutout_rect_time = cfg.get('cutout_rect_time', 5)
|
|
|
self.cutout_rect_freq = cfg.get('cutout_rect_freq', 20)
|
|
|
@@ -141,9 +135,9 @@ class SpecCutoutRegions(nn.Module):
|
|
|
|
|
|
for idx in range(sh[0]):
|
|
|
for i in range(self.cutout_rect_regions):
|
|
|
- cutout_rect_x = int(self._rng.uniform(
|
|
|
+ cutout_rect_x = int(random.uniform(
|
|
|
0, sh[1] - self.cutout_rect_freq))
|
|
|
- cutout_rect_y = int(self._rng.uniform(
|
|
|
+ cutout_rect_y = int(random.uniform(
|
|
|
0, sh[2] - self.cutout_rect_time))
|
|
|
|
|
|
mask[idx, cutout_rect_x:cutout_rect_x + self.cutout_rect_freq,
|
|
|
@@ -154,18 +148,19 @@ class SpecCutoutRegions(nn.Module):
|
|
|
return x
|
|
|
|
|
|
class JasperEncoder(nn.Module):
|
|
|
- """Jasper encoder
|
|
|
+
|
|
|
+ """Jasper encoder
|
|
|
"""
|
|
|
def __init__(self, **kwargs):
|
|
|
cfg = {}
|
|
|
for key, value in kwargs.items():
|
|
|
cfg[key] = value
|
|
|
|
|
|
- nn.Module.__init__(self)
|
|
|
+ nn.Module.__init__(self)
|
|
|
self._cfg = cfg
|
|
|
|
|
|
activation = jasper_activations[cfg['encoder']['activation']]()
|
|
|
- use_conv_mask = cfg['encoder'].get('convmask', False)
|
|
|
+ self.use_conv_mask = cfg['encoder'].get('convmask', False)
|
|
|
feat_in = cfg['input']['features'] * cfg['input'].get('frame_splicing', 1)
|
|
|
init_mode = cfg.get('init_mode', 'xavier_uniform')
|
|
|
|
|
|
@@ -183,7 +178,7 @@ class JasperEncoder(nn.Module):
|
|
|
kernel_size=lcfg['kernel'], stride=lcfg['stride'],
|
|
|
dilation=lcfg['dilation'], dropout=lcfg['dropout'],
|
|
|
residual=lcfg['residual'], activation=activation,
|
|
|
- residual_panes=dense_res, conv_mask=use_conv_mask))
|
|
|
+ residual_panes=dense_res, use_conv_mask=self.use_conv_mask))
|
|
|
feat_in = lcfg['filters']
|
|
|
|
|
|
self.encoder = nn.Sequential(*encoder_layers)
|
|
|
@@ -193,106 +188,146 @@ class JasperEncoder(nn.Module):
|
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- audio_signal, length = x
|
|
|
- s_input, length = self.encoder(([audio_signal], length))
|
|
|
- return s_input, length
|
|
|
+ if self.use_conv_mask:
|
|
|
+ audio_signal, length = x
|
|
|
+ return self.encoder(([audio_signal], length))
|
|
|
+ else:
|
|
|
+ return self.encoder([x])
|
|
|
|
|
|
class JasperDecoderForCTC(nn.Module):
|
|
|
- """Jasper decoder
|
|
|
+ """Jasper decoder
|
|
|
"""
|
|
|
def __init__(self, **kwargs):
|
|
|
- nn.Module.__init__(self)
|
|
|
+ nn.Module.__init__(self)
|
|
|
self._feat_in = kwargs.get("feat_in")
|
|
|
self._num_classes = kwargs.get("num_classes")
|
|
|
init_mode = kwargs.get('init_mode', 'xavier_uniform')
|
|
|
|
|
|
self.decoder_layers = nn.Sequential(
|
|
|
- nn.Conv1d(self._feat_in, self._num_classes, kernel_size=1, bias=True),
|
|
|
- nn.LogSoftmax(dim=1))
|
|
|
+ nn.Conv1d(self._feat_in, self._num_classes, kernel_size=1, bias=True),)
|
|
|
self.apply(lambda x: init_weights(x, mode=init_mode))
|
|
|
|
|
|
-
|
|
|
def num_weights(self):
|
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
|
def forward(self, encoder_output):
|
|
|
- out = self.decoder_layers(encoder_output[-1])
|
|
|
- return out.transpose(1, 2)
|
|
|
+ out = self.decoder_layers(encoder_output[-1]).transpose(1, 2)
|
|
|
+ return nn.functional.log_softmax(out, dim=2)
|
|
|
|
|
|
class Jasper(nn.Module):
|
|
|
- """Contains data preprocessing, spectrogram augmentation, jasper encoder and decoder
|
|
|
+ """Contains data preprocessing, spectrogram augmentation, jasper encoder and decoder
|
|
|
"""
|
|
|
def __init__(self, **kwargs):
|
|
|
- nn.Module.__init__(self)
|
|
|
- self.audio_preprocessor = AudioPreprocessing(**kwargs.get("feature_config"))
|
|
|
+ nn.Module.__init__(self)
|
|
|
+ if kwargs.get("no_featurizer", False):
|
|
|
+ self.audio_preprocessor = None
|
|
|
+ else:
|
|
|
+ self.audio_preprocessor = AudioPreprocessing(**kwargs.get("feature_config"))
|
|
|
+
|
|
|
self.data_spectr_augmentation = SpectrogramAugmentation(**kwargs.get("feature_config"))
|
|
|
self.jasper_encoder = JasperEncoder(**kwargs.get("jasper_model_definition"))
|
|
|
self.jasper_decoder = JasperDecoderForCTC(feat_in=kwargs.get("feat_in"),
|
|
|
- num_classes=kwargs.get("num_classes"))
|
|
|
+ num_classes=kwargs.get("num_classes"))
|
|
|
+ self.acoustic_model = JasperAcousticModel(self.jasper_encoder, self.jasper_decoder)
|
|
|
|
|
|
def num_weights(self):
|
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- input_signal, length = x
|
|
|
- t_processed_signal, p_length_t = self.audio_preprocessor(x)
|
|
|
+
|
|
|
+ # Apply optional preprocessing
|
|
|
+ if self.audio_preprocessor is not None:
|
|
|
+ t_processed_signal, p_length_t = self.audio_preprocessor(x)
|
|
|
+ # Apply optional spectral augmentation
|
|
|
if self.training:
|
|
|
t_processed_signal = self.data_spectr_augmentation(input_spec=t_processed_signal)
|
|
|
- t_encoded_t, t_encoded_len_t = self.jasper_encoder((t_processed_signal, p_length_t))
|
|
|
- return self.jasper_decoder(encoder_output=t_encoded_t), t_encoded_len_t
|
|
|
+
|
|
|
+ if (self.jasper_encoder.use_conv_mask):
|
|
|
+ a_inp = (t_processed_signal, p_length_t)
|
|
|
+ else:
|
|
|
+ a_inp = t_processed_signal
|
|
|
+ # Forward Pass through Encoder-Decoder
|
|
|
+ return self.acoustic_model.forward(a_inp)
|
|
|
+
|
|
|
+
|
|
|
+class JasperAcousticModel(nn.Module):
|
|
|
+ def __init__(self, enc, dec, transpose_in=False):
|
|
|
+ nn.Module.__init__(self)
|
|
|
+ self.jasper_encoder = enc
|
|
|
+ self.jasper_decoder = dec
|
|
|
+ self.transpose_in = transpose_in
|
|
|
+ def forward(self, x):
|
|
|
+ if self.jasper_encoder.use_conv_mask:
|
|
|
+ t_encoded_t, t_encoded_len_t = self.jasper_encoder(x)
|
|
|
+ else:
|
|
|
+ if self.transpose_in:
|
|
|
+ x = x.transpose(1, 2)
|
|
|
+ t_encoded_t = self.jasper_encoder(x)
|
|
|
+
|
|
|
+ out = self.jasper_decoder(encoder_output=t_encoded_t)
|
|
|
+ if self.jasper_encoder.use_conv_mask:
|
|
|
+ return out, t_encoded_len_t
|
|
|
+ else:
|
|
|
+ return out
|
|
|
|
|
|
class JasperEncoderDecoder(nn.Module):
|
|
|
- """Contains jasper encoder and decoder
|
|
|
+ """Contains jasper encoder and decoder
|
|
|
"""
|
|
|
def __init__(self, **kwargs):
|
|
|
- nn.Module.__init__(self)
|
|
|
+ nn.Module.__init__(self)
|
|
|
self.jasper_encoder = JasperEncoder(**kwargs.get("jasper_model_definition"))
|
|
|
self.jasper_decoder = JasperDecoderForCTC(feat_in=kwargs.get("feat_in"),
|
|
|
- num_classes=kwargs.get("num_classes"))
|
|
|
+ num_classes=kwargs.get("num_classes"))
|
|
|
+ self.acoustic_model = JasperAcousticModel(self.jasper_encoder,
|
|
|
+ self.jasper_decoder,
|
|
|
+ kwargs.get("transpose_in", False))
|
|
|
+
|
|
|
def num_weights(self):
|
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
|
def forward(self, x):
|
|
|
- t_processed_signal, p_length_t = x
|
|
|
- t_encoded_t, t_encoded_len_t = self.jasper_encoder((t_processed_signal, p_length_t))
|
|
|
- return self.jasper_decoder(encoder_output=t_encoded_t), t_encoded_len_t
|
|
|
+ return self.acoustic_model.forward(x)
|
|
|
|
|
|
class MaskedConv1d(nn.Conv1d):
|
|
|
- """1D convolution with sequence masking
|
|
|
+ """1D convolution with sequence masking
|
|
|
"""
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
|
|
- padding=0, dilation=1, groups=1, bias=False, use_mask=True):
|
|
|
+ padding=0, dilation=1, groups=1, bias=False, use_conv_mask=True):
|
|
|
super(MaskedConv1d, self).__init__(in_channels, out_channels, kernel_size,
|
|
|
stride=stride,
|
|
|
padding=padding, dilation=dilation,
|
|
|
groups=groups, bias=bias)
|
|
|
- self.use_mask = use_mask
|
|
|
+ self.use_conv_mask = use_conv_mask
|
|
|
|
|
|
def get_seq_len(self, lens):
|
|
|
return ((lens + 2 * self.padding[0] - self.dilation[0] * (
|
|
|
self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
|
|
|
|
|
def forward(self, inp):
|
|
|
- x, lens = inp
|
|
|
- if self.use_mask:
|
|
|
+ if self.use_conv_mask:
|
|
|
+ x, lens = inp
|
|
|
max_len = x.size(2)
|
|
|
- mask = torch.arange(max_len).to(lens.dtype).to(lens.device).expand(len(lens),
|
|
|
- max_len) >= lens.unsqueeze(
|
|
|
- 1)
|
|
|
+ idxs = torch.arange(max_len).to(lens.dtype).to(lens.device).expand(len(lens), max_len)
|
|
|
+ mask = idxs >= lens.unsqueeze(1)
|
|
|
x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0)
|
|
|
del mask
|
|
|
-
|
|
|
+ del idxs
|
|
|
lens = self.get_seq_len(lens)
|
|
|
-
|
|
|
+ else:
|
|
|
+ x = inp
|
|
|
out = super(MaskedConv1d, self).forward(x)
|
|
|
- return out, lens
|
|
|
+
|
|
|
+ if self.use_conv_mask:
|
|
|
+ return out, lens
|
|
|
+ else:
|
|
|
+ return out
|
|
|
|
|
|
class JasperBlock(nn.Module):
|
|
|
"""Jasper Block. See https://arxiv.org/pdf/1904.03288.pdf
|
|
|
"""
|
|
|
def __init__(self, inplanes, planes, repeat=3, kernel_size=11, stride=1,
|
|
|
dilation=1, padding='same', dropout=0.2, activation=None,
|
|
|
- residual=True, residual_panes=[], conv_mask=False):
|
|
|
+ residual=True, residual_panes=[], use_conv_mask=False):
|
|
|
super(JasperBlock, self).__init__()
|
|
|
|
|
|
if padding != "same":
|
|
|
@@ -300,7 +335,7 @@ class JasperBlock(nn.Module):
|
|
|
|
|
|
|
|
|
padding_val = get_same_padding(kernel_size[0], stride[0], dilation[0])
|
|
|
- self.conv_mask = conv_mask
|
|
|
+ self.use_conv_mask = use_conv_mask
|
|
|
self.conv = nn.ModuleList()
|
|
|
inplanes_loop = inplanes
|
|
|
for _ in range(repeat - 1):
|
|
|
@@ -334,7 +369,7 @@ class JasperBlock(nn.Module):
|
|
|
layers = [
|
|
|
MaskedConv1d(in_channels, out_channels, kernel_size, stride=stride,
|
|
|
dilation=dilation, padding=padding, bias=bias,
|
|
|
- use_mask=self.conv_mask),
|
|
|
+ use_conv_mask=self.use_conv_mask),
|
|
|
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.1)
|
|
|
]
|
|
|
return layers
|
|
|
@@ -352,13 +387,16 @@ class JasperBlock(nn.Module):
|
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
|
def forward(self, input_):
|
|
|
-
|
|
|
- xs, lens_orig = input_
|
|
|
+ if self.use_conv_mask:
|
|
|
+ xs, lens_orig = input_
|
|
|
+ else:
|
|
|
+ xs = input_
|
|
|
+ lens_orig = 0
|
|
|
# compute forward convolutions
|
|
|
out = xs[-1]
|
|
|
lens = lens_orig
|
|
|
for i, l in enumerate(self.conv):
|
|
|
- if isinstance(l, MaskedConv1d):
|
|
|
+ if self.use_conv_mask and isinstance(l, MaskedConv1d):
|
|
|
out, lens = l((out, lens))
|
|
|
else:
|
|
|
out = l(out)
|
|
|
@@ -367,7 +405,7 @@ class JasperBlock(nn.Module):
|
|
|
for i, layer in enumerate(self.res):
|
|
|
res_out = xs[i]
|
|
|
for j, res_layer in enumerate(layer):
|
|
|
- if j == 0:
|
|
|
+ if j == 0 and self.use_conv_mask:
|
|
|
res_out, _ = res_layer((res_out, lens_orig))
|
|
|
else:
|
|
|
res_out = res_layer(res_out)
|
|
|
@@ -376,9 +414,14 @@ class JasperBlock(nn.Module):
|
|
|
# compute the output
|
|
|
out = self.out(out)
|
|
|
if self.res is not None and self.dense_residual:
|
|
|
- return xs + [out], lens
|
|
|
+ out = xs + [out]
|
|
|
+ else:
|
|
|
+ out = [out]
|
|
|
|
|
|
- return [out], lens
|
|
|
+ if self.use_conv_mask:
|
|
|
+ return out, lens
|
|
|
+ else:
|
|
|
+ return out
|
|
|
|
|
|
class GreedyCTCDecoder(nn.Module):
|
|
|
""" Greedy CTC Decoder
|