Quellcode durchsuchen

[BERT/Paddle] fix some issues on throughput measurement

Shijie Wang vor 2 Jahren
Ursprung
Commit
eb4b0ab2c6

+ 10 - 7
PaddlePaddle/LanguageModeling/BERT/program.py

@@ -214,6 +214,7 @@ def run(exe,
     last_step = args.last_step_of_checkpoint
     train_iter = 0
     epoch = 0
+    train_time_raw = 0
     if progress is None:
         progress = dict()
     else:
@@ -229,9 +230,12 @@ def run(exe,
                 f"Only {max_steps - last_step} steps will be performed in this run due to the limit of --max-steps."
             )
         else:
-            max_steps = args.steps_this_run + last_step
+            steps_this_run = args.steps_this_run
+            if args.benchmark:
+                steps_this_run = min(steps_this_run, args.benchmark_warmup_steps + args.benchmark_steps)
+            max_steps = steps_this_run + last_step
             logging.warning(
-                f"{args.steps_this_run} steps will be performed in this run.")
+                f"{steps_this_run} steps will be performed in this run.")
 
     total_samples = 0
     raw_train_start = time.time()
@@ -272,6 +276,7 @@ def run(exe,
 
             if train_iter % (save_steps * gradient_merge_steps
                              ) == 0 or global_step >= max_steps:
+                train_time_raw = time.time() - raw_train_start
                 if trainer_id == 0:
                     model_path = os.path.join(
                         args.output_dir, args.bert_model, "phase1"
@@ -287,9 +292,7 @@ def run(exe,
                     if len(most_recent_ckpts_paths) > 3:
                         ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
                         shutil.rmtree(ckpt_to_be_removed)
-            if (global_step >= max_steps) or (
-                    args.benchmark and global_step >=
-                    args.benchmark_steps + args.benchmark_warmup_steps):
-                train_time_raw = time.time() - raw_train_start
-                return global_step, loss_return[0].item(), train_time_raw
+            if global_step >= max_steps:
+                actual_steps_this_run = global_step - last_step
+                return global_step, actual_steps_this_run, loss_return[0].item(), train_time_raw
         epoch += 1

+ 3 - 3
PaddlePaddle/LanguageModeling/BERT/run_pretraining.py

@@ -86,7 +86,7 @@ def main():
     if args.amp:
         optimizer.amp_init(device)
 
-    global_step, final_loss, train_time_raw = program.run(
+    global_step, actual_steps_this_run, final_loss, train_time_raw = program.run(
         exe, main_program, args, lr_scheduler, loss, train_dataloader,
         progress)
 
@@ -94,10 +94,10 @@ def main():
         e2e_time = time.time() - now
         if args.benchmark:
             training_perf = args.batch_size * args.gradient_merge_steps * (
-                global_step - args.benchmark_warmup_steps
+                actual_steps_this_run - args.benchmark_warmup_steps
             ) * get_num_trainers() / train_time_raw
         else:
-            training_perf = args.batch_size * args.gradient_merge_steps * global_step * get_num_trainers(
+            training_perf = args.batch_size * args.gradient_merge_steps * actual_steps_this_run * get_num_trainers(
             ) / train_time_raw
         dllogger.log(step=tuple(),
                      data={