monkey_patch.py 613 B

1234567891011121314151617181920
  1. import re
  2. import warnings
  3. import torch
  4. major, minor, *_ = re.search('(\d+)\.(\d+)', torch.__version__).groups()
  5. if int(major) >= 1 and int(minor) >= 12:
  6. # Mutes 'UserWarning: positional arguments and argument "destination"
  7. # are deprecated. nn.Module.state_dict will not accept them in the future.'
  8. def state_dict(self, *args, **kwargs):
  9. warnings.filterwarnings("ignore")
  10. ret = self._state_dict(*args, **kwargs)
  11. warnings.filterwarnings("default")
  12. return ret
  13. torch.nn.Module._state_dict = torch.nn.Module.state_dict
  14. torch.nn.Module.state_dict = state_dict