model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from common import filter_warnings
  18. activations = {
  19. "hardtanh": nn.Hardtanh,
  20. "relu": nn.ReLU,
  21. "selu": nn.SELU,
  22. }
  23. def init_weights(m, mode='xavier_uniform'):
  24. if type(m) == nn.Conv1d or type(m) == MaskedConv1d:
  25. if mode == 'xavier_uniform':
  26. nn.init.xavier_uniform_(m.weight, gain=1.0)
  27. elif mode == 'xavier_normal':
  28. nn.init.xavier_normal_(m.weight, gain=1.0)
  29. elif mode == 'kaiming_uniform':
  30. nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
  31. elif mode == 'kaiming_normal':
  32. nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
  33. else:
  34. raise ValueError("Unknown Initialization mode: {0}".format(mode))
  35. elif type(m) == nn.BatchNorm1d:
  36. if m.track_running_stats:
  37. m.running_mean.zero_()
  38. m.running_var.fill_(1)
  39. m.num_batches_tracked.zero_()
  40. if m.affine:
  41. nn.init.ones_(m.weight)
  42. nn.init.zeros_(m.bias)
  43. def compute_new_kernel_size(kernel_size, kernel_width):
  44. new_kernel_size = max(int(kernel_size * kernel_width), 1)
  45. # If kernel is even shape, round up to make it odd
  46. if new_kernel_size % 2 == 0:
  47. new_kernel_size += 1
  48. return new_kernel_size
  49. def get_same_padding(kernel_size, stride, dilation):
  50. if stride > 1 and dilation > 1:
  51. raise ValueError("Only stride OR dilation may be greater than 1")
  52. return (kernel_size // 2) * dilation
  53. class GroupShuffle(nn.Module):
  54. def __init__(self, groups, channels):
  55. super(GroupShuffle, self).__init__()
  56. self.groups = groups
  57. self.channels_per_group = channels // groups
  58. def forward(self, x):
  59. sh = x.shape
  60. x = x.view(-1, self.groups, self.channels_per_group, sh[-1])
  61. x = torch.transpose(x, 1, 2).contiguous()
  62. x = x.view(-1, self.groups * self.channels_per_group, sh[-1])
  63. return x
  64. class MaskedConv1d(nn.Conv1d):
  65. """1D convolution with sequence masking
  66. """
  67. __constants__ = ["masked"]
  68. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  69. padding=0, dilation=1, groups=1, bias=False, use_mask=True,
  70. heads=-1):
  71. # Jasper refactor compat
  72. assert heads == -1 # Unsupported
  73. masked = use_mask
  74. super(MaskedConv1d, self).__init__(
  75. in_channels, out_channels, kernel_size, stride=stride,
  76. padding=padding, dilation=dilation, groups=groups, bias=bias)
  77. self.masked = masked
  78. def get_seq_len(self, lens):
  79. pad, ks = self.padding[0], self.kernel_size[0]
  80. return torch.div(lens + 2 * pad - self.dilation[0] * (ks - 1) - 1,
  81. self.stride[0], rounding_mode='trunc') + 1
  82. def forward(self, x, x_lens=None):
  83. if self.masked:
  84. max_len = x.size(2)
  85. idxs = torch.arange(max_len, dtype=x_lens.dtype, device=x.device)
  86. mask = idxs.expand(x_lens.size(0), max_len) >= x_lens.unsqueeze(1)
  87. x = x.masked_fill(mask.unsqueeze(1).to(device=x.device), 0)
  88. x_lens = self.get_seq_len(x_lens)
  89. return super(MaskedConv1d, self).forward(x), x_lens
  90. class JasperBlock(nn.Module):
  91. __constants__ = ["conv_mask", "separable", "res", "mconv"]
  92. def __init__(self, infilters, filters, repeat=3, kernel_size=11,
  93. kernel_size_factor=1, stride=1, dilation=1, padding='same',
  94. dropout=0.2, activation=None, residual=True, groups=1,
  95. separable=False, heads=-1, normalization="batch",
  96. norm_groups=1, residual_panes=[], use_conv_masks=False):
  97. super(JasperBlock, self).__init__()
  98. # Fix params being passed as list, but default to ints
  99. wrap = lambda v: [v] if type(v) is int else v
  100. kernel_size = wrap(kernel_size)
  101. dilation = wrap(dilation)
  102. padding = wrap(padding)
  103. stride = wrap(stride)
  104. if padding != "same":
  105. raise ValueError("currently only 'same' padding is supported")
  106. kernel_size_factor = float(kernel_size_factor)
  107. if type(kernel_size) in (list, tuple):
  108. kernel_size = [compute_new_kernel_size(k, kernel_size_factor)
  109. for k in kernel_size]
  110. else:
  111. kernel_size = compute_new_kernel_size(kernel_size,
  112. kernel_size_factor)
  113. padding_val = get_same_padding(kernel_size[0], stride[0], dilation[0])
  114. self.conv_mask = use_conv_masks
  115. self.separable = separable
  116. infilters_loop = infilters
  117. conv = nn.ModuleList()
  118. for _ in range(repeat - 1):
  119. conv.extend(
  120. self._get_conv_bn_layer(
  121. infilters_loop, filters, kernel_size=kernel_size,
  122. stride=stride, dilation=dilation, padding=padding_val,
  123. groups=groups, heads=heads, separable=separable,
  124. normalization=normalization, norm_groups=norm_groups)
  125. )
  126. conv.extend(self._get_act_dropout_layer(drop_prob=dropout,
  127. activation=activation))
  128. infilters_loop = filters
  129. conv.extend(
  130. self._get_conv_bn_layer(
  131. infilters_loop, filters, kernel_size=kernel_size, stride=stride,
  132. dilation=dilation, padding=padding_val, groups=groups,
  133. heads=heads, separable=separable, normalization=normalization,
  134. norm_groups=norm_groups)
  135. )
  136. self.mconv = conv
  137. res_panes = residual_panes.copy()
  138. self.dense_residual = residual
  139. if residual:
  140. res_list = nn.ModuleList()
  141. if len(residual_panes) == 0:
  142. res_panes = [infilters]
  143. self.dense_residual = False
  144. for ip in res_panes:
  145. res_list.append(nn.ModuleList(
  146. self._get_conv_bn_layer(ip, filters, kernel_size=1,
  147. normalization=normalization,
  148. norm_groups=norm_groups, stride=[1])
  149. ))
  150. self.res = res_list
  151. else:
  152. self.res = None
  153. self.mout = nn.Sequential(*self._get_act_dropout_layer(
  154. drop_prob=dropout, activation=activation))
  155. def _get_conv(self, in_channels, out_channels, kernel_size=11, stride=1,
  156. dilation=1, padding=0, bias=False, groups=1, heads=-1,
  157. separable=False):
  158. kw = {'in_channels': in_channels, 'out_channels': out_channels,
  159. 'kernel_size': kernel_size, 'stride': stride, 'dilation': dilation,
  160. 'padding': padding, 'bias': bias, 'groups': groups}
  161. if self.conv_mask:
  162. return MaskedConv1d(**kw, heads=heads, use_mask=self.conv_mask)
  163. else:
  164. return nn.Conv1d(**kw)
  165. def _get_conv_bn_layer(self, in_channels, out_channels, kernel_size=11,
  166. stride=1, dilation=1, padding=0, bias=False,
  167. groups=1, heads=-1, separable=False,
  168. normalization="batch", norm_groups=1):
  169. if norm_groups == -1:
  170. norm_groups = out_channels
  171. if separable:
  172. layers = [
  173. self._get_conv(in_channels, in_channels, kernel_size,
  174. stride=stride, dilation=dilation, padding=padding,
  175. bias=bias, groups=in_channels, heads=heads),
  176. self._get_conv(in_channels, out_channels, kernel_size=1,
  177. stride=1, dilation=1, padding=0, bias=bias,
  178. groups=groups),
  179. ]
  180. else:
  181. layers = [
  182. self._get_conv(in_channels, out_channels, kernel_size,
  183. stride=stride, dilation=dilation,
  184. padding=padding, bias=bias, groups=groups)
  185. ]
  186. if normalization == "group":
  187. layers.append(nn.GroupNorm(num_groups=norm_groups,
  188. num_channels=out_channels))
  189. elif normalization == "instance":
  190. layers.append(nn.GroupNorm(num_groups=out_channels,
  191. num_channels=out_channels))
  192. elif normalization == "layer":
  193. layers.append(nn.GroupNorm(num_groups=1, num_channels=out_channels))
  194. elif normalization == "batch":
  195. layers.append(nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.1))
  196. else:
  197. raise ValueError(
  198. f"Normalization method ({normalization}) does not match"
  199. f" one of [batch, layer, group, instance]."
  200. )
  201. if groups > 1:
  202. layers.append(GroupShuffle(groups, out_channels))
  203. return layers
  204. def _get_act_dropout_layer(self, drop_prob=0.2, activation=None):
  205. if activation is None:
  206. activation = nn.Hardtanh(min_val=0.0, max_val=20.0)
  207. layers = [activation, nn.Dropout(p=drop_prob)]
  208. return layers
  209. def forward(self, xs, xs_lens=None):
  210. if not self.conv_mask:
  211. xs_lens = 0
  212. # compute forward convolutions
  213. out = xs[-1]
  214. lens = xs_lens
  215. for i, l in enumerate(self.mconv):
  216. # if we're doing masked convolutions, we need to pass in and
  217. # possibly update the sequence lengths
  218. # if (i % 4) == 0 and self.conv_mask:
  219. if isinstance(l, MaskedConv1d):
  220. out, lens = l(out, lens)
  221. else:
  222. out = l(out)
  223. # compute the residuals
  224. if self.res is not None:
  225. for i, layer in enumerate(self.res):
  226. res_out = xs[i]
  227. for j, res_layer in enumerate(layer):
  228. if isinstance(res_layer, MaskedConv1d):
  229. res_out, _ = res_layer(res_out, xs_lens)
  230. else:
  231. res_out = res_layer(res_out)
  232. out = out + res_out
  233. # compute the output
  234. out = self.mout(out)
  235. if self.res is not None and self.dense_residual:
  236. out = xs + [out]
  237. else:
  238. out = [out]
  239. return (out, lens) if self.conv_mask else (out, None)
  240. class JasperEncoder(nn.Module):
  241. __constants__ = ["use_conv_masks"]
  242. def __init__(self, in_feats, activation, frame_splicing=1,
  243. init='xavier_uniform', use_conv_masks=False, blocks=[]):
  244. super(JasperEncoder, self).__init__()
  245. self.use_conv_masks = use_conv_masks
  246. self.layers = nn.ModuleList()
  247. in_feats *= frame_splicing
  248. all_residual_panes = []
  249. for i, blk in enumerate(blocks):
  250. blk['activation'] = activations[activation]()
  251. has_residual_dense = blk.pop('residual_dense', False)
  252. if has_residual_dense:
  253. all_residual_panes += [in_feats]
  254. blk['residual_panes'] = all_residual_panes
  255. else:
  256. blk['residual_panes'] = []
  257. self.layers.append(
  258. JasperBlock(in_feats, use_conv_masks=use_conv_masks, **blk))
  259. in_feats = blk['filters']
  260. self.apply(lambda x: init_weights(x, mode=init))
  261. def forward(self, x, x_lens=None):
  262. out, out_lens = [x], x_lens
  263. for layer in self.layers:
  264. out, out_lens = layer(out, out_lens)
  265. return out, out_lens
  266. class JasperDecoderForCTC(nn.Module):
  267. def __init__(self, in_feats, n_classes, init='xavier_uniform'):
  268. super(JasperDecoderForCTC, self).__init__()
  269. self.layers = nn.Sequential(
  270. nn.Conv1d(in_feats, n_classes, kernel_size=1, bias=True),)
  271. self.apply(lambda x: init_weights(x, mode=init))
  272. def forward(self, enc_out):
  273. out = self.layers(enc_out[-1]).transpose(1, 2)
  274. return F.log_softmax(out, dim=2)
  275. class GreedyCTCDecoder(nn.Module):
  276. @torch.no_grad()
  277. def forward(self, log_probs):
  278. return log_probs.argmax(dim=-1, keepdim=False).int()
  279. class QuartzNet(nn.Module):
  280. def __init__(self, encoder_kw, decoder_kw, transpose_in=False):
  281. super(QuartzNet, self).__init__()
  282. self.transpose_in = transpose_in
  283. self.encoder = JasperEncoder(**encoder_kw)
  284. self.decoder = JasperDecoderForCTC(**decoder_kw)
  285. def forward(self, x, x_lens=None):
  286. if self.encoder.use_conv_masks:
  287. assert x_lens is not None
  288. enc, enc_lens = self.encoder(x, x_lens)
  289. out = self.decoder(enc)
  290. return out, enc_lens
  291. else:
  292. if self.transpose_in:
  293. x = x.transpose(1, 2)
  294. enc, _ = self.encoder(x)
  295. out = self.decoder(enc)
  296. return out # XXX torchscript refuses to output None
  297. # TODO Explicitly add x_lens=None for inference (now x can be a Tensor or tuple)
  298. def infer(self, x):
  299. if self.encoder.use_conv_masks:
  300. return self.forward(x)
  301. else:
  302. ret = self.forward(x[0])
  303. return ret, len(ret)
  304. class CTCLossNM:
  305. def __init__(self, n_classes):
  306. self._criterion = nn.CTCLoss(blank=n_classes-1, reduction='none')
  307. def __call__(self, log_probs, targets, input_length, target_length):
  308. input_length = input_length.long()
  309. target_length = target_length.long()
  310. targets = targets.long()
  311. loss = self._criterion(log_probs.transpose(1, 0), targets,
  312. input_length, target_length)
  313. # note that this is different from reduction = 'mean'
  314. # because we are not dividing by target lengths
  315. return torch.mean(loss)