Felix Konrad commited on
Commit
41c94d8
Β·
1 Parent(s): 5fa1af0

Please work

Browse files
Files changed (1) hide show
  1. app.py +87 -73
app.py CHANGED
@@ -6,9 +6,6 @@ import gradio as gr
6
  from transformers import AutoModel, AutoImageProcessor
7
  from PIL import Image
8
  import torch
9
- from huggingface_hub import hf_hub_download
10
-
11
-
12
 
13
  os.environ["HF_HUB_OFFLINE"] = "0"
14
 
@@ -20,10 +17,9 @@ state = {
20
  "repo_id": None,
21
  }
22
 
23
-
24
  def similarity_heatmap(image):
25
  """
26
- ...
27
  """
28
  model, processor = state["model"], state["processor"]
29
 
@@ -76,117 +72,135 @@ def overlay_cosine_grid_on_image(cos_grid: np.ndarray, image: Image.Image, alpha
76
 
77
  return blended
78
 
79
-
80
-
81
- def load_model_dropdown(choice: str):
82
- """
83
- Load one of the predefined models.
84
- """
85
- repo_path = SUPPORTED_MODELS[choice]
86
- try:
87
- model = AutoModel.from_pretrained(repo_path)
88
- processor = AutoImageProcessor.from_pretrained(repo_path)
89
- model.to("cuda" if torch.cuda.is_available() else "cpu")
90
- model.eval()
91
-
92
- state["model"] = model
93
- state["processor"] = processor
94
- state["repo_id"] = choice
95
- return f"Successfully loaded model: {choice}"
96
- except Exception as e:
97
- return f"Error loading model {choice}: {e}"
98
-
99
-
100
-
101
  def load_model(repo_id: str, revision: str = None):
102
  """
103
- Load a Hugging Face model + processor from Hub using huggingface_hub.
104
  Works with any public repo_id.
105
  """
106
  try:
107
- # Explicitly download model + processor files to local cache
108
- model_path = hf_hub_download(
109
- repo_id=repo_id,
110
- revision=revision,
111
- filename="pytorch_model.bin", # default filename for weights
112
- cache_dir="./model_cache"
113
- )
114
- config_path = hf_hub_download(
115
- repo_id=repo_id,
116
  revision=revision,
117
- filename="config.json",
118
- cache_dir="./model_cache"
119
  )
120
- processor_path = hf_hub_download(
121
- repo_id=repo_id,
 
122
  revision=revision,
123
- filename="preprocessor_config.json",
124
- cache_dir="./model_cache"
125
  )
126
 
127
- # Load with transformers (it will reuse the local cache)
128
- model = AutoModel.from_pretrained(repo_id, revision=revision, cache_dir="./model_cache")
129
- processor = AutoImageProcessor.from_pretrained(repo_id, revision=revision, cache_dir="./model_cache")
130
-
131
- if torch.cuda.is_available():
132
- model.to("cuda")
133
- else:
134
- model.to("cpu")
135
-
136
  model.eval()
 
 
 
 
 
 
137
  state["model"] = model
138
  state["processor"] = processor
139
  state["repo_id"] = repo_id
140
- return f"Successfully loaded model '{repo_id}'"
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
- return f"Error loading model: {e}"
143
 
144
  def display_image(image: Image):
145
  """
146
- Simply returns the uploaded image (you can process it later).
147
  """
148
  return image
149
 
150
  def visualize_cosine_heatmap(image: Image):
 
 
 
151
  if state["model"] is None:
152
- return None # or placeholder image
153
-
154
- cos_grid = similarity_heatmap(image)
155
- blended = overlay_cosine_grid_on_image(cos_grid, image)
156
- return blended
157
 
 
 
 
 
 
 
 
158
 
159
  # Gradio UI
160
- with gr.Blocks() as demo:
161
  gr.Markdown("# ViT CLS-Visualizer")
162
  gr.Markdown(
163
  "Enter the Hugging Face model repo ID (must be public), upload an image, "
164
  "and visualize the cosine similarity between the CLS token and patches."
165
  )
 
 
 
 
 
 
 
166
 
167
  with gr.Row():
168
  repo_input = gr.Textbox(
169
  label="Hugging Face Model Repo ID",
170
- placeholder="e.g. google/vit-base-patch16-224"
 
171
  )
172
  revision_input = gr.Textbox(
173
  label="Revision (optional)",
174
  placeholder="branch, tag, or commit hash"
175
  )
176
- load_btn = gr.Button("Load Model")
 
177
  load_status = gr.Textbox(label="Model Status", interactive=False)
178
 
179
  with gr.Row():
180
- image_input = gr.Image(type="pil", label="Upload Image")
181
- image_output = gr.Image(label="Uploaded Image")
182
-
183
- with gr.Row():
184
- compute_btn = gr.Button("Compute Heatmap")
185
- heatmap_output = gr.Image(label="Cosine Similarity Heatmap")
 
186
 
187
  # Events
188
- load_btn.click(fn=load_model, inputs=[repo_input, revision_input], outputs=load_status)
189
- image_input.change(fn=display_image, inputs=image_input, outputs=image_output)
190
- compute_btn.click(fn=visualize_cosine_heatmap, inputs=image_input, outputs=heatmap_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- demo.launch()
 
 
6
  from transformers import AutoModel, AutoImageProcessor
7
  from PIL import Image
8
  import torch
 
 
 
9
 
10
  os.environ["HF_HUB_OFFLINE"] = "0"
11
 
 
17
  "repo_id": None,
18
  }
19
 
 
20
  def similarity_heatmap(image):
21
  """
22
+ Compute cosine similarity between CLS token and patch tokens
23
  """
24
  model, processor = state["model"], state["processor"]
25
 
 
72
 
73
  return blended
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def load_model(repo_id: str, revision: str = None):
76
  """
77
+ Load a Hugging Face model + processor from Hub.
78
  Works with any public repo_id.
79
  """
80
  try:
81
+ # Clean up revision input (handle empty strings)
82
+ if revision and revision.strip() == "":
83
+ revision = None
84
+
85
+ # Load model and processor directly (they handle caching automatically)
86
+ model = AutoModel.from_pretrained(
87
+ repo_id,
 
 
88
  revision=revision,
89
+ cache_dir="./model_cache",
90
+ trust_remote_code=True # Some models might need this
91
  )
92
+
93
+ processor = AutoImageProcessor.from_pretrained(
94
+ repo_id,
95
  revision=revision,
96
+ cache_dir="./model_cache",
97
+ trust_remote_code=True
98
  )
99
 
100
+ # Move to appropriate device
101
+ device = "cuda" if torch.cuda.is_available() else "cpu"
102
+ model.to(device)
 
 
 
 
 
 
103
  model.eval()
104
+
105
+ # Validate it's a Vision Transformer
106
+ if not hasattr(model.config, 'patch_size'):
107
+ return f"❌ Model '{repo_id}' doesn't appear to be a Vision Transformer (no patch_size in config)"
108
+
109
+ # Update global state
110
  state["model"] = model
111
  state["processor"] = processor
112
  state["repo_id"] = repo_id
113
+ state["model_type"] = "custom"
114
+
115
+ return f"βœ… Successfully loaded model '{repo_id}' on {device}"
116
+
117
+ except OSError as e:
118
+ if "Repository not found" in str(e):
119
+ return f"❌ Repository '{repo_id}' not found. Please check the repo ID."
120
+ elif "offline" in str(e).lower():
121
+ return f"❌ Network error. Please check your internet connection."
122
+ else:
123
+ return f"❌ Error accessing model: {str(e)}"
124
  except Exception as e:
125
+ return f"❌ Error loading model: {str(e)}"
126
 
127
  def display_image(image: Image):
128
  """
129
+ Simply returns the uploaded image.
130
  """
131
  return image
132
 
133
  def visualize_cosine_heatmap(image: Image):
134
+ """
135
+ Generate and overlay cosine similarity heatmap on the input image.
136
+ """
137
  if state["model"] is None:
138
+ return None # Return None if no model is loaded
 
 
 
 
139
 
140
+ try:
141
+ cos_grid = similarity_heatmap(image)
142
+ blended = overlay_cosine_grid_on_image(cos_grid, image)
143
+ return blended
144
+ except Exception as e:
145
+ print(f"Error generating heatmap: {e}")
146
+ return None
147
 
148
  # Gradio UI
149
+ with gr.Blocks(title="ViT CLS Visualizer") as demo:
150
  gr.Markdown("# ViT CLS-Visualizer")
151
  gr.Markdown(
152
  "Enter the Hugging Face model repo ID (must be public), upload an image, "
153
  "and visualize the cosine similarity between the CLS token and patches."
154
  )
155
+
156
+ gr.Markdown("### Popular Vision Transformer models to try:")
157
+ gr.Markdown(
158
+ "- `google/vit-base-patch16-224`\n"
159
+ "- `facebook/deit-base-distilled-patch16-224`\n"
160
+ "- `microsoft/dit-base`"
161
+ )
162
 
163
  with gr.Row():
164
  repo_input = gr.Textbox(
165
  label="Hugging Face Model Repo ID",
166
+ placeholder="e.g. google/vit-base-patch16-224",
167
+ value="google/vit-base-patch16-224"
168
  )
169
  revision_input = gr.Textbox(
170
  label="Revision (optional)",
171
  placeholder="branch, tag, or commit hash"
172
  )
173
+ load_btn = gr.Button("Load Model", variant="primary")
174
+
175
  load_status = gr.Textbox(label="Model Status", interactive=False)
176
 
177
  with gr.Row():
178
+ with gr.Column():
179
+ image_input = gr.Image(type="pil", label="Upload Image")
180
+ image_output = gr.Image(label="Uploaded Image")
181
+
182
+ with gr.Column():
183
+ compute_btn = gr.Button("Compute Heatmap", variant="primary")
184
+ heatmap_output = gr.Image(label="Cosine Similarity Heatmap")
185
 
186
  # Events
187
+ load_btn.click(
188
+ fn=load_model,
189
+ inputs=[repo_input, revision_input],
190
+ outputs=load_status
191
+ )
192
+
193
+ image_input.change(
194
+ fn=display_image,
195
+ inputs=image_input,
196
+ outputs=image_output
197
+ )
198
+
199
+ compute_btn.click(
200
+ fn=visualize_cosine_heatmap,
201
+ inputs=image_input,
202
+ outputs=heatmap_output
203
+ )
204
 
205
+ if __name__ == "__main__":
206
+ demo.launch()