import streamlit as st import torch import torch.nn.functional as F from torchvision.transforms import Compose, Resize, ToTensor, Normalize from model import AgePredictResnet path = './final-models/resnet_101_weigthed.pt' age_dict = { 0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60', 6: '60 to 70', 7: '70 to 80', 8: 'Above 80' } sex_dict = {0: 'Male', 1: 'Female'} race_dict = { 0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)' } @st.experimental_memo def load_trained_model(model_path): model = AgePredictResnet() model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False) model.eval() return model def get_predictions(input_image): model = load_trained_model(path) transforms = Compose([Resize((256, 256)), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) transformed_image = transforms(input_image) transformed_image = torch.unsqueeze(transformed_image, 0) with torch.inference_mode(): logits = model(transformed_image) age_prob = F.softmax(logits[0], dim=1) sex_prob = F.softmax(logits[1], dim=1) race_prob = F.softmax(logits[2], dim=1) top2_age = torch.topk(age_prob, 2, dim=1) sex = torch.argmax(sex_prob, dim=1) top2_race = torch.topk(race_prob, 2, dim=1) all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), ( sex.item(), sex_prob[0][sex.item()].item()), \ (list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1))) pred_dict = { 'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]), 'Age Probability': all_predictions[0][0], 'Predicted Sex': sex_dict[all_predictions[1][0]], 'Sex Probability': all_predictions[1][1], 'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]), 'Race Probability': all_predictions[2][0], } return pred_dict