hysts HF staff commited on
Commit
caa3c6a
1 Parent(s): a5c1a92
Files changed (1) hide show
  1. app.py +31 -43
app.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from __future__ import annotations
4
 
5
- import functools
6
  import sys
7
  from typing import Callable
8
 
@@ -66,25 +65,6 @@ def crop_face(image: np.ndarray, box: tuple[int, int, int, int]) -> np.ndarray:
66
  return image
67
 
68
 
69
- @spaces.GPU
70
- @torch.inference_mode()
71
- def predict(image: np.ndarray, transform: Callable, model: nn.Module, device: torch.device) -> np.ndarray:
72
- indices = torch.arange(66).float().to(device)
73
-
74
- image = PIL.Image.fromarray(image)
75
- data = transform(image)
76
- data = data.to(device)
77
-
78
- # the output of the model is a tuple of 3 tensors (yaw, pitch, roll)
79
- # the shape of each tensor is (1, 66)
80
- out = model(data[None, ...])
81
- out = torch.stack(out, dim=1) # shape: (1, 3, 66)
82
- out = F.softmax(out, dim=2)
83
- out = (out * indices).sum(dim=2) * 3 - 99
84
- out = out.cpu().numpy()[0]
85
- return out
86
-
87
-
88
  def draw_axis(image: np.ndarray, pose: np.ndarray, origin: np.ndarray, length: int) -> None:
89
  # (yaw, pitch, roll) -> (roll, yaw, pitch)
90
  pose = pose[[2, 0, 1]]
@@ -99,19 +79,33 @@ def draw_axis(image: np.ndarray, pose: np.ndarray, origin: np.ndarray, length: i
99
  cv2.line(image, tuple(origin), tuple(pts[2]), (255, 0, 0), 2)
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def run(
103
  image: np.ndarray,
104
  model_name: str,
105
- face_detector: RetinaFacePredictor,
106
- models: dict[str, nn.Module],
107
- transform: Callable,
108
- device: torch.device,
109
  ) -> np.ndarray:
110
  model = models[model_name]
111
 
112
  # RGB -> BGR
113
  det_faces = face_detector(image[:, :, ::-1], rgb=False)
114
 
 
 
115
  res = image[:, :, ::-1].copy()
116
  for det_face in det_faces:
117
  box = np.round(det_face[:4]).astype(int)
@@ -119,8 +113,17 @@ def run(
119
  # RGB
120
  face_image = crop_face(image, box.tolist())
121
 
122
- # (yaw, pitch, roll)
123
- angles = predict(face_image, transform, model, device)
 
 
 
 
 
 
 
 
 
124
 
125
  center = (box[:2] + box[2:]) // 2
126
  length = (box[3] - box[1]) // 2
@@ -129,21 +132,6 @@ def run(
129
  return res[:, :, ::-1]
130
 
131
 
132
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
133
- face_detector = RetinaFacePredictor(threshold=0.8, device="cpu", model=RetinaFacePredictor.get_model("mobilenet0.25"))
134
- face_detector.device = device
135
- face_detector.net.to(device)
136
-
137
- model_names = [
138
- "hopenet_alpha1",
139
- "hopenet_alpha2",
140
- "hopenet_robust_alpha1",
141
- ]
142
- models = {name: load_model(name, device) for name in model_names}
143
- transform = create_transform()
144
-
145
- fn = functools.partial(run, face_detector=face_detector, models=models, transform=transform, device=device)
146
-
147
  examples = [["images/pexels-ksenia-chernaya-8535230.jpg", "hopenet_alpha1"]]
148
 
149
  with gr.Blocks(css="style.css") as demo:
@@ -159,10 +147,10 @@ with gr.Blocks(css="style.css") as demo:
159
  examples=examples,
160
  inputs=[image, model_name],
161
  outputs=result,
162
- fn=fn,
163
  )
164
  run_button.click(
165
- fn=fn,
166
  inputs=[image, model_name],
167
  outputs=result,
168
  api_name="run",
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import sys
6
  from typing import Callable
7
 
 
65
  return image
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def draw_axis(image: np.ndarray, pose: np.ndarray, origin: np.ndarray, length: int) -> None:
69
  # (yaw, pitch, roll) -> (roll, yaw, pitch)
70
  pose = pose[[2, 0, 1]]
 
79
  cv2.line(image, tuple(origin), tuple(pts[2]), (255, 0, 0), 2)
80
 
81
 
82
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
83
+ face_detector = RetinaFacePredictor(threshold=0.8, device="cpu", model=RetinaFacePredictor.get_model("mobilenet0.25"))
84
+ face_detector.device = device
85
+ face_detector.net.to(device)
86
+
87
+ model_names = [
88
+ "hopenet_alpha1",
89
+ "hopenet_alpha2",
90
+ "hopenet_robust_alpha1",
91
+ ]
92
+ models = {name: load_model(name, device) for name in model_names}
93
+ transform = create_transform()
94
+
95
+
96
+ @spaces.GPU
97
+ @torch.inference_mode()
98
  def run(
99
  image: np.ndarray,
100
  model_name: str,
 
 
 
 
101
  ) -> np.ndarray:
102
  model = models[model_name]
103
 
104
  # RGB -> BGR
105
  det_faces = face_detector(image[:, :, ::-1], rgb=False)
106
 
107
+ indices = torch.arange(66).float().to(device)
108
+
109
  res = image[:, :, ::-1].copy()
110
  for det_face in det_faces:
111
  box = np.round(det_face[:4]).astype(int)
 
113
  # RGB
114
  face_image = crop_face(image, box.tolist())
115
 
116
+ face_image = PIL.Image.fromarray(face_image)
117
+ data = transform(face_image)
118
+ data = data.to(device)
119
+
120
+ # the output of the model is a tuple of 3 tensors (yaw, pitch, roll)
121
+ # the shape of each tensor is (1, 66)
122
+ out = model(data[None, ...])
123
+ out = torch.stack(out, dim=1) # shape: (1, 3, 66)
124
+ out = F.softmax(out, dim=2)
125
+ out = (out * indices).sum(dim=2) * 3 - 99
126
+ angles = out.cpu().numpy()[0]
127
 
128
  center = (box[:2] + box[2:]) // 2
129
  length = (box[3] - box[1]) // 2
 
132
  return res[:, :, ::-1]
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  examples = [["images/pexels-ksenia-chernaya-8535230.jpg", "hopenet_alpha1"]]
136
 
137
  with gr.Blocks(css="style.css") as demo:
 
147
  examples=examples,
148
  inputs=[image, model_name],
149
  outputs=result,
150
+ fn=run,
151
  )
152
  run_button.click(
153
+ fn=run,
154
  inputs=[image, model_name],
155
  outputs=result,
156
  api_name="run",