|
import torch |
|
from torchvision import transforms |
|
from PIL import Image, ImageFile |
|
import pandas as pd |
|
import os |
|
import math |
|
from model import ConvolutionalNet |
|
from collections import Counter |
|
from vector_dict import vector_dict |
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
model = ConvolutionalNet() |
|
model.load_state_dict(torch.load('model.pt')) |
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
transforms.Resize((256, 256)) |
|
]) |
|
|
|
def get_prediction(path): |
|
img = Image.open(path) |
|
with torch.no_grad(): |
|
pred = model(transform(img)) |
|
return vector_dict[torch.max(pred, 1)[1].item()] |
|
|
|
print(get_prediction('data/test/Afghanistan/39841.png')) |
|
|