Spaces:
Sleeping
Sleeping
antoniospoletojr
commited on
Commit
•
e2508c0
1
Parent(s):
a12e461
first commit
Browse files
app.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from facenet_pytorch import MTCNN
|
6 |
+
from model import HPEnet
|
7 |
+
from torchvision import transforms
|
8 |
+
from scipy.spatial.transform import Rotation as R
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from utils import draw_2D_axes
|
12 |
+
|
13 |
+
|
14 |
+
def detect_faces(image):
|
15 |
+
# Detect face
|
16 |
+
boxes, _ = mtcnn.detect(image)
|
17 |
+
boxes_centroids = []
|
18 |
+
sizes = []
|
19 |
+
faces = []
|
20 |
+
|
21 |
+
# If no boxes have been detected return
|
22 |
+
if boxes is None:
|
23 |
+
return None, None, None
|
24 |
+
|
25 |
+
# Add margin to each box, calculate centroids and crop the face image
|
26 |
+
for i in range(len(boxes)):
|
27 |
+
# Add margin while safe checking
|
28 |
+
margin=50
|
29 |
+
boxes[i][0] = max(0, boxes[i][0] - margin)
|
30 |
+
boxes[i][1] = max(0, boxes[i][1] - margin)
|
31 |
+
boxes[i][2] = min(image.width, boxes[i][2] + margin)
|
32 |
+
boxes[i][3] = min(image.height, boxes[i][3] + margin)
|
33 |
+
|
34 |
+
# Calculate centroids and sizes
|
35 |
+
boxes_centroids.append([int((boxes[i][0] + boxes[i][2])/2), int((boxes[i][1] + boxes[i][3]) /2)])
|
36 |
+
sizes.append(boxes[i][2] - boxes[i][0])
|
37 |
+
# Crop the face using boxes
|
38 |
+
faces.append(image.crop(boxes[i]))
|
39 |
+
|
40 |
+
return faces, boxes_centroids, sizes
|
41 |
+
|
42 |
+
def process(frame):
|
43 |
+
# Convert from opencv to PIL
|
44 |
+
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
45 |
+
image = Image.fromarray(frame)
|
46 |
+
|
47 |
+
# Detect face
|
48 |
+
faces, centroids, sizes = detect_faces(image)
|
49 |
+
|
50 |
+
if faces is None:
|
51 |
+
return frame
|
52 |
+
|
53 |
+
for idx, face in enumerate(faces):
|
54 |
+
|
55 |
+
# Preprocess the image
|
56 |
+
transform = transforms.Compose([
|
57 |
+
transforms.PILToTensor(),
|
58 |
+
transforms.Resize((200, 200)),
|
59 |
+
])
|
60 |
+
face_tensor = transform(face)
|
61 |
+
face_tensor = face_tensor.permute(1, 2, 0)
|
62 |
+
|
63 |
+
# Standardize the tensor
|
64 |
+
face_tensor = (face_tensor - mean) / std
|
65 |
+
face_tensor = face_tensor.permute(2, 0, 1)
|
66 |
+
face_tensor = face_tensor.type(torch.float32)
|
67 |
+
|
68 |
+
# Run the inference
|
69 |
+
with torch.inference_mode():
|
70 |
+
face_tensor = face_tensor.unsqueeze(0).to(device)
|
71 |
+
r1, r2, r3, _ = model(face_tensor)
|
72 |
+
|
73 |
+
# Create a numpy matrix out of r1, r2, r3 (these vectors are the columns of the rotation matrix)
|
74 |
+
r1 = r1.squeeze().numpy()
|
75 |
+
r2 = r2.squeeze().numpy()
|
76 |
+
r3 = r3.squeeze().numpy()
|
77 |
+
|
78 |
+
rotation_matrix = np.array([r1, r2, r3])
|
79 |
+
|
80 |
+
r = R.from_matrix(rotation_matrix)
|
81 |
+
|
82 |
+
pitch, yaw, roll = r.as_euler('zyx', degrees=True)
|
83 |
+
|
84 |
+
center = centroids[idx]
|
85 |
+
size = sizes[idx]*0.5
|
86 |
+
|
87 |
+
frame = draw_2D_axes(frame, yaw, roll, pitch, center[0], center[1], size)
|
88 |
+
|
89 |
+
return frame
|
90 |
+
|
91 |
+
|
92 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
93 |
+
model = HPEnet().to(device)
|
94 |
+
# Load model from checkpoint
|
95 |
+
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
|
96 |
+
model.to(device);
|
97 |
+
model.eval()
|
98 |
+
mtcnn = MTCNN(keep_all=True, post_process=False, device='cpu')
|
99 |
+
mean = torch.load('mean.pt')
|
100 |
+
std = torch.load('std.pt')
|
101 |
+
|
102 |
+
demo = gr.Interface(
|
103 |
+
process,
|
104 |
+
gr.Image(sources="webcam", streaming=True),
|
105 |
+
"image",
|
106 |
+
live=True,
|
107 |
+
allow_flagging="never",
|
108 |
+
)
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
demo.launch()
|
mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:173d644e647a24e9d317b78ea47774a3bc60257952fc5f8f3a5951af56b80b80
|
3 |
+
size 961101
|
model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:997e75103549eab362cec3939c35c9ebe1cba720a18d166424d48425e41af20d
|
3 |
+
size 143750654
|
model.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
class HPEnet(nn.Module):
|
8 |
+
def __init__(self, roll_bins=18, yaw_bins=29, pitch_bins=21):
|
9 |
+
super(HPEnet, self).__init__()
|
10 |
+
print("Loading the model...")
|
11 |
+
|
12 |
+
self.resnet = torchvision.models.resnet50(weights="ResNet50_Weights.DEFAULT") #ResNet50_Weights.DEFAULT
|
13 |
+
self.resnet.fc = nn.Linear(2048, 2048)
|
14 |
+
self.fc = nn.Linear(2048, 2048)
|
15 |
+
|
16 |
+
# Classification layers
|
17 |
+
self.fc_class = nn.Linear(2048, 1921)
|
18 |
+
|
19 |
+
# Regression layers
|
20 |
+
self.fc_r1 = nn.Linear(2048, 3)
|
21 |
+
self.fc_r2 = nn.Linear(2048, 3)
|
22 |
+
self.fc_r3 = nn.Linear(2048, 3)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
# Backbone
|
26 |
+
x = self.resnet(x)
|
27 |
+
|
28 |
+
# Dense layer
|
29 |
+
x = torch.nn.functional.relu(x)
|
30 |
+
x = self.fc(x)
|
31 |
+
|
32 |
+
# Regression layers
|
33 |
+
r1 = self.fc_r1(x)
|
34 |
+
r2 = self.fc_r2(x)
|
35 |
+
r3 = self.fc_r3(x)
|
36 |
+
|
37 |
+
# Classification layers
|
38 |
+
x = torch.nn.functional.relu(x)
|
39 |
+
x = self.fc_class(x)
|
40 |
+
|
41 |
+
return r1, r2, r3, x
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
opencv-python
|
3 |
+
scipy
|
4 |
+
facenet_pytorch
|
5 |
+
pillow
|
std.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8e1e3b7a9db0ac613f8c5d4deb0b73151d08aa485459096cdb39aa2e294f5fea
|
3 |
+
size 961096
|
utils.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from math import cos, sin
|
6 |
+
from scipy.spatial.transform import Rotation as R
|
7 |
+
|
8 |
+
|
9 |
+
def plot_3D_rotation(rotation_matrix):
|
10 |
+
fig = go.Figure()
|
11 |
+
|
12 |
+
# Original axis orientation
|
13 |
+
axes_points = np.array([
|
14 |
+
[1, 0, 0, 0],
|
15 |
+
[0, 1, 0, 0],
|
16 |
+
[0, 0, 1, 0]
|
17 |
+
], dtype=np.float64)
|
18 |
+
|
19 |
+
# Plot original axes
|
20 |
+
fig.add_trace(go.Scatter3d(
|
21 |
+
x=[0, axes_points[0, 0]],
|
22 |
+
y=[0, axes_points[1, 0]],
|
23 |
+
z=[0, axes_points[2, 0]],
|
24 |
+
mode='lines+text',
|
25 |
+
line=dict(color='blue', width=6),
|
26 |
+
name='Canonical X-axis',
|
27 |
+
text=['', 'X axis'],
|
28 |
+
textposition='middle center',
|
29 |
+
))
|
30 |
+
|
31 |
+
fig.add_trace(go.Scatter3d(
|
32 |
+
x=[0, axes_points[0, 1]],
|
33 |
+
y=[0, axes_points[1, 1]],
|
34 |
+
z=[0, axes_points[2, 1]],
|
35 |
+
mode='lines+text',
|
36 |
+
line=dict(color='blue', width=6),
|
37 |
+
name='Canonical Z-axis',
|
38 |
+
text=['', 'Z axis'],
|
39 |
+
textposition='middle center',
|
40 |
+
))
|
41 |
+
|
42 |
+
fig.add_trace(go.Scatter3d(
|
43 |
+
x=[0, axes_points[0, 2]],
|
44 |
+
y=[0, axes_points[1, 2]],
|
45 |
+
z=[0, axes_points[2, 2]],
|
46 |
+
mode='lines+text',
|
47 |
+
line=dict(color='blue', width=6),
|
48 |
+
name='Canonical Y-axis',
|
49 |
+
text=['', 'Y axis'],
|
50 |
+
textposition='middle center',
|
51 |
+
))
|
52 |
+
|
53 |
+
# Apply rotation
|
54 |
+
axes_points = rotation_matrix @ axes_points
|
55 |
+
|
56 |
+
# Plot rotated axes
|
57 |
+
fig.add_trace(go.Scatter3d(
|
58 |
+
x=[0, axes_points[0, 0]],
|
59 |
+
y=[0, axes_points[1, 0]],
|
60 |
+
z=[0, axes_points[2, 0]],
|
61 |
+
mode='lines+text',
|
62 |
+
line=dict(color='red', width=6),
|
63 |
+
name='Rotated X\'-axis',
|
64 |
+
text=['', 'Rotated X axis'],
|
65 |
+
textposition='middle center',
|
66 |
+
))
|
67 |
+
|
68 |
+
fig.add_trace(go.Scatter3d(
|
69 |
+
x=[0, axes_points[0, 1]],
|
70 |
+
y=[0, axes_points[1, 1]],
|
71 |
+
z=[0, axes_points[2, 1]],
|
72 |
+
mode='lines+text',
|
73 |
+
line=dict(color='red', width=6),
|
74 |
+
name='Rotated Z\'-axis',
|
75 |
+
text=['', 'Rotated Z axis'],
|
76 |
+
textposition='middle center',
|
77 |
+
))
|
78 |
+
|
79 |
+
fig.add_trace(go.Scatter3d(
|
80 |
+
x=[0, axes_points[0, 2]],
|
81 |
+
y=[0, axes_points[1, 2]],
|
82 |
+
z=[0, axes_points[2, 2]],
|
83 |
+
mode='lines+text',
|
84 |
+
line=dict(color='red', width=6),
|
85 |
+
name='Rotated Y\'-axis',
|
86 |
+
text=['', 'Rotated Y axis'],
|
87 |
+
textposition='middle center',
|
88 |
+
))
|
89 |
+
|
90 |
+
# Retrieve pitch, yaw, roll from rotation matrix
|
91 |
+
r = R.from_matrix(rotation_matrix)
|
92 |
+
pitch, yaw, roll = r.as_euler('xzy', degrees=True)
|
93 |
+
|
94 |
+
# Set layout
|
95 |
+
fig.update_layout(
|
96 |
+
scene=dict(
|
97 |
+
xaxis=dict(title='X-axis', range=[-1.2, 1.2]),
|
98 |
+
yaxis=dict(title='Z-axis', range=[-1.2, 1.2]),
|
99 |
+
zaxis=dict(title='Y-axis', range=[-1.2, 1.2]),
|
100 |
+
xaxis_tickvals=np.arange(-1.2, 1.2, 0.6),
|
101 |
+
yaxis_tickvals=np.arange(-1.2, 1.2, 0.5),
|
102 |
+
zaxis_tickvals=np.arange(-1.2, 1.2, 0.5),
|
103 |
+
aspectmode='cube',
|
104 |
+
aspectratio=dict(x=1, y=1, z=1),
|
105 |
+
),
|
106 |
+
margin=dict(l=0, r=0, t=0, b=30),
|
107 |
+
)
|
108 |
+
# add annotation
|
109 |
+
fig.add_annotation(dict(font=dict(color='black',size=15),
|
110 |
+
x=-30,
|
111 |
+
y=50,
|
112 |
+
showarrow=False,
|
113 |
+
text=f"Pitch: {int(pitch)} - Yaw: {int(yaw)} - Roll: {int(roll)}",
|
114 |
+
textangle=0,
|
115 |
+
xanchor='left',
|
116 |
+
xref="paper",
|
117 |
+
yref="paper"))
|
118 |
+
return fig
|
119 |
+
|
120 |
+
|
121 |
+
def draw_2D_axes(img, roll, pitch, yaw, tdx=None, tdy=None, size=150.):
|
122 |
+
# Input is a cv2 image
|
123 |
+
# pose_params: (pitch, yaw, roll, tdx, tdy)
|
124 |
+
# Where (tdx, tdy) is the translation of the face.
|
125 |
+
# For pose we have [pitch yaw roll tdx tdy tdz scale_factor]
|
126 |
+
|
127 |
+
p = pitch * np.pi / 180
|
128 |
+
y = (yaw * np.pi / 180)
|
129 |
+
r = -roll * np.pi / 180
|
130 |
+
if tdx != None and tdy != None:
|
131 |
+
face_x = tdx - 0.50 * size
|
132 |
+
face_y = tdy - 0.50 * size
|
133 |
+
|
134 |
+
else:
|
135 |
+
height, width = img.shape[:2]
|
136 |
+
face_x = width / 2 - 0.5 * size
|
137 |
+
face_y = height / 2 - 0.5 * size
|
138 |
+
|
139 |
+
x1 = size * (cos(y) * cos(r)) + face_x
|
140 |
+
y1 = size * (cos(p) * sin(r) + cos(r) * sin(p) * sin(y)) + face_y
|
141 |
+
x2 = size * (-cos(y) * sin(r)) + face_x
|
142 |
+
y2 = size * (cos(p) * cos(r) - sin(p) * sin(y) * sin(r)) + face_y
|
143 |
+
x3 = size * (sin(y)) + face_x
|
144 |
+
y3 = size * (-cos(y) * sin(p)) + face_y
|
145 |
+
|
146 |
+
# Draw base in red
|
147 |
+
cv2.line(img, (int(face_x), int(face_y)), (int(x1),int(y1)),(0,0,255),3)
|
148 |
+
cv2.line(img, (int(face_x), int(face_y)), (int(x2),int(y2)),(0,0,255),3)
|
149 |
+
cv2.line(img, (int(x2), int(y2)), (int(x2+x1-face_x),int(y2+y1-face_y)),(0,0,255),3)
|
150 |
+
cv2.line(img, (int(x1), int(y1)), (int(x1+x2-face_x),int(y1+y2-face_y)),(0,0,255),3)
|
151 |
+
# Draw pillars in blue
|
152 |
+
cv2.line(img, (int(face_x), int(face_y)), (int(x3),int(y3)),(255,0,0),2)
|
153 |
+
cv2.line(img, (int(x1), int(y1)), (int(x1+x3-face_x),int(y1+y3-face_y)),(255,0,0),2)
|
154 |
+
cv2.line(img, (int(x2), int(y2)), (int(x2+x3-face_x),int(y2+y3-face_y)),(255,0,0),2)
|
155 |
+
cv2.line(img, (int(x2+x1-face_x),int(y2+y1-face_y)), (int(x3+x1+x2-2*face_x),int(y3+y2+y1-2*face_y)),(255,0,0),2)
|
156 |
+
# Draw top in green
|
157 |
+
cv2.line(img, (int(x3+x1-face_x),int(y3+y1-face_y)), (int(x3+x1+x2-2*face_x),int(y3+y2+y1-2*face_y)),(0,255,0),2)
|
158 |
+
cv2.line(img, (int(x2+x3-face_x),int(y2+y3-face_y)), (int(x3+x1+x2-2*face_x),int(y3+y2+y1-2*face_y)),(0,255,0),2)
|
159 |
+
cv2.line(img, (int(x3), int(y3)), (int(x3+x1-face_x),int(y3+y1-face_y)),(0,255,0),2)
|
160 |
+
cv2.line(img, (int(x3), int(y3)), (int(x3+x2-face_x),int(y3+y2-face_y)),(0,255,0),2)
|
161 |
+
|
162 |
+
return img
|