bien.nguyen1
add model
cde7e09
raw
history blame
No virus
2.05 kB
import gradio as gr
from model import SixDRepNet
import os
import numpy as np
import torch
from torchvision import transforms
import utils
import time
transformations = transforms.Compose([transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
model = SixDRepNet(backbone_name='RepVGG-A0',
backbone_file='',
deploy=True,
pretrained=False)
saved_state_dict = torch.load(os.path.join(
"weights_ALFW_A0.pth"), map_location='cpu')
if 'model_state_dict' in saved_state_dict:
model.load_state_dict(saved_state_dict['model_state_dict'])
else:
model.load_state_dict(saved_state_dict)
# Test the Model
model.eval() # Change model to 'eval' mode (BN uses moving mean/var).
th = 15
def predict(img):
img = img.convert('RGB')
img = transformations(img).unsqueeze(0)
with torch.no_grad():
start = time.time()
R_pred = model(img)
end = time.time()
timemilis = (end - start)*1000
euler = utils.compute_euler_angles_from_rotation_matrices(
R_pred,use_gpu=False)*180/np.pi
p_pred_deg = euler[:, 0].cpu().item()
y_pred_deg = euler[:, 1].cpu().item()
direction_str = ""
if p_pred_deg > th:
direction_str = "UP "
elif p_pred_deg < th:
direction_str ="DOWN "
if y_pred_deg > th:
direction_str += "LEFT"
elif y_pred_deg < th:
direction_str += "RIGHT"
return f"Yaw: {y_pred_deg:0.1f} \n Pitch: {p_pred_deg:0.1f}\n Direction: {direction_str} \n Time: {timemilis:0.2f}ms"
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(),
examples=["face_left.jpg","face_right.jpg","face_up.jpg","face_down.jpg"]).launch(share=True)