Spaces:
Sleeping
Sleeping
| """ | |
| Grad-CAM visualization for model interpretability. | |
| """ | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from typing import Union | |
| import matplotlib.pyplot as plt | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from .dataset import get_transforms | |
| from .config import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES | |
| def get_gradcam(model, target_layer=None): | |
| """Create GradCAM object for the model.""" | |
| if target_layer is None: | |
| # Use the last conv layer of EfficientNet | |
| target_layer = model.backbone.features[-1] | |
| return GradCAM(model=model, target_layers=[target_layer]) | |
| def denormalize_image(tensor: torch.Tensor) -> np.ndarray: | |
| """Denormalize tensor to numpy image [0,1].""" | |
| mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1) | |
| std = torch.tensor(IMAGENET_STD).view(3, 1, 1) | |
| img = tensor.cpu() * std + mean | |
| img = img.permute(1, 2, 0).numpy() | |
| return np.clip(img, 0, 1) | |
| def generate_gradcam( | |
| model, | |
| image: Union[str, Path, Image.Image], | |
| device: torch.device | |
| ) -> tuple: | |
| """Generate Grad-CAM heatmap for an image.""" | |
| model.eval() | |
| # Load and transform image | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert('RGB') | |
| transform = get_transforms(is_training=False) | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| prob = torch.sigmoid(output).item() | |
| pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0] | |
| confidence = prob if prob > 0.5 else 1 - prob | |
| # Generate Grad-CAM | |
| cam = get_gradcam(model) | |
| grayscale_cam = cam(input_tensor=img_tensor, targets=None)[0] | |
| # Create visualization | |
| rgb_img = denormalize_image(img_tensor[0]) | |
| cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
| return cam_image, pred_class, confidence, rgb_img | |
| def plot_gradcam( | |
| model, | |
| image_path: Union[str, Path], | |
| true_label: str, | |
| device: torch.device, | |
| save_path: str = None | |
| ): | |
| """Plot original image with Grad-CAM overlay.""" | |
| cam_image, pred_class, confidence, original = generate_gradcam(model, image_path, device) | |
| fig, axes = plt.subplots(1, 2, figsize=(10, 4)) | |
| # Original | |
| axes[0].imshow(original) | |
| axes[0].set_title(f"Original\nTrue: {true_label}") | |
| axes[0].axis('off') | |
| # Grad-CAM | |
| color = 'green' if pred_class == true_label else 'red' | |
| axes[1].imshow(cam_image) | |
| axes[1].set_title(f"Grad-CAM\nPred: {pred_class} ({confidence:.1%})", color=color) | |
| axes[1].axis('off') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| plt.show() | |
| return pred_class, confidence | |
| def plot_gradcam_grid( | |
| model, | |
| image_paths: list, | |
| true_labels: list, | |
| device: torch.device, | |
| save_path: str = None, | |
| title: str = "Grad-CAM Visualizations" | |
| ): | |
| """Plot grid of Grad-CAM visualizations.""" | |
| n = len(image_paths) | |
| fig, axes = plt.subplots(n, 2, figsize=(8, 3 * n)) | |
| if n == 1: | |
| axes = axes.reshape(1, -1) | |
| for i, (path, true_label) in enumerate(zip(image_paths, true_labels)): | |
| cam_image, pred_class, confidence, original = generate_gradcam(model, path, device) | |
| # Original | |
| axes[i, 0].imshow(original) | |
| axes[i, 0].set_title(f"True: {true_label}") | |
| axes[i, 0].axis('off') | |
| # Grad-CAM | |
| color = 'green' if pred_class == true_label else 'red' | |
| axes[i, 1].imshow(cam_image) | |
| axes[i, 1].set_title(f"Pred: {pred_class} ({confidence:.1%})", color=color) | |
| axes[i, 1].axis('off') | |
| plt.suptitle(title, fontsize=14, fontweight='bold') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| plt.show() | |