Upload directory
Browse files- models/vit_kprpe/__init__.py +65 -0
models/vit_kprpe/__init__.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..base import BaseModel
|
2 |
+
from .vit import VisionTransformerWithKPRPE
|
3 |
+
from torchvision import transforms
|
4 |
+
|
5 |
+
|
6 |
+
class ViTKPRPEModel(BaseModel):
|
7 |
+
|
8 |
+
|
9 |
+
"""
|
10 |
+
Vision Transformer for face recognition model with KeyPoint Relative Position Encoding (KP-RPE).
|
11 |
+
|
12 |
+
```
|
13 |
+
@article{kim2024keypoint,
|
14 |
+
title={KeyPoint Relative Position Encoding for Face Recognition},
|
15 |
+
author={Kim, Minchul and Su, Yiyang and Liu, Feng and Jain, Anil and Liu, Xiaoming},
|
16 |
+
journal={CVPR},
|
17 |
+
year={2024}
|
18 |
+
}
|
19 |
+
```
|
20 |
+
"""
|
21 |
+
def __init__(self, net, config):
|
22 |
+
super(ViTKPRPEModel, self).__init__(config)
|
23 |
+
self.net = net
|
24 |
+
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def from_config(cls, config):
|
28 |
+
|
29 |
+
if config.name == 'small':
|
30 |
+
net = VisionTransformerWithKPRPE(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=12,
|
31 |
+
mlp_ratio=5, num_heads=8, drop_path_rate=0.1, norm_layer="ln",
|
32 |
+
mask_ratio=config.mask_ratio, rpe_config=config.rpe_config)
|
33 |
+
elif config.name == 'base':
|
34 |
+
net = VisionTransformerWithKPRPE(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=24,
|
35 |
+
mlp_ratio=3, num_heads=16, drop_path_rate=0.1, norm_layer="ln",
|
36 |
+
mask_ratio=config.mask_ratio, rpe_config=config.rpe_config)
|
37 |
+
else:
|
38 |
+
raise NotImplementedError
|
39 |
+
|
40 |
+
model = cls(net, config)
|
41 |
+
model.eval()
|
42 |
+
return model
|
43 |
+
|
44 |
+
def forward(self, x, *args, **kwargs):
|
45 |
+
if self.input_color_flip:
|
46 |
+
x = x.flip(1)
|
47 |
+
return self.net(x, *args, **kwargs)
|
48 |
+
|
49 |
+
def make_train_transform(self):
|
50 |
+
transform = transforms.Compose([
|
51 |
+
transforms.ToTensor(),
|
52 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
53 |
+
])
|
54 |
+
return transform
|
55 |
+
|
56 |
+
def make_test_transform(self):
|
57 |
+
transform = transforms.Compose([
|
58 |
+
transforms.ToTensor(),
|
59 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
60 |
+
])
|
61 |
+
return transform
|
62 |
+
|
63 |
+
def load_model(model_config):
|
64 |
+
model = ViTKPRPEModel.from_config(model_config)
|
65 |
+
return model
|