| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230 |
- # coding=utf-8
- # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
- # Copyright 2018 The Google AI Language Team Authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Run BERT on SQuAD 1.1 and SQuAD 2.0."""
- from __future__ import absolute_import, division, print_function
- import collections
- import json
- import math
- import os
- import random
- import shutil
- import time
- import horovod.tensorflow as hvd
- import numpy as np
- import six
- import tensorflow as tf
- from tensorflow.python.client import device_lib
- import modeling
- import optimization
- import tokenization
- from utils.create_squad_data import *
- from utils.utils import LogEvalRunHook, LogTrainRunHook, setup_xla_flags
- from utils.gpu_affinity import set_affinity
- import utils.dllogger_class
- from dllogger import Verbosity
- flags = tf.flags
- FLAGS = None
- def extract_run_squad_flags():
- ## Required parameters
- flags.DEFINE_string(
- "bert_config_file", None,
- "The config json file corresponding to the pre-trained BERT model. "
- "This specifies the model architecture.")
- flags.DEFINE_string("vocab_file", None,
- "The vocabulary file that the BERT model was trained on.")
- flags.DEFINE_string(
- "output_dir", None,
- "The output directory where the model checkpoints will be written.")
- ## Other parameters
- flags.DEFINE_string(
- "dllog_path", "/results/bert_dllog.json",
- "filename where dllogger writes to")
- flags.DEFINE_string("train_file", None,
- "SQuAD json for training. E.g., train-v1.1.json")
- flags.DEFINE_string(
- "predict_file", None,
- "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
- flags.DEFINE_string(
- "eval_script", None,
- "SQuAD evaluate.py file to compute f1 and exact_match E.g., evaluate-v1.1.py")
- flags.DEFINE_string(
- "init_checkpoint", None,
- "Initial checkpoint (usually from a pre-trained BERT model).")
- flags.DEFINE_bool(
- "do_lower_case", True,
- "Whether to lower case the input text. Should be True for uncased "
- "models and False for cased models.")
- flags.DEFINE_integer(
- "max_seq_length", 384,
- "The maximum total input sequence length after WordPiece tokenization. "
- "Sequences longer than this will be truncated, and sequences shorter "
- "than this will be padded.")
- flags.DEFINE_integer(
- "doc_stride", 128,
- "When splitting up a long document into chunks, how much stride to "
- "take between chunks.")
- flags.DEFINE_integer(
- "max_query_length", 64,
- "The maximum number of tokens for the question. Questions longer than "
- "this will be truncated to this length.")
- flags.DEFINE_bool("do_train", False, "Whether to run training.")
- flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
- flags.DEFINE_integer("train_batch_size", 8, "Total batch size for training.")
- flags.DEFINE_integer("predict_batch_size", 8,
- "Total batch size for predictions.")
- flags.DEFINE_float("learning_rate", 5e-6, "The initial learning rate for Adam.")
- flags.DEFINE_bool("use_trt", False, "Whether to use TF-TRT")
- flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs")
- flags.DEFINE_float("num_train_epochs", 3.0,
- "Total number of training epochs to perform.")
- flags.DEFINE_float(
- "warmup_proportion", 0.1,
- "Proportion of training to perform linear learning rate warmup for. "
- "E.g., 0.1 = 10% of training.")
- flags.DEFINE_integer("save_checkpoints_steps", 5000,
- "How often to save the model checkpoint.")
- flags.DEFINE_integer("display_loss_steps", 10,
- "How often to print loss from estimator")
- flags.DEFINE_integer("iterations_per_loop", 1000,
- "How many steps to make in each estimator call.")
- flags.DEFINE_integer("num_accumulation_steps", 1,
- "Number of accumulation steps before gradient update"
- "Global batch size = num_accumulation_steps * train_batch_size")
- flags.DEFINE_integer(
- "n_best_size", 20,
- "The total number of n-best predictions to generate in the "
- "nbest_predictions.json output file.")
- flags.DEFINE_integer(
- "max_answer_length", 30,
- "The maximum length of an answer that can be generated. This is needed "
- "because the start and end predictions are not conditioned on one another.")
- flags.DEFINE_bool(
- "verbose_logging", False,
- "If true, all of the warnings related to data processing will be printed. "
- "A number of warnings are expected for a normal SQuAD evaluation.")
- flags.DEFINE_bool(
- "version_2_with_negative", False,
- "If true, the SQuAD examples contain some that do not have an answer.")
- flags.DEFINE_float(
- "null_score_diff_threshold", 0.0,
- "If null_score - best_non_null is greater than the threshold predict null.")
- flags.DEFINE_bool("amp", True, "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.")
- flags.DEFINE_bool("use_xla", True, "Whether to enable XLA JIT compilation.")
- flags.DEFINE_integer("num_eval_iterations", None,
- "How many eval iterations to run - performs inference on subset")
- # Triton Specific flags
- flags.DEFINE_bool("export_triton", False, "Whether to export saved model or run inference with Triton")
- flags.DEFINE_string("triton_model_name", "bert", "exports to appropriate directory for Triton")
- flags.DEFINE_integer("triton_model_version", 1, "exports to appropriate directory for Triton")
- flags.DEFINE_string("triton_server_url", "localhost:8001", "exports to appropriate directory for Triton")
- flags.DEFINE_bool("triton_model_overwrite", False, "If True, will overwrite an existing directory with the specified 'model_name' and 'version_name'")
- flags.DEFINE_integer("triton_max_batch_size", 8, "Specifies the 'max_batch_size' in the Triton model config. See the Triton documentation for more info.")
- flags.DEFINE_float("triton_dyn_batching_delay", 0, "Determines the dynamic_batching queue delay in milliseconds(ms) for the Triton model config. Use '0' or '-1' to specify static batching. See the Triton documentation for more info.")
- flags.DEFINE_integer("triton_engine_count", 1, "Specifies the 'instance_group' count value in the Triton model config. See the Triton documentation for more info.")
- flags.mark_flag_as_required("vocab_file")
- flags.mark_flag_as_required("bert_config_file")
- flags.mark_flag_as_required("output_dir")
- return flags.FLAGS
- def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
- use_one_hot_embeddings):
- """Creates a classification model."""
- model = modeling.BertModel(
- config=bert_config,
- is_training=is_training,
- input_ids=input_ids,
- input_mask=input_mask,
- token_type_ids=segment_ids,
- use_one_hot_embeddings=use_one_hot_embeddings,
- compute_type=tf.float32)
- final_hidden = model.get_sequence_output()
- final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
- batch_size = final_hidden_shape[0]
- seq_length = final_hidden_shape[1]
- hidden_size = final_hidden_shape[2]
- output_weights = tf.get_variable(
- "cls/squad/output_weights", [2, hidden_size],
- initializer=tf.truncated_normal_initializer(stddev=0.02))
- output_bias = tf.get_variable(
- "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())
- final_hidden_matrix = tf.reshape(final_hidden,
- [batch_size * seq_length, hidden_size])
- logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
- logits = tf.nn.bias_add(logits, output_bias)
- logits = tf.reshape(logits, [batch_size, seq_length, 2])
- logits = tf.transpose(logits, [2, 0, 1])
- unstacked_logits = tf.unstack(logits, axis=0, name='unstack')
- (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
- return (start_logits, end_logits)
- def get_frozen_tftrt_model(bert_config, shape, use_one_hot_embeddings, init_checkpoint):
- tf_config = tf.compat.v1.ConfigProto()
- tf_config.gpu_options.allow_growth = True
- output_node_names = ['unstack']
- with tf.Session(config=tf_config) as tf_sess:
- input_ids = tf.placeholder(tf.int32, shape, 'input_ids')
- input_mask = tf.placeholder(tf.int32, shape, 'input_mask')
- segment_ids = tf.placeholder(tf.int32, shape, 'segment_ids')
- (start_logits, end_logits) = create_model(bert_config=bert_config,
- is_training=False,
- input_ids=input_ids,
- input_mask=input_mask,
- segment_ids=segment_ids,
- use_one_hot_embeddings=use_one_hot_embeddings)
- tvars = tf.trainable_variables()
- (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
- tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
- tf_sess.run(tf.global_variables_initializer())
- print("LOADED!")
- tf.compat.v1.logging.info("**** Trainable Variables ****")
- for var in tvars:
- init_string = ""
- if var.name in initialized_variable_names:
- init_string = ", *INIT_FROM_CKPT*"
- else:
- init_string = ", *NOTTTTTTTTTTTTTTTTTTTTT"
- tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string)
- frozen_graph = tf.graph_util.convert_variables_to_constants(tf_sess,
- tf_sess.graph.as_graph_def(), output_node_names)
- num_nodes = len(frozen_graph.node)
- print('Converting graph using TensorFlow-TensorRT...')
- from tensorflow.python.compiler.tensorrt import trt_convert as trt
- converter = trt.TrtGraphConverter(
- input_graph_def=frozen_graph,
- nodes_blacklist=output_node_names,
- max_workspace_size_bytes=(4096 << 20) - 1000,
- precision_mode = "FP16" if FLAGS.amp else "FP32",
- minimum_segment_size=4,
- is_dynamic_op=True,
- maximum_cached_engines=1000
- )
- frozen_graph = converter.convert()
- print('Total node count before and after TF-TRT conversion:',
- num_nodes, '->', len(frozen_graph.node))
- print('TRT node count:',
- len([1 for n in frozen_graph.node if str(n.op) == 'TRTEngineOp']))
-
- with tf.io.gfile.GFile("frozen_modelTRT.pb", "wb") as f:
- f.write(frozen_graph.SerializeToString())
-
- return frozen_graph
- def model_fn_builder(bert_config, init_checkpoint, learning_rate,
- num_train_steps, num_warmup_steps,
- hvd=None, amp=False, use_one_hot_embeddings=False):
- """Returns `model_fn` closure for Estimator."""
- def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
- """The `model_fn` for Estimator."""
- if FLAGS.verbose_logging:
- tf.compat.v1.logging.info("*** Features ***")
- for name in sorted(features.keys()):
- tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
- unique_ids = features["unique_ids"]
- input_ids = features["input_ids"]
- input_mask = features["input_mask"]
- segment_ids = features["segment_ids"]
- is_training = (mode == tf.estimator.ModeKeys.TRAIN)
- if not is_training and FLAGS.use_trt:
- trt_graph = get_frozen_tftrt_model(bert_config, input_ids.shape, use_one_hot_embeddings, init_checkpoint)
- (start_logits, end_logits) = tf.import_graph_def(trt_graph,
- input_map={'input_ids':input_ids, 'input_mask':input_mask, 'segment_ids':segment_ids},
- return_elements=['unstack:0', 'unstack:1'],
- name='')
- predictions = {
- "unique_ids": unique_ids,
- "start_logits": start_logits,
- "end_logits": end_logits,
- }
- output_spec = tf.estimator.EstimatorSpec(
- mode=mode, predictions=predictions)
- return output_spec
- (start_logits, end_logits) = create_model(
- bert_config=bert_config,
- is_training=is_training,
- input_ids=input_ids,
- input_mask=input_mask,
- segment_ids=segment_ids,
- use_one_hot_embeddings=use_one_hot_embeddings)
- tvars = tf.trainable_variables()
- initialized_variable_names = {}
- if init_checkpoint and (hvd is None or hvd.rank() == 0):
- (assignment_map, initialized_variable_names
- ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
-
- tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
- if FLAGS.verbose_logging:
- tf.compat.v1.logging.info("**** Trainable Variables ****")
- for var in tvars:
- init_string = ""
- if var.name in initialized_variable_names:
- init_string = ", *INIT_FROM_CKPT*"
- tf.compat.v1.logging.info(" %d name = %s, shape = %s%s", 0 if hvd is None else hvd.rank(), var.name, var.shape,
- init_string)
- output_spec = None
- if mode == tf.estimator.ModeKeys.TRAIN:
- seq_length = modeling.get_shape_list(input_ids)[1]
- def compute_loss(logits, positions):
- one_hot_positions = tf.one_hot(
- positions, depth=seq_length, dtype=tf.float32)
- log_probs = tf.nn.log_softmax(logits, axis=-1)
- loss = -tf.reduce_mean(
- tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
- return loss
- start_positions = features["start_positions"]
- end_positions = features["end_positions"]
- start_loss = compute_loss(start_logits, start_positions)
- end_loss = compute_loss(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2.0
- train_op = optimization.create_optimizer(
- total_loss, learning_rate, num_train_steps, num_warmup_steps, hvd, False, amp, FLAGS.num_accumulation_steps)
- output_spec = tf.estimator.EstimatorSpec(
- mode=mode,
- loss=total_loss,
- train_op=train_op)
- elif mode == tf.estimator.ModeKeys.PREDICT:
- dummy_op = tf.no_op()
- # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
- if amp:
- loss_scaler = tf.train.experimental.FixedLossScale(1)
- dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
- optimization.LAMBOptimizer(learning_rate=0.0), loss_scaler)
- predictions = {
- "unique_ids": tf.identity(unique_ids),
- "start_logits": start_logits,
- "end_logits": end_logits,
- }
- output_spec = tf.estimator.EstimatorSpec(
- mode=mode, predictions=predictions)
- else:
- raise ValueError(
- "Only TRAIN and PREDICT modes are supported: %s" % (mode))
- return output_spec
- return model_fn
- def input_fn_builder(input_file, batch_size, seq_length, is_training, drop_remainder, hvd=None):
- """Creates an `input_fn` closure to be passed to Estimator."""
- name_to_features = {
- "unique_ids": tf.io.FixedLenFeature([], tf.int64),
- "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
- "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
- "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
- }
- if is_training:
- name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
- name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)
- def _decode_record(record, name_to_features):
- """Decodes a record to a TensorFlow example."""
- example = tf.parse_single_example(record, name_to_features)
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
- # So cast all int64 to int32.
- for name in list(example.keys()):
- t = example[name]
- if t.dtype == tf.int64:
- t = tf.to_int32(t)
- example[name] = t
- return example
- def input_fn():
- """The actual input function."""
- # For training, we want a lot of parallel reading and shuffling.
- # For eval, we want no shuffling and parallel reading doesn't matter.
- if is_training:
- d = tf.data.TFRecordDataset(input_file, num_parallel_reads=4)
- if hvd is not None: d = d.shard(hvd.size(), hvd.rank())
- d = d.apply(tf.data.experimental.ignore_errors())
- d = d.shuffle(buffer_size=100)
- d = d.repeat()
- else:
- d = tf.data.TFRecordDataset(input_file)
- d = d.apply(
- tf.contrib.data.map_and_batch(
- lambda record: _decode_record(record, name_to_features),
- batch_size=batch_size,
- drop_remainder=drop_remainder))
- return d
- return input_fn
- RawResult = collections.namedtuple("RawResult",
- ["unique_id", "start_logits", "end_logits"])
- def get_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length,
- do_lower_case, version_2_with_negative, verbose_logging):
- """Get final predictions"""
- example_index_to_features = collections.defaultdict(list)
- for feature in all_features:
- example_index_to_features[feature.example_index].append(feature)
- unique_id_to_result = {}
- for result in all_results:
- unique_id_to_result[result.unique_id] = result
- _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
- "PrelimPrediction",
- ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
- all_predictions = collections.OrderedDict()
- all_nbest_json = collections.OrderedDict()
- scores_diff_json = collections.OrderedDict()
- for (example_index, example) in enumerate(all_examples):
- features = example_index_to_features[example_index]
- prelim_predictions = []
- # keep track of the minimum score of null start+end of position 0
- score_null = 1000000 # large and positive
- min_null_feature_index = 0 # the paragraph slice with min mull score
- null_start_logit = 0 # the start logit at the slice with min null score
- null_end_logit = 0 # the end logit at the slice with min null score
- for (feature_index, feature) in enumerate(features):
- result = unique_id_to_result[feature.unique_id]
- start_indexes = _get_best_indexes(result.start_logits, n_best_size)
- end_indexes = _get_best_indexes(result.end_logits, n_best_size)
- # if we could have irrelevant answers, get the min score of irrelevant
- if version_2_with_negative:
- feature_null_score = result.start_logits[0] + result.end_logits[0]
- if feature_null_score < score_null:
- score_null = feature_null_score
- min_null_feature_index = feature_index
- null_start_logit = result.start_logits[0]
- null_end_logit = result.end_logits[0]
- for start_index in start_indexes:
- for end_index in end_indexes:
- # We could hypothetically create invalid predictions, e.g., predict
- # that the start of the span is in the question. We throw out all
- # invalid predictions.
- if start_index >= len(feature.tokens):
- continue
- if end_index >= len(feature.tokens):
- continue
- if start_index not in feature.token_to_orig_map:
- continue
- if end_index not in feature.token_to_orig_map:
- continue
- if not feature.token_is_max_context.get(start_index, False):
- continue
- if end_index < start_index:
- continue
- length = end_index - start_index + 1
- if length > max_answer_length:
- continue
- prelim_predictions.append(
- _PrelimPrediction(
- feature_index=feature_index,
- start_index=start_index,
- end_index=end_index,
- start_logit=result.start_logits[start_index],
- end_logit=result.end_logits[end_index]))
- if version_2_with_negative:
- prelim_predictions.append(
- _PrelimPrediction(
- feature_index=min_null_feature_index,
- start_index=0,
- end_index=0,
- start_logit=null_start_logit,
- end_logit=null_end_logit))
- prelim_predictions = sorted(
- prelim_predictions,
- key=lambda x: (x.start_logit + x.end_logit),
- reverse=True)
- _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
- "NbestPrediction", ["text", "start_logit", "end_logit"])
- seen_predictions = {}
- nbest = []
- for pred in prelim_predictions:
- if len(nbest) >= n_best_size:
- break
- feature = features[pred.feature_index]
- if pred.start_index > 0: # this is a non-null prediction
- tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
- orig_doc_start = feature.token_to_orig_map[pred.start_index]
- orig_doc_end = feature.token_to_orig_map[pred.end_index]
- orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
- tok_text = " ".join(tok_tokens)
- # De-tokenize WordPieces that have been split off.
- tok_text = tok_text.replace(" ##", "")
- tok_text = tok_text.replace("##", "")
- # Clean whitespace
- tok_text = tok_text.strip()
- tok_text = " ".join(tok_text.split())
- orig_text = " ".join(orig_tokens)
- final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
- if final_text in seen_predictions:
- continue
- seen_predictions[final_text] = True
- else:
- final_text = ""
- seen_predictions[final_text] = True
- nbest.append(
- _NbestPrediction(
- text=final_text,
- start_logit=pred.start_logit,
- end_logit=pred.end_logit))
- # if we didn't inlude the empty option in the n-best, inlcude it
- if version_2_with_negative:
- if "" not in seen_predictions:
- nbest.append(
- _NbestPrediction(
- text="", start_logit=null_start_logit,
- end_logit=null_end_logit))
- # In very rare edge cases we could have no valid predictions. So we
- # just create a nonce prediction in this case to avoid failure.
- if not nbest:
- nbest.append(
- _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
- assert len(nbest) >= 1
- total_scores = []
- best_non_null_entry = None
- for entry in nbest:
- total_scores.append(entry.start_logit + entry.end_logit)
- if not best_non_null_entry:
- if entry.text:
- best_non_null_entry = entry
- probs = _compute_softmax(total_scores)
- nbest_json = []
- for (i, entry) in enumerate(nbest):
- output = collections.OrderedDict()
- output["text"] = entry.text
- output["probability"] = probs[i]
- output["start_logit"] = entry.start_logit
- output["end_logit"] = entry.end_logit
- nbest_json.append(output)
- assert len(nbest_json) >= 1
- if not version_2_with_negative:
- all_predictions[example.qas_id] = nbest_json[0]["text"]
- else:
- # predict "" iff the null score - the score of best non-null > threshold
- score_diff = score_null - best_non_null_entry.start_logit - (
- best_non_null_entry.end_logit)
- scores_diff_json[example.qas_id] = score_diff
- try:
- null_score_diff_threshold = FLAGS.null_score_diff_threshold
- except:
- null_score_diff_threshold = 0.0
- if score_diff > null_score_diff_threshold:
- all_predictions[example.qas_id] = ""
- else:
- all_predictions[example.qas_id] = best_non_null_entry.text
- all_nbest_json[example.qas_id] = nbest_json
- return all_predictions, all_nbest_json, scores_diff_json
- def write_predictions(all_examples, all_features, all_results, n_best_size,
- max_answer_length, do_lower_case, output_prediction_file,
- output_nbest_file, output_null_log_odds_file,
- version_2_with_negative, verbose_logging):
- """Write final predictions to the json file and log-odds of null if needed."""
- tf.compat.v1.logging.info("Writing predictions to: %s" % (output_prediction_file))
- tf.compat.v1.logging.info("Writing nbest to: %s" % (output_nbest_file))
- all_predictions, all_nbest_json, scores_diff_json = get_predictions(all_examples, all_features,
- all_results, n_best_size, max_answer_length, do_lower_case, version_2_with_negative, verbose_logging)
- with tf.io.gfile.GFile(output_prediction_file, "w") as writer:
- writer.write(json.dumps(all_predictions, indent=4) + "\n")
- with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
- writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
- if version_2_with_negative:
- with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer:
- writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
- def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging):
- """Project the tokenized prediction back to the original text."""
- # When we created the data, we kept track of the alignment between original
- # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
- # now `orig_text` contains the span of our original text corresponding to the
- # span that we predicted.
- #
- # However, `orig_text` may contain extra characters that we don't want in
- # our prediction.
- #
- # For example, let's say:
- # pred_text = steve smith
- # orig_text = Steve Smith's
- #
- # We don't want to return `orig_text` because it contains the extra "'s".
- #
- # We don't want to return `pred_text` because it's already been normalized
- # (the SQuAD eval script also does punctuation stripping/lower casing but
- # our tokenizer does additional normalization like stripping accent
- # characters).
- #
- # What we really want to return is "Steve Smith".
- #
- # Therefore, we have to apply a semi-complicated alignment heruistic between
- # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
- # can fail in certain cases in which case we just return `orig_text`.
- def _strip_spaces(text):
- ns_chars = []
- ns_to_s_map = collections.OrderedDict()
- for (i, c) in enumerate(text):
- if c == " ":
- continue
- ns_to_s_map[len(ns_chars)] = i
- ns_chars.append(c)
- ns_text = "".join(ns_chars)
- return (ns_text, ns_to_s_map)
- # We first tokenize `orig_text`, strip whitespace from the result
- # and `pred_text`, and check if they are the same length. If they are
- # NOT the same length, the heuristic has failed. If they are the same
- # length, we assume the characters are one-to-one aligned.
- tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
- tok_text = " ".join(tokenizer.tokenize(orig_text))
- start_position = tok_text.find(pred_text)
- if start_position == -1:
- if verbose_logging:
- tf.compat.v1.logging.info(
- "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
- return orig_text
- end_position = start_position + len(pred_text) - 1
- (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
- (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
- if len(orig_ns_text) != len(tok_ns_text):
- if verbose_logging:
- tf.compat.v1.logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
- orig_ns_text, tok_ns_text)
- return orig_text
- # We then project the characters in `pred_text` back to `orig_text` using
- # the character-to-character alignment.
- tok_s_to_ns_map = {}
- for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
- tok_s_to_ns_map[tok_index] = i
- orig_start_position = None
- if start_position in tok_s_to_ns_map:
- ns_start_position = tok_s_to_ns_map[start_position]
- if ns_start_position in orig_ns_to_s_map:
- orig_start_position = orig_ns_to_s_map[ns_start_position]
- if orig_start_position is None:
- if verbose_logging:
- tf.compat.v1.logging.info("Couldn't map start position")
- return orig_text
- orig_end_position = None
- if end_position in tok_s_to_ns_map:
- ns_end_position = tok_s_to_ns_map[end_position]
- if ns_end_position in orig_ns_to_s_map:
- orig_end_position = orig_ns_to_s_map[ns_end_position]
- if orig_end_position is None:
- if verbose_logging:
- tf.compat.v1.logging.info("Couldn't map end position")
- return orig_text
- output_text = orig_text[orig_start_position:(orig_end_position + 1)]
- return output_text
- def _get_best_indexes(logits, n_best_size):
- """Get the n-best logits from a list."""
- index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
- best_indexes = []
- for i in range(len(index_and_score)):
- if i >= n_best_size:
- break
- best_indexes.append(index_and_score[i][0])
- return best_indexes
- def _compute_softmax(scores):
- """Compute softmax probability over raw logits."""
- if not scores:
- return []
- max_score = None
- for score in scores:
- if max_score is None or score > max_score:
- max_score = score
- exp_scores = []
- total_sum = 0.0
- for score in scores:
- x = math.exp(score - max_score)
- exp_scores.append(x)
- total_sum += x
- probs = []
- for score in exp_scores:
- probs.append(score / total_sum)
- return probs
- def validate_flags_or_throw(bert_config):
- """Validate the input FLAGS or throw an exception."""
- tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
- FLAGS.init_checkpoint)
- if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.export_triton:
- raise ValueError("At least one of `do_train` or `do_predict` or `export_SavedModel` must be True.")
- if FLAGS.do_train:
- if not FLAGS.train_file:
- raise ValueError(
- "If `do_train` is True, then `train_file` must be specified.")
- if FLAGS.do_predict:
- if not FLAGS.predict_file:
- raise ValueError(
- "If `do_predict` is True, then `predict_file` must be specified.")
- if FLAGS.max_seq_length > bert_config.max_position_embeddings:
- raise ValueError(
- "Cannot use sequence length %d because the BERT model "
- "was only trained up to sequence length %d" %
- (FLAGS.max_seq_length, bert_config.max_position_embeddings))
- if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
- raise ValueError(
- "The max_seq_length (%d) must be greater than max_query_length "
- "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))
- def export_model(estimator, export_dir, init_checkpoint):
- """Exports a checkpoint in SavedModel format in a directory structure compatible with Triton."""
- def serving_input_fn():
- label_ids = tf.placeholder(tf.int32, [None,], name='unique_ids')
- input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
- input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
- segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
- input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
- 'unique_ids': label_ids,
- 'input_ids': input_ids,
- 'input_mask': input_mask,
- 'segment_ids': segment_ids,
- })()
- return input_fn
- saved_dir = estimator.export_savedmodel(
- export_dir,
- serving_input_fn,
- assets_extra=None,
- as_text=False,
- checkpoint_path=init_checkpoint,
- strip_default_attrs=False)
- model_name = FLAGS.triton_model_name
- model_folder = export_dir + "/triton_models/" + model_name
- version_folder = model_folder + "/" + str(FLAGS.triton_model_version)
- final_model_folder = version_folder + "/model.savedmodel"
- if not os.path.exists(version_folder):
- os.makedirs(version_folder)
-
- if (not os.path.exists(final_model_folder)):
- os.rename(saved_dir, final_model_folder)
- print("Model saved to dir", final_model_folder)
- else:
- if (FLAGS.triton_model_overwrite):
- shutil.rmtree(final_model_folder)
- os.rename(saved_dir, final_model_folder)
- print("WARNING: Existing model was overwritten. Model dir: {}".format(final_model_folder))
- else:
- print("ERROR: Could not save Triton model. Folder already exists. Use '--triton_model_overwrite=True' if you would like to overwrite an existing model. Model dir: {}".format(final_model_folder))
- return
- # Now build the config for Triton. Check to make sure we can overwrite it, if it exists
- config_filename = os.path.join(model_folder, "config.pbtxt")
- optimization_str = ""
- if FLAGS.amp:
- optimization_str = r"""
- optimization {
- execution_accelerators
- {
- gpu_execution_accelerator :
- [ {
- name : "auto_mixed_precision"
- } ]
- }
- }"""
- if (os.path.exists(config_filename) and not FLAGS.triton_model_overwrite):
- print("ERROR: Could not save Triton model config. Config file already exists. Use '--triton_model_overwrite=True' if you would like to overwrite an existing model config. Model config: {}".format(config_filename))
- return
-
- config_template = r"""
- name: "{model_name}"
- platform: "tensorflow_savedmodel"
- max_batch_size: {max_batch_size}
- {optimization_str}
- input [
- {{
- name: "unique_ids"
- data_type: TYPE_INT32
- dims: [ 1 ]
- reshape: {{ shape: [ ] }}
- }},
- {{
- name: "segment_ids"
- data_type: TYPE_INT32
- dims: {seq_length}
- }},
- {{
- name: "input_ids"
- data_type: TYPE_INT32
- dims: {seq_length}
- }},
- {{
- name: "input_mask"
- data_type: TYPE_INT32
- dims: {seq_length}
- }}
- ]
- output [
- {{
- name: "end_logits"
- data_type: TYPE_FP32
- dims: {seq_length}
- }},
- {{
- name: "start_logits"
- data_type: TYPE_FP32
- dims: {seq_length}
- }}
- ]
- {dynamic_batching}
- instance_group [
- {{
- count: {engine_count}
- }}
- ]"""
- batching_str = ""
- max_batch_size = FLAGS.triton_max_batch_size
- if (FLAGS.triton_dyn_batching_delay > 0):
- # Use only full and half full batches
- pref_batch_size = [int(max_batch_size / 2.0), max_batch_size]
- batching_str = r"""
- dynamic_batching {{
- preferred_batch_size: [{0}]
- max_queue_delay_microseconds: {1}
- }}""".format(", ".join([str(x) for x in pref_batch_size]), int(FLAGS.triton_dyn_batching_delay * 1000.0))
- config_values = {
- "model_name": model_name,
- "max_batch_size": max_batch_size,
- "seq_length": FLAGS.max_seq_length,
- "dynamic_batching": batching_str,
- "engine_count": FLAGS.triton_engine_count,
- "optimization_str":optimization_str,
- }
- with open(model_folder + "/config.pbtxt", "w") as file:
- final_config_str = config_template.format_map(config_values)
- file.write(final_config_str)
- def main(_):
- setup_xla_flags()
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
- dllogging = utils.dllogger_class.dllogger_class(FLAGS.dllog_path)
- if FLAGS.horovod:
- hvd.init()
- bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
- validate_flags_or_throw(bert_config)
- tf.io.gfile.makedirs(FLAGS.output_dir)
- tokenizer = tokenization.FullTokenizer(
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
- master_process = True
- training_hooks = []
- global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
- hvd_rank = 0
- config = tf.compat.v1.ConfigProto()
- learning_rate = FLAGS.learning_rate
- if FLAGS.horovod:
- tf.compat.v1.logging.info("Multi-GPU training with TF Horovod")
- tf.compat.v1.logging.info("hvd.size() = %d hvd.rank() = %d", hvd.size(), hvd.rank())
- global_batch_size = FLAGS.train_batch_size * hvd.size() * FLAGS.num_accumulation_steps
- learning_rate = learning_rate * hvd.size()
- master_process = (hvd.rank() == 0)
- hvd_rank = hvd.rank()
- config.gpu_options.visible_device_list = str(hvd.local_rank())
- set_affinity(hvd.local_rank())
- if hvd.size() > 1:
- training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
- if FLAGS.use_xla:
- config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
- if FLAGS.amp:
- tf.enable_resource_variables()
- run_config = tf.estimator.RunConfig(
- model_dir=FLAGS.output_dir if master_process else None,
- session_config=config,
- save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None,
- save_summary_steps=FLAGS.save_checkpoints_steps if master_process else None,
- log_step_count_steps=FLAGS.display_loss_steps,
- keep_checkpoint_max=1)
- if master_process:
- tf.compat.v1.logging.info("***** Configuaration *****")
- for key in FLAGS.__flags.keys():
- tf.compat.v1.logging.info(' {}: {}'.format(key, getattr(FLAGS, key)))
- tf.compat.v1.logging.info("**************************")
- train_examples = None
- num_train_steps = None
- num_warmup_steps = None
- training_hooks.append(LogTrainRunHook(global_batch_size, hvd_rank, FLAGS.save_checkpoints_steps))
- # Prepare Training Data
- if FLAGS.do_train:
- train_examples = read_squad_examples(
- input_file=FLAGS.train_file, is_training=True,
- version_2_with_negative=FLAGS.version_2_with_negative)
- num_train_steps = int(
- len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
- num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
- # Pre-shuffle the input to avoid having to make a very large shuffle
- # buffer in in the `input_fn`.
- rng = random.Random(12345)
- rng.shuffle(train_examples)
- start_index = 0
- end_index = len(train_examples)
- tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record")]
- if FLAGS.horovod:
- tmp_filenames = [os.path.join(FLAGS.output_dir, "train.tf_record{}".format(i)) for i in range(hvd.size())]
- num_examples_per_rank = len(train_examples) // hvd.size()
- remainder = len(train_examples) % hvd.size()
- if hvd.rank() < remainder:
- start_index = hvd.rank() * (num_examples_per_rank+1)
- end_index = start_index + num_examples_per_rank + 1
- else:
- start_index = hvd.rank() * num_examples_per_rank + remainder
- end_index = start_index + (num_examples_per_rank)
- model_fn = model_fn_builder(
- bert_config=bert_config,
- init_checkpoint=FLAGS.init_checkpoint,
- learning_rate=learning_rate,
- num_train_steps=num_train_steps,
- num_warmup_steps=num_warmup_steps,
- hvd=None if not FLAGS.horovod else hvd,
- amp=FLAGS.amp)
- estimator = tf.estimator.Estimator(
- model_fn=model_fn,
- config=run_config)
- if FLAGS.do_train:
- # We write to a temporary file to avoid storing very large constant tensors
- # in memory.
- train_writer = FeatureWriter(
- filename=tmp_filenames[hvd_rank],
- is_training=True)
- convert_examples_to_features(
- examples=train_examples[start_index:end_index],
- tokenizer=tokenizer,
- max_seq_length=FLAGS.max_seq_length,
- doc_stride=FLAGS.doc_stride,
- max_query_length=FLAGS.max_query_length,
- is_training=True,
- output_fn=train_writer.process_feature,
- verbose_logging=FLAGS.verbose_logging)
- train_writer.close()
- tf.compat.v1.logging.info("***** Running training *****")
- tf.compat.v1.logging.info(" Num orig examples = %d", end_index - start_index)
- tf.compat.v1.logging.info(" Num split examples = %d", train_writer.num_features)
- tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
- tf.compat.v1.logging.info(" Num steps = %d", num_train_steps)
- tf.compat.v1.logging.info(" LR = %f", learning_rate)
- del train_examples
- train_input_fn = input_fn_builder(
- input_file=tmp_filenames,
- batch_size=FLAGS.train_batch_size,
- seq_length=FLAGS.max_seq_length,
- is_training=True,
- drop_remainder=True,
- hvd=None if not FLAGS.horovod else hvd)
- train_start_time = time.time()
- estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=num_train_steps)
- train_time_elapsed = time.time() - train_start_time
- train_time_wo_overhead = training_hooks[-1].total_time
- avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
- ss_sentences_per_second = (num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead
- if master_process:
- tf.compat.v1.logging.info("-----------------------------")
- tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
- num_train_steps * global_batch_size)
- tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
- (num_train_steps - training_hooks[-1].skipped) * global_batch_size)
- tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
- tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
- dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
- tf.compat.v1.logging.info("-----------------------------")
- if FLAGS.export_triton and master_process:
- export_model(estimator, FLAGS.output_dir, FLAGS.init_checkpoint)
- if FLAGS.do_predict and master_process:
- eval_examples = read_squad_examples(
- input_file=FLAGS.predict_file, is_training=False,
- version_2_with_negative=FLAGS.version_2_with_negative)
- # Perform evaluation on subset, useful for profiling
- if FLAGS.num_eval_iterations is not None:
- eval_examples = eval_examples[:FLAGS.num_eval_iterations*FLAGS.predict_batch_size]
- eval_writer = FeatureWriter(
- filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
- is_training=False)
- eval_features = []
- def append_feature(feature):
- eval_features.append(feature)
- eval_writer.process_feature(feature)
- convert_examples_to_features(
- examples=eval_examples,
- tokenizer=tokenizer,
- max_seq_length=FLAGS.max_seq_length,
- doc_stride=FLAGS.doc_stride,
- max_query_length=FLAGS.max_query_length,
- is_training=False,
- output_fn=append_feature,
- verbose_logging=FLAGS.verbose_logging)
- eval_writer.close()
- tf.compat.v1.logging.info("***** Running predictions *****")
- tf.compat.v1.logging.info(" Num orig examples = %d", len(eval_examples))
- tf.compat.v1.logging.info(" Num split examples = %d", len(eval_features))
- tf.compat.v1.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
- predict_input_fn = input_fn_builder(
- input_file=eval_writer.filename,
- batch_size=FLAGS.predict_batch_size,
- seq_length=FLAGS.max_seq_length,
- is_training=False,
- drop_remainder=False)
- all_results = []
- eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
- eval_start_time = time.time()
- for result in estimator.predict(
- predict_input_fn, yield_single_examples=True, hooks=eval_hooks):
- if len(all_results) % 1000 == 0:
- tf.compat.v1.logging.info("Processing example: %d" % (len(all_results)))
- unique_id = int(result["unique_ids"])
- start_logits = [float(x) for x in result["start_logits"].flat]
- end_logits = [float(x) for x in result["end_logits"].flat]
- all_results.append(
- RawResult(
- unique_id=unique_id,
- start_logits=start_logits,
- end_logits=end_logits))
- eval_time_elapsed = time.time() - eval_start_time
- time_list = eval_hooks[-1].time_list
- time_list.sort()
- # Removing outliers (init/warmup) in throughput computation.
- eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
- num_sentences = (int(len(time_list) * 0.99)) * FLAGS.predict_batch_size
- avg = np.mean(time_list)
- cf_50 = max(time_list[:int(len(time_list) * 0.50)])
- cf_90 = max(time_list[:int(len(time_list) * 0.90)])
- cf_95 = max(time_list[:int(len(time_list) * 0.95)])
- cf_99 = max(time_list[:int(len(time_list) * 0.99)])
- cf_100 = max(time_list[:int(len(time_list) * 1)])
- ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead
- tf.compat.v1.logging.info("-----------------------------")
- tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
- eval_hooks[-1].count * FLAGS.predict_batch_size)
- tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
- num_sentences)
- tf.compat.v1.logging.info("Summary Inference Statistics")
- tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
- tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
- tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
- tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
- tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
- tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
- tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
- tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
- tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
- tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
- dllogging.logger.log(step=(), data={"throughput_val": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
- tf.compat.v1.logging.info("-----------------------------")
- output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
- output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
- output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")
- write_predictions(eval_examples, eval_features, all_results,
- FLAGS.n_best_size, FLAGS.max_answer_length,
- FLAGS.do_lower_case, output_prediction_file,
- output_nbest_file, output_null_log_odds_file,
- FLAGS.version_2_with_negative, FLAGS.verbose_logging)
- if FLAGS.eval_script:
- import sys
- import subprocess
- eval_out = subprocess.check_output([sys.executable, FLAGS.eval_script,
- FLAGS.predict_file, output_prediction_file])
- scores = str(eval_out).strip()
- exact_match = float(scores.split(":")[1].split(",")[0])
- f1 = float(scores.split(":")[2].split("}")[0])
- dllogging.logger.log(step=(), data={"f1": f1}, verbosity=Verbosity.DEFAULT)
- dllogging.logger.log(step=(), data={"exact_match": exact_match}, verbosity=Verbosity.DEFAULT)
- print(str(eval_out))
- if __name__ == "__main__":
- FLAGS = extract_run_squad_flags()
- tf.app.run()
|