File size: 2,052 Bytes
cf3ce1b
cde7e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf3ce1b
 
 
 
cde7e09
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
from model import SixDRepNet
import os

import numpy as np
import torch
from torchvision import transforms
import utils
import time

transformations = transforms.Compose([transforms.Resize(224),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

model = SixDRepNet(backbone_name='RepVGG-A0',
                       backbone_file='',
                       deploy=True,
                       pretrained=False)

saved_state_dict = torch.load(os.path.join(
    "weights_ALFW_A0.pth"), map_location='cpu')

if 'model_state_dict' in saved_state_dict:
    model.load_state_dict(saved_state_dict['model_state_dict'])
else:
    model.load_state_dict(saved_state_dict)

# Test the Model
model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).

th = 15

def predict(img):
    
    img = img.convert('RGB')
    img = transformations(img).unsqueeze(0)
    with torch.no_grad():
        start = time.time()
        R_pred = model(img)
        end = time.time()
        timemilis = (end - start)*1000

        euler = utils.compute_euler_angles_from_rotation_matrices(
            R_pred,use_gpu=False)*180/np.pi
        p_pred_deg = euler[:, 0].cpu().item()
        y_pred_deg = euler[:, 1].cpu().item()
        direction_str = ""
        if p_pred_deg > th:
            direction_str = "UP "
        elif p_pred_deg < th:
            direction_str ="DOWN "
        
        if y_pred_deg > th:
            direction_str += "LEFT"
        elif y_pred_deg < th:
            direction_str += "RIGHT"

        return f"Yaw: {y_pred_deg:0.1f} \n Pitch: {p_pred_deg:0.1f}\n Direction: {direction_str} \n Time: {timemilis:0.2f}ms"


gr.Interface(fn=predict, 
             inputs=gr.Image(type="pil"),
             outputs=gr.Textbox(),
             examples=["face_left.jpg","face_right.jpg","face_up.jpg","face_down.jpg"]).launch(share=True)