niks-salodkar's picture
added code and files
7faf1c4
raw
history blame contribute delete
No virus
3.69 kB
import os
from PIL import Image
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
class AgePredictResnet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet101()
self.model.fc = nn.Linear(2048, 512)
self.age_linear1 = nn.Linear(512, 256)
self.age_linear2 = nn.Linear(256, 128)
self.age_out = nn.Linear(128, 9)
self.gender_linear1 = nn.Linear(512, 256)
self.gender_linear2 = nn.Linear(256, 128)
self.gender_out = nn.Linear(128, 2)
self.race_linear1 = nn.Linear(512, 256)
self.race_linear2 = nn.Linear(256, 128)
self.race_out = nn.Linear(128, 5)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(0.4)
def forward(self, x):
out = self.activation(self.model(x))
age_out = self.activation(self.dropout((self.age_linear1(out))))
age_out = self.activation(self.dropout(self.age_linear2(age_out)))
age_out = self.age_out(age_out)
gender_out = self.activation(self.dropout((self.gender_linear1(out))))
gender_out = self.activation(self.dropout(self.gender_linear2(gender_out)))
gender_out = self.gender_out(gender_out)
race_out = self.activation(self.dropout((self.race_linear1(out))))
race_out = self.activation(self.dropout(self.race_linear2(race_out)))
race_out = self.race_out(race_out)
return age_out, gender_out, race_out
if __name__ == '__main__':
trained_model_path = os.path.join('./final-models/resnet_101_weigthed.pt')
model = AgePredictResnet()
model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')), strict=False)
model.eval()
sample_image = Image.open('../../age_prediction/data/wild_images/part1/50_1_1_20170110120147003.jpg')
transforms = Compose([Resize((256, 256)), ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transformed_image = transforms(sample_image)
transformed_image = torch.unsqueeze(transformed_image, 0)
print(transformed_image.shape)
with torch.inference_mode():
logits = model(transformed_image)
age_prob = F.softmax(logits[0], dim=1)
sex_prob = F.softmax(logits[1], dim=1)
race_prob = F.softmax(logits[2], dim=1)
top2_age = torch.topk(age_prob, 2, dim=1)
sex = torch.argmax(sex_prob, dim=1)
top2_race = torch.topk(race_prob, 2, dim=1)
all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), (
sex.item(), sex_prob[0][sex.item()].item()), \
(list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1)))
print(all_predictions)
age_dict = {
0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60',
6: '60 to 70', 7: '70 to 80', 8: 'Above 80'
}
sex_dict = {0: 'Male', 1: 'Female'}
race_dict = {
0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)'
}
#
pred_dict = {
'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]),
'Age Probability': all_predictions[0][0],
'Predicted Sex': sex_dict[all_predictions[1][0]],
'Sex Probability': all_predictions[1][1],
'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]),
'Race Probability': all_predictions[2][0],
}
print(pred_dict)