| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- from dataclasses import dataclass, asdict, replace
- from typing import Optional, Callable
- import os
- import torch
- import argparse
- @dataclass
- class ModelArch:
- pass
- @dataclass
- class ModelParams:
- def parser(self, name):
- return argparse.ArgumentParser(
- description=f"{name} arguments", add_help=False, usage=""
- )
- @dataclass
- class OptimizerParams:
- pass
- @dataclass
- class Model:
- constructor: Callable
- arch: ModelArch
- params: Optional[ModelParams]
- optimizer_params: Optional[OptimizerParams] = None
- checkpoint_url: Optional[str] = None
- class EntryPoint:
- def __init__(self, name: str, model: Model):
- self.name = name
- self.model = model
- def __call__(self, pretrained=False, pretrained_from_file=None, **kwargs):
- assert not (pretrained and (pretrained_from_file is not None))
- params = replace(self.model.params, **kwargs)
- model = self.model.constructor(arch=self.model.arch, **asdict(params))
- state_dict = None
- if pretrained:
- assert self.model.checkpoint_url is not None
- state_dict = torch.hub.load_state_dict_from_url(
- self.model.checkpoint_url, map_location=torch.device("cpu")
- )
- if pretrained_from_file is not None:
- if os.path.isfile(pretrained_from_file):
- print(
- "=> loading pretrained weights from '{}'".format(
- pretrained_from_file
- )
- )
- state_dict = torch.load(
- pretrained_from_file, map_location=torch.device("cpu")
- )
- else:
- print(
- "=> no pretrained weights found at '{}'".format(
- pretrained_from_file
- )
- )
- # Temporary fix to allow NGC checkpoint loading
- if state_dict is not None:
- state_dict = {
- k[len("module.") :] if k.startswith("module.") else k: v
- for k, v in state_dict.items()
- }
- def reshape(t, conv):
- if conv:
- if len(t.shape) == 4:
- return t
- else:
- return t.view(t.shape[0], -1, 1, 1)
- else:
- if len(t.shape) == 4:
- return t.view(t.shape[0], t.shape[1])
- else:
- return t
- state_dict = {
- k: reshape(
- v,
- conv=dict(model.named_modules())[
- ".".join(k.split(".")[:-2])
- ].use_conv,
- )
- if is_se_weight(k, v)
- else v
- for k, v in state_dict.items()
- }
- model.load_state_dict(state_dict)
- return model
- def parser(self):
- if self.model.params is None:
- return None
- parser = self.model.params.parser(self.name)
- parser.add_argument(
- "--pretrained-from-file",
- default=None,
- type=str,
- metavar="PATH",
- help="load weights from local file",
- )
- if self.model.checkpoint_url is not None:
- parser.add_argument(
- "--pretrained",
- default=False,
- action="store_true",
- help="load pretrained weights from NGC",
- )
- return parser
- def is_se_weight(key, value):
- return (key.endswith("squeeze.weight") or key.endswith("expand.weight"))
- def create_entrypoint(m: Model):
- def _ep(**kwargs):
- params = replace(m.params, **kwargs)
- return m.constructor(arch=m.arch, **asdict(params))
- return _ep
|