Răsfoiți Sursa

[HiFi-GAN/PyT] Import amp_C (apex) only when necessary

Adrian Lancucki 3 ani în urmă
părinte
comite
91c1de23f7

+ 26 - 0
PyTorch/SpeechSynthesis/HiFiGAN/common/ema_utils.py

@@ -0,0 +1,26 @@
+import amp_C
+import torch
+
+
+def apply_ema_decay(model, ema_model, decay):
+    if not decay:
+        return
+    st = model.state_dict()
+    add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
+    for k, v in ema_model.state_dict().items():
+        if add_module and not k.startswith('module.'):
+            k = 'module.' + k
+        v.copy_(decay * v + (1 - decay) * st[k])
+
+
+def init_multi_tensor_ema(model, ema_model):
+    model_weights = list(model.state_dict().values())
+    ema_model_weights = list(ema_model.state_dict().values())
+    ema_overflow_buf = torch.cuda.IntTensor([0])
+    return model_weights, ema_model_weights, ema_overflow_buf
+
+
+def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
+    amp_C.multi_tensor_axpby(
+        65536, overflow_buf, [ema_weights, model_weights, ema_weights],
+        decay, 1-decay, -1)

+ 0 - 25
PyTorch/SpeechSynthesis/HiFiGAN/common/utils.py

@@ -49,7 +49,6 @@ from typing import Optional
 
 import soundfile  # flac
 
-import amp_C
 import matplotlib
 
 matplotlib.use("Agg")
@@ -97,30 +96,6 @@ def adjust_fine_tuning_lr(args, ckpt_d):
             param_group['lr'] = new_v
 
 
-def apply_ema_decay(model, ema_model, decay):
-    if not decay:
-        return
-    st = model.state_dict()
-    add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
-    for k, v in ema_model.state_dict().items():
-        if add_module and not k.startswith('module.'):
-            k = 'module.' + k
-        v.copy_(decay * v + (1 - decay) * st[k])
-
-
-def init_multi_tensor_ema(model, ema_model):
-    model_weights = list(model.state_dict().values())
-    ema_model_weights = list(ema_model.state_dict().values())
-    ema_overflow_buf = torch.cuda.IntTensor([0])
-    return model_weights, ema_model_weights, ema_overflow_buf
-
-
-def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
-    amp_C.multi_tensor_axpby(
-        65536, overflow_buf, [ema_weights, model_weights, ema_weights],
-        decay, 1-decay, -1)
-
-
 def init_distributed(args, world_size, rank):
     assert torch.cuda.is_available(), "Distributed mode requires CUDA."
     print(f"{args.local_rank}: Initializing distributed training")

+ 10 - 6
PyTorch/SpeechSynthesis/HiFiGAN/train.py

@@ -275,6 +275,10 @@ def main():
 
     # setup EMA
     if args.ema_decay > 0:
+        # burried import, requires apex
+        from common.ema_utils import (apply_multi_tensor_ema,
+                                      init_multi_tensor_ema)
+
         gen_ema = models.get_model('HiFi-GAN', gen_config, 'cuda').cuda()
         mpd_ema = MultiPeriodDiscriminator(
             periods=args.mpd_periods,
@@ -316,9 +320,9 @@ def main():
                                          val_kwargs=dict(split=False),
                                          batch_size=1)
     if args.ema_decay > 0.0:
-        gen_ema_params = utils.init_multi_tensor_ema(gen, gen_ema)
-        mpd_ema_params = utils.init_multi_tensor_ema(mpd, mpd_ema)
-        msd_ema_params = utils.init_multi_tensor_ema(msd, msd_ema)
+        gen_ema_params = init_multi_tensor_ema(gen, gen_ema)
+        mpd_ema_params = init_multi_tensor_ema(mpd, mpd_ema)
+        msd_ema_params = init_multi_tensor_ema(msd, msd_ema)
 
     epochs_done = 0
 
@@ -428,9 +432,9 @@ def main():
             metrics.accumulate()
 
             if args.ema_decay > 0.0:
-                utils.apply_multi_tensor_ema(args.ema_decay, *gen_ema_params)
-                utils.apply_multi_tensor_ema(args.ema_decay, *mpd_ema_params)
-                utils.apply_multi_tensor_ema(args.ema_decay, *msd_ema_params)
+                apply_multi_tensor_ema(args.ema_decay, *gen_ema_params)
+                apply_multi_tensor_ema(args.ema_decay, *mpd_ema_params)
+                apply_multi_tensor_ema(args.ema_decay, *msd_ema_params)
 
             metrics.finish_iter()  # done accumulating
             if iters_all % args.step_logs_interval == 0: