|
|
|
|
|
def get_model(model_config, task=''): |
|
|
|
if '/vit/' in model_config.yaml_path: |
|
from .vit import load_model as load_vit_model |
|
model = load_vit_model(model_config) |
|
print('Loaded ViT model') |
|
elif '/vit_irpe/' in model_config.yaml_path: |
|
from .vit_irpe import load_model as load_vit_irpe_model |
|
model = load_vit_irpe_model(model_config) |
|
print('Loaded ViT model with iRPE') |
|
elif '/vit_kprpe/' in model_config.yaml_path: |
|
from .vit_kprpe import load_model as load_vit_kprpe_model |
|
model = load_vit_kprpe_model(model_config) |
|
print('Loaded ViT model with KPRPE') |
|
elif '/iresnet/' in model_config.yaml_path: |
|
from .iresnet import load_model as load_iresnet_model |
|
model = load_iresnet_model(model_config) |
|
print('Loaded iResNet model') |
|
elif '/iresnet_insightface/' in model_config.yaml_path: |
|
from .iresnet_insightface import load_model as load_iresnet_insightface_model |
|
model = load_iresnet_insightface_model(model_config) |
|
print('Loaded iResNet model') |
|
elif '/part_fvit/' in model_config.yaml_path: |
|
from .part_fvit import load_model as load_part_fvit_model |
|
model = load_part_fvit_model(model_config) |
|
print('Loaded PartFVIT model') |
|
elif '/swin/' in model_config.yaml_path: |
|
from .swin import load_model as load_swin_model |
|
model = load_swin_model(model_config) |
|
print('Loaded Swin model') |
|
elif '/swin_kprpe/' in model_config.yaml_path: |
|
from .swin_kprpe import load_model as load_swin_kprpe_model |
|
model = load_swin_kprpe_model(model_config) |
|
print('Loaded Swin model with KPRPE') |
|
else: |
|
raise NotImplementedError(f"Model {model_config.yaml_path} not implemented") |
|
if model_config.start_from: |
|
model.load_state_dict_from_path(model_config.start_from) |
|
|
|
if model_config.freeze: |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
return model |
|
|