model.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from dataclasses import dataclass, asdict, replace
  2. from typing import Optional, Callable
  3. import os
  4. import torch
  5. import argparse
  6. @dataclass
  7. class ModelArch:
  8. pass
  9. @dataclass
  10. class ModelParams:
  11. def parser(self, name):
  12. return argparse.ArgumentParser(
  13. description=f"{name} arguments", add_help=False, usage=""
  14. )
  15. @dataclass
  16. class OptimizerParams:
  17. pass
  18. @dataclass
  19. class Model:
  20. constructor: Callable
  21. arch: ModelArch
  22. params: Optional[ModelParams]
  23. optimizer_params: Optional[OptimizerParams] = None
  24. checkpoint_url: Optional[str] = None
  25. class EntryPoint:
  26. def __init__(self, name: str, model: Model):
  27. self.name = name
  28. self.model = model
  29. def __call__(self, pretrained=False, pretrained_from_file=None, **kwargs):
  30. assert not (pretrained and (pretrained_from_file is not None))
  31. params = replace(self.model.params, **kwargs)
  32. model = self.model.constructor(arch=self.model.arch, **asdict(params))
  33. state_dict = None
  34. if pretrained:
  35. assert self.model.checkpoint_url is not None
  36. state_dict = torch.hub.load_state_dict_from_url(
  37. self.model.checkpoint_url, map_location=torch.device("cpu")
  38. )
  39. if pretrained_from_file is not None:
  40. if os.path.isfile(pretrained_from_file):
  41. print(
  42. "=> loading pretrained weights from '{}'".format(
  43. pretrained_from_file
  44. )
  45. )
  46. state_dict = torch.load(
  47. pretrained_from_file, map_location=torch.device("cpu")
  48. )
  49. else:
  50. print(
  51. "=> no pretrained weights found at '{}'".format(
  52. pretrained_from_file
  53. )
  54. )
  55. # Temporary fix to allow NGC checkpoint loading
  56. if state_dict is not None:
  57. state_dict = {
  58. k[len("module.") :] if k.startswith("module.") else k: v
  59. for k, v in state_dict.items()
  60. }
  61. def reshape(t, conv):
  62. if conv:
  63. if len(t.shape) == 4:
  64. return t
  65. else:
  66. return t.view(t.shape[0], -1, 1, 1)
  67. else:
  68. if len(t.shape) == 4:
  69. return t.view(t.shape[0], t.shape[1])
  70. else:
  71. return t
  72. state_dict = {
  73. k: reshape(
  74. v,
  75. conv=dict(model.named_modules())[
  76. ".".join(k.split(".")[:-2])
  77. ].use_conv,
  78. )
  79. if is_se_weight(k, v)
  80. else v
  81. for k, v in state_dict.items()
  82. }
  83. model.load_state_dict(state_dict)
  84. return model
  85. def parser(self):
  86. if self.model.params is None:
  87. return None
  88. parser = self.model.params.parser(self.name)
  89. parser.add_argument(
  90. "--pretrained-from-file",
  91. default=None,
  92. type=str,
  93. metavar="PATH",
  94. help="load weights from local file",
  95. )
  96. if self.model.checkpoint_url is not None:
  97. parser.add_argument(
  98. "--pretrained",
  99. default=False,
  100. action="store_true",
  101. help="load pretrained weights from NGC",
  102. )
  103. return parser
  104. def is_se_weight(key, value):
  105. return (key.endswith("squeeze.weight") or key.endswith("expand.weight"))
  106. def create_entrypoint(m: Model):
  107. def _ep(**kwargs):
  108. params = replace(m.params, **kwargs)
  109. return m.constructor(arch=m.arch, **asdict(params))
  110. return _ep