| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from common.utils import mask_from_lens
- from common.text.symbols import pad_idx, symbols
- class PositionalEmbedding(nn.Module):
- def __init__(self, demb):
- super(PositionalEmbedding, self).__init__()
- self.demb = demb
- inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
- self.register_buffer('inv_freq', inv_freq)
- def forward(self, pos_seq, bsz=None):
- sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
- pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
- if bsz is not None:
- return pos_emb[None, :, :].expand(bsz, -1, -1)
- else:
- return pos_emb[None, :, :]
- class PositionwiseFF(nn.Module):
- def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
- super(PositionwiseFF, self).__init__()
- self.d_model = d_model
- self.d_inner = d_inner
- self.dropout = dropout
- self.CoreNet = nn.Sequential(
- nn.Linear(d_model, d_inner), nn.ReLU(),
- nn.Dropout(dropout),
- nn.Linear(d_inner, d_model),
- nn.Dropout(dropout),
- )
- self.layer_norm = nn.LayerNorm(d_model)
- self.pre_lnorm = pre_lnorm
- def forward(self, inp):
- if self.pre_lnorm:
- # layer normalization + positionwise feed-forward
- core_out = self.CoreNet(self.layer_norm(inp))
- # residual connection
- output = core_out + inp
- else:
- # positionwise feed-forward
- core_out = self.CoreNet(inp)
- # residual connection + layer normalization
- output = self.layer_norm(inp + core_out)
- return output
- class PositionwiseConvFF(nn.Module):
- def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
- super(PositionwiseConvFF, self).__init__()
- self.d_model = d_model
- self.d_inner = d_inner
- self.dropout = dropout
- self.CoreNet = nn.Sequential(
- nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
- nn.ReLU(),
- # nn.Dropout(dropout), # worse convergence
- nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
- nn.Dropout(dropout),
- )
- self.layer_norm = nn.LayerNorm(d_model)
- self.pre_lnorm = pre_lnorm
- def forward(self, inp):
- return self._forward(inp)
- def _forward(self, inp):
- if self.pre_lnorm:
- # layer normalization + positionwise feed-forward
- core_out = inp.transpose(1, 2)
- core_out = self.CoreNet(self.layer_norm(core_out))
- core_out = core_out.transpose(1, 2)
- # residual connection
- output = core_out + inp
- else:
- # positionwise feed-forward
- core_out = inp.transpose(1, 2)
- core_out = self.CoreNet(core_out)
- core_out = core_out.transpose(1, 2)
- # residual connection + layer normalization
- output = self.layer_norm(inp + core_out)
- return output
- class MultiHeadAttn(nn.Module):
- def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1,
- pre_lnorm=False):
- super(MultiHeadAttn, self).__init__()
- self.n_head = n_head
- self.d_model = d_model
- self.d_head = d_head
- self.scale = 1 / (d_head ** 0.5)
- self.pre_lnorm = pre_lnorm
- self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
- self.drop = nn.Dropout(dropout)
- self.dropatt = nn.Dropout(dropatt)
- self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
- self.layer_norm = nn.LayerNorm(d_model)
- def forward(self, inp, attn_mask=None):
- return self._forward(inp, attn_mask)
- def _forward(self, inp, attn_mask=None):
- residual = inp
- if self.pre_lnorm:
- # layer normalization
- inp = self.layer_norm(inp)
- n_head, d_head = self.n_head, self.d_head
- head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=-1)
- head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
- head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
- head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
- q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
- k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
- v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), d_head)
- attn_score = torch.bmm(q, k.transpose(1, 2))
- attn_score.mul_(self.scale)
- if attn_mask is not None:
- attn_mask = attn_mask.unsqueeze(1)
- attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
- attn_score.masked_fill_(attn_mask, -float('inf'))
- attn_prob = F.softmax(attn_score, dim=2)
- attn_prob = self.dropatt(attn_prob)
- attn_vec = torch.bmm(attn_prob, v)
- attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
- attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
- inp.size(0), inp.size(1), n_head * d_head)
- # linear projection
- attn_out = self.o_net(attn_vec)
- attn_out = self.drop(attn_out)
- if self.pre_lnorm:
- # residual connection
- output = residual + attn_out
- else:
- # residual connection + layer normalization
- output = self.layer_norm(residual + attn_out)
- return output
- # disabled; slower
- def forward_einsum(self, h, attn_mask=None):
- # multihead attention
- # [hlen x bsz x n_head x d_head]
- c = h
- if self.pre_lnorm:
- # layer normalization
- c = self.layer_norm(c)
- head_q = self.q_net(h)
- head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
- head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
- head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
- head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
- # [bsz x n_head x qlen x klen]
- # attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
- attn_score = torch.einsum('bind,bjnd->bnij', (head_q, head_k))
- attn_score.mul_(self.scale)
- if attn_mask is not None and attn_mask.any().item():
- attn_score.masked_fill_(attn_mask[:, None, None, :], -float('inf'))
- # [bsz x qlen x klen x n_head]
- attn_prob = F.softmax(attn_score, dim=3)
- attn_prob = self.dropatt(attn_prob)
- # [bsz x n_head x qlen x klen] * [klen x bsz x n_head x d_head]
- # -> [qlen x bsz x n_head x d_head]
- attn_vec = torch.einsum('bnij,bjnd->bind', (attn_prob, head_v))
- attn_vec = attn_vec.contiguous().view(
- attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
- # linear projection
- attn_out = self.o_net(attn_vec)
- attn_out = self.drop(attn_out)
- if self.pre_lnorm:
- # residual connection
- output = h + attn_out
- else:
- # residual connection + layer normalization
- output = self.layer_norm(h + attn_out)
- return output
- class TransformerLayer(nn.Module):
- def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout,
- **kwargs):
- super(TransformerLayer, self).__init__()
- self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
- self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout,
- pre_lnorm=kwargs.get('pre_lnorm'))
- def forward(self, dec_inp, mask=None):
- output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
- output *= mask
- output = self.pos_ff(output)
- output *= mask
- return output
- class FFTransformer(nn.Module):
- def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
- dropout, dropatt, dropemb=0.0, embed_input=True, d_embed=None,
- pre_lnorm=False):
- super(FFTransformer, self).__init__()
- self.d_model = d_model
- self.n_head = n_head
- self.d_head = d_head
- if embed_input:
- self.word_emb = nn.Embedding(len(symbols), d_embed or d_model,
- padding_idx=pad_idx)
- else:
- self.word_emb = None
- self.pos_emb = PositionalEmbedding(self.d_model)
- self.drop = nn.Dropout(dropemb)
- self.layers = nn.ModuleList()
- for _ in range(n_layer):
- self.layers.append(
- TransformerLayer(
- n_head, d_model, d_head, d_inner, kernel_size, dropout,
- dropatt=dropatt, pre_lnorm=pre_lnorm)
- )
- def forward(self, dec_inp, seq_lens=None):
- if self.word_emb is None:
- inp = dec_inp
- mask = mask_from_lens(seq_lens).unsqueeze(2)
- else:
- inp = self.word_emb(dec_inp)
- # [bsz x L x 1]
- mask = (dec_inp != pad_idx).unsqueeze(2)
- pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
- pos_emb = self.pos_emb(pos_seq) * mask
- out = self.drop(inp + pos_emb)
- for layer in self.layers:
- out = layer(out, mask=mask)
- # out = self.drop(out)
- return out, mask
|