fish-freshness-classifier / gradcam_efficientnetb0.py
roqueselopeta's picture
Initial commit with clean project files
c17bef1
import os, sys
import torch
import numpy as np
import cv2
from torchvision import transforms
from PIL import Image
# Add model import path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from models.efficientnet_b0 import EfficientNetB0Classifier
# --- Grad-CAM Class ---
class GradCAM:
def __init__(self, model, target_layer):
self.model = model.eval()
self.target_layer = target_layer
self.activations = None
self.gradients = None
self.hooks = [
target_layer.register_forward_hook(self.save_activation),
target_layer.register_full_backward_hook(self.save_gradient)
]
def save_activation(self, module, input, output):
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0]
def __call__(self, input_tensor):
input_tensor.requires_grad_()
self.model.zero_grad()
output = self.model(input_tensor)
output = output.squeeze()
score = output if output.ndim == 0 else output[0]
score.backward(retain_graph=True)
print(f"β–Ά ACTIVATIONS: {self.activations is not None}, GRADIENTS: {self.gradients is not None}")
if self.gradients is None:
raise RuntimeError("❌ Gradients not captured. Try another target layer.")
weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)
cam = torch.sum(weights * self.activations, dim=1).squeeze()
cam = torch.clamp(cam, min=0)
cam -= cam.min()
cam /= cam.max()
return cam.cpu().numpy()
def remove_hooks(self):
for h in self.hooks:
h.remove()
# --- Apply Grad-CAM to image and save overlay ---
def apply_gradcam_on_image(img_path, model, cam_extractor, transform, save_path):
img = Image.open(img_path).convert("RGB")
input_tensor = transform(img).unsqueeze(0).to(device)
cam = cam_extractor(input_tensor)
# Apply threshold to focus on high-activation areas
threshold = 0.5
cam[cam < threshold] = 0
raw = cv2.imread(img_path)
raw = cv2.resize(raw, (380, 380))
heatmap = np.uint8(255 * cam)
heatmap = cv2.resize(heatmap, (380, 380))
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Increase heatmap intensity and reduce image opacity
overlay = cv2.addWeighted(raw, 0.3, heatmap, 0.7, 0)
# Save both original and overlay
cv2.imwrite(save_path, overlay)
# --- Main ---
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = EfficientNetB0Classifier()
model.load_state_dict(torch.load("results_efficientnet_b0/efficientnet_best9912.pth", map_location=device))
model.to(device)
# Image transform
transform = transforms.Compose([
transforms.Resize((380, 380)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# βœ… Try an earlier layer for better feature focus
target_layer = model.base_model.features[5][0].block[0] # Using layer 5 instead of 6
gradcam = GradCAM(model, target_layer)
# Random selection of test images
image_paths = np.load("test_paths.npy", allow_pickle=True).astype(str)
np.random.seed(42)
selected_indices = np.random.choice(len(image_paths), 5, replace=False)
os.makedirs("results/gradcam_b0", exist_ok=True)
for i in selected_indices:
input_path = image_paths[i]
output_path = f"results/gradcam_b0/gradcam_{i}.png"
apply_gradcam_on_image(input_path, model, gradcam, transform, output_path)
print(f"βœ… Saved Grad-CAM: {output_path}")
gradcam.remove_hooks()