import gradio as gr import numpy as np import cv2 import torch from facenet_pytorch import MTCNN from model import HPEnet from torchvision import transforms from scipy.spatial.transform import Rotation as R from PIL import Image from utils import draw_2D_axes def detect_faces(image): # Detect face boxes, _ = mtcnn.detect(image) boxes_centroids = [] sizes = [] faces = [] # If no boxes have been detected return if boxes is None: return None, None, None # Add margin to each box, calculate centroids and crop the face image for i in range(len(boxes)): # Add margin while safe checking margin=50 boxes[i][0] = max(0, boxes[i][0] - margin) boxes[i][1] = max(0, boxes[i][1] - margin) boxes[i][2] = min(image.width, boxes[i][2] + margin) boxes[i][3] = min(image.height, boxes[i][3] + margin) # Calculate centroids and sizes boxes_centroids.append([int((boxes[i][0] + boxes[i][2])/2), int((boxes[i][1] + boxes[i][3]) /2)]) sizes.append(boxes[i][2] - boxes[i][0]) # Crop the face using boxes faces.append(image.crop(boxes[i])) return faces, boxes_centroids, sizes def process(frame): # Convert from opencv to PIL image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) image = Image.fromarray(frame) # Detect face faces, centroids, sizes = detect_faces(image) if faces is None: return frame for idx, face in enumerate(faces): # Preprocess the image transform = transforms.Compose([ transforms.PILToTensor(), transforms.Resize((200, 200)), ]) face_tensor = transform(face) face_tensor = face_tensor.permute(1, 2, 0) # Standardize the tensor face_tensor = (face_tensor - mean) / std face_tensor = face_tensor.permute(2, 0, 1) face_tensor = face_tensor.type(torch.float32) # Run the inference with torch.inference_mode(): face_tensor = face_tensor.unsqueeze(0).to(device) r1, r2, r3, _ = model(face_tensor) # Create a numpy matrix out of r1, r2, r3 (these vectors are the columns of the rotation matrix) r1 = r1.squeeze().numpy() r2 = r2.squeeze().numpy() r3 = r3.squeeze().numpy() rotation_matrix = np.array([r1, r2, r3]) r = R.from_matrix(rotation_matrix) pitch, yaw, roll = r.as_euler('zyx', degrees=True) center = centroids[idx] size = sizes[idx]*0.5 frame = draw_2D_axes(frame, yaw, roll, pitch, center[0], center[1], size) return frame device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = HPEnet().to(device) # Load model from checkpoint model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu'))) model.to(device); model.eval() mtcnn = MTCNN(keep_all=True, post_process=False, device='cpu') mean = torch.load('mean.pt') std = torch.load('std.pt') demo = gr.Interface( process, gr.Image(sources="webcam", streaming=True), "image", live=True, allow_flagging="never", ) if __name__ == "__main__": demo.launch()