gender-age / mivolo /model /create_timm_model.py
MuGeminorum
upl base codes
29dba9b
raw
history blame
No virus
4.35 kB
"""
Code adapted from timm https://github.com/huggingface/pytorch-image-models
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
"""
import os
from typing import Any, Dict, Optional, Union
import timm
# register new models
from mivolo.model.mivolo_model import * # noqa: F403, F401
from timm.layers import set_layer_config
from timm.models._factory import parse_model_name
from timm.models._helpers import load_state_dict, remap_checkpoint
from timm.models._hub import load_model_config_from_hf
from timm.models._pretrained import PretrainedCfg, split_model_name_tag
from timm.models._registry import is_model, model_entrypoint
def load_checkpoint(
model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None
):
if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, "load_pretrained"):
timm.models._model_builder.load_pretrained(checkpoint_path)
else:
raise NotImplementedError("Model cannot load numpy checkpoint")
return
state_dict = load_state_dict(checkpoint_path, use_ema)
if remap:
state_dict = remap_checkpoint(model, state_dict)
if filter_keys:
for sd_key in list(state_dict.keys()):
for filter_key in filter_keys:
if filter_key in sd_key:
if sd_key in state_dict:
del state_dict[sd_key]
rep = []
if state_dict_map is not None:
# 'patch_embed.conv1.' : 'patch_embed.conv.'
for state_k in list(state_dict.keys()):
for target_k, target_v in state_dict_map.items():
if target_v in state_k:
target_name = state_k.replace(target_v, target_k)
state_dict[target_name] = state_dict[state_k]
rep.append(state_k)
for r in rep:
if r in state_dict:
del state_dict[r]
incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False)
return incompatible_keys
def create_model(
model_name: str,
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: str = "",
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
filter_keys=None,
state_dict_map=None,
**kwargs,
):
"""Create a model
Lookup model's entrypoint function and pass relevant args to create a new model.
"""
# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
model_source, model_name = parse_model_name(model_name)
if model_source == "hf-hub":
assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub."
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
else:
model_name, pretrained_tag = split_model_name_tag(model_name)
if not pretrained_cfg:
# a valid pretrained_cfg argument takes priority over tag in model name
pretrained_cfg = pretrained_tag
if not is_model(model_name):
raise RuntimeError("Unknown model (%s)" % model_name)
create_fn = model_entrypoint(model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
**kwargs,
)
if checkpoint_path:
load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map)
return model