hyesulim commited on
Commit
c574085
·
verified ·
1 Parent(s): cabf670

perf: improve delay

Browse files
Files changed (1) hide show
  1. app.py +696 -335
app.py CHANGED
@@ -2,7 +2,10 @@ import gzip
2
  import os
3
  import pickle
4
  from glob import glob
5
- from time import sleep
 
 
 
6
 
7
  import gradio as gr
8
  import numpy as np
@@ -11,47 +14,259 @@ import torch
11
  from PIL import Image, ImageDraw
12
  from plotly.subplots import make_subplots
13
 
 
14
  IMAGE_SIZE = 400
15
  DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
16
  GRID_NUM = 14
17
  pkl_root = "./data/out"
 
 
18
  preloaded_data = {}
 
 
 
 
 
19
 
 
 
20
 
21
- def preload_activation(image_name):
22
- for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
23
- image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
24
- with gzip.open(image_file, "rb") as f:
25
- preloaded_data[model] = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def get_activation_distribution(image_name: str, model_type: str):
29
- activation = get_data(image_name, model_type)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  noisy_features_indices = (
32
  (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
33
  )
34
  activation[:, noisy_features_indices] = 0
35
-
36
  return activation
37
 
38
-
39
  def get_grid_loc(evt, image):
 
40
  # Get click coordinates
41
  x, y = evt._data["index"][0], evt._data["index"][1]
42
-
43
  cell_width = image.width // GRID_NUM
44
  cell_height = image.height // GRID_NUM
45
-
46
  grid_x = x // cell_width
47
  grid_y = y // cell_height
48
  return grid_x, grid_y, cell_width, cell_height
49
 
50
-
51
- def highlight_grid(evt: gr.EventData, image_name):
52
  image = data_dict[image_name]["image"]
53
  grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
54
-
55
  highlighted_image = image.copy()
56
  draw = ImageDraw.Draw(highlighted_image)
57
  box = [
@@ -61,16 +276,14 @@ def highlight_grid(evt: gr.EventData, image_name):
61
  (grid_y + 1) * cell_height,
62
  ]
63
  draw.rectangle(box, outline="red", width=3)
64
-
65
  return highlighted_image
66
 
67
-
68
  def load_image(img_name):
69
- return Image.open(data_dict[img_name]["image_path"]).resize(
70
- (IMAGE_SIZE, IMAGE_SIZE)
71
- )
72
-
73
 
 
74
  def plot_activations(
75
  all_activation,
76
  tile_activations=None,
@@ -80,19 +293,28 @@ def plot_activations(
80
  colors=("blue", "cyan"),
81
  model_name="CLIP",
82
  ):
 
83
  fig = go.Figure()
84
-
85
  def _add_scatter_with_annotation(fig, activations, model_name, color, label):
 
 
 
 
 
 
86
  fig.add_trace(
87
  go.Scatter(
88
- x=np.arange(len(activations)),
89
- y=activations,
90
  mode="lines",
91
  name=label,
92
  line=dict(color=color, dash="solid"),
93
  showlegend=True,
94
  )
95
  )
 
 
96
  top_neurons = np.argsort(activations)[::-1][:top_k]
97
  for idx in top_neurons:
98
  fig.add_annotation(
@@ -107,45 +329,46 @@ def plot_activations(
107
  opacity=0.7,
108
  )
109
  return fig
110
-
111
- label = f"{model_name.split('-')[-0]} Image-level"
112
  fig = _add_scatter_with_annotation(
113
  fig, all_activation, model_name, colors[0], label
114
  )
 
115
  if tile_activations is not None:
116
- label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
117
  fig = _add_scatter_with_annotation(
118
  fig, tile_activations, model_name, colors[1], label
119
  )
120
-
 
121
  fig.update_layout(
122
  title="Activation Distribution",
123
  xaxis_title="SAE latent index",
124
  yaxis_title="Activation Value",
125
  template="plotly_white",
 
126
  )
127
- fig.update_layout(
128
- legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
129
- )
130
-
131
  return fig
132
 
133
-
134
- def get_activations(evt: gr.EventData, selected_image: str, model_name: str, colors):
135
  activation = get_activation_distribution(selected_image, model_name)
136
  all_activation = activation.mean(0)
137
-
138
  tile_activations = None
139
  grid_x = None
140
  grid_y = None
141
-
142
- if evt is not None:
143
- if evt._data is not None:
144
- image = data_dict[selected_image]["image"]
145
- grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
146
- token_idx = grid_y * GRID_NUM + grid_x + 1
 
147
  tile_activations = activation[token_idx]
148
-
149
  fig = plot_activations(
150
  all_activation,
151
  tile_activations,
@@ -155,124 +378,291 @@ def get_activations(evt: gr.EventData, selected_image: str, model_name: str, col
155
  model_name=model_name,
156
  colors=colors,
157
  )
 
158
  return fig
159
 
160
-
161
- def plot_activation_distribution(
162
- evt: gr.EventData, selected_image: str, model_name: str
163
- ):
 
 
 
 
 
 
164
  fig = make_subplots(
165
  rows=2,
166
  cols=1,
167
  shared_xaxes=True,
168
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
169
  )
170
-
171
  fig_clip = get_activations(
172
  evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
173
  )
174
  fig_maple = get_activations(
175
  evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
176
  )
177
-
178
  def _attach_fig(fig, sub_fig, row, col, yref):
179
  for trace in sub_fig.data:
180
  fig.add_trace(trace, row=row, col=col)
181
-
182
  for annotation in sub_fig.layout.annotations:
183
  annotation.update(yref=yref)
184
  fig.add_annotation(annotation)
185
  return fig
186
-
187
  fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
188
  fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
189
-
 
190
  fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
191
  fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
192
  fig.update_yaxes(title_text="Activation Value", row=1, col=1)
193
  fig.update_yaxes(title_text="Activation Value", row=2, col=1)
194
  fig.update_layout(
195
- # height=500,
196
- # title="Activation Distributions",
197
  template="plotly_white",
198
  showlegend=True,
199
  legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
200
  margin=dict(l=20, r=20, t=40, b=20),
201
  )
202
-
203
  return fig
204
 
205
-
 
206
  def get_segmask(selected_image, slider_value, model_type):
207
- image = data_dict[selected_image]["image"]
208
- sae_act = get_data(selected_image, model_type)[0]
209
- temp = sae_act[:, slider_value]
210
  try:
211
- mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  except Exception as e:
213
- print(sae_act.shape, slider_value)
214
- mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
215
- 0
216
- ].numpy()
217
- mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
218
-
219
- base_opacity = 30
220
- image_array = np.array(image)[..., :3]
221
- rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
222
- rgba_overlay[..., :3] = image_array[..., :3]
223
-
224
- darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
225
- rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
226
- rgba_overlay[..., 3] = 255 # Fully opaque
227
-
228
- return rgba_overlay
229
-
230
 
 
 
231
  def get_top_images(slider_value, toggle_btn):
 
 
 
 
 
 
232
  def _get_images(dataset_path):
233
  top_image_paths = [
234
  os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
235
  os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
236
  os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
237
  ]
238
- top_images = [
239
- (
240
- Image.open(path)
241
- if os.path.exists(path)
242
- else Image.new("RGB", (256, 256), (255, 255, 255))
243
- )
244
- for path in top_image_paths
245
- ]
246
  return top_images
247
-
248
  if toggle_btn:
249
  top_images = _get_images("./data/top_images_masked")
250
  else:
251
  top_images = _get_images("./data/top_images")
 
 
 
 
252
  return top_images
253
 
254
-
255
  def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
256
- slider_value = int(slider_value.split("-")[-1])
257
- rgba_overlay = get_segmask(selected_image, slider_value, model_type)
258
- top_images = get_top_images(slider_value, toggle_btn)
259
-
260
- act_values = []
261
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
262
- act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5]
263
- act_value = [str(round(value, 3)) for value in act_value]
264
- act_value = " | ".join(act_value)
265
- out = f"#### Activation values: {act_value}"
266
- act_values.append(out)
267
-
268
- return rgba_overlay, top_images, act_values
269
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
 
272
  rgba_overlay, top_images, act_values = show_activation_heatmap(
273
  selected_image, slider_value, "CLIP", toggle_btn
274
  )
275
- sleep(0.1)
276
  return (
277
  rgba_overlay,
278
  top_images[0],
@@ -283,18 +673,19 @@ def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
283
  act_values[2],
284
  )
285
 
286
-
287
  def show_activation_heatmap_maple(selected_image, slider_value, model_name):
288
- slider_value = int(slider_value.split("-")[-1])
289
- rgba_overlay = get_segmask(selected_image, slider_value, model_name)
290
- sleep(0.1)
 
291
  return rgba_overlay
292
 
293
-
294
  def get_init_radio_options(selected_image, model_name):
 
295
  clip_neuron_dict = {}
296
  maple_neuron_dict = {}
297
-
298
  def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
299
  activations = get_activation_distribution(selected_image, model_name).mean(0)
300
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
@@ -304,127 +695,138 @@ def get_init_radio_options(selected_image, model_name):
304
  sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
305
  )
306
  return sorted_dict
307
-
308
- clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
309
- maple_neuron_dict = _get_top_actvation(
310
- selected_image, model_name, maple_neuron_dict
311
- )
312
-
 
 
 
313
  radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
314
-
315
  return radio_choices
316
 
317
-
318
  def get_radio_names(clip_neuron_dict, maple_neuron_dict):
 
319
  clip_keys = list(clip_neuron_dict.keys())
320
  maple_keys = list(maple_neuron_dict.keys())
321
-
 
322
  common_keys = list(set(clip_keys).intersection(set(maple_keys)))
323
- clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
324
- maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
325
-
 
326
  common_keys.sort(
327
- key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
 
328
  )
329
- clip_only_keys.sort(reverse=True)
330
- maple_only_keys.sort(reverse=True)
331
-
 
332
  out = []
333
  out.extend([f"common-{i}" for i in common_keys[:5]])
334
  out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
335
  out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
336
-
337
  return out
338
 
339
-
340
- def update_radio_options(evt: gr.EventData, selected_image, model_name):
341
- def _sort_and_save_top_k(activations, neuron_dict, top_k=5):
342
- top_neurons = list(np.argsort(activations)[::-1][:top_k])
343
- for top_neuron in top_neurons:
344
- neuron_dict[top_neuron] = activations[top_neuron]
345
-
346
- def _get_top_actvation(evt, selected_image, model_name, neuron_dict):
347
  all_activation = get_activation_distribution(selected_image, model_name)
348
  image_activation = all_activation.mean(0)
349
- _sort_and_save_top_k(image_activation, neuron_dict)
350
-
351
- if evt is not None:
352
- if evt._data is not None and isinstance(evt._data["index"], list):
353
- image = data_dict[selected_image]["image"]
354
- grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
355
- token_idx = grid_y * GRID_NUM + grid_x + 1
 
 
 
 
 
 
 
356
  tile_activations = all_activation[token_idx]
357
- _sort_and_save_top_k(tile_activations, neuron_dict)
358
-
359
- sorted_dict = dict(
360
- sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
361
- )
362
- return sorted_dict
363
-
364
- clip_neuron_dict = {}
365
- maple_neuron_dict = {}
366
- clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
367
- maple_neuron_dict = _get_top_actvation(
368
- evt, selected_image, model_name, maple_neuron_dict
369
- )
370
-
371
- clip_keys = list(clip_neuron_dict.keys())
372
- maple_keys = list(maple_neuron_dict.keys())
373
-
374
- common_keys = list(set(clip_keys).intersection(set(maple_keys)))
375
- clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
376
- maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
377
-
378
- common_keys.sort(
379
- key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
380
- )
381
- clip_only_keys.sort(reverse=True)
382
- maple_only_keys.sort(reverse=True)
383
-
384
- out = []
385
- out.extend([f"common-{i}" for i in common_keys[:5]])
386
- out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
387
- out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
388
-
389
- radio_choices = gr.Radio(
390
- choices=out, label="Top activating SAE latent", value=out[0]
391
  )
392
- sleep(0.1)
393
- return radio_choices
394
-
395
 
396
  def update_markdown(option_value):
 
397
  latent_idx = int(option_value.split("-")[-1])
398
  out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
399
  out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
400
  return out_1, out_2
401
 
402
-
403
- def get_data(image_name, model_name):
404
- pkl_root = "./data/out"
405
- data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
406
- with gzip.open(data_dir, "rb") as f:
407
- data = pickle.load(f)
408
- out = data
409
-
410
- return out
411
-
412
-
413
  def update_all(selected_image, slider_value, toggle_btn, model_name):
414
- (
415
- seg_mask_display,
416
- top_image_1,
417
- top_image_2,
418
- top_image_3,
419
- act_value_1,
420
- act_value_2,
421
- act_value_3,
422
- ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
423
- seg_mask_display_maple = show_activation_heatmap_maple(
424
- selected_image, slider_value, model_name
425
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  markdown_display, markdown_display_2 = update_markdown(slider_value)
427
-
428
  return (
429
  seg_mask_display,
430
  seg_mask_display_maple,
@@ -438,42 +840,17 @@ def update_all(selected_image, slider_value, toggle_btn, model_name):
438
  markdown_display_2,
439
  )
440
 
441
-
442
- def load_all_data(image_root, pkl_root):
443
- image_files = glob(f"{image_root}/*")
444
- data_dict = {}
445
- for image_file in image_files:
446
- image_name = os.path.basename(image_file).split(".")[0]
447
- if image_file not in data_dict:
448
- data_dict[image_name] = {
449
- "image": Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE)),
450
- "image_path": image_file,
451
- }
452
-
453
- sae_data_dict = {}
454
- with open("./data/sae_data/mean_acts.pkl", "rb") as f:
455
- data = pickle.load(f)
456
- sae_data_dict["mean_acts"] = data
457
-
458
- sae_data_dict["mean_act_values"] = {}
459
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
460
- with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
461
- data = pickle.load(f)
462
- sae_data_dict["mean_act_values"][dataset] = data
463
-
464
- return data_dict, sae_data_dict
465
-
466
-
467
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
468
  default_image_name = "christmas-imagenet"
469
 
470
-
471
  with gr.Blocks(
472
  theme=gr.themes.Citrus(),
473
  css="""
474
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
475
  .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
476
- """,
477
  ) as demo:
478
  with gr.Row():
479
  with gr.Column():
@@ -485,21 +862,36 @@ with gr.Blocks(
485
  label="Select Image",
486
  )
487
  image_display = gr.Image(
488
- value=data_dict[default_image_name]["image"],
489
  type="pil",
490
  interactive=True,
491
  )
492
-
493
- # Update image display when a new image is selected
494
  image_selector.change(
495
- fn=lambda img_name: data_dict[img_name]["image"],
496
  inputs=image_selector,
497
  outputs=image_display,
 
 
 
 
 
 
 
 
 
 
 
498
  )
 
 
499
  image_display.select(
500
- fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
 
 
501
  )
502
-
503
  with gr.Column():
504
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
505
  model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
@@ -508,139 +900,108 @@ with gr.Blocks(
508
  value=model_options[0],
509
  label="Select adapted model (MaPLe)",
510
  )
511
- init_plot = plot_activation_distribution(
512
- None, default_image_name, model_options[0]
513
- )
514
  neuron_plot = gr.Plot(
515
- label="Neuron Activation", value=init_plot, show_label=False
 
516
  )
517
-
518
- image_selector.change(
519
- fn=plot_activation_distribution,
 
 
 
 
 
 
 
 
 
 
 
 
520
  inputs=[image_selector, model_selector],
521
  outputs=neuron_plot,
522
  )
 
 
523
  image_display.select(
524
- fn=plot_activation_distribution,
525
- inputs=[image_selector, model_selector],
526
- outputs=neuron_plot,
527
- )
528
- model_selector.change(
529
- fn=load_image, inputs=[image_selector], outputs=image_display
530
- )
531
- model_selector.change(
532
- fn=plot_activation_distribution,
533
  inputs=[image_selector, model_selector],
534
  outputs=neuron_plot,
535
  )
536
 
537
  with gr.Row():
538
  with gr.Column():
539
- radio_names = get_init_radio_options(default_image_name, model_options[0])
540
-
541
- feautre_idx = radio_names[0].split("-")[-1]
542
- markdown_display = gr.Markdown(
543
- f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
544
- )
545
- init_seg, init_tops, init_values = show_activation_heatmap(
546
- default_image_name, radio_names[0], "CLIP"
547
- )
548
-
549
  gr.Markdown("### Localize SAE latent activation using CLIP")
550
- seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
551
- init_seg_maple, _, _ = show_activation_heatmap(
552
- default_image_name, radio_names[0], model_options[0]
553
- )
554
  gr.Markdown("### Localize SAE latent activation using MaPLE")
555
- seg_mask_display_maple = gr.Image(
556
- value=init_seg_maple, type="pil", show_label=False
557
- )
558
-
559
  with gr.Column():
560
  gr.Markdown("## Top activating SAE latent index")
561
-
 
562
  radio_choices = gr.Radio(
563
- choices=radio_names,
564
  label="Top activating SAE latent",
565
  interactive=True,
566
- value=radio_names[0],
567
  )
568
- toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
569
-
570
- markdown_display_2 = gr.Markdown(
571
- f"## Top reference images for the selected SAE latent - {feautre_idx}"
 
 
 
 
 
572
  )
573
-
 
 
 
 
 
574
  gr.Markdown("### ImageNet")
575
- top_image_1 = gr.Image(
576
- value=init_tops[0], type="pil", label="ImageNet", show_label=False
577
- )
578
- act_value_1 = gr.Markdown(init_values[0])
579
-
580
  gr.Markdown("### ImageNet-Sketch")
581
- top_image_2 = gr.Image(
582
- value=init_tops[1],
583
- type="pil",
584
- label="ImageNet-Sketch",
585
- show_label=False,
586
- )
587
- act_value_2 = gr.Markdown(init_values[1])
588
-
589
  gr.Markdown("### Caltech101")
590
- top_image_3 = gr.Image(
591
- value=init_tops[2], type="pil", label="Caltech101", show_label=False
592
- )
593
- act_value_3 = gr.Markdown(init_values[2])
594
-
595
  image_display.select(
596
  fn=update_radio_options,
597
  inputs=[image_selector, model_selector],
598
- outputs=[radio_choices],
599
  )
600
-
 
601
  model_selector.change(
602
  fn=update_radio_options,
603
  inputs=[image_selector, model_selector],
604
- outputs=[radio_choices],
605
  )
606
-
607
- image_selector.select(
 
608
  fn=update_radio_options,
609
  inputs=[image_selector, model_selector],
610
- outputs=[radio_choices],
611
  )
612
-
613
- radio_choices.change(
614
- fn=update_all,
615
- inputs=[image_selector, radio_choices, toggle_btn, model_selector],
616
- outputs=[
617
- seg_mask_display,
618
- seg_mask_display_maple,
619
- top_image_1,
620
- top_image_2,
621
- top_image_3,
622
- act_value_1,
623
- act_value_2,
624
- act_value_3,
625
- markdown_display,
626
- markdown_display_2,
627
- ],
628
- )
629
-
630
- toggle_btn.change(
631
- fn=show_activation_heatmap_clip,
632
- inputs=[image_selector, radio_choices, toggle_btn],
633
- outputs=[
634
- seg_mask_display,
635
- top_image_1,
636
- top_image_2,
637
- top_image_3,
638
- act_value_1,
639
- act_value_2,
640
- act_value_3,
641
- ],
642
- )
643
-
644
- # Launch the app
645
- # demo.queue()
646
- demo.launch()
 
2
  import os
3
  import pickle
4
  from glob import glob
5
+ from functools import lru_cache
6
+ import concurrent.futures
7
+ import threading
8
+ import time
9
 
10
  import gradio as gr
11
  import numpy as np
 
14
  from PIL import Image, ImageDraw
15
  from plotly.subplots import make_subplots
16
 
17
+ # Constants
18
  IMAGE_SIZE = 400
19
  DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
20
  GRID_NUM = 14
21
  pkl_root = "./data/out"
22
+
23
+ # Global cache for preloaded data
24
  preloaded_data = {}
25
+ data_dict = {}
26
+ sae_data_dict = {}
27
+ activation_cache = {}
28
+ segmask_cache = {}
29
+ top_images_cache = {}
30
 
31
+ # Thread lock for thread-safe operations
32
+ data_lock = threading.Lock()
33
 
34
+ # Load data more efficiently
35
+ def load_all_data(image_root, pkl_root):
36
+ """Load all necessary data with optimized caching"""
37
+ # Load image data
38
+ image_files = glob(f"{image_root}/*")
39
+ data_dict = {}
40
+
41
+ # Use thread pool for parallel image loading
42
+ def load_image_data(image_file):
43
+ image_name = os.path.basename(image_file).split(".")[0]
44
+ # Only load thumbnail for initial display, load full image on demand
45
+ thumbnail = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
46
+ return image_name, {
47
+ "image": thumbnail,
48
+ "image_path": image_file,
49
+ }
50
+
51
+ # Load images in parallel
52
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
53
+ results = executor.map(load_image_data, image_files)
54
+ for image_name, data in results:
55
+ data_dict[image_name] = data
56
+
57
+ # Load SAE data with minimal processing
58
+ sae_data_dict = {}
59
+
60
+ # Load mean acts only once
61
+ with open("./data/sae_data/mean_acts.pkl", "rb") as f:
62
+ sae_data_dict["mean_acts"] = pickle.load(f)
63
+
64
+ # Update all components when radio selection changes
65
+ radio_choices.change(
66
+ fn=update_all,
67
+ inputs=[image_selector, radio_choices, toggle_btn, model_selector],
68
+ outputs=[
69
+ seg_mask_display,
70
+ seg_mask_display_maple,
71
+ top_image_1,
72
+ top_image_2,
73
+ top_image_3,
74
+ act_value_1,
75
+ act_value_2,
76
+ act_value_3,
77
+ markdown_display,
78
+ markdown_display_2,
79
+ ],
80
+ _js="""
81
+ function(img, radio, toggle, model) {
82
+ // Add a small delay to prevent rapid UI updates
83
+ clearTimeout(window._radioTimeout);
84
+ return new Promise((resolve) => {
85
+ window._radioTimeout = setTimeout(() => {
86
+ resolve([img, radio, toggle, model]);
87
+ }, 100);
88
+ });
89
+ }
90
+ """
91
+ )
92
+
93
+ # Update components when toggle button changes
94
+ toggle_btn.change(
95
+ fn=show_activation_heatmap_clip,
96
+ inputs=[image_selector, radio_choices, toggle_btn],
97
+ outputs=[
98
+ seg_mask_display,
99
+ top_image_1,
100
+ top_image_2,
101
+ top_image_3,
102
+ act_value_1,
103
+ act_value_2,
104
+ act_value_3,
105
+ ],
106
+ _js="""
107
+ function(img, radio, toggle) {
108
+ // Add a small delay to prevent rapid UI updates
109
+ clearTimeout(window._toggleTimeout);
110
+ return new Promise((resolve) => {
111
+ window._toggleTimeout = setTimeout(() => {
112
+ resolve([img, radio, toggle]);
113
+ }, 100);
114
+ });
115
+ }
116
+ """
117
+ )
118
 
119
+ # Initialize UI with default values
120
+ default_options = get_init_radio_options(default_image_name, model_options[0])
121
+ if default_options:
122
+ default_option = default_options[0]
123
+
124
+ # Set initial values to avoid blank UI at start
125
+ gr.on(
126
+ gr.Blocks.load,
127
+ fn=lambda: update_all(
128
+ default_image_name,
129
+ default_option,
130
+ False,
131
+ model_options[0]
132
+ ),
133
+ outputs=[
134
+ seg_mask_display,
135
+ seg_mask_display_maple,
136
+ top_image_1,
137
+ top_image_2,
138
+ top_image_3,
139
+ act_value_1,
140
+ act_value_2,
141
+ act_value_3,
142
+ markdown_display,
143
+ markdown_display_2,
144
+ ],
145
+ )
146
 
147
+ # Add a status indicator to show processing state
148
+ status_indicator = gr.Markdown("Status: Ready")
149
+
150
+ # Add a refresh button to manually reload data if needed
151
+ refresh_btn = gr.Button("Refresh Data")
152
+
153
+ def reload_data():
154
+ global data_dict, sae_data_dict
155
+
156
+ # Update status
157
+ yield "Status: Reloading data..."
158
+
159
+ # Reload data
160
+ try:
161
+ data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
162
+ yield "Status: Data reloaded successfully!"
163
+ except Exception as e:
164
+ yield f"Status: Error reloading data - {str(e)}"
165
+
166
+ refresh_btn.click(
167
+ fn=reload_data,
168
+ inputs=[],
169
+ outputs=[status_indicator],
170
+ queue=False
171
+ )
172
+
173
+ # Launch app with optimized settings
174
+ demo.queue(concurrency_count=3, max_size=10) # Balanced concurrency for better performance
175
+
176
+ # Add startup message
177
+ print("Starting visualization application...")
178
+ print(f"Loaded {len(data_dict)} images and {len(sae_data_dict)} datasets")
179
+
180
+ # Launch with proper error handling
181
+ demo.launch(
182
+ share=False, # Don't share publicly
183
+ debug=False, # Disable debug mode for production
184
+ show_error=True, # Show errors for debugging
185
+ quiet=False, # Show startup messages
186
+ favicon_path=None, # Default favicon
187
+ server_port=None, # Use default port
188
+ server_name=None, # Bind to all interfaces
189
+ height=None, # Use default height
190
+ width=None, # Use default width
191
+ enable_queue=True, # Enable queue for better performance
192
+ ) dictionary for dataset values
193
+ sae_data_dict["mean_act_values"] = {}
194
+
195
+ # Load dataset values in parallel
196
+ def load_dataset_values(dataset):
197
+ with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
198
+ return dataset, pickle.load(f)
199
+
200
+ with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
201
+ futures = [
202
+ executor.submit(load_dataset_values, dataset)
203
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
204
+ ]
205
+ for future in concurrent.futures.as_completed(futures):
206
+ dataset, data = future.result()
207
+ sae_data_dict["mean_act_values"][dataset] = data
208
+
209
+ return data_dict, sae_data_dict
210
 
211
+ # Cache activation data with LRU cache
212
+ @lru_cache(maxsize=32)
213
+ def preload_activation(image_name, model_name):
214
+ """Preload and cache activation data for a specific image and model"""
215
+ image_file = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
216
+
217
+ try:
218
+ with gzip.open(image_file, "rb") as f:
219
+ return pickle.load(f)
220
+ except Exception as e:
221
+ print(f"Error loading {image_file}: {e}")
222
+ return None
223
+
224
+ # Get activation with caching
225
+ def get_data(image_name, model_type):
226
+ """Get activation data with caching for better performance"""
227
+ cache_key = f"{image_name}_{model_type}"
228
+
229
+ with data_lock:
230
+ if cache_key not in activation_cache:
231
+ activation_cache[cache_key] = preload_activation(image_name, model_type)
232
+
233
+ return activation_cache[cache_key]
234
+
235
+ def get_activation_distribution(image_name, model_type):
236
+ """Get activation distribution with noise filtering"""
237
+ activation = get_data(image_name, model_type)
238
+
239
+ if activation is None:
240
+ # Return empty tensor if data loading failed
241
+ return torch.zeros((GRID_NUM * GRID_NUM + 1, 1000))
242
+
243
+ activation = activation[0]
244
+
245
+ # Filter out noisy features
246
  noisy_features_indices = (
247
  (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
248
  )
249
  activation[:, noisy_features_indices] = 0
250
+
251
  return activation
252
 
 
253
  def get_grid_loc(evt, image):
254
+ """Get grid location from click event"""
255
  # Get click coordinates
256
  x, y = evt._data["index"][0], evt._data["index"][1]
257
+
258
  cell_width = image.width // GRID_NUM
259
  cell_height = image.height // GRID_NUM
260
+
261
  grid_x = x // cell_width
262
  grid_y = y // cell_height
263
  return grid_x, grid_y, cell_width, cell_height
264
 
265
+ def highlight_grid(evt, image_name):
266
+ """Highlight grid cell on click"""
267
  image = data_dict[image_name]["image"]
268
  grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
269
+
270
  highlighted_image = image.copy()
271
  draw = ImageDraw.Draw(highlighted_image)
272
  box = [
 
276
  (grid_y + 1) * cell_height,
277
  ]
278
  draw.rectangle(box, outline="red", width=3)
279
+
280
  return highlighted_image
281
 
 
282
  def load_image(img_name):
283
+ """Load image by name"""
284
+ return data_dict[img_name]["image"]
 
 
285
 
286
+ # Optimized plotting with less annotations
287
  def plot_activations(
288
  all_activation,
289
  tile_activations=None,
 
293
  colors=("blue", "cyan"),
294
  model_name="CLIP",
295
  ):
296
+ """Plot activations with optimized rendering"""
297
  fig = go.Figure()
298
+
299
  def _add_scatter_with_annotation(fig, activations, model_name, color, label):
300
+ # Only plot non-zero values to reduce points
301
+ non_zero_indices = np.where(np.abs(activations) > 1e-5)[0]
302
+ if len(non_zero_indices) == 0:
303
+ # If all values are near zero, use full array
304
+ non_zero_indices = np.arange(len(activations))
305
+
306
  fig.add_trace(
307
  go.Scatter(
308
+ x=non_zero_indices,
309
+ y=activations[non_zero_indices],
310
  mode="lines",
311
  name=label,
312
  line=dict(color=color, dash="solid"),
313
  showlegend=True,
314
  )
315
  )
316
+
317
+ # Only annotate the top_k activations
318
  top_neurons = np.argsort(activations)[::-1][:top_k]
319
  for idx in top_neurons:
320
  fig.add_annotation(
 
329
  opacity=0.7,
330
  )
331
  return fig
332
+
333
+ label = f"{model_name.split('-')[-1]} Image-level"
334
  fig = _add_scatter_with_annotation(
335
  fig, all_activation, model_name, colors[0], label
336
  )
337
+
338
  if tile_activations is not None:
339
+ label = f"{model_name.split('-')[-1]} Tile ({grid_x}, {grid_y})"
340
  fig = _add_scatter_with_annotation(
341
  fig, tile_activations, model_name, colors[1], label
342
  )
343
+
344
+ # Optimize layout with minimal settings
345
  fig.update_layout(
346
  title="Activation Distribution",
347
  xaxis_title="SAE latent index",
348
  yaxis_title="Activation Value",
349
  template="plotly_white",
350
+ legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5),
351
  )
352
+
 
 
 
353
  return fig
354
 
355
+ def get_activations(evt, selected_image, model_name, colors):
356
+ """Get activations for plotting"""
357
  activation = get_activation_distribution(selected_image, model_name)
358
  all_activation = activation.mean(0)
359
+
360
  tile_activations = None
361
  grid_x = None
362
  grid_y = None
363
+
364
+ if evt is not None and evt._data is not None:
365
+ image = data_dict[selected_image]["image"]
366
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
367
+ token_idx = grid_y * GRID_NUM + grid_x + 1
368
+ # Ensure token_idx is within bounds
369
+ if token_idx < activation.shape[0]:
370
  tile_activations = activation[token_idx]
371
+
372
  fig = plot_activations(
373
  all_activation,
374
  tile_activations,
 
378
  model_name=model_name,
379
  colors=colors,
380
  )
381
+
382
  return fig
383
 
384
+ # Cache plot results
385
+ @lru_cache(maxsize=16)
386
+ def plot_activation_distribution(evt_data, selected_image, model_name):
387
+ """Plot activation distribution with caching"""
388
+ # Convert event data to hashable format for caching
389
+ if evt_data is not None:
390
+ evt = type('obj', (object,), {'_data': evt_data})
391
+ else:
392
+ evt = None
393
+
394
  fig = make_subplots(
395
  rows=2,
396
  cols=1,
397
  shared_xaxes=True,
398
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
399
  )
400
+
401
  fig_clip = get_activations(
402
  evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
403
  )
404
  fig_maple = get_activations(
405
  evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
406
  )
407
+
408
  def _attach_fig(fig, sub_fig, row, col, yref):
409
  for trace in sub_fig.data:
410
  fig.add_trace(trace, row=row, col=col)
411
+
412
  for annotation in sub_fig.layout.annotations:
413
  annotation.update(yref=yref)
414
  fig.add_annotation(annotation)
415
  return fig
416
+
417
  fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
418
  fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
419
+
420
+ # Optimize layout with minimal settings
421
  fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
422
  fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
423
  fig.update_yaxes(title_text="Activation Value", row=1, col=1)
424
  fig.update_yaxes(title_text="Activation Value", row=2, col=1)
425
  fig.update_layout(
 
 
426
  template="plotly_white",
427
  showlegend=True,
428
  legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
429
  margin=dict(l=20, r=20, t=40, b=20),
430
  )
431
+
432
  return fig
433
 
434
+ # Cache segmentation masks
435
+ @lru_cache(maxsize=32)
436
  def get_segmask(selected_image, slider_value, model_type):
437
+ """Generate segmentation mask with caching"""
 
 
438
  try:
439
+ # Check if image exists
440
+ if selected_image not in data_dict:
441
+ print(f"Image {selected_image} not found in data dictionary")
442
+ # Return blank mask with IMAGE_SIZE dimensions
443
+ return np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8)
444
+
445
+ # Use cache if available
446
+ cache_key = f"{selected_image}_{slider_value}_{model_type}"
447
+ with data_lock:
448
+ if cache_key in segmask_cache:
449
+ return segmask_cache[cache_key]
450
+
451
+ # Get image
452
+ image = data_dict[selected_image]["image"]
453
+
454
+ # Get activation data
455
+ sae_act = get_data(selected_image, model_type)
456
+
457
+ if sae_act is None:
458
+ # Return blank mask if data loading failed
459
+ return np.zeros((image.height, image.width, 4), dtype=np.uint8)
460
+
461
+ # Handle array shape issues
462
+ try:
463
+ # Check array shape and dimensions
464
+ if isinstance(sae_act, tuple) and len(sae_act) > 0:
465
+ # First element of tuple
466
+ act_data = sae_act[0]
467
+ else:
468
+ # Direct array
469
+ act_data = sae_act
470
+
471
+ # Check if slider_value is within bounds
472
+ if slider_value >= act_data.shape[1]:
473
+ print(f"Slider value {slider_value} out of bounds for activation shape {act_data.shape}")
474
+ return np.zeros((image.height, image.width, 4), dtype=np.uint8)
475
+
476
+ # Get activation for specific latent
477
+ temp = act_data[:, slider_value]
478
+
479
+ # Skip first token (CLS token) and reshape to grid
480
+ if len(temp) > 1: # Ensure we have enough tokens
481
+ mask = torch.Tensor(temp[1:].reshape(GRID_NUM, GRID_NUM)).view(1, 1, GRID_NUM, GRID_NUM)
482
+
483
+ # Upsample to image dimensions
484
+ mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
485
+
486
+ # Normalize mask values between 0 and 1
487
+ mask_min, mask_max = mask.min(), mask.max()
488
+ if mask_max > mask_min: # Avoid division by zero
489
+ mask = (mask - mask_min) / (mask_max - mask_min)
490
+ else:
491
+ mask = np.zeros_like(mask)
492
+ else:
493
+ # Not enough tokens
494
+ print(f"Not enough tokens in activation data: {len(temp)}")
495
+ return np.zeros((image.height, image.width, 4), dtype=np.uint8)
496
+
497
+ except Exception as e:
498
+ print(f"Error processing activation data: {e}")
499
+ print(f"Shape info - sae_act: {type(sae_act)}, slider_value: {slider_value}")
500
+ return np.zeros((image.height, image.width, 4), dtype=np.uint8)
501
+
502
+ # Create RGBA overlay
503
+ try:
504
+ # Set base opacity for darkened areas
505
+ base_opacity = 30
506
+
507
+ # Convert image to numpy array
508
+ image_array = np.array(image)
509
+
510
+ # Handle grayscale images
511
+ if len(image_array.shape) == 2:
512
+ # Convert grayscale to RGB
513
+ image_array = np.stack([image_array] * 3, axis=-1)
514
+ elif image_array.shape[2] == 4:
515
+ # Use only RGB channels
516
+ image_array = image_array[..., :3]
517
+
518
+ # Create overlay
519
+ rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
520
+ rgba_overlay[..., :3] = image_array
521
+
522
+ # Use vectorized operations for better performance
523
+ darkened_image = (image_array * (base_opacity / 255)).astype(np.uint8)
524
+
525
+ # Create mask for darkened areas
526
+ mask_threshold = 0.1 # Adjust threshold if needed
527
+ mask_zero = mask < mask_threshold
528
+
529
+ # Apply darkening only to low-activation areas
530
+ rgba_overlay[mask_zero, :3] = darkened_image[mask_zero]
531
+
532
+ # Set alpha channel
533
+ rgba_overlay[..., 3] = 255 # Fully opaque
534
+
535
+ # Cache result for future use
536
+ with data_lock:
537
+ segmask_cache[cache_key] = rgba_overlay
538
+
539
+ return rgba_overlay
540
+
541
+ except Exception as e:
542
+ print(f"Error creating overlay: {e}")
543
+ return np.zeros((image.height, image.width, 4), dtype=np.uint8)
544
+
545
  except Exception as e:
546
+ print(f"Unexpected error in get_segmask: {e}")
547
+ # Return a blank image of standard size
548
+ return np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
+ # Cache top images
551
+ @lru_cache(maxsize=32)
552
  def get_top_images(slider_value, toggle_btn):
553
+ """Get top images with caching"""
554
+ cache_key = f"{slider_value}_{toggle_btn}"
555
+
556
+ if cache_key in top_images_cache:
557
+ return top_images_cache[cache_key]
558
+
559
  def _get_images(dataset_path):
560
  top_image_paths = [
561
  os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
562
  os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
563
  os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
564
  ]
565
+
566
+ top_images = []
567
+ for path in top_image_paths:
568
+ if os.path.exists(path):
569
+ top_images.append(Image.open(path))
570
+ else:
571
+ top_images.append(Image.new("RGB", (256, 256), (255, 255, 255)))
572
+
573
  return top_images
574
+
575
  if toggle_btn:
576
  top_images = _get_images("./data/top_images_masked")
577
  else:
578
  top_images = _get_images("./data/top_images")
579
+
580
+ # Cache result
581
+ top_images_cache[cache_key] = top_images
582
+
583
  return top_images
584
 
 
585
  def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
586
+ """Show activation heatmap with optimized processing"""
587
+ try:
588
+ # Parse slider value safely
589
+ if not slider_value:
590
+ # Fallback to the first option if no slider value
591
+ radio_options = get_init_radio_options(selected_image, model_type)
592
+ if not radio_options:
593
+ # Create placeholder data if no options available
594
+ return (
595
+ np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8),
596
+ [Image.new("RGB", (256, 256), (255, 255, 255)) for _ in range(3)],
597
+ ["#### Activation values: No data available"] * 3
598
+ )
599
+ slider_value = radio_options[0]
600
+
601
+ # Extract the integer value
602
+ try:
603
+ slider_value_int = int(slider_value.split("-")[-1])
604
+ except (ValueError, IndexError):
605
+ print(f"Error parsing slider value: {slider_value}")
606
+ slider_value_int = 0
607
+
608
+ # Process in parallel with thread pool and add timeout
609
+ results = []
610
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
611
+ # Start both tasks
612
+ segmask_future = executor.submit(get_segmask, selected_image, slider_value_int, model_type)
613
+ top_images_future = executor.submit(get_top_images, slider_value_int, toggle_btn)
614
+
615
+ # Get results with timeout to prevent hanging
616
+ try:
617
+ rgba_overlay = segmask_future.result(timeout=5)
618
+ except (concurrent.futures.TimeoutError, Exception) as e:
619
+ print(f"Error or timeout generating segmentation mask: {e}")
620
+ rgba_overlay = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8)
621
+
622
+ try:
623
+ top_images = top_images_future.result(timeout=5)
624
+ except (concurrent.futures.TimeoutError, Exception) as e:
625
+ print(f"Error or timeout getting top images: {e}")
626
+ top_images = [Image.new("RGB", (256, 256), (255, 255, 255)) for _ in range(3)]
627
+
628
+ # Prepare activation values with error handling
629
+ act_values = []
630
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
631
+ try:
632
+ if dataset in sae_data_dict["mean_act_values"]:
633
+ values = sae_data_dict["mean_act_values"][dataset]
634
+ if slider_value_int < values.shape[0]:
635
+ act_value = values[slider_value_int, :5]
636
+ act_value = [str(round(value, 3)) for value in act_value]
637
+ act_value = " | ".join(act_value)
638
+ out = f"#### Activation values: {act_value}"
639
+ else:
640
+ out = f"#### Activation values: Index out of range"
641
+ else:
642
+ out = f"#### Activation values: Dataset not available"
643
+ except Exception as e:
644
+ print(f"Error getting activation values for {dataset}: {e}")
645
+ out = f"#### Activation values: Error retrieving data"
646
+
647
+ act_values.append(out)
648
+
649
+ return rgba_overlay, top_images, act_values
650
+
651
+ except Exception as e:
652
+ print(f"Error in show_activation_heatmap: {e}")
653
+ # Return placeholder data in case of error
654
+ return (
655
+ np.zeros((IMAGE_SIZE, IMAGE_SIZE, 4), dtype=np.uint8),
656
+ [Image.new("RGB", (256, 256), (255, 255, 255)) for _ in range(3)],
657
+ ["#### Activation values: Error occurred"] * 3
658
+ )
659
 
660
  def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
661
+ """Show CLIP activation heatmap"""
662
  rgba_overlay, top_images, act_values = show_activation_heatmap(
663
  selected_image, slider_value, "CLIP", toggle_btn
664
  )
665
+
666
  return (
667
  rgba_overlay,
668
  top_images[0],
 
673
  act_values[2],
674
  )
675
 
 
676
  def show_activation_heatmap_maple(selected_image, slider_value, model_name):
677
+ """Show MaPLE activation heatmap"""
678
+ slider_value_int = int(slider_value.split("-")[-1])
679
+ rgba_overlay = get_segmask(selected_image, slider_value_int, model_name)
680
+
681
  return rgba_overlay
682
 
683
+ # Optimize radio options generation
684
  def get_init_radio_options(selected_image, model_name):
685
+ """Get initial radio options with optimized processing"""
686
  clip_neuron_dict = {}
687
  maple_neuron_dict = {}
688
+
689
  def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
690
  activations = get_activation_distribution(selected_image, model_name).mean(0)
691
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
 
695
  sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
696
  )
697
  return sorted_dict
698
+
699
+ # Process in parallel
700
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
701
+ future_clip = executor.submit(_get_top_actvation, selected_image, "CLIP", {})
702
+ future_maple = executor.submit(_get_top_actvation, selected_image, model_name, {})
703
+
704
+ clip_neuron_dict = future_clip.result()
705
+ maple_neuron_dict = future_maple.result()
706
+
707
  radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
708
+
709
  return radio_choices
710
 
 
711
  def get_radio_names(clip_neuron_dict, maple_neuron_dict):
712
+ """Get radio button names based on neuron activations"""
713
  clip_keys = list(clip_neuron_dict.keys())
714
  maple_keys = list(maple_neuron_dict.keys())
715
+
716
+ # Use set operations for better performance
717
  common_keys = list(set(clip_keys).intersection(set(maple_keys)))
718
+ clip_only_keys = list(set(clip_keys) - set(maple_keys))
719
+ maple_only_keys = list(set(maple_keys) - set(clip_keys))
720
+
721
+ # Sort keys by activation values
722
  common_keys.sort(
723
+ key=lambda x: max(clip_neuron_dict.get(x, 0), maple_neuron_dict.get(x, 0)),
724
+ reverse=True
725
  )
726
+ clip_only_keys.sort(key=lambda x: clip_neuron_dict.get(x, 0), reverse=True)
727
+ maple_only_keys.sort(key=lambda x: maple_neuron_dict.get(x, 0), reverse=True)
728
+
729
+ # Limit number of choices to improve performance
730
  out = []
731
  out.extend([f"common-{i}" for i in common_keys[:5]])
732
  out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
733
  out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
734
+
735
  return out
736
 
737
+ def update_radio_options(evt, selected_image, model_name):
738
+ """Update radio options based on user interaction"""
739
+ def _get_top_actvation(evt, selected_image, model_name):
740
+ neuron_dict = {}
 
 
 
 
741
  all_activation = get_activation_distribution(selected_image, model_name)
742
  image_activation = all_activation.mean(0)
743
+
744
+ # Get top activations from image-level
745
+ top_neurons = list(np.argsort(image_activation)[::-1][:5])
746
+ for top_neuron in top_neurons:
747
+ neuron_dict[top_neuron] = image_activation[top_neuron]
748
+
749
+ # Get top activations from tile-level if available
750
+ if evt is not None and evt._data is not None and isinstance(evt._data["index"], list):
751
+ image = data_dict[selected_image]["image"]
752
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
753
+ token_idx = grid_y * GRID_NUM + grid_x + 1
754
+
755
+ # Ensure token_idx is within bounds
756
+ if token_idx < all_activation.shape[0]:
757
  tile_activations = all_activation[token_idx]
758
+ top_tile_neurons = list(np.argsort(tile_activations)[::-1][:5])
759
+ for top_neuron in top_tile_neurons:
760
+ neuron_dict[top_neuron] = max(
761
+ neuron_dict.get(top_neuron, 0),
762
+ tile_activations[top_neuron]
763
+ )
764
+
765
+ # Sort by activation value
766
+ return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
767
+
768
+ # Process in parallel
769
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
770
+ future_clip = executor.submit(_get_top_actvation, evt, selected_image, "CLIP")
771
+ future_maple = executor.submit(_get_top_actvation, evt, selected_image, model_name)
772
+
773
+ clip_neuron_dict = future_clip.result()
774
+ maple_neuron_dict = future_maple.result()
775
+
776
+ # Get radio choices
777
+ radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
778
+
779
+ # Create radio component
780
+ radio = gr.Radio(
781
+ choices=radio_choices,
782
+ label="Top activating SAE latent",
783
+ value=radio_choices[0] if radio_choices else None
 
 
 
 
 
 
 
 
784
  )
785
+
786
+ return radio
 
787
 
788
  def update_markdown(option_value):
789
+ """Update markdown text"""
790
  latent_idx = int(option_value.split("-")[-1])
791
  out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
792
  out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
793
  return out_1, out_2
794
 
 
 
 
 
 
 
 
 
 
 
 
795
  def update_all(selected_image, slider_value, toggle_btn, model_name):
796
+ """Update all UI components in optimized way"""
797
+ # Use a thread pool to parallelize operations
798
+ with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
799
+ # Start both tasks
800
+ clip_future = executor.submit(
801
+ show_activation_heatmap_clip,
802
+ selected_image,
803
+ slider_value,
804
+ toggle_btn
805
+ )
806
+
807
+ maple_future = executor.submit(
808
+ show_activation_heatmap_maple,
809
+ selected_image,
810
+ slider_value,
811
+ model_name
812
+ )
813
+
814
+ # Get results
815
+ (
816
+ seg_mask_display,
817
+ top_image_1,
818
+ top_image_2,
819
+ top_image_3,
820
+ act_value_1,
821
+ act_value_2,
822
+ act_value_3,
823
+ ) = clip_future.result()
824
+
825
+ seg_mask_display_maple = maple_future.result()
826
+
827
+ # Update markdown
828
  markdown_display, markdown_display_2 = update_markdown(slider_value)
829
+
830
  return (
831
  seg_mask_display,
832
  seg_mask_display_maple,
 
840
  markdown_display_2,
841
  )
842
 
843
+ # Initialize data - load at startup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
845
  default_image_name = "christmas-imagenet"
846
 
847
+ # Define UI with lazy loading
848
  with gr.Blocks(
849
  theme=gr.themes.Citrus(),
850
  css="""
851
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
852
  .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
853
+ """,
854
  ) as demo:
855
  with gr.Row():
856
  with gr.Column():
 
862
  label="Select Image",
863
  )
864
  image_display = gr.Image(
865
+ value=load_image(default_image_name),
866
  type="pil",
867
  interactive=True,
868
  )
869
+
870
+ # Update image display when a new image is selected (with debounce)
871
  image_selector.change(
872
+ fn=load_image,
873
  inputs=image_selector,
874
  outputs=image_display,
875
+ _js="""
876
+ function(img_name) {
877
+ // Simple debounce
878
+ clearTimeout(window._imageSelectTimeout);
879
+ return new Promise((resolve) => {
880
+ window._imageSelectTimeout = setTimeout(() => {
881
+ resolve(img_name);
882
+ }, 100);
883
+ });
884
+ }
885
+ """
886
  )
887
+
888
+ # Handle grid highlighting
889
  image_display.select(
890
+ fn=highlight_grid,
891
+ inputs=[image_selector],
892
+ outputs=[image_display]
893
  )
894
+
895
  with gr.Column():
896
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
897
  model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
 
900
  value=model_options[0],
901
  label="Select adapted model (MaPLe)",
902
  )
903
+
904
+ # Initialize with a placeholder plot to avoid delays
 
905
  neuron_plot = gr.Plot(
906
+ label="Neuron Activation",
907
+ show_label=False
908
  )
909
+
910
+ # Add event handlers with proper data flow
911
+ def update_plot(evt, selected_image, model_name):
912
+ if hasattr(evt, '_data') and evt._data is not None:
913
+ return plot_activation_distribution(
914
+ tuple(map(tuple, evt._data.get('index', []))),
915
+ selected_image,
916
+ model_name
917
+ )
918
+ return plot_activation_distribution(None, selected_image, model_name)
919
+
920
+ # Load initial plot after UI is rendered
921
+ gr.on(
922
+ [image_selector.change, model_selector.change],
923
+ fn=lambda img, model: plot_activation_distribution(None, img, model),
924
  inputs=[image_selector, model_selector],
925
  outputs=neuron_plot,
926
  )
927
+
928
+ # Update plot on image click
929
  image_display.select(
930
+ fn=update_plot,
 
 
 
 
 
 
 
 
931
  inputs=[image_selector, model_selector],
932
  outputs=neuron_plot,
933
  )
934
 
935
  with gr.Row():
936
  with gr.Column():
937
+ # Initialize radio options
938
+ radio_names = gr.State(value=get_init_radio_options(default_image_name, model_options[0]))
939
+
940
+ # Initialize markdown displays
941
+ markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent")
942
+
943
+ # Initialize segmentation displays
 
 
 
944
  gr.Markdown("### Localize SAE latent activation using CLIP")
945
+ seg_mask_display = gr.Image(type="pil", show_label=False)
946
+
 
 
947
  gr.Markdown("### Localize SAE latent activation using MaPLE")
948
+ seg_mask_display_maple = gr.Image(type="pil", show_label=False)
949
+
 
 
950
  with gr.Column():
951
  gr.Markdown("## Top activating SAE latent index")
952
+
953
+ # Initialize radio component
954
  radio_choices = gr.Radio(
 
955
  label="Top activating SAE latent",
956
  interactive=True,
 
957
  )
958
+
959
+ # Initialize as soon as UI loads
960
+ gr.on(
961
+ gr.Blocks.load,
962
+ fn=lambda: gr.Radio.update(
963
+ choices=get_init_radio_options(default_image_name, model_options[0]),
964
+ value=get_init_radio_options(default_image_name, model_options[0])[0]
965
+ ),
966
+ outputs=radio_choices
967
  )
968
+
969
+ toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
970
+
971
+ markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent")
972
+
973
+ # Initialize image displays
974
  gr.Markdown("### ImageNet")
975
+ top_image_1 = gr.Image(type="pil", label="ImageNet", show_label=False)
976
+ act_value_1 = gr.Markdown()
977
+
 
 
978
  gr.Markdown("### ImageNet-Sketch")
979
+ top_image_2 = gr.Image(type="pil", label="ImageNet-Sketch", show_label=False)
980
+ act_value_2 = gr.Markdown()
981
+
 
 
 
 
 
982
  gr.Markdown("### Caltech101")
983
+ top_image_3 = gr.Image(type="pil", label="Caltech101", show_label=False)
984
+ act_value_3 = gr.Markdown()
985
+
986
+ # Update radio options on image interaction
 
987
  image_display.select(
988
  fn=update_radio_options,
989
  inputs=[image_selector, model_selector],
990
+ outputs=radio_choices,
991
  )
992
+
993
+ # Update radio options on model change
994
  model_selector.change(
995
  fn=update_radio_options,
996
  inputs=[image_selector, model_selector],
997
+ outputs=radio_choices,
998
  )
999
+
1000
+ # Update radio options on image selection
1001
+ image_selector.change(
1002
  fn=update_radio_options,
1003
  inputs=[image_selector, model_selector],
1004
+ outputs=radio_choices,
1005
  )
1006
+
1007
+ # Initialize