antoniospoletojr commited on
Commit
e2508c0
1 Parent(s): a12e461

first commit

Browse files
Files changed (7) hide show
  1. app.py +111 -0
  2. mean.pt +3 -0
  3. model.pt +3 -0
  4. model.py +41 -0
  5. requirements.txt +5 -0
  6. std.pt +3 -0
  7. utils.py +162 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ from facenet_pytorch import MTCNN
6
+ from model import HPEnet
7
+ from torchvision import transforms
8
+ from scipy.spatial.transform import Rotation as R
9
+ import numpy as np
10
+ from PIL import Image
11
+ from utils import draw_2D_axes
12
+
13
+
14
+ def detect_faces(image):
15
+ # Detect face
16
+ boxes, _ = mtcnn.detect(image)
17
+ boxes_centroids = []
18
+ sizes = []
19
+ faces = []
20
+
21
+ # If no boxes have been detected return
22
+ if boxes is None:
23
+ return None, None, None
24
+
25
+ # Add margin to each box, calculate centroids and crop the face image
26
+ for i in range(len(boxes)):
27
+ # Add margin while safe checking
28
+ margin=50
29
+ boxes[i][0] = max(0, boxes[i][0] - margin)
30
+ boxes[i][1] = max(0, boxes[i][1] - margin)
31
+ boxes[i][2] = min(image.width, boxes[i][2] + margin)
32
+ boxes[i][3] = min(image.height, boxes[i][3] + margin)
33
+
34
+ # Calculate centroids and sizes
35
+ boxes_centroids.append([int((boxes[i][0] + boxes[i][2])/2), int((boxes[i][1] + boxes[i][3]) /2)])
36
+ sizes.append(boxes[i][2] - boxes[i][0])
37
+ # Crop the face using boxes
38
+ faces.append(image.crop(boxes[i]))
39
+
40
+ return faces, boxes_centroids, sizes
41
+
42
+ def process(frame):
43
+ # Convert from opencv to PIL
44
+ image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
45
+ image = Image.fromarray(frame)
46
+
47
+ # Detect face
48
+ faces, centroids, sizes = detect_faces(image)
49
+
50
+ if faces is None:
51
+ return frame
52
+
53
+ for idx, face in enumerate(faces):
54
+
55
+ # Preprocess the image
56
+ transform = transforms.Compose([
57
+ transforms.PILToTensor(),
58
+ transforms.Resize((200, 200)),
59
+ ])
60
+ face_tensor = transform(face)
61
+ face_tensor = face_tensor.permute(1, 2, 0)
62
+
63
+ # Standardize the tensor
64
+ face_tensor = (face_tensor - mean) / std
65
+ face_tensor = face_tensor.permute(2, 0, 1)
66
+ face_tensor = face_tensor.type(torch.float32)
67
+
68
+ # Run the inference
69
+ with torch.inference_mode():
70
+ face_tensor = face_tensor.unsqueeze(0).to(device)
71
+ r1, r2, r3, _ = model(face_tensor)
72
+
73
+ # Create a numpy matrix out of r1, r2, r3 (these vectors are the columns of the rotation matrix)
74
+ r1 = r1.squeeze().numpy()
75
+ r2 = r2.squeeze().numpy()
76
+ r3 = r3.squeeze().numpy()
77
+
78
+ rotation_matrix = np.array([r1, r2, r3])
79
+
80
+ r = R.from_matrix(rotation_matrix)
81
+
82
+ pitch, yaw, roll = r.as_euler('zyx', degrees=True)
83
+
84
+ center = centroids[idx]
85
+ size = sizes[idx]*0.5
86
+
87
+ frame = draw_2D_axes(frame, yaw, roll, pitch, center[0], center[1], size)
88
+
89
+ return frame
90
+
91
+
92
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
+ model = HPEnet().to(device)
94
+ # Load model from checkpoint
95
+ model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
96
+ model.to(device);
97
+ model.eval()
98
+ mtcnn = MTCNN(keep_all=True, post_process=False, device='cpu')
99
+ mean = torch.load('mean.pt')
100
+ std = torch.load('std.pt')
101
+
102
+ demo = gr.Interface(
103
+ process,
104
+ gr.Image(sources="webcam", streaming=True),
105
+ "image",
106
+ live=True,
107
+ allow_flagging="never",
108
+ )
109
+
110
+ if __name__ == "__main__":
111
+ demo.launch()
mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:173d644e647a24e9d317b78ea47774a3bc60257952fc5f8f3a5951af56b80b80
3
+ size 961101
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:997e75103549eab362cec3939c35c9ebe1cba720a18d166424d48425e41af20d
3
+ size 143750654
model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import torch.nn as nn
3
+ import torch
4
+ import math
5
+
6
+
7
+ class HPEnet(nn.Module):
8
+ def __init__(self, roll_bins=18, yaw_bins=29, pitch_bins=21):
9
+ super(HPEnet, self).__init__()
10
+ print("Loading the model...")
11
+
12
+ self.resnet = torchvision.models.resnet50(weights="ResNet50_Weights.DEFAULT") #ResNet50_Weights.DEFAULT
13
+ self.resnet.fc = nn.Linear(2048, 2048)
14
+ self.fc = nn.Linear(2048, 2048)
15
+
16
+ # Classification layers
17
+ self.fc_class = nn.Linear(2048, 1921)
18
+
19
+ # Regression layers
20
+ self.fc_r1 = nn.Linear(2048, 3)
21
+ self.fc_r2 = nn.Linear(2048, 3)
22
+ self.fc_r3 = nn.Linear(2048, 3)
23
+
24
+ def forward(self, x):
25
+ # Backbone
26
+ x = self.resnet(x)
27
+
28
+ # Dense layer
29
+ x = torch.nn.functional.relu(x)
30
+ x = self.fc(x)
31
+
32
+ # Regression layers
33
+ r1 = self.fc_r1(x)
34
+ r2 = self.fc_r2(x)
35
+ r3 = self.fc_r3(x)
36
+
37
+ # Classification layers
38
+ x = torch.nn.functional.relu(x)
39
+ x = self.fc_class(x)
40
+
41
+ return r1, r2, r3, x
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ opencv-python
3
+ scipy
4
+ facenet_pytorch
5
+ pillow
std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e1e3b7a9db0ac613f8c5d4deb0b73151d08aa485459096cdb39aa2e294f5fea
3
+ size 961096
utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ from math import cos, sin
6
+ from scipy.spatial.transform import Rotation as R
7
+
8
+
9
+ def plot_3D_rotation(rotation_matrix):
10
+ fig = go.Figure()
11
+
12
+ # Original axis orientation
13
+ axes_points = np.array([
14
+ [1, 0, 0, 0],
15
+ [0, 1, 0, 0],
16
+ [0, 0, 1, 0]
17
+ ], dtype=np.float64)
18
+
19
+ # Plot original axes
20
+ fig.add_trace(go.Scatter3d(
21
+ x=[0, axes_points[0, 0]],
22
+ y=[0, axes_points[1, 0]],
23
+ z=[0, axes_points[2, 0]],
24
+ mode='lines+text',
25
+ line=dict(color='blue', width=6),
26
+ name='Canonical X-axis',
27
+ text=['', 'X axis'],
28
+ textposition='middle center',
29
+ ))
30
+
31
+ fig.add_trace(go.Scatter3d(
32
+ x=[0, axes_points[0, 1]],
33
+ y=[0, axes_points[1, 1]],
34
+ z=[0, axes_points[2, 1]],
35
+ mode='lines+text',
36
+ line=dict(color='blue', width=6),
37
+ name='Canonical Z-axis',
38
+ text=['', 'Z axis'],
39
+ textposition='middle center',
40
+ ))
41
+
42
+ fig.add_trace(go.Scatter3d(
43
+ x=[0, axes_points[0, 2]],
44
+ y=[0, axes_points[1, 2]],
45
+ z=[0, axes_points[2, 2]],
46
+ mode='lines+text',
47
+ line=dict(color='blue', width=6),
48
+ name='Canonical Y-axis',
49
+ text=['', 'Y axis'],
50
+ textposition='middle center',
51
+ ))
52
+
53
+ # Apply rotation
54
+ axes_points = rotation_matrix @ axes_points
55
+
56
+ # Plot rotated axes
57
+ fig.add_trace(go.Scatter3d(
58
+ x=[0, axes_points[0, 0]],
59
+ y=[0, axes_points[1, 0]],
60
+ z=[0, axes_points[2, 0]],
61
+ mode='lines+text',
62
+ line=dict(color='red', width=6),
63
+ name='Rotated X\'-axis',
64
+ text=['', 'Rotated X axis'],
65
+ textposition='middle center',
66
+ ))
67
+
68
+ fig.add_trace(go.Scatter3d(
69
+ x=[0, axes_points[0, 1]],
70
+ y=[0, axes_points[1, 1]],
71
+ z=[0, axes_points[2, 1]],
72
+ mode='lines+text',
73
+ line=dict(color='red', width=6),
74
+ name='Rotated Z\'-axis',
75
+ text=['', 'Rotated Z axis'],
76
+ textposition='middle center',
77
+ ))
78
+
79
+ fig.add_trace(go.Scatter3d(
80
+ x=[0, axes_points[0, 2]],
81
+ y=[0, axes_points[1, 2]],
82
+ z=[0, axes_points[2, 2]],
83
+ mode='lines+text',
84
+ line=dict(color='red', width=6),
85
+ name='Rotated Y\'-axis',
86
+ text=['', 'Rotated Y axis'],
87
+ textposition='middle center',
88
+ ))
89
+
90
+ # Retrieve pitch, yaw, roll from rotation matrix
91
+ r = R.from_matrix(rotation_matrix)
92
+ pitch, yaw, roll = r.as_euler('xzy', degrees=True)
93
+
94
+ # Set layout
95
+ fig.update_layout(
96
+ scene=dict(
97
+ xaxis=dict(title='X-axis', range=[-1.2, 1.2]),
98
+ yaxis=dict(title='Z-axis', range=[-1.2, 1.2]),
99
+ zaxis=dict(title='Y-axis', range=[-1.2, 1.2]),
100
+ xaxis_tickvals=np.arange(-1.2, 1.2, 0.6),
101
+ yaxis_tickvals=np.arange(-1.2, 1.2, 0.5),
102
+ zaxis_tickvals=np.arange(-1.2, 1.2, 0.5),
103
+ aspectmode='cube',
104
+ aspectratio=dict(x=1, y=1, z=1),
105
+ ),
106
+ margin=dict(l=0, r=0, t=0, b=30),
107
+ )
108
+ # add annotation
109
+ fig.add_annotation(dict(font=dict(color='black',size=15),
110
+ x=-30,
111
+ y=50,
112
+ showarrow=False,
113
+ text=f"Pitch: {int(pitch)} - Yaw: {int(yaw)} - Roll: {int(roll)}",
114
+ textangle=0,
115
+ xanchor='left',
116
+ xref="paper",
117
+ yref="paper"))
118
+ return fig
119
+
120
+
121
+ def draw_2D_axes(img, roll, pitch, yaw, tdx=None, tdy=None, size=150.):
122
+ # Input is a cv2 image
123
+ # pose_params: (pitch, yaw, roll, tdx, tdy)
124
+ # Where (tdx, tdy) is the translation of the face.
125
+ # For pose we have [pitch yaw roll tdx tdy tdz scale_factor]
126
+
127
+ p = pitch * np.pi / 180
128
+ y = (yaw * np.pi / 180)
129
+ r = -roll * np.pi / 180
130
+ if tdx != None and tdy != None:
131
+ face_x = tdx - 0.50 * size
132
+ face_y = tdy - 0.50 * size
133
+
134
+ else:
135
+ height, width = img.shape[:2]
136
+ face_x = width / 2 - 0.5 * size
137
+ face_y = height / 2 - 0.5 * size
138
+
139
+ x1 = size * (cos(y) * cos(r)) + face_x
140
+ y1 = size * (cos(p) * sin(r) + cos(r) * sin(p) * sin(y)) + face_y
141
+ x2 = size * (-cos(y) * sin(r)) + face_x
142
+ y2 = size * (cos(p) * cos(r) - sin(p) * sin(y) * sin(r)) + face_y
143
+ x3 = size * (sin(y)) + face_x
144
+ y3 = size * (-cos(y) * sin(p)) + face_y
145
+
146
+ # Draw base in red
147
+ cv2.line(img, (int(face_x), int(face_y)), (int(x1),int(y1)),(0,0,255),3)
148
+ cv2.line(img, (int(face_x), int(face_y)), (int(x2),int(y2)),(0,0,255),3)
149
+ cv2.line(img, (int(x2), int(y2)), (int(x2+x1-face_x),int(y2+y1-face_y)),(0,0,255),3)
150
+ cv2.line(img, (int(x1), int(y1)), (int(x1+x2-face_x),int(y1+y2-face_y)),(0,0,255),3)
151
+ # Draw pillars in blue
152
+ cv2.line(img, (int(face_x), int(face_y)), (int(x3),int(y3)),(255,0,0),2)
153
+ cv2.line(img, (int(x1), int(y1)), (int(x1+x3-face_x),int(y1+y3-face_y)),(255,0,0),2)
154
+ cv2.line(img, (int(x2), int(y2)), (int(x2+x3-face_x),int(y2+y3-face_y)),(255,0,0),2)
155
+ cv2.line(img, (int(x2+x1-face_x),int(y2+y1-face_y)), (int(x3+x1+x2-2*face_x),int(y3+y2+y1-2*face_y)),(255,0,0),2)
156
+ # Draw top in green
157
+ cv2.line(img, (int(x3+x1-face_x),int(y3+y1-face_y)), (int(x3+x1+x2-2*face_x),int(y3+y2+y1-2*face_y)),(0,255,0),2)
158
+ cv2.line(img, (int(x2+x3-face_x),int(y2+y3-face_y)), (int(x3+x1+x2-2*face_x),int(y3+y2+y1-2*face_y)),(0,255,0),2)
159
+ cv2.line(img, (int(x3), int(y3)), (int(x3+x1-face_x),int(y3+y1-face_y)),(0,255,0),2)
160
+ cv2.line(img, (int(x3), int(y3)), (int(x3+x2-face_x),int(y3+y2-face_y)),(0,255,0),2)
161
+
162
+ return img