minchul commited on
Commit
9f88867
·
verified ·
1 Parent(s): c3dba1d

Upload model

Browse files
Files changed (3) hide show
  1. config.json +35 -0
  2. model.safetensors +3 -0
  3. wrapper.py +30 -0
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CVLFaceRecognitionModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "wrapper.ModelConfig",
7
+ "AutoModel": "wrapper.CVLFaceRecognitionModel"
8
+ },
9
+ "conf": {
10
+ "color_space": "RGB",
11
+ "freeze": false,
12
+ "input_size": [
13
+ 3,
14
+ 112,
15
+ 112
16
+ ],
17
+ "mask_ratio": 0.0,
18
+ "name": "base",
19
+ "output_dim": 512,
20
+ "rpe_config": {
21
+ "ctx_type": "rel_keypoint_splithead_unshared",
22
+ "method": "product",
23
+ "mode": "ctx",
24
+ "name": "KPRPE_shared",
25
+ "num_keypoints": 5,
26
+ "ratio": 1.9,
27
+ "rpe_on": "k",
28
+ "shared_head": true
29
+ },
30
+ "start_from": "",
31
+ "yaml_path": "models/vit_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml"
32
+ },
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.33.0"
35
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c6d37ea874c2f38ffc9a7f0e9247efc994c3fb5c12d044759ac294e19d127f7
3
+ size 460344344
wrapper.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from transformers import PretrainedConfig
3
+ from omegaconf import OmegaConf
4
+ from models import get_model
5
+ import yaml
6
+
7
+ class ModelConfig(PretrainedConfig):
8
+
9
+ def __init__(
10
+ self,
11
+ **kwargs,
12
+ ):
13
+ super().__init__(**kwargs)
14
+ self.conf = dict(yaml.safe_load(open('pretrained_model/model.yaml')))
15
+
16
+
17
+ class CVLFaceRecognitionModel(PreTrainedModel):
18
+ config_class = ModelConfig
19
+
20
+ def __init__(self, cfg):
21
+ super().__init__(cfg)
22
+ model_conf = OmegaConf.create(cfg.conf)
23
+ self.model = get_model(model_conf)
24
+ self.model.load_state_dict_from_path('pretrained_model/model.pt')
25
+
26
+ def forward(self, *args, **kwargs):
27
+ return self.model(*args, **kwargs)
28
+
29
+
30
+