run_squad.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Run BERT on SQuAD 1.1 and SQuAD 2.0."""
  17. from __future__ import absolute_import, division, print_function
  18. import collections
  19. import json
  20. import math
  21. import os
  22. import random
  23. import shutil
  24. import time
  25. import horovod.tensorflow as hvd
  26. import numpy as np
  27. import six
  28. import tensorflow as tf
  29. from tensorflow.python.client import device_lib
  30. import modeling
  31. import optimization
  32. import tokenization
  33. from utils.create_squad_data import *
  34. from utils.utils import LogEvalRunHook, LogTrainRunHook, setup_xla_flags
  35. from utils.gpu_affinity import set_affinity
  36. import utils.dllogger_class
  37. from dllogger import Verbosity
  38. flags = tf.flags
  39. FLAGS = None
  40. def extract_run_squad_flags():
  41. ## Required parameters
  42. flags.DEFINE_string(
  43. "bert_config_file", None,
  44. "The config json file corresponding to the pre-trained BERT model. "
  45. "This specifies the model architecture.")
  46. flags.DEFINE_string("vocab_file", None,
  47. "The vocabulary file that the BERT model was trained on.")
  48. flags.DEFINE_string(
  49. "output_dir", None,
  50. "The output directory where the model checkpoints will be written.")
  51. ## Other parameters
  52. flags.DEFINE_string(
  53. "dllog_path", "/results/bert_dllog.json",
  54. "filename where dllogger writes to")
  55. flags.DEFINE_string("train_file", None,
  56. "SQuAD json for training. E.g., train-v1.1.json")
  57. flags.DEFINE_string(
  58. "predict_file", None,
  59. "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
  60. flags.DEFINE_string(
  61. "eval_script", None,
  62. "SQuAD evaluate.py file to compute f1 and exact_match E.g., evaluate-v1.1.py")
  63. flags.DEFINE_string(
  64. "init_checkpoint", None,
  65. "Initial checkpoint (usually from a pre-trained BERT model).")
  66. flags.DEFINE_bool(
  67. "do_lower_case", True,
  68. "Whether to lower case the input text. Should be True for uncased "
  69. "models and False for cased models.")
  70. flags.DEFINE_integer(
  71. "max_seq_length", 384,
  72. "The maximum total input sequence length after WordPiece tokenization. "
  73. "Sequences longer than this will be truncated, and sequences shorter "
  74. "than this will be padded.")
  75. flags.DEFINE_integer(
  76. "doc_stride", 128,
  77. "When splitting up a long document into chunks, how much stride to "
  78. "take between chunks.")
  79. flags.DEFINE_integer(
  80. "max_query_length", 64,
  81. "The maximum number of tokens for the question. Questions longer than "
  82. "this will be truncated to this length.")
  83. flags.DEFINE_bool("do_train", False, "Whether to run training.")
  84. flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
  85. flags.DEFINE_integer("train_batch_size", 8, "Total batch size for training.")
  86. flags.DEFINE_integer("predict_batch_size", 8,
  87. "Total batch size for predictions.")
  88. flags.DEFINE_float("learning_rate", 5e-6, "The initial learning rate for Adam.")
  89. flags.DEFINE_bool("use_trt", False, "Whether to use TF-TRT")
  90. flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs")
  91. flags.DEFINE_float("num_train_epochs", 3.0,
  92. "Total number of training epochs to perform.")
  93. flags.DEFINE_float(
  94. "warmup_proportion", 0.1,
  95. "Proportion of training to perform linear learning rate warmup for. "
  96. "E.g., 0.1 = 10% of training.")
  97. flags.DEFINE_integer("save_checkpoints_steps", 5000,
  98. "How often to save the model checkpoint.")
  99. flags.DEFINE_integer("display_loss_steps", 10,
  100. "How often to print loss from estimator")
  101. flags.DEFINE_integer("iterations_per_loop", 1000,
  102. "How many steps to make in each estimator call.")
  103. flags.DEFINE_integer("num_accumulation_steps", 1,
  104. "Number of accumulation steps before gradient update"
  105. "Global batch size = num_accumulation_steps * train_batch_size")
  106. flags.DEFINE_integer(
  107. "n_best_size", 20,
  108. "The total number of n-best predictions to generate in the "
  109. "nbest_predictions.json output file.")
  110. flags.DEFINE_integer(
  111. "max_answer_length", 30,
  112. "The maximum length of an answer that can be generated. This is needed "
  113. "because the start and end predictions are not conditioned on one another.")
  114. flags.DEFINE_bool(
  115. "verbose_logging", False,
  116. "If true, all of the warnings related to data processing will be printed. "
  117. "A number of warnings are expected for a normal SQuAD evaluation.")
  118. flags.DEFINE_bool(
  119. "version_2_with_negative", False,
  120. "If true, the SQuAD examples contain some that do not have an answer.")
  121. flags.DEFINE_float(
  122. "null_score_diff_threshold", 0.0,
  123. "If null_score - best_non_null is greater than the threshold predict null.")
  124. flags.DEFINE_bool("amp", True, "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.")
  125. flags.DEFINE_bool("use_xla", True, "Whether to enable XLA JIT compilation.")
  126. flags.DEFINE_integer("num_eval_iterations", None,
  127. "How many eval iterations to run - performs inference on subset")
  128. # Triton Specific flags
  129. flags.DEFINE_bool("export_triton", False, "Whether to export saved model or run inference with Triton")
  130. flags.DEFINE_string("triton_model_name", "bert", "exports to appropriate directory for Triton")
  131. flags.DEFINE_integer("triton_model_version", 1, "exports to appropriate directory for Triton")
  132. flags.DEFINE_string("triton_server_url", "localhost:8001", "exports to appropriate directory for Triton")
  133. flags.DEFINE_bool("triton_model_overwrite", False, "If True, will overwrite an existing directory with the specified 'model_name' and 'version_name'")
  134. flags.DEFINE_integer("triton_max_batch_size", 8, "Specifies the 'max_batch_size' in the Triton model config. See the Triton documentation for more info.")
  135. flags.DEFINE_float("triton_dyn_batching_delay", 0, "Determines the dynamic_batching queue delay in milliseconds(ms) for the Triton model config. Use '0' or '-1' to specify static batching. See the Triton documentation for more info.")
  136. flags.DEFINE_integer("triton_engine_count", 1, "Specifies the 'instance_group' count value in the Triton model config. See the Triton documentation for more info.")
  137. flags.mark_flag_as_required("vocab_file")
  138. flags.mark_flag_as_required("bert_config_file")
  139. flags.mark_flag_as_required("output_dir")
  140. return flags.FLAGS
  141. def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
  142. use_one_hot_embeddings):
  143. """Creates a classification model."""
  144. model = modeling.BertModel(
  145. config=bert_config,
  146. is_training=is_training,
  147. input_ids=input_ids,
  148. input_mask=input_mask,
  149. token_type_ids=segment_ids,
  150. use_one_hot_embeddings=use_one_hot_embeddings,
  151. compute_type=tf.float32)
  152. final_hidden = model.get_sequence_output()
  153. final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
  154. batch_size = final_hidden_shape[0]
  155. seq_length = final_hidden_shape[1]
  156. hidden_size = final_hidden_shape[2]
  157. output_weights = tf.get_variable(
  158. "cls/squad/output_weights", [2, hidden_size],
  159. initializer=tf.truncated_normal_initializer(stddev=0.02))
  160. output_bias = tf.get_variable(
  161. "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())
  162. final_hidden_matrix = tf.reshape(final_hidden,
  163. [batch_size * seq_length, hidden_size])
  164. logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
  165. logits = tf.nn.bias_add(logits, output_bias)
  166. logits = tf.reshape(logits, [batch_size, seq_length, 2])
  167. logits = tf.transpose(logits, [2, 0, 1])
  168. unstacked_logits = tf.unstack(logits, axis=0, name='unstack')
  169. (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
  170. return (start_logits, end_logits)
  171. def get_frozen_tftrt_model(bert_config, shape, use_one_hot_embeddings, init_checkpoint):
  172. tf_config = tf.compat.v1.ConfigProto()
  173. tf_config.gpu_options.allow_growth = True
  174. output_node_names = ['unstack']
  175. with tf.Session(config=tf_config) as tf_sess:
  176. input_ids = tf.placeholder(tf.int32, shape, 'input_ids')
  177. input_mask = tf.placeholder(tf.int32, shape, 'input_mask')
  178. segment_ids = tf.placeholder(tf.int32, shape, 'segment_ids')
  179. (start_logits, end_logits) = create_model(bert_config=bert_config,
  180. is_training=False,
  181. input_ids=input_ids,
  182. input_mask=input_mask,
  183. segment_ids=segment_ids,
  184. use_one_hot_embeddings=use_one_hot_embeddings)
  185. tvars = tf.trainable_variables()
  186. (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
  187. tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  188. tf_sess.run(tf.global_variables_initializer())
  189. print("LOADED!")
  190. tf.compat.v1.logging.info("**** Trainable Variables ****")
  191. for var in tvars:
  192. init_string = ""
  193. if var.name in initialized_variable_names:
  194. init_string = ", *INIT_FROM_CKPT*"
  195. else:
  196. init_string = ", *NOTTTTTTTTTTTTTTTTTTTTT"
  197. tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string)
  198. frozen_graph = tf.graph_util.convert_variables_to_constants(tf_sess,
  199. tf_sess.graph.as_graph_def(), output_node_names)
  200. num_nodes = len(frozen_graph.node)
  201. print('Converting graph using TensorFlow-TensorRT...')
  202. from tensorflow.python.compiler.tensorrt import trt_convert as trt
  203. converter = trt.TrtGraphConverter(
  204. input_graph_def=frozen_graph,
  205. nodes_blacklist=output_node_names,
  206. max_workspace_size_bytes=(4096 << 20) - 1000,
  207. precision_mode = "FP16" if FLAGS.amp else "FP32",
  208. minimum_segment_size=4,
  209. is_dynamic_op=True,
  210. maximum_cached_engines=1000
  211. )
  212. frozen_graph = converter.convert()
  213. print('Total node count before and after TF-TRT conversion:',
  214. num_nodes, '->', len(frozen_graph.node))
  215. print('TRT node count:',
  216. len([1 for n in frozen_graph.node if str(n.op) == 'TRTEngineOp']))
  217. with tf.io.gfile.GFile("frozen_modelTRT.pb", "wb") as f:
  218. f.write(frozen_graph.SerializeToString())
  219. return frozen_graph
  220. def model_fn_builder(bert_config, init_checkpoint, learning_rate,
  221. num_train_steps, num_warmup_steps,
  222. hvd=None, amp=False, use_one_hot_embeddings=False):
  223. """Returns `model_fn` closure for Estimator."""
  224. def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
  225. """The `model_fn` for Estimator."""
  226. if FLAGS.verbose_logging:
  227. tf.compat.v1.logging.info("*** Features ***")
  228. for name in sorted(features.keys()):
  229. tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
  230. unique_ids = features["unique_ids"]
  231. input_ids = features["input_ids"]
  232. input_mask = features["input_mask"]
  233. segment_ids = features["segment_ids"]
  234. is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  235. if not is_training and FLAGS.use_trt:
  236. trt_graph = get_frozen_tftrt_model(bert_config, input_ids.shape, use_one_hot_embeddings, init_checkpoint)
  237. (start_logits, end_logits) = tf.import_graph_def(trt_graph,
  238. input_map={'input_ids':input_ids, 'input_mask':input_mask, 'segment_ids':segment_ids},
  239. return_elements=['unstack:0', 'unstack:1'],
  240. name='')
  241. predictions = {
  242. "unique_ids": unique_ids,
  243. "start_logits": start_logits,
  244. "end_logits": end_logits,
  245. }
  246. output_spec = tf.estimator.EstimatorSpec(
  247. mode=mode, predictions=predictions)
  248. return output_spec
  249. (start_logits, end_logits) = create_model(
  250. bert_config=bert_config,
  251. is_training=is_training,
  252. input_ids=input_ids,
  253. input_mask=input_mask,
  254. segment_ids=segment_ids,
  255. use_one_hot_embeddings=use_one_hot_embeddings)
  256. tvars = tf.trainable_variables()
  257. initialized_variable_names = {}
  258. if init_checkpoint and (hvd is None or hvd.rank() == 0):
  259. (assignment_map, initialized_variable_names
  260. ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
  261. tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  262. if FLAGS.verbose_logging:
  263. tf.compat.v1.logging.info("**** Trainable Variables ****")
  264. for var in tvars:
  265. init_string = ""
  266. if var.name in initialized_variable_names:
  267. init_string = ", *INIT_FROM_CKPT*"
  268. tf.compat.v1.logging.info(" %d name = %s, shape = %s%s", 0 if hvd is None else hvd.rank(), var.name, var.shape,
  269. init_string)
  270. output_spec = None
  271. if mode == tf.estimator.ModeKeys.TRAIN:
  272. seq_length = modeling.get_shape_list(input_ids)[1]
  273. def compute_loss(logits, positions):
  274. one_hot_positions = tf.one_hot(
  275. positions, depth=seq_length, dtype=tf.float32)
  276. log_probs = tf.nn.log_softmax(logits, axis=-1)
  277. loss = -tf.reduce_mean(
  278. tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
  279. return loss
  280. start_positions = features["start_positions"]
  281. end_positions = features["end_positions"]
  282. start_loss = compute_loss(start_logits, start_positions)
  283. end_loss = compute_loss(end_logits, end_positions)
  284. total_loss = (start_loss + end_loss) / 2.0
  285. train_op = optimization.create_optimizer(
  286. total_loss, learning_rate, num_train_steps, num_warmup_steps, hvd, False, amp, FLAGS.num_accumulation_steps)
  287. output_spec = tf.estimator.EstimatorSpec(
  288. mode=mode,
  289. loss=total_loss,
  290. train_op=train_op)
  291. elif mode == tf.estimator.ModeKeys.PREDICT:
  292. dummy_op = tf.no_op()
  293. # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
  294. if amp:
  295. loss_scaler = tf.train.experimental.FixedLossScale(1)
  296. dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
  297. optimization.LAMBOptimizer(learning_rate=0.0), loss_scaler)
  298. predictions = {
  299. "unique_ids": tf.identity(unique_ids),
  300. "start_logits": start_logits,
  301. "end_logits": end_logits,
  302. }
  303. output_spec = tf.estimator.EstimatorSpec(
  304. mode=mode, predictions=predictions)
  305. else:
  306. raise ValueError(
  307. "Only TRAIN and PREDICT modes are supported: %s" % (mode))
  308. return output_spec
  309. return model_fn
  310. def input_fn_builder(input_file, batch_size, seq_length, is_training, drop_remainder, hvd=None):
  311. """Creates an `input_fn` closure to be passed to Estimator."""
  312. name_to_features = {
  313. "unique_ids": tf.io.FixedLenFeature([], tf.int64),
  314. "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
  315. "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
  316. "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
  317. }
  318. if is_training:
  319. name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
  320. name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)
  321. def _decode_record(record, name_to_features):
  322. """Decodes a record to a TensorFlow example."""
  323. example = tf.parse_single_example(record, name_to_features)
  324. # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
  325. # So cast all int64 to int32.
  326. for name in list(example.keys()):
  327. t = example[name]
  328. if t.dtype == tf.int64:
  329. t = tf.to_int32(t)
  330. example[name] = t
  331. return example
  332. def input_fn():
  333. """The actual input function."""
  334. # For training, we want a lot of parallel reading and shuffling.
  335. # For eval, we want no shuffling and parallel reading doesn't matter.
  336. if is_training:
  337. d = tf.data.TFRecordDataset(input_file, num_parallel_reads=4)
  338. if hvd is not None: d = d.shard(hvd.size(), hvd.rank())
  339. d = d.apply(tf.data.experimental.ignore_errors())
  340. d = d.shuffle(buffer_size=100)
  341. d = d.repeat()
  342. else:
  343. d = tf.data.TFRecordDataset(input_file)
  344. d = d.apply(
  345. tf.contrib.data.map_and_batch(
  346. lambda record: _decode_record(record, name_to_features),
  347. batch_size=batch_size,
  348. drop_remainder=drop_remainder))
  349. return d
  350. return input_fn
  351. RawResult = collections.namedtuple("RawResult",
  352. ["unique_id", "start_logits", "end_logits"])
  353. def get_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length,
  354. do_lower_case, version_2_with_negative, verbose_logging):
  355. """Get final predictions"""
  356. example_index_to_features = collections.defaultdict(list)
  357. for feature in all_features:
  358. example_index_to_features[feature.example_index].append(feature)
  359. unique_id_to_result = {}
  360. for result in all_results:
  361. unique_id_to_result[result.unique_id] = result
  362. _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
  363. "PrelimPrediction",
  364. ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
  365. all_predictions = collections.OrderedDict()
  366. all_nbest_json = collections.OrderedDict()
  367. scores_diff_json = collections.OrderedDict()
  368. for (example_index, example) in enumerate(all_examples):
  369. features = example_index_to_features[example_index]
  370. prelim_predictions = []
  371. # keep track of the minimum score of null start+end of position 0
  372. score_null = 1000000 # large and positive
  373. min_null_feature_index = 0 # the paragraph slice with min mull score
  374. null_start_logit = 0 # the start logit at the slice with min null score
  375. null_end_logit = 0 # the end logit at the slice with min null score
  376. for (feature_index, feature) in enumerate(features):
  377. result = unique_id_to_result[feature.unique_id]
  378. start_indexes = _get_best_indexes(result.start_logits, n_best_size)
  379. end_indexes = _get_best_indexes(result.end_logits, n_best_size)
  380. # if we could have irrelevant answers, get the min score of irrelevant
  381. if version_2_with_negative:
  382. feature_null_score = result.start_logits[0] + result.end_logits[0]
  383. if feature_null_score < score_null:
  384. score_null = feature_null_score
  385. min_null_feature_index = feature_index
  386. null_start_logit = result.start_logits[0]
  387. null_end_logit = result.end_logits[0]
  388. for start_index in start_indexes:
  389. for end_index in end_indexes:
  390. # We could hypothetically create invalid predictions, e.g., predict
  391. # that the start of the span is in the question. We throw out all
  392. # invalid predictions.
  393. if start_index >= len(feature.tokens):
  394. continue
  395. if end_index >= len(feature.tokens):
  396. continue
  397. if start_index not in feature.token_to_orig_map:
  398. continue
  399. if end_index not in feature.token_to_orig_map:
  400. continue
  401. if not feature.token_is_max_context.get(start_index, False):
  402. continue
  403. if end_index < start_index:
  404. continue
  405. length = end_index - start_index + 1
  406. if length > max_answer_length:
  407. continue
  408. prelim_predictions.append(
  409. _PrelimPrediction(
  410. feature_index=feature_index,
  411. start_index=start_index,
  412. end_index=end_index,
  413. start_logit=result.start_logits[start_index],
  414. end_logit=result.end_logits[end_index]))
  415. if version_2_with_negative:
  416. prelim_predictions.append(
  417. _PrelimPrediction(
  418. feature_index=min_null_feature_index,
  419. start_index=0,
  420. end_index=0,
  421. start_logit=null_start_logit,
  422. end_logit=null_end_logit))
  423. prelim_predictions = sorted(
  424. prelim_predictions,
  425. key=lambda x: (x.start_logit + x.end_logit),
  426. reverse=True)
  427. _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
  428. "NbestPrediction", ["text", "start_logit", "end_logit"])
  429. seen_predictions = {}
  430. nbest = []
  431. for pred in prelim_predictions:
  432. if len(nbest) >= n_best_size:
  433. break
  434. feature = features[pred.feature_index]
  435. if pred.start_index > 0: # this is a non-null prediction
  436. tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
  437. orig_doc_start = feature.token_to_orig_map[pred.start_index]
  438. orig_doc_end = feature.token_to_orig_map[pred.end_index]
  439. orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
  440. tok_text = " ".join(tok_tokens)
  441. # De-tokenize WordPieces that have been split off.
  442. tok_text = tok_text.replace(" ##", "")
  443. tok_text = tok_text.replace("##", "")
  444. # Clean whitespace
  445. tok_text = tok_text.strip()
  446. tok_text = " ".join(tok_text.split())
  447. orig_text = " ".join(orig_tokens)
  448. final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
  449. if final_text in seen_predictions:
  450. continue
  451. seen_predictions[final_text] = True
  452. else:
  453. final_text = ""
  454. seen_predictions[final_text] = True
  455. nbest.append(
  456. _NbestPrediction(
  457. text=final_text,
  458. start_logit=pred.start_logit,
  459. end_logit=pred.end_logit))
  460. # if we didn't inlude the empty option in the n-best, inlcude it
  461. if version_2_with_negative:
  462. if "" not in seen_predictions:
  463. nbest.append(
  464. _NbestPrediction(
  465. text="", start_logit=null_start_logit,
  466. end_logit=null_end_logit))
  467. # In very rare edge cases we could have no valid predictions. So we
  468. # just create a nonce prediction in this case to avoid failure.
  469. if not nbest:
  470. nbest.append(
  471. _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
  472. assert len(nbest) >= 1
  473. total_scores = []
  474. best_non_null_entry = None
  475. for entry in nbest:
  476. total_scores.append(entry.start_logit + entry.end_logit)
  477. if not best_non_null_entry:
  478. if entry.text:
  479. best_non_null_entry = entry
  480. probs = _compute_softmax(total_scores)
  481. nbest_json = []
  482. for (i, entry) in enumerate(nbest):
  483. output = collections.OrderedDict()
  484. output["text"] = entry.text
  485. output["probability"] = probs[i]
  486. output["start_logit"] = entry.start_logit
  487. output["end_logit"] = entry.end_logit
  488. nbest_json.append(output)
  489. assert len(nbest_json) >= 1
  490. if not version_2_with_negative:
  491. all_predictions[example.qas_id] = nbest_json[0]["text"]
  492. else:
  493. # predict "" iff the null score - the score of best non-null > threshold
  494. score_diff = score_null - best_non_null_entry.start_logit - (
  495. best_non_null_entry.end_logit)
  496. scores_diff_json[example.qas_id] = score_diff
  497. try:
  498. null_score_diff_threshold = FLAGS.null_score_diff_threshold
  499. except:
  500. null_score_diff_threshold = 0.0
  501. if score_diff > null_score_diff_threshold:
  502. all_predictions[example.qas_id] = ""
  503. else:
  504. all_predictions[example.qas_id] = best_non_null_entry.text
  505. all_nbest_json[example.qas_id] = nbest_json
  506. return all_predictions, all_nbest_json, scores_diff_json
  507. def write_predictions(all_examples, all_features, all_results, n_best_size,
  508. max_answer_length, do_lower_case, output_prediction_file,
  509. output_nbest_file, output_null_log_odds_file,
  510. version_2_with_negative, verbose_logging):
  511. """Write final predictions to the json file and log-odds of null if needed."""
  512. tf.compat.v1.logging.info("Writing predictions to: %s" % (output_prediction_file))
  513. tf.compat.v1.logging.info("Writing nbest to: %s" % (output_nbest_file))
  514. all_predictions, all_nbest_json, scores_diff_json = get_predictions(all_examples, all_features,
  515. all_results, n_best_size, max_answer_length, do_lower_case, version_2_with_negative, verbose_logging)
  516. with tf.io.gfile.GFile(output_prediction_file, "w") as writer:
  517. writer.write(json.dumps(all_predictions, indent=4) + "\n")
  518. with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
  519. writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
  520. if version_2_with_negative:
  521. with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer:
  522. writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
  523. def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging):
  524. """Project the tokenized prediction back to the original text."""
  525. # When we created the data, we kept track of the alignment between original
  526. # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
  527. # now `orig_text` contains the span of our original text corresponding to the
  528. # span that we predicted.
  529. #
  530. # However, `orig_text` may contain extra characters that we don't want in
  531. # our prediction.
  532. #
  533. # For example, let's say:
  534. # pred_text = steve smith
  535. # orig_text = Steve Smith's
  536. #
  537. # We don't want to return `orig_text` because it contains the extra "'s".
  538. #
  539. # We don't want to return `pred_text` because it's already been normalized
  540. # (the SQuAD eval script also does punctuation stripping/lower casing but
  541. # our tokenizer does additional normalization like stripping accent
  542. # characters).
  543. #
  544. # What we really want to return is "Steve Smith".
  545. #
  546. # Therefore, we have to apply a semi-complicated alignment heruistic between
  547. # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
  548. # can fail in certain cases in which case we just return `orig_text`.
  549. def _strip_spaces(text):
  550. ns_chars = []
  551. ns_to_s_map = collections.OrderedDict()
  552. for (i, c) in enumerate(text):
  553. if c == " ":
  554. continue
  555. ns_to_s_map[len(ns_chars)] = i
  556. ns_chars.append(c)
  557. ns_text = "".join(ns_chars)
  558. return (ns_text, ns_to_s_map)
  559. # We first tokenize `orig_text`, strip whitespace from the result
  560. # and `pred_text`, and check if they are the same length. If they are
  561. # NOT the same length, the heuristic has failed. If they are the same
  562. # length, we assume the characters are one-to-one aligned.
  563. tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
  564. tok_text = " ".join(tokenizer.tokenize(orig_text))
  565. start_position = tok_text.find(pred_text)
  566. if start_position == -1:
  567. if verbose_logging:
  568. tf.compat.v1.logging.info(
  569. "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
  570. return orig_text
  571. end_position = start_position + len(pred_text) - 1
  572. (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
  573. (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
  574. if len(orig_ns_text) != len(tok_ns_text):
  575. if verbose_logging:
  576. tf.compat.v1.logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
  577. orig_ns_text, tok_ns_text)
  578. return orig_text
  579. # We then project the characters in `pred_text` back to `orig_text` using
  580. # the character-to-character alignment.
  581. tok_s_to_ns_map = {}
  582. for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
  583. tok_s_to_ns_map[tok_index] = i
  584. orig_start_position = None
  585. if start_position in tok_s_to_ns_map:
  586. ns_start_position = tok_s_to_ns_map[start_position]
  587. if ns_start_position in orig_ns_to_s_map:
  588. orig_start_position = orig_ns_to_s_map[ns_start_position]
  589. if orig_start_position is None:
  590. if verbose_logging:
  591. tf.compat.v1.logging.info("Couldn't map start position")
  592. return orig_text
  593. orig_end_position = None
  594. if end_position in tok_s_to_ns_map:
  595. ns_end_position = tok_s_to_ns_map[end_position]
  596. if ns_end_position in orig_ns_to_s_map:
  597. orig_end_position = orig_ns_to_s_map[ns_end_position]
  598. if orig_end_position is None:
  599. if verbose_logging:
  600. tf.compat.v1.logging.info("Couldn't map end position")
  601. return orig_text
  602. output_text = orig_text[orig_start_position:(orig_end_position + 1)]
  603. return output_text
  604. def _get_best_indexes(logits, n_best_size):
  605. """Get the n-best logits from a list."""
  606. index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
  607. best_indexes = []
  608. for i in range(len(index_and_score)):
  609. if i >= n_best_size:
  610. break
  611. best_indexes.append(index_and_score[i][0])
  612. return best_indexes
  613. def _compute_softmax(scores):
  614. """Compute softmax probability over raw logits."""
  615. if not scores:
  616. return []
  617. max_score = None
  618. for score in scores:
  619. if max_score is None or score > max_score:
  620. max_score = score
  621. exp_scores = []
  622. total_sum = 0.0
  623. for score in scores:
  624. x = math.exp(score - max_score)
  625. exp_scores.append(x)
  626. total_sum += x
  627. probs = []
  628. for score in exp_scores:
  629. probs.append(score / total_sum)
  630. return probs
  631. def validate_flags_or_throw(bert_config):
  632. """Validate the input FLAGS or throw an exception."""
  633. tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
  634. FLAGS.init_checkpoint)
  635. if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.export_triton:
  636. raise ValueError("At least one of `do_train` or `do_predict` or `export_SavedModel` must be True.")
  637. if FLAGS.do_train:
  638. if not FLAGS.train_file:
  639. raise ValueError(
  640. "If `do_train` is True, then `train_file` must be specified.")
  641. if FLAGS.do_predict:
  642. if not FLAGS.predict_file:
  643. raise ValueError(
  644. "If `do_predict` is True, then `predict_file` must be specified.")
  645. if FLAGS.max_seq_length > bert_config.max_position_embeddings:
  646. raise ValueError(
  647. "Cannot use sequence length %d because the BERT model "
  648. "was only trained up to sequence length %d" %
  649. (FLAGS.max_seq_length, bert_config.max_position_embeddings))
  650. if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
  651. raise ValueError(
  652. "The max_seq_length (%d) must be greater than max_query_length "
  653. "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))
  654. def export_model(estimator, export_dir, init_checkpoint):
  655. """Exports a checkpoint in SavedModel format in a directory structure compatible with Triton."""
  656. def serving_input_fn():
  657. label_ids = tf.placeholder(tf.int32, [None,], name='unique_ids')
  658. input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
  659. input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
  660. segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
  661. input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
  662. 'unique_ids': label_ids,
  663. 'input_ids': input_ids,
  664. 'input_mask': input_mask,
  665. 'segment_ids': segment_ids,
  666. })()
  667. return input_fn
  668. saved_dir = estimator.export_savedmodel(
  669. export_dir,
  670. serving_input_fn,
  671. assets_extra=None,
  672. as_text=False,
  673. checkpoint_path=init_checkpoint,
  674. strip_default_attrs=False)
  675. model_name = FLAGS.triton_model_name
  676. model_folder = export_dir + "/triton_models/" + model_name
  677. version_folder = model_folder + "/" + str(FLAGS.triton_model_version)
  678. final_model_folder = version_folder + "/model.savedmodel"
  679. if not os.path.exists(version_folder):
  680. os.makedirs(version_folder)
  681. if (not os.path.exists(final_model_folder)):
  682. os.rename(saved_dir, final_model_folder)
  683. print("Model saved to dir", final_model_folder)
  684. else:
  685. if (FLAGS.triton_model_overwrite):
  686. shutil.rmtree(final_model_folder)
  687. os.rename(saved_dir, final_model_folder)
  688. print("WARNING: Existing model was overwritten. Model dir: {}".format(final_model_folder))
  689. else:
  690. print("ERROR: Could not save Triton model. Folder already exists. Use '--triton_model_overwrite=True' if you would like to overwrite an existing model. Model dir: {}".format(final_model_folder))
  691. return
  692. # Now build the config for Triton. Check to make sure we can overwrite it, if it exists
  693. config_filename = os.path.join(model_folder, "config.pbtxt")
  694. optimization_str = ""
  695. if FLAGS.amp:
  696. optimization_str = r"""
  697. optimization {
  698. execution_accelerators
  699. {
  700. gpu_execution_accelerator :
  701. [ {
  702. name : "auto_mixed_precision"
  703. } ]
  704. }
  705. }"""
  706. if (os.path.exists(config_filename) and not FLAGS.triton_model_overwrite):
  707. print("ERROR: Could not save Triton model config. Config file already exists. Use '--triton_model_overwrite=True' if you would like to overwrite an existing model config. Model config: {}".format(config_filename))
  708. return
  709. config_template = r"""
  710. name: "{model_name}"
  711. platform: "tensorflow_savedmodel"
  712. max_batch_size: {max_batch_size}
  713. {optimization_str}
  714. input [
  715. {{
  716. name: "unique_ids"
  717. data_type: TYPE_INT32
  718. dims: [ 1 ]
  719. reshape: {{ shape: [ ] }}
  720. }},
  721. {{
  722. name: "segment_ids"
  723. data_type: TYPE_INT32
  724. dims: {seq_length}
  725. }},
  726. {{
  727. name: "input_ids"
  728. data_type: TYPE_INT32
  729. dims: {seq_length}
  730. }},
  731. {{
  732. name: "input_mask"
  733. data_type: TYPE_INT32
  734. dims: {seq_length}
  735. }}
  736. ]
  737. output [
  738. {{
  739. name: "end_logits"
  740. data_type: TYPE_FP32
  741. dims: {seq_length}
  742. }},
  743. {{
  744. name: "start_logits"
  745. data_type: TYPE_FP32
  746. dims: {seq_length}
  747. }}
  748. ]
  749. {dynamic_batching}
  750. instance_group [
  751. {{
  752. count: {engine_count}
  753. }}
  754. ]"""
  755. batching_str = ""
  756. max_batch_size = FLAGS.triton_max_batch_size
  757. if (FLAGS.triton_dyn_batching_delay > 0):
  758. # Use only full and half full batches
  759. pref_batch_size = [int(max_batch_size / 2.0), max_batch_size]
  760. batching_str = r"""
  761. dynamic_batching {{
  762. preferred_batch_size: [{0}]
  763. max_queue_delay_microseconds: {1}
  764. }}""".format(", ".join([str(x) for x in pref_batch_size]), int(FLAGS.triton_dyn_batching_delay * 1000.0))
  765. config_values = {
  766. "model_name": model_name,
  767. "max_batch_size": max_batch_size,
  768. "seq_length": FLAGS.max_seq_length,
  769. "dynamic_batching": batching_str,
  770. "engine_count": FLAGS.triton_engine_count,
  771. "optimization_str":optimization_str,
  772. }
  773. with open(model_folder + "/config.pbtxt", "w") as file:
  774. final_config_str = config_template.format_map(config_values)
  775. file.write(final_config_str)
  776. def main(_):
  777. setup_xla_flags()
  778. tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  779. dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)
  780. if FLAGS.horovod:
  781. hvd.init()
  782. bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  783. validate_flags_or_throw(bert_config)
  784. tf.io.gfile.makedirs(FLAGS.output_dir)
  785. tokenizer = tokenization.FullTokenizer(
  786. vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  787. master_process = True
  788. training_hooks = []
  789. global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
  790. hvd_rank = 0
  791. config = tf.compat.v1.ConfigProto()
  792. learning_rate = FLAGS.learning_rate
  793. if FLAGS.horovod:
  794. tf.compat.v1.logging.info("Multi-GPU training with TF Horovod")
  795. tf.compat.v1.logging.info("hvd.size() = %d hvd.rank() = %d", hvd.size(), hvd.rank())
  796. global_batch_size = FLAGS.train_batch_size * hvd.size() * FLAGS.num_accumulation_steps
  797. learning_rate = learning_rate * hvd.size()
  798. master_process = (hvd.rank() == 0)
  799. hvd_rank = hvd.rank()
  800. config.gpu_options.visible_device_list = str(hvd.local_rank())
  801. set_affinity(hvd.local_rank())
  802. if hvd.size() > 1:
  803. training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
  804. if FLAGS.use_xla:
  805. config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
  806. if FLAGS.amp:
  807. tf.enable_resource_variables()
  808. run_config = tf.estimator.RunConfig(
  809. model_dir=FLAGS.output_dir if master_process else None,
  810. session_config=config,
  811. save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
  812. save_summary_steps=FLAGS.save_checkpoints_steps if master_process else None,
  813. log_step_count_steps=FLAGS.display_loss_steps,
  814. keep_checkpoint_max=1)
  815. if master_process:
  816. tf.compat.v1.logging.info("***** Configuaration *****")
  817. for key in FLAGS.__flags.keys():
  818. tf.compat.v1.logging.info(' {}: {}'.format(key, getattr(FLAGS, key)))
  819. tf.compat.v1.logging.info("**************************")
  820. train_examples = None
  821. num_train_steps = None
  822. num_warmup_steps = None
  823. training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank, FLAGS.save_checkpoints_steps))
  824. # Prepare Training Data
  825. if FLAGS.do_train:
  826. train_examples = read_squad_examples(
  827. input_file=FLAGS.train_file, is_training=True,
  828. version_2_with_negative=FLAGS.version_2_with_negative)
  829. num_train_steps = int(
  830. len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
  831. num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
  832. # Pre-shuffle the input to avoid having to make a very large shuffle
  833. # buffer in in the `input_fn`.
  834. rng = random.Random(12345)
  835. rng.shuffle(train_examples)
  836. start_index = 0
  837. end_index = len(train_examples)
  838. tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]
  839. if FLAGS.horovod:
  840. tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
  841. num_examples_per_rank = len(train_examples) // hvd.size()
  842. remainder = len(train_examples) % hvd.size()
  843. if hvd.rank() < remainder:
  844. start_index = hvd.rank() * (num_examples_per_rank+1)
  845. end_index = start_index + num_examples_per_rank + 1
  846. else:
  847. start_index = hvd.rank() * num_examples_per_rank + remainder
  848. end_index = start_index + (num_examples_per_rank)
  849. model_fn = model_fn_builder(
  850. bert_config=bert_config,
  851. init_checkpoint=FLAGS.init_checkpoint,
  852. learning_rate=learning_rate,
  853. num_train_steps=num_train_steps,
  854. num_warmup_steps=num_warmup_steps,
  855. hvd=None if not FLAGS.horovod else hvd,
  856. amp=FLAGS.amp)
  857. estimator = tf.estimator.Estimator(
  858. model_fn=model_fn,
  859. config=run_config)
  860. if FLAGS.do_train:
  861. # We write to a temporary file to avoid storing very large constant tensors
  862. # in memory.
  863. train_writer = FeatureWriter(
  864. filename=tmp_filenames[hvd_rank],
  865. is_training=True)
  866. convert_examples_to_features(
  867. examples=train_examples[start_index:end_index],
  868. tokenizer=tokenizer,
  869. max_seq_length=FLAGS.max_seq_length,
  870. doc_stride=FLAGS.doc_stride,
  871. max_query_length=FLAGS.max_query_length,
  872. is_training=True,
  873. output_fn=train_writer.process_feature,
  874. verbose_logging=FLAGS.verbose_logging)
  875. train_writer.close()
  876. tf.compat.v1.logging.info("***** Running training *****")
  877. tf.compat.v1.logging.info(" Num orig examples = %d", end_index - start_index)
  878. tf.compat.v1.logging.info(" Num split examples = %d", train_writer.num_features)
  879. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
  880. tf.compat.v1.logging.info(" Num steps = %d", num_train_steps)
  881. tf.compat.v1.logging.info(" LR = %f", learning_rate)
  882. del train_examples
  883. train_input_fn = input_fn_builder(
  884. input_file=tmp_filenames,
  885. batch_size=FLAGS.train_batch_size,
  886. seq_length=FLAGS.max_seq_length,
  887. is_training=True,
  888. drop_remainder=True,
  889. hvd=None if not FLAGS.horovod else hvd)
  890. train_start_time = time.time()
  891. estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=num_train_steps)
  892. train_time_elapsed = time.time() - train_start_time
  893. train_time_wo_overhead = training_hooks[-1].total_time
  894. avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
  895. ss_sentences_per_second = (num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead
  896. if master_process:
  897. tf.compat.v1.logging.info("-----------------------------")
  898. tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
  899. num_train_steps * global_batch_size)
  900. tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
  901. (num_train_steps - training_hooks[-1].skipped) * global_batch_size)
  902. tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
  903. tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
  904. dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
  905. tf.compat.v1.logging.info("-----------------------------")
  906. if FLAGS.export_triton and master_process:
  907. export_model(estimator, FLAGS.output_dir, FLAGS.init_checkpoint)
  908. if FLAGS.do_predict and master_process:
  909. eval_examples = read_squad_examples(
  910. input_file=FLAGS.predict_file, is_training=False,
  911. version_2_with_negative=FLAGS.version_2_with_negative)
  912. # Perform evaluation on subset, useful for profiling
  913. if FLAGS.num_eval_iterations is not None:
  914. eval_examples = eval_examples[:FLAGS.num_eval_iterations*FLAGS.predict_batch_size]
  915. eval_writer = FeatureWriter(
  916. filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
  917. is_training=False)
  918. eval_features = []
  919. def append_feature(feature):
  920. eval_features.append(feature)
  921. eval_writer.process_feature(feature)
  922. convert_examples_to_features(
  923. examples=eval_examples,
  924. tokenizer=tokenizer,
  925. max_seq_length=FLAGS.max_seq_length,
  926. doc_stride=FLAGS.doc_stride,
  927. max_query_length=FLAGS.max_query_length,
  928. is_training=False,
  929. output_fn=append_feature,
  930. verbose_logging=FLAGS.verbose_logging)
  931. eval_writer.close()
  932. tf.compat.v1.logging.info("***** Running predictions *****")
  933. tf.compat.v1.logging.info(" Num orig examples = %d", len(eval_examples))
  934. tf.compat.v1.logging.info(" Num split examples = %d", len(eval_features))
  935. tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
  936. predict_input_fn = input_fn_builder(
  937. input_file=eval_writer.filename,
  938. batch_size=FLAGS.predict_batch_size,
  939. seq_length=FLAGS.max_seq_length,
  940. is_training=False,
  941. drop_remainder=False)
  942. all_results = []
  943. eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
  944. eval_start_time = time.time()
  945. for result in estimator.predict(
  946. predict_input_fn, yield_single_examples=True, hooks=eval_hooks):
  947. if len(all_results) % 1000 == 0:
  948. tf.compat.v1.logging.info("Processing example: %d" % (len(all_results)))
  949. unique_id = int(result["unique_ids"])
  950. start_logits = [float(x) for x in result["start_logits"].flat]
  951. end_logits = [float(x) for x in result["end_logits"].flat]
  952. all_results.append(
  953. RawResult(
  954. unique_id=unique_id,
  955. start_logits=start_logits,
  956. end_logits=end_logits))
  957. eval_time_elapsed = time.time() - eval_start_time
  958. time_list = eval_hooks[-1].time_list
  959. time_list.sort()
  960. # Removing outliers (init/warmup) in throughput computation.
  961. eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
  962. num_sentences = (int(len(time_list) * 0.99)) * FLAGS.predict_batch_size
  963. avg = np.mean(time_list)
  964. cf_50 = max(time_list[:int(len(time_list) * 0.50)])
  965. cf_90 = max(time_list[:int(len(time_list) * 0.90)])
  966. cf_95 = max(time_list[:int(len(time_list) * 0.95)])
  967. cf_99 = max(time_list[:int(len(time_list) * 0.99)])
  968. cf_100 = max(time_list[:int(len(time_list) * 1)])
  969. ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead
  970. tf.compat.v1.logging.info("-----------------------------")
  971. tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
  972. eval_hooks[-1].count * FLAGS.predict_batch_size)
  973. tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
  974. num_sentences)
  975. tf.compat.v1.logging.info("Summary Inference Statistics")
  976. tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
  977. tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
  978. tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
  979. tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
  980. tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
  981. tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
  982. tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
  983. tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
  984. tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
  985. tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
  986. dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
  987. tf.compat.v1.logging.info("-----------------------------")
  988. output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
  989. output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
  990. output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")
  991. write_predictions(eval_examples, eval_features, all_results,
  992. FLAGS.n_best_size, FLAGS.max_answer_length,
  993. FLAGS.do_lower_case, output_prediction_file,
  994. output_nbest_file, output_null_log_odds_file,
  995. FLAGS.version_2_with_negative, FLAGS.verbose_logging)
  996. if FLAGS.eval_script:
  997. import sys
  998. import subprocess
  999. eval_out = subprocess.check_output([sys.executable, FLAGS.eval_script,
  1000. FLAGS.predict_file, output_prediction_file])
  1001. scores = str(eval_out).strip()
  1002. exact_match = float(scores.split(":")[1].split(",")[0])
  1003. f1 = float(scores.split(":")[2].split("}")[0])
  1004. dllogging.logger.log(step=(), data={"f1": f1}, verbosity=Verbosity.DEFAULT)
  1005. dllogging.logger.log(step=(), data={"exact_match": exact_match}, verbosity=Verbosity.DEFAULT)
  1006. print(str(eval_out))
  1007. if __name__ == "__main__":
  1008. FLAGS = extract_run_squad_flags()
  1009. tf.app.run()