import os from PIL import Image import torch import torchvision from torch import nn import torch.nn.functional as F from torchvision.transforms import Compose, Resize, ToTensor, Normalize class AgePredictResnet(nn.Module): def __init__(self): super().__init__() self.model = torchvision.models.resnet101() self.model.fc = nn.Linear(2048, 512) self.age_linear1 = nn.Linear(512, 256) self.age_linear2 = nn.Linear(256, 128) self.age_out = nn.Linear(128, 9) self.gender_linear1 = nn.Linear(512, 256) self.gender_linear2 = nn.Linear(256, 128) self.gender_out = nn.Linear(128, 2) self.race_linear1 = nn.Linear(512, 256) self.race_linear2 = nn.Linear(256, 128) self.race_out = nn.Linear(128, 5) self.activation = nn.ReLU() self.dropout = nn.Dropout(0.4) def forward(self, x): out = self.activation(self.model(x)) age_out = self.activation(self.dropout((self.age_linear1(out)))) age_out = self.activation(self.dropout(self.age_linear2(age_out))) age_out = self.age_out(age_out) gender_out = self.activation(self.dropout((self.gender_linear1(out)))) gender_out = self.activation(self.dropout(self.gender_linear2(gender_out))) gender_out = self.gender_out(gender_out) race_out = self.activation(self.dropout((self.race_linear1(out)))) race_out = self.activation(self.dropout(self.race_linear2(race_out))) race_out = self.race_out(race_out) return age_out, gender_out, race_out if __name__ == '__main__': trained_model_path = os.path.join('./final-models/resnet_101_weigthed.pt') model = AgePredictResnet() model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')), strict=False) model.eval() sample_image = Image.open('../../age_prediction/data/wild_images/part1/50_1_1_20170110120147003.jpg') transforms = Compose([Resize((256, 256)), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) transformed_image = transforms(sample_image) transformed_image = torch.unsqueeze(transformed_image, 0) print(transformed_image.shape) with torch.inference_mode(): logits = model(transformed_image) age_prob = F.softmax(logits[0], dim=1) sex_prob = F.softmax(logits[1], dim=1) race_prob = F.softmax(logits[2], dim=1) top2_age = torch.topk(age_prob, 2, dim=1) sex = torch.argmax(sex_prob, dim=1) top2_race = torch.topk(race_prob, 2, dim=1) all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), ( sex.item(), sex_prob[0][sex.item()].item()), \ (list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1))) print(all_predictions) age_dict = { 0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60', 6: '60 to 70', 7: '70 to 80', 8: 'Above 80' } sex_dict = {0: 'Male', 1: 'Female'} race_dict = { 0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)' } # pred_dict = { 'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]), 'Age Probability': all_predictions[0][0], 'Predicted Sex': sex_dict[all_predictions[1][0]], 'Sex Probability': all_predictions[1][1], 'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]), 'Race Probability': all_predictions[2][0], } print(pred_dict)