MSc_02_PDL_A4 / viewpoint.py
maxjmohr's picture
Tasks 6-9: Add viewpoint step, add score models, finish app
258e5de
raw
history blame contribute delete
338 Bytes
import torch
def get_viewpoint(model, image, device):
model.eval() # Set model to evaluate mode
with torch.no_grad():
image = image.to(device)
# Get the prediction probabilities, the highest is the viewpoint
pred_label = model(image)
viewpoint = torch.argmax(pred_label, 1)
return viewpoint