|
|
@@ -434,8 +434,8 @@ def main():
|
|
|
|
|
|
metrics.finish_iter() # done accumulating
|
|
|
if iters_all % args.step_logs_interval == 0:
|
|
|
- logger.log((epoch, iter_, iters_num), metrics,
|
|
|
- scope='train', tb_iter=iters_all)
|
|
|
+ logger.log((epoch, iter_, iters_num), metrics, scope='train',
|
|
|
+ tb_iter=iters_all, flush_log=True)
|
|
|
|
|
|
assert is_last_accum_step
|
|
|
metrics.finish_epoch()
|
|
|
@@ -443,7 +443,8 @@ def main():
|
|
|
|
|
|
if epoch % args.validation_interval == 0:
|
|
|
validate(args, gen, mel_spec, mpd, msd, val_loader, val_metrics)
|
|
|
- logger.log((epoch,), val_metrics, scope='val', tb_iter=iters_all)
|
|
|
+ logger.log((epoch,), val_metrics, scope='val', tb_iter=iters_all,
|
|
|
+ flush_log=True)
|
|
|
|
|
|
# validation samples
|
|
|
if epoch % args.samples_interval == 0 and args.local_rank == 0:
|
|
|
@@ -477,6 +478,7 @@ def main():
|
|
|
gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d, epoch,
|
|
|
train_state, args, gen_config, train_setup,
|
|
|
gen_ema=gen_ema, mpd_ema=mpd_ema, msd_ema=msd_ema)
|
|
|
+ logger.flush()
|
|
|
|
|
|
sched_g.step()
|
|
|
sched_d.step()
|
|
|
@@ -488,10 +490,10 @@ def main():
|
|
|
|
|
|
# finished training
|
|
|
if epochs_done > 0:
|
|
|
- logger.log((), metrics, scope='train_benchmark')
|
|
|
+ logger.log((), metrics, scope='train_benchmark', flush_log=True)
|
|
|
if epoch % args.validation_interval != 0: # val metrics are not up-to-date
|
|
|
validate(args, gen, mel_spec, mpd, msd, val_loader, val_metrics)
|
|
|
- logger.log((), val_metrics, scope='val')
|
|
|
+ logger.log((), val_metrics, scope='val', flush_log=True)
|
|
|
else:
|
|
|
print_once(f'Finished without training after epoch {args.epochs}.')
|
|
|
|