Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import cv2 | |
import torch | |
from facenet_pytorch import MTCNN | |
from model import HPEnet | |
from torchvision import transforms | |
from scipy.spatial.transform import Rotation as R | |
from PIL import Image | |
from utils import draw_2D_axes | |
def detect_faces(image): | |
# Detect face | |
boxes, _ = mtcnn.detect(image) | |
boxes_centroids = [] | |
sizes = [] | |
faces = [] | |
# If no boxes have been detected return | |
if boxes is None: | |
return None, None, None | |
# Add margin to each box, calculate centroids and crop the face image | |
for i in range(len(boxes)): | |
# Add margin while safe checking | |
margin=50 | |
boxes[i][0] = max(0, boxes[i][0] - margin) | |
boxes[i][1] = max(0, boxes[i][1] - margin) | |
boxes[i][2] = min(image.width, boxes[i][2] + margin) | |
boxes[i][3] = min(image.height, boxes[i][3] + margin) | |
# Calculate centroids and sizes | |
boxes_centroids.append([int((boxes[i][0] + boxes[i][2])/2), int((boxes[i][1] + boxes[i][3]) /2)]) | |
sizes.append(boxes[i][2] - boxes[i][0]) | |
# Crop the face using boxes | |
faces.append(image.crop(boxes[i])) | |
return faces, boxes_centroids, sizes | |
def process(frame): | |
# Convert from opencv to PIL | |
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
image = Image.fromarray(frame) | |
# Detect face | |
faces, centroids, sizes = detect_faces(image) | |
if faces is None: | |
return frame | |
for idx, face in enumerate(faces): | |
# Preprocess the image | |
transform = transforms.Compose([ | |
transforms.PILToTensor(), | |
transforms.Resize((200, 200)), | |
]) | |
face_tensor = transform(face) | |
face_tensor = face_tensor.permute(1, 2, 0) | |
# Standardize the tensor | |
face_tensor = (face_tensor - mean) / std | |
face_tensor = face_tensor.permute(2, 0, 1) | |
face_tensor = face_tensor.type(torch.float32) | |
# Run the inference | |
with torch.inference_mode(): | |
face_tensor = face_tensor.unsqueeze(0).to(device) | |
r1, r2, r3, _ = model(face_tensor) | |
# Create a numpy matrix out of r1, r2, r3 (these vectors are the columns of the rotation matrix) | |
r1 = r1.squeeze().numpy() | |
r2 = r2.squeeze().numpy() | |
r3 = r3.squeeze().numpy() | |
rotation_matrix = np.array([r1, r2, r3]) | |
r = R.from_matrix(rotation_matrix) | |
pitch, yaw, roll = r.as_euler('zyx', degrees=True) | |
center = centroids[idx] | |
size = sizes[idx]*0.5 | |
frame = draw_2D_axes(frame, yaw, roll, pitch, center[0], center[1], size) | |
return frame | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = HPEnet().to(device) | |
# Load model from checkpoint | |
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu'))) | |
model.to(device); | |
model.eval() | |
mtcnn = MTCNN(keep_all=True, post_process=False, device='cpu') | |
mean = torch.load('mean.pt') | |
std = torch.load('std.pt') | |
demo = gr.Interface( | |
process, | |
gr.Image(sources="webcam", streaming=True), | |
"image", | |
live=True, | |
allow_flagging="never", | |
) | |
if __name__ == "__main__": | |
demo.launch() |