Felix Konrad commited on
Commit
5fa1af0
·
1 Parent(s): 1323bb7

Using hf_hub_download.

Browse files
Files changed (1) hide show
  1. app.py +50 -25
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import matplotlib.pyplot as plt
2
  import matplotlib.cm as cm
3
  import numpy as np
@@ -5,7 +6,8 @@ import gradio as gr
5
  from transformers import AutoModel, AutoImageProcessor
6
  from PIL import Image
7
  import torch
8
- import os
 
9
 
10
 
11
  os.environ["HF_HUB_OFFLINE"] = "0"
@@ -18,13 +20,6 @@ state = {
18
  "repo_id": None,
19
  }
20
 
21
- # Predefined supported models (must also exist locally in your Space repo)
22
- SUPPORTED_MODELS = {
23
- "Google ViT-Base (patch16-224)": "./models/vit-base-patch16-224",
24
- "Facebook DINO (ViT-S/16)": "./models/dino-vits16",
25
- "OpenAI CLIP (ViT-B/32)": "./models/clip-vit-base-patch32",
26
- }
27
-
28
 
29
  def similarity_heatmap(image):
30
  """
@@ -105,19 +100,40 @@ def load_model_dropdown(choice: str):
105
 
106
  def load_model(repo_id: str, revision: str = None):
107
  """
108
- Load a Hugging Face model and processor from a repo ID.
 
109
  """
110
  try:
111
- model = AutoModel.from_pretrained(repo_id, revision=revision, trust_remote_code=False)
112
- processor = AutoImageProcessor.from_pretrained(repo_id, revision=revision, trust_remote_code=False)
113
- # Move model to CPU/GPU if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  if torch.cuda.is_available():
115
  model.to("cuda")
116
  else:
117
  model.to("cpu")
118
 
119
- model.eval()
120
- # Store in global state
121
  state["model"] = model
122
  state["processor"] = processor
123
  state["repo_id"] = repo_id
@@ -143,25 +159,34 @@ def visualize_cosine_heatmap(image: Image):
143
  # Gradio UI
144
  with gr.Blocks() as demo:
145
  gr.Markdown("# ViT CLS-Visualizer")
 
 
 
 
146
 
147
  with gr.Row():
148
- model_choice = gr.Dropdown(
149
- choices=list(SUPPORTED_MODELS.keys()),
150
- label="Choose a Vision Transformer model",
151
- value=list(SUPPORTED_MODELS.keys())[0],
 
 
 
152
  )
153
  load_btn = gr.Button("Load Model")
154
  load_status = gr.Textbox(label="Model Status", interactive=False)
155
 
156
- image_input = gr.Image(type="pil", label="Upload Image")
157
- image_output = gr.Image(label="Uploaded Image")
158
- heatmap_output = gr.Image(label="Cosine Similarity Heatmap")
 
 
 
 
159
 
160
  # Events
161
- load_btn.click(fn=load_model_dropdown, inputs=model_choice, outputs=load_status)
162
  image_input.change(fn=display_image, inputs=image_input, outputs=image_output)
163
-
164
- compute_btn = gr.Button("Compute Heatmap")
165
  compute_btn.click(fn=visualize_cosine_heatmap, inputs=image_input, outputs=heatmap_output)
166
 
167
- demo.launch()
 
1
+ import os
2
  import matplotlib.pyplot as plt
3
  import matplotlib.cm as cm
4
  import numpy as np
 
6
  from transformers import AutoModel, AutoImageProcessor
7
  from PIL import Image
8
  import torch
9
+ from huggingface_hub import hf_hub_download
10
+
11
 
12
 
13
  os.environ["HF_HUB_OFFLINE"] = "0"
 
20
  "repo_id": None,
21
  }
22
 
 
 
 
 
 
 
 
23
 
24
  def similarity_heatmap(image):
25
  """
 
100
 
101
  def load_model(repo_id: str, revision: str = None):
102
  """
103
+ Load a Hugging Face model + processor from Hub using huggingface_hub.
104
+ Works with any public repo_id.
105
  """
106
  try:
107
+ # Explicitly download model + processor files to local cache
108
+ model_path = hf_hub_download(
109
+ repo_id=repo_id,
110
+ revision=revision,
111
+ filename="pytorch_model.bin", # default filename for weights
112
+ cache_dir="./model_cache"
113
+ )
114
+ config_path = hf_hub_download(
115
+ repo_id=repo_id,
116
+ revision=revision,
117
+ filename="config.json",
118
+ cache_dir="./model_cache"
119
+ )
120
+ processor_path = hf_hub_download(
121
+ repo_id=repo_id,
122
+ revision=revision,
123
+ filename="preprocessor_config.json",
124
+ cache_dir="./model_cache"
125
+ )
126
+
127
+ # Load with transformers (it will reuse the local cache)
128
+ model = AutoModel.from_pretrained(repo_id, revision=revision, cache_dir="./model_cache")
129
+ processor = AutoImageProcessor.from_pretrained(repo_id, revision=revision, cache_dir="./model_cache")
130
+
131
  if torch.cuda.is_available():
132
  model.to("cuda")
133
  else:
134
  model.to("cpu")
135
 
136
+ model.eval()
 
137
  state["model"] = model
138
  state["processor"] = processor
139
  state["repo_id"] = repo_id
 
159
  # Gradio UI
160
  with gr.Blocks() as demo:
161
  gr.Markdown("# ViT CLS-Visualizer")
162
+ gr.Markdown(
163
+ "Enter the Hugging Face model repo ID (must be public), upload an image, "
164
+ "and visualize the cosine similarity between the CLS token and patches."
165
+ )
166
 
167
  with gr.Row():
168
+ repo_input = gr.Textbox(
169
+ label="Hugging Face Model Repo ID",
170
+ placeholder="e.g. google/vit-base-patch16-224"
171
+ )
172
+ revision_input = gr.Textbox(
173
+ label="Revision (optional)",
174
+ placeholder="branch, tag, or commit hash"
175
  )
176
  load_btn = gr.Button("Load Model")
177
  load_status = gr.Textbox(label="Model Status", interactive=False)
178
 
179
+ with gr.Row():
180
+ image_input = gr.Image(type="pil", label="Upload Image")
181
+ image_output = gr.Image(label="Uploaded Image")
182
+
183
+ with gr.Row():
184
+ compute_btn = gr.Button("Compute Heatmap")
185
+ heatmap_output = gr.Image(label="Cosine Similarity Heatmap")
186
 
187
  # Events
188
+ load_btn.click(fn=load_model, inputs=[repo_input, revision_input], outputs=load_status)
189
  image_input.change(fn=display_image, inputs=image_input, outputs=image_output)
 
 
190
  compute_btn.click(fn=visualize_cosine_heatmap, inputs=image_input, outputs=heatmap_output)
191
 
192
+ demo.launch()