run_pretraining.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Run masked LM/next sentence masked_lm pre-training for BERT."""
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import os
  21. import time
  22. import modeling
  23. import optimization
  24. import tensorflow as tf
  25. import glob
  26. from utils.utils import LogEvalRunHook, setup_xla_flags
  27. import utils.dllogger_class
  28. from utils.gpu_affinity import set_affinity
  29. from dllogger import Verbosity
  30. from tensorflow.core.protobuf import rewriter_config_pb2
  31. flags = tf.flags
  32. FLAGS = flags.FLAGS
  33. ## Required parameters
  34. flags.DEFINE_string(
  35. "bert_config_file", None,
  36. "The config json file corresponding to the pre-trained BERT model. "
  37. "This specifies the model architecture.")
  38. flags.DEFINE_string(
  39. "input_files_dir", None,
  40. "Directory with input files, comma separated or single directory.")
  41. flags.DEFINE_string(
  42. "eval_files_dir", None,
  43. "Directory with eval files, comma separated or single directory. ")
  44. flags.DEFINE_string(
  45. "output_dir", None,
  46. "The output directory where the model checkpoints will be written.")
  47. ## Other parameters
  48. flags.DEFINE_string(
  49. "dllog_path", "/results/bert_dllog.json",
  50. "filename where dllogger writes to")
  51. flags.DEFINE_string(
  52. "init_checkpoint", None,
  53. "Initial checkpoint (usually from a pre-trained BERT model).")
  54. flags.DEFINE_string(
  55. "optimizer_type", "lamb",
  56. "Optimizer used for training - LAMB or ADAM")
  57. flags.DEFINE_integer(
  58. "max_seq_length", 512,
  59. "The maximum total input sequence length after WordPiece tokenization. "
  60. "Sequences longer than this will be truncated, and sequences shorter "
  61. "than this will be padded. Must match data generation.")
  62. flags.DEFINE_integer(
  63. "max_predictions_per_seq", 80,
  64. "Maximum number of masked LM predictions per sequence. "
  65. "Must match data generation.")
  66. flags.DEFINE_bool("do_train", False, "Whether to run training.")
  67. flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
  68. flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
  69. flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
  70. flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
  71. flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
  72. flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
  73. flags.DEFINE_integer("save_checkpoints_steps", 1000,
  74. "How often to save the model checkpoint.")
  75. flags.DEFINE_integer("display_loss_steps", 1,
  76. "How often to print loss")
  77. flags.DEFINE_integer("iterations_per_loop", 1000,
  78. "How many steps to make in each estimator call.")
  79. flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
  80. flags.DEFINE_integer("num_accumulation_steps", 1,
  81. "Number of accumulation steps before gradient update."
  82. "Global batch size = num_accumulation_steps * train_batch_size")
  83. flags.DEFINE_bool("allreduce_post_accumulation", False, "Whether to all reduce after accumulation of N steps or after each step")
  84. flags.DEFINE_bool(
  85. "verbose_logging", False,
  86. "If true, all of the trainable parameters are printed")
  87. flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs")
  88. flags.DEFINE_bool("report_loss", True, "Whether to report total loss during training.")
  89. flags.DEFINE_bool("manual_fp16", False, "Whether to use fp32 or fp16 arithmetic on GPU. "
  90. "Manual casting is done instead of using AMP")
  91. flags.DEFINE_bool("amp", True, "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.")
  92. flags.DEFINE_bool("use_xla", True, "Whether to enable XLA JIT compilation.")
  93. flags.DEFINE_integer("init_loss_scale", 2**32, "Initial value of loss scale if mixed precision training")
  94. # report samples/sec, total loss and learning rate during training
  95. class _LogSessionRunHook(tf.estimator.SessionRunHook):
  96. def __init__(self, global_batch_size, num_accumulation_steps, dllogging, display_every=10,
  97. save_ckpt_steps=1000, report_loss=True, hvd_rank=-1):
  98. self.global_batch_size = global_batch_size
  99. self.display_every = display_every
  100. self.save_ckpt_steps = save_ckpt_steps
  101. self.hvd_rank = hvd_rank
  102. self.num_accumulation_steps = num_accumulation_steps
  103. self.dllogging = dllogging
  104. self.report_loss = report_loss
  105. def after_create_session(self, session, coord):
  106. self.elapsed_secs = 0.0 #elapsed seconds between every print
  107. self.count = 0 # number of global steps between every print
  108. self.all_count = 0 #number of steps (including accumulation) between every print
  109. self.loss = 0.0 # accumulation of loss in each step between every print
  110. self.total_time = 0.0 # total time taken to train (excluding warmup + ckpt saving steps)
  111. self.step_time = 0.0 # time taken per step
  112. self.init_global_step = session.run(tf.train.get_global_step()) # training starts at init_global_step
  113. self.skipped = 0
  114. self.final_loss = 0
  115. def before_run(self, run_context):
  116. self.t0 = time.time()
  117. if self.num_accumulation_steps <= 1:
  118. if FLAGS.manual_fp16 or FLAGS.amp:
  119. return tf.estimator.SessionRunArgs(
  120. fetches=['step_update:0', 'total_loss:0',
  121. 'learning_rate:0', 'nsp_loss:0',
  122. 'mlm_loss:0', 'loss_scale:0'])
  123. else:
  124. return tf.estimator.SessionRunArgs(
  125. fetches=['step_update:0', 'total_loss:0',
  126. 'learning_rate:0', 'nsp_loss:0',
  127. 'mlm_loss:0'])
  128. else:
  129. if FLAGS.manual_fp16 or FLAGS.amp:
  130. return tf.estimator.SessionRunArgs(
  131. fetches=['step_update:0', 'update_step:0', 'total_loss:0',
  132. 'learning_rate:0', 'nsp_loss:0',
  133. 'mlm_loss:0', 'loss_scale:0'])
  134. else:
  135. return tf.estimator.SessionRunArgs(
  136. fetches=['step_update:0', 'update_step:0', 'total_loss:0',
  137. 'learning_rate:0', 'nsp_loss:0',
  138. 'mlm_loss:0'])
  139. def after_run(self, run_context, run_values):
  140. run_time = time.time() - self.t0
  141. if self.num_accumulation_steps <=1:
  142. if FLAGS.manual_fp16 or FLAGS.amp:
  143. self.global_step, total_loss, lr, nsp_loss, mlm_loss, loss_scaler = run_values.results
  144. else:
  145. self.global_step, total_loss, lr, nsp_loss, mlm_loss = run_values. \
  146. results
  147. update_step = True
  148. else:
  149. if FLAGS.manual_fp16 or FLAGS.amp:
  150. self.global_step, update_step, total_loss, lr, nsp_loss, mlm_loss, loss_scaler = run_values.results
  151. else:
  152. self.global_step, update_step, total_loss, lr, nsp_loss, mlm_loss = run_values.\
  153. results
  154. self.elapsed_secs += run_time
  155. self.step_time += run_time
  156. print_step = self.global_step + 1 # One-based index for printing.
  157. self.loss += total_loss
  158. self.all_count += 1
  159. if update_step:
  160. self.count += 1
  161. # Removing first six steps after every checkpoint save from timing
  162. if (self.global_step - self.init_global_step) % self.save_ckpt_steps < 6:
  163. print("Skipping time record for ", self.global_step, " due to checkpoint-saving/warmup overhead")
  164. self.skipped += 1
  165. else:
  166. self.total_time += self.step_time
  167. self.step_time = 0.0 #Reset Step Time
  168. if (print_step == 1 or print_step % self.display_every == 0):
  169. dt = self.elapsed_secs / self.count
  170. sent_per_sec = self.global_batch_size / dt
  171. avg_loss_step = self.loss / self.all_count
  172. if self.hvd_rank >= 0 and FLAGS.report_loss:
  173. if FLAGS.manual_fp16 or FLAGS.amp:
  174. self.dllogging.logger.log(step=(print_step),
  175. data={"Rank": int(self.hvd_rank), "throughput_train": float(sent_per_sec),
  176. "mlm_loss":float(mlm_loss), "nsp_loss":float(nsp_loss),
  177. "total_loss":float(total_loss), "avg_loss_step":float(avg_loss_step),
  178. "learning_rate": str(lr), "loss_scaler":int(loss_scaler)},
  179. verbosity=Verbosity.DEFAULT)
  180. else:
  181. self.dllogging.logger.log(step=int(print_step),
  182. data={"Rank": int(self.hvd_rank), "throughput_train": float(sent_per_sec),
  183. "mlm_loss":float(mlm_loss), "nsp_loss":float(nsp_loss),
  184. "total_loss":float(total_loss), "avg_loss_step":float(avg_loss_step),
  185. "learning_rate": str(lr)},
  186. verbosity=Verbosity.DEFAULT)
  187. else:
  188. if FLAGS.manual_fp16 or FLAGS.amp:
  189. self.dllogging.logger.log(step=int(print_step),
  190. data={"throughput_train": float(sent_per_sec),
  191. "mlm_loss":float(mlm_loss), "nsp_loss":float(nsp_loss),
  192. "total_loss":float(total_loss), "avg_loss_step":float(avg_loss_step),
  193. "learning_rate": str(lr), "loss_scaler":int(loss_scaler)},
  194. verbosity=Verbosity.DEFAULT)
  195. else:
  196. self.dllogging.logger.log(step=int(print_step),
  197. data={"throughput_train": float(sent_per_sec),
  198. "mlm_loss":float(mlm_loss), "nsp_loss":float(nsp_loss),
  199. "total_loss":float(total_loss), "avg_loss_step":float(avg_loss_step),
  200. "learning_rate": str(lr)},
  201. verbosity=Verbosity.DEFAULT)
  202. self.elapsed_secs = 0.0
  203. self.count = 0
  204. self.loss = 0.0
  205. self.all_count = 0
  206. self.final_loss = avg_loss_step
  207. def model_fn_builder(bert_config, init_checkpoint, learning_rate,
  208. num_train_steps, num_warmup_steps,
  209. use_one_hot_embeddings, hvd=None):
  210. """Returns `model_fn` closure for TPUEstimator."""
  211. def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
  212. """The `model_fn` for TPUEstimator."""
  213. tf.compat.v1.logging.info("*** Features ***")
  214. for name in sorted(features.keys()):
  215. tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
  216. input_ids = features["input_ids"]
  217. input_mask = features["input_mask"]
  218. segment_ids = features["segment_ids"]
  219. masked_lm_positions = features["masked_lm_positions"]
  220. masked_lm_ids = features["masked_lm_ids"]
  221. masked_lm_weights = features["masked_lm_weights"]
  222. next_sentence_labels = features["next_sentence_labels"]
  223. is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  224. model = modeling.BertModel(
  225. config=bert_config,
  226. is_training=is_training,
  227. input_ids=input_ids,
  228. input_mask=input_mask,
  229. token_type_ids=segment_ids,
  230. use_one_hot_embeddings=use_one_hot_embeddings,
  231. compute_type=tf.float16 if FLAGS.manual_fp16 else tf.float32)
  232. (masked_lm_loss,
  233. masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
  234. bert_config, model.get_sequence_output(), model.get_embedding_table(),
  235. masked_lm_positions, masked_lm_ids,
  236. masked_lm_weights)
  237. (next_sentence_loss, next_sentence_example_loss,
  238. next_sentence_log_probs) = get_next_sentence_output(
  239. bert_config, model.get_pooled_output(), next_sentence_labels)
  240. masked_lm_loss = tf.identity(masked_lm_loss, name="mlm_loss")
  241. next_sentence_loss = tf.identity(next_sentence_loss, name="nsp_loss")
  242. total_loss = masked_lm_loss + next_sentence_loss
  243. total_loss = tf.identity(total_loss, name='total_loss')
  244. tvars = tf.trainable_variables()
  245. initialized_variable_names = {}
  246. if init_checkpoint and (hvd is None or hvd.rank() == 0):
  247. print("Loading checkpoint", init_checkpoint)
  248. (assignment_map, initialized_variable_names
  249. ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
  250. tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  251. if FLAGS.verbose_logging:
  252. tf.compat.v1.logging.info("**** Trainable Variables ****")
  253. for var in tvars:
  254. init_string = ""
  255. if var.name in initialized_variable_names:
  256. init_string = ", *INIT_FROM_CKPT*"
  257. tf.compat.v1.logging.info(" %d :: name = %s, shape = %s%s", 0 if hvd is None else hvd.rank(), var.name, var.shape,
  258. init_string)
  259. output_spec = None
  260. if mode == tf.estimator.ModeKeys.TRAIN:
  261. train_op = optimization.create_optimizer(
  262. total_loss, learning_rate, num_train_steps, num_warmup_steps,
  263. hvd, FLAGS.manual_fp16, FLAGS.amp, FLAGS.num_accumulation_steps, FLAGS.optimizer_type, FLAGS.allreduce_post_accumulation, FLAGS.init_loss_scale)
  264. output_spec = tf.estimator.EstimatorSpec(
  265. mode=mode,
  266. loss=total_loss,
  267. train_op=train_op)
  268. elif mode == tf.estimator.ModeKeys.EVAL:
  269. def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
  270. masked_lm_weights, next_sentence_example_loss,
  271. next_sentence_log_probs, next_sentence_labels):
  272. """Computes the loss and accuracy of the model."""
  273. masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
  274. [-1, masked_lm_log_probs.shape[-1]])
  275. masked_lm_predictions = tf.argmax(
  276. masked_lm_log_probs, axis=-1, output_type=tf.int32)
  277. masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
  278. masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
  279. masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
  280. masked_lm_accuracy = tf.metrics.accuracy(
  281. labels=masked_lm_ids,
  282. predictions=masked_lm_predictions,
  283. weights=masked_lm_weights)
  284. masked_lm_mean_loss = tf.metrics.mean(
  285. values=masked_lm_example_loss, weights=masked_lm_weights)
  286. next_sentence_log_probs = tf.reshape(
  287. next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
  288. next_sentence_predictions = tf.argmax(
  289. next_sentence_log_probs, axis=-1, output_type=tf.int32)
  290. next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
  291. next_sentence_accuracy = tf.metrics.accuracy(
  292. labels=next_sentence_labels, predictions=next_sentence_predictions)
  293. next_sentence_mean_loss = tf.metrics.mean(
  294. values=next_sentence_example_loss)
  295. return {
  296. "masked_lm_accuracy": masked_lm_accuracy,
  297. "masked_lm_loss": masked_lm_mean_loss,
  298. "next_sentence_accuracy": next_sentence_accuracy,
  299. "next_sentence_loss": next_sentence_mean_loss,
  300. }
  301. eval_metric_ops = metric_fn(
  302. masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
  303. masked_lm_weights, next_sentence_example_loss,
  304. next_sentence_log_probs, next_sentence_labels
  305. )
  306. output_spec = tf.estimator.EstimatorSpec(
  307. mode=mode,
  308. loss=total_loss,
  309. eval_metric_ops=eval_metric_ops)
  310. else:
  311. raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
  312. return output_spec
  313. return model_fn
  314. def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
  315. label_ids, label_weights):
  316. """Get loss and log probs for the masked LM."""
  317. input_tensor = gather_indexes(input_tensor, positions)
  318. with tf.variable_scope("cls/predictions"):
  319. # We apply one more non-linear transformation before the output layer.
  320. # This matrix is not used after pre-training.
  321. with tf.variable_scope("transform"):
  322. input_tensor = tf.layers.dense(
  323. input_tensor,
  324. units=bert_config.hidden_size,
  325. activation=modeling.get_activation(bert_config.hidden_act),
  326. kernel_initializer=modeling.create_initializer(
  327. bert_config.initializer_range))
  328. input_tensor = modeling.layer_norm(input_tensor)
  329. # The output weights are the same as the input embeddings, but there is
  330. # an output-only bias for each token.
  331. output_bias = tf.get_variable(
  332. "output_bias",
  333. shape=[bert_config.vocab_size],
  334. initializer=tf.zeros_initializer())
  335. logits = tf.matmul(tf.cast(input_tensor, tf.float32), output_weights, transpose_b=True)
  336. logits = tf.nn.bias_add(logits, output_bias)
  337. log_probs = tf.nn.log_softmax(logits, axis=-1)
  338. label_ids = tf.reshape(label_ids, [-1])
  339. label_weights = tf.reshape(label_weights, [-1])
  340. one_hot_labels = tf.one_hot(
  341. label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
  342. # The `positions` tensor might be zero-padded (if the sequence is too
  343. # short to have the maximum number of predictions). The `label_weights`
  344. # tensor has a value of 1.0 for every real prediction and 0.0 for the
  345. # padding predictions.
  346. per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
  347. numerator = tf.reduce_sum(label_weights * per_example_loss)
  348. denominator = tf.reduce_sum(label_weights) + 1e-5
  349. loss = numerator / denominator
  350. return (loss, per_example_loss, log_probs)
  351. def get_next_sentence_output(bert_config, input_tensor, labels):
  352. """Get loss and log probs for the next sentence prediction."""
  353. # Simple binary classification. Note that 0 is "next sentence" and 1 is
  354. # "random sentence". This weight matrix is not used after pre-training.
  355. with tf.variable_scope("cls/seq_relationship"):
  356. output_weights = tf.get_variable(
  357. "output_weights",
  358. shape=[2, bert_config.hidden_size],
  359. initializer=modeling.create_initializer(bert_config.initializer_range))
  360. output_bias = tf.get_variable(
  361. "output_bias", shape=[2], initializer=tf.zeros_initializer())
  362. logits = tf.matmul(tf.cast(input_tensor, tf.float32), output_weights, transpose_b=True)
  363. logits = tf.nn.bias_add(logits, output_bias)
  364. log_probs = tf.nn.log_softmax(logits, axis=-1)
  365. labels = tf.reshape(labels, [-1])
  366. one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
  367. per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
  368. loss = tf.reduce_mean(per_example_loss)
  369. return (loss, per_example_loss, log_probs)
  370. def gather_indexes(sequence_tensor, positions):
  371. """Gathers the vectors at the specific positions over a minibatch."""
  372. sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
  373. batch_size = sequence_shape[0]
  374. seq_length = sequence_shape[1]
  375. width = sequence_shape[2]
  376. flat_offsets = tf.reshape(
  377. tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
  378. flat_positions = tf.reshape(positions + flat_offsets, [-1])
  379. flat_sequence_tensor = tf.reshape(sequence_tensor,
  380. [batch_size * seq_length, width])
  381. output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
  382. return output_tensor
  383. def input_fn_builder(input_files,
  384. batch_size,
  385. max_seq_length,
  386. max_predictions_per_seq,
  387. is_training,
  388. num_cpu_threads=4,
  389. hvd=None):
  390. """Creates an `input_fn` closure to be passed to Estimator."""
  391. def input_fn():
  392. """The actual input function."""
  393. name_to_features = {
  394. "input_ids":
  395. tf.io.FixedLenFeature([max_seq_length], tf.int64),
  396. "input_mask":
  397. tf.io.FixedLenFeature([max_seq_length], tf.int64),
  398. "segment_ids":
  399. tf.io.FixedLenFeature([max_seq_length], tf.int64),
  400. "masked_lm_positions":
  401. tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
  402. "masked_lm_ids":
  403. tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
  404. "masked_lm_weights":
  405. tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
  406. "next_sentence_labels":
  407. tf.io.FixedLenFeature([1], tf.int64),
  408. }
  409. # For training, we want a lot of parallel reading and shuffling.
  410. # For eval, we want no shuffling and parallel reading doesn't matter.
  411. if is_training:
  412. d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
  413. if hvd is not None: d = d.shard(hvd.size(), hvd.rank())
  414. d = d.repeat()
  415. d = d.shuffle(buffer_size=len(input_files))
  416. # `cycle_length` is the number of parallel files that get read.
  417. cycle_length = min(num_cpu_threads, len(input_files))
  418. # `sloppy` mode means that the interleaving is not exact. This adds
  419. # even more randomness to the training pipeline.
  420. d = d.apply(
  421. tf.contrib.data.parallel_interleave(
  422. tf.data.TFRecordDataset,
  423. sloppy=is_training,
  424. cycle_length=cycle_length))
  425. d = d.shuffle(buffer_size=100)
  426. else:
  427. d = tf.data.TFRecordDataset(input_files)
  428. # Since we evaluate for a fixed number of steps we don't want to encounter
  429. # out-of-range exceptions.
  430. d = d.repeat()
  431. # We must `drop_remainder` on training because the TPU requires fixed
  432. # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
  433. # and we *don't* want to drop the remainder, otherwise we wont cover
  434. # every sample.
  435. d = d.apply(
  436. tf.contrib.data.map_and_batch(
  437. lambda record: _decode_record(record, name_to_features),
  438. batch_size=batch_size,
  439. num_parallel_batches=num_cpu_threads,
  440. drop_remainder=True if is_training else False))
  441. return d
  442. return input_fn
  443. def _decode_record(record, name_to_features):
  444. """Decodes a record to a TensorFlow example."""
  445. example = tf.parse_single_example(record, name_to_features)
  446. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
  447. # So cast all int64 to int32.
  448. for name in list(example.keys()):
  449. t = example[name]
  450. if t.dtype == tf.int64:
  451. t = tf.to_int32(t)
  452. example[name] = t
  453. return example
  454. def main(_):
  455. setup_xla_flags()
  456. tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  457. dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)
  458. if not FLAGS.do_train and not FLAGS.do_eval:
  459. raise ValueError("At least one of `do_train` or `do_eval` must be True.")
  460. if FLAGS.horovod:
  461. import horovod.tensorflow as hvd
  462. hvd.init()
  463. bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  464. tf.io.gfile.makedirs(FLAGS.output_dir)
  465. input_files = []
  466. for input_file_dir in FLAGS.input_files_dir.split(","):
  467. input_files.extend(tf.io.gfile.glob(os.path.join(input_file_dir, "*")))
  468. if FLAGS.horovod and len(input_files) < hvd.size():
  469. raise ValueError("Input Files must be sharded")
  470. if FLAGS.amp and FLAGS.manual_fp16:
  471. raise ValueError("AMP and Manual Mixed Precision Training are both activated! Error")
  472. is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  473. config = tf.compat.v1.ConfigProto()
  474. if FLAGS.horovod:
  475. config.gpu_options.visible_device_list = str(hvd.local_rank())
  476. set_affinity(hvd.local_rank())
  477. if hvd.rank() == 0:
  478. tf.compat.v1.logging.info("***** Configuaration *****")
  479. for key in FLAGS.__flags.keys():
  480. tf.compat.v1.logging.info(' {}: {}'.format(key, getattr(FLAGS, key)))
  481. tf.compat.v1.logging.info("**************************")
  482. # config.gpu_options.per_process_gpu_memory_fraction = 0.7
  483. if FLAGS.use_xla:
  484. config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
  485. config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
  486. if FLAGS.amp:
  487. tf.enable_resource_variables()
  488. run_config = tf.estimator.RunConfig(
  489. model_dir=FLAGS.output_dir,
  490. session_config=config,
  491. save_checkpoints_steps=FLAGS.save_checkpoints_steps if not FLAGS.horovod or hvd.rank() == 0 else None,
  492. save_summary_steps=FLAGS.save_checkpoints_steps if not FLAGS.horovod or hvd.rank() == 0 else None,
  493. # This variable controls how often estimator reports examples/sec.
  494. # Default value is every 100 steps.
  495. # When --report_loss is True, we set to very large value to prevent
  496. # default info reporting from estimator.
  497. # Ideally we should set it to None, but that does not work.
  498. log_step_count_steps=10000 if FLAGS.report_loss else 100)
  499. model_fn = model_fn_builder(
  500. bert_config=bert_config,
  501. init_checkpoint=FLAGS.init_checkpoint,
  502. learning_rate=FLAGS.learning_rate if not FLAGS.horovod else FLAGS.learning_rate*hvd.size(),
  503. num_train_steps=FLAGS.num_train_steps,
  504. num_warmup_steps=FLAGS.num_warmup_steps,
  505. use_one_hot_embeddings=False,
  506. hvd=None if not FLAGS.horovod else hvd)
  507. estimator = tf.estimator.Estimator(
  508. model_fn=model_fn,
  509. config=run_config)
  510. if FLAGS.do_train:
  511. training_hooks = []
  512. if FLAGS.horovod and hvd.size() > 1:
  513. training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
  514. if (not FLAGS.horovod or hvd.rank() == 0):
  515. global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size()
  516. log_hook = _LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss)
  517. training_hooks.append(log_hook)
  518. tf.compat.v1.logging.info("***** Running training *****")
  519. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
  520. train_input_fn = input_fn_builder(
  521. input_files=input_files,
  522. batch_size=FLAGS.train_batch_size,
  523. max_seq_length=FLAGS.max_seq_length,
  524. max_predictions_per_seq=FLAGS.max_predictions_per_seq,
  525. is_training=True,
  526. hvd=None if not FLAGS.horovod else hvd)
  527. train_start_time = time.time()
  528. estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=FLAGS.num_train_steps)
  529. train_time_elapsed = time.time() - train_start_time
  530. if (not FLAGS.horovod or hvd.rank() == 0):
  531. train_time_wo_overhead = training_hooks[-1].total_time
  532. avg_sentences_per_second = FLAGS.num_train_steps * global_batch_size * 1.0 / train_time_elapsed
  533. ss_sentences_per_second = (FLAGS.num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead
  534. tf.compat.v1.logging.info("-----------------------------")
  535. tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
  536. FLAGS.num_train_steps * global_batch_size)
  537. tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
  538. (FLAGS.num_train_steps - training_hooks[-1].skipped) * global_batch_size)
  539. tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
  540. tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
  541. dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
  542. if log_hook.final_loss != 0:
  543. dllogging.logger.log(step=(), data={"total_loss": log_hook.final_loss}, verbosity=Verbosity.DEFAULT)
  544. tf.compat.v1.logging.info("-----------------------------")
  545. if FLAGS.do_eval and (not FLAGS.horovod or hvd.rank() == 0):
  546. tf.compat.v1.logging.info("***** Running evaluation *****")
  547. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
  548. eval_files = []
  549. for eval_file_dir in FLAGS.eval_files_dir.split(","):
  550. eval_files.extend(tf.io.gfile.glob(os.path.join(eval_file_dir, "*")))
  551. eval_input_fn = input_fn_builder(
  552. input_files=eval_files,
  553. batch_size=FLAGS.eval_batch_size,
  554. max_seq_length=FLAGS.max_seq_length,
  555. max_predictions_per_seq=FLAGS.max_predictions_per_seq,
  556. is_training=False,
  557. hvd=None if not FLAGS.horovod else hvd)
  558. eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
  559. eval_start_time = time.time()
  560. result = estimator.evaluate(
  561. input_fn=eval_input_fn, steps=FLAGS.max_eval_steps, hooks=eval_hooks)
  562. eval_time_elapsed = time.time() - eval_start_time
  563. time_list = eval_hooks[-1].time_list
  564. time_list.sort()
  565. # Removing outliers (init/warmup) in throughput computation.
  566. eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
  567. num_sentences = (int(len(time_list) * 0.99)) * FLAGS.eval_batch_size
  568. ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead
  569. tf.compat.v1.logging.info("-----------------------------")
  570. tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
  571. eval_hooks[-1].count * FLAGS.eval_batch_size)
  572. tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
  573. num_sentences)
  574. tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
  575. tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
  576. tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
  577. tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
  578. tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
  579. dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
  580. tf.compat.v1.logging.info("-----------------------------")
  581. output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
  582. with tf.io.gfile.GFile(output_eval_file, "w") as writer:
  583. tf.compat.v1.logging.info("***** Eval results *****")
  584. for key in sorted(result.keys()):
  585. tf.compat.v1.logging.info(" %s = %s", key, str(result[key]))
  586. writer.write("%s = %s\n" % (key, str(result[key])))
  587. if __name__ == "__main__":
  588. flags.mark_flag_as_required("input_files_dir")
  589. if FLAGS.do_eval:
  590. flags.mark_flag_as_required("eval_files_dir")
  591. flags.mark_flag_as_required("bert_config_file")
  592. flags.mark_flag_as_required("output_dir")
  593. if FLAGS.use_xla and FLAGS.manual_fp16:
  594. print('WARNING! Combining --use_xla with --manual_fp16 may prevent convergence.')
  595. print(' This warning message will be removed when the underlying')
  596. print(' issues have been fixed and you are running a TF version')
  597. print(' that has that fix.')
  598. tf.compat.v1.app.run()