| """ |
| Thyroid Grad-CAM Visualization (Fixed for SwinV2) |
| """ |
| import os, sys, math, json, random, warnings, traceback |
| warnings.filterwarnings("ignore") |
|
|
| import numpy as np |
| from PIL import Image |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| import torch |
| import torch.nn.functional as F |
| from datasets import load_dataset |
| from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
|
| HF_USERNAME = "Johnyquest7" |
| DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset" |
| MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid" |
| OUTPUT_DIR = "./gradcam_outputs" |
| SEED = 42 |
| BATCH_SIZE = 16 |
|
|
| random.seed(SEED) |
| np.random.seed(SEED) |
| torch.manual_seed(SEED) |
|
|
| def main(): |
| print("=" * 60) |
| print("Thyroid Grad-CAM Visualization") |
| print("=" * 60) |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\nDevice: {device}") |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval() |
| id2label = model.config.id2label |
|
|
| ds = load_dataset(DATASET_NAME, split="train") |
| ds = ds.shuffle(seed=SEED) |
| train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED) |
| test_ds = train_test["test"] |
| print(f"Test samples: {len(test_ds)}") |
|
|
| |
| all_logits, all_labels = [], [] |
| for i in range(0, len(test_ds), BATCH_SIZE): |
| batch_items = [test_ds[j] for j in range(i, min(i+BATCH_SIZE, len(test_ds)))] |
| images = [item["image"].convert("RGB") for item in batch_items] |
| inputs = processor(images, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(pixel_values=inputs["pixel_values"].to(device)) |
| all_logits.extend(outputs.logits.cpu().numpy()) |
| all_labels.extend([item["label"] for item in batch_items]) |
|
|
| y_true = np.array(all_labels) |
| y_pred = np.argmax(np.array(all_logits), axis=1) |
|
|
| correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]] |
| incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]] |
| random.shuffle(correct_idx) |
| random.shuffle(incorrect_idx) |
| selected = correct_idx[:5] + incorrect_idx[:5] |
| print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...") |
|
|
| |
| gradcam_data = {} |
| def fwd_hook(module, input, output): |
| gradcam_data["feat"] = output.detach() |
| def bwd_hook(module, grad_input, grad_output): |
| gradcam_data["grad"] = grad_output[0].detach() |
|
|
| target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after |
| fwd_handle = target_layer.register_forward_hook(fwd_hook) |
| bwd_handle = target_layer.register_full_backward_hook(bwd_hook) |
|
|
| for idx in selected: |
| try: |
| item = test_ds[idx] |
| img = item["image"].convert("RGB") |
| label = item["label"] |
| inputs = processor(img, return_tensors="pt") |
| img_tensor = inputs["pixel_values"].to(device).requires_grad_(True) |
| model.zero_grad() |
| outputs = model(pixel_values=img_tensor) |
| target_class = int(y_pred[idx]) |
| score = outputs.logits[0, target_class] |
| score.backward() |
|
|
| feat = gradcam_data["feat"][0] |
| grads = gradcam_data["grad"][0] |
|
|
| |
| weights = grads.mean(dim=0, keepdim=True) |
| cam = torch.matmul(feat, weights.t()).squeeze() |
|
|
| |
| H = W = int(math.sqrt(cam.shape[0])) |
| cam = cam.reshape(H, W) |
|
|
| |
| cam = F.relu(cam) |
| cam = cam - cam.min() |
| cam = cam / (cam.max() + 1e-8) |
|
|
| |
| cam = cam.unsqueeze(0).unsqueeze(0) |
| cam = F.interpolate(cam, size=(256, 256), mode="bilinear", align_corners=False) |
| cam = cam.squeeze().cpu().numpy() |
|
|
| |
| img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy() |
| img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8) |
|
|
| plt.figure(figsize=(6,6)) |
| plt.imshow(img_np) |
| plt.imshow(cam, cmap="jet", alpha=0.5) |
| pred_name = id2label.get(target_class, str(target_class)) |
| true_name = id2label.get(label, str(label)) |
| status = "CORRECT" if y_true[idx] == y_pred[idx] else "WRONG" |
| plt.title(f"{status}: Pred={pred_name} | True={true_name}") |
| plt.axis("off") |
| fname = f"{OUTPUT_DIR}/gradcam_{status}_sample{idx}_{pred_name}_vs_{true_name}.png" |
| plt.savefig(fname, bbox_inches="tight", dpi=150) |
| plt.close() |
| print(f" Saved {fname}") |
| except Exception as e: |
| print(f" Skipped sample {idx}: {e}") |
| traceback.print_exc() |
|
|
| fwd_handle.remove() |
| bwd_handle.remove() |
| print("\nGrad-CAM complete.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|