| import torch | |
| import torchvision | |
| from PIL import Image | |
| from model import * | |
| test_data = torchvision.datasets.CIFAR10("CIFAR10", False, download=False) | |
| print(test_data.class_to_idx) | |
| image_path = "" # Your test image | |
| image = Image.open(image_path) | |
| print(image) | |
| image = image.convert("RGB") | |
| transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)), | |
| torchvision.transforms.ToTensor()]) | |
| image = transform(image) | |
| print(image.shape) | |
| model = torch.load("./Mini-Vision-V1.pth", weights_only=False) | |
| image = torch.reshape(image, (1, 3, 32, 32)) | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model(image) | |
| print(output) | |
| print(output.argmax(1)) | |