minchul commited on
Commit
46a3b5e
1 Parent(s): f4a7259

Upload directory

Browse files
Files changed (1) hide show
  1. models/__init__.py +46 -0
models/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def get_model(model_config, task=''):
4
+
5
+ if '/vit/' in model_config.yaml_path:
6
+ from .vit import load_model as load_vit_model
7
+ model = load_vit_model(model_config)
8
+ print('Loaded ViT model')
9
+ elif '/vit_irpe/' in model_config.yaml_path:
10
+ from .vit_irpe import load_model as load_vit_irpe_model
11
+ model = load_vit_irpe_model(model_config)
12
+ print('Loaded ViT model with iRPE')
13
+ elif '/vit_kprpe/' in model_config.yaml_path:
14
+ from .vit_kprpe import load_model as load_vit_kprpe_model
15
+ model = load_vit_kprpe_model(model_config)
16
+ print('Loaded ViT model with KPRPE')
17
+ elif '/iresnet/' in model_config.yaml_path:
18
+ from .iresnet import load_model as load_iresnet_model
19
+ model = load_iresnet_model(model_config)
20
+ print('Loaded iResNet model')
21
+ elif '/iresnet_insightface/' in model_config.yaml_path:
22
+ from .iresnet_insightface import load_model as load_iresnet_insightface_model
23
+ model = load_iresnet_insightface_model(model_config)
24
+ print('Loaded iResNet model')
25
+ elif '/part_fvit/' in model_config.yaml_path:
26
+ from .part_fvit import load_model as load_part_fvit_model
27
+ model = load_part_fvit_model(model_config)
28
+ print('Loaded PartFVIT model')
29
+ elif '/swin/' in model_config.yaml_path:
30
+ from .swin import load_model as load_swin_model
31
+ model = load_swin_model(model_config)
32
+ print('Loaded Swin model')
33
+ elif '/swin_kprpe/' in model_config.yaml_path:
34
+ from .swin_kprpe import load_model as load_swin_kprpe_model
35
+ model = load_swin_kprpe_model(model_config)
36
+ print('Loaded Swin model with KPRPE')
37
+ else:
38
+ raise NotImplementedError(f"Model {model_config.yaml_path} not implemented")
39
+ if model_config.start_from:
40
+ model.load_state_dict_from_path(model_config.start_from)
41
+
42
+ if model_config.freeze:
43
+ for param in model.parameters():
44
+ param.requires_grad = False
45
+
46
+ return model