Mini-Vision-V1 / test.py
LWWZH's picture
Upload Mini-Vision-V1 of the Mini-Vision-Series
c3f9ef7 verified
raw
history blame contribute delete
743 Bytes
import torch
import torchvision
from PIL import Image
from model import *
test_data = torchvision.datasets.CIFAR10("CIFAR10", False, download=False)
print(test_data.class_to_idx)
image_path = "" # Your test image
image = Image.open(image_path)
print(image)
image = image.convert("RGB")
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)
model = torch.load("./Mini-Vision-V1.pth", weights_only=False)
image = torch.reshape(image, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
output = model(image)
print(output)
print(output.argmax(1))