""" Code adapted from timm https://github.com/huggingface/pytorch-image-models Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich """ import os from typing import Any, Dict, Optional, Union import timm # register new models from mivolo.model.mivolo_model import * # noqa: F403, F401 from timm.layers import set_layer_config from timm.models._factory import parse_model_name from timm.models._helpers import load_state_dict, remap_checkpoint from timm.models._hub import load_model_config_from_hf from timm.models._pretrained import PretrainedCfg, split_model_name_tag from timm.models._registry import is_model, model_entrypoint def load_checkpoint( model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None ): if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"): # numpy checkpoint, try to load via model specific load_pretrained fn if hasattr(model, "load_pretrained"): timm.models._model_builder.load_pretrained(checkpoint_path) else: raise NotImplementedError("Model cannot load numpy checkpoint") return state_dict = load_state_dict(checkpoint_path, use_ema) if remap: state_dict = remap_checkpoint(model, state_dict) if filter_keys: for sd_key in list(state_dict.keys()): for filter_key in filter_keys: if filter_key in sd_key: if sd_key in state_dict: del state_dict[sd_key] rep = [] if state_dict_map is not None: # 'patch_embed.conv1.' : 'patch_embed.conv.' for state_k in list(state_dict.keys()): for target_k, target_v in state_dict_map.items(): if target_v in state_k: target_name = state_k.replace(target_v, target_k) state_dict[target_name] = state_dict[state_k] rep.append(state_k) for r in rep: if r in state_dict: del state_dict[r] incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False) return incompatible_keys def create_model( model_name: str, pretrained: bool = False, pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, checkpoint_path: str = "", scriptable: Optional[bool] = None, exportable: Optional[bool] = None, no_jit: Optional[bool] = None, filter_keys=None, state_dict_map=None, **kwargs, ): """Create a model Lookup model's entrypoint function and pass relevant args to create a new model. """ # Parameters that aren't supported by all models or are intended to only override model defaults if set # should default to None in command line args/cfg. Remove them if they are present and not set so that # non-supporting models don't break and default args remain in effect. kwargs = {k: v for k, v in kwargs.items() if v is not None} model_source, model_name = parse_model_name(model_name) if model_source == "hf-hub": assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub." # For model names specified in the form `hf-hub:path/architecture_name@revision`, # load model weights + pretrained_cfg from Hugging Face hub. pretrained_cfg, model_name = load_model_config_from_hf(model_name) else: model_name, pretrained_tag = split_model_name_tag(model_name) if not pretrained_cfg: # a valid pretrained_cfg argument takes priority over tag in model name pretrained_cfg = pretrained_tag if not is_model(model_name): raise RuntimeError("Unknown model (%s)" % model_name) create_fn = model_entrypoint(model_name) with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): model = create_fn( pretrained=pretrained, pretrained_cfg=pretrained_cfg, pretrained_cfg_overlay=pretrained_cfg_overlay, **kwargs, ) if checkpoint_path: load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map) return model