Spaces:
Runtime error
Runtime error
File size: 2,024 Bytes
58ba382 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
from typing import List
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
CAM_METHODS = {"GradCAM": GradCAM, "GradCAM++": GradCAMPlusPlus}
def generate_grad_cam_overlay(
model: nn.Module,
target_layer: nn.Module,
input_tensor: torch.Tensor, # (B, C, H, W)
original_frames: List[Image.Image],
target_class_idx: int,
cam_method: str = "GradCAM",
temporal_aggregation_method: str = "mean",
target_frame_idx: int = -1,
):
"""
Generates a Grad-CAM heatmap and overlays it on the original image.
Returns a PIL Image of the overlay or None on error.
"""
if cam_method not in CAM_METHODS:
raise ValueError(f"Unsupported CAM method: {cam_method}.")
cam_method = CAM_METHODS[cam_method]
with cam_method(model=model, target_layers=[target_layer]) as cam:
targets = [ClassifierOutputTarget(target_class_idx)]
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
if temporal_aggregation_method == "mean":
aggregated_heatmap_2d = np.mean(grayscale_cam, axis=0)
if target_frame_idx == -1 or not (0 <= target_frame_idx < len(original_frames)):
idx_representative_frame = len(original_frames) // 2
else:
idx_representative_frame = target_frame_idx
representative_frame = original_frames[idx_representative_frame]
rgb_image = np.array(representative_frame) / 255.0
target_h, target_w = rgb_image.shape[:2]
heatmap_resized = cv2.resize(
aggregated_heatmap_2d, (target_w, target_h), interpolation=cv2.INTER_LINEAR
)
visualization = show_cam_on_image(
rgb_image, heatmap_resized, use_rgb=True, image_weight=0.5
)
return Image.fromarray(visualization.astype(np.uint8))
|