Felix Konrad commited on
Commit
1323bb7
·
1 Parent(s): 2ec5753

Added HF_HUB_OFFLINE env variable.

Browse files
Files changed (1) hide show
  1. app.py +45 -15
app.py CHANGED
@@ -5,8 +5,12 @@ import gradio as gr
5
  from transformers import AutoModel, AutoImageProcessor
6
  from PIL import Image
7
  import torch
 
8
 
9
- # Global state to store loaded model + processor
 
 
 
10
  state = {
11
  "model_type": None,
12
  "model": None,
@@ -14,6 +18,13 @@ state = {
14
  "repo_id": None,
15
  }
16
 
 
 
 
 
 
 
 
17
 
18
  def similarity_heatmap(image):
19
  """
@@ -72,6 +83,26 @@ def overlay_cosine_grid_on_image(cos_grid: np.ndarray, image: Image.Image, alpha
72
 
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def load_model(repo_id: str, revision: str = None):
76
  """
77
  Load a Hugging Face model and processor from a repo ID.
@@ -108,30 +139,29 @@ def visualize_cosine_heatmap(image: Image):
108
  blended = overlay_cosine_grid_on_image(cos_grid, image)
109
  return blended
110
 
111
- # Build the Gradio interface
 
112
  with gr.Blocks() as demo:
113
- gr.Markdown("# Dynamic ViT Loader Template")
114
-
115
- # TODO: Add drop-down menu (or something else) for user to allow choosing model type (e.g. DINOv2, Google ViT-Base etc.)
116
- # ...
117
 
118
  with gr.Row():
119
- repo_input = gr.Textbox(label="Hugging Face model repo ID", placeholder="e.g. google/vit-base-patch16-224")
120
- revision_input = gr.Textbox(label="Revision (optional)", placeholder="branch, tag, or commit hash")
 
 
 
121
  load_btn = gr.Button("Load Model")
122
  load_status = gr.Textbox(label="Model Status", interactive=False)
123
-
124
- image_input = gr.Image(type="pil", label="Upload Image")
125
- image_output = gr.Image(label="Displayed Image")
126
 
127
- # cos-sim visualization:
 
128
  heatmap_output = gr.Image(label="Cosine Similarity Heatmap")
129
 
130
- # Button clicks / image upload handlers
131
- load_btn.click(fn=load_model, inputs=[repo_input, revision_input], outputs=load_status)
132
  image_input.change(fn=display_image, inputs=image_input, outputs=image_output)
133
 
134
  compute_btn = gr.Button("Compute Heatmap")
135
  compute_btn.click(fn=visualize_cosine_heatmap, inputs=image_input, outputs=heatmap_output)
136
 
137
- demo.launch()
 
5
  from transformers import AutoModel, AutoImageProcessor
6
  from PIL import Image
7
  import torch
8
+ import os
9
 
10
+
11
+ os.environ["HF_HUB_OFFLINE"] = "0"
12
+
13
+ # Global state to store loaded model + processors
14
  state = {
15
  "model_type": None,
16
  "model": None,
 
18
  "repo_id": None,
19
  }
20
 
21
+ # Predefined supported models (must also exist locally in your Space repo)
22
+ SUPPORTED_MODELS = {
23
+ "Google ViT-Base (patch16-224)": "./models/vit-base-patch16-224",
24
+ "Facebook DINO (ViT-S/16)": "./models/dino-vits16",
25
+ "OpenAI CLIP (ViT-B/32)": "./models/clip-vit-base-patch32",
26
+ }
27
+
28
 
29
  def similarity_heatmap(image):
30
  """
 
83
 
84
 
85
 
86
+ def load_model_dropdown(choice: str):
87
+ """
88
+ Load one of the predefined models.
89
+ """
90
+ repo_path = SUPPORTED_MODELS[choice]
91
+ try:
92
+ model = AutoModel.from_pretrained(repo_path)
93
+ processor = AutoImageProcessor.from_pretrained(repo_path)
94
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
95
+ model.eval()
96
+
97
+ state["model"] = model
98
+ state["processor"] = processor
99
+ state["repo_id"] = choice
100
+ return f"Successfully loaded model: {choice}"
101
+ except Exception as e:
102
+ return f"Error loading model {choice}: {e}"
103
+
104
+
105
+
106
  def load_model(repo_id: str, revision: str = None):
107
  """
108
  Load a Hugging Face model and processor from a repo ID.
 
139
  blended = overlay_cosine_grid_on_image(cos_grid, image)
140
  return blended
141
 
142
+
143
+ # Gradio UI
144
  with gr.Blocks() as demo:
145
+ gr.Markdown("# ViT CLS-Visualizer")
 
 
 
146
 
147
  with gr.Row():
148
+ model_choice = gr.Dropdown(
149
+ choices=list(SUPPORTED_MODELS.keys()),
150
+ label="Choose a Vision Transformer model",
151
+ value=list(SUPPORTED_MODELS.keys())[0],
152
+ )
153
  load_btn = gr.Button("Load Model")
154
  load_status = gr.Textbox(label="Model Status", interactive=False)
 
 
 
155
 
156
+ image_input = gr.Image(type="pil", label="Upload Image")
157
+ image_output = gr.Image(label="Uploaded Image")
158
  heatmap_output = gr.Image(label="Cosine Similarity Heatmap")
159
 
160
+ # Events
161
+ load_btn.click(fn=load_model_dropdown, inputs=model_choice, outputs=load_status)
162
  image_input.change(fn=display_image, inputs=image_input, outputs=image_output)
163
 
164
  compute_btn = gr.Button("Compute Heatmap")
165
  compute_btn.click(fn=visualize_cosine_heatmap, inputs=image_input, outputs=heatmap_output)
166
 
167
+ demo.launch()