import os os.system("pip install git+https://github.com/elliottzheng/face-detection.git@master") os.system("git clone https://github.com/thohemp/6DRepNet") import sys sys.path.append("6DRepNet") import numpy as np import gradio as gr import torch from huggingface_hub import hf_hub_download from face_detection import RetinaFace from model import SixDRepNet import utils import cv2 from PIL import Image snapshot_path = hf_hub_download(repo_id="osanseviero/6DRepNet_300W_LP_AFLW2000", filename="model.pth") model = SixDRepNet(backbone_name='RepVGG-B1g2', backbone_file='', deploy=True, pretrained=False) detector = RetinaFace(0) saved_state_dict = torch.load(os.path.join( snapshot_path), map_location='cpu') if 'model_state_dict' in saved_state_dict: model.load_state_dict(saved_state_dict['model_state_dict']) else: model.load_state_dict(saved_state_dict) model.cuda(0) model.eval() def predict(frame): faces = detector(frame) for box, landmarks, score in faces: # Print the location of each face in this image if score < .95: continue x_min = int(box[0]) y_min = int(box[1]) x_max = int(box[2]) y_max = int(box[3]) bbox_width = abs(x_max - x_min) bbox_height = abs(y_max - y_min) x_min = max(0,x_min-int(0.2*bbox_height)) y_min = max(0,y_min-int(0.2*bbox_width)) x_max = x_max+int(0.2*bbox_height) y_max = y_max+int(0.2*bbox_width) img = frame[y_min:y_max,x_min:x_max] img = cv2.resize(img, (244, 244))/255.0 img = img.transpose(2, 0, 1) img = torch.from_numpy(img).type(torch.FloatTensor) img = torch.Tensor(img).cuda(0) img=img.unsqueeze(0) R_pred = model(img) euler = utils.compute_euler_angles_from_rotation_matrices( R_pred)*180/np.pi p_pred_deg = euler[:, 0].cpu() y_pred_deg = euler[:, 1].cpu() r_pred_deg = euler[:, 2].cpu() return utils.plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5*(x_max-x_min)), y_min + int(.5*(y_max-y_min)), size = bbox_width) title = "6D Rotation Representation for Unconstrained Head Pose Estimation" description = "Gradio demo for 6DRepNet. To use it, simply click the camera picture. Read more at the links below." article = "
Github Repo | Paper
" image_flip_css = """ .input-image .image-preview img{ -webkit-transform: scaleX(-1); transform: scaleX(-1) !important; } .output-image img { -webkit-transform: scaleX(-1); transform: scaleX(-1) !important; } """ iface = gr.Interface( fn=predict, inputs=gr.inputs.Image(label="Input Image", source="webcam"), outputs='image', live=True, title=title, description=description, article=article, css = image_flip_css ) iface.launch()