|
@@ -104,7 +104,7 @@ flags.DEFINE_integer("iterations_per_loop", 1000,
|
|
|
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
|
|
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
|
|
|
|
|
|
|
|
flags.DEFINE_integer("num_accumulation_steps", 1,
|
|
flags.DEFINE_integer("num_accumulation_steps", 1,
|
|
|
- "Number of accumulation steps before gradient update."
|
|
|
|
|
|
|
+ "Number of accumulation steps before gradient update."
|
|
|
"Global batch size = num_accumulation_steps * train_batch_size")
|
|
"Global batch size = num_accumulation_steps * train_batch_size")
|
|
|
|
|
|
|
|
flags.DEFINE_bool("allreduce_post_accumulation", False, "Whether to all reduce after accumulation of N steps or after each step")
|
|
flags.DEFINE_bool("allreduce_post_accumulation", False, "Whether to all reduce after accumulation of N steps or after each step")
|
|
@@ -146,6 +146,7 @@ class _LogSessionRunHook(tf.estimator.SessionRunHook):
|
|
|
self.step_time = 0.0 # time taken per step
|
|
self.step_time = 0.0 # time taken per step
|
|
|
self.init_global_step = session.run(tf.train.get_global_step()) # training starts at init_global_step
|
|
self.init_global_step = session.run(tf.train.get_global_step()) # training starts at init_global_step
|
|
|
self.skipped = 0
|
|
self.skipped = 0
|
|
|
|
|
+ self.final_loss = 0
|
|
|
|
|
|
|
|
def before_run(self, run_context):
|
|
def before_run(self, run_context):
|
|
|
self.t0 = time.time()
|
|
self.t0 = time.time()
|
|
@@ -246,6 +247,7 @@ class _LogSessionRunHook(tf.estimator.SessionRunHook):
|
|
|
self.count = 0
|
|
self.count = 0
|
|
|
self.loss = 0.0
|
|
self.loss = 0.0
|
|
|
self.all_count = 0
|
|
self.all_count = 0
|
|
|
|
|
+ self.final_loss = avg_loss_step
|
|
|
|
|
|
|
|
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
|
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
|
|
num_train_steps, num_warmup_steps,
|
|
num_train_steps, num_warmup_steps,
|
|
@@ -280,8 +282,8 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
|
|
|
|
|
|
|
(masked_lm_loss,
|
|
(masked_lm_loss,
|
|
|
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
|
|
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
|
|
|
- bert_config, model.get_sequence_output(), model.get_embedding_table(),
|
|
|
|
|
- masked_lm_positions, masked_lm_ids,
|
|
|
|
|
|
|
+ bert_config, model.get_sequence_output(), model.get_embedding_table(),
|
|
|
|
|
+ masked_lm_positions, masked_lm_ids,
|
|
|
masked_lm_weights)
|
|
masked_lm_weights)
|
|
|
|
|
|
|
|
(next_sentence_loss, next_sentence_example_loss,
|
|
(next_sentence_loss, next_sentence_example_loss,
|
|
@@ -582,7 +584,7 @@ def main(_):
|
|
|
tf.compat.v1.logging.info("**************************")
|
|
tf.compat.v1.logging.info("**************************")
|
|
|
|
|
|
|
|
# config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
|
# config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
|
|
- if FLAGS.use_xla:
|
|
|
|
|
|
|
+ if FLAGS.use_xla:
|
|
|
config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
|
|
config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
|
|
|
config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
|
|
config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
|
|
|
if FLAGS.amp:
|
|
if FLAGS.amp:
|
|
@@ -620,7 +622,8 @@ def main(_):
|
|
|
training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
|
|
training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
|
|
|
if (not FLAGS.horovod or hvd.rank() == 0):
|
|
if (not FLAGS.horovod or hvd.rank() == 0):
|
|
|
global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size()
|
|
global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size()
|
|
|
- training_hooks.append(_LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss))
|
|
|
|
|
|
|
+ log_hook = _LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss)
|
|
|
|
|
+ training_hooks.append(log_hook)
|
|
|
|
|
|
|
|
tf.compat.v1.logging.info("***** Running training *****")
|
|
tf.compat.v1.logging.info("***** Running training *****")
|
|
|
tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
|
tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
|
@@ -649,6 +652,8 @@ def main(_):
|
|
|
tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
|
|
tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
|
|
|
tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
|
|
tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
|
|
|
dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
|
|
dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
|
|
|
|
|
+ if log_hook.final_loss != 0:
|
|
|
|
|
+ dllogging.logger.log(step=(), data={"total_loss": log_hook.final_loss}, verbosity=Verbosity.DEFAULT)
|
|
|
tf.compat.v1.logging.info("-----------------------------")
|
|
tf.compat.v1.logging.info("-----------------------------")
|
|
|
|
|
|
|
|
if FLAGS.do_eval and (not FLAGS.horovod or hvd.rank() == 0):
|
|
if FLAGS.do_eval and (not FLAGS.horovod or hvd.rank() == 0):
|