import torch from torchvision import transforms from PIL import Image, ImageFile import pandas as pd import os import math from model import ConvolutionalNet from collections import Counter from vector_dict import vector_dict ImageFile.LOAD_TRUNCATED_IMAGES = True model = ConvolutionalNet() model.load_state_dict(torch.load('model.pt')) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), transforms.Resize((256, 256)) ]) def get_prediction(path): img = Image.open(path) with torch.no_grad(): pred = model(transform(img)) return vector_dict[torch.max(pred, 1)[1].item()] print(get_prediction('data/test/Afghanistan/39841.png'))