modeling.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BERT model."""
  16. from __future__ import absolute_import, division, print_function, unicode_literals
  17. import copy
  18. import json
  19. import logging
  20. import math
  21. import os
  22. import shutil
  23. import tarfile
  24. import tempfile
  25. import sys
  26. from io import open
  27. import torch
  28. from torch import nn
  29. from torch.nn import CrossEntropyLoss
  30. from torch.utils import checkpoint
  31. sys.path.append('/workspace/bert/')
  32. from file_utils import cached_path
  33. from torch.nn import Module
  34. from torch.nn.parameter import Parameter
  35. import torch.nn.functional as F
  36. import torch.nn.init as init
  37. logger = logging.getLogger(__name__)
  38. PRETRAINED_MODEL_ARCHIVE_MAP = {
  39. 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
  40. 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
  41. 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
  42. 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
  43. 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
  44. 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
  45. 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
  46. }
  47. CONFIG_NAME = 'bert_config.json'
  48. WEIGHTS_NAME = 'pytorch_model.bin'
  49. TF_WEIGHTS_NAME = 'model.ckpt'
  50. def load_tf_weights_in_bert(model, tf_checkpoint_path):
  51. """ Load tf checkpoints in a pytorch model
  52. """
  53. try:
  54. import re
  55. import numpy as np
  56. import tensorflow as tf
  57. except ImportError:
  58. print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
  59. "https://www.tensorflow.org/install/ for installation instructions.")
  60. raise
  61. tf_path = os.path.abspath(tf_checkpoint_path)
  62. print("Converting TensorFlow checkpoint from {}".format(tf_path))
  63. # Load weights from TF model
  64. init_vars = tf.train.list_variables(tf_path)
  65. names = []
  66. arrays = []
  67. for name, shape in init_vars:
  68. print("Loading TF weight {} with shape {}".format(name, shape))
  69. array = tf.train.load_variable(tf_path, name)
  70. names.append(name)
  71. arrays.append(array)
  72. for name, array in zip(names, arrays):
  73. name = name.split('/')
  74. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  75. # which are not required for using pretrained model
  76. if any(n in ["adam_v", "adam_m"] for n in name):
  77. print("Skipping {}".format("/".join(name)))
  78. continue
  79. pointer = model
  80. for m_name in name:
  81. if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
  82. l = re.split(r'_(\d+)', m_name)
  83. else:
  84. l = [m_name]
  85. if l[0] == 'kernel' or l[0] == 'gamma':
  86. pointer = getattr(pointer, 'weight')
  87. elif l[0] == 'output_bias' or l[0] == 'beta':
  88. pointer = getattr(pointer, 'bias')
  89. elif l[0] == 'output_weights':
  90. pointer = getattr(pointer, 'weight')
  91. else:
  92. pointer = getattr(pointer, l[0])
  93. if len(l) >= 2:
  94. num = int(l[1])
  95. pointer = pointer[num]
  96. if m_name[-11:] == '_embeddings':
  97. pointer = getattr(pointer, 'weight')
  98. elif m_name == 'kernel':
  99. array = np.ascontiguousarray(np.transpose(array))
  100. try:
  101. assert pointer.shape == array.shape
  102. except AssertionError as e:
  103. e.args += (pointer.shape, array.shape)
  104. raise
  105. print("Initialize PyTorch weight {}".format(name))
  106. pointer.data = torch.from_numpy(array)
  107. return model
  108. def gelu(x):
  109. return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
  110. #used only for triton inference
  111. def bias_gelu(bias, y):
  112. x = bias + y
  113. return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
  114. # used specifically for training since torch.nn.functional.gelu breaks ONNX export
  115. def bias_gelu_training(bias, y):
  116. x = bias + y
  117. return torch.nn.functional.gelu(x) # Breaks ONNX export
  118. def bias_tanh(bias, y):
  119. x = bias + y
  120. return torch.tanh(x)
  121. def swish(x):
  122. return x * torch.sigmoid(x)
  123. #torch.nn.functional.gelu(x) # Breaks ONNX export
  124. ACT2FN = {"gelu": gelu, "bias_gelu": bias_gelu, "bias_tanh": bias_tanh, "relu": torch.nn.functional.relu, "swish": swish}
  125. class LinearActivation(Module):
  126. r"""Fused Linear and activation Module.
  127. """
  128. __constants__ = ['bias']
  129. def __init__(self, in_features, out_features, act='gelu', bias=True):
  130. super(LinearActivation, self).__init__()
  131. self.in_features = in_features
  132. self.out_features = out_features
  133. self.act_fn = nn.Identity() #
  134. self.biased_act_fn = None #
  135. self.bias = None #
  136. if isinstance(act, str) or (sys.version_info[0] == 2 and isinstance(act, unicode)): # For TorchScript
  137. if bias and not 'bias' in act: # compatibility
  138. act = 'bias_' + act #
  139. self.biased_act_fn = ACT2FN[act] #
  140. else:
  141. self.act_fn = ACT2FN[act]
  142. else:
  143. self.act_fn = act
  144. self.weight = Parameter(torch.Tensor(out_features, in_features))
  145. if bias:
  146. self.bias = Parameter(torch.Tensor(out_features))
  147. else:
  148. self.register_parameter('bias', None)
  149. self.reset_parameters()
  150. def reset_parameters(self):
  151. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  152. if self.bias is not None:
  153. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  154. bound = 1 / math.sqrt(fan_in)
  155. init.uniform_(self.bias, -bound, bound)
  156. def forward(self, input):
  157. if not self.bias is None:
  158. return self.biased_act_fn(self.bias, F.linear(input, self.weight, None))
  159. else:
  160. return self.act_fn(F.linear(input, self.weight, self.bias))
  161. def extra_repr(self):
  162. return 'in_features={}, out_features={}, bias={}'.format(
  163. self.in_features, self.out_features, self.bias is not None
  164. )
  165. class BertConfig(object):
  166. """Configuration class to store the configuration of a `BertModel`.
  167. """
  168. def __init__(self,
  169. vocab_size_or_config_json_file,
  170. hidden_size=768,
  171. num_hidden_layers=12,
  172. num_attention_heads=12,
  173. intermediate_size=3072,
  174. hidden_act="gelu",
  175. hidden_dropout_prob=0.1,
  176. attention_probs_dropout_prob=0.1,
  177. max_position_embeddings=512,
  178. type_vocab_size=2,
  179. initializer_range=0.02,
  180. output_all_encoded_layers=False):
  181. """Constructs BertConfig.
  182. Args:
  183. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
  184. hidden_size: Size of the encoder layers and the pooler layer.
  185. num_hidden_layers: Number of hidden layers in the Transformer encoder.
  186. num_attention_heads: Number of attention heads for each attention layer in
  187. the Transformer encoder.
  188. intermediate_size: The size of the "intermediate" (i.e., feed-forward)
  189. layer in the Transformer encoder.
  190. hidden_act: The non-linear activation function (function or string) in the
  191. encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
  192. hidden_dropout_prob: The dropout probabilitiy for all fully connected
  193. layers in the embeddings, encoder, and pooler.
  194. attention_probs_dropout_prob: The dropout ratio for the attention
  195. probabilities.
  196. max_position_embeddings: The maximum sequence length that this model might
  197. ever be used with. Typically set this to something large just in case
  198. (e.g., 512 or 1024 or 2048).
  199. type_vocab_size: The vocabulary size of the `token_type_ids` passed into
  200. `BertModel`.
  201. initializer_range: The sttdev of the truncated_normal_initializer for
  202. initializing all weight matrices.
  203. """
  204. if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
  205. and isinstance(vocab_size_or_config_json_file, unicode)):
  206. with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
  207. json_config = json.loads(reader.read())
  208. for key, value in json_config.items():
  209. self.__dict__[key] = value
  210. elif isinstance(vocab_size_or_config_json_file, int):
  211. self.vocab_size = vocab_size_or_config_json_file
  212. self.hidden_size = hidden_size
  213. self.num_hidden_layers = num_hidden_layers
  214. self.num_attention_heads = num_attention_heads
  215. self.hidden_act = hidden_act
  216. self.intermediate_size = intermediate_size
  217. self.hidden_dropout_prob = hidden_dropout_prob
  218. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  219. self.max_position_embeddings = max_position_embeddings
  220. self.type_vocab_size = type_vocab_size
  221. self.initializer_range = initializer_range
  222. self.output_all_encoded_layers = output_all_encoded_layers
  223. else:
  224. raise ValueError("First argument must be either a vocabulary size (int)"
  225. "or the path to a pretrained model config file (str)")
  226. @classmethod
  227. def from_dict(cls, json_object):
  228. """Constructs a `BertConfig` from a Python dictionary of parameters."""
  229. config = BertConfig(vocab_size_or_config_json_file=-1)
  230. for key, value in json_object.items():
  231. config.__dict__[key] = value
  232. return config
  233. @classmethod
  234. def from_json_file(cls, json_file):
  235. """Constructs a `BertConfig` from a json file of parameters."""
  236. with open(json_file, "r", encoding='utf-8') as reader:
  237. text = reader.read()
  238. return cls.from_dict(json.loads(text))
  239. def __repr__(self):
  240. return str(self.to_json_string())
  241. def to_dict(self):
  242. """Serializes this instance to a Python dictionary."""
  243. output = copy.deepcopy(self.__dict__)
  244. return output
  245. def to_json_string(self):
  246. """Serializes this instance to a JSON string."""
  247. return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
  248. class BertNonFusedLayerNorm(nn.Module):
  249. def __init__(self, hidden_size, eps=1e-12):
  250. """Construct a layernorm module in the TF style (epsilon inside the square root).
  251. """
  252. super(BertNonFusedLayerNorm, self).__init__()
  253. self.weight = nn.Parameter(torch.ones(hidden_size))
  254. self.bias = nn.Parameter(torch.zeros(hidden_size))
  255. self.variance_epsilon = eps
  256. def forward(self, x):
  257. u = x.mean(-1, keepdim=True)
  258. s = (x - u).pow(2).mean(-1, keepdim=True)
  259. x = (x - u) / torch.sqrt(s + self.variance_epsilon)
  260. return self.weight * x + self.bias
  261. try:
  262. import apex
  263. #apex.amp.register_half_function(apex.normalization.fused_layer_norm, 'FusedLayerNorm')
  264. import apex.normalization
  265. from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
  266. #apex.amp.register_float_function(apex.normalization.FusedLayerNorm, 'forward')
  267. #BertLayerNorm = apex.normalization.FusedLayerNorm
  268. APEX_IS_AVAILABLE = True
  269. except ImportError:
  270. print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
  271. #BertLayerNorm = BertNonFusedLayerNorm
  272. APEX_IS_AVAILABLE = False
  273. class BertLayerNorm(Module):
  274. def __init__(self, hidden_size, eps=1e-12):
  275. super(BertLayerNorm, self).__init__()
  276. self.shape = torch.Size((hidden_size,))
  277. self.eps = eps
  278. self.weight = nn.Parameter(torch.ones(hidden_size))
  279. self.bias = nn.Parameter(torch.zeros(hidden_size))
  280. self.apex_enabled = APEX_IS_AVAILABLE
  281. @torch.jit.unused
  282. def fused_layer_norm(self, x):
  283. return FusedLayerNormAffineFunction.apply(
  284. x, self.weight, self.bias, self.shape, self.eps)
  285. def forward(self, x):
  286. if self.apex_enabled and not torch.jit.is_scripting():
  287. x = self.fused_layer_norm(x)
  288. else:
  289. u = x.mean(-1, keepdim=True)
  290. s = (x - u).pow(2).mean(-1, keepdim=True)
  291. x = (x - u) / torch.sqrt(s + self.eps)
  292. x = self.weight * x + self.bias
  293. return x
  294. class BertEmbeddings(nn.Module):
  295. """Construct the embeddings from word, position and token_type embeddings.
  296. """
  297. def __init__(self, config):
  298. super(BertEmbeddings, self).__init__()
  299. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
  300. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  301. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  302. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  303. # any TensorFlow checkpoint file
  304. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  305. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  306. def forward(self, input_ids, token_type_ids):
  307. seq_length = input_ids.size(1)
  308. position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
  309. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  310. words_embeddings = self.word_embeddings(input_ids)
  311. position_embeddings = self.position_embeddings(position_ids)
  312. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  313. embeddings = words_embeddings + position_embeddings + token_type_embeddings
  314. embeddings = self.LayerNorm(embeddings)
  315. embeddings = self.dropout(embeddings)
  316. return embeddings
  317. class BertSelfAttention(nn.Module):
  318. def __init__(self, config):
  319. super(BertSelfAttention, self).__init__()
  320. if config.hidden_size % config.num_attention_heads != 0:
  321. raise ValueError(
  322. "The hidden size (%d) is not a multiple of the number of attention "
  323. "heads (%d)" % (config.hidden_size, config.num_attention_heads))
  324. self.num_attention_heads = config.num_attention_heads
  325. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  326. self.all_head_size = self.num_attention_heads * self.attention_head_size
  327. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  328. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  329. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  330. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  331. def transpose_for_scores(self, x):
  332. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  333. x = torch.reshape(x, new_x_shape)
  334. return x.permute(0, 2, 1, 3)
  335. def transpose_key_for_scores(self, x):
  336. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  337. x = torch.reshape(x, new_x_shape)
  338. return x.permute(0, 2, 3, 1)
  339. def forward(self, hidden_states, attention_mask):
  340. mixed_query_layer = self.query(hidden_states)
  341. mixed_key_layer = self.key(hidden_states)
  342. mixed_value_layer = self.value(hidden_states)
  343. query_layer = self.transpose_for_scores(mixed_query_layer)
  344. key_layer = self.transpose_key_for_scores(mixed_key_layer)
  345. value_layer = self.transpose_for_scores(mixed_value_layer)
  346. # Take the dot product between "query" and "key" to get the raw attention scores.
  347. attention_scores = torch.matmul(query_layer, key_layer)
  348. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  349. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  350. attention_scores = attention_scores + attention_mask
  351. # Normalize the attention scores to probabilities.
  352. attention_probs = F.softmax(attention_scores, dim=-1)
  353. # This is actually dropping out entire tokens to attend to, which might
  354. # seem a bit unusual, but is taken from the original Transformer paper.
  355. attention_probs = self.dropout(attention_probs)
  356. context_layer = torch.matmul(attention_probs, value_layer)
  357. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  358. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  359. context_layer = torch.reshape(context_layer, new_context_layer_shape)
  360. return context_layer
  361. class BertSelfOutput(nn.Module):
  362. def __init__(self, config):
  363. super(BertSelfOutput, self).__init__()
  364. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  365. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  366. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  367. def forward(self, hidden_states, input_tensor):
  368. hidden_states = self.dense(hidden_states)
  369. hidden_states = self.dropout(hidden_states)
  370. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  371. return hidden_states
  372. class BertAttention(nn.Module):
  373. def __init__(self, config):
  374. super(BertAttention, self).__init__()
  375. self.self = BertSelfAttention(config)
  376. self.output = BertSelfOutput(config)
  377. def forward(self, input_tensor, attention_mask):
  378. self_output = self.self(input_tensor, attention_mask)
  379. attention_output = self.output(self_output, input_tensor)
  380. return attention_output
  381. class BertIntermediate(nn.Module):
  382. def __init__(self, config):
  383. super(BertIntermediate, self).__init__()
  384. self.dense_act = LinearActivation(config.hidden_size, config.intermediate_size, act=config.hidden_act)
  385. def forward(self, hidden_states):
  386. hidden_states = self.dense_act(hidden_states)
  387. return hidden_states
  388. class BertOutput(nn.Module):
  389. def __init__(self, config):
  390. super(BertOutput, self).__init__()
  391. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  392. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  393. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  394. def forward(self, hidden_states, input_tensor):
  395. hidden_states = self.dense(hidden_states)
  396. hidden_states = self.dropout(hidden_states)
  397. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  398. return hidden_states
  399. class BertLayer(nn.Module):
  400. def __init__(self, config):
  401. super(BertLayer, self).__init__()
  402. self.attention = BertAttention(config)
  403. self.intermediate = BertIntermediate(config)
  404. self.output = BertOutput(config)
  405. def forward(self, hidden_states, attention_mask):
  406. attention_output = self.attention(hidden_states, attention_mask)
  407. intermediate_output = self.intermediate(attention_output)
  408. layer_output = self.output(intermediate_output, attention_output)
  409. return layer_output
  410. class BertEncoder(nn.Module):
  411. def __init__(self, config):
  412. super(BertEncoder, self).__init__()
  413. self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
  414. self.output_all_encoded_layers = config.output_all_encoded_layers
  415. self._checkpoint_activations = False
  416. @torch.jit.unused
  417. def checkpointed_forward(self, hidden_states, attention_mask):
  418. def custom(start, end):
  419. def custom_forward(*inputs):
  420. layers = self.layer[start:end]
  421. x_ = inputs[0]
  422. for layer in layers:
  423. x_ = layer(x_, inputs[1])
  424. return x_
  425. return custom_forward
  426. l = 0
  427. num_layers = len(self.layer)
  428. chunk_length = math.ceil(math.sqrt(num_layers))
  429. while l < num_layers:
  430. hidden_states = checkpoint.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1)
  431. l += chunk_length
  432. return hidden_states
  433. def forward(self, hidden_states, attention_mask):
  434. all_encoder_layers = []
  435. if self._checkpoint_activations:
  436. hidden_states = self.checkpointed_forward(hidden_states, attention_mask)
  437. else:
  438. for i,layer_module in enumerate(self.layer):
  439. hidden_states = layer_module(hidden_states, attention_mask)
  440. if self.output_all_encoded_layers:
  441. all_encoder_layers.append(hidden_states)
  442. if not self.output_all_encoded_layers or self._checkpoint_activations:
  443. all_encoder_layers.append(hidden_states)
  444. return all_encoder_layers
  445. class BertPooler(nn.Module):
  446. def __init__(self, config):
  447. super(BertPooler, self).__init__()
  448. self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act="tanh")
  449. def forward(self, hidden_states):
  450. # We "pool" the model by simply taking the hidden state corresponding
  451. # to the first token.
  452. first_token_tensor = hidden_states[:, 0]
  453. pooled_output = self.dense_act(first_token_tensor)
  454. return pooled_output
  455. class BertPredictionHeadTransform(nn.Module):
  456. def __init__(self, config):
  457. super(BertPredictionHeadTransform, self).__init__()
  458. self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act=config.hidden_act)
  459. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  460. def forward(self, hidden_states):
  461. hidden_states = self.dense_act(hidden_states)
  462. hidden_states = self.LayerNorm(hidden_states)
  463. return hidden_states
  464. class BertLMPredictionHead(nn.Module):
  465. def __init__(self, config, bert_model_embedding_weights):
  466. super(BertLMPredictionHead, self).__init__()
  467. self.transform = BertPredictionHeadTransform(config)
  468. # The output weights are the same as the input embeddings, but there is
  469. # an output-only bias for each token.
  470. self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
  471. bert_model_embedding_weights.size(0),
  472. bias=False)
  473. self.decoder.weight = bert_model_embedding_weights
  474. self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
  475. def forward(self, hidden_states):
  476. hidden_states = self.transform(hidden_states)
  477. hidden_states = self.decoder(hidden_states) + self.bias
  478. return hidden_states
  479. class BertOnlyMLMHead(nn.Module):
  480. def __init__(self, config, bert_model_embedding_weights):
  481. super(BertOnlyMLMHead, self).__init__()
  482. self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
  483. def forward(self, sequence_output):
  484. prediction_scores = self.predictions(sequence_output)
  485. return prediction_scores
  486. class BertOnlyNSPHead(nn.Module):
  487. def __init__(self, config):
  488. super(BertOnlyNSPHead, self).__init__()
  489. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  490. def forward(self, pooled_output):
  491. seq_relationship_score = self.seq_relationship(pooled_output)
  492. return seq_relationship_score
  493. class BertPreTrainingHeads(nn.Module):
  494. def __init__(self, config, bert_model_embedding_weights):
  495. super(BertPreTrainingHeads, self).__init__()
  496. self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
  497. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  498. def forward(self, sequence_output, pooled_output):
  499. prediction_scores = self.predictions(sequence_output)
  500. seq_relationship_score = self.seq_relationship(pooled_output)
  501. return prediction_scores, seq_relationship_score
  502. class BertPreTrainedModel(nn.Module):
  503. """ An abstract class to handle weights initialization and
  504. a simple interface for dowloading and loading pretrained models.
  505. """
  506. def __init__(self, config, *inputs, **kwargs):
  507. super(BertPreTrainedModel, self).__init__()
  508. if not isinstance(config, BertConfig):
  509. raise ValueError(
  510. "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
  511. "To create a model from a Google pretrained model use "
  512. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  513. self.__class__.__name__, self.__class__.__name__
  514. ))
  515. self.config = config
  516. def init_bert_weights(self, module):
  517. """ Initialize the weights.
  518. """
  519. if isinstance(module, (nn.Linear, nn.Embedding)):
  520. # Slightly different from the TF version which uses truncated_normal for initialization
  521. # cf https://github.com/pytorch/pytorch/pull/5617
  522. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  523. elif isinstance(module, BertLayerNorm):
  524. module.bias.data.zero_()
  525. module.weight.data.fill_(1.0)
  526. if isinstance(module, nn.Linear) and module.bias is not None:
  527. module.bias.data.zero_()
  528. def checkpoint_activations(self, val):
  529. def _apply_flag(module):
  530. if hasattr(module, "_checkpoint_activations"):
  531. module._checkpoint_activations=val
  532. self.apply(_apply_flag)
  533. def enable_apex(self, val):
  534. def _apply_flag(module):
  535. if hasattr(module, "apex_enabled"):
  536. module.apex_enabled=val
  537. self.apply(_apply_flag)
  538. @classmethod
  539. def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
  540. from_tf=False, *inputs, **kwargs):
  541. """
  542. Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
  543. Download and cache the pre-trained model file if needed.
  544. Params:
  545. pretrained_model_name_or_path: either:
  546. - a str with the name of a pre-trained model to load selected in the list of:
  547. . `bert-base-uncased`
  548. . `bert-large-uncased`
  549. . `bert-base-cased`
  550. . `bert-large-cased`
  551. . `bert-base-multilingual-uncased`
  552. . `bert-base-multilingual-cased`
  553. . `bert-base-chinese`
  554. - a path or url to a pretrained model archive containing:
  555. . `bert_config.json` a configuration file for the model
  556. . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
  557. - a path or url to a pretrained model archive containing:
  558. . `bert_config.json` a configuration file for the model
  559. . `model.chkpt` a TensorFlow checkpoint
  560. from_tf: should we load the weights from a locally saved TensorFlow checkpoint
  561. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
  562. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
  563. *inputs, **kwargs: additional input for the specific Bert class
  564. (ex: num_labels for BertForSequenceClassification)
  565. """
  566. if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
  567. archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
  568. else:
  569. archive_file = pretrained_model_name_or_path
  570. # redirect to the cache, if necessary
  571. try:
  572. resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
  573. except EnvironmentError:
  574. logger.error(
  575. "Model name '{}' was not found in model name list ({}). "
  576. "We assumed '{}' was a path or url but couldn't find any file "
  577. "associated to this path or url.".format(
  578. pretrained_model_name_or_path,
  579. ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
  580. archive_file))
  581. return None
  582. if resolved_archive_file == archive_file:
  583. logger.info("loading archive file {}".format(archive_file))
  584. else:
  585. logger.info("loading archive file {} from cache at {}".format(
  586. archive_file, resolved_archive_file))
  587. tempdir = None
  588. if os.path.isdir(resolved_archive_file) or from_tf:
  589. serialization_dir = resolved_archive_file
  590. else:
  591. # Extract archive to temp dir
  592. tempdir = tempfile.mkdtemp()
  593. logger.info("extracting archive file {} to temp dir {}".format(
  594. resolved_archive_file, tempdir))
  595. with tarfile.open(resolved_archive_file, 'r:gz') as archive:
  596. archive.extractall(tempdir)
  597. serialization_dir = tempdir
  598. # Load config
  599. config_file = os.path.join(serialization_dir, CONFIG_NAME)
  600. config = BertConfig.from_json_file(config_file)
  601. logger.info("Model config {}".format(config))
  602. # Instantiate model.
  603. model = cls(config, *inputs, **kwargs)
  604. if state_dict is None and not from_tf:
  605. weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
  606. state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
  607. if tempdir:
  608. # Clean up temp dir
  609. shutil.rmtree(tempdir)
  610. if from_tf:
  611. # Directly load from a TensorFlow checkpoint
  612. weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
  613. return load_tf_weights_in_bert(model, weights_path)
  614. # Load from a PyTorch state_dict
  615. old_keys = []
  616. new_keys = []
  617. for key in state_dict.keys():
  618. new_key = None
  619. if 'gamma' in key:
  620. new_key = key.replace('gamma', 'weight')
  621. if 'beta' in key:
  622. new_key = key.replace('beta', 'bias')
  623. if new_key:
  624. old_keys.append(key)
  625. new_keys.append(new_key)
  626. for old_key, new_key in zip(old_keys, new_keys):
  627. state_dict[new_key] = state_dict.pop(old_key)
  628. missing_keys = []
  629. unexpected_keys = []
  630. error_msgs = []
  631. # copy state_dict so _load_from_state_dict can modify it
  632. metadata = getattr(state_dict, '_metadata', None)
  633. state_dict = state_dict.copy()
  634. if metadata is not None:
  635. state_dict._metadata = metadata
  636. def load(module, prefix=''):
  637. local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
  638. module._load_from_state_dict(
  639. state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
  640. for name, child in module._modules.items():
  641. if child is not None:
  642. load(child, prefix + name + '.')
  643. start_prefix = ''
  644. if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
  645. start_prefix = 'bert.'
  646. load(model, prefix=start_prefix)
  647. if len(missing_keys) > 0:
  648. logger.info("Weights of {} not initialized from pretrained model: {}".format(
  649. model.__class__.__name__, missing_keys))
  650. if len(unexpected_keys) > 0:
  651. logger.info("Weights from pretrained model not used in {}: {}".format(
  652. model.__class__.__name__, unexpected_keys))
  653. if len(error_msgs) > 0:
  654. raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
  655. model.__class__.__name__, "\n\t".join(error_msgs)))
  656. return model
  657. class BertModel(BertPreTrainedModel):
  658. """BERT model ("Bidirectional Embedding Representations from a Transformer").
  659. Params:
  660. config: a BertConfig class instance with the configuration to build a new model
  661. Inputs:
  662. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  663. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  664. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  665. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  666. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  667. a `sentence B` token (see BERT paper for more details).
  668. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  669. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  670. input sequence length in the current batch. It's the mask that we typically use for attention when
  671. a batch has varying length sentences.
  672. Outputs: Tuple of (encoded_layers, pooled_output)
  673. `encoded_layers`: controled by `output_all_encoded_layers` argument:
  674. - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
  675. of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
  676. encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
  677. - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
  678. to the last attention block of shape [batch_size, sequence_length, hidden_size],
  679. `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
  680. classifier pretrained on top of the hidden state associated to the first character of the
  681. input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
  682. Example usage:
  683. ```python
  684. # Already been converted into WordPiece token ids
  685. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  686. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  687. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  688. config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  689. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  690. model = modeling.BertModel(config=config)
  691. all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
  692. ```
  693. """
  694. def __init__(self, config):
  695. super(BertModel, self).__init__(config)
  696. self.embeddings = BertEmbeddings(config)
  697. self.encoder = BertEncoder(config)
  698. self.pooler = BertPooler(config)
  699. self.apply(self.init_bert_weights)
  700. self.output_all_encoded_layers = config.output_all_encoded_layers
  701. def forward(self, input_ids, token_type_ids, attention_mask):
  702. # We create a 3D attention mask from a 2D tensor mask.
  703. # Sizes are [batch_size, 1, 1, to_seq_length]
  704. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  705. # this attention mask is more simple than the triangular masking of causal attention
  706. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  707. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  708. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  709. # masked positions, this operation will create a tensor which is 0.0 for
  710. # positions we want to attend and -10000.0 for masked positions.
  711. # Since we are adding it to the raw scores before the softmax, this is
  712. # effectively the same as removing these entirely.
  713. extended_attention_mask = extended_attention_mask.to(dtype=self.embeddings.word_embeddings.weight.dtype) # fp16 compatibility
  714. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  715. embedding_output = self.embeddings(input_ids, token_type_ids)
  716. encoded_layers = self.encoder(embedding_output, extended_attention_mask)
  717. sequence_output = encoded_layers[-1]
  718. pooled_output = self.pooler(sequence_output)
  719. if not self.output_all_encoded_layers:
  720. encoded_layers = encoded_layers[-1:]
  721. return encoded_layers, pooled_output
  722. class BertForPreTraining(BertPreTrainedModel):
  723. """BERT model with pre-training heads.
  724. This module comprises the BERT model followed by the two pre-training heads:
  725. - the masked language modeling head, and
  726. - the next sentence classification head.
  727. Params:
  728. config: a BertConfig class instance with the configuration to build a new model.
  729. Inputs:
  730. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  731. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  732. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  733. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  734. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  735. a `sentence B` token (see BERT paper for more details).
  736. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  737. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  738. input sequence length in the current batch. It's the mask that we typically use for attention when
  739. a batch has varying length sentences.
  740. `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
  741. with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
  742. is only computed for the labels set in [0, ..., vocab_size]
  743. `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
  744. with indices selected in [0, 1].
  745. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
  746. Outputs:
  747. if `masked_lm_labels` and `next_sentence_label` are not `None`:
  748. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  749. sentence classification loss.
  750. if `masked_lm_labels` or `next_sentence_label` is `None`:
  751. Outputs a tuple comprising
  752. - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
  753. - the next sentence classification logits of shape [batch_size, 2].
  754. Example usage:
  755. ```python
  756. # Already been converted into WordPiece token ids
  757. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  758. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  759. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  760. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  761. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  762. model = BertForPreTraining(config)
  763. masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
  764. ```
  765. """
  766. def __init__(self, config):
  767. super(BertForPreTraining, self).__init__(config)
  768. self.bert = BertModel(config)
  769. self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
  770. self.apply(self.init_bert_weights)
  771. def forward(self, input_ids, token_type_ids, attention_mask):
  772. encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  773. sequence_output = encoded_layers[-1]
  774. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  775. return prediction_scores, seq_relationship_score
  776. class BertForMaskedLM(BertPreTrainedModel):
  777. """BERT model with the masked language modeling head.
  778. This module comprises the BERT model followed by the masked language modeling head.
  779. Params:
  780. config: a BertConfig class instance with the configuration to build a new model.
  781. Inputs:
  782. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  783. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  784. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  785. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  786. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  787. a `sentence B` token (see BERT paper for more details).
  788. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  789. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  790. input sequence length in the current batch. It's the mask that we typically use for attention when
  791. a batch has varying length sentences.
  792. `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
  793. with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
  794. is only computed for the labels set in [0, ..., vocab_size]
  795. Outputs:
  796. if `masked_lm_labels` is not `None`:
  797. Outputs the masked language modeling loss.
  798. if `masked_lm_labels` is `None`:
  799. Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
  800. Example usage:
  801. ```python
  802. # Already been converted into WordPiece token ids
  803. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  804. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  805. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  806. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  807. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  808. model = BertForMaskedLM(config)
  809. masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
  810. ```
  811. """
  812. def __init__(self, config):
  813. super(BertForMaskedLM, self).__init__(config)
  814. self.bert = BertModel(config)
  815. self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
  816. self.apply(self.init_bert_weights)
  817. def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
  818. encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
  819. sequence_output = encoded_layers[-1]
  820. prediction_scores = self.cls(sequence_output)
  821. if masked_lm_labels is not None:
  822. loss_fct = CrossEntropyLoss(ignore_index=-1)
  823. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
  824. return masked_lm_loss
  825. else:
  826. return prediction_scores
  827. class BertForNextSentencePrediction(BertPreTrainedModel):
  828. """BERT model with next sentence prediction head.
  829. This module comprises the BERT model followed by the next sentence classification head.
  830. Params:
  831. config: a BertConfig class instance with the configuration to build a new model.
  832. Inputs:
  833. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  834. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  835. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  836. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  837. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  838. a `sentence B` token (see BERT paper for more details).
  839. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  840. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  841. input sequence length in the current batch. It's the mask that we typically use for attention when
  842. a batch has varying length sentences.
  843. `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
  844. with indices selected in [0, 1].
  845. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
  846. Outputs:
  847. if `next_sentence_label` is not `None`:
  848. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  849. sentence classification loss.
  850. if `next_sentence_label` is `None`:
  851. Outputs the next sentence classification logits of shape [batch_size, 2].
  852. Example usage:
  853. ```python
  854. # Already been converted into WordPiece token ids
  855. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  856. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  857. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  858. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  859. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  860. model = BertForNextSentencePrediction(config)
  861. seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
  862. ```
  863. """
  864. def __init__(self, config):
  865. super(BertForNextSentencePrediction, self).__init__(config)
  866. self.bert = BertModel(config)
  867. self.cls = BertOnlyNSPHead(config)
  868. self.apply(self.init_bert_weights)
  869. def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
  870. _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  871. seq_relationship_score = self.cls( pooled_output)
  872. if next_sentence_label is not None:
  873. loss_fct = CrossEntropyLoss(ignore_index=-1)
  874. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  875. return next_sentence_loss
  876. else:
  877. return seq_relationship_score
  878. class BertForSequenceClassification(BertPreTrainedModel):
  879. """BERT model for classification.
  880. This module is composed of the BERT model with a linear layer on top of
  881. the pooled output.
  882. Params:
  883. `config`: a BertConfig class instance with the configuration to build a new model.
  884. `num_labels`: the number of classes for the classifier. Default = 2.
  885. Inputs:
  886. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  887. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  888. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  889. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  890. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  891. a `sentence B` token (see BERT paper for more details).
  892. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  893. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  894. input sequence length in the current batch. It's the mask that we typically use for attention when
  895. a batch has varying length sentences.
  896. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
  897. with indices selected in [0, ..., num_labels].
  898. Outputs:
  899. if `labels` is not `None`:
  900. Outputs the CrossEntropy classification loss of the output with the labels.
  901. if `labels` is `None`:
  902. Outputs the classification logits of shape [batch_size, num_labels].
  903. Example usage:
  904. ```python
  905. # Already been converted into WordPiece token ids
  906. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  907. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  908. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  909. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  910. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  911. num_labels = 2
  912. model = BertForSequenceClassification(config, num_labels)
  913. logits = model(input_ids, token_type_ids, input_mask)
  914. ```
  915. """
  916. def __init__(self, config, num_labels):
  917. super(BertForSequenceClassification, self).__init__(config)
  918. self.num_labels = num_labels
  919. self.bert = BertModel(config)
  920. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  921. self.classifier = nn.Linear(config.hidden_size, num_labels)
  922. self.apply(self.init_bert_weights)
  923. def forward(self, input_ids, token_type_ids=None, attention_mask=None):
  924. _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  925. pooled_output = self.dropout(pooled_output)
  926. return self.classifier(pooled_output)
  927. class BertForMultipleChoice(BertPreTrainedModel):
  928. """BERT model for multiple choice tasks.
  929. This module is composed of the BERT model with a linear layer on top of
  930. the pooled output.
  931. Params:
  932. `config`: a BertConfig class instance with the configuration to build a new model.
  933. `num_choices`: the number of classes for the classifier. Default = 2.
  934. Inputs:
  935. `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
  936. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  937. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  938. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
  939. with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
  940. and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
  941. `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
  942. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  943. input sequence length in the current batch. It's the mask that we typically use for attention when
  944. a batch has varying length sentences.
  945. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
  946. with indices selected in [0, ..., num_choices].
  947. Outputs:
  948. if `labels` is not `None`:
  949. Outputs the CrossEntropy classification loss of the output with the labels.
  950. if `labels` is `None`:
  951. Outputs the classification logits of shape [batch_size, num_labels].
  952. Example usage:
  953. ```python
  954. # Already been converted into WordPiece token ids
  955. input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
  956. input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
  957. token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
  958. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  959. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  960. num_choices = 2
  961. model = BertForMultipleChoice(config, num_choices)
  962. logits = model(input_ids, token_type_ids, input_mask)
  963. ```
  964. """
  965. def __init__(self, config, num_choices):
  966. super(BertForMultipleChoice, self).__init__(config)
  967. self.num_choices = num_choices
  968. self.bert = BertModel(config)
  969. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  970. self.classifier = nn.Linear(config.hidden_size, 1)
  971. self.apply(self.init_bert_weights)
  972. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  973. flat_input_ids = input_ids.view(-1, input_ids.size(-1))
  974. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
  975. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
  976. _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask)
  977. pooled_output = self.dropout(pooled_output)
  978. logits = self.classifier(pooled_output)
  979. reshaped_logits = logits.view(-1, self.num_choices)
  980. if labels is not None:
  981. loss_fct = CrossEntropyLoss()
  982. loss = loss_fct(reshaped_logits, labels)
  983. return loss
  984. else:
  985. return reshaped_logits
  986. class BertForTokenClassification(BertPreTrainedModel):
  987. """BERT model for token-level classification.
  988. This module is composed of the BERT model with a linear layer on top of
  989. the full hidden state of the last layer.
  990. Params:
  991. `config`: a BertConfig class instance with the configuration to build a new model.
  992. `num_labels`: the number of classes for the classifier. Default = 2.
  993. Inputs:
  994. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  995. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  996. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  997. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  998. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  999. a `sentence B` token (see BERT paper for more details).
  1000. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  1001. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  1002. input sequence length in the current batch. It's the mask that we typically use for attention when
  1003. a batch has varying length sentences.
  1004. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
  1005. with indices selected in [0, ..., num_labels].
  1006. Outputs:
  1007. if `labels` is not `None`:
  1008. Outputs the CrossEntropy classification loss of the output with the labels.
  1009. if `labels` is `None`:
  1010. Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
  1011. Example usage:
  1012. ```python
  1013. # Already been converted into WordPiece token ids
  1014. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  1015. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  1016. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  1017. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  1018. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  1019. num_labels = 2
  1020. model = BertForTokenClassification(config, num_labels)
  1021. logits = model(input_ids, token_type_ids, input_mask)
  1022. ```
  1023. """
  1024. def __init__(self, config, num_labels):
  1025. super(BertForTokenClassification, self).__init__(config)
  1026. self.num_labels = num_labels
  1027. self.bert = BertModel(config)
  1028. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1029. self.classifier = nn.Linear(config.hidden_size, num_labels)
  1030. self.apply(self.init_bert_weights)
  1031. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  1032. encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
  1033. sequence_output = encoded_layers[-1]
  1034. sequence_output = self.dropout(sequence_output)
  1035. logits = self.classifier(sequence_output)
  1036. if labels is not None:
  1037. loss_fct = CrossEntropyLoss()
  1038. # Only keep active parts of the loss
  1039. if attention_mask is not None:
  1040. active_loss = attention_mask.view(-1) == 1
  1041. active_logits = logits.view(-1, self.num_labels)[active_loss]
  1042. active_labels = labels.view(-1)[active_loss]
  1043. loss = loss_fct(active_logits, active_labels)
  1044. else:
  1045. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1046. return loss
  1047. else:
  1048. return logits
  1049. class BertForQuestionAnswering(BertPreTrainedModel):
  1050. """BERT model for Question Answering (span extraction).
  1051. This module is composed of the BERT model with a linear layer on top of
  1052. the sequence output that computes start_logits and end_logits
  1053. Params:
  1054. `config`: a BertConfig class instance with the configuration to build a new model.
  1055. Inputs:
  1056. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  1057. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  1058. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  1059. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  1060. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  1061. a `sentence B` token (see BERT paper for more details).
  1062. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  1063. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  1064. input sequence length in the current batch. It's the mask that we typically use for attention when
  1065. a batch has varying length sentences.
  1066. Outputs:
  1067. Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
  1068. position tokens of shape [batch_size, sequence_length].
  1069. Example usage:
  1070. ```python
  1071. # Already been converted into WordPiece token ids
  1072. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  1073. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  1074. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  1075. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  1076. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  1077. model = BertForQuestionAnswering(config)
  1078. start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
  1079. ```
  1080. """
  1081. def __init__(self, config):
  1082. super(BertForQuestionAnswering, self).__init__(config)
  1083. self.bert = BertModel(config)
  1084. # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
  1085. # self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1086. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  1087. self.apply(self.init_bert_weights)
  1088. def forward(self, input_ids, token_type_ids, attention_mask):
  1089. encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
  1090. sequence_output = encoded_layers[-1]
  1091. logits = self.qa_outputs(sequence_output)
  1092. start_logits, end_logits = logits.split(1, dim=-1)
  1093. start_logits = start_logits.squeeze(-1)
  1094. end_logits = end_logits.squeeze(-1)
  1095. return start_logits, end_logits