pokemon_classification / utils /inference_utils.py
MostHumble's picture
add inference script
eed12b2
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import os
import random
from utils.data import CLASS_NAMES
# Function to find correctly and incorrectly classified images
def find_images(dataloader, model, device, num_correct, num_incorrect):
correct_images = []
incorrect_images = []
correct_labels = []
incorrect_labels = []
correct_preds = []
incorrect_preds = []
model.eval()
with torch.no_grad():
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
for i in range(images.size(0)):
if preds[i] == labels[i] and len(correct_images) < num_correct:
correct_images.append(images[i].cpu())
correct_labels.append(labels[i].cpu())
correct_preds.append(preds[i].cpu())
elif preds[i] != labels[i] and len(incorrect_images) < num_incorrect:
incorrect_images.append(images[i].cpu())
incorrect_labels.append(labels[i].cpu())
incorrect_preds.append(preds[i].cpu())
if (
len(correct_images) >= num_correct
and len(incorrect_images) >= num_incorrect
):
break
if (
len(correct_images) >= num_correct
and len(incorrect_images) >= num_incorrect
):
break
return (
correct_images,
correct_labels,
correct_preds,
incorrect_images,
incorrect_labels,
incorrect_preds,
)
def find_images_from_path(data_path, model, device, num_correct=2, num_incorrect=2, label=None):
correct_images_paths = []
incorrect_images_paths = []
correct_labels = []
incorrect_labels = []
label_to_idx = {label: idx for idx, label in enumerate(CLASS_NAMES)}
model.eval()
# First collect available images for the specified label or all labels
label_images = {}
if label:
if os.path.isdir(os.path.join(data_path, label)):
label_path = os.path.join(data_path, label)
label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
else:
for label in os.listdir(data_path):
label_path = os.path.join(data_path, label)
if not os.path.isdir(label_path):
continue
label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
# Randomly process images until we have enough samples
with torch.no_grad():
while len(correct_images_paths) < num_correct or len(incorrect_images_paths) < num_incorrect:
# Randomly select a label that still has unprocessed images
available_labels = [l for l in label_images if label_images[l]]
if not available_labels:
break
selected_label = random.choice(available_labels)
image_path = random.choice(label_images[selected_label])
label_images[selected_label].remove(image_path) # Remove the selected image
image = preprocess_image(image_path, (224, 224)).to(device)
label_idx = label_to_idx[selected_label]
outputs = model(image)
_, pred = torch.max(outputs, 1)
if pred == label_idx and len(correct_images_paths) < num_correct:
correct_images_paths.append(image_path)
correct_labels.append(label_idx)
elif pred != label_idx and len(incorrect_images_paths) < num_incorrect:
incorrect_images_paths.append(image_path)
incorrect_labels.append(label_idx)
save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels)
def save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels):
# Create root directories for correct and incorrect classifications
for class_name in CLASS_NAMES:
os.makedirs(os.path.join('predictions', class_name, 'correct'), exist_ok=True)
os.makedirs(os.path.join('predictions', class_name, 'mistake'), exist_ok=True)
# Save correctly classified images
for img_path, label in zip(correct_images_paths, correct_labels):
class_name = CLASS_NAMES[label]
img_name = os.path.basename(img_path)
destination = os.path.join('predictions', class_name, 'correct', img_name)
os.makedirs(os.path.dirname(destination), exist_ok=True)
Image.open(img_path).save(destination)
# Save incorrectly classified images
for img_path, label in zip(incorrect_images_paths, incorrect_labels):
class_name = CLASS_NAMES[label]
img_name = os.path.basename(img_path)
destination = os.path.join('predictions', class_name, 'mistake', img_name)
os.makedirs(os.path.dirname(destination), exist_ok=True)
Image.open(img_path).save(destination)
def show_samples(dataloader, model, device, num_correct=3, num_incorrect=3):
# Get some correctly and incorrectly classified images
(
correct_images,
correct_labels,
correct_preds,
incorrect_images,
incorrect_labels,
incorrect_preds,
) = find_images(dataloader, model, device, num_correct, num_incorrect)
# Display the results in a grid
fig, axes = plt.subplots(
num_correct + num_incorrect, 1, figsize=(10, (num_correct + num_incorrect) * 5)
)
for i in range(num_correct):
axes[i].imshow(correct_images[i].permute(1, 2, 0))
axes[i].set_title(
f"Correctly Classified: True Label = {correct_labels[i]}, Predicted = {correct_preds[i]}"
)
axes[i].axis("off")
for i in range(num_incorrect):
axes[num_correct + i].imshow(incorrect_images[i].permute(1, 2, 0))
axes[num_correct + i].set_title(
f"Incorrectly Classified: True Label = {incorrect_labels[i]}, Predicted = {incorrect_preds[i]}"
)
axes[num_correct + i].axis("off")
plt.tight_layout()
plt.show()
# Function to preprocess image
def preprocess_image(image_path, img_shape):
# Load the image using PIL
image = Image.open(image_path)
# Apply preprocessing transformations
preprocess = transforms.Compose([
transforms.Resize(img_shape),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = preprocess(image).unsqueeze(0)
return image
# Function to predict
def predict(model, image):
model.eval()
with torch.no_grad():
outputs = model(image)
return outputs
# Function to get model predictions for LIME
def batch_predict(model, images, device):
model.eval()
batch = torch.stack(
tuple(preprocess_image(image, (224, 224)) for image in images), dim=0
)
batch = batch.to(device)
logits = model(batch)
probs = torch.nn.functional.softmax(logits, dim=1)
return probs.detach().cpu().numpy()