run_re.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """BERT finetuning runner."""
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import collections
  21. import csv
  22. import logging
  23. import os, sys
  24. import numpy as np
  25. import tensorflow as tf
  26. sys.path.append("/workspace/bert")
  27. import modeling
  28. import optimization
  29. import tokenization
  30. import time
  31. import horovod.tensorflow as hvd
  32. from utils.utils import LogEvalRunHook, LogTrainRunHook, setup_xla_flags
  33. from utils.gpu_affinity import set_affinity
  34. import utils.dllogger_class
  35. from dllogger import Verbosity
  36. flags = tf.flags
  37. FLAGS = flags.FLAGS
  38. ## Required parameters
  39. flags.DEFINE_string(
  40. "data_dir", None,
  41. "The input data dir. Should contain the .tsv files (or other data files) "
  42. "for the task.")
  43. flags.DEFINE_string(
  44. "bert_config_file", None,
  45. "The config json file corresponding to the pre-trained BERT model. "
  46. "This specifies the model architecture.")
  47. flags.DEFINE_string("task_name", None, "The name of the task to train.")
  48. flags.DEFINE_string("vocab_file", None,
  49. "The vocabulary file that the BERT model was trained on.")
  50. flags.DEFINE_string(
  51. "output_dir", None,
  52. "The output directory where the model checkpoints will be written.")
  53. ## Other parameters
  54. flags.DEFINE_string(
  55. "dllog_path", "/results/bert_dllog.json",
  56. "filename where dllogger writes to")
  57. flags.DEFINE_string(
  58. "init_checkpoint", None,
  59. "Initial checkpoint (usually from a pre-trained BERT model).")
  60. flags.DEFINE_bool(
  61. "do_lower_case", True,
  62. "Whether to lower case the input text. Should be True for uncased "
  63. "models and False for cased models.")
  64. flags.DEFINE_integer(
  65. "max_seq_length", 128,
  66. "The maximum total input sequence length after WordPiece tokenization. "
  67. "Sequences longer than this will be truncated, and sequences shorter "
  68. "than this will be padded.")
  69. flags.DEFINE_bool("do_train", False, "Whether to run training.")
  70. flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
  71. flags.DEFINE_bool(
  72. "do_predict", False,
  73. "Whether to run the model in inference mode on the test set.")
  74. flags.DEFINE_integer("train_batch_size", 16, "Total batch size for training.")
  75. flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
  76. flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
  77. flags.DEFINE_float("learning_rate", 5e-6, "The initial learning rate for Adam.")
  78. flags.DEFINE_float("num_train_epochs", 3.0,
  79. "Total number of training epochs to perform.")
  80. flags.DEFINE_float(
  81. "warmup_proportion", 0.1,
  82. "Proportion of training to perform linear learning rate warmup for. "
  83. "E.g., 0.1 = 10% of training.")
  84. flags.DEFINE_integer("save_checkpoints_steps", 1000,
  85. "How often to save the model checkpoint.")
  86. flags.DEFINE_integer("iterations_per_loop", 1000,
  87. "How many steps to make in each estimator call.")
  88. tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
  89. flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs")
  90. flags.DEFINE_bool("amp", True, "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.")
  91. flags.DEFINE_bool("use_xla", True, "Whether to enable XLA JIT compilation.")
  92. class InputExample(object):
  93. """A single training/test example for simple sequence classification."""
  94. def __init__(self, guid, text_a, text_b=None, label=None):
  95. """Constructs a InputExample.
  96. Args:
  97. guid: Unique id for the example.
  98. text_a: string. The untokenized text of the first sequence. For single
  99. sequence tasks, only this sequence must be specified.
  100. text_b: (Optional) string. The untokenized text of the second sequence.
  101. Only must be specified for sequence pair tasks.
  102. label: (Optional) string. The label of the example. This should be
  103. specified for train and dev examples, but not for test examples.
  104. """
  105. self.guid = guid
  106. self.text_a = text_a
  107. self.text_b = text_b
  108. self.label = label
  109. class PaddingInputExample(object):
  110. """Fake example so the num input examples is a multiple of the batch size.
  111. When running eval/predict on the TPU, we need to pad the number of examples
  112. to be a multiple of the batch size, because the TPU requires a fixed batch
  113. size. The alternative is to drop the last batch, which is bad because it means
  114. the entire output data won't be generated.
  115. We use this class instead of `None` because treating `None` as padding
  116. battches could cause silent errors.
  117. """
  118. class InputFeatures(object):
  119. """A single set of features of data."""
  120. def __init__(self,
  121. input_ids,
  122. input_mask,
  123. segment_ids,
  124. label_id,
  125. is_real_example=True):
  126. self.input_ids = input_ids
  127. self.input_mask = input_mask
  128. self.segment_ids = segment_ids
  129. self.label_id = label_id
  130. self.is_real_example = is_real_example
  131. class DataProcessor(object):
  132. """Base class for data converters for sequence classification data sets."""
  133. def get_train_examples(self, data_dir):
  134. """Gets a collection of `InputExample`s for the train set."""
  135. raise NotImplementedError()
  136. def get_dev_examples(self, data_dir):
  137. """Gets a collection of `InputExample`s for the dev set."""
  138. raise NotImplementedError()
  139. def get_test_examples(self, data_dir):
  140. """Gets a collection of `InputExample`s for prediction."""
  141. raise NotImplementedError()
  142. def get_labels(self):
  143. """Gets the list of labels for this data set."""
  144. raise NotImplementedError()
  145. @classmethod
  146. def _read_tsv(cls, input_file, quotechar=None):
  147. """Reads a tab separated value file."""
  148. with tf.io.gfile.GFile(input_file, "r") as f:
  149. reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
  150. lines = []
  151. for line in reader:
  152. lines.append(line)
  153. return lines
  154. class BioBERTChemprotProcessor(DataProcessor):
  155. """Processor for the BioBERT data set obtained from
  156. (https://github.com/arwhirang/recursive_chemprot/tree/master/Demo/tree_LSTM/data).
  157. """
  158. def get_train_examples(self, data_dir, file_name="trainingPosit_chem"):
  159. """See base class."""
  160. return self._create_examples(
  161. self._read_tsv(os.path.join(data_dir, file_name)), "train")
  162. def get_dev_examples(self, data_dir, file_name="developPosit_chem"):
  163. """See base class."""
  164. return self._create_examples(
  165. self._read_tsv(os.path.join(data_dir, file_name)), "dev")
  166. def get_test_examples(self, data_dir, file_name="testPosit_chem"):
  167. """See base class."""
  168. return self._create_examples(
  169. self._read_tsv(os.path.join(data_dir, file_name)), "test")
  170. def get_labels(self):
  171. """See base class."""
  172. return ["CPR:3", "CPR:4", "CPR:5", "CPR:6", "CPR:9", "False"]
  173. def _create_examples(self, lines, set_type):
  174. """Creates examples for the training and dev sets."""
  175. examples = []
  176. for (i, line) in enumerate(lines):
  177. guid = "%s-%s" % (set_type, i)
  178. if set_type == "test":
  179. text_a = tokenization.convert_to_unicode(line[1])
  180. label = "False"
  181. else:
  182. text_a = tokenization.convert_to_unicode(line[1])
  183. label = tokenization.convert_to_unicode(line[2])
  184. if label == "True":
  185. label = tokenization.convert_to_unicode(line[3])
  186. examples.append(
  187. InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
  188. return examples
  189. class _ChemProtProcessor(DataProcessor):
  190. """Processor for the ChemProt data set."""
  191. def get_train_examples(self, data_dir):
  192. """See base class."""
  193. return self._create_examples(
  194. self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  195. def get_dev_examples(self, data_dir, file_name="dev.tsv"):
  196. """See base class."""
  197. return self._create_examples(
  198. self._read_tsv(os.path.join(data_dir, file_name)), "dev")
  199. def get_test_examples(self, data_dir, file_name="test.tsv"):
  200. """See base class."""
  201. return self._create_examples(
  202. self._read_tsv(os.path.join(data_dir, file_name)), "test")
  203. def _create_examples(self, lines, set_type):
  204. """Creates examples for the training and dev sets."""
  205. examples = []
  206. for (i, line) in enumerate(lines):
  207. # skip header
  208. if i == 0:
  209. continue
  210. guid = line[0]
  211. text_a = tokenization.convert_to_unicode(line[1])
  212. if set_type == "test":
  213. label = self.get_labels()[-1]
  214. else:
  215. try:
  216. label = tokenization.convert_to_unicode(line[2])
  217. except IndexError:
  218. logging.exception(line)
  219. exit(1)
  220. examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
  221. return examples
  222. class ChemProtProcessor(_ChemProtProcessor):
  223. def get_labels(self):
  224. """See base class."""
  225. return ["CPR:3", "CPR:4", "CPR:5", "CPR:6", "CPR:9", "false"]
  226. class MedNLIProcessor(DataProcessor):
  227. def get_train_examples(self, data_dir):
  228. """See base class."""
  229. return self._create_examples(
  230. self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  231. def get_dev_examples(self, data_dir, file_name="dev.tsv"):
  232. """See base class."""
  233. return self._create_examples(
  234. self._read_tsv(os.path.join(data_dir, file_name)), "dev")
  235. def get_test_examples(self, data_dir, file_name="test.tsv"):
  236. """See base class."""
  237. return self._create_examples(
  238. self._read_tsv(os.path.join(data_dir, file_name)), "test")
  239. def get_labels(self):
  240. """See base class."""
  241. return ['contradiction', 'entailment', 'neutral']
  242. def _create_examples(self, lines, set_type):
  243. """Creates examples for the training and dev sets."""
  244. examples = []
  245. for (i, line) in enumerate(lines):
  246. if i == 0:
  247. continue
  248. guid = line[1]
  249. text_a = tokenization.convert_to_unicode(line[2])
  250. text_b = tokenization.convert_to_unicode(line[3])
  251. if set_type == "test":
  252. label = self.get_labels()[-1]
  253. else:
  254. label = tokenization.convert_to_unicode(line[0])
  255. examples.append(
  256. InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  257. return examples
  258. def convert_single_example(ex_index, example, label_list, max_seq_length,
  259. tokenizer):
  260. """Converts a single `InputExample` into a single `InputFeatures`."""
  261. if isinstance(example, PaddingInputExample):
  262. return InputFeatures(
  263. input_ids=[0] * max_seq_length,
  264. input_mask=[0] * max_seq_length,
  265. segment_ids=[0] * max_seq_length,
  266. label_id=0,
  267. is_real_example=False)
  268. label_map = {}
  269. for (i, label) in enumerate(label_list):
  270. label_map[label] = i
  271. tokens_a = tokenizer.tokenize(example.text_a)
  272. tokens_b = None
  273. if example.text_b:
  274. tokens_b = tokenizer.tokenize(example.text_b)
  275. if tokens_b:
  276. # Modifies `tokens_a` and `tokens_b` in place so that the total
  277. # length is less than the specified length.
  278. # Account for [CLS], [SEP], [SEP] with "- 3"
  279. _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
  280. else:
  281. # Account for [CLS] and [SEP] with "- 2"
  282. if len(tokens_a) > max_seq_length - 2:
  283. tokens_a = tokens_a[0:(max_seq_length - 2)]
  284. # The convention in BERT is:
  285. # (a) For sequence pairs:
  286. # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
  287. # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
  288. # (b) For single sequences:
  289. # tokens: [CLS] the dog is hairy . [SEP]
  290. # type_ids: 0 0 0 0 0 0 0
  291. #
  292. # Where "type_ids" are used to indicate whether this is the first
  293. # sequence or the second sequence. The embedding vectors for `type=0` and
  294. # `type=1` were learned during pre-training and are added to the wordpiece
  295. # embedding vector (and position vector). This is not *strictly* necessary
  296. # since the [SEP] token unambiguously separates the sequences, but it makes
  297. # it easier for the model to learn the concept of sequences.
  298. #
  299. # For classification tasks, the first vector (corresponding to [CLS]) is
  300. # used as the "sentence vector". Note that this only makes sense because
  301. # the entire model is fine-tuned.
  302. tokens = []
  303. segment_ids = []
  304. tokens.append("[CLS]")
  305. segment_ids.append(0)
  306. for token in tokens_a:
  307. tokens.append(token)
  308. segment_ids.append(0)
  309. tokens.append("[SEP]")
  310. segment_ids.append(0)
  311. if tokens_b:
  312. for token in tokens_b:
  313. tokens.append(token)
  314. segment_ids.append(1)
  315. tokens.append("[SEP]")
  316. segment_ids.append(1)
  317. input_ids = tokenizer.convert_tokens_to_ids(tokens)
  318. # The mask has 1 for real tokens and 0 for padding tokens. Only real
  319. # tokens are attended to.
  320. input_mask = [1] * len(input_ids)
  321. # Zero-pad up to the sequence length.
  322. while len(input_ids) < max_seq_length:
  323. input_ids.append(0)
  324. input_mask.append(0)
  325. segment_ids.append(0)
  326. assert len(input_ids) == max_seq_length
  327. assert len(input_mask) == max_seq_length
  328. assert len(segment_ids) == max_seq_length
  329. label_id = label_map[example.label]
  330. if ex_index < 5:
  331. tf.compat.v1.logging.info("*** Example ***")
  332. tf.compat.v1.logging.info("guid: %s" % (example.guid))
  333. tf.compat.v1.logging.info("tokens: %s" % " ".join(
  334. [tokenization.printable_text(x) for x in tokens]))
  335. tf.compat.v1.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
  336. tf.compat.v1.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
  337. tf.compat.v1.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
  338. tf.compat.v1.logging.info("label: %s (id = %d)" % (example.label, label_id))
  339. feature = InputFeatures(
  340. input_ids=input_ids,
  341. input_mask=input_mask,
  342. segment_ids=segment_ids,
  343. label_id=label_id,
  344. is_real_example=True)
  345. return feature
  346. def file_based_convert_examples_to_features(
  347. examples, label_list, max_seq_length, tokenizer, output_file):
  348. """Convert a set of `InputExample`s to a TFRecord file."""
  349. writer = tf.python_io.TFRecordWriter(output_file)
  350. for (ex_index, example) in enumerate(examples):
  351. if ex_index % 10000 == 0:
  352. tf.compat.v1.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
  353. feature = convert_single_example(ex_index, example, label_list,
  354. max_seq_length, tokenizer)
  355. def create_int_feature(values):
  356. f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
  357. return f
  358. features = collections.OrderedDict()
  359. features["input_ids"] = create_int_feature(feature.input_ids)
  360. features["input_mask"] = create_int_feature(feature.input_mask)
  361. features["segment_ids"] = create_int_feature(feature.segment_ids)
  362. features["label_ids"] = create_int_feature([feature.label_id])
  363. features["is_real_example"] = create_int_feature(
  364. [int(feature.is_real_example)])
  365. tf_example = tf.train.Example(features=tf.train.Features(feature=features))
  366. writer.write(tf_example.SerializeToString())
  367. writer.close()
  368. def file_based_input_fn_builder(input_file, batch_size, seq_length, is_training,
  369. drop_remainder, hvd=None):
  370. """Creates an `input_fn` closure to be passed to TPUEstimator."""
  371. name_to_features = {
  372. "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
  373. "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
  374. "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
  375. "label_ids": tf.io.FixedLenFeature([], tf.int64),
  376. "is_real_example": tf.io.FixedLenFeature([], tf.int64),
  377. }
  378. def _decode_record(record, name_to_features):
  379. """Decodes a record to a TensorFlow example."""
  380. example = tf.parse_single_example(record, name_to_features)
  381. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
  382. # So cast all int64 to int32.
  383. for name in list(example.keys()):
  384. t = example[name]
  385. if t.dtype == tf.int64:
  386. t = tf.to_int32(t)
  387. example[name] = t
  388. return example
  389. def input_fn(params):
  390. """The actual input function."""
  391. #batch_size = params["batch_size"]
  392. # For training, we want a lot of parallel reading and shuffling.
  393. # For eval, we want no shuffling and parallel reading doesn't matter.
  394. d = tf.data.TFRecordDataset(input_file)
  395. if is_training:
  396. if hvd is not None: d = d.shard(hvd.size(), hvd.rank())
  397. d = d.repeat()
  398. d = d.shuffle(buffer_size=100)
  399. d = d.apply(
  400. tf.contrib.data.map_and_batch(
  401. lambda record: _decode_record(record, name_to_features),
  402. batch_size=batch_size,
  403. drop_remainder=drop_remainder))
  404. return d
  405. return input_fn
  406. def _truncate_seq_pair(tokens_a, tokens_b, max_length):
  407. """Truncates a sequence pair in place to the maximum length."""
  408. # This is a simple heuristic which will always truncate the longer sequence
  409. # one token at a time. This makes more sense than truncating an equal percent
  410. # of tokens from each, since if one sequence is very short then each token
  411. # that's truncated likely contains more information than a longer sequence.
  412. while True:
  413. total_length = len(tokens_a) + len(tokens_b)
  414. if total_length <= max_length:
  415. break
  416. if len(tokens_a) > len(tokens_b):
  417. tokens_a.pop()
  418. else:
  419. tokens_b.pop()
  420. def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
  421. labels, num_labels, use_one_hot_embeddings):
  422. """Creates a classification model."""
  423. model = modeling.BertModel(
  424. config=bert_config,
  425. is_training=is_training,
  426. input_ids=input_ids,
  427. input_mask=input_mask,
  428. token_type_ids=segment_ids,
  429. use_one_hot_embeddings=use_one_hot_embeddings)
  430. # In the demo, we are doing a simple classification task on the entire
  431. # segment.
  432. #
  433. # If you want to use the token-level output, use model.get_sequence_output()
  434. # instead.
  435. output_layer = model.get_pooled_output()
  436. hidden_size = output_layer.shape[-1].value
  437. output_weights = tf.get_variable(
  438. "output_weights", [num_labels, hidden_size],
  439. initializer=tf.truncated_normal_initializer(stddev=0.02))
  440. output_bias = tf.get_variable(
  441. "output_bias", [num_labels], initializer=tf.zeros_initializer())
  442. with tf.variable_scope("loss"):
  443. if is_training:
  444. # I.e., 0.1 dropout
  445. output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
  446. logits = tf.matmul(output_layer, output_weights, transpose_b=True)
  447. logits = tf.nn.bias_add(logits, output_bias)
  448. probabilities = tf.nn.softmax(logits, axis=-1)
  449. log_probs = tf.nn.log_softmax(logits, axis=-1)
  450. one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
  451. per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
  452. loss = tf.reduce_mean(per_example_loss)
  453. return (loss, per_example_loss, logits, probabilities)
  454. def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate=None,
  455. num_train_steps=None, num_warmup_steps=None,
  456. use_one_hot_embeddings=False, hvd=None, amp=False):
  457. """Returns `model_fn` closure for TPUEstimator."""
  458. def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
  459. """The `model_fn` for TPUEstimator."""
  460. tf.compat.v1.logging.info("*** Features ***")
  461. for name in sorted(features.keys()):
  462. tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
  463. input_ids = features["input_ids"]
  464. input_mask = features["input_mask"]
  465. segment_ids = features["segment_ids"]
  466. label_ids = features["label_ids"]
  467. is_real_example = None
  468. if "is_real_example" in features:
  469. is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
  470. else:
  471. is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
  472. is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  473. (total_loss, per_example_loss, logits, probabilities) = create_model(
  474. bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
  475. num_labels, use_one_hot_embeddings)
  476. tvars = tf.trainable_variables()
  477. initialized_variable_names = {}
  478. scaffold_fn = None
  479. if init_checkpoint and (hvd is None or hvd.rank() == 0):
  480. (assignment_map, initialized_variable_names
  481. ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
  482. tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  483. tf.compat.v1.logging.info("**** Trainable Variables ****")
  484. for var in tvars:
  485. init_string = ""
  486. if var.name in initialized_variable_names:
  487. init_string = ", *INIT_FROM_CKPT*"
  488. tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
  489. init_string)
  490. output_spec = None
  491. if mode == tf.estimator.ModeKeys.TRAIN:
  492. train_op = optimization.create_optimizer(
  493. total_loss, learning_rate, num_train_steps, num_warmup_steps, hvd, False, amp)
  494. output_spec = tf.estimator.EstimatorSpec(
  495. mode=mode,
  496. loss=total_loss,
  497. train_op=train_op)
  498. elif mode == tf.estimator.ModeKeys.EVAL:
  499. dummy_op = tf.no_op()
  500. # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
  501. if amp:
  502. loss_scaler = tf.train.experimental.FixedLossScale(1)
  503. dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
  504. optimization.LAMBOptimizer(learning_rate=0.0), loss_scaler)
  505. def metric_fn(per_example_loss, label_ids, logits, is_real_example):
  506. predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
  507. accuracy = tf.metrics.accuracy(
  508. labels=label_ids, predictions=predictions, weights=is_real_example)
  509. loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
  510. return {
  511. "eval_accuracy": accuracy,
  512. "eval_loss": loss,
  513. }
  514. eval_metric_ops = metric_fn(per_example_loss, label_ids, logits, is_real_example)
  515. output_spec = tf.estimator.EstimatorSpec(
  516. mode=mode,
  517. loss=total_loss,
  518. eval_metric_ops=eval_metric_ops)
  519. else:
  520. dummy_op = tf.no_op()
  521. # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
  522. if amp:
  523. dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
  524. optimization.LAMBOptimizer(learning_rate=0.0))
  525. output_spec = tf.estimator.EstimatorSpec(
  526. mode=mode, predictions={"probabilities": probabilities})#predicts)#probabilities)
  527. return output_spec
  528. return model_fn
  529. # This function is not used by this file but is still used by the Colab and
  530. # people who depend on it.
  531. def input_fn_builder(features, seq_length, is_training, drop_remainder):
  532. """Creates an `input_fn` closure to be passed to TPUEstimator."""
  533. all_input_ids = []
  534. all_input_mask = []
  535. all_segment_ids = []
  536. all_label_ids = []
  537. for feature in features:
  538. all_input_ids.append(feature.input_ids)
  539. all_input_mask.append(feature.input_mask)
  540. all_segment_ids.append(feature.segment_ids)
  541. all_label_ids.append(feature.label_id)
  542. def input_fn(params):
  543. """The actual input function."""
  544. batch_size = params["batch_size"]
  545. num_examples = len(features)
  546. # This is for demo purposes and does NOT scale to large data sets. We do
  547. # not use Dataset.from_generator() because that uses tf.py_func which is
  548. # not TPU compatible. The right way to load data is with TFRecordReader.
  549. d = tf.data.Dataset.from_tensor_slices({
  550. "input_ids":
  551. tf.constant(
  552. all_input_ids, shape=[num_examples, seq_length],
  553. dtype=tf.int32),
  554. "input_mask":
  555. tf.constant(
  556. all_input_mask,
  557. shape=[num_examples, seq_length],
  558. dtype=tf.int32),
  559. "segment_ids":
  560. tf.constant(
  561. all_segment_ids,
  562. shape=[num_examples, seq_length],
  563. dtype=tf.int32),
  564. "label_ids":
  565. tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),
  566. })
  567. if is_training:
  568. d = d.repeat()
  569. d = d.shuffle(buffer_size=100)
  570. d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
  571. return d
  572. return input_fn
  573. # This function is not used by this file but is still used by the Colab and
  574. # people who depend on it.
  575. def convert_examples_to_features(examples, label_list, max_seq_length,
  576. tokenizer):
  577. """Convert a set of `InputExample`s to a list of `InputFeatures`."""
  578. features = []
  579. for (ex_index, example) in enumerate(examples):
  580. if ex_index % 10000 == 0:
  581. tf.compat.v1.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
  582. feature = convert_single_example(ex_index, example, label_list,
  583. max_seq_length, tokenizer)
  584. features.append(feature)
  585. return features
  586. def main(_):
  587. setup_xla_flags()
  588. tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  589. dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)
  590. if FLAGS.horovod:
  591. hvd.init()
  592. processors = {
  593. "chemprot": BioBERTChemprotProcessor,
  594. 'mednli': MedNLIProcessor,
  595. }
  596. tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
  597. FLAGS.init_checkpoint)
  598. if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
  599. raise ValueError(
  600. "At least one of `do_train`, `do_eval` or `do_predict' must be True.")
  601. bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  602. if FLAGS.max_seq_length > bert_config.max_position_embeddings:
  603. raise ValueError(
  604. "Cannot use sequence length %d because the BERT model "
  605. "was only trained up to sequence length %d" %
  606. (FLAGS.max_seq_length, bert_config.max_position_embeddings))
  607. tf.io.gfile.makedirs(FLAGS.output_dir)
  608. task_name = FLAGS.task_name.lower()
  609. if task_name not in processors:
  610. raise ValueError("Task not found: %s" % (task_name))
  611. processor = processors[task_name]()
  612. label_list = processor.get_labels()
  613. tokenizer = tokenization.FullTokenizer(
  614. vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  615. is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  616. master_process = True
  617. training_hooks = []
  618. global_batch_size = FLAGS.train_batch_size
  619. hvd_rank = 0
  620. config = tf.compat.v1.ConfigProto()
  621. if FLAGS.horovod:
  622. global_batch_size = FLAGS.train_batch_size * hvd.size()
  623. master_process = (hvd.rank() == 0)
  624. hvd_rank = hvd.rank()
  625. config.gpu_options.visible_device_list = str(hvd.local_rank())
  626. if hvd.size() > 1:
  627. training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
  628. if FLAGS.use_xla:
  629. config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
  630. if FLAGS.amp:
  631. tf.enable_resource_variables()
  632. run_config = tf.estimator.RunConfig(
  633. model_dir=FLAGS.output_dir if master_process else None,
  634. session_config=config,
  635. save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
  636. keep_checkpoint_max=1)
  637. if master_process:
  638. tf.compat.v1.logging.info("***** Configuaration *****")
  639. for key in FLAGS.__flags.keys():
  640. tf.compat.v1.logging.info(' {}: {}'.format(key, getattr(FLAGS, key)))
  641. tf.compat.v1.logging.info("**************************")
  642. train_examples = None
  643. num_train_steps = None
  644. num_warmup_steps = None
  645. training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank))
  646. if FLAGS.do_train:
  647. train_examples = processor.get_train_examples(FLAGS.data_dir)
  648. num_train_steps = int(
  649. len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
  650. num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
  651. start_index = 0
  652. end_index = len(train_examples)
  653. tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]
  654. if FLAGS.horovod:
  655. tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
  656. num_examples_per_rank = len(train_examples) // hvd.size()
  657. remainder = len(train_examples) % hvd.size()
  658. if hvd.rank() < remainder:
  659. start_index = hvd.rank() * (num_examples_per_rank+1)
  660. end_index = start_index + num_examples_per_rank + 1
  661. else:
  662. start_index = hvd.rank() * num_examples_per_rank + remainder
  663. end_index = start_index + (num_examples_per_rank)
  664. model_fn = model_fn_builder(
  665. bert_config=bert_config,
  666. num_labels=len(label_list),
  667. init_checkpoint=FLAGS.init_checkpoint,
  668. learning_rate=FLAGS.learning_rate if not FLAGS.horovod else FLAGS.learning_rate * hvd.size(),
  669. num_train_steps=num_train_steps,
  670. num_warmup_steps=num_warmup_steps,
  671. use_one_hot_embeddings=False,
  672. hvd=None if not FLAGS.horovod else hvd,
  673. amp=FLAGS.amp)
  674. estimator = tf.estimator.Estimator(
  675. model_fn=model_fn,
  676. config=run_config)
  677. if FLAGS.do_train:
  678. file_based_convert_examples_to_features(
  679. train_examples[start_index:end_index], label_list, FLAGS.max_seq_length, tokenizer, tmp_filenames[hvd_rank])
  680. tf.compat.v1.logging.info("***** Running training *****")
  681. tf.compat.v1.logging.info(" Num examples = %d", len(train_examples))
  682. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
  683. tf.compat.v1.logging.info(" Num steps = %d", num_train_steps)
  684. train_input_fn = file_based_input_fn_builder(
  685. input_file=tmp_filenames,
  686. batch_size=FLAGS.train_batch_size,
  687. seq_length=FLAGS.max_seq_length,
  688. is_training=True,
  689. drop_remainder=True,
  690. hvd=None if not FLAGS.horovod else hvd)
  691. train_start_time = time.time()
  692. estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=training_hooks)
  693. train_time_elapsed = time.time() - train_start_time
  694. train_time_wo_overhead = training_hooks[-1].total_time
  695. avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
  696. ss_sentences_per_second = (num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead
  697. if master_process:
  698. tf.compat.v1.logging.info("-----------------------------")
  699. tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
  700. num_train_steps * global_batch_size)
  701. tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
  702. (num_train_steps - training_hooks[-1].skipped) * global_batch_size)
  703. tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
  704. tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
  705. dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
  706. tf.compat.v1.logging.info("-----------------------------")
  707. if FLAGS.do_eval and master_process:
  708. eval_examples = processor.get_dev_examples(FLAGS.data_dir)
  709. num_actual_eval_examples = len(eval_examples)
  710. eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
  711. file_based_convert_examples_to_features(
  712. eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
  713. tf.compat.v1.logging.info("***** Running evaluation *****")
  714. tf.compat.v1.logging.info(" Num examples = %d (%d actual, %d padding)",
  715. len(eval_examples), num_actual_eval_examples,
  716. len(eval_examples) - num_actual_eval_examples)
  717. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
  718. # This tells the estimator to run through the entire set.
  719. eval_steps = None
  720. eval_drop_remainder = False
  721. eval_input_fn = file_based_input_fn_builder(
  722. input_file=eval_file,
  723. batch_size=FLAGS.eval_batch_size,
  724. seq_length=FLAGS.max_seq_length,
  725. is_training=False,
  726. drop_remainder=eval_drop_remainder)
  727. result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
  728. output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
  729. with tf.io.gfile.GFile(output_eval_file, "w") as writer:
  730. tf.compat.v1.logging.info("***** Eval results *****")
  731. for key in sorted(result.keys()):
  732. tf.compat.v1.logging.info(" %s = %s", key, str(result[key]))
  733. writer.write("%s = %s\n" % (key, str(result[key])))
  734. if FLAGS.do_predict and master_process:
  735. predict_examples = processor.get_test_examples(FLAGS.data_dir)
  736. num_actual_predict_examples = len(predict_examples)
  737. predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
  738. file_based_convert_examples_to_features(predict_examples, label_list,
  739. FLAGS.max_seq_length, tokenizer,
  740. predict_file)
  741. tf.compat.v1.logging.info("***** Running prediction*****")
  742. tf.compat.v1.logging.info(" Num examples = %d (%d actual, %d padding)",
  743. len(predict_examples), num_actual_predict_examples,
  744. len(predict_examples) - num_actual_predict_examples)
  745. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
  746. predict_drop_remainder = False
  747. predict_input_fn = file_based_input_fn_builder(
  748. input_file=predict_file,
  749. batch_size=FLAGS.predict_batch_size,
  750. seq_length=FLAGS.max_seq_length,
  751. is_training=False,
  752. drop_remainder=predict_drop_remainder)
  753. eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
  754. eval_start_time = time.time()
  755. output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
  756. with tf.io.gfile.GFile(output_predict_file, "w") as writer:
  757. num_written_lines = 0
  758. tf.compat.v1.logging.info("***** Predict results *****")
  759. for prediction in estimator.predict(input_fn=predict_input_fn, hooks=eval_hooks,
  760. yield_single_examples=True):
  761. probabilities = prediction["probabilities"]
  762. output_line = "\t".join(
  763. str(class_probability)
  764. for class_probability in probabilities) + "\n"
  765. writer.write(output_line)
  766. num_written_lines += 1
  767. assert num_written_lines == num_actual_predict_examples
  768. eval_time_elapsed = time.time() - eval_start_time
  769. time_list = eval_hooks[-1].time_list
  770. time_list.sort()
  771. # Removing outliers (init/warmup) in throughput computation.
  772. eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
  773. num_sentences = (int(len(time_list) * 0.99)) * FLAGS.predict_batch_size
  774. avg = np.mean(time_list)
  775. cf_50 = max(time_list[:int(len(time_list) * 0.50)])
  776. cf_90 = max(time_list[:int(len(time_list) * 0.90)])
  777. cf_95 = max(time_list[:int(len(time_list) * 0.95)])
  778. cf_99 = max(time_list[:int(len(time_list) * 0.99)])
  779. cf_100 = max(time_list[:int(len(time_list) * 1)])
  780. ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead
  781. tf.compat.v1.logging.info("-----------------------------")
  782. tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
  783. eval_hooks[-1].count * FLAGS.predict_batch_size)
  784. tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
  785. num_sentences)
  786. tf.compat.v1.logging.info("Summary Inference Statistics")
  787. tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
  788. tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
  789. tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
  790. tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
  791. tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
  792. tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
  793. tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
  794. tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
  795. tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
  796. tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
  797. dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
  798. tf.compat.v1.logging.info("-----------------------------")
  799. if __name__ == "__main__":
  800. flags.mark_flag_as_required("data_dir")
  801. flags.mark_flag_as_required("task_name")
  802. flags.mark_flag_as_required("vocab_file")
  803. flags.mark_flag_as_required("bert_config_file")
  804. flags.mark_flag_as_required("output_dir")
  805. tf.compat.v1.app.run()