Age-Prediction-Demo / inference.py
niks-salodkar's picture
added code and files
7faf1c4
raw
history blame contribute delete
No virus
2.16 kB
import streamlit as st
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from model import AgePredictResnet
path = './final-models/resnet_101_weigthed.pt'
age_dict = {
0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60',
6: '60 to 70', 7: '70 to 80', 8: 'Above 80'
}
sex_dict = {0: 'Male', 1: 'Female'}
race_dict = {
0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)'
}
@st.experimental_memo
def load_trained_model(model_path):
model = AgePredictResnet()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
model.eval()
return model
def get_predictions(input_image):
model = load_trained_model(path)
transforms = Compose([Resize((256, 256)), ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transformed_image = transforms(input_image)
transformed_image = torch.unsqueeze(transformed_image, 0)
with torch.inference_mode():
logits = model(transformed_image)
age_prob = F.softmax(logits[0], dim=1)
sex_prob = F.softmax(logits[1], dim=1)
race_prob = F.softmax(logits[2], dim=1)
top2_age = torch.topk(age_prob, 2, dim=1)
sex = torch.argmax(sex_prob, dim=1)
top2_race = torch.topk(race_prob, 2, dim=1)
all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), (
sex.item(), sex_prob[0][sex.item()].item()), \
(list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1)))
pred_dict = {
'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]),
'Age Probability': all_predictions[0][0],
'Predicted Sex': sex_dict[all_predictions[1][0]],
'Sex Probability': all_predictions[1][1],
'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]),
'Race Probability': all_predictions[2][0],
}
return pred_dict