|
from lime import lime_image |
|
from skimage.segmentation import mark_boundaries |
|
import matplotlib.pyplot as plt |
|
from utils.inference_utils import predict |
|
import os |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
def unnormalize(image): |
|
|
|
|
|
mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) |
|
std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) |
|
|
|
|
|
if isinstance(image, torch.Tensor): |
|
image = image * std + mean |
|
else: |
|
image = torch.tensor(image, dtype=torch.float32) * std + mean |
|
|
|
return image |
|
|
|
|
|
|
|
def lime_interpret_image_inference(args, model, image, device): |
|
|
|
def prepare_for_plot(image): return unnormalize(image).cpu().numpy() |
|
|
|
image = image.squeeze(0).permute(1, 2, 0) |
|
|
|
|
|
image_np = image.cpu().numpy() |
|
|
|
|
|
explainer = lime_image.LimeImageExplainer() |
|
|
|
|
|
def predict_fn(x): |
|
|
|
x_tensor = torch.tensor(x).permute(0, 3, 1, 2).to(device) |
|
preds = model(x_tensor) |
|
return preds.detach().cpu().numpy() |
|
|
|
|
|
explanation = explainer.explain_instance( |
|
image_np, |
|
predict_fn, |
|
top_labels=5, |
|
hide_color=0, |
|
num_samples=5000 |
|
) |
|
|
|
|
|
temp, mask = explanation.get_image_and_mask( |
|
explanation.top_labels[0], |
|
positive_only=True, |
|
num_features=10, |
|
hide_rest=False |
|
) |
|
|
|
|
|
fig, axs = plt.subplots(2, 2, figsize=(15, 15)) |
|
|
|
|
|
axs[0, 0].imshow(prepare_for_plot(image)) |
|
axs[0, 0].set_title("Original Image") |
|
|
|
|
|
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True) |
|
axs[0, 1].imshow(prepare_for_plot(mark_boundaries(temp, mask))) |
|
axs[0, 1].set_title("Top Positive Features") |
|
|
|
|
|
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=1000, hide_rest=False, min_weight=0.1) |
|
axs[1, 0].imshow(mark_boundaries(prepare_for_plot(temp), mask)) |
|
axs[1, 0].set_title("Top Positive and Negative Features") |
|
|
|
|
|
ind = explanation.top_labels[0] |
|
dict_heatmap = dict(explanation.local_exp[ind]) |
|
heatmap = np.vectorize(dict_heatmap.get)(explanation.segments) |
|
im = axs[1, 1].imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max()) |
|
axs[1, 1].set_title("Feature Heatmap") |
|
fig.colorbar(im, ax=axs[1, 1]) |
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
if args.classify: |
|
|
|
path_parts = args.image_path.split(os.sep) |
|
class_name = path_parts[-3] |
|
correctness = path_parts[-2] |
|
assert correctness in ['correct', 'mistake'], "The image path should contain 'correct' or 'mistake'" |
|
|
|
|
|
save_path = os.path.join('explanations', class_name, correctness, os.path.basename(args.image_path)) |
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
|
|
|
plt.savefig(save_path, dpi=300) |
|
print(f"Explanation saved at {save_path}") |
|
else: |
|
|
|
os.makedirs("./explanations", exist_ok=True) |
|
plt.savefig(f"./explanations/{os.path.basename(args.image_path)}") |
|
print(f"Explanation saved at ./explanations/{os.path.basename(args.image_path)}") |
|
|