|
import os |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from PIL import Image |
|
from models import * |
|
from torchmetrics import ConfusionMatrix |
|
import matplotlib.pyplot as plt |
|
from configs import * |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = MODEL.to(DEVICE) |
|
|
|
model.load_state_dict( |
|
torch.load(f"output/checkpoints/{MODEL.__class__.__name__}.pth", map_location=DEVICE)) |
|
model.eval() |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
def predict_image(image_path, model=model, transform=preprocess): |
|
classes = CLASSES |
|
|
|
print("---------------------------") |
|
print("Image path:", image_path) |
|
image = Image.open(image_path).convert("RGB") |
|
image = transform(image).unsqueeze(0) |
|
image = image.to(DEVICE) |
|
output = model(image) |
|
|
|
|
|
probabilities = torch.softmax(output, dim=1)[0] * 100 |
|
|
|
|
|
sorted_classes = sorted( |
|
zip(classes, probabilities), key=lambda x: x[1], reverse=True |
|
) |
|
|
|
|
|
print("Probabilities for each class:") |
|
for class_label, class_prob in sorted_classes: |
|
class_prob = class_prob.item().__round__(2) |
|
print(f"{class_label}: {class_prob}%") |
|
|
|
|
|
predicted_class = sorted_classes[0][0] |
|
predicted_label = classes.index(predicted_class) |
|
|
|
|
|
print("Predicted class:", predicted_label) |
|
print("Predicted label:", predicted_class) |
|
|
|
return sorted_classes |
|
|