pyt_mha.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import defaultdict
  15. import torch
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from common.fairseq.modules.multihead_attention import RotaryEmbedding
  19. def mha_state_dict_to_fairseq(sd):
  20. """Concatenate q, k, v matrices and load as usual."""
  21. new_sd = {}
  22. qkv = defaultdict(dict)
  23. for key, val in sd.items():
  24. fields = key.split('.')
  25. if len(fields) < 2:
  26. continue
  27. prefix = '.'.join(fields[:-2] + [""])
  28. module, param = fields[-2:]
  29. if module in ['q_proj', 'k_proj', 'v_proj']:
  30. qkv[prefix][module + '.' + param] = val
  31. else:
  32. new_sd[key] = val
  33. for prefix, param_dict in qkv.items():
  34. # Stitch qkv params together
  35. assert len(param_dict) == 6
  36. new_sd[f"{prefix}qkv.weight"] = torch.cat(
  37. [param_dict[f"{k}_proj.weight"] for k in ["q", "k", "v"]], dim=0)
  38. new_sd[f"{prefix}qkv.bias"] = torch.cat(
  39. [param_dict[f"{k}_proj.bias"] for k in ["q", "k", "v"]], dim=0)
  40. return new_sd
  41. class PytMultiheadAttention(nn.Module):
  42. """Drop-in replacement for Fairseq MHA.
  43. Calls torch.nn.functional with combined qkv.
  44. """
  45. def __init__(
  46. self,
  47. embed_dim,
  48. num_heads,
  49. dropout=0.0,
  50. bias=True,
  51. self_attention=True,
  52. rotary_embeddings=False,
  53. ):
  54. super().__init__()
  55. assert self_attention
  56. assert not rotary_embeddings, "Not yet supported"
  57. self.embed_dim = embed_dim
  58. self.num_heads = num_heads
  59. self.rotary_embeddings = rotary_embeddings
  60. if self.rotary_embeddings:
  61. self.rotary_freq = RotaryEmbedding(embed_dim)
  62. self.head_dim = embed_dim // num_heads
  63. assert (
  64. self.head_dim * num_heads == self.embed_dim
  65. ), "embed_dim must be divisible by num_heads"
  66. self.qkv = nn.Linear(embed_dim, 3 * num_heads * self.head_dim,
  67. bias=bias)
  68. self.dropatt = nn.Dropout(dropout)
  69. self.out_proj = nn.Linear(num_heads * self.head_dim, embed_dim,
  70. bias=bias)
  71. self.reset_parameters()
  72. def hook(state_dict, prefix, *args, **kwargs):
  73. this_keys = {k for k in state_dict.keys() if k.startswith(prefix)}
  74. new_sd = {k: v for k, v in state_dict.items() if k in this_keys}
  75. for k in this_keys:
  76. del state_dict[k]
  77. state_dict.update(mha_state_dict_to_fairseq(new_sd))
  78. self._register_load_state_dict_pre_hook(hook)
  79. def forward(self, query, key=None, value=None, key_padding_mask=None,
  80. attn_mask=None):
  81. return F.multi_head_attention_forward(
  82. query,
  83. key,
  84. value,
  85. self.embed_dim,
  86. self.num_heads,
  87. self.qkv.weight,
  88. self.qkv.bias,
  89. None,
  90. None,
  91. False,
  92. self.dropatt.p,
  93. self.out_proj.weight,
  94. self.out_proj.bias,
  95. training=self.training,
  96. key_padding_mask=key_padding_mask,
  97. need_weights=False,
  98. attn_mask=attn_mask,
  99. average_attn_weights=False,
  100. )
  101. def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
  102. """Split q, k, v matrices for bwd compatibility with Fairseq."""
  103. sd = super().state_dict(*args, destination, prefix, keep_vars)
  104. for key in list(sd.keys()):
  105. if not (key.endswith(".qkv.weight") or key.endswith(".qkv.bias")):
  106. continue
  107. *pref, qkv, param = key.split(".")
  108. pref = ".".join(pref)
  109. assert qkv == "qkv"
  110. q, k, v = torch.chunk(sd.pop(key), 3, dim=0)
  111. sd[f"{pref}.q_proj.{param}"] = q
  112. sd[f"{pref}.k_proj.{param}"] = k
  113. sd[f"{pref}.v_proj.{param}"] = v
  114. return sd
  115. def reset_parameters(self):
  116. # Init as in Fairseq with qkv_same_dim=True and separate qkv projs
  117. t = self.qkv.weight.size(0) // 3
  118. nn.init.xavier_uniform_(self.qkv.weight[0*t:1*t], gain=1 / (2 ** 0.5))
  119. nn.init.xavier_uniform_(self.qkv.weight[1*t:2*t], gain=1 / (2 ** 0.5))
  120. nn.init.xavier_uniform_(self.qkv.weight[2*t:3*t], gain=1 / (2 ** 0.5))
  121. nn.init.xavier_uniform_(self.out_proj.weight)
  122. if self.out_proj.bias is not None:
  123. nn.init.constant_(self.out_proj.bias, 0.0)
  124. class Fp32Softmax(nn.Softmax):
  125. def forward(self, x):
  126. return F.softmax(x.float(), dim=self.dim).type_as(x)
  127. class SlowMultiHeadAttention(nn.Module):
  128. """Drop-in replacement for Fairseq MHA."""
  129. def __init__(self,
  130. embed_dim,
  131. num_heads,
  132. dropout=0.0,
  133. bias=True,
  134. self_attention=True,
  135. rotary_embeddings=None,
  136. fp32_softmax=False,
  137. ):
  138. super().__init__()
  139. n_head = num_heads
  140. d_model = embed_dim
  141. d_head = embed_dim // n_head
  142. dropatt = dropout
  143. pre_lnorm = False
  144. assert self_attention
  145. assert rotary_embeddings is None, "Rotary embs not yet supported"
  146. self.embed_dim = embed_dim
  147. self.num_heads = num_heads
  148. self.n_head = n_head
  149. self.d_model = d_model
  150. self.d_head = d_head
  151. self.scale = 1 / (d_head ** 0.5)
  152. self.pre_lnorm = pre_lnorm
  153. self.qkv = nn.Linear(d_model, 3 * n_head * d_head, bias=bias)
  154. self.dropatt = nn.Dropout(dropatt)
  155. self.proj = nn.Linear(n_head * d_head, d_model, bias=bias)
  156. self.layer_norm = nn.LayerNorm(d_model, elementwise_affine=False)
  157. self.softmax = Fp32Softmax(dim=2) if fp32_softmax else nn.Softmax(dim=2)
  158. def state_dict(self):
  159. """Convert QKV to be compatible with Fairseq"""
  160. sd = super().state_dict()
  161. ret = {}
  162. for key, val in sd.items():
  163. fields = key.split('.')
  164. if len(fields) < 2:
  165. continue
  166. prefix = '.'.join(fields[:-2] + [""])
  167. module, param = fields[-2:]
  168. if module == 'qkv':
  169. q, k, v = torch.chunk(val, 3, dim=0)
  170. ret[f"{prefix}q_proj.{param}"] = q
  171. ret[f"{prefix}k_proj.{param}"] = k
  172. ret[f"{prefix}v_proj.{param}"] = v
  173. else:
  174. ret[key] = val
  175. return ret
  176. def load_state_dict(self, sd):
  177. from collections import defaultdict
  178. ret = {}
  179. qkv = defaultdict(dict)
  180. for key, val in sd.items():
  181. fields = key.split('.')
  182. if len(fields) < 2:
  183. continue
  184. prefix = '.'.join(fields[:-2] + [""])
  185. module, param = fields[-2:]
  186. if module in ['q_proj', 'k_proj', 'v_proj']:
  187. qkv[prefix][module + '.' + param] = val
  188. else:
  189. ret[key] = val
  190. for prefix, param_dict in qkv.items():
  191. # Stitch qkv params together
  192. assert len(param_dict) == 6
  193. ret[f"{prefix}qkv.weight"] = torch.cat(
  194. [param_dict[f"{k}_proj.weight"] for k in ["q", "k", "v"]],
  195. dim=0)
  196. ret[f"{prefix}qkv.bias"] = torch.cat(
  197. [param_dict[f"{k}_proj.bias"] for k in ["q", "k", "v"]],
  198. dim=0)
  199. super().load_state_dict(ret)
  200. def forward(self, inp, attn_mask=None):
  201. inp = inp.permute(1, 0, 2) # (T, B, H) -> (B, T, H)
  202. if self.pre_lnorm:
  203. inp = self.layer_norm(inp)
  204. n_head, d_head = self.n_head, self.d_head
  205. head_q, head_k, head_v = torch.chunk(self.qkv(inp), 3, dim=2)
  206. head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
  207. head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
  208. head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
  209. q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
  210. k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
  211. v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
  212. attn_score = torch.bmm(q, k.transpose(1, 2))
  213. attn_score.mul_(self.scale)
  214. if attn_mask is not None:
  215. attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
  216. attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
  217. attn_score.masked_fill_(attn_mask.to(torch.bool), -float('inf'))
  218. attn_prob = self.softmax(attn_score)
  219. attn_prob = self.dropatt(attn_prob)
  220. attn_vec = torch.bmm(attn_prob, v)
  221. attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
  222. attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
  223. inp.size(0), inp.size(1), n_head * d_head)
  224. output = self.proj(attn_vec)
  225. return output.permute(1, 0, 2) # (B, T, H) -> (T, B, H)