File size: 3,234 Bytes
e2508c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import numpy as np
import cv2
import torch
from facenet_pytorch import MTCNN
from model import HPEnet
from torchvision import transforms
from scipy.spatial.transform import Rotation as R
from PIL import Image
from utils import draw_2D_axes


def detect_faces(image):
    # Detect face
    boxes, _ = mtcnn.detect(image)
    boxes_centroids = []
    sizes = []
    faces = []

    # If no boxes have been detected return
    if boxes is None:
        return None, None, None

    # Add margin to each box, calculate centroids and crop the face image
    for i in range(len(boxes)):
        # Add margin while safe checking
        margin=50
        boxes[i][0] = max(0, boxes[i][0] - margin)
        boxes[i][1] = max(0, boxes[i][1] - margin)
        boxes[i][2] = min(image.width, boxes[i][2] + margin)
        boxes[i][3] = min(image.height, boxes[i][3] + margin)

        # Calculate centroids and sizes
        boxes_centroids.append([int((boxes[i][0] + boxes[i][2])/2), int((boxes[i][1] + boxes[i][3]) /2)])
        sizes.append(boxes[i][2] - boxes[i][0])
        # Crop the face using boxes
        faces.append(image.crop(boxes[i]))

    return faces, boxes_centroids, sizes

def process(frame):
    # Convert from opencv to PIL
    image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    image = Image.fromarray(frame)

    # Detect face
    faces, centroids, sizes = detect_faces(image)

    if faces is None:
        return frame

    for idx, face in enumerate(faces):

        # Preprocess the image
        transform = transforms.Compose([ 
            transforms.PILToTensor(),
            transforms.Resize((200, 200)),
        ])
        face_tensor = transform(face)
        face_tensor = face_tensor.permute(1, 2, 0)

        # Standardize the tensor
        face_tensor = (face_tensor - mean) / std
        face_tensor = face_tensor.permute(2, 0, 1)
        face_tensor = face_tensor.type(torch.float32)

        # Run the inference
        with torch.inference_mode():
            face_tensor = face_tensor.unsqueeze(0).to(device)
            r1, r2, r3, _ = model(face_tensor)

            # Create a numpy matrix out of r1, r2, r3 (these vectors are the columns of the rotation matrix)
            r1 = r1.squeeze().numpy()
            r2 = r2.squeeze().numpy()
            r3 = r3.squeeze().numpy()

            rotation_matrix = np.array([r1, r2, r3])

            r = R.from_matrix(rotation_matrix)

            pitch, yaw, roll = r.as_euler('zyx', degrees=True)

            center = centroids[idx]
            size = sizes[idx]*0.5

            frame = draw_2D_axes(frame, yaw, roll, pitch, center[0], center[1], size)
        
    return frame


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HPEnet().to(device)
# Load model from checkpoint
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
model.to(device);
model.eval()
mtcnn = MTCNN(keep_all=True, post_process=False, device='cpu')
mean = torch.load('mean.pt')
std = torch.load('std.pt') 

demo = gr.Interface(
    process, 
    gr.Image(sources="webcam", streaming=True), 
    "image",
    live=True,
    allow_flagging="never",
)

if __name__ == "__main__":
    demo.launch()