MSc_02_PDL_A4 / modernity.py
maxjmohr's picture
Tasks 6-9: Add viewpoint step, add score models, finish app
258e5de
raw
history blame contribute delete
798 Bytes
import torch
import torch.nn.functional as F
def get_year_modernity_score(model, image, device):
model.eval() # Set model to evaluate mode
# Map modernity_scores into the corresponding years
year_group_mapping = {
0: "2000-2003",
1: "2006-2008",
2: "2009-2011",
3: "2012-2014",
4: "2015-2017"
}
with torch.no_grad():
image = image.to(device)
# Get the prediction probabilities, from that calculate the modernity score and round to get the year group
pred_label = model(image)
modernity_score = torch.sum(F.softmax(pred_label, dim=1) * torch.tensor([0,1,2,3,4], device=device), dim=1)
year_group = year_group_mapping[modernity_score.round().item()]
return modernity_score.item(), year_group