rwightman HF staff commited on
Commit
6458094
1 Parent(s): 8c53c8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -14
app.py CHANGED
@@ -65,27 +65,91 @@ def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, f
65
  masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
66
  return masked_image.astype(np.uint8)
67
 
68
- def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image]:
69
- """Visualize attention maps for the given image and model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  model, extractor = load_model(model_name)
71
  attention_maps = process_image(image, model, extractor)
72
- num_prefix_tokens = getattr(model, 'num_prefix_tokens', 0)
73
 
 
 
 
74
  # Convert PIL Image to numpy array
75
  image_np = np.array(image)
76
 
77
  # Create visualizations
78
  visualizations = []
 
79
  for layer_name, attn_map in attention_maps.items():
80
  print(f"Attention map shape for {layer_name}: {attn_map.shape}")
81
-
82
- # Remove the CLS token attention and average over heads
83
- attn_map = attn_map[0, :, 0, num_prefix_tokens:].mean(0) # Shape: (seq_len-1,)
84
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # Reshape the attention map to 2D
86
- num_patches = int(np.sqrt(attn_map.shape[0]))
87
  attn_map = attn_map.reshape(num_patches, num_patches)
88
-
89
  # Interpolate to match image size
90
  attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
91
  attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
@@ -116,18 +180,54 @@ def visualize_attention(image: Image.Image, model_name: str) -> List[Image.Image
116
  visualizations.append(vis_image)
117
  plt.close(fig)
118
 
119
- return visualizations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # Create Gradio interface
122
  iface = gr.Interface(
123
  fn=visualize_attention,
124
  inputs=[
125
  gr.Image(type="pil", label="Input Image"),
126
- gr.Dropdown(choices=get_attention_models(), label="Select Model")
 
 
 
 
 
 
 
 
 
 
127
  ],
128
- outputs=gr.Gallery(label="Attention Maps"),
129
- title="Attention Map Visualizer for timm Models. NOTE: This is a WIP.",
130
  description="Upload an image and select a timm model to visualize its attention maps."
131
  )
132
 
133
- iface.launch(debug=True)
 
65
  masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
66
  return masked_image.astype(np.uint8)
67
 
68
+
69
+ def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
70
+ # based on https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py
71
+ result = torch.eye(attentions[0].size(-1))
72
+ with torch.no_grad():
73
+ for attention in attentions:
74
+ if head_fusion.startswith('mean'):
75
+ # mean_std fusion doesn't appear to make sense with rollout
76
+ attention_heads_fused = attention.mean(dim=0)
77
+ elif head_fusion == "max":
78
+ attention_heads_fused = attention.amax(dim=0)
79
+ elif head_fusion == "min":
80
+ attention_heads_fused = attention.amin(dim=0)
81
+ else:
82
+ raise ValueError("Attention head fusion type Not supported")
83
+
84
+ # Discard the lowest attentions, but don't discard the prefix tokens
85
+ flat = attention_heads_fused.view(-1)
86
+ _, indices = flat.topk(int(flat.size(-1 )* discard_ratio), -1, False)
87
+ print(indices)
88
+ print(indices.shape)
89
+ indices = indices[indices >= num_prefix_tokens]
90
+ flat[indices] = 0
91
+
92
+ I = torch.eye(attention_heads_fused.size(-1))
93
+ a = (attention_heads_fused + 1.0 * I) / 2
94
+ a = a / a.sum(dim=-1)
95
+ result = torch.matmul(a, result)
96
+
97
+ # Look at the total attention between the prefix tokens (usually class tokens)
98
+ # and the image patches
99
+ # FIXME this is token 0 vs non-prefix right now, need to cover other cases (> 1 prefix, no prefix, etc)
100
+ mask = result[0, num_prefix_tokens:]
101
+ width = int(mask.size(-1) ** 0.5)
102
+ mask = mask.reshape(width, width).numpy()
103
+ mask = mask / np.max(mask)
104
+ return mask
105
+
106
+
107
+ def visualize_attention(
108
+ image: Image.Image,
109
+ model_name: str,
110
+ head_fusion: str,
111
+ discard_ratio: float,
112
+ ) -> Tuple[List[Image.Image], Image.Image]:
113
+ """Visualize attention maps and rollout for the given image and model."""
114
  model, extractor = load_model(model_name)
115
  attention_maps = process_image(image, model, extractor)
 
116
 
117
+ # FIXME handle wider range of models that may not have num_prefix_tokens attr
118
+ num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1) # Default to 1 class token if not specified
119
+
120
  # Convert PIL Image to numpy array
121
  image_np = np.array(image)
122
 
123
  # Create visualizations
124
  visualizations = []
125
+ attentions_for_rollout = []
126
  for layer_name, attn_map in attention_maps.items():
127
  print(f"Attention map shape for {layer_name}: {attn_map.shape}")
128
+ attn_map = attn_map[0] # Remove batch dimension
129
+
130
+ attentions_for_rollout.append(attn_map)
131
+
132
+ attn_map = attn_map[:, :, num_prefix_tokens:] # Remove prefix tokens for visualization
133
+
134
+ if head_fusion == 'mean_std':
135
+ attn_map = attn_map.mean(0) / attn_map.std(0)
136
+ elif head_fusion == 'mean':
137
+ attn_map = attn_map.mean(0)
138
+ elif head_fusion == 'max':
139
+ attn_map = attn_map.amax(0)
140
+ elif head_fusion == 'min':
141
+ attn_map = attn_map.amin(0)
142
+ else:
143
+ raise ValueError(f"Invalid head fusion method: {head_fusion}")
144
+
145
+ # Use the first token's attention (usually the class token)
146
+ # FIXME handle different prefix token scenarios
147
+ attn_map = attn_map[0]
148
+
149
  # Reshape the attention map to 2D
150
+ num_patches = int(attn_map.shape[0] ** 0.5)
151
  attn_map = attn_map.reshape(num_patches, num_patches)
152
+
153
  # Interpolate to match image size
154
  attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
155
  attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
 
180
  visualizations.append(vis_image)
181
  plt.close(fig)
182
 
183
+ # Calculate rollout
184
+ rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)
185
+
186
+ # Create rollout visualization
187
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
188
+
189
+ # Original image
190
+ ax1.imshow(image_np)
191
+ ax1.set_title("Original Image")
192
+ ax1.axis('off')
193
+
194
+ # Rollout overlay
195
+ rollout_mask_pil = Image.fromarray((rollout_mask * 255).astype(np.uint8))
196
+ rollout_mask_resized = np.array(rollout_mask_pil.resize((image_np.shape[1], image_np.shape[0]), Image.BICUBIC)) / 255.0
197
+ masked_image = apply_mask(image_np, rollout_mask_resized, color=(1, 0, 0)) # Red mask
198
+ ax2.imshow(masked_image)
199
+ ax2.set_title('Attention Rollout')
200
+ ax2.axis('off')
201
+
202
+ plt.tight_layout()
203
+
204
+ # Convert plot to image
205
+ fig.canvas.draw()
206
+ rollout_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
207
+ plt.close(fig)
208
+
209
+ return visualizations, rollout_image
210
+
211
 
212
  # Create Gradio interface
213
  iface = gr.Interface(
214
  fn=visualize_attention,
215
  inputs=[
216
  gr.Image(type="pil", label="Input Image"),
217
+ gr.Dropdown(choices=get_attention_models(), label="Select Model"),
218
+ gr.Dropdown(
219
+ choices=['mean_std', 'mean', 'max', 'min'],
220
+ label="Head Fusion Method",
221
+ value='mean' # Default value
222
+ ),
223
+ gr.Slider(0, 1, 0.9, label="Discard Ratio", info="Ratio of lowest attentions to discard")
224
+ ],
225
+ outputs=[
226
+ gr.Gallery(label="Attention Maps"),
227
+ gr.Image(label="Attention Rollout")
228
  ],
229
+ title="Attention Map Visualizer for timm Models",
 
230
  description="Upload an image and select a timm model to visualize its attention maps."
231
  )
232
 
233
+ iface.launch()