|
|
@@ -21,6 +21,11 @@ from utils.log_uniform_sampler import sample_logits
|
|
|
from utils.proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
|
|
|
|
|
|
|
|
|
[email protected]
|
|
|
+def add_and_scale(tensor1, tensor2, alpha: float):
|
|
|
+ return alpha * (tensor1 + tensor2)
|
|
|
+
|
|
|
+
|
|
|
class PositionalEmbedding(nn.Module):
|
|
|
def __init__(self, demb):
|
|
|
super(PositionalEmbedding, self).__init__()
|
|
|
@@ -122,7 +127,7 @@ class MultiHeadAttn(nn.Module):
|
|
|
# [bsz x n_head x qlen x klen]
|
|
|
attn_score = torch.einsum('ibnd,jbnd->bnij', (head_q, head_k))
|
|
|
attn_score.mul_(self.scale)
|
|
|
- if attn_mask is not None and attn_mask.any().item():
|
|
|
+ if attn_mask is not None:
|
|
|
if attn_mask.dim() == 2:
|
|
|
attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
|
|
|
elif attn_mask.dim() == 3:
|
|
|
@@ -266,11 +271,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|
|
BD = self._rel_shift(BD)
|
|
|
|
|
|
# [bsz x n_head x qlen x klen]
|
|
|
- attn_score = AC + BD
|
|
|
- attn_score.mul_(self.scale)
|
|
|
+ attn_score = add_and_scale(AC, BD, self.scale)
|
|
|
|
|
|
# compute attention probability
|
|
|
- if attn_mask is not None and attn_mask.any().item():
|
|
|
+ if attn_mask is not None:
|
|
|
if attn_mask.dim() == 2:
|
|
|
attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
|
|
|
elif attn_mask.dim() == 3:
|
|
|
@@ -354,11 +358,10 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|
|
BD = self._rel_shift(B_ + D_)
|
|
|
|
|
|
# [bsz x qlen x klen x n_head]
|
|
|
- attn_score = AC + BD
|
|
|
- attn_score.mul_(self.scale)
|
|
|
+ attn_score = add_and_scale(AC, BD, self.scale)
|
|
|
|
|
|
# compute attention probability
|
|
|
- if attn_mask is not None and attn_mask.any().item():
|
|
|
+ if attn_mask is not None:
|
|
|
if attn_mask.dim() == 2:
|
|
|
attn_score.masked_fill_(attn_mask[None, None, :, :], -float('inf'))
|
|
|
elif attn_mask.dim() == 3:
|
|
|
@@ -627,6 +630,12 @@ class MemTransformerLM(nn.Module):
|
|
|
self.n_layer, self.max_klen, self.d_model).zero_())
|
|
|
|
|
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
|
|
+ if tgt_len < 1:
|
|
|
+ raise RuntimeError(f'tgt_len should be >= 1, but got {tgt_len}')
|
|
|
+ if ext_len < 0:
|
|
|
+ raise RuntimeError(f'ext_len should be >= 0, but got {ext_len}')
|
|
|
+ if mem_len < 0:
|
|
|
+ raise RuntimeError(f'mem_len should be >= 0, but got {mem_len}')
|
|
|
self.tgt_len = tgt_len
|
|
|
self.mem_len = mem_len
|
|
|
self.ext_len = ext_len
|
|
|
@@ -634,7 +643,7 @@ class MemTransformerLM(nn.Module):
|
|
|
def init_mems(self):
|
|
|
if self.mem_len > 0:
|
|
|
param = next(self.parameters())
|
|
|
- mems = torch.empty(self.n_layer + 1, 0, dtype=param.dtype,
|
|
|
+ mems = torch.empty(self.n_layer, 0, dtype=param.dtype,
|
|
|
device=param.device)
|
|
|
return mems
|
|
|
else:
|
|
|
@@ -654,14 +663,21 @@ class MemTransformerLM(nn.Module):
|
|
|
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
|
|
# to `mlen + qlen - self.ext_len`.
|
|
|
with torch.no_grad():
|
|
|
- end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
|
|
- beg_idx = max(0, end_idx - self.mem_len)
|
|
|
stacked = torch.stack(hids)
|
|
|
- if mems.numel():
|
|
|
- cat = torch.cat([mems, stacked], dim=1)
|
|
|
+ if (
|
|
|
+ self.mem_len == self.tgt_len
|
|
|
+ and self.ext_len == 0
|
|
|
+ and stacked.size(1) == self.mem_len
|
|
|
+ ):
|
|
|
+ new_mems = stacked.detach()
|
|
|
else:
|
|
|
- cat = stacked
|
|
|
- new_mems = cat[:, beg_idx:end_idx].detach()
|
|
|
+ end_idx = mlen + max(0, qlen - self.ext_len)
|
|
|
+ beg_idx = max(0, end_idx - self.mem_len)
|
|
|
+ if mems.numel():
|
|
|
+ cat = torch.cat([mems, stacked], dim=1)
|
|
|
+ else:
|
|
|
+ cat = stacked
|
|
|
+ new_mems = cat[:, beg_idx:end_idx].detach()
|
|
|
|
|
|
return new_mems
|
|
|
|
|
|
@@ -697,18 +713,17 @@ class MemTransformerLM(nn.Module):
|
|
|
core_out = self.drop(word_emb)
|
|
|
pos_emb = self.drop(pos_emb)
|
|
|
|
|
|
- hids.append(core_out.detach())
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
+ hids.append(core_out.detach())
|
|
|
mems_i = None if mems is None else mems[i]
|
|
|
core_out = layer(core_out, pos_emb, self.r_w_bias,
|
|
|
self.r_r_bias, dec_attn_mask=dec_attn_mask,
|
|
|
mems=mems_i)
|
|
|
- hids.append(core_out.detach())
|
|
|
# learnable
|
|
|
elif self.attn_type == 1:
|
|
|
core_out = self.drop(word_emb)
|
|
|
- hids.append(core_out.detach())
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
+ hids.append(core_out.detach())
|
|
|
if self.clamp_len > 0:
|
|
|
r_emb = self.r_emb[i][-self.clamp_len:]
|
|
|
r_bias = self.r_bias[i][-self.clamp_len:]
|
|
|
@@ -718,7 +733,6 @@ class MemTransformerLM(nn.Module):
|
|
|
mems_i = None if mems is None else mems[i]
|
|
|
core_out = layer(core_out, r_emb, self.r_w_bias[i],
|
|
|
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
|
|
- hids.append(core_out.detach())
|
|
|
# absolute
|
|
|
elif self.attn_type == 2:
|
|
|
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
|
|
|
@@ -729,19 +743,18 @@ class MemTransformerLM(nn.Module):
|
|
|
|
|
|
core_out = self.drop(word_emb + pos_emb[-qlen:])
|
|
|
|
|
|
- hids.append(core_out.detach())
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
+ hids.append(core_out.detach())
|
|
|
mems_i = None if mems is None else mems[i]
|
|
|
if mems_i is not None and len(mems_i) and i == 0:
|
|
|
mems_i += pos_emb[:mlen]
|
|
|
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
|
|
|
mems=mems_i)
|
|
|
- hids.append(core_out.detach())
|
|
|
elif self.attn_type == 3:
|
|
|
core_out = self.drop(word_emb)
|
|
|
|
|
|
- hids.append(core_out.detach())
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
+ hids.append(core_out.detach())
|
|
|
mems_i = None if mems is None else mems[i]
|
|
|
if mems_i is not None and len(mems_i) and mlen > 0:
|
|
|
cur_emb = self.r_emb[i][:-qlen]
|
|
|
@@ -756,7 +769,6 @@ class MemTransformerLM(nn.Module):
|
|
|
|
|
|
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
|
|
|
mems=mems_i)
|
|
|
- hids.append(core_out.detach())
|
|
|
|
|
|
core_out = self.drop(core_out)
|
|
|
|