pokemon_classification / utils /interpretability.py
MostHumble's picture
add inference script
eed12b2
from lime import lime_image
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt
from utils.inference_utils import predict
import os
import torch
from PIL import Image
import numpy as np
def unnormalize(image):
# Make sure the image is on the correct dtype and device
# Convert mean and std to torch tensors with the correct dtype
mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) # Use torch.float32
std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) # Use torch.float32
# If the image is a PyTorch tensor, ensure it has the same dtype
if isinstance(image, torch.Tensor):
image = image * std + mean
else:
image = torch.tensor(image, dtype=torch.float32) * std + mean # Convert to torch if necessary
return image
def lime_interpret_image_inference(args, model, image, device):
# prepare the image
def prepare_for_plot(image): return unnormalize(image).cpu().numpy()
# Remove batch dimension and Rearrange dimensions to (H, W, C)
image = image.squeeze(0).permute(1, 2, 0) # From From [1, 3, 224, 224] to [224, 224, 3]
# Convert to NumPy array
image_np = image.cpu().numpy() # Ensure the tensor is on the CPU
# Initialize LIME explainer
explainer = lime_image.LimeImageExplainer()
# Define the prediction function
def predict_fn(x):
# Convert (B, H, W, C) to PyTorch tensor (B, C, H, W)
x_tensor = torch.tensor(x).permute(0, 3, 1, 2).to(device)
preds = model(x_tensor)
return preds.detach().cpu().numpy()
# Run LIME explanation
explanation = explainer.explain_instance(
image_np,
predict_fn,
top_labels=5,
hide_color=0,
num_samples=5000
)
# Get the mask for the top predicted class
temp, mask = explanation.get_image_and_mask(
explanation.top_labels[0],
positive_only=True,
num_features=10,
hide_rest=False
)
# Create a 2x2 subplot
fig, axs = plt.subplots(2, 2, figsize=(15, 15))
# Plot the original image
axs[0, 0].imshow(prepare_for_plot(image))
axs[0, 0].set_title("Original Image")
# Plot the feature that contributed the most positively
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
axs[0, 1].imshow(prepare_for_plot(mark_boundaries(temp, mask)))
axs[0, 1].set_title("Top Positive Features")
# Plot the features that contributed the most positively and negatively
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=1000, hide_rest=False, min_weight=0.1)
axs[1, 0].imshow(mark_boundaries(prepare_for_plot(temp), mask))
axs[1, 0].set_title("Top Positive and Negative Features")
# Plot a heatmap of the features
ind = explanation.top_labels[0]
dict_heatmap = dict(explanation.local_exp[ind])
heatmap = np.vectorize(dict_heatmap.get)(explanation.segments)
im = axs[1, 1].imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max())
axs[1, 1].set_title("Feature Heatmap")
fig.colorbar(im, ax=axs[1, 1])
plt.tight_layout()
# If classification mode is enabled, save in the appropriate directory
# check if the basename is an jpg image
if args.classify:
# Extract the class name and correctness from the image path
path_parts = args.image_path.split(os.sep)
class_name = path_parts[-3]
correctness = path_parts[-2] # correct or mistake
assert correctness in ['correct', 'mistake'], "The image path should contain 'correct' or 'mistake'"
# Create the full save path under the explanations directory
save_path = os.path.join('explanations', class_name, correctness, os.path.basename(args.image_path))
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Save the explanation
plt.savefig(save_path, dpi=300)
print(f"Explanation saved at {save_path}")
else:
# make dir for storing the explanations and save it there with the same name as the image
os.makedirs("./explanations", exist_ok=True)
plt.savefig(f"./explanations/{os.path.basename(args.image_path)}")
print(f"Explanation saved at ./explanations/{os.path.basename(args.image_path)}")