Spaces:
Runtime error
Runtime error
raedinkhaled
commited on
Commit
•
5174b1f
1
Parent(s):
b6c245f
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torchvision import models, transforms
|
9 |
+
from torchvision.models.feature_extraction import create_feature_extractor
|
10 |
+
from transformers import ViTForImageClassification
|
11 |
+
|
12 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
13 |
+
|
14 |
+
labels = json.loads(Path("labels.json").read_text())
|
15 |
+
|
16 |
+
# Load ResNet-50
|
17 |
+
resnet50 = models.resnet50(pretrained=True).to(device)
|
18 |
+
resnet50.eval()
|
19 |
+
|
20 |
+
# Create ResNet feature extractor
|
21 |
+
feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"])
|
22 |
+
fc_layer_weights = resnet50.fc.weight
|
23 |
+
|
24 |
+
# Load ViT
|
25 |
+
vit = ViTForImageClassification.from_pretrained("raedinkhaled/vit-base-mri").to(
|
26 |
+
device
|
27 |
+
)
|
28 |
+
vit.eval()
|
29 |
+
|
30 |
+
|
31 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
32 |
+
|
33 |
+
preprocess = transforms.Compose(
|
34 |
+
[transforms.Resize((224, 224)), transforms.ToTensor(), normalize]
|
35 |
+
)
|
36 |
+
|
37 |
+
examples = sorted([f.as_posix() for f in Path("examples").glob("*")])
|
38 |
+
|
39 |
+
|
40 |
+
def get_cam(img_tensor):
|
41 |
+
output = feature_extractor(img_tensor)
|
42 |
+
cnn_features = output["layer4"].squeeze()
|
43 |
+
class_id = output["fc"].argmax()
|
44 |
+
|
45 |
+
cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1))
|
46 |
+
cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2])
|
47 |
+
|
48 |
+
return cam.cpu().numpy(), labels[class_id]
|
49 |
+
|
50 |
+
|
51 |
+
def get_attention_mask(img_tensor):
|
52 |
+
result = vit(img_tensor, output_attentions=True)
|
53 |
+
class_id = result[0].argmax()
|
54 |
+
attention_probs = torch.stack(result[1]).squeeze(1)
|
55 |
+
|
56 |
+
# Average the attention at each layer over all heads
|
57 |
+
attention_probs = torch.mean(attention_probs, dim=1)
|
58 |
+
residual = torch.eye(attention_probs.size(-1)).to(device)
|
59 |
+
attention_probs = 0.5 * attention_probs + 0.5 * residual
|
60 |
+
|
61 |
+
# normalize by layer
|
62 |
+
attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1)
|
63 |
+
|
64 |
+
attention_rollout = attention_probs[0]
|
65 |
+
|
66 |
+
for i in range(1, attention_probs.size(0)):
|
67 |
+
attention_rollout = torch.matmul(attention_probs[i], attention_rollout)
|
68 |
+
|
69 |
+
# Attention between cls token and patches
|
70 |
+
mask = attention_rollout[0, 1:]
|
71 |
+
mask_size = np.sqrt(mask.size(0)).astype(int)
|
72 |
+
mask = mask.reshape(mask_size, mask_size)
|
73 |
+
|
74 |
+
return mask.cpu().numpy(), labels[class_id]
|
75 |
+
|
76 |
+
|
77 |
+
def plot_mask_on_image(image, mask):
|
78 |
+
# min-max normalization
|
79 |
+
mask = (mask - mask.min()) / mask.max()
|
80 |
+
mask = (255 * mask).astype(np.uint8)
|
81 |
+
mask = cv2.resize(mask, image.size)
|
82 |
+
|
83 |
+
heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
|
84 |
+
result = heatmap * 0.3 + np.array(image) * 0.5
|
85 |
+
return result.astype(np.uint8)
|
86 |
+
|
87 |
+
|
88 |
+
def inference(img):
|
89 |
+
img_tensor = preprocess(img).unsqueeze(0).to(device)
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
cam, resnet_label = get_cam(img_tensor)
|
93 |
+
attention_mask, vit_label = get_attention_mask(img_tensor)
|
94 |
+
|
95 |
+
cam_result = plot_mask_on_image(img, cam)
|
96 |
+
rollout_result = plot_mask_on_image(img, attention_mask)
|
97 |
+
|
98 |
+
return resnet_label, cam_result, vit_label, rollout_result
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
interface = gr.Interface(
|
102 |
+
fn=inference,
|
103 |
+
inputs=gr.inputs.Image(type="pil", label="Input Image"),
|
104 |
+
outputs=[
|
105 |
+
gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"),
|
106 |
+
gr.outputs.Image(type="auto", label="ResNet CAM"),
|
107 |
+
gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"),
|
108 |
+
gr.outputs.Image(type="auto", label="raedinkhaled/vit-base-mri CAM"),
|
109 |
+
],
|
110 |
+
examples=examples,
|
111 |
+
title="Transformer Explainability On Our Pre Trained Model",
|
112 |
+
live=True,
|
113 |
+
)
|
114 |
+
interface.launch()
|