saudi-date-classifier / src /explainability.py
Rashidbm
Initial deployment
6276d4c
"""
Explainability: Grad-CAM heatmaps and t-SNE feature visualization.
Usage:
python -m src.explainability --gradcam
python -m src.explainability --tsne
python -m src.explainability --all
"""
import argparse
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from sklearn.manifold import TSNE
from tqdm import tqdm
def reshape_transform(tensor, height=14, width=14):
"""Reshape ViT output from (B, 197, 768) to (B, 768, 14, 14) for Grad-CAM.
ViT outputs a sequence of patch tokens. We drop the CLS token and reshape
the remaining 196 tokens into a 14x14 spatial grid."""
# Remove CLS token (first token)
result = tensor[:, 1:, :]
# Reshape to spatial grid: (B, 196, 768) -> (B, 14, 14, 768) -> (B, 768, 14, 14)
result = result.reshape(result.size(0), height, width, result.size(2))
result = result.permute(0, 3, 1, 2)
return result
from src.dataset import DateFruitDataset, get_val_transforms
from src.models.vit_pretrained import PretrainedViTClassifier
from src.utils import load_config, get_device, seed_everything
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
CLASS_NAMES = [
"Ajwa", "Galaxy", "Medjool", "Meneifi", "Nabtat Ali",
"Rutab", "Shaishe", "Sokari", "Sugaey",
]
def load_vit_model(checkpoint_path: str, device: torch.device) -> PretrainedViTClassifier:
"""Load trained ViT from checkpoint."""
model = PretrainedViTClassifier(num_classes=9)
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model = model.to(device)
model.eval()
return model
def denormalize(tensor: torch.Tensor) -> np.ndarray:
"""Convert normalized tensor back to 0-1 RGB numpy array."""
mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
img = tensor.cpu() * std + mean
img = img.clamp(0, 1).permute(1, 2, 0).numpy()
return img
def generate_gradcam(
model: PretrainedViTClassifier,
dataset: DateFruitDataset,
device: torch.device,
samples_per_class: int = 2,
save_path: str = "results/gradcam_grid.png",
) -> None:
"""Generate Grad-CAM heatmap grid for each variety."""
# ViT target layer: last layernorm before attention
target_layer = model.backbone.vit.encoder.layer[-1].layernorm_before
cam = GradCAM(model=model, target_layers=[target_layer], reshape_transform=reshape_transform)
num_classes = len(CLASS_NAMES)
fig, axes = plt.subplots(
num_classes, samples_per_class * 2,
figsize=(4 * samples_per_class * 2, 3 * num_classes),
)
# Group images by class
class_indices = {v: [] for v in CLASS_NAMES}
for idx in range(len(dataset)):
_, label, variety = dataset[idx]
if len(class_indices[variety]) < samples_per_class:
class_indices[variety].append(idx)
for row, variety in enumerate(CLASS_NAMES):
indices = class_indices.get(variety, [])
for col, idx in enumerate(indices):
image_tensor, label, _ = dataset[idx]
input_tensor = image_tensor.unsqueeze(0).to(device)
grayscale_cam = cam(input_tensor=input_tensor, targets=None)
grayscale_cam = grayscale_cam[0, :]
rgb_img = denormalize(image_tensor)
# Original image
ax_orig = axes[row, col * 2]
ax_orig.imshow(rgb_img)
ax_orig.set_title(variety, fontsize=10, fontweight="bold")
ax_orig.axis("off")
# Grad-CAM overlay
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
ax_cam = axes[row, col * 2 + 1]
ax_cam.imshow(cam_image)
ax_cam.set_title("Grad-CAM", fontsize=10)
ax_cam.axis("off")
plt.suptitle("Grad-CAM: What the ViT Model Focuses On", fontsize=16, fontweight="bold", y=1.01)
plt.tight_layout()
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved: {save_path}")
def generate_tsne(
model: PretrainedViTClassifier,
dataset: DateFruitDataset,
device: torch.device,
save_path: str = "results/tsne.png",
) -> None:
"""Generate t-SNE plot of learned feature embeddings."""
model.eval()
all_features = []
all_labels = []
loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)
with torch.no_grad():
for images, labels, _ in tqdm(loader, desc="Extracting features"):
images = images.to(device)
# Get CLS token features from ViT
outputs = model.backbone.vit(pixel_values=images)
features = outputs.last_hidden_state[:, 0, :] # CLS token
all_features.append(features.cpu().numpy())
all_labels.append(labels.numpy())
features = np.concatenate(all_features)
labels = np.concatenate(all_labels)
print(f"Running t-SNE on {features.shape[0]} samples, {features.shape[1]} dimensions...")
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(features) - 1))
embeddings = tsne.fit_transform(features)
fig, ax = plt.subplots(figsize=(12, 10))
colors = plt.cm.tab10(np.linspace(0, 1, len(CLASS_NAMES)))
for i, variety in enumerate(CLASS_NAMES):
mask = labels == i
ax.scatter(
embeddings[mask, 0],
embeddings[mask, 1],
c=[colors[i]],
label=variety,
alpha=0.7,
s=50,
edgecolors="white",
linewidth=0.5,
)
ax.legend(fontsize=10, loc="best")
ax.set_title("t-SNE: How ViT Clusters Saudi Date Varieties", fontsize=14, fontweight="bold")
ax.set_xlabel("t-SNE Dimension 1")
ax.set_ylabel("t-SNE Dimension 2")
ax.grid(True, alpha=0.3)
plt.tight_layout()
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved: {save_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--gradcam", action="store_true")
parser.add_argument("--tsne", action="store_true")
parser.add_argument("--all", action="store_true")
parser.add_argument("--checkpoint", type=str, default="checkpoints/vit_best_model.pth")
args = parser.parse_args()
if args.all:
args.gradcam = True
args.tsne = True
if not args.gradcam and not args.tsne:
print("Specify --gradcam, --tsne, or --all")
return
config = load_config()
seed_everything(42)
device = get_device()
print(f"Device: {device}")
print(f"Loading model from {args.checkpoint}...")
model = load_vit_model(args.checkpoint, device)
transform = get_val_transforms(config)
dataset = DateFruitDataset("data/test.csv", transform=transform)
print(f"Test set: {len(dataset)} images")
if args.gradcam:
print("\nGenerating Grad-CAM...")
generate_gradcam(model, dataset, device)
if args.tsne:
print("\nGenerating t-SNE...")
generate_tsne(model, dataset, device)
if __name__ == "__main__":
main()