Просмотр исходного кода

[BERT/TF] Add final loss metrics

Bobby Chen 3 лет назад
Родитель
Сommit
71aab7bda8
1 измененных файлов с 10 добавлено и 5 удалено
  1. 10 5
      TensorFlow/LanguageModeling/BERT/run_pretraining.py

+ 10 - 5
TensorFlow/LanguageModeling/BERT/run_pretraining.py

@@ -104,7 +104,7 @@ flags.DEFINE_integer("iterations_per_loop", 1000,
 flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
 
 flags.DEFINE_integer("num_accumulation_steps", 1,
-                     "Number of accumulation steps before gradient update." 
+                     "Number of accumulation steps before gradient update."
                       "Global batch size = num_accumulation_steps * train_batch_size")
 
 flags.DEFINE_bool("allreduce_post_accumulation", False, "Whether to all reduce after accumulation of N steps or after each step")
@@ -146,6 +146,7 @@ class _LogSessionRunHook(tf.estimator.SessionRunHook):
     self.step_time = 0.0 # time taken per step
     self.init_global_step = session.run(tf.train.get_global_step()) # training starts at init_global_step
     self.skipped = 0
+    self.final_loss = 0
 
   def before_run(self, run_context):
     self.t0 = time.time()
@@ -246,6 +247,7 @@ class _LogSessionRunHook(tf.estimator.SessionRunHook):
             self.count = 0
             self.loss = 0.0
             self.all_count = 0
+            self.final_loss = avg_loss_step
 
 def model_fn_builder(bert_config, init_checkpoint, learning_rate,
                      num_train_steps, num_warmup_steps,
@@ -280,8 +282,8 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
 
     (masked_lm_loss,
      masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
-         bert_config, model.get_sequence_output(), model.get_embedding_table(), 
-         masked_lm_positions, masked_lm_ids, 
+         bert_config, model.get_sequence_output(), model.get_embedding_table(),
+         masked_lm_positions, masked_lm_ids,
          masked_lm_weights)
 
     (next_sentence_loss, next_sentence_example_loss,
@@ -582,7 +584,7 @@ def main(_):
       tf.compat.v1.logging.info("**************************")
 
 #    config.gpu_options.per_process_gpu_memory_fraction = 0.7
-  if FLAGS.use_xla: 
+  if FLAGS.use_xla:
       config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
       config.graph_options.rewrite_options.memory_optimization = rewriter_config_pb2.RewriterConfig.NO_MEM_OPT
       if FLAGS.amp:
@@ -620,7 +622,8 @@ def main(_):
       training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))
     if (not FLAGS.horovod or hvd.rank() == 0):
       global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps if not FLAGS.horovod else FLAGS.train_batch_size * FLAGS.num_accumulation_steps * hvd.size()
-      training_hooks.append(_LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss))
+      log_hook = _LogSessionRunHook(global_batch_size, FLAGS.num_accumulation_steps, dllogging, FLAGS.display_loss_steps, FLAGS.save_checkpoints_steps, FLAGS.report_loss)
+      training_hooks.append(log_hook)
 
     tf.compat.v1.logging.info("***** Running training *****")
     tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
@@ -649,6 +652,8 @@ def main(_):
         tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
         tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
         dllogging.logger.log(step=(), data={"throughput_train": ss_sentences_per_second}, verbosity=Verbosity.DEFAULT)
+        if log_hook.final_loss != 0:
+          dllogging.logger.log(step=(), data={"total_loss": log_hook.final_loss}, verbosity=Verbosity.DEFAULT)
         tf.compat.v1.logging.info("-----------------------------")
 
   if FLAGS.do_eval and (not FLAGS.horovod or hvd.rank() == 0):