modeling.py 59 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BERT model."""
  16. from __future__ import absolute_import, division, print_function, unicode_literals
  17. import copy
  18. import json
  19. import logging
  20. import math
  21. import os
  22. import shutil
  23. import tarfile
  24. import tempfile
  25. import sys
  26. from io import open
  27. import torch
  28. from torch import nn
  29. from torch.nn import CrossEntropyLoss
  30. from torch.utils import checkpoint
  31. sys.path.append('/workspace/bert/')
  32. from file_utils import cached_path
  33. from torch.nn import Module
  34. from torch.nn.parameter import Parameter
  35. import torch.nn.functional as F
  36. import torch.nn.init as init
  37. logger = logging.getLogger(__name__)
  38. PRETRAINED_MODEL_ARCHIVE_MAP = {
  39. 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
  40. 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
  41. 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
  42. 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
  43. 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
  44. 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
  45. 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
  46. }
  47. CONFIG_NAME = 'bert_config.json'
  48. WEIGHTS_NAME = 'pytorch_model.bin'
  49. TF_WEIGHTS_NAME = 'model.ckpt'
  50. def load_tf_weights_in_bert(model, tf_checkpoint_path):
  51. """ Load tf checkpoints in a pytorch model
  52. """
  53. try:
  54. import re
  55. import numpy as np
  56. import tensorflow as tf
  57. except ImportError:
  58. print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
  59. "https://www.tensorflow.org/install/ for installation instructions.")
  60. raise
  61. tf_path = os.path.abspath(tf_checkpoint_path)
  62. print("Converting TensorFlow checkpoint from {}".format(tf_path))
  63. # Load weights from TF model
  64. init_vars = tf.train.list_variables(tf_path)
  65. names = []
  66. arrays = []
  67. for name, shape in init_vars:
  68. print("Loading TF weight {} with shape {}".format(name, shape))
  69. array = tf.train.load_variable(tf_path, name)
  70. names.append(name)
  71. arrays.append(array)
  72. for name, array in zip(names, arrays):
  73. name = name.split('/')
  74. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  75. # which are not required for using pretrained model
  76. if any(n in ["adam_v", "adam_m"] for n in name):
  77. print("Skipping {}".format("/".join(name)))
  78. continue
  79. pointer = model
  80. for m_name in name:
  81. if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
  82. l = re.split(r'_(\d+)', m_name)
  83. else:
  84. l = [m_name]
  85. if l[0] == 'kernel' or l[0] == 'gamma':
  86. pointer = getattr(pointer, 'weight')
  87. elif l[0] == 'output_bias' or l[0] == 'beta':
  88. pointer = getattr(pointer, 'bias')
  89. elif l[0] == 'output_weights':
  90. pointer = getattr(pointer, 'weight')
  91. else:
  92. pointer = getattr(pointer, l[0])
  93. if len(l) >= 2:
  94. num = int(l[1])
  95. pointer = pointer[num]
  96. if m_name[-11:] == '_embeddings':
  97. pointer = getattr(pointer, 'weight')
  98. elif m_name == 'kernel':
  99. array = np.ascontiguousarray(np.transpose(array))
  100. try:
  101. assert pointer.shape == array.shape
  102. except AssertionError as e:
  103. e.args += (pointer.shape, array.shape)
  104. raise
  105. print("Initialize PyTorch weight {}".format(name))
  106. pointer.data = torch.from_numpy(array)
  107. return model
  108. def gelu(x):
  109. return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
  110. #used only for triton inference
  111. def bias_gelu(bias, y):
  112. x = bias + y
  113. return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
  114. # used specifically for training since torch.nn.functional.gelu breaks ONNX export
  115. def bias_gelu_training(bias, y):
  116. x = bias + y
  117. return torch.nn.functional.gelu(x) # Breaks ONNX export
  118. def bias_tanh(bias, y):
  119. x = bias + y
  120. return torch.tanh(x)
  121. def swish(x):
  122. return x * torch.sigmoid(x)
  123. #torch.nn.functional.gelu(x) # Breaks ONNX export
  124. ACT2FN = {"gelu": gelu, "bias_gelu": bias_gelu, "bias_tanh": bias_tanh, "relu": torch.nn.functional.relu, "swish": swish}
  125. class LinearActivation(Module):
  126. r"""Fused Linear and activation Module.
  127. """
  128. __constants__ = ['bias']
  129. def __init__(self, in_features, out_features, act='gelu', bias=True):
  130. super(LinearActivation, self).__init__()
  131. self.in_features = in_features
  132. self.out_features = out_features
  133. self.act_fn = nn.Identity() #
  134. self.biased_act_fn = None #
  135. self.bias = None #
  136. if isinstance(act, str) or (sys.version_info[0] == 2 and isinstance(act, unicode)): # For TorchScript
  137. if bias and not 'bias' in act: # compatibility
  138. act = 'bias_' + act #
  139. self.biased_act_fn = ACT2FN[act] #
  140. else:
  141. self.act_fn = ACT2FN[act]
  142. else:
  143. self.act_fn = act
  144. self.weight = Parameter(torch.Tensor(out_features, in_features))
  145. if bias:
  146. self.bias = Parameter(torch.Tensor(out_features))
  147. else:
  148. self.register_parameter('bias', None)
  149. self.reset_parameters()
  150. def reset_parameters(self):
  151. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  152. if self.bias is not None:
  153. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  154. bound = 1 / math.sqrt(fan_in)
  155. init.uniform_(self.bias, -bound, bound)
  156. def forward(self, input):
  157. if not self.bias is None:
  158. return self.biased_act_fn(self.bias, F.linear(input, self.weight, None))
  159. else:
  160. return self.act_fn(F.linear(input, self.weight, self.bias))
  161. def extra_repr(self):
  162. return 'in_features={}, out_features={}, bias={}'.format(
  163. self.in_features, self.out_features, self.bias is not None
  164. )
  165. class BertConfig(object):
  166. """Configuration class to store the configuration of a `BertModel`.
  167. """
  168. def __init__(self,
  169. vocab_size_or_config_json_file,
  170. hidden_size=768,
  171. num_hidden_layers=12,
  172. num_attention_heads=12,
  173. intermediate_size=3072,
  174. hidden_act="gelu",
  175. hidden_dropout_prob=0.1,
  176. attention_probs_dropout_prob=0.1,
  177. max_position_embeddings=512,
  178. type_vocab_size=2,
  179. initializer_range=0.02,
  180. output_all_encoded_layers=False):
  181. """Constructs BertConfig.
  182. Args:
  183. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
  184. hidden_size: Size of the encoder layers and the pooler layer.
  185. num_hidden_layers: Number of hidden layers in the Transformer encoder.
  186. num_attention_heads: Number of attention heads for each attention layer in
  187. the Transformer encoder.
  188. intermediate_size: The size of the "intermediate" (i.e., feed-forward)
  189. layer in the Transformer encoder.
  190. hidden_act: The non-linear activation function (function or string) in the
  191. encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
  192. hidden_dropout_prob: The dropout probabilitiy for all fully connected
  193. layers in the embeddings, encoder, and pooler.
  194. attention_probs_dropout_prob: The dropout ratio for the attention
  195. probabilities.
  196. max_position_embeddings: The maximum sequence length that this model might
  197. ever be used with. Typically set this to something large just in case
  198. (e.g., 512 or 1024 or 2048).
  199. type_vocab_size: The vocabulary size of the `token_type_ids` passed into
  200. `BertModel`.
  201. initializer_range: The sttdev of the truncated_normal_initializer for
  202. initializing all weight matrices.
  203. """
  204. if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
  205. and isinstance(vocab_size_or_config_json_file, unicode)):
  206. with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
  207. json_config = json.loads(reader.read())
  208. for key, value in json_config.items():
  209. self.__dict__[key] = value
  210. elif isinstance(vocab_size_or_config_json_file, int):
  211. self.vocab_size = vocab_size_or_config_json_file
  212. self.hidden_size = hidden_size
  213. self.num_hidden_layers = num_hidden_layers
  214. self.num_attention_heads = num_attention_heads
  215. self.hidden_act = hidden_act
  216. self.intermediate_size = intermediate_size
  217. self.hidden_dropout_prob = hidden_dropout_prob
  218. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  219. self.max_position_embeddings = max_position_embeddings
  220. self.type_vocab_size = type_vocab_size
  221. self.initializer_range = initializer_range
  222. self.output_all_encoded_layers = output_all_encoded_layers
  223. else:
  224. raise ValueError("First argument must be either a vocabulary size (int)"
  225. "or the path to a pretrained model config file (str)")
  226. @classmethod
  227. def from_dict(cls, json_object):
  228. """Constructs a `BertConfig` from a Python dictionary of parameters."""
  229. config = BertConfig(vocab_size_or_config_json_file=-1)
  230. for key, value in json_object.items():
  231. config.__dict__[key] = value
  232. return config
  233. @classmethod
  234. def from_json_file(cls, json_file):
  235. """Constructs a `BertConfig` from a json file of parameters."""
  236. with open(json_file, "r", encoding='utf-8') as reader:
  237. text = reader.read()
  238. return cls.from_dict(json.loads(text))
  239. def __repr__(self):
  240. return str(self.to_json_string())
  241. def to_dict(self):
  242. """Serializes this instance to a Python dictionary."""
  243. output = copy.deepcopy(self.__dict__)
  244. return output
  245. def to_json_string(self):
  246. """Serializes this instance to a JSON string."""
  247. return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
  248. class BertNonFusedLayerNorm(nn.Module):
  249. def __init__(self, hidden_size, eps=1e-12):
  250. """Construct a layernorm module in the TF style (epsilon inside the square root).
  251. """
  252. super(BertNonFusedLayerNorm, self).__init__()
  253. self.weight = nn.Parameter(torch.ones(hidden_size))
  254. self.bias = nn.Parameter(torch.zeros(hidden_size))
  255. self.variance_epsilon = eps
  256. def forward(self, x):
  257. u = x.mean(-1, keepdim=True)
  258. s = (x - u)
  259. s = s * s
  260. s = s.mean(-1, keepdim=True)
  261. x = (x - u) / torch.sqrt(s + self.variance_epsilon)
  262. return self.weight * x + self.bias
  263. try:
  264. import apex
  265. #apex.amp.register_half_function(apex.normalization.fused_layer_norm, 'FusedLayerNorm')
  266. import apex.normalization
  267. from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
  268. #apex.amp.register_float_function(apex.normalization.FusedLayerNorm, 'forward')
  269. #BertLayerNorm = apex.normalization.FusedLayerNorm
  270. APEX_IS_AVAILABLE = True
  271. except ImportError:
  272. print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
  273. #BertLayerNorm = BertNonFusedLayerNorm
  274. APEX_IS_AVAILABLE = False
  275. class BertLayerNorm(Module):
  276. def __init__(self, hidden_size, eps=1e-12):
  277. super(BertLayerNorm, self).__init__()
  278. self.shape = torch.Size((hidden_size,))
  279. self.eps = eps
  280. self.weight = nn.Parameter(torch.ones(hidden_size))
  281. self.bias = nn.Parameter(torch.zeros(hidden_size))
  282. self.apex_enabled = APEX_IS_AVAILABLE
  283. @torch.jit.unused
  284. def fused_layer_norm(self, x):
  285. return FusedLayerNormAffineFunction.apply(
  286. x, self.weight, self.bias, self.shape, self.eps)
  287. def forward(self, x):
  288. if self.apex_enabled and not torch.jit.is_scripting():
  289. x = self.fused_layer_norm(x)
  290. else:
  291. u = x.mean(-1, keepdim=True)
  292. s = (x - u)
  293. s = s * s
  294. s = s.mean(-1, keepdim=True)
  295. x = (x - u) / torch.sqrt(s + self.eps)
  296. x = self.weight * x + self.bias
  297. return x
  298. class BertEmbeddings(nn.Module):
  299. """Construct the embeddings from word, position and token_type embeddings.
  300. """
  301. def __init__(self, config):
  302. super(BertEmbeddings, self).__init__()
  303. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
  304. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  305. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  306. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  307. # any TensorFlow checkpoint file
  308. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  309. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  310. def forward(self, input_ids, token_type_ids):
  311. seq_length = input_ids.size(1)
  312. position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
  313. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  314. words_embeddings = self.word_embeddings(input_ids)
  315. position_embeddings = self.position_embeddings(position_ids)
  316. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  317. embeddings = words_embeddings + position_embeddings + token_type_embeddings
  318. embeddings = self.LayerNorm(embeddings)
  319. embeddings = self.dropout(embeddings)
  320. return embeddings
  321. class BertSelfAttention(nn.Module):
  322. def __init__(self, config):
  323. super(BertSelfAttention, self).__init__()
  324. if config.hidden_size % config.num_attention_heads != 0:
  325. raise ValueError(
  326. "The hidden size (%d) is not a multiple of the number of attention "
  327. "heads (%d)" % (config.hidden_size, config.num_attention_heads))
  328. self.num_attention_heads = config.num_attention_heads
  329. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  330. self.all_head_size = self.num_attention_heads * self.attention_head_size
  331. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  332. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  333. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  334. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  335. def transpose_for_scores(self, x):
  336. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  337. x = torch.reshape(x, new_x_shape)
  338. return x.permute(0, 2, 1, 3)
  339. def transpose_key_for_scores(self, x):
  340. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  341. x = torch.reshape(x, new_x_shape)
  342. return x.permute(0, 2, 3, 1)
  343. def forward(self, hidden_states, attention_mask):
  344. mixed_query_layer = self.query(hidden_states)
  345. mixed_key_layer = self.key(hidden_states)
  346. mixed_value_layer = self.value(hidden_states)
  347. query_layer = self.transpose_for_scores(mixed_query_layer)
  348. key_layer = self.transpose_key_for_scores(mixed_key_layer)
  349. value_layer = self.transpose_for_scores(mixed_value_layer)
  350. # Take the dot product between "query" and "key" to get the raw attention scores.
  351. attention_scores = torch.matmul(query_layer, key_layer)
  352. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  353. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  354. attention_scores = attention_scores + attention_mask
  355. # Normalize the attention scores to probabilities.
  356. attention_probs = F.softmax(attention_scores, dim=-1)
  357. # This is actually dropping out entire tokens to attend to, which might
  358. # seem a bit unusual, but is taken from the original Transformer paper.
  359. attention_probs = self.dropout(attention_probs)
  360. context_layer = torch.matmul(attention_probs, value_layer)
  361. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  362. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  363. context_layer = torch.reshape(context_layer, new_context_layer_shape)
  364. return context_layer
  365. class BertSelfOutput(nn.Module):
  366. def __init__(self, config):
  367. super(BertSelfOutput, self).__init__()
  368. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  369. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  370. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  371. def forward(self, hidden_states, input_tensor):
  372. hidden_states = self.dense(hidden_states)
  373. hidden_states = self.dropout(hidden_states)
  374. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  375. return hidden_states
  376. class BertAttention(nn.Module):
  377. def __init__(self, config):
  378. super(BertAttention, self).__init__()
  379. self.self = BertSelfAttention(config)
  380. self.output = BertSelfOutput(config)
  381. def forward(self, input_tensor, attention_mask):
  382. self_output = self.self(input_tensor, attention_mask)
  383. attention_output = self.output(self_output, input_tensor)
  384. return attention_output
  385. class BertIntermediate(nn.Module):
  386. def __init__(self, config):
  387. super(BertIntermediate, self).__init__()
  388. self.dense_act = LinearActivation(config.hidden_size, config.intermediate_size, act=config.hidden_act)
  389. def forward(self, hidden_states):
  390. hidden_states = self.dense_act(hidden_states)
  391. return hidden_states
  392. class BertOutput(nn.Module):
  393. def __init__(self, config):
  394. super(BertOutput, self).__init__()
  395. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  396. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  397. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  398. def forward(self, hidden_states, input_tensor):
  399. hidden_states = self.dense(hidden_states)
  400. hidden_states = self.dropout(hidden_states)
  401. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  402. return hidden_states
  403. class BertLayer(nn.Module):
  404. def __init__(self, config):
  405. super(BertLayer, self).__init__()
  406. self.attention = BertAttention(config)
  407. self.intermediate = BertIntermediate(config)
  408. self.output = BertOutput(config)
  409. def forward(self, hidden_states, attention_mask):
  410. attention_output = self.attention(hidden_states, attention_mask)
  411. intermediate_output = self.intermediate(attention_output)
  412. layer_output = self.output(intermediate_output, attention_output)
  413. return layer_output
  414. class BertEncoder(nn.Module):
  415. def __init__(self, config):
  416. super(BertEncoder, self).__init__()
  417. self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
  418. self.output_all_encoded_layers = config.output_all_encoded_layers
  419. self._checkpoint_activations = False
  420. @torch.jit.unused
  421. def checkpointed_forward(self, hidden_states, attention_mask):
  422. def custom(start, end):
  423. def custom_forward(*inputs):
  424. layers = self.layer[start:end]
  425. x_ = inputs[0]
  426. for layer in layers:
  427. x_ = layer(x_, inputs[1])
  428. return x_
  429. return custom_forward
  430. l = 0
  431. num_layers = len(self.layer)
  432. chunk_length = math.ceil(math.sqrt(num_layers))
  433. while l < num_layers:
  434. hidden_states = checkpoint.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1)
  435. l += chunk_length
  436. return hidden_states
  437. def forward(self, hidden_states, attention_mask):
  438. all_encoder_layers = []
  439. if self._checkpoint_activations:
  440. hidden_states = self.checkpointed_forward(hidden_states, attention_mask)
  441. else:
  442. for i,layer_module in enumerate(self.layer):
  443. hidden_states = layer_module(hidden_states, attention_mask)
  444. if self.output_all_encoded_layers:
  445. all_encoder_layers.append(hidden_states)
  446. if not self.output_all_encoded_layers or self._checkpoint_activations:
  447. all_encoder_layers.append(hidden_states)
  448. return all_encoder_layers
  449. class BertPooler(nn.Module):
  450. def __init__(self, config):
  451. super(BertPooler, self).__init__()
  452. self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act="tanh")
  453. def forward(self, hidden_states):
  454. # We "pool" the model by simply taking the hidden state corresponding
  455. # to the first token.
  456. first_token_tensor = hidden_states[:, 0]
  457. pooled_output = self.dense_act(first_token_tensor)
  458. return pooled_output
  459. class BertPredictionHeadTransform(nn.Module):
  460. def __init__(self, config):
  461. super(BertPredictionHeadTransform, self).__init__()
  462. self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act=config.hidden_act)
  463. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  464. def forward(self, hidden_states):
  465. hidden_states = self.dense_act(hidden_states)
  466. hidden_states = self.LayerNorm(hidden_states)
  467. return hidden_states
  468. class BertLMPredictionHead(nn.Module):
  469. def __init__(self, config, bert_model_embedding_weights):
  470. super(BertLMPredictionHead, self).__init__()
  471. self.transform = BertPredictionHeadTransform(config)
  472. # The output weights are the same as the input embeddings, but there is
  473. # an output-only bias for each token.
  474. self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
  475. bert_model_embedding_weights.size(0),
  476. bias=False)
  477. self.decoder.weight = bert_model_embedding_weights
  478. self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
  479. def forward(self, hidden_states):
  480. hidden_states = self.transform(hidden_states)
  481. hidden_states = self.decoder(hidden_states) + self.bias
  482. return hidden_states
  483. class BertOnlyMLMHead(nn.Module):
  484. def __init__(self, config, bert_model_embedding_weights):
  485. super(BertOnlyMLMHead, self).__init__()
  486. self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
  487. def forward(self, sequence_output):
  488. prediction_scores = self.predictions(sequence_output)
  489. return prediction_scores
  490. class BertOnlyNSPHead(nn.Module):
  491. def __init__(self, config):
  492. super(BertOnlyNSPHead, self).__init__()
  493. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  494. def forward(self, pooled_output):
  495. seq_relationship_score = self.seq_relationship(pooled_output)
  496. return seq_relationship_score
  497. class BertPreTrainingHeads(nn.Module):
  498. def __init__(self, config, bert_model_embedding_weights):
  499. super(BertPreTrainingHeads, self).__init__()
  500. self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
  501. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  502. def forward(self, sequence_output, pooled_output):
  503. prediction_scores = self.predictions(sequence_output)
  504. seq_relationship_score = self.seq_relationship(pooled_output)
  505. return prediction_scores, seq_relationship_score
  506. class BertPreTrainedModel(nn.Module):
  507. """ An abstract class to handle weights initialization and
  508. a simple interface for dowloading and loading pretrained models.
  509. """
  510. def __init__(self, config, *inputs, **kwargs):
  511. super(BertPreTrainedModel, self).__init__()
  512. if not isinstance(config, BertConfig):
  513. raise ValueError(
  514. "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
  515. "To create a model from a Google pretrained model use "
  516. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  517. self.__class__.__name__, self.__class__.__name__
  518. ))
  519. self.config = config
  520. def init_bert_weights(self, module):
  521. """ Initialize the weights.
  522. """
  523. if isinstance(module, (nn.Linear, nn.Embedding)):
  524. # Slightly different from the TF version which uses truncated_normal for initialization
  525. # cf https://github.com/pytorch/pytorch/pull/5617
  526. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  527. elif isinstance(module, BertLayerNorm):
  528. module.bias.data.zero_()
  529. module.weight.data.fill_(1.0)
  530. if isinstance(module, nn.Linear) and module.bias is not None:
  531. module.bias.data.zero_()
  532. def checkpoint_activations(self, val):
  533. def _apply_flag(module):
  534. if hasattr(module, "_checkpoint_activations"):
  535. module._checkpoint_activations=val
  536. self.apply(_apply_flag)
  537. def enable_apex(self, val):
  538. def _apply_flag(module):
  539. if hasattr(module, "apex_enabled"):
  540. module.apex_enabled=val
  541. self.apply(_apply_flag)
  542. @classmethod
  543. def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
  544. from_tf=False, *inputs, **kwargs):
  545. """
  546. Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
  547. Download and cache the pre-trained model file if needed.
  548. Params:
  549. pretrained_model_name_or_path: either:
  550. - a str with the name of a pre-trained model to load selected in the list of:
  551. . `bert-base-uncased`
  552. . `bert-large-uncased`
  553. . `bert-base-cased`
  554. . `bert-large-cased`
  555. . `bert-base-multilingual-uncased`
  556. . `bert-base-multilingual-cased`
  557. . `bert-base-chinese`
  558. - a path or url to a pretrained model archive containing:
  559. . `bert_config.json` a configuration file for the model
  560. . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
  561. - a path or url to a pretrained model archive containing:
  562. . `bert_config.json` a configuration file for the model
  563. . `model.chkpt` a TensorFlow checkpoint
  564. from_tf: should we load the weights from a locally saved TensorFlow checkpoint
  565. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
  566. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
  567. *inputs, **kwargs: additional input for the specific Bert class
  568. (ex: num_labels for BertForSequenceClassification)
  569. """
  570. if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
  571. archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
  572. else:
  573. archive_file = pretrained_model_name_or_path
  574. # redirect to the cache, if necessary
  575. try:
  576. resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
  577. except EnvironmentError:
  578. logger.error(
  579. "Model name '{}' was not found in model name list ({}). "
  580. "We assumed '{}' was a path or url but couldn't find any file "
  581. "associated to this path or url.".format(
  582. pretrained_model_name_or_path,
  583. ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
  584. archive_file))
  585. return None
  586. if resolved_archive_file == archive_file:
  587. logger.info("loading archive file {}".format(archive_file))
  588. else:
  589. logger.info("loading archive file {} from cache at {}".format(
  590. archive_file, resolved_archive_file))
  591. tempdir = None
  592. if os.path.isdir(resolved_archive_file) or from_tf:
  593. serialization_dir = resolved_archive_file
  594. else:
  595. # Extract archive to temp dir
  596. tempdir = tempfile.mkdtemp()
  597. logger.info("extracting archive file {} to temp dir {}".format(
  598. resolved_archive_file, tempdir))
  599. with tarfile.open(resolved_archive_file, 'r:gz') as archive:
  600. archive.extractall(tempdir)
  601. serialization_dir = tempdir
  602. # Load config
  603. config_file = os.path.join(serialization_dir, CONFIG_NAME)
  604. config = BertConfig.from_json_file(config_file)
  605. logger.info("Model config {}".format(config))
  606. # Instantiate model.
  607. model = cls(config, *inputs, **kwargs)
  608. if state_dict is None and not from_tf:
  609. weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
  610. state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
  611. if tempdir:
  612. # Clean up temp dir
  613. shutil.rmtree(tempdir)
  614. if from_tf:
  615. # Directly load from a TensorFlow checkpoint
  616. weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
  617. return load_tf_weights_in_bert(model, weights_path)
  618. # Load from a PyTorch state_dict
  619. old_keys = []
  620. new_keys = []
  621. for key in state_dict.keys():
  622. new_key = None
  623. if 'gamma' in key:
  624. new_key = key.replace('gamma', 'weight')
  625. if 'beta' in key:
  626. new_key = key.replace('beta', 'bias')
  627. if new_key:
  628. old_keys.append(key)
  629. new_keys.append(new_key)
  630. for old_key, new_key in zip(old_keys, new_keys):
  631. state_dict[new_key] = state_dict.pop(old_key)
  632. missing_keys = []
  633. unexpected_keys = []
  634. error_msgs = []
  635. # copy state_dict so _load_from_state_dict can modify it
  636. metadata = getattr(state_dict, '_metadata', None)
  637. state_dict = state_dict.copy()
  638. if metadata is not None:
  639. state_dict._metadata = metadata
  640. def load(module, prefix=''):
  641. local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
  642. module._load_from_state_dict(
  643. state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
  644. for name, child in module._modules.items():
  645. if child is not None:
  646. load(child, prefix + name + '.')
  647. start_prefix = ''
  648. if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
  649. start_prefix = 'bert.'
  650. load(model, prefix=start_prefix)
  651. if len(missing_keys) > 0:
  652. logger.info("Weights of {} not initialized from pretrained model: {}".format(
  653. model.__class__.__name__, missing_keys))
  654. if len(unexpected_keys) > 0:
  655. logger.info("Weights from pretrained model not used in {}: {}".format(
  656. model.__class__.__name__, unexpected_keys))
  657. if len(error_msgs) > 0:
  658. raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
  659. model.__class__.__name__, "\n\t".join(error_msgs)))
  660. return model
  661. class BertModel(BertPreTrainedModel):
  662. """BERT model ("Bidirectional Embedding Representations from a Transformer").
  663. Params:
  664. config: a BertConfig class instance with the configuration to build a new model
  665. Inputs:
  666. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  667. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  668. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  669. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  670. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  671. a `sentence B` token (see BERT paper for more details).
  672. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  673. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  674. input sequence length in the current batch. It's the mask that we typically use for attention when
  675. a batch has varying length sentences.
  676. Outputs: Tuple of (encoded_layers, pooled_output)
  677. `encoded_layers`: controled by `output_all_encoded_layers` argument:
  678. - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
  679. of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
  680. encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
  681. - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
  682. to the last attention block of shape [batch_size, sequence_length, hidden_size],
  683. `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
  684. classifier pretrained on top of the hidden state associated to the first character of the
  685. input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
  686. Example usage:
  687. ```python
  688. # Already been converted into WordPiece token ids
  689. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  690. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  691. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  692. config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  693. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  694. model = modeling.BertModel(config=config)
  695. all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
  696. ```
  697. """
  698. def __init__(self, config):
  699. super(BertModel, self).__init__(config)
  700. self.embeddings = BertEmbeddings(config)
  701. self.encoder = BertEncoder(config)
  702. self.pooler = BertPooler(config)
  703. self.apply(self.init_bert_weights)
  704. self.output_all_encoded_layers = config.output_all_encoded_layers
  705. def forward(self, input_ids, token_type_ids, attention_mask):
  706. # We create a 3D attention mask from a 2D tensor mask.
  707. # Sizes are [batch_size, 1, 1, to_seq_length]
  708. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  709. # this attention mask is more simple than the triangular masking of causal attention
  710. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  711. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  712. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  713. # masked positions, this operation will create a tensor which is 0.0 for
  714. # positions we want to attend and -10000.0 for masked positions.
  715. # Since we are adding it to the raw scores before the softmax, this is
  716. # effectively the same as removing these entirely.
  717. extended_attention_mask = extended_attention_mask.to(dtype=self.embeddings.word_embeddings.weight.dtype) # fp16 compatibility
  718. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  719. embedding_output = self.embeddings(input_ids, token_type_ids)
  720. encoded_layers = self.encoder(embedding_output, extended_attention_mask)
  721. sequence_output = encoded_layers[-1]
  722. pooled_output = self.pooler(sequence_output)
  723. if not self.output_all_encoded_layers:
  724. encoded_layers = encoded_layers[-1:]
  725. return encoded_layers, pooled_output
  726. class BertForPreTraining(BertPreTrainedModel):
  727. """BERT model with pre-training heads.
  728. This module comprises the BERT model followed by the two pre-training heads:
  729. - the masked language modeling head, and
  730. - the next sentence classification head.
  731. Params:
  732. config: a BertConfig class instance with the configuration to build a new model.
  733. Inputs:
  734. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  735. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  736. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  737. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  738. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  739. a `sentence B` token (see BERT paper for more details).
  740. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  741. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  742. input sequence length in the current batch. It's the mask that we typically use for attention when
  743. a batch has varying length sentences.
  744. `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
  745. with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
  746. is only computed for the labels set in [0, ..., vocab_size]
  747. `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
  748. with indices selected in [0, 1].
  749. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
  750. Outputs:
  751. if `masked_lm_labels` and `next_sentence_label` are not `None`:
  752. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  753. sentence classification loss.
  754. if `masked_lm_labels` or `next_sentence_label` is `None`:
  755. Outputs a tuple comprising
  756. - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
  757. - the next sentence classification logits of shape [batch_size, 2].
  758. Example usage:
  759. ```python
  760. # Already been converted into WordPiece token ids
  761. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  762. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  763. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  764. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  765. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  766. model = BertForPreTraining(config)
  767. masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
  768. ```
  769. """
  770. def __init__(self, config):
  771. super(BertForPreTraining, self).__init__(config)
  772. self.bert = BertModel(config)
  773. self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
  774. self.apply(self.init_bert_weights)
  775. def forward(self, input_ids, token_type_ids, attention_mask):
  776. encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  777. sequence_output = encoded_layers[-1]
  778. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  779. return prediction_scores, seq_relationship_score
  780. class BertForMaskedLM(BertPreTrainedModel):
  781. """BERT model with the masked language modeling head.
  782. This module comprises the BERT model followed by the masked language modeling head.
  783. Params:
  784. config: a BertConfig class instance with the configuration to build a new model.
  785. Inputs:
  786. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  787. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  788. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  789. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  790. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  791. a `sentence B` token (see BERT paper for more details).
  792. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  793. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  794. input sequence length in the current batch. It's the mask that we typically use for attention when
  795. a batch has varying length sentences.
  796. `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
  797. with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
  798. is only computed for the labels set in [0, ..., vocab_size]
  799. Outputs:
  800. if `masked_lm_labels` is not `None`:
  801. Outputs the masked language modeling loss.
  802. if `masked_lm_labels` is `None`:
  803. Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
  804. Example usage:
  805. ```python
  806. # Already been converted into WordPiece token ids
  807. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  808. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  809. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  810. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  811. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  812. model = BertForMaskedLM(config)
  813. masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
  814. ```
  815. """
  816. def __init__(self, config):
  817. super(BertForMaskedLM, self).__init__(config)
  818. self.bert = BertModel(config)
  819. self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
  820. self.apply(self.init_bert_weights)
  821. def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
  822. encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
  823. sequence_output = encoded_layers[-1]
  824. prediction_scores = self.cls(sequence_output)
  825. if masked_lm_labels is not None:
  826. loss_fct = CrossEntropyLoss(ignore_index=-1)
  827. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
  828. return masked_lm_loss
  829. else:
  830. return prediction_scores
  831. class BertForNextSentencePrediction(BertPreTrainedModel):
  832. """BERT model with next sentence prediction head.
  833. This module comprises the BERT model followed by the next sentence classification head.
  834. Params:
  835. config: a BertConfig class instance with the configuration to build a new model.
  836. Inputs:
  837. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  838. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  839. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  840. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  841. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  842. a `sentence B` token (see BERT paper for more details).
  843. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  844. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  845. input sequence length in the current batch. It's the mask that we typically use for attention when
  846. a batch has varying length sentences.
  847. `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
  848. with indices selected in [0, 1].
  849. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
  850. Outputs:
  851. if `next_sentence_label` is not `None`:
  852. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  853. sentence classification loss.
  854. if `next_sentence_label` is `None`:
  855. Outputs the next sentence classification logits of shape [batch_size, 2].
  856. Example usage:
  857. ```python
  858. # Already been converted into WordPiece token ids
  859. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  860. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  861. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  862. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  863. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  864. model = BertForNextSentencePrediction(config)
  865. seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
  866. ```
  867. """
  868. def __init__(self, config):
  869. super(BertForNextSentencePrediction, self).__init__(config)
  870. self.bert = BertModel(config)
  871. self.cls = BertOnlyNSPHead(config)
  872. self.apply(self.init_bert_weights)
  873. def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
  874. _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  875. seq_relationship_score = self.cls( pooled_output)
  876. if next_sentence_label is not None:
  877. loss_fct = CrossEntropyLoss(ignore_index=-1)
  878. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  879. return next_sentence_loss
  880. else:
  881. return seq_relationship_score
  882. class BertForSequenceClassification(BertPreTrainedModel):
  883. """BERT model for classification.
  884. This module is composed of the BERT model with a linear layer on top of
  885. the pooled output.
  886. Params:
  887. `config`: a BertConfig class instance with the configuration to build a new model.
  888. `num_labels`: the number of classes for the classifier. Default = 2.
  889. Inputs:
  890. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  891. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  892. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  893. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  894. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  895. a `sentence B` token (see BERT paper for more details).
  896. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  897. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  898. input sequence length in the current batch. It's the mask that we typically use for attention when
  899. a batch has varying length sentences.
  900. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
  901. with indices selected in [0, ..., num_labels].
  902. Outputs:
  903. if `labels` is not `None`:
  904. Outputs the CrossEntropy classification loss of the output with the labels.
  905. if `labels` is `None`:
  906. Outputs the classification logits of shape [batch_size, num_labels].
  907. Example usage:
  908. ```python
  909. # Already been converted into WordPiece token ids
  910. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  911. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  912. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  913. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  914. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  915. num_labels = 2
  916. model = BertForSequenceClassification(config, num_labels)
  917. logits = model(input_ids, token_type_ids, input_mask)
  918. ```
  919. """
  920. def __init__(self, config, num_labels):
  921. super(BertForSequenceClassification, self).__init__(config)
  922. self.num_labels = num_labels
  923. self.bert = BertModel(config)
  924. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  925. self.classifier = nn.Linear(config.hidden_size, num_labels)
  926. self.apply(self.init_bert_weights)
  927. def forward(self, input_ids, token_type_ids=None, attention_mask=None):
  928. _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  929. pooled_output = self.dropout(pooled_output)
  930. return self.classifier(pooled_output)
  931. class BertForMultipleChoice(BertPreTrainedModel):
  932. """BERT model for multiple choice tasks.
  933. This module is composed of the BERT model with a linear layer on top of
  934. the pooled output.
  935. Params:
  936. `config`: a BertConfig class instance with the configuration to build a new model.
  937. `num_choices`: the number of classes for the classifier. Default = 2.
  938. Inputs:
  939. `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
  940. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  941. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  942. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
  943. with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
  944. and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
  945. `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
  946. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  947. input sequence length in the current batch. It's the mask that we typically use for attention when
  948. a batch has varying length sentences.
  949. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
  950. with indices selected in [0, ..., num_choices].
  951. Outputs:
  952. if `labels` is not `None`:
  953. Outputs the CrossEntropy classification loss of the output with the labels.
  954. if `labels` is `None`:
  955. Outputs the classification logits of shape [batch_size, num_labels].
  956. Example usage:
  957. ```python
  958. # Already been converted into WordPiece token ids
  959. input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
  960. input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
  961. token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
  962. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  963. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  964. num_choices = 2
  965. model = BertForMultipleChoice(config, num_choices)
  966. logits = model(input_ids, token_type_ids, input_mask)
  967. ```
  968. """
  969. def __init__(self, config, num_choices):
  970. super(BertForMultipleChoice, self).__init__(config)
  971. self.num_choices = num_choices
  972. self.bert = BertModel(config)
  973. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  974. self.classifier = nn.Linear(config.hidden_size, 1)
  975. self.apply(self.init_bert_weights)
  976. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  977. flat_input_ids = input_ids.view(-1, input_ids.size(-1))
  978. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
  979. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
  980. _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask)
  981. pooled_output = self.dropout(pooled_output)
  982. logits = self.classifier(pooled_output)
  983. reshaped_logits = logits.view(-1, self.num_choices)
  984. if labels is not None:
  985. loss_fct = CrossEntropyLoss()
  986. loss = loss_fct(reshaped_logits, labels)
  987. return loss
  988. else:
  989. return reshaped_logits
  990. class BertForTokenClassification(BertPreTrainedModel):
  991. """BERT model for token-level classification.
  992. This module is composed of the BERT model with a linear layer on top of
  993. the full hidden state of the last layer.
  994. Params:
  995. `config`: a BertConfig class instance with the configuration to build a new model.
  996. `num_labels`: the number of classes for the classifier. Default = 2.
  997. Inputs:
  998. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  999. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  1000. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  1001. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  1002. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  1003. a `sentence B` token (see BERT paper for more details).
  1004. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  1005. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  1006. input sequence length in the current batch. It's the mask that we typically use for attention when
  1007. a batch has varying length sentences.
  1008. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
  1009. with indices selected in [0, ..., num_labels].
  1010. Outputs:
  1011. if `labels` is not `None`:
  1012. Outputs the CrossEntropy classification loss of the output with the labels.
  1013. if `labels` is `None`:
  1014. Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
  1015. Example usage:
  1016. ```python
  1017. # Already been converted into WordPiece token ids
  1018. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  1019. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  1020. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  1021. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  1022. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  1023. num_labels = 2
  1024. model = BertForTokenClassification(config, num_labels)
  1025. logits = model(input_ids, token_type_ids, input_mask)
  1026. ```
  1027. """
  1028. def __init__(self, config, num_labels):
  1029. super(BertForTokenClassification, self).__init__(config)
  1030. self.num_labels = num_labels
  1031. self.bert = BertModel(config)
  1032. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1033. self.classifier = nn.Linear(config.hidden_size, num_labels)
  1034. self.apply(self.init_bert_weights)
  1035. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  1036. encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
  1037. sequence_output = encoded_layers[-1]
  1038. sequence_output = self.dropout(sequence_output)
  1039. logits = self.classifier(sequence_output)
  1040. if labels is not None:
  1041. loss_fct = CrossEntropyLoss()
  1042. # Only keep active parts of the loss
  1043. if attention_mask is not None:
  1044. active_loss = attention_mask.view(-1) == 1
  1045. active_logits = logits.view(-1, self.num_labels)[active_loss]
  1046. active_labels = labels.view(-1)[active_loss]
  1047. loss = loss_fct(active_logits, active_labels)
  1048. else:
  1049. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1050. return loss
  1051. else:
  1052. return logits
  1053. class BertForQuestionAnswering(BertPreTrainedModel):
  1054. """BERT model for Question Answering (span extraction).
  1055. This module is composed of the BERT model with a linear layer on top of
  1056. the sequence output that computes start_logits and end_logits
  1057. Params:
  1058. `config`: a BertConfig class instance with the configuration to build a new model.
  1059. Inputs:
  1060. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  1061. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  1062. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  1063. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  1064. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  1065. a `sentence B` token (see BERT paper for more details).
  1066. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  1067. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  1068. input sequence length in the current batch. It's the mask that we typically use for attention when
  1069. a batch has varying length sentences.
  1070. Outputs:
  1071. Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
  1072. position tokens of shape [batch_size, sequence_length].
  1073. Example usage:
  1074. ```python
  1075. # Already been converted into WordPiece token ids
  1076. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  1077. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  1078. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  1079. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  1080. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  1081. model = BertForQuestionAnswering(config)
  1082. start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
  1083. ```
  1084. """
  1085. def __init__(self, config):
  1086. super(BertForQuestionAnswering, self).__init__(config)
  1087. self.bert = BertModel(config)
  1088. # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
  1089. # self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1090. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  1091. self.apply(self.init_bert_weights)
  1092. def forward(self, input_ids, token_type_ids, attention_mask):
  1093. encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
  1094. sequence_output = encoded_layers[-1]
  1095. logits = self.qa_outputs(sequence_output)
  1096. start_logits, end_logits = logits.split(1, dim=-1)
  1097. start_logits = start_logits.squeeze(-1)
  1098. end_logits = end_logits.squeeze(-1)
  1099. return start_logits, end_logits