minchul's picture
Upload directory
99ff494 verified
raw
history blame
2 kB
def get_model(model_config, task=''):
if '/vit/' in model_config.yaml_path:
from .vit import load_model as load_vit_model
model = load_vit_model(model_config)
print('Loaded ViT model')
elif '/vit_irpe/' in model_config.yaml_path:
from .vit_irpe import load_model as load_vit_irpe_model
model = load_vit_irpe_model(model_config)
print('Loaded ViT model with iRPE')
elif '/vit_kprpe/' in model_config.yaml_path:
from .vit_kprpe import load_model as load_vit_kprpe_model
model = load_vit_kprpe_model(model_config)
print('Loaded ViT model with KPRPE')
elif '/iresnet/' in model_config.yaml_path:
from .iresnet import load_model as load_iresnet_model
model = load_iresnet_model(model_config)
print('Loaded iResNet model')
elif '/iresnet_insightface/' in model_config.yaml_path:
from .iresnet_insightface import load_model as load_iresnet_insightface_model
model = load_iresnet_insightface_model(model_config)
print('Loaded iResNet model')
elif '/part_fvit/' in model_config.yaml_path:
from .part_fvit import load_model as load_part_fvit_model
model = load_part_fvit_model(model_config)
print('Loaded PartFVIT model')
elif '/swin/' in model_config.yaml_path:
from .swin import load_model as load_swin_model
model = load_swin_model(model_config)
print('Loaded Swin model')
elif '/swin_kprpe/' in model_config.yaml_path:
from .swin_kprpe import load_model as load_swin_kprpe_model
model = load_swin_kprpe_model(model_config)
print('Loaded Swin model with KPRPE')
else:
raise NotImplementedError(f"Model {model_config.yaml_path} not implemented")
if model_config.start_from:
model.load_state_dict_from_path(model_config.start_from)
if model_config.freeze:
for param in model.parameters():
param.requires_grad = False
return model