| | """ |
| | Grad-CAM Tool - Visual explanation of ConvNeXt predictions |
| | Shows which regions of the image the model focuses on. |
| | """ |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from PIL import Image |
| | from torchvision import transforms |
| | from typing import Optional, Tuple |
| | import cv2 |
| |
|
| |
|
| | class GradCAM: |
| | """ |
| | Grad-CAM implementation for ConvNeXt model. |
| | Generates heatmaps showing model attention. |
| | """ |
| |
|
| | def __init__(self, model, target_layer=None): |
| | """ |
| | Args: |
| | model: ConvNeXtDualEncoder model |
| | target_layer: Layer to extract gradients from (default: last conv layer) |
| | """ |
| | self.model = model |
| | self.gradients = None |
| | self.activations = None |
| |
|
| | |
| | if target_layer is None: |
| | target_layer = model.backbone.stages[-1] |
| |
|
| | target_layer.register_forward_hook(self._save_activation) |
| | target_layer.register_full_backward_hook(self._save_gradient) |
| |
|
| | def _save_activation(self, module, input, output): |
| | """Save activations during forward pass""" |
| | self.activations = output.detach() |
| |
|
| | def _save_gradient(self, module, grad_input, grad_output): |
| | """Save gradients during backward pass""" |
| | self.gradients = grad_output[0].detach() |
| |
|
| | def generate( |
| | self, |
| | image_tensor: torch.Tensor, |
| | target_class: Optional[int] = None, |
| | derm_tensor: Optional[torch.Tensor] = None, |
| | metadata: Optional[torch.Tensor] = None |
| | ) -> np.ndarray: |
| | """ |
| | Generate Grad-CAM heatmap. |
| | |
| | Args: |
| | image_tensor: Input image tensor [1, 3, H, W] |
| | target_class: Class index to visualize (default: predicted class) |
| | derm_tensor: Optional dermoscopy image tensor |
| | metadata: Optional metadata tensor |
| | |
| | Returns: |
| | CAM heatmap as numpy array [H, W] normalized to 0-1 |
| | """ |
| | self.model.eval() |
| |
|
| | |
| | output = self.model(image_tensor, derm_tensor, metadata) |
| |
|
| | if target_class is None: |
| | target_class = output.argmax(dim=1).item() |
| |
|
| | |
| | self.model.zero_grad() |
| | output[0, target_class].backward() |
| |
|
| | |
| | gradients = self.gradients[0] |
| | activations = self.activations[0] |
| |
|
| | |
| | weights = gradients.mean(dim=(1, 2)) |
| |
|
| | |
| | cam = torch.zeros(activations.shape[1:], dtype=torch.float32, device=activations.device) |
| | for i, w in enumerate(weights): |
| | cam += w * activations[i] |
| |
|
| | |
| | cam = F.relu(cam) |
| | cam = cam.cpu().numpy() |
| |
|
| | if cam.max() > 0: |
| | cam = (cam - cam.min()) / (cam.max() - cam.min()) |
| |
|
| | return cam |
| |
|
| | def overlay( |
| | self, |
| | image: np.ndarray, |
| | cam: np.ndarray, |
| | alpha: float = 0.5, |
| | colormap: int = cv2.COLORMAP_JET |
| | ) -> np.ndarray: |
| | """ |
| | Overlay CAM heatmap on original image. |
| | |
| | Args: |
| | image: Original image [H, W, 3] RGB uint8 |
| | cam: CAM heatmap [H, W] float 0-1 |
| | alpha: Overlay transparency |
| | colormap: OpenCV colormap |
| | |
| | Returns: |
| | Overlaid image [H, W, 3] RGB uint8 |
| | """ |
| | H, W = image.shape[:2] |
| |
|
| | |
| | cam_resized = cv2.resize(cam, (W, H)) |
| |
|
| | |
| | heatmap = cv2.applyColorMap( |
| | np.uint8(255 * cam_resized), |
| | colormap |
| | ) |
| | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
| |
|
| | |
| | overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8) |
| |
|
| | return overlay |
| |
|
| |
|
| | class GradCAMTool: |
| | """ |
| | High-level Grad-CAM tool for ConvNeXt classifier. |
| | """ |
| |
|
| | def __init__(self, classifier=None): |
| | """ |
| | Args: |
| | classifier: ConvNeXtClassifier instance (will create one if None) |
| | """ |
| | self.classifier = classifier |
| | self.gradcam = None |
| | self.loaded = False |
| |
|
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize((384, 384)), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225] |
| | ) |
| | ]) |
| |
|
| | def load(self): |
| | """Load classifier and setup Grad-CAM""" |
| | if self.loaded: |
| | return |
| |
|
| | if self.classifier is None: |
| | from models.convnext_classifier import ConvNeXtClassifier |
| | self.classifier = ConvNeXtClassifier() |
| | self.classifier.load() |
| |
|
| | self.gradcam = GradCAM(self.classifier.model) |
| | self.loaded = True |
| |
|
| | def generate_heatmap( |
| | self, |
| | image: Image.Image, |
| | target_class: Optional[int] = None |
| | ) -> Tuple[np.ndarray, np.ndarray, int, float]: |
| | """ |
| | Generate Grad-CAM heatmap for an image. |
| | |
| | Args: |
| | image: PIL Image |
| | target_class: Class to visualize (default: predicted) |
| | |
| | Returns: |
| | Tuple of (overlay_image, cam_heatmap, predicted_class, confidence) |
| | """ |
| | if not self.loaded: |
| | self.load() |
| |
|
| | |
| | if image.mode != "RGB": |
| | image = image.convert("RGB") |
| |
|
| | |
| | image_np = np.array(image.resize((384, 384))) |
| | image_tensor = self.transform(image).unsqueeze(0).to(self.classifier.device) |
| |
|
| | |
| | with torch.no_grad(): |
| | logits = self.classifier.model(image_tensor) |
| | probs = torch.softmax(logits, dim=1)[0] |
| | pred_class = probs.argmax().item() |
| | confidence = probs[pred_class].item() |
| |
|
| | |
| | if target_class is None: |
| | target_class = pred_class |
| |
|
| | |
| | cam = self.gradcam.generate(image_tensor, target_class) |
| |
|
| | |
| | overlay = self.gradcam.overlay(image_np, cam, alpha=0.5) |
| |
|
| | return overlay, cam, pred_class, confidence |
| |
|
| | def analyze( |
| | self, |
| | image: Image.Image, |
| | target_class: Optional[int] = None |
| | ) -> dict: |
| | """ |
| | Full analysis with Grad-CAM visualization. |
| | |
| | Args: |
| | image: PIL Image |
| | target_class: Class to visualize |
| | |
| | Returns: |
| | Dict with overlay_image, cam, prediction info |
| | """ |
| | from models.convnext_classifier import CLASS_NAMES, CLASS_FULL_NAMES |
| |
|
| | overlay, cam, pred_class, confidence = self.generate_heatmap(image, target_class) |
| |
|
| | return { |
| | "overlay": Image.fromarray(overlay), |
| | "cam": cam, |
| | "predicted_class": CLASS_NAMES[pred_class], |
| | "predicted_class_full": CLASS_FULL_NAMES[CLASS_NAMES[pred_class]], |
| | "confidence": confidence, |
| | "class_index": pred_class, |
| | } |
| |
|
| | def __call__(self, image: Image.Image, target_class: Optional[int] = None) -> dict: |
| | return self.analyze(image, target_class) |
| |
|
| |
|
| | |
| | _gradcam_instance = None |
| |
|
| |
|
| | def get_gradcam_tool() -> GradCAMTool: |
| | """Get or create Grad-CAM tool instance""" |
| | global _gradcam_instance |
| | if _gradcam_instance is None: |
| | _gradcam_instance = GradCAMTool() |
| | return _gradcam_instance |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| |
|
| | print("Grad-CAM Tool Test") |
| | print("=" * 50) |
| |
|
| | tool = GradCAMTool() |
| | print("Loading model...") |
| | tool.load() |
| | print("Model loaded!") |
| |
|
| | if len(sys.argv) > 1: |
| | image_path = sys.argv[1] |
| | print(f"\nAnalyzing: {image_path}") |
| |
|
| | image = Image.open(image_path).convert("RGB") |
| | result = tool.analyze(image) |
| |
|
| | print(f"\nPrediction: {result['predicted_class']} ({result['confidence']:.1%})") |
| | print(f"Full name: {result['predicted_class_full']}") |
| |
|
| | |
| | output_path = image_path.rsplit(".", 1)[0] + "_gradcam.png" |
| | result["overlay"].save(output_path) |
| | print(f"\nGrad-CAM overlay saved to: {output_path}") |
| |
|