run_classifier.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """BERT finetuning runner."""
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import collections
  21. import csv
  22. import os
  23. import modeling
  24. import optimization
  25. import tokenization
  26. import tensorflow as tf
  27. import horovod.tensorflow as hvd
  28. import time
  29. from utils.utils import LogEvalRunHook, LogTrainRunHook, setup_xla_flags
  30. import utils.dllogger_class
  31. from dllogger import Verbosity
  32. from utils.create_glue_data import *
  33. import numpy as np
  34. import tf_metrics
  35. flags = tf.flags
  36. FLAGS = flags.FLAGS
  37. ## Required parameters
  38. flags.DEFINE_string(
  39. "data_dir", None,
  40. "The input data dir. Should contain the .tsv files (or other data files) "
  41. "for the task.")
  42. flags.DEFINE_string(
  43. "bert_config_file", None,
  44. "The config json file corresponding to the pre-trained BERT model. "
  45. "This specifies the model architecture.")
  46. flags.DEFINE_string("task_name", None, "The name of the task to train.")
  47. flags.DEFINE_string("vocab_file", None,
  48. "The vocabulary file that the BERT model was trained on.")
  49. flags.DEFINE_string(
  50. "output_dir", None,
  51. "The output directory where the model checkpoints will be written.")
  52. ## Other parameters
  53. flags.DEFINE_string(
  54. "dllog_path", "/results/bert_dllog.json",
  55. "filename where dllogger writes to")
  56. flags.DEFINE_string(
  57. "optimizer_type", "lamb",
  58. "Optimizer type : adam or lamb")
  59. flags.DEFINE_string(
  60. "init_checkpoint", None,
  61. "Initial checkpoint (usually from a pre-trained BERT model).")
  62. flags.DEFINE_bool(
  63. "do_lower_case", True,
  64. "Whether to lower case the input text. Should be True for uncased "
  65. "models and False for cased models.")
  66. flags.DEFINE_integer(
  67. "max_seq_length", 128,
  68. "The maximum total input sequence length after WordPiece tokenization. "
  69. "Sequences longer than this will be truncated, and sequences shorter "
  70. "than this will be padded.")
  71. flags.DEFINE_bool("do_train", False, "Whether to run training.")
  72. flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
  73. flags.DEFINE_bool(
  74. "do_predict", False,
  75. "Whether to run the model in inference mode on the test set.")
  76. flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
  77. flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
  78. flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
  79. flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
  80. flags.DEFINE_bool("use_trt", False, "Whether to use TF-TRT")
  81. flags.DEFINE_float("num_train_epochs", 3.0,
  82. "Total number of training epochs to perform.")
  83. flags.DEFINE_float(
  84. "warmup_proportion", 0.1,
  85. "Proportion of training to perform linear learning rate warmup for. "
  86. "E.g., 0.1 = 10% of training.")
  87. flags.DEFINE_integer("save_checkpoints_steps", 1000,
  88. "How often to save the model checkpoint.")
  89. flags.DEFINE_integer("display_loss_steps", 10,
  90. "How often to print loss from estimator")
  91. flags.DEFINE_integer("iterations_per_loop", 1000,
  92. "How many steps to make in each estimator call.")
  93. flags.DEFINE_integer("num_accumulation_steps", 1,
  94. "Number of accumulation steps before gradient update"
  95. "Global batch size = num_accumulation_steps * train_batch_size")
  96. flags.DEFINE_bool("amp", True, "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.")
  97. flags.DEFINE_bool("use_xla", True, "Whether to enable XLA JIT compilation.")
  98. flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs")
  99. flags.DEFINE_bool(
  100. "verbose_logging", False,
  101. "If true, all of the warnings related to data processing will be printed. "
  102. "A number of warnings are expected for a normal SQuAD evaluation.")
  103. def file_based_input_fn_builder(input_file, batch_size, seq_length, is_training,
  104. drop_remainder, hvd=None):
  105. """Creates an `input_fn` closure to be passed to Estimator."""
  106. name_to_features = {
  107. "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
  108. "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
  109. "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
  110. "label_ids": tf.io.FixedLenFeature([], tf.int64),
  111. }
  112. def _decode_record(record, name_to_features):
  113. """Decodes a record to a TensorFlow example."""
  114. example = tf.parse_single_example(record, name_to_features)
  115. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
  116. # So cast all int64 to int32.
  117. for name in list(example.keys()):
  118. t = example[name]
  119. if t.dtype == tf.int64:
  120. t = tf.to_int32(t)
  121. example[name] = t
  122. return example
  123. def input_fn():
  124. """The actual input function."""
  125. # For training, we want a lot of parallel reading and shuffling.
  126. # For eval, we want no shuffling and parallel reading doesn't matter.
  127. d = tf.data.TFRecordDataset(input_file)
  128. if is_training:
  129. if hvd is not None: d = d.shard(hvd.size(), hvd.rank())
  130. d = d.repeat()
  131. d = d.shuffle(buffer_size=100)
  132. d = d.apply(
  133. tf.contrib.data.map_and_batch(
  134. lambda record: _decode_record(record, name_to_features),
  135. batch_size=batch_size,
  136. drop_remainder=drop_remainder))
  137. return d
  138. return input_fn
  139. def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
  140. labels, num_labels, use_one_hot_embeddings):
  141. """Creates a classification model."""
  142. model = modeling.BertModel(
  143. config=bert_config,
  144. is_training=is_training,
  145. input_ids=input_ids,
  146. input_mask=input_mask,
  147. token_type_ids=segment_ids,
  148. use_one_hot_embeddings=use_one_hot_embeddings,
  149. compute_type=tf.float32)
  150. # In the demo, we are doing a simple classification task on the entire
  151. # segment.
  152. #
  153. # If you want to use the token-level output, use model.get_sequence_output()
  154. # instead.
  155. output_layer = model.get_pooled_output()
  156. hidden_size = output_layer.shape[-1].value
  157. output_weights = tf.get_variable(
  158. "output_weights", [num_labels, hidden_size],
  159. initializer=tf.truncated_normal_initializer(stddev=0.02))
  160. output_bias = tf.get_variable(
  161. "output_bias", [num_labels], initializer=tf.zeros_initializer())
  162. with tf.variable_scope("loss"):
  163. if is_training:
  164. # I.e., 0.1 dropout
  165. output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
  166. logits = tf.matmul(output_layer, output_weights, transpose_b=True)
  167. logits = tf.nn.bias_add(logits, output_bias, name='cls_logits')
  168. probabilities = tf.nn.softmax(logits, axis=-1, name='cls_probabilities')
  169. log_probs = tf.nn.log_softmax(logits, axis=-1)
  170. one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
  171. per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1, name='cls_per_example_loss')
  172. loss = tf.reduce_mean(per_example_loss, name='cls_loss')
  173. return (loss, per_example_loss, logits, probabilities)
  174. def get_frozen_tftrt_model(bert_config, shape, num_labels, use_one_hot_embeddings, init_checkpoint):
  175. tf_config = tf.compat.v1.ConfigProto()
  176. tf_config.gpu_options.allow_growth = True
  177. output_node_names = ['loss/cls_loss', 'loss/cls_per_example_loss', 'loss/cls_logits', 'loss/cls_probabilities']
  178. with tf.Session(config=tf_config) as tf_sess:
  179. input_ids = tf.placeholder(tf.int32, shape, 'input_ids')
  180. input_mask = tf.placeholder(tf.int32, shape, 'input_mask')
  181. segment_ids = tf.placeholder(tf.int32, shape, 'segment_ids')
  182. label_ids = tf.placeholder(tf.int32, (None), 'label_ids')
  183. create_model(bert_config, False, input_ids, input_mask, segment_ids, label_ids,
  184. num_labels, use_one_hot_embeddings)
  185. tvars = tf.trainable_variables()
  186. (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
  187. tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  188. tf_sess.run(tf.global_variables_initializer())
  189. print("LOADED!")
  190. tf.compat.v1.logging.info("**** Trainable Variables ****")
  191. for var in tvars:
  192. init_string = ""
  193. if var.name in initialized_variable_names:
  194. init_string = ", *INIT_FROM_CKPT*"
  195. else:
  196. init_string = ", *NOTTTTTTTTTTTTTTTTTTTTT"
  197. tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string)
  198. frozen_graph = tf.graph_util.convert_variables_to_constants(tf_sess,
  199. tf_sess.graph.as_graph_def(), output_node_names)
  200. num_nodes = len(frozen_graph.node)
  201. print('Converting graph using TensorFlow-TensorRT...')
  202. from tensorflow.python.compiler.tensorrt import trt_convert as trt
  203. converter = trt.TrtGraphConverter(
  204. input_graph_def=frozen_graph,
  205. nodes_blacklist=output_node_names,
  206. max_workspace_size_bytes=(4096 << 20) - 1000,
  207. precision_mode = "FP16" if FLAGS.amp else "FP32",
  208. minimum_segment_size=4,
  209. is_dynamic_op=True,
  210. maximum_cached_engines=1000
  211. )
  212. frozen_graph = converter.convert()
  213. print('Total node count before and after TF-TRT conversion:',
  214. num_nodes, '->', len(frozen_graph.node))
  215. print('TRT node count:',
  216. len([1 for n in frozen_graph.node if str(n.op) == 'TRTEngineOp']))
  217. with tf.io.gfile.GFile("frozen_modelTRT.pb", "wb") as f:
  218. f.write(frozen_graph.SerializeToString())
  219. return frozen_graph
  220. def model_fn_builder(task_name, bert_config, num_labels, init_checkpoint, learning_rate,
  221. num_train_steps, num_warmup_steps,
  222. use_one_hot_embeddings, hvd=None):
  223. """Returns `model_fn` closure for Estimator."""
  224. def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
  225. """The `model_fn` for Estimator."""
  226. def metric_fn(per_example_loss, label_ids, logits):
  227. predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
  228. if task_name == "cola":
  229. FN, FN_op = tf.metrics.false_negatives(labels=label_ids, predictions=predictions)
  230. FP, FP_op = tf.metrics.false_positives(labels=label_ids, predictions=predictions)
  231. TP, TP_op = tf.metrics.true_positives(labels=label_ids, predictions=predictions)
  232. TN, TN_op = tf.metrics.true_negatives(labels=label_ids, predictions=predictions)
  233. MCC = (TP * TN - FP * FN) / ((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)) ** 0.5
  234. MCC_op = tf.group(FN_op, TN_op, TP_op, FP_op, tf.identity(MCC, name="MCC"))
  235. return {"MCC": (MCC, MCC_op)}
  236. elif task_name == "mrpc":
  237. accuracy = tf.metrics.accuracy(
  238. labels=label_ids, predictions=predictions)
  239. loss = tf.metrics.mean(values=per_example_loss)
  240. f1 = tf_metrics.f1(labels=label_ids, predictions=predictions, num_classes=2, pos_indices=[1])
  241. return {
  242. "eval_accuracy": accuracy,
  243. "eval_f1": f1,
  244. "eval_loss": loss,
  245. }
  246. else:
  247. accuracy = tf.metrics.accuracy(
  248. labels=label_ids, predictions=predictions)
  249. loss = tf.metrics.mean(values=per_example_loss)
  250. return {
  251. "eval_accuracy": accuracy,
  252. "eval_loss": loss,
  253. }
  254. tf.compat.v1.logging.info("*** Features ***")
  255. tf.compat.v1.logging.info("*** Features ***")
  256. for name in sorted(features.keys()):
  257. tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
  258. input_ids = features["input_ids"]
  259. input_mask = features["input_mask"]
  260. segment_ids = features["segment_ids"]
  261. label_ids = features["label_ids"]
  262. is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  263. if not is_training and FLAGS.use_trt:
  264. trt_graph = get_frozen_tftrt_model(bert_config, input_ids.shape, num_labels, use_one_hot_embeddings, init_checkpoint)
  265. (total_loss, per_example_loss, logits, probabilities) = tf.import_graph_def(trt_graph,
  266. input_map={'input_ids':input_ids, 'input_mask':input_mask, 'segment_ids':segment_ids, 'label_ids':label_ids},
  267. return_elements=['loss/cls_loss:0', 'loss/cls_per_example_loss:0', 'loss/cls_logits:0', 'loss/cls_probabilities:0'],
  268. name='')
  269. if mode == tf.estimator.ModeKeys.PREDICT:
  270. predictions = {"probabilities": probabilities}
  271. output_spec = tf.estimator.EstimatorSpec(
  272. mode=mode, predictions=predictions)
  273. elif mode == tf.estimator.ModeKeys.EVAL:
  274. eval_metric_ops = metric_fn(per_example_loss, label_ids, logits)
  275. output_spec = tf.estimator.EstimatorSpec(
  276. mode=mode,
  277. loss=total_loss,
  278. eval_metric_ops=eval_metric_ops)
  279. return output_spec
  280. (total_loss, per_example_loss, logits, probabilities) = create_model(
  281. bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
  282. num_labels, use_one_hot_embeddings)
  283. tvars = tf.trainable_variables()
  284. initialized_variable_names = {}
  285. if init_checkpoint and (hvd is None or hvd.rank() == 0):
  286. (assignment_map, initialized_variable_names
  287. ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
  288. tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  289. if FLAGS.verbose_logging:
  290. tf.compat.v1.logging.info("**** Trainable Variables ****")
  291. for var in tvars:
  292. init_string = ""
  293. if var.name in initialized_variable_names:
  294. init_string = ", *INIT_FROM_CKPT*"
  295. tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
  296. init_string)
  297. output_spec = None
  298. if mode == tf.estimator.ModeKeys.TRAIN:
  299. train_op = optimization.create_optimizer(
  300. total_loss, learning_rate, num_train_steps, num_warmup_steps,
  301. hvd, False, FLAGS.amp, FLAGS.num_accumulation_steps, FLAGS.optimizer_type)
  302. output_spec = tf.estimator.EstimatorSpec(
  303. mode=mode,
  304. loss=total_loss,
  305. train_op=train_op)
  306. elif mode == tf.estimator.ModeKeys.EVAL:
  307. dummy_op = tf.no_op()
  308. # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
  309. if FLAGS.amp:
  310. loss_scaler = tf.train.experimental.FixedLossScale(1)
  311. dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
  312. optimization.LAMBOptimizer(learning_rate=0.0), loss_scaler)
  313. eval_metric_ops = metric_fn(per_example_loss, label_ids, logits)
  314. output_spec = tf.estimator.EstimatorSpec(
  315. mode=mode,
  316. loss=total_loss,
  317. eval_metric_ops=eval_metric_ops)
  318. else:
  319. dummy_op = tf.no_op()
  320. # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
  321. if FLAGS.amp:
  322. dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
  323. optimization.LAMBOptimizer(learning_rate=0.0))
  324. output_spec = tf.estimator.EstimatorSpec(
  325. mode=mode, predictions=probabilities)
  326. return output_spec
  327. return model_fn
  328. # This function is not used by this file but is still used by the Colab and
  329. # people who depend on it.
  330. def input_fn_builder(features, batch_size, seq_length, is_training, drop_remainder, hvd=None):
  331. """Creates an `input_fn` closure to be passed to Estimator."""
  332. all_input_ids = []
  333. all_input_mask = []
  334. all_segment_ids = []
  335. all_label_ids = []
  336. for feature in features:
  337. all_input_ids.append(feature.input_ids)
  338. all_input_mask.append(feature.input_mask)
  339. all_segment_ids.append(feature.segment_ids)
  340. all_label_ids.append(feature.label_id)
  341. def input_fn():
  342. """The actual input function."""
  343. num_examples = len(features)
  344. # This is for demo purposes and does NOT scale to large data sets. We do
  345. # not use Dataset.from_generator() because that uses tf.py_func which is
  346. # not TPU compatible. The right way to load data is with TFRecordReader.
  347. d = tf.data.Dataset.from_tensor_slices({
  348. "input_ids":
  349. tf.constant(
  350. all_input_ids, shape=[num_examples, seq_length],
  351. dtype=tf.int32),
  352. "input_mask":
  353. tf.constant(
  354. all_input_mask,
  355. shape=[num_examples, seq_length],
  356. dtype=tf.int32),
  357. "segment_ids":
  358. tf.constant(
  359. all_segment_ids,
  360. shape=[num_examples, seq_length],
  361. dtype=tf.int32),
  362. "label_ids":
  363. tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),
  364. })
  365. if is_training:
  366. if hvd is not None: d = d.shard(hvd.size(), hvd.rank())
  367. d = d.repeat()
  368. d = d.shuffle(buffer_size=100)
  369. d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
  370. return d
  371. return input_fn
  372. def main(_):
  373. setup_xla_flags()
  374. tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  375. dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)
  376. if FLAGS.horovod:
  377. hvd.init()
  378. processors = {
  379. "cola": ColaProcessor,
  380. "mnli": MnliProcessor,
  381. "mrpc": MrpcProcessor,
  382. "xnli": XnliProcessor,
  383. }
  384. if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
  385. raise ValueError(
  386. "At least one of `do_train`, `do_eval` or `do_predict' must be True.")
  387. bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  388. if FLAGS.max_seq_length > bert_config.max_position_embeddings:
  389. raise ValueError(
  390. "Cannot use sequence length %d because the BERT model "
  391. "was only trained up to sequence length %d" %
  392. (FLAGS.max_seq_length, bert_config.max_position_embeddings))
  393. tf.io.gfile.makedirs(FLAGS.output_dir)
  394. task_name = FLAGS.task_name.lower()
  395. if task_name not in processors:
  396. raise ValueError("Task not found: %s" % (task_name))
  397. processor = processors[task_name]()
  398. label_list = processor.get_labels()
  399. tokenizer = tokenization.FullTokenizer(
  400. vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  401. master_process = True
  402. training_hooks = []
  403. global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
  404. hvd_rank = 0
  405. config = tf.compat.v1.ConfigProto()
  406. if FLAGS.horovod:
  407. tf.compat.v1.logging.info("Multi-GPU training with TF Horovod")
  408. tf.compat.v1.logging.info("hvd.size() = %d hvd.rank() = %d", hvd.size(), hvd.rank())
  409. global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size()
  410. master_process = (hvd.rank() == 0)
  411. hvd_rank = hvd.rank()
  412. config.gpu_options.visible_device_list = str(hvd.local_rank())
  413. if hvd.size() > 1:
  414. training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
  415. if FLAGS.use_xla:
  416. config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
  417. if FLAGS.amp:
  418. tf.enable_resource_variables()
  419. run_config = tf.estimator.RunConfig(
  420. model_dir=FLAGS.output_dir if master_process else None,
  421. session_config=config,
  422. save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
  423. save_summary_steps=FLAGS.save_checkpoints_steps if master_process else None,
  424. log_step_count_steps=FLAGS.display_loss_steps,
  425. keep_checkpoint_max=1)
  426. if master_process:
  427. tf.compat.v1.logging.info("***** Configuaration *****")
  428. for key in FLAGS.__flags.keys():
  429. tf.compat.v1.logging.info(' {}: {}'.format(key, getattr(FLAGS, key)))
  430. tf.compat.v1.logging.info("**************************")
  431. train_examples = None
  432. num_train_steps = None
  433. num_warmup_steps = None
  434. training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank, FLAGS.save_checkpoints_steps, num_steps_ignore_xla=25))
  435. if FLAGS.do_train:
  436. train_examples = processor.get_train_examples(FLAGS.data_dir)
  437. num_train_steps = int(
  438. len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
  439. num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
  440. start_index = 0
  441. end_index = len(train_examples)
  442. tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]
  443. if FLAGS.horovod:
  444. tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
  445. num_examples_per_rank = len(train_examples) // hvd.size()
  446. remainder = len(train_examples) % hvd.size()
  447. if hvd.rank() < remainder:
  448. start_index = hvd.rank() * (num_examples_per_rank+1)
  449. end_index = start_index + num_examples_per_rank + 1
  450. else:
  451. start_index = hvd.rank() * num_examples_per_rank + remainder
  452. end_index = start_index + (num_examples_per_rank)
  453. model_fn = model_fn_builder(
  454. task_name=task_name,
  455. bert_config=bert_config,
  456. num_labels=len(label_list),
  457. init_checkpoint=FLAGS.init_checkpoint,
  458. learning_rate=FLAGS.learning_rate if not FLAGS.horovod else FLAGS.learning_rate * hvd.size(),
  459. num_train_steps=num_train_steps,
  460. num_warmup_steps=num_warmup_steps,
  461. use_one_hot_embeddings=False,
  462. hvd=None if not FLAGS.horovod else hvd)
  463. estimator = tf.estimator.Estimator(
  464. model_fn=model_fn,
  465. config=run_config)
  466. if FLAGS.do_train:
  467. file_based_convert_examples_to_features(
  468. train_examples[start_index:end_index], label_list, FLAGS.max_seq_length, tokenizer, tmp_filenames[hvd_rank])
  469. tf.compat.v1.logging.info("***** Running training *****")
  470. tf.compat.v1.logging.info(" Num examples = %d", len(train_examples))
  471. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
  472. tf.compat.v1.logging.info(" Num steps = %d", num_train_steps)
  473. train_input_fn = file_based_input_fn_builder(
  474. input_file=tmp_filenames,
  475. batch_size=FLAGS.train_batch_size,
  476. seq_length=FLAGS.max_seq_length,
  477. is_training=True,
  478. drop_remainder=True,
  479. hvd=None if not FLAGS.horovod else hvd)
  480. train_start_time = time.time()
  481. estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=training_hooks)
  482. train_time_elapsed = time.time() - train_start_time
  483. train_time_wo_overhead = training_hooks[-1].total_time
  484. avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
  485. ss_sentences_per_second = (training_hooks[-1].count - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead
  486. if master_process:
  487. tf.compat.v1.logging.info("-----------------------------")
  488. tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
  489. num_train_steps * global_batch_size)
  490. tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
  491. (training_hooks[-1].count - training_hooks[-1].skipped) * global_batch_size)
  492. tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
  493. tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
  494. tf.compat.v1.logging.info("-----------------------------")
  495. if FLAGS.do_eval and master_process:
  496. eval_examples = processor.get_dev_examples(FLAGS.data_dir)
  497. eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
  498. file_based_convert_examples_to_features(
  499. eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
  500. tf.compat.v1.logging.info("***** Running evaluation *****")
  501. tf.compat.v1.logging.info(" Num examples = %d", len(eval_examples))
  502. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
  503. eval_drop_remainder = False
  504. eval_input_fn = file_based_input_fn_builder(
  505. input_file=eval_file,
  506. batch_size=FLAGS.eval_batch_size,
  507. seq_length=FLAGS.max_seq_length,
  508. is_training=False,
  509. drop_remainder=eval_drop_remainder)
  510. eval_hooks = [LogEvalRunHook(FLAGS.eval_batch_size)]
  511. eval_start_time = time.time()
  512. result = estimator.evaluate(input_fn=eval_input_fn, hooks=eval_hooks)
  513. eval_time_elapsed = time.time() - eval_start_time
  514. time_list = eval_hooks[-1].time_list
  515. time_list.sort()
  516. # Removing outliers (init/warmup) in throughput computation.
  517. eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.8)])
  518. num_sentences = (int(len(time_list) * 0.8)) * FLAGS.eval_batch_size
  519. avg = np.mean(time_list)
  520. cf_50 = max(time_list[:int(len(time_list) * 0.50)])
  521. cf_90 = max(time_list[:int(len(time_list) * 0.90)])
  522. cf_95 = max(time_list[:int(len(time_list) * 0.95)])
  523. cf_99 = max(time_list[:int(len(time_list) * 0.99)])
  524. cf_100 = max(time_list[:int(len(time_list) * 1)])
  525. ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead
  526. tf.compat.v1.logging.info("-----------------------------")
  527. tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
  528. eval_hooks[-1].count * FLAGS.eval_batch_size)
  529. tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
  530. num_sentences)
  531. tf.compat.v1.logging.info("Summary Inference Statistics on EVAL set")
  532. tf.compat.v1.logging.info("Batch size = %d", FLAGS.eval_batch_size)
  533. tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
  534. tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
  535. tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
  536. tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
  537. tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
  538. tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
  539. tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
  540. tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
  541. tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
  542. dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
  543. tf.compat.v1.logging.info("-----------------------------")
  544. output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
  545. with tf.io.gfile.GFile(output_eval_file, "w") as writer:
  546. tf.compat.v1.logging.info("***** Eval results *****")
  547. for key in sorted(result.keys()):
  548. dllogging.logger.log(step=(), data={key: float(result[key])}, verbosity=Verbosity.DEFAULT)
  549. tf.compat.v1.logging.info(" %s = %s", key, str(result[key]))
  550. writer.write("%s = %s\n" % (key, str(result[key])))
  551. if FLAGS.do_predict and master_process:
  552. predict_examples = processor.get_test_examples(FLAGS.data_dir)
  553. predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
  554. file_based_convert_examples_to_features(predict_examples, label_list,
  555. FLAGS.max_seq_length, tokenizer,
  556. predict_file)
  557. tf.compat.v1.logging.info("***** Running prediction*****")
  558. tf.compat.v1.logging.info(" Num examples = %d", len(predict_examples))
  559. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
  560. predict_drop_remainder = False
  561. predict_input_fn = file_based_input_fn_builder(
  562. input_file=predict_file,
  563. batch_size=FLAGS.predict_batch_size,
  564. seq_length=FLAGS.max_seq_length,
  565. is_training=False,
  566. drop_remainder=predict_drop_remainder)
  567. predict_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
  568. predict_start_time = time.time()
  569. output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
  570. with tf.io.gfile.GFile(output_predict_file, "w") as writer:
  571. tf.compat.v1.logging.info("***** Predict results *****")
  572. for prediction in estimator.predict(input_fn=predict_input_fn, hooks=predict_hooks,
  573. yield_single_examples=False):
  574. output_line = "\t".join(
  575. str(class_probability) for class_probability in prediction) + "\n"
  576. writer.write(output_line)
  577. predict_time_elapsed = time.time() - predict_start_time
  578. time_list = predict_hooks[-1].time_list
  579. time_list.sort()
  580. # Removing outliers (init/warmup) in throughput computation.
  581. predict_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.8)])
  582. num_sentences = (int(len(time_list) * 0.8)) * FLAGS.predict_batch_size
  583. avg = np.mean(time_list)
  584. cf_50 = max(time_list[:int(len(time_list) * 0.50)])
  585. cf_90 = max(time_list[:int(len(time_list) * 0.90)])
  586. cf_95 = max(time_list[:int(len(time_list) * 0.95)])
  587. cf_99 = max(time_list[:int(len(time_list) * 0.99)])
  588. cf_100 = max(time_list[:int(len(time_list) * 1)])
  589. ss_sentences_per_second = num_sentences * 1.0 / predict_time_wo_overhead
  590. tf.compat.v1.logging.info("-----------------------------")
  591. tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", predict_time_elapsed,
  592. predict_hooks[-1].count * FLAGS.predict_batch_size)
  593. tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", predict_time_wo_overhead,
  594. num_sentences)
  595. tf.compat.v1.logging.info("Summary Inference Statistics on TEST SET")
  596. tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
  597. tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
  598. tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
  599. tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
  600. tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
  601. tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
  602. tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
  603. tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
  604. tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
  605. tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
  606. dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
  607. tf.compat.v1.logging.info("-----------------------------")
  608. if __name__ == "__main__":
  609. flags.mark_flag_as_required("data_dir")
  610. flags.mark_flag_as_required("task_name")
  611. flags.mark_flag_as_required("vocab_file")
  612. flags.mark_flag_as_required("bert_config_file")
  613. flags.mark_flag_as_required("output_dir")
  614. tf.compat.v1.app.run()