File size: 2,024 Bytes
58ba382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
from typing import List

from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus

CAM_METHODS = {"GradCAM": GradCAM, "GradCAM++": GradCAMPlusPlus}


def generate_grad_cam_overlay(
    model: nn.Module,
    target_layer: nn.Module,
    input_tensor: torch.Tensor,  # (B, C, H, W)
    original_frames: List[Image.Image],
    target_class_idx: int,
    cam_method: str = "GradCAM",
    temporal_aggregation_method: str = "mean",
    target_frame_idx: int = -1,
):
    """
    Generates a Grad-CAM heatmap and overlays it on the original image.
    Returns a PIL Image of the overlay or None on error.
    """
    if cam_method not in CAM_METHODS:
        raise ValueError(f"Unsupported CAM method: {cam_method}.")
    cam_method = CAM_METHODS[cam_method]

    with cam_method(model=model, target_layers=[target_layer]) as cam:
        targets = [ClassifierOutputTarget(target_class_idx)]
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)

        if temporal_aggregation_method == "mean":
            aggregated_heatmap_2d = np.mean(grayscale_cam, axis=0)

        if target_frame_idx == -1 or not (0 <= target_frame_idx < len(original_frames)):
            idx_representative_frame = len(original_frames) // 2
        else:
            idx_representative_frame = target_frame_idx

        representative_frame = original_frames[idx_representative_frame]
        rgb_image = np.array(representative_frame) / 255.0

        target_h, target_w = rgb_image.shape[:2]
        heatmap_resized = cv2.resize(
            aggregated_heatmap_2d, (target_w, target_h), interpolation=cv2.INTER_LINEAR
        )

        visualization = show_cam_on_image(
            rgb_image, heatmap_resized, use_rgb=True, image_weight=0.5
        )

        return Image.fromarray(visualization.astype(np.uint8))