pawlo2013 commited on
Commit
66b11d3
·
1 Parent(s): 4326ce4

added attention rollout for visualisation of the ViT prediction

Browse files
.history/app_20240617182329.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import ViTForImageClassification, ViTImageProcessor
6
+ from datasets import load_dataset
7
+
8
+ # Model and processor configuration
9
+ model_name_or_path = "google/vit-base-patch16-224-in21k"
10
+ processor = ViTImageProcessor.from_pretrained(model_name_or_path)
11
+
12
+ # Load dataset (adjust dataset_path accordingly)
13
+ dataset_path = "pawlo2013/chest_xray"
14
+ train_dataset = load_dataset(dataset_path, split="train")
15
+ class_names = train_dataset.features["label"].names
16
+
17
+ # Load ViT model
18
+ model = ViTForImageClassification.from_pretrained(
19
+ "./models",
20
+ num_labels=len(class_names),
21
+ id2label={str(i): label for i, label in enumerate(class_names)},
22
+ label2id={label: i for i, label in enumerate(class_names)},
23
+ )
24
+
25
+ # Set model to evaluation mode
26
+ model.eval()
27
+
28
+
29
+ # Define the classification function
30
+ def classify_image(img_path):
31
+ img = Image.open(img_path)
32
+ processed_input = processor(images=img, return_tensors="pt")
33
+ with torch.no_grad():
34
+ outputs = model(**processed_input)
35
+ logits = outputs.logits
36
+ probabilities = torch.softmax(logits, dim=1)[0].tolist()
37
+
38
+ result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
39
+ filename = os.path.basename(img_path).split(".")[0]
40
+ return {"filename": filename, "probabilities": result}
41
+
42
+
43
+ def format_output(output):
44
+ return f"{output['filename']}", output["probabilities"]
45
+
46
+
47
+ # Function to load examples from a folder
48
+ def load_examples_from_folder(folder_path):
49
+ examples = []
50
+ for file in os.listdir(folder_path):
51
+ if file.endswith((".png", ".jpg", ".jpeg")):
52
+ examples.append(os.path.join(folder_path, file))
53
+ return examples
54
+
55
+
56
+ # Define the path to the examples folder
57
+ examples_folder = "./examples"
58
+ examples = load_examples_from_folder(examples_folder)
59
+
60
+ # Create the Gradio interface
61
+ iface = gr.Interface(
62
+ fn=lambda img: format_output(classify_image(img)),
63
+ inputs=gr.Image(type="filepath"),
64
+ outputs=[gr.Textbox(label="True Label (from filename)"), gr.Label()],
65
+ examples=examples,
66
+ title="Pneumonia X-Ray 3-Class Classification with Vision Transformer (ViT) using data augmentation",
67
+ description="Upload an X-ray image to classify it as normal, viral or bacterial pneumonia. Checkout the model in more details at https://huggingface.co/pawlo2013/vit-pneumonia-x-ray_3_class. The examples presented are take from the test set of [Kermany et al. (2018) dataset.](https://data.mendeley.com/datasets/rscbjbr9sj/2)",
68
+ )
69
+
70
+ # Launch the app
71
+ if __name__ == "__main__":
72
+ iface.launch()
.history/app_20240617182353.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import ViTForImageClassification, ViTImageProcessor
6
+ from datasets import load_dataset
7
+
8
+ # Model and processor configuration
9
+ model_name_or_path = "google/vit-base-patch16-224-in21k"
10
+ processor = ViTImageProcessor.from_pretrained(model_name_or_path)
11
+
12
+ # Load dataset (adjust dataset_path accordingly)
13
+ dataset_path = "pawlo2013/chest_xray"
14
+ train_dataset = load_dataset(dataset_path, split="train")
15
+ class_names = train_dataset.features["label"].names
16
+
17
+ # Load ViT model
18
+ model = ViTForImageClassification.from_pretrained(
19
+ "./models",
20
+ num_labels=len(class_names),
21
+ id2label={str(i): label for i, label in enumerate(class_names)},
22
+ label2id={label: i for i, label in enumerate(class_names)},
23
+ )
24
+
25
+ # Set model to evaluation mode
26
+ model.eval()
27
+
28
+
29
+ # Define the classification function
30
+ def classify_image(img_path):
31
+ img = Image.open(img_path)
32
+ processed_input = processor(images=img, return_tensors="pt")
33
+ with torch.no_grad():
34
+ outputs = model(**processed_input)
35
+ logits = outputs.logits
36
+ probabilities = torch.softmax(logits, dim=1)[0].tolist()
37
+
38
+ result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
39
+ filename = os.path.basename(img_path).split(".")[0]
40
+ return {"filename": filename, "probabilities": result}
41
+
42
+
43
+ def format_output(output):
44
+ return f"{output['filename']}", output["probabilities"]
45
+
46
+
47
+ # Function to load examples from a folder
48
+ def load_examples_from_folder(folder_path):
49
+ examples = []
50
+ for file in os.listdir(folder_path):
51
+ if file.endswith((".png", ".jpg", ".jpeg")):
52
+ examples.append(os.path.join(folder_path, file))
53
+ return examples
54
+
55
+
56
+ # Define the path to the examples folder
57
+ examples_folder = "./examples"
58
+ examples = load_examples_from_folder(examples_folder)
59
+
60
+ # Create the Gradio interface
61
+ iface = gr.Interface(
62
+ fn=lambda img: format_output(classify_image(img)),
63
+ inputs=gr.Image(type="filepath"),
64
+ outputs=[gr.Textbox(label="True Label (from filename)"), gr.Label()],
65
+ examples=examples,
66
+ title="Pneumonia X-Ray 3-Class Classification with Vision Transformer (ViT) using data augmentation",
67
+ description="Upload an X-ray image to classify it as normal, viral or bacterial pneumonia. Checkout the model in more details [here] (https://huggingface.co/pawlo2013/vit-pneumonia-x-ray_3_class). The examples presented are take from the test set of [Kermany et al. (2018) dataset.](https://data.mendeley.com/datasets/rscbjbr9sj/2)",
68
+ )
69
+
70
+ # Launch the app
71
+ if __name__ == "__main__":
72
+ iface.launch()
.history/app_20240617182506.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import ViTForImageClassification, ViTImageProcessor
6
+ from datasets import load_dataset
7
+
8
+ # Model and processor configuration
9
+ model_name_or_path = "google/vit-base-patch16-224-in21k"
10
+ processor = ViTImageProcessor.from_pretrained(model_name_or_path)
11
+
12
+ # Load dataset (adjust dataset_path accordingly)
13
+ dataset_path = "pawlo2013/chest_xray"
14
+ train_dataset = load_dataset(dataset_path, split="train")
15
+ class_names = train_dataset.features["label"].names
16
+
17
+ # Load ViT model
18
+ model = ViTForImageClassification.from_pretrained(
19
+ "./models",
20
+ num_labels=len(class_names),
21
+ id2label={str(i): label for i, label in enumerate(class_names)},
22
+ label2id={label: i for i, label in enumerate(class_names)},
23
+ )
24
+
25
+ # Set model to evaluation mode
26
+ model.eval()
27
+
28
+
29
+ # Define the classification function
30
+ def classify_image(img_path):
31
+ img = Image.open(img_path)
32
+ processed_input = processor(images=img, return_tensors="pt")
33
+ with torch.no_grad():
34
+ outputs = model(**processed_input)
35
+ logits = outputs.logits
36
+ probabilities = torch.softmax(logits, dim=1)[0].tolist()
37
+
38
+ result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
39
+ filename = os.path.basename(img_path).split(".")[0]
40
+ return {"filename": filename, "probabilities": result}
41
+
42
+
43
+ def format_output(output):
44
+ return f"{output['filename']}", output["probabilities"]
45
+
46
+
47
+ # Function to load examples from a folder
48
+ def load_examples_from_folder(folder_path):
49
+ examples = []
50
+ for file in os.listdir(folder_path):
51
+ if file.endswith((".png", ".jpg", ".jpeg")):
52
+ examples.append(os.path.join(folder_path, file))
53
+ return examples
54
+
55
+
56
+ # Define the path to the examples folder
57
+ examples_folder = "./examples"
58
+ examples = load_examples_from_folder(examples_folder)
59
+
60
+ # Create the Gradio interface
61
+ iface = gr.Interface(
62
+ fn=lambda img: format_output(classify_image(img)),
63
+ inputs=gr.Image(type="filepath"),
64
+ outputs=[gr.Textbox(label="True Label (from filename)"), gr.Label()],
65
+ examples=examples,
66
+ title="Pneumonia X-Ray 3-Class Classification with Vision Transformer (ViT) using data augmentation",
67
+ description="Upload an X-ray image to classify it as normal, viral or bacterial pneumonia. Checkout the model in more details [here](https://huggingface.co/pawlo2013/vit-pneumonia-x-ray_3_class). The examples presented are take from the test set of [Kermany et al. (2018) dataset.](https://data.mendeley.com/datasets/rscbjbr9sj/2)",
68
+ )
69
+
70
+ # Launch the app
71
+ if __name__ == "__main__":
72
+ iface.launch()
app.py CHANGED
@@ -4,6 +4,10 @@ from PIL import Image
4
  import torch
5
  from transformers import ViTForImageClassification, ViTImageProcessor
6
  from datasets import load_dataset
 
 
 
 
7
 
8
  # Model and processor configuration
9
  model_name_or_path = "google/vit-base-patch16-224-in21k"
@@ -27,21 +31,37 @@ model.eval()
27
 
28
 
29
  # Define the classification function
30
- def classify_image(img_path):
31
- img = Image.open(img_path)
32
- processed_input = processor(images=img, return_tensors="pt")
 
 
 
 
33
  with torch.no_grad():
34
  outputs = model(**processed_input)
35
  logits = outputs.logits
36
  probabilities = torch.softmax(logits, dim=1)[0].tolist()
 
 
37
 
38
  result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
39
  filename = os.path.basename(img_path).split(".")[0]
40
- return {"filename": filename, "probabilities": result}
 
 
 
 
 
 
41
 
42
 
43
  def format_output(output):
44
- return f"{output['filename']}", output["probabilities"]
 
 
 
 
45
 
46
 
47
  # Function to load examples from a folder
@@ -53,20 +73,94 @@ def load_examples_from_folder(folder_path):
53
  return examples
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Define the path to the examples folder
57
  examples_folder = "./examples"
58
  examples = load_examples_from_folder(examples_folder)
59
 
60
  # Create the Gradio interface
61
  iface = gr.Interface(
62
- fn=lambda img: format_output(classify_image(img)),
63
  inputs=gr.Image(type="filepath"),
64
- outputs=[gr.Textbox(label="True Label (from filename)"), gr.Label()],
 
 
 
 
65
  examples=examples,
66
  title="Pneumonia X-Ray 3-Class Classification with Vision Transformer (ViT) using data augmentation",
67
- description="Upload an X-ray image to classify it as normal, viral or bacterial pneumonia. Checkout the model in more details at https://huggingface.co/pawlo2013/vit-pneumonia-x-ray_3_class",
68
  )
69
-
70
  # Launch the app
71
  if __name__ == "__main__":
72
  iface.launch()
 
4
  import torch
5
  from transformers import ViTForImageClassification, ViTImageProcessor
6
  from datasets import load_dataset
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import cv2
10
+
11
 
12
  # Model and processor configuration
13
  model_name_or_path = "google/vit-base-patch16-224-in21k"
 
31
 
32
 
33
  # Define the classification function
34
+ # Define the classification function
35
+ def classify_and_visualize(
36
+ img_path, device="cpu", discard_ratio=0.9, head_fusion="mean"
37
+ ):
38
+ img = Image.open(img_path).convert("RGB")
39
+ processed_input = processor(images=img, return_tensors="pt").to(device)
40
+
41
  with torch.no_grad():
42
  outputs = model(**processed_input)
43
  logits = outputs.logits
44
  probabilities = torch.softmax(logits, dim=1)[0].tolist()
45
+ prediction = torch.argmax(logits, dim=-1).item()
46
+ predicted_class = class_names[prediction]
47
 
48
  result = {class_name: prob for class_name, prob in zip(class_names, probabilities)}
49
  filename = os.path.basename(img_path).split(".")[0]
50
+
51
+ # Generate attention heatmap
52
+ heatmap_img = show_final_layer_attention_maps(
53
+ model, processed_input, device, discard_ratio, head_fusion
54
+ )
55
+
56
+ return {"filename": filename, "probabilities": result, "heatmap": heatmap_img}
57
 
58
 
59
  def format_output(output):
60
+ return (
61
+ f"{output['filename']}",
62
+ output["probabilities"],
63
+ gr.Image(value=output["heatmap"]),
64
+ )
65
 
66
 
67
  # Function to load examples from a folder
 
73
  return examples
74
 
75
 
76
+ # Function to show final layer attention maps
77
+ def show_final_layer_attention_maps(
78
+ model, tensor, device, discard_ratio=0.6, head_fusion="max", only_last_layer=False
79
+ ):
80
+ # Create a DataLoader with batch size equal to the number of images
81
+ image = tensor["pixel_values"].to(device).squeeze(0)
82
+
83
+ # Iterate over the samples
84
+ with torch.no_grad():
85
+ # Forward pass through the model
86
+ outputs = model(**tensor, output_attentions=True)
87
+
88
+ # Scale image to [0, 1]
89
+ image = image - image.min()
90
+ image = image / image.max()
91
+
92
+ # Initialize the result tensor and recursively fuse the attention maps
93
+ result = torch.eye(outputs.attentions[0].size(-1)).to(device)
94
+ if only_last_layer:
95
+ attention_list = outputs.attentions[-1].unsqueeze(0).to(device)
96
+ else:
97
+ attention_list = outputs.attentions
98
+
99
+ for attention in attention_list:
100
+ if head_fusion == "mean":
101
+ attention_heads_fused = attention.mean(axis=1)
102
+ elif head_fusion == "max":
103
+ attention_heads_fused = attention.max(axis=1)[0]
104
+ elif head_fusion == "min":
105
+ attention_heads_fused = attention.min(axis=1)[0]
106
+
107
+ flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
108
+ _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
109
+ indices = indices[indices != 0]
110
+ flat[0, indices] = 0
111
+
112
+ I = torch.eye(attention_heads_fused.size(-1)).to(device)
113
+ a = (attention_heads_fused + 1.0 * I) / 2
114
+ a = a / a.sum(dim=-1)
115
+
116
+ result = torch.matmul(a, result)
117
+
118
+ mask = result[0, 0, 1:]
119
+ # In case of 224x224 image, this brings us from 196 to 14
120
+ width = int(mask.size(-1) ** 0.5)
121
+ mask = mask.reshape(width, width).cpu().numpy()
122
+ mask = mask / np.max(mask)
123
+
124
+ mask = cv2.resize(mask, (224, 224))
125
+
126
+ # Normalize mask to [0, 1] for visualization
127
+ mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
128
+ heatmap = plt.cm.jet(mask)[:, :, :3] # Apply colormap
129
+
130
+ # Superimpose heatmap on the original image
131
+ showed_img = image.permute(1, 2, 0).detach().cpu().numpy()
132
+ showed_img = (showed_img - np.min(showed_img)) / (
133
+ np.max(showed_img) - np.min(showed_img)
134
+ ) # Normalize image
135
+ superimposed_img = (
136
+ heatmap * 0.4 + showed_img * 0.6
137
+ ) # Combine heatmap with original image
138
+
139
+ # Plot attention map
140
+ superimposed_img_pil = Image.fromarray(
141
+ (superimposed_img * 255).astype(np.uint8)
142
+ )
143
+
144
+ return superimposed_img_pil
145
+
146
+
147
  # Define the path to the examples folder
148
  examples_folder = "./examples"
149
  examples = load_examples_from_folder(examples_folder)
150
 
151
  # Create the Gradio interface
152
  iface = gr.Interface(
153
+ fn=lambda img: format_output(classify_and_visualize(img)),
154
  inputs=gr.Image(type="filepath"),
155
+ outputs=[
156
+ gr.Textbox(label="True Label (from filename)"),
157
+ gr.Label(),
158
+ gr.Image(label="Attention Heatmap"),
159
+ ],
160
  examples=examples,
161
  title="Pneumonia X-Ray 3-Class Classification with Vision Transformer (ViT) using data augmentation",
162
+ description="Upload an X-ray image to classify it as normal, viral or bacterial pneumonia. Checkout the model in more details [here](https://huggingface.co/pawlo2013/vit-pneumonia-x-ray_3_class). The examples presented are taken from the test set of [Kermany et al. (2018) dataset.](https://data.mendeley.com/datasets/rscbjbr9sj/2.) The attention heatmap over all layers of the transfomer done by the attention rollout techinique by the implementation of [jacobgil](https://github.com/jacobgil/vit-explain).",
163
  )
 
164
  # Launch the app
165
  if __name__ == "__main__":
166
  iface.launch()
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
  torch
2
  transformers
3
- datasets
 
 
 
 
 
 
1
  torch
2
  transformers
3
+ datasets
4
+ numpy
5
+ cv2
6
+ PIL
7
+ os
8
+ matplotlib