File size: 2,160 Bytes
7faf1c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import streamlit as st
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

from model import AgePredictResnet

path = './final-models/resnet_101_weigthed.pt'
age_dict = {
    0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60',
    6: '60 to 70', 7: '70 to 80', 8: 'Above 80'
}
sex_dict = {0: 'Male', 1: 'Female'}
race_dict = {
    0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)'
}

@st.experimental_memo
def load_trained_model(model_path):
    model = AgePredictResnet()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
    model.eval()
    return model


def get_predictions(input_image):
    model = load_trained_model(path)
    transforms = Compose([Resize((256, 256)), ToTensor(),
                          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    transformed_image = transforms(input_image)
    transformed_image = torch.unsqueeze(transformed_image, 0)
    with torch.inference_mode():
        logits = model(transformed_image)
    age_prob = F.softmax(logits[0], dim=1)
    sex_prob = F.softmax(logits[1], dim=1)
    race_prob = F.softmax(logits[2], dim=1)
    top2_age = torch.topk(age_prob, 2, dim=1)
    sex = torch.argmax(sex_prob, dim=1)
    top2_race = torch.topk(race_prob, 2, dim=1)
    all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), (
    sex.item(), sex_prob[0][sex.item()].item()), \
        (list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1)))

    pred_dict = {
        'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]),
        'Age Probability': all_predictions[0][0],
        'Predicted Sex': sex_dict[all_predictions[1][0]],
        'Sex Probability': all_predictions[1][1],
        'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]),
        'Race Probability': all_predictions[2][0],
    }
    return pred_dict