hysts HF staff commited on
Commit
c4cd192
1 Parent(s): 700188c
Files changed (6) hide show
  1. .gitignore +1 -0
  2. .gitmodules +6 -0
  3. app.py +222 -0
  4. deep-head-pose +1 -0
  5. face_detection +1 -0
  6. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "face_detection"]
2
+ path = face_detection
3
+ url = https://github.com/ibug-group/face_detection
4
+ [submodule "deep-head-pose"]
5
+ path = deep-head-pose
6
+ url = https://github.com/natanielruiz/deep-head-pose
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import sys
10
+ import tarfile
11
+ from typing import Callable
12
+
13
+ import cv2
14
+ import gradio as gr
15
+ import huggingface_hub
16
+ import numpy as np
17
+ import PIL.Image
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torchvision
22
+ import torchvision.transforms as T
23
+ from scipy.spatial.transform import Rotation
24
+
25
+ sys.path.insert(0, 'face_detection')
26
+ sys.path.insert(0, 'deep-head-pose/code')
27
+
28
+ from hopenet import Hopenet
29
+ from ibug.face_detection import RetinaFacePredictor
30
+
31
+ TITLE = 'natanielruiz/deep-head-pose'
32
+ DESCRIPTION = 'This is a demo for https://github.com/natanielruiz/deep-head-pose.'
33
+ ARTICLE = None
34
+
35
+ TOKEN = os.environ['TOKEN']
36
+
37
+
38
+ def parse_args() -> argparse.Namespace:
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument('--device', type=str, default='cpu')
41
+ parser.add_argument('--theme', type=str)
42
+ parser.add_argument('--live', action='store_true')
43
+ parser.add_argument('--share', action='store_true')
44
+ parser.add_argument('--port', type=int)
45
+ parser.add_argument('--disable-queue',
46
+ dest='enable_queue',
47
+ action='store_false')
48
+ parser.add_argument('--allow-flagging', type=str, default='never')
49
+ parser.add_argument('--allow-screenshot', action='store_true')
50
+ return parser.parse_args()
51
+
52
+
53
+ def load_sample_images() -> list[pathlib.Path]:
54
+ image_dir = pathlib.Path('images')
55
+ if not image_dir.exists():
56
+ image_dir.mkdir()
57
+ dataset_repo = 'hysts/input-images'
58
+ filenames = ['001.tar']
59
+ for name in filenames:
60
+ path = huggingface_hub.hf_hub_download(dataset_repo,
61
+ name,
62
+ repo_type='dataset',
63
+ use_auth_token=TOKEN)
64
+ with tarfile.open(path) as f:
65
+ f.extractall(image_dir.as_posix())
66
+ return sorted(image_dir.rglob('*.jpg'))
67
+
68
+
69
+ def load_model(model_name: str, device: torch.device) -> nn.Module:
70
+ path = huggingface_hub.hf_hub_download('hysts/Hopenet',
71
+ f'models/{model_name}.pkl',
72
+ use_auth_token=TOKEN)
73
+ state_dict = torch.load(path, map_location='cpu')
74
+ model = Hopenet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 66)
75
+ model.load_state_dict(state_dict)
76
+ model.to(device)
77
+ model.eval()
78
+ return model
79
+
80
+
81
+ def create_transform() -> Callable:
82
+ transform = T.Compose([
83
+ T.Resize(224),
84
+ T.CenterCrop(224),
85
+ T.ToTensor(),
86
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
87
+ ])
88
+ return transform
89
+
90
+
91
+ def crop_face(image: np.ndarray, box: tuple[int, int, int, int]) -> np.ndarray:
92
+ x0, y0, x1, y1 = box
93
+ w = x1 - x0
94
+ h = y1 - y0
95
+ x0 -= 2 * w // 4
96
+ x1 += 2 * w // 4
97
+ y0 -= 3 * h // 4
98
+ y1 += h // 4
99
+ x0 = max(x0, 0)
100
+ y0 = max(y0, 0)
101
+ x1 = min(x1, image.shape[1])
102
+ y1 = min(y1, image.shape[0])
103
+ image = image[y0:y1, x0:x1]
104
+ return image
105
+
106
+
107
+ @torch.inference_mode()
108
+ def predict(image: np.ndarray, transform: Callable, model: nn.Module,
109
+ device: torch.device) -> np.ndarray:
110
+ indices = torch.arange(66).float().to(device)
111
+
112
+ image = PIL.Image.fromarray(image)
113
+ data = transform(image)
114
+ data = data.to(device)
115
+
116
+ # the output of the model is a tuple of 3 tensors (yaw, pitch, roll)
117
+ # the shape of each tensor is (1, 66)
118
+ out = model(data[None, ...])
119
+ out = torch.stack(out, dim=1) # shape: (1, 3, 66)
120
+ out = F.softmax(out, dim=2)
121
+ out = (out * indices).sum(dim=2) * 3 - 99
122
+ out = out.cpu().numpy()[0]
123
+ return out
124
+
125
+
126
+ def draw_axis(image: np.ndarray, pose: np.ndarray, origin: np.ndarray,
127
+ length: int) -> None:
128
+ # (yaw, pitch, roll) -> (roll, yaw, pitch)
129
+ pose = pose[[2, 0, 1]]
130
+ pose *= np.array([1, -1, 1])
131
+ rot = Rotation.from_euler('zyx', pose, degrees=True)
132
+
133
+ vectors = rot.as_matrix().T[:, :2] # shape: (3, 2)
134
+ pts = np.round(vectors * length + origin).astype(int)
135
+
136
+ cv2.line(image, tuple(origin), tuple(pts[0]), (0, 0, 255), 3)
137
+ cv2.line(image, tuple(origin), tuple(pts[1]), (0, 255, 0), 3)
138
+ cv2.line(image, tuple(origin), tuple(pts[2]), (255, 0, 0), 2)
139
+
140
+
141
+ def run(image: np.ndarray, model_name: str, face_detector: RetinaFacePredictor,
142
+ models: dict[str, nn.Module], transform: Callable,
143
+ device: torch.device) -> np.ndarray:
144
+ model = models[model_name]
145
+
146
+ # RGB -> BGR
147
+ det_faces = face_detector(image[:, :, ::-1], rgb=False)
148
+
149
+ res = image[:, :, ::-1].copy()
150
+ for det_face in det_faces:
151
+ box = np.round(det_face[:4]).astype(int)
152
+
153
+ # RGB
154
+ face_image = crop_face(image, box.tolist())
155
+
156
+ # (yaw, pitch, roll)
157
+ angles = predict(face_image, transform, model, device)
158
+
159
+ center = (box[:2] + box[2:]) // 2
160
+ length = (box[3] - box[1]) // 2
161
+ draw_axis(res, angles, center, length)
162
+
163
+ return res[:, :, ::-1]
164
+
165
+
166
+ def main():
167
+ gr.close_all()
168
+
169
+ args = parse_args()
170
+ device = torch.device(args.device)
171
+
172
+ face_detector = RetinaFacePredictor(
173
+ threshold=0.8,
174
+ device=device,
175
+ model=RetinaFacePredictor.get_model('mobilenet0.25'))
176
+
177
+ model_names = [
178
+ 'hopenet_alpha1',
179
+ 'hopenet_alpha2',
180
+ 'hopenet_robust_alpha1',
181
+ ]
182
+ models = {name: load_model(name, device) for name in model_names}
183
+
184
+ transform = create_transform()
185
+
186
+ func = functools.partial(run,
187
+ face_detector=face_detector,
188
+ models=models,
189
+ transform=transform,
190
+ device=device)
191
+ func = functools.update_wrapper(func, run)
192
+
193
+ image_paths = load_sample_images()
194
+ examples = [[path.as_posix(), model_names[0]] for path in image_paths]
195
+
196
+ gr.Interface(
197
+ func,
198
+ [
199
+ gr.inputs.Image(type='numpy', label='Input'),
200
+ gr.inputs.Radio(model_names,
201
+ type='value',
202
+ default=model_names[0],
203
+ label='Model'),
204
+ ],
205
+ gr.outputs.Image(type='numpy', label='Output'),
206
+ examples=examples,
207
+ title=TITLE,
208
+ description=DESCRIPTION,
209
+ article=ARTICLE,
210
+ theme=args.theme,
211
+ allow_screenshot=args.allow_screenshot,
212
+ allow_flagging=args.allow_flagging,
213
+ live=args.live,
214
+ ).launch(
215
+ enable_queue=args.enable_queue,
216
+ server_port=args.port,
217
+ share=args.share,
218
+ )
219
+
220
+
221
+ if __name__ == '__main__':
222
+ main()
deep-head-pose ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit f7bbb9981c2953c2eca67748d6492a64c8243946
face_detection ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit bc1e392b11d731fa20b1397c8ff3faed5e7fc76e
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ opencv-python-headless==4.5.5.64
3
+ Pillow==9.1.0
4
+ scipy==1.8.0
5
+ torch==1.11.0
6
+ torchvision==0.12.0