import json from pathlib import Path import cv2 import gradio as gr import numpy as np import torch from torchvision import models, transforms from torchvision.models.feature_extraction import create_feature_extractor from transformers import ViTForImageClassification device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") labels = json.loads(Path("imagenet-simple-labels.json").read_text()) # Load ResNet-50 resnet50 = models.resnet50(pretrained=True).to(device) resnet50.eval() # Create ResNet feature extractor feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"]) fc_layer_weights = resnet50.fc.weight # Load ViT vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to( device ) vit.eval() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) preprocess = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), normalize] ) examples = sorted([f.as_posix() for f in Path("examples").glob("*")]) def get_cam(img_tensor): output = feature_extractor(img_tensor) cnn_features = output["layer4"].squeeze() class_id = output["fc"].argmax() cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1)) cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2]) return cam.cpu().numpy(), labels[class_id] def get_attention_mask(img_tensor): result = vit(img_tensor, output_attentions=True) class_id = result[0].argmax() attention_probs = torch.stack(result[1]).squeeze(1) # Average the attention at each layer over all heads attention_probs = torch.mean(attention_probs, dim=1) residual = torch.eye(attention_probs.size(-1)).to(device) attention_probs = 0.5 * attention_probs + 0.5 * residual # normalize by layer attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1) attention_rollout = attention_probs[0] for i in range(1, attention_probs.size(0)): attention_rollout = torch.matmul(attention_probs[i], attention_rollout) # Attention between cls token and patches mask = attention_rollout[0, 1:] mask_size = np.sqrt(mask.size(0)).astype(int) mask = mask.reshape(mask_size, mask_size) return mask.cpu().numpy(), labels[class_id] def plot_mask_on_image(image, mask): # min-max normalization mask = (mask - mask.min()) / mask.max() mask = (255 * mask).astype(np.uint8) mask = cv2.resize(mask, image.size) heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET) result = heatmap * 0.3 + np.array(image) * 0.5 return result.astype(np.uint8) def inference(img): img_tensor = preprocess(img).unsqueeze(0).to(device) with torch.no_grad(): cam, resnet_label = get_cam(img_tensor) attention_mask, vit_label = get_attention_mask(img_tensor) cam_result = plot_mask_on_image(img, cam) rollout_result = plot_mask_on_image(img, attention_mask) return resnet_label, cam_result, vit_label, rollout_result if __name__ == "__main__": interface = gr.Interface( fn=inference, inputs=gr.inputs.Image(type="pil", label="Input Image"), outputs=[ gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"), gr.outputs.Image(type="auto", label="ResNet CAM"), gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"), gr.outputs.Image(type="auto", label="Rollout Attn Flow"), ], examples=examples, title="CNN - Transformer Explainability", live=True, ) interface.launch()