limhyesu98 commited on
Commit
4cf80d2
·
1 Parent(s): 5e15f55
Files changed (4) hide show
  1. .gitattributes +21 -0
  2. README.md +12 -3
  3. app.py +503 -0
  4. requirements.txt +3 -0
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +34,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
 
34
  *.zip filter=lfs diff=lfs merge=lfs -text
35
  *.zst filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ =======
38
+ chrismas-imagnet.pkl filter=lfs diff=lfs merge=lfs -text
39
+ dog-imagenet.pkl filter=lfs diff=lfs merge=lfs -text
40
+ dog-mmvp.pkl filter=lfs diff=lfs merge=lfs -text
41
+ golden_gate_bridge.pkl filter=lfs diff=lfs merge=lfs -text
42
+ hen-imagenet-r.pkl filter=lfs diff=lfs merge=lfs -text
43
+ hen-imagenet.pkl filter=lfs diff=lfs merge=lfs -text
44
+ kayaking-ucf.pkl filter=lfs diff=lfs merge=lfs -text
45
+ owl-imagenet-sketch.pkl filter=lfs diff=lfs merge=lfs -text
46
+ owl-imagenet.pkl filter=lfs diff=lfs merge=lfs -text
47
+ paphiopedilum-micranthum.pkl filter=lfs diff=lfs merge=lfs -text
48
+ phalaenopsis-aphrodite.pkl filter=lfs diff=lfs merge=lfs -text
49
+ text-1.pkl filter=lfs diff=lfs merge=lfs -text
50
+ text-2.pkl filter=lfs diff=lfs merge=lfs -text
51
+ text-3.pkl filter=lfs diff=lfs merge=lfs -text
52
+ vegetation-land-eurosat.pkl filter=lfs diff=lfs merge=lfs -text
53
+ data/sae_data/mean_act_values_caltech101.pkl.gz filter=lfs diff=lfs merge=lfs -text
54
+ data/sae_data/mean_act_values_imagenet-sketch.pkl.gz filter=lfs diff=lfs merge=lfs -text
55
+ data/sae_data/mean_act_values_imagenet.pkl.gz filter=lfs diff=lfs merge=lfs -text
56
+ >>>>>>> master
README.md CHANGED
@@ -1,10 +1,19 @@
1
  ---
 
2
  title: Patchsae Demo
3
- emoji: 📚
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.8.0
 
 
 
 
 
 
 
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ <<<<<<< HEAD
3
  title: Patchsae Demo
4
+ emoji: 😻
5
+ colorFrom: red
6
+ colorTo: gray
7
  sdk: gradio
8
  sdk_version: 5.8.0
9
+ =======
10
+ title: Paper14240
11
+ emoji: 📈
12
+ colorFrom: blue
13
+ colorTo: pink
14
+ sdk: gradio
15
+ sdk_version: 5.5.0
16
+ >>>>>>> master
17
  app_file: app.py
18
  pinned: false
19
  ---
app.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import os
3
+ import pickle
4
+ from glob import glob
5
+ from time import sleep
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import plotly.graph_objects as go
10
+ import torch
11
+ from PIL import Image, ImageDraw
12
+ from plotly.subplots import make_subplots
13
+
14
+ IMAGE_SIZE = 400
15
+ DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
16
+ GRID_NUM = 14
17
+ pkl_root = "./data/out"
18
+ preloaded_data = {}
19
+
20
+
21
+ def preload_activation(image_name):
22
+ for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
23
+ image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
24
+ with gzip.open(image_file, "rb") as f:
25
+ preloaded_data[model] = pickle.load(f)
26
+
27
+
28
+ def get_activation_distribution(image_name: str, model_type: str):
29
+ activation = get_data(image_name, model_type)[0]
30
+
31
+ noisy_features_indices = (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
32
+ activation[:, noisy_features_indices] = 0
33
+
34
+ return activation
35
+
36
+
37
+ def get_grid_loc(evt, image):
38
+ # Get click coordinates
39
+ x, y = evt._data["index"][0], evt._data["index"][1]
40
+
41
+ cell_width = image.width // GRID_NUM
42
+ cell_height = image.height // GRID_NUM
43
+
44
+ grid_x = x // cell_width
45
+ grid_y = y // cell_height
46
+ return grid_x, grid_y, cell_width, cell_height
47
+
48
+
49
+ def highlight_grid(evt: gr.EventData, image_name):
50
+ image = data_dict[image_name]["image"]
51
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
52
+
53
+ highlighted_image = image.copy()
54
+ draw = ImageDraw.Draw(highlighted_image)
55
+ box = [grid_x * cell_width, grid_y * cell_height, (grid_x + 1) * cell_width, (grid_y + 1) * cell_height]
56
+ draw.rectangle(box, outline="red", width=3)
57
+
58
+ return highlighted_image
59
+
60
+
61
+ def load_image(img_name):
62
+ return Image.open(data_dict[img_name]["image_path"]).resize((IMAGE_SIZE, IMAGE_SIZE))
63
+
64
+
65
+ def plot_activations(
66
+ all_activation, tile_activations=None, grid_x=None, grid_y=None, top_k=5, colors=("blue", "cyan"), model_name="CLIP"
67
+ ):
68
+ fig = go.Figure()
69
+
70
+ def _add_scatter_with_annotation(fig, activations, model_name, color, label):
71
+ fig.add_trace(
72
+ go.Scatter(
73
+ x=np.arange(len(activations)),
74
+ y=activations,
75
+ mode="lines",
76
+ name=label,
77
+ line=dict(color=color, dash="solid"),
78
+ showlegend=True,
79
+ )
80
+ )
81
+ top_neurons = np.argsort(activations)[::-1][:top_k]
82
+ for idx in top_neurons:
83
+ fig.add_annotation(
84
+ x=idx,
85
+ y=activations[idx],
86
+ text=str(idx),
87
+ showarrow=True,
88
+ arrowhead=2,
89
+ ax=0,
90
+ ay=-15,
91
+ arrowcolor=color,
92
+ opacity=0.7,
93
+ )
94
+ return fig
95
+
96
+ label = f"{model_name.split('-')[-0]} Image-level"
97
+ fig = _add_scatter_with_annotation(fig, all_activation, model_name, colors[0], label)
98
+ if tile_activations is not None:
99
+ label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
100
+ fig = _add_scatter_with_annotation(fig, tile_activations, model_name, colors[1], label)
101
+
102
+ fig.update_layout(
103
+ title="Activation Distribution",
104
+ xaxis_title="SAE latent index",
105
+ yaxis_title="Activation Value",
106
+ template="plotly_white",
107
+ )
108
+ fig.update_layout(legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5))
109
+
110
+ return fig
111
+
112
+
113
+ def get_activations(evt: gr.EventData, selected_image: str, model_name: str, colors):
114
+ activation = get_activation_distribution(selected_image, model_name)
115
+ all_activation = activation.mean(0)
116
+
117
+ tile_activations = None
118
+ grid_x = None
119
+ grid_y = None
120
+
121
+ if evt is not None:
122
+ if evt._data is not None:
123
+ image = data_dict[selected_image]["image"]
124
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
125
+ token_idx = grid_y * GRID_NUM + grid_x + 1
126
+ tile_activations = activation[token_idx]
127
+
128
+ fig = plot_activations(
129
+ all_activation, tile_activations, grid_x, grid_y, top_k=5, model_name=model_name, colors=colors
130
+ )
131
+ return fig
132
+
133
+
134
+ def plot_activation_distribution(evt: gr.EventData, selected_image: str, model_name: str):
135
+ fig = make_subplots(
136
+ rows=2,
137
+ cols=1,
138
+ shared_xaxes=True,
139
+ subplot_titles=["CLIP Activation", f"{model_name} Activation"],
140
+ )
141
+
142
+ fig_clip = get_activations(evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef"))
143
+ fig_maple = get_activations(evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4"))
144
+
145
+ def _attach_fig(fig, sub_fig, row, col, yref):
146
+ for trace in sub_fig.data:
147
+ fig.add_trace(trace, row=row, col=col)
148
+
149
+ for annotation in sub_fig.layout.annotations:
150
+ annotation.update(yref=yref)
151
+ fig.add_annotation(annotation)
152
+ return fig
153
+
154
+ fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
155
+ fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
156
+
157
+ fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
158
+ fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
159
+ fig.update_yaxes(title_text="Activation Value", row=1, col=1)
160
+ fig.update_yaxes(title_text="Activation Value", row=2, col=1)
161
+ fig.update_layout(
162
+ # height=500,
163
+ # title="Activation Distributions",
164
+ template="plotly_white",
165
+ showlegend=True,
166
+ legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
167
+ margin=dict(l=20, r=20, t=40, b=20),
168
+ )
169
+
170
+ return fig
171
+
172
+
173
+ def get_segmask(selected_image, slider_value, model_type):
174
+ image = data_dict[selected_image]["image"]
175
+ sae_act = get_data(selected_image, model_type)[0]
176
+ temp = sae_act[:, slider_value]
177
+ try:
178
+ mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
179
+ except Exception as e:
180
+ print(sae_act.shape, slider_value)
181
+ mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
182
+ mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
183
+
184
+ base_opacity = 30
185
+ image_array = np.array(image)[..., :3]
186
+ rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
187
+ rgba_overlay[..., :3] = image_array[..., :3]
188
+
189
+ darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
190
+ rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
191
+ rgba_overlay[..., 3] = 255 # Fully opaque
192
+
193
+ return rgba_overlay
194
+
195
+
196
+ def get_top_images(slider_value, toggle_btn):
197
+ def _get_images(dataset_path):
198
+ top_image_paths = [
199
+ os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
200
+ os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
201
+ os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
202
+ ]
203
+ top_images = [
204
+ Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
205
+ for path in top_image_paths
206
+ ]
207
+ return top_images
208
+
209
+ if toggle_btn:
210
+ top_images = _get_images("./data/top_images_masked")
211
+ else:
212
+ top_images = _get_images("./data/top_images")
213
+ return top_images
214
+
215
+
216
+ def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
217
+ slider_value = int(slider_value.split("-")[-1])
218
+ rgba_overlay = get_segmask(selected_image, slider_value, model_type)
219
+ top_images = get_top_images(slider_value, toggle_btn)
220
+
221
+ act_values = []
222
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
223
+ act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5]
224
+ act_value = [str(round(value, 3)) for value in act_value]
225
+ act_value = " | ".join(act_value)
226
+ out = f"#### Activation values: {act_value}"
227
+ act_values.append(out)
228
+
229
+ return rgba_overlay, top_images, act_values
230
+
231
+
232
+ def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
233
+ rgba_overlay, top_images, act_values = show_activation_heatmap(selected_image, slider_value, "CLIP", toggle_btn)
234
+ sleep(0.1)
235
+ return (rgba_overlay, top_images[0], top_images[1], top_images[2], act_values[0], act_values[1], act_values[2])
236
+
237
+
238
+ def show_activation_heatmap_maple(selected_image, slider_value, model_name):
239
+ slider_value = int(slider_value.split("-")[-1])
240
+ rgba_overlay = get_segmask(selected_image, slider_value, model_name)
241
+ sleep(0.1)
242
+ return rgba_overlay
243
+
244
+
245
+ def get_init_radio_options(selected_image, model_name):
246
+ clip_neuron_dict = {}
247
+ maple_neuron_dict = {}
248
+
249
+ def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
250
+ activations = get_activation_distribution(selected_image, model_name).mean(0)
251
+ top_neurons = list(np.argsort(activations)[::-1][:top_k])
252
+ for top_neuron in top_neurons:
253
+ neuron_dict[top_neuron] = activations[top_neuron]
254
+ sorted_dict = dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
255
+ return sorted_dict
256
+
257
+ clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
258
+ maple_neuron_dict = _get_top_actvation(selected_image, model_name, maple_neuron_dict)
259
+
260
+ radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
261
+
262
+ return radio_choices
263
+
264
+
265
+ def get_radio_names(clip_neuron_dict, maple_neuron_dict):
266
+ clip_keys = list(clip_neuron_dict.keys())
267
+ maple_keys = list(maple_neuron_dict.keys())
268
+
269
+ common_keys = list(set(clip_keys).intersection(set(maple_keys)))
270
+ clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
271
+ maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
272
+
273
+ common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
274
+ clip_only_keys.sort(reverse=True)
275
+ maple_only_keys.sort(reverse=True)
276
+
277
+ out = []
278
+ out.extend([f"common-{i}" for i in common_keys[:5]])
279
+ out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
280
+ out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
281
+
282
+ return out
283
+
284
+
285
+ def update_radio_options(evt: gr.EventData, selected_image, model_name):
286
+ def _sort_and_save_top_k(activations, neuron_dict, top_k=5):
287
+ top_neurons = list(np.argsort(activations)[::-1][:top_k])
288
+ for top_neuron in top_neurons:
289
+ neuron_dict[top_neuron] = activations[top_neuron]
290
+
291
+ def _get_top_actvation(evt, selected_image, model_name, neuron_dict):
292
+ all_activation = get_activation_distribution(selected_image, model_name)
293
+ image_activation = all_activation.mean(0)
294
+ _sort_and_save_top_k(image_activation, neuron_dict)
295
+
296
+ if evt is not None:
297
+ if evt._data is not None and isinstance(evt._data["index"], list):
298
+ image = data_dict[selected_image]["image"]
299
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
300
+ token_idx = grid_y * GRID_NUM + grid_x + 1
301
+ tile_activations = all_activation[token_idx]
302
+ _sort_and_save_top_k(tile_activations, neuron_dict)
303
+
304
+ sorted_dict = dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
305
+ return sorted_dict
306
+
307
+ clip_neuron_dict = {}
308
+ maple_neuron_dict = {}
309
+ clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
310
+ maple_neuron_dict = _get_top_actvation(evt, selected_image, model_name, maple_neuron_dict)
311
+
312
+ clip_keys = list(clip_neuron_dict.keys())
313
+ maple_keys = list(maple_neuron_dict.keys())
314
+
315
+ common_keys = list(set(clip_keys).intersection(set(maple_keys)))
316
+ clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
317
+ maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
318
+
319
+ common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
320
+ clip_only_keys.sort(reverse=True)
321
+ maple_only_keys.sort(reverse=True)
322
+
323
+ out = []
324
+ out.extend([f"common-{i}" for i in common_keys[:5]])
325
+ out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
326
+ out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
327
+
328
+ radio_choices = gr.Radio(choices=out, label="Top activating SAE latent", value=out[0])
329
+ sleep(0.1)
330
+ return radio_choices
331
+
332
+
333
+ def update_markdown(option_value):
334
+ latent_idx = int(option_value.split("-")[-1])
335
+ out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
336
+ out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
337
+ return out_1, out_2
338
+
339
+
340
+ def get_data(image_name, model_name):
341
+ pkl_root = "./data/out"
342
+ data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
343
+ with gzip.open(data_dir, "rb") as f:
344
+ data = pickle.load(f)
345
+ out = data
346
+
347
+ return out
348
+
349
+
350
+ def load_all_data(image_root, pkl_root):
351
+ image_files = glob(f"{image_root}/*")
352
+ data_dict = {}
353
+ for image_file in image_files:
354
+ image_name = os.path.basename(image_file).split(".")[0]
355
+ if image_file not in data_dict:
356
+ data_dict[image_name] = {
357
+ "image": Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE)),
358
+ "image_path": image_file,
359
+ }
360
+
361
+ sae_data_dict = {}
362
+ with open("./data/sae_data/mean_acts.pkl", "rb") as f:
363
+ data = pickle.load(f)
364
+ sae_data_dict["mean_acts"] = data
365
+
366
+ sae_data_dict["mean_act_values"] = {}
367
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
368
+ with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
369
+ data = pickle.load(f)
370
+ sae_data_dict["mean_act_values"][dataset] = data
371
+
372
+ return data_dict, sae_data_dict
373
+
374
+
375
+ data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
376
+ default_image_name = "christmas-imagenet"
377
+
378
+
379
+ with gr.Blocks(
380
+ theme=gr.themes.Citrus(),
381
+ css="""
382
+ .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
383
+ .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
384
+ """,
385
+ ) as demo:
386
+ with gr.Row():
387
+ with gr.Column():
388
+ # Left View: Image selection and click handling
389
+ gr.Markdown("## Select input image and patch on the image")
390
+ image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
391
+ image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
392
+
393
+ # Update image display when a new image is selected
394
+ image_selector.change(
395
+ fn=lambda img_name: data_dict[img_name]["image"], inputs=image_selector, outputs=image_display
396
+ )
397
+ image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
398
+
399
+ with gr.Column():
400
+ gr.Markdown("## SAE latent activations of CLIP and MaPLE")
401
+ model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
402
+ model_selector = gr.Dropdown(
403
+ choices=model_options, value=model_options[0], label="Select adapted model (MaPLe)"
404
+ )
405
+ init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
406
+ neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
407
+
408
+ image_selector.change(
409
+ fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
410
+ )
411
+ image_display.select(
412
+ fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
413
+ )
414
+ model_selector.change(fn=load_image, inputs=[image_selector], outputs=image_display)
415
+ model_selector.change(
416
+ fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
417
+ )
418
+
419
+ with gr.Row():
420
+ with gr.Column():
421
+ radio_names = get_init_radio_options(default_image_name, model_options[0])
422
+
423
+ feautre_idx = radio_names[0].split("-")[-1]
424
+ markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
425
+ init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
426
+
427
+ gr.Markdown("### Localize SAE latent activation using CLIP")
428
+ seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
429
+ init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], model_options[0])
430
+ gr.Markdown("### Localize SAE latent activation using MaPLE")
431
+ seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
432
+
433
+ with gr.Column():
434
+ gr.Markdown("## Top activating SAE latent index")
435
+
436
+ radio_choices = gr.Radio(
437
+ choices=radio_names, label="Top activating SAE latent", interactive=True, value=radio_names[0]
438
+ )
439
+ toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
440
+
441
+ markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
442
+
443
+ gr.Markdown("### ImageNet")
444
+ top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
445
+ act_value_1 = gr.Markdown(init_values[0])
446
+
447
+ gr.Markdown("### ImageNet-Sketch")
448
+ top_image_2 = gr.Image(value=init_tops[1], type="pil", label="ImageNet-Sketch", show_label=False)
449
+ act_value_2 = gr.Markdown(init_values[1])
450
+
451
+ gr.Markdown("### Caltech101")
452
+ top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
453
+ act_value_3 = gr.Markdown(init_values[2])
454
+
455
+ image_display.select(
456
+ fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
457
+ )
458
+
459
+ model_selector.change(
460
+ fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
461
+ )
462
+
463
+ image_selector.select(
464
+ fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
465
+ )
466
+
467
+ radio_choices.change(
468
+ fn=update_markdown,
469
+ inputs=[radio_choices],
470
+ outputs=[markdown_display, markdown_display_2],
471
+ queue=True,
472
+ )
473
+
474
+ radio_choices.change(
475
+ fn=show_activation_heatmap_clip,
476
+ inputs=[image_selector, radio_choices, toggle_btn],
477
+ outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
478
+ queue=True,
479
+ )
480
+
481
+ radio_choices.change(
482
+ fn=show_activation_heatmap_maple,
483
+ inputs=[image_selector, radio_choices, model_selector],
484
+ outputs=[seg_mask_display_maple],
485
+ queue=True,
486
+ )
487
+
488
+ # toggle_btn.change(
489
+ # fn=get_top_images,
490
+ # inputs=[radio_choices, toggle_btn],
491
+ # outputs=[top_image_1, top_image_2, top_image_3],
492
+ # queue=True,
493
+ # )
494
+
495
+ toggle_btn.change(
496
+ fn=show_activation_heatmap_clip,
497
+ inputs=[image_selector, radio_choices, toggle_btn],
498
+ outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
499
+ queue=True,
500
+ )
501
+
502
+ # Launch the app
503
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ matplotlib
3
+ plotly