Felix Konrad commited on
Commit
2ec5753
·
1 Parent(s): 57c8491

Added proper Cosine-Similarity Computation + Visualization

Browse files
Files changed (1) hide show
  1. app.py +70 -17
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import matplotlib.pyplot as plt
 
2
  import numpy as np
3
  import gradio as gr
4
  from transformers import AutoModel, AutoImageProcessor
@@ -7,29 +8,68 @@ import torch
7
 
8
  # Global state to store loaded model + processor
9
  state = {
 
10
  "model": None,
11
  "processor": None,
12
  "repo_id": None,
13
  }
14
 
15
 
16
- def plot_similarity_heatmap(sim_array: np.ndarray):
17
  """
18
- sim_array: 2D numpy array of shape (h, w)
19
- Returns a PIL image that can be displayed in Gradio
20
  """
21
- fig, ax = plt.subplots(figsize=(5, 5))
22
- cax = ax.imshow(sim_array, cmap='viridis')
23
- ax.set_xticks([])
24
- ax.set_yticks([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- fig.colorbar(cax)
27
- fig.canvas.draw()
28
- img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
29
- img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
30
 
31
- plt.close(fig)
32
- return img
 
 
 
 
 
 
33
 
34
 
35
  def load_model(repo_id: str, revision: str = None):
@@ -44,6 +84,8 @@ def load_model(repo_id: str, revision: str = None):
44
  model.to("cuda")
45
  else:
46
  model.to("cpu")
 
 
47
  # Store in global state
48
  state["model"] = model
49
  state["processor"] = processor
@@ -58,10 +100,21 @@ def display_image(image: Image):
58
  """
59
  return image
60
 
 
 
 
 
 
 
 
 
61
  # Build the Gradio interface
62
  with gr.Blocks() as demo:
63
  gr.Markdown("# Dynamic ViT Loader Template")
64
 
 
 
 
65
  with gr.Row():
66
  repo_input = gr.Textbox(label="Hugging Face model repo ID", placeholder="e.g. google/vit-base-patch16-224")
67
  revision_input = gr.Textbox(label="Revision (optional)", placeholder="branch, tag, or commit hash")
@@ -72,13 +125,13 @@ with gr.Blocks() as demo:
72
  image_output = gr.Image(label="Displayed Image")
73
 
74
  # cos-sim visualization:
75
- # sim_array is your (h, w) numpy array
76
- sim_array = np.random.normal((128, 128))
77
- heatmap_img = plot_similarity_heatmap(sim_array)
78
- gr.Image(value=heatmap_img, label="Cosine Similarity Heatmap")
79
 
80
  # Button clicks / image upload handlers
81
  load_btn.click(fn=load_model, inputs=[repo_input, revision_input], outputs=load_status)
82
  image_input.change(fn=display_image, inputs=image_input, outputs=image_output)
83
 
 
 
 
84
  demo.launch()
 
1
  import matplotlib.pyplot as plt
2
+ import matplotlib.cm as cm
3
  import numpy as np
4
  import gradio as gr
5
  from transformers import AutoModel, AutoImageProcessor
 
8
 
9
  # Global state to store loaded model + processor
10
  state = {
11
+ "model_type": None,
12
  "model": None,
13
  "processor": None,
14
  "repo_id": None,
15
  }
16
 
17
 
18
+ def similarity_heatmap(image):
19
  """
20
+ ...
 
21
  """
22
+ model, processor = state["model"], state["processor"]
23
+
24
+ inputs = processor(images=image, return_tensors="pt")
25
+ pixel_values = inputs["pixel_values"].to(model.device) # shape: (1, 3, H, W)
26
+
27
+ # get ViT patch size (from model config)
28
+ patch_size = model.config.patch_size # usually 16
29
+
30
+ # Compute patch grid (needed for resizing later)
31
+ H_patch = pixel_values.shape[2] // patch_size
32
+ W_patch = pixel_values.shape[3] // patch_size
33
+
34
+ with torch.no_grad():
35
+ outputs = model(pixel_values) # last_hidden_state: (1, seq_len, hidden_dim)
36
+ last_hidden_state = outputs.last_hidden_state
37
+ cls_token = last_hidden_state[:, 0, :] # shape: (1, hidden_dim)
38
+ patch_tokens = last_hidden_state[:, 1:, :] # shape: (1, num_patches, hidden_dim)
39
+
40
+ cls_norm = cls_token / cls_token.norm(dim=-1, keepdim=True)
41
+ patch_norm = patch_tokens / patch_tokens.norm(dim=-1, keepdim=True)
42
+
43
+ cos_sim = torch.einsum("bd,bpd->bp", cls_norm, patch_norm) # shape: (1, num_patches)
44
+ cos_sim = cos_sim.reshape((H_patch, W_patch))
45
+ return np.array(cos_sim)
46
+
47
+ def overlay_cosine_grid_on_image(cos_grid: np.ndarray, image: Image.Image, alpha=0.5, colormap="viridis"):
48
+ """
49
+ cos_grid: (H_patch, W_patch) numpy array of cosine similarities
50
+ image: PIL.Image
51
+ alpha: blending factor
52
+ colormap: matplotlib colormap name
53
+ """
54
+ # Normalize cosine values to [0, 1] for colormap
55
+ norm_grid = (cos_grid - cos_grid.min()) / (cos_grid.max() - cos_grid.min() + 1e-8)
56
+
57
+ # Apply colormap
58
+ cmap = cm.get_cmap(colormap)
59
+ heatmap_rgba = cmap(norm_grid) # shape: (H_patch, W_patch, 4)
60
 
61
+ # Convert to RGB 0-255
62
+ heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
63
+ heatmap_img = Image.fromarray(heatmap_rgb)
 
64
 
65
+ # Resize heatmap to match original image size
66
+ heatmap_resized = heatmap_img.resize(image.size, resample=Image.BILINEAR)
67
+
68
+ # Blend with original image
69
+ blended = Image.blend(image.convert("RGBA"), heatmap_resized.convert("RGBA"), alpha=alpha)
70
+
71
+ return blended
72
+
73
 
74
 
75
  def load_model(repo_id: str, revision: str = None):
 
84
  model.to("cuda")
85
  else:
86
  model.to("cpu")
87
+
88
+ model.eval()
89
  # Store in global state
90
  state["model"] = model
91
  state["processor"] = processor
 
100
  """
101
  return image
102
 
103
+ def visualize_cosine_heatmap(image: Image):
104
+ if state["model"] is None:
105
+ return None # or placeholder image
106
+
107
+ cos_grid = similarity_heatmap(image)
108
+ blended = overlay_cosine_grid_on_image(cos_grid, image)
109
+ return blended
110
+
111
  # Build the Gradio interface
112
  with gr.Blocks() as demo:
113
  gr.Markdown("# Dynamic ViT Loader Template")
114
 
115
+ # TODO: Add drop-down menu (or something else) for user to allow choosing model type (e.g. DINOv2, Google ViT-Base etc.)
116
+ # ...
117
+
118
  with gr.Row():
119
  repo_input = gr.Textbox(label="Hugging Face model repo ID", placeholder="e.g. google/vit-base-patch16-224")
120
  revision_input = gr.Textbox(label="Revision (optional)", placeholder="branch, tag, or commit hash")
 
125
  image_output = gr.Image(label="Displayed Image")
126
 
127
  # cos-sim visualization:
128
+ heatmap_output = gr.Image(label="Cosine Similarity Heatmap")
 
 
 
129
 
130
  # Button clicks / image upload handlers
131
  load_btn.click(fn=load_model, inputs=[repo_input, revision_input], outputs=load_status)
132
  image_input.change(fn=display_image, inputs=image_input, outputs=image_output)
133
 
134
+ compute_btn = gr.Button("Compute Heatmap")
135
+ compute_btn.click(fn=visualize_cosine_heatmap, inputs=image_input, outputs=heatmap_output)
136
+
137
  demo.launch()