Rbcloud's picture
Upload 4 files
a189a79
raw
history blame
730 Bytes
import torch
from torchvision import transforms
from PIL import Image, ImageFile
import pandas as pd
import os
import math
from model import ConvolutionalNet
from collections import Counter
from vector_dict import vector_dict
ImageFile.LOAD_TRUNCATED_IMAGES = True
model = ConvolutionalNet()
model.load_state_dict(torch.load('model.pt'))
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
transforms.Resize((256, 256))
])
def get_prediction(path):
img = Image.open(path)
with torch.no_grad():
pred = model(transform(img))
return vector_dict[torch.max(pred, 1)[1].item()]
print(get_prediction('data/test/Afghanistan/39841.png'))