File size: 3,798 Bytes
c17bef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os, sys
import torch
import numpy as np
import cv2
from torchvision import transforms
from PIL import Image

# Add model import path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from models.efficientnet_b0 import EfficientNetB0Classifier

# --- Grad-CAM Class ---
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model.eval()
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None
        self.hooks = [
            target_layer.register_forward_hook(self.save_activation),
            target_layer.register_full_backward_hook(self.save_gradient)
        ]

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def __call__(self, input_tensor):
        input_tensor.requires_grad_()
        self.model.zero_grad()

        output = self.model(input_tensor)
        output = output.squeeze()
        score = output if output.ndim == 0 else output[0]
        score.backward(retain_graph=True)

        print(f"β–Ά ACTIVATIONS: {self.activations is not None}, GRADIENTS: {self.gradients is not None}")

        if self.gradients is None:
            raise RuntimeError("❌ Gradients not captured. Try another target layer.")

        weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)
        cam = torch.sum(weights * self.activations, dim=1).squeeze()
        cam = torch.clamp(cam, min=0)
        cam -= cam.min()
        cam /= cam.max()
        return cam.cpu().numpy()

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()

# --- Apply Grad-CAM to image and save overlay ---
def apply_gradcam_on_image(img_path, model, cam_extractor, transform, save_path):
    img = Image.open(img_path).convert("RGB")
    input_tensor = transform(img).unsqueeze(0).to(device)

    cam = cam_extractor(input_tensor)
    
    # Apply threshold to focus on high-activation areas
    threshold = 0.5
    cam[cam < threshold] = 0

    raw = cv2.imread(img_path)
    raw = cv2.resize(raw, (380, 380))

    heatmap = np.uint8(255 * cam)
    heatmap = cv2.resize(heatmap, (380, 380))
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    # Increase heatmap intensity and reduce image opacity
    overlay = cv2.addWeighted(raw, 0.3, heatmap, 0.7, 0)
    
    # Save both original and overlay
    cv2.imwrite(save_path, overlay)

# --- Main ---
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    model = EfficientNetB0Classifier()
    model.load_state_dict(torch.load("results_efficientnet_b0/efficientnet_best9912.pth", map_location=device))
    model.to(device)

    # Image transform
    transform = transforms.Compose([
        transforms.Resize((380, 380)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # βœ… Try an earlier layer for better feature focus
    target_layer = model.base_model.features[5][0].block[0]  # Using layer 5 instead of 6
    gradcam = GradCAM(model, target_layer)

    # Random selection of test images
    image_paths = np.load("test_paths.npy", allow_pickle=True).astype(str)
    np.random.seed(42)
    selected_indices = np.random.choice(len(image_paths), 5, replace=False)

    os.makedirs("results/gradcam_b0", exist_ok=True)

    for i in selected_indices:
        input_path = image_paths[i]
        output_path = f"results/gradcam_b0/gradcam_{i}.png"
        apply_gradcam_on_image(input_path, model, gradcam, transform, output_path)
        print(f"βœ… Saved Grad-CAM: {output_path}")

    gradcam.remove_hooks()