Sfoglia il codice sorgente

[HiFi-GAN/PyT] Explicitly flush logs

Adrian Lancucki 3 anni fa
parent
commit
66da9fdc5c
1 ha cambiato i file con 7 aggiunte e 5 eliminazioni
  1. 7 5
      PyTorch/SpeechSynthesis/HiFi-GAN/train.py

+ 7 - 5
PyTorch/SpeechSynthesis/HiFi-GAN/train.py

@@ -434,8 +434,8 @@ def main():
 
             metrics.finish_iter()  # done accumulating
             if iters_all % args.step_logs_interval == 0:
-                logger.log((epoch, iter_, iters_num), metrics,
-                           scope='train', tb_iter=iters_all)
+                logger.log((epoch, iter_, iters_num), metrics, scope='train',
+                           tb_iter=iters_all, flush_log=True)
 
         assert is_last_accum_step
         metrics.finish_epoch()
@@ -443,7 +443,8 @@ def main():
 
         if epoch % args.validation_interval == 0:
             validate(args, gen, mel_spec, mpd, msd, val_loader, val_metrics)
-            logger.log((epoch,), val_metrics, scope='val', tb_iter=iters_all)
+            logger.log((epoch,), val_metrics, scope='val', tb_iter=iters_all,
+                       flush_log=True)
 
         # validation samples
         if epoch % args.samples_interval == 0 and args.local_rank == 0:
@@ -477,6 +478,7 @@ def main():
             gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d, epoch,
             train_state, args, gen_config, train_setup,
             gen_ema=gen_ema, mpd_ema=mpd_ema, msd_ema=msd_ema)
+        logger.flush()
 
         sched_g.step()
         sched_d.step()
@@ -488,10 +490,10 @@ def main():
 
     # finished training
     if epochs_done > 0:
-        logger.log((), metrics, scope='train_benchmark')
+        logger.log((), metrics, scope='train_benchmark', flush_log=True)
         if epoch % args.validation_interval != 0:  # val metrics are not up-to-date
             validate(args, gen, mel_spec, mpd, msd, val_loader, val_metrics)
-        logger.log((), val_metrics, scope='val')
+        logger.log((), val_metrics, scope='val', flush_log=True)
     else:
         print_once(f'Finished without training after epoch {args.epochs}.')