HPE-streaming / app.py
antoniospoletojr
updated code
38df5fe
import gradio as gr
import numpy as np
import cv2
import torch
from facenet_pytorch import MTCNN
from model import HPEnet
from torchvision import transforms
from scipy.spatial.transform import Rotation as R
from PIL import Image
from utils import draw_2D_axes
def detect_faces(image):
# Detect face
boxes, _ = mtcnn.detect(image)
boxes_centroids = []
sizes = []
faces = []
# If no boxes have been detected return
if boxes is None:
return None, None, None
# Add margin to each box, calculate centroids and crop the face image
for i in range(len(boxes)):
# Add margin while safe checking
margin=50
boxes[i][0] = max(0, boxes[i][0] - margin)
boxes[i][1] = max(0, boxes[i][1] - margin)
boxes[i][2] = min(image.width, boxes[i][2] + margin)
boxes[i][3] = min(image.height, boxes[i][3] + margin)
# Calculate centroids and sizes
boxes_centroids.append([int((boxes[i][0] + boxes[i][2])/2), int((boxes[i][1] + boxes[i][3]) /2)])
sizes.append(boxes[i][2] - boxes[i][0])
# Crop the face using boxes
faces.append(image.crop(boxes[i]))
return faces, boxes_centroids, sizes
def process(frame):
# Convert from opencv to PIL
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image = Image.fromarray(frame)
# Detect face
faces, centroids, sizes = detect_faces(image)
if faces is None:
return frame
for idx, face in enumerate(faces):
# Preprocess the image
transform = transforms.Compose([
transforms.PILToTensor(),
transforms.Resize((200, 200)),
])
face_tensor = transform(face)
face_tensor = face_tensor.permute(1, 2, 0)
# Standardize the tensor
face_tensor = (face_tensor - mean) / std
face_tensor = face_tensor.permute(2, 0, 1)
face_tensor = face_tensor.type(torch.float32)
# Run the inference
with torch.inference_mode():
face_tensor = face_tensor.unsqueeze(0).to(device)
r1, r2, r3, _ = model(face_tensor)
# Create a numpy matrix out of r1, r2, r3 (these vectors are the columns of the rotation matrix)
r1 = r1.squeeze().numpy()
r2 = r2.squeeze().numpy()
r3 = r3.squeeze().numpy()
rotation_matrix = np.array([r1, r2, r3])
r = R.from_matrix(rotation_matrix)
pitch, yaw, roll = r.as_euler('zyx', degrees=True)
center = centroids[idx]
size = sizes[idx]*0.5
frame = draw_2D_axes(frame, yaw, roll, pitch, center[0], center[1], size)
return frame
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HPEnet().to(device)
# Load model from checkpoint
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
model.to(device);
model.eval()
mtcnn = MTCNN(keep_all=True, post_process=False, device='cpu')
mean = torch.load('mean.pt')
std = torch.load('std.pt')
demo = gr.Interface(
process,
gr.Image(sources="webcam", streaming=True),
"image",
live=True,
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()