|
|
@@ -275,6 +275,10 @@ def main():
|
|
|
|
|
|
# setup EMA
|
|
|
if args.ema_decay > 0:
|
|
|
+ # burried import, requires apex
|
|
|
+ from common.ema_utils import (apply_multi_tensor_ema,
|
|
|
+ init_multi_tensor_ema)
|
|
|
+
|
|
|
gen_ema = models.get_model('HiFi-GAN', gen_config, 'cuda').cuda()
|
|
|
mpd_ema = MultiPeriodDiscriminator(
|
|
|
periods=args.mpd_periods,
|
|
|
@@ -316,9 +320,9 @@ def main():
|
|
|
val_kwargs=dict(split=False),
|
|
|
batch_size=1)
|
|
|
if args.ema_decay > 0.0:
|
|
|
- gen_ema_params = utils.init_multi_tensor_ema(gen, gen_ema)
|
|
|
- mpd_ema_params = utils.init_multi_tensor_ema(mpd, mpd_ema)
|
|
|
- msd_ema_params = utils.init_multi_tensor_ema(msd, msd_ema)
|
|
|
+ gen_ema_params = init_multi_tensor_ema(gen, gen_ema)
|
|
|
+ mpd_ema_params = init_multi_tensor_ema(mpd, mpd_ema)
|
|
|
+ msd_ema_params = init_multi_tensor_ema(msd, msd_ema)
|
|
|
|
|
|
epochs_done = 0
|
|
|
|
|
|
@@ -428,9 +432,9 @@ def main():
|
|
|
metrics.accumulate()
|
|
|
|
|
|
if args.ema_decay > 0.0:
|
|
|
- utils.apply_multi_tensor_ema(args.ema_decay, *gen_ema_params)
|
|
|
- utils.apply_multi_tensor_ema(args.ema_decay, *mpd_ema_params)
|
|
|
- utils.apply_multi_tensor_ema(args.ema_decay, *msd_ema_params)
|
|
|
+ apply_multi_tensor_ema(args.ema_decay, *gen_ema_params)
|
|
|
+ apply_multi_tensor_ema(args.ema_decay, *mpd_ema_params)
|
|
|
+ apply_multi_tensor_ema(args.ema_decay, *msd_ema_params)
|
|
|
|
|
|
metrics.finish_iter() # done accumulating
|
|
|
if iters_all % args.step_logs_interval == 0:
|