ema_utils.py 879 B

1234567891011121314151617181920212223242526
  1. import amp_C
  2. import torch
  3. def apply_ema_decay(model, ema_model, decay):
  4. if not decay:
  5. return
  6. st = model.state_dict()
  7. add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
  8. for k, v in ema_model.state_dict().items():
  9. if add_module and not k.startswith('module.'):
  10. k = 'module.' + k
  11. v.copy_(decay * v + (1 - decay) * st[k])
  12. def init_multi_tensor_ema(model, ema_model):
  13. model_weights = list(model.state_dict().values())
  14. ema_model_weights = list(ema_model.state_dict().values())
  15. ema_overflow_buf = torch.cuda.IntTensor([0])
  16. return model_weights, ema_model_weights, ema_overflow_buf
  17. def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
  18. amp_C.multi_tensor_axpby(
  19. 65536, overflow_buf, [ema_weights, model_weights, ema_weights],
  20. decay, 1-decay, -1)