|
@@ -214,6 +214,7 @@ def run(exe,
|
|
|
last_step = args.last_step_of_checkpoint
|
|
last_step = args.last_step_of_checkpoint
|
|
|
train_iter = 0
|
|
train_iter = 0
|
|
|
epoch = 0
|
|
epoch = 0
|
|
|
|
|
+ train_time_raw = 0
|
|
|
if progress is None:
|
|
if progress is None:
|
|
|
progress = dict()
|
|
progress = dict()
|
|
|
else:
|
|
else:
|
|
@@ -229,9 +230,12 @@ def run(exe,
|
|
|
f"Only {max_steps - last_step} steps will be performed in this run due to the limit of --max-steps."
|
|
f"Only {max_steps - last_step} steps will be performed in this run due to the limit of --max-steps."
|
|
|
)
|
|
)
|
|
|
else:
|
|
else:
|
|
|
- max_steps = args.steps_this_run + last_step
|
|
|
|
|
|
|
+ steps_this_run = args.steps_this_run
|
|
|
|
|
+ if args.benchmark:
|
|
|
|
|
+ steps_this_run = min(steps_this_run, args.benchmark_warmup_steps + args.benchmark_steps)
|
|
|
|
|
+ max_steps = steps_this_run + last_step
|
|
|
logging.warning(
|
|
logging.warning(
|
|
|
- f"{args.steps_this_run} steps will be performed in this run.")
|
|
|
|
|
|
|
+ f"{steps_this_run} steps will be performed in this run.")
|
|
|
|
|
|
|
|
total_samples = 0
|
|
total_samples = 0
|
|
|
raw_train_start = time.time()
|
|
raw_train_start = time.time()
|
|
@@ -272,6 +276,7 @@ def run(exe,
|
|
|
|
|
|
|
|
if train_iter % (save_steps * gradient_merge_steps
|
|
if train_iter % (save_steps * gradient_merge_steps
|
|
|
) == 0 or global_step >= max_steps:
|
|
) == 0 or global_step >= max_steps:
|
|
|
|
|
+ train_time_raw = time.time() - raw_train_start
|
|
|
if trainer_id == 0:
|
|
if trainer_id == 0:
|
|
|
model_path = os.path.join(
|
|
model_path = os.path.join(
|
|
|
args.output_dir, args.bert_model, "phase1"
|
|
args.output_dir, args.bert_model, "phase1"
|
|
@@ -287,9 +292,7 @@ def run(exe,
|
|
|
if len(most_recent_ckpts_paths) > 3:
|
|
if len(most_recent_ckpts_paths) > 3:
|
|
|
ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
|
|
ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
|
|
|
shutil.rmtree(ckpt_to_be_removed)
|
|
shutil.rmtree(ckpt_to_be_removed)
|
|
|
- if (global_step >= max_steps) or (
|
|
|
|
|
- args.benchmark and global_step >=
|
|
|
|
|
- args.benchmark_steps + args.benchmark_warmup_steps):
|
|
|
|
|
- train_time_raw = time.time() - raw_train_start
|
|
|
|
|
- return global_step, loss_return[0].item(), train_time_raw
|
|
|
|
|
|
|
+ if global_step >= max_steps:
|
|
|
|
|
+ actual_steps_this_run = global_step - last_step
|
|
|
|
|
+ return global_step, actual_steps_this_run, loss_return[0].item(), train_time_raw
|
|
|
epoch += 1
|
|
epoch += 1
|