hyesulim commited on
Commit
cf1e3e9
·
verified ·
1 Parent(s): b494b53

test: fix typo

Browse files
Files changed (1) hide show
  1. app.py +401 -548
app.py CHANGED
@@ -2,196 +2,153 @@ import gzip
2
  import os
3
  import pickle
4
  from glob import glob
5
- from time import sleep
6
-
7
  from functools import lru_cache
8
  import concurrent.futures
9
- from typing import Dict, Tuple, List
 
10
 
11
  import gradio as gr
12
  import numpy as np
13
- import plotly.graph_objects as go
14
  import torch
15
  from PIL import Image, ImageDraw
 
16
  from plotly.subplots import make_subplots
17
 
 
18
  IMAGE_SIZE = 400
19
  DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
20
  GRID_NUM = 14
21
- pkl_root = "./data/out"
22
- preloaded_data = {}
23
-
24
-
25
- # Global cache for data
26
- _CACHE = {
27
- 'data_dict': {},
28
- 'sae_data_dict': {},
29
- 'model_data': {},
30
- 'segmasks': {},
31
- 'top_images': {},
32
- 'precomputed_activations' = {}
33
-
34
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def load_all_data(image_root: str, pkl_root: str) -> Tuple[Dict, Dict]:
37
  """Load all data with optimized parallel processing."""
 
 
 
 
 
 
 
 
 
 
 
38
  # Load images in parallel
39
  with concurrent.futures.ThreadPoolExecutor() as executor:
40
- image_files = glob(f"{image_root}/*")
41
  future_to_file = {
42
- executor.submit(_load_image_file, image_file): image_file
43
- for image_file in image_files
44
  }
45
 
46
  for future in concurrent.futures.as_completed(future_to_file):
47
- image_file = future_to_file[future]
48
- image_name = os.path.basename(image_file).split(".")[0]
49
- result = future.result()
50
- if result is not None:
51
- _CACHE['data_dict'][image_name] = result
 
 
 
52
 
53
  # Load SAE data
54
- with open("./data/sae_data/mean_acts.pkl", "rb") as f:
55
- _CACHE['sae_data_dict']["mean_acts"] = pickle.load(f)
56
-
57
- # Load mean act values in parallel
58
- datasets = ["imagenet", "imagenet-sketch", "caltech101"]
59
- _CACHE['sae_data_dict']["mean_act_values"] = {}
60
-
61
- with concurrent.futures.ThreadPoolExecutor() as executor:
62
- future_to_dataset = {
63
- executor.submit(_load_mean_act_values, dataset): dataset
64
- for dataset in datasets
65
- }
66
-
67
- for future in concurrent.futures.as_completed(future_to_dataset):
68
- dataset = future_to_dataset[future]
69
- result = future.result()
70
- if result is not None:
71
- _CACHE['sae_data_dict']["mean_act_values"][dataset] = result
72
-
73
- return _CACHE['data_dict'], _CACHE['sae_data_dict']
74
-
75
- def _load_image_file(image_file: str) -> Dict:
76
- """Helper function to load a single image file."""
77
  try:
78
- image = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
79
- return {
80
- "image": image,
81
- "image_path": image_file,
82
- }
83
  except Exception as e:
84
- print(f"Error loading {image_file}: {e}")
85
- return None
86
 
87
- def _load_mean_act_values(dataset: str) -> np.ndarray:
88
- """Helper function to load mean act values for a dataset."""
89
- try:
90
- with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
91
- return pickle.load(f)
92
- except Exception as e:
93
- print(f"Error loading mean act values for {dataset}: {e}")
94
- return None
 
 
 
 
95
 
96
  @lru_cache(maxsize=1024)
97
  def get_data(image_name: str, model_name: str) -> np.ndarray:
98
- """Cached function to get model data."""
99
  cache_key = f"{model_name}_{image_name}"
100
- if cache_key not in _CACHE['model_data']:
101
- data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
102
- with gzip.open(data_dir, "rb") as f:
103
- _CACHE['model_data'][cache_key] = pickle.load(f)
104
- return _CACHE['model_data'][cache_key]
 
 
 
 
105
 
106
  @lru_cache(maxsize=1024)
107
  def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
108
- """Cached function to get activation distribution."""
109
- activation = get_data(image_name, model_type)[0]
110
- noisy_features_indices = (
111
- (_CACHE['sae_data_dict']["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
112
- )
113
- activation[:, noisy_features_indices] = 0
114
- return activation
115
-
116
- @lru_cache(maxsize=1024)
117
- def get_segmask(selected_image: str, slider_value: int, model_type: str) -> np.ndarray:
118
- """Cached function to get segmentation mask."""
119
- cache_key = f"{selected_image}_{slider_value}_{model_type}"
120
- if cache_key not in _CACHE['segmasks']:
121
- image = _CACHE['data_dict'][selected_image]["image"]
122
- sae_act = get_data(selected_image, model_type)[0]
123
- temp = sae_act[:, slider_value]
124
-
125
- mask = torch.Tensor(temp[1:].reshape(14, 14)).view(1, 1, 14, 14)
126
- mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
127
- mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
128
-
129
- base_opacity = 30
130
- image_array = np.array(image)[..., :3]
131
- rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
132
- rgba_overlay[..., :3] = image_array[..., :3]
133
-
134
- darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
135
- rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
136
- rgba_overlay[..., 3] = 255
137
-
138
- _CACHE['segmasks'][cache_key] = rgba_overlay
139
-
140
- return _CACHE['segmasks'][cache_key]
141
-
142
- @lru_cache(maxsize=1024)
143
- def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]:
144
- """Cached function to get top images."""
145
- cache_key = f"{slider_value}_{toggle_btn}"
146
- if cache_key not in _CACHE['top_images']:
147
- dataset_path = "./data/top_images_masked" if toggle_btn else "./data/top_images"
148
- paths = [
149
- os.path.join(dataset_path, dataset, f"{slider_value}.jpg")
150
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
151
- ]
152
-
153
- _CACHE['top_images'][cache_key] = [
154
- Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
155
- for path in paths
156
- ]
157
 
158
- return _CACHE['top_images'][cache_key]
159
-
160
-
161
- # def preload_activation(image_name):
162
- # for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
163
- # image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
164
- # with gzip.open(image_file, "rb") as f:
165
- # preloaded_data[model] = pickle.load(f)
166
-
167
-
168
- # def get_activation_distribution(image_name: str, model_type: str):
169
- # activation = get_data(image_name, model_type)[0]
170
-
171
- # noisy_features_indices = (
172
- # (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
173
- # )
174
- # activation[:, noisy_features_indices] = 0
175
-
176
- # return activation
177
-
178
 
179
- def get_grid_loc(evt, image):
180
- # Get click coordinates
181
  x, y = evt._data["index"][0], evt._data["index"][1]
182
-
183
  cell_width = image.width // GRID_NUM
184
  cell_height = image.height // GRID_NUM
185
-
186
  grid_x = x // cell_width
187
  grid_y = y // cell_height
188
  return grid_x, grid_y, cell_width, cell_height
189
 
190
-
191
- def highlight_grid(evt: gr.EventData, image_name):
192
- image = data_dict[image_name]["image"]
 
 
 
193
  grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
194
-
195
  highlighted_image = image.copy()
196
  draw = ImageDraw.Draw(highlighted_image)
197
  box = [
@@ -201,25 +158,18 @@ def highlight_grid(evt: gr.EventData, image_name):
201
  (grid_y + 1) * cell_height,
202
  ]
203
  draw.rectangle(box, outline="red", width=3)
204
-
205
  return highlighted_image
206
 
207
-
208
- def load_image(img_name):
209
- return Image.open(data_dict[img_name]["image_path"]).resize(
210
- (IMAGE_SIZE, IMAGE_SIZE)
211
- )
212
-
213
-
214
  def plot_activations(
215
- all_activation,
216
- tile_activations=None,
217
- grid_x=None,
218
- grid_y=None,
219
- top_k=5,
220
- colors=("blue", "cyan"),
221
- model_name="CLIP",
222
- ):
 
223
  fig = go.Figure()
224
 
225
  def _add_scatter_with_annotation(fig, activations, model_name, color, label):
@@ -248,59 +198,90 @@ def plot_activations(
248
  )
249
  return fig
250
 
251
- label = f"{model_name.split('-')[-0]} Image-level"
252
- fig = _add_scatter_with_annotation(
253
- fig, all_activation, model_name, colors[0], label
254
- )
255
  if tile_activations is not None:
256
- label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
257
- fig = _add_scatter_with_annotation(
258
- fig, tile_activations, model_name, colors[1], label
259
- )
260
 
261
  fig.update_layout(
262
  title="Activation Distribution",
263
  xaxis_title="SAE latent index",
264
  yaxis_title="Activation Value",
265
  template="plotly_white",
266
- )
267
- fig.update_layout(
268
  legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
269
  )
270
 
271
  return fig
272
 
 
 
 
 
 
 
273
 
274
- def get_activations(evt: gr.EventData, selected_image: str, model_name: str, colors):
275
- activation = get_activation_distribution(selected_image, model_name)
276
- all_activation = activation.mean(0)
277
-
278
- tile_activations = None
279
- grid_x = None
280
- grid_y = None
281
-
282
- if evt is not None:
283
- if evt._data is not None:
284
- image = data_dict[selected_image]["image"]
285
- grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
286
- token_idx = grid_y * GRID_NUM + grid_x + 1
287
- tile_activations = activation[token_idx]
288
-
289
- fig = plot_activations(
290
- all_activation,
291
- tile_activations,
292
- grid_x,
293
- grid_y,
294
- top_k=5,
295
- model_name=model_name,
296
- colors=colors,
297
- )
298
- return fig
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
 
301
  def plot_activation_distribution(
302
- evt: gr.EventData, selected_image: str, model_name: str
303
- ):
 
 
 
304
  fig = make_subplots(
305
  rows=2,
306
  cols=1,
@@ -308,17 +289,37 @@ def plot_activation_distribution(
308
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
309
  )
310
 
311
- fig_clip = get_activations(
312
- evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
313
- )
314
- fig_maple = get_activations(
315
- evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
316
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  def _attach_fig(fig, sub_fig, row, col, yref):
319
  for trace in sub_fig.data:
320
  fig.add_trace(trace, row=row, col=col)
321
-
322
  for annotation in sub_fig.layout.annotations:
323
  annotation.update(yref=yref)
324
  fig.add_annotation(annotation)
@@ -332,8 +333,6 @@ def plot_activation_distribution(
332
  fig.update_yaxes(title_text="Activation Value", row=1, col=1)
333
  fig.update_yaxes(title_text="Activation Value", row=2, col=1)
334
  fig.update_layout(
335
- # height=500,
336
- # title="Activation Distributions",
337
  template="plotly_white",
338
  showlegend=True,
339
  legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
@@ -342,73 +341,12 @@ def plot_activation_distribution(
342
 
343
  return fig
344
 
345
-
346
- # def get_segmask(selected_image, slider_value, model_type):
347
- # image = data_dict[selected_image]["image"]
348
- # sae_act = get_data(selected_image, model_type)[0]
349
- # temp = sae_act[:, slider_value]
350
- # try:
351
- # mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
352
- # except Exception as e:
353
- # print(sae_act.shape, slider_value)
354
- # mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
355
- # 0
356
- # ].numpy()
357
- # mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
358
-
359
- # base_opacity = 30
360
- # image_array = np.array(image)[..., :3]
361
- # rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
362
- # rgba_overlay[..., :3] = image_array[..., :3]
363
-
364
- # darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
365
- # rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
366
- # rgba_overlay[..., 3] = 255 # Fully opaque
367
-
368
- # return rgba_overlay
369
-
370
-
371
- # def get_top_images(slider_value, toggle_btn):
372
- # def _get_images(dataset_path):
373
- # top_image_paths = [
374
- # os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
375
- # os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
376
- # os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
377
- # ]
378
- # top_images = [
379
- # (
380
- # Image.open(path)
381
- # if os.path.exists(path)
382
- # else Image.new("RGB", (256, 256), (255, 255, 255))
383
- # )
384
- # for path in top_image_paths
385
- # ]
386
- # return top_images
387
-
388
- # if toggle_btn:
389
- # top_images = _get_images("./data/top_images_masked")
390
- # else:
391
- # top_images = _get_images("./data/top_images")
392
- # return top_images
393
-
394
-
395
- def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
396
- slider_value = int(slider_value.split("-")[-1])
397
- rgba_overlay = get_segmask(selected_image, slider_value, model_type)
398
- top_images = get_top_images(slider_value, toggle_btn)
399
-
400
- act_values = []
401
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
402
- act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5]
403
- act_value = [str(round(value, 3)) for value in act_value]
404
- act_value = " | ".join(act_value)
405
- out = f"#### Activation values: {act_value}"
406
- act_values.append(out)
407
-
408
- return rgba_overlay, top_images, act_values
409
-
410
-
411
- def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
412
  rgba_overlay, top_images, act_values = show_activation_heatmap(
413
  selected_image, slider_value, "CLIP", toggle_btn
414
  )
@@ -423,49 +361,68 @@ def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
423
  act_values[2],
424
  )
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
- def show_activation_heatmap_maple(selected_image, slider_value, model_name):
 
 
 
 
 
 
 
428
  slider_value = int(slider_value.split("-")[-1])
429
  rgba_overlay = get_segmask(selected_image, slider_value, model_name)
430
  sleep(0.1)
431
  return rgba_overlay
432
 
433
-
434
- def get_init_radio_options(selected_image, model_name):
435
  clip_neuron_dict = {}
436
  maple_neuron_dict = {}
437
 
438
- def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
439
  activations = get_activation_distribution(selected_image, model_name).mean(0)
440
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
441
  for top_neuron in top_neurons:
442
  neuron_dict[top_neuron] = activations[top_neuron]
443
- sorted_dict = dict(
444
- sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
445
- )
446
- return sorted_dict
447
-
448
- clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
449
- maple_neuron_dict = _get_top_actvation(
450
- selected_image, model_name, maple_neuron_dict
451
- )
452
-
453
- radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
454
 
455
- return radio_choices
 
456
 
 
457
 
458
- def get_radio_names(clip_neuron_dict, maple_neuron_dict):
 
 
 
 
459
  clip_keys = list(clip_neuron_dict.keys())
460
  maple_keys = list(maple_neuron_dict.keys())
461
 
462
  common_keys = list(set(clip_keys).intersection(set(maple_keys)))
463
- clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
464
- maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
465
 
466
- common_keys.sort(
467
- key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
468
- )
469
  clip_only_keys.sort(reverse=True)
470
  maple_only_keys.sort(reverse=True)
471
 
@@ -476,81 +433,54 @@ def get_radio_names(clip_neuron_dict, maple_neuron_dict):
476
 
477
  return out
478
 
 
 
 
 
 
 
 
 
479
 
480
- def update_radio_options(evt: gr.EventData, selected_image, model_name):
481
- def _sort_and_save_top_k(activations, neuron_dict, top_k=5):
482
- top_neurons = list(np.argsort(activations)[::-1][:top_k])
483
- for top_neuron in top_neurons:
484
- neuron_dict[top_neuron] = activations[top_neuron]
485
-
486
- def _get_top_actvation(evt, selected_image, model_name, neuron_dict):
487
  all_activation = get_activation_distribution(selected_image, model_name)
488
  image_activation = all_activation.mean(0)
489
- _sort_and_save_top_k(image_activation, neuron_dict)
 
 
490
 
491
- if evt is not None:
492
- if evt._data is not None and isinstance(evt._data["index"], list):
493
- image = data_dict[selected_image]["image"]
494
- grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
495
  token_idx = grid_y * GRID_NUM + grid_x + 1
496
  tile_activations = all_activation[token_idx]
497
- _sort_and_save_top_k(tile_activations, neuron_dict)
498
-
499
- sorted_dict = dict(
500
- sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
501
- )
502
- return sorted_dict
503
-
504
- clip_neuron_dict = {}
505
- maple_neuron_dict = {}
506
- clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
507
- maple_neuron_dict = _get_top_actvation(
508
- evt, selected_image, model_name, maple_neuron_dict
509
- )
510
 
511
- clip_keys = list(clip_neuron_dict.keys())
512
- maple_keys = list(maple_neuron_dict.keys())
513
-
514
- common_keys = list(set(clip_keys).intersection(set(maple_keys)))
515
- clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
516
- maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
517
-
518
- common_keys.sort(
519
- key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
520
- )
521
- clip_only_keys.sort(reverse=True)
522
- maple_only_keys.sort(reverse=True)
523
 
524
- out = []
525
- out.extend([f"common-{i}" for i in common_keys[:5]])
526
- out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
527
- out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
528
-
529
- radio_choices = gr.Radio(
530
- choices=out, label="Top activating SAE latent", value=out[0]
531
- )
532
- sleep(0.1)
533
- return radio_choices
534
 
 
 
535
 
536
- def update_markdown(option_value):
 
537
  latent_idx = int(option_value.split("-")[-1])
538
  out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
539
  out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
540
  return out_1, out_2
541
 
542
-
543
- def get_data(image_name, model_name):
544
- pkl_root = "./data/out"
545
- data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
546
- with gzip.open(data_dir, "rb") as f:
547
- data = pickle.load(f)
548
- out = data
549
-
550
- return out
551
-
552
-
553
- def update_all(selected_image, slider_value, toggle_btn, model_name):
554
  (
555
  seg_mask_display,
556
  top_image_1,
@@ -560,6 +490,7 @@ def update_all(selected_image, slider_value, toggle_btn, model_name):
560
  act_value_2,
561
  act_value_3,
562
  ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
 
563
  seg_mask_display_maple = show_activation_heatmap_maple(
564
  selected_image, slider_value, model_name
565
  )
@@ -578,101 +509,67 @@ def update_all(selected_image, slider_value, toggle_btn, model_name):
578
  markdown_display_2,
579
  )
580
 
581
-
582
- def load_all_data(image_root, pkl_root):
583
- image_files = glob(f"{image_root}/*")
584
- data_dict = {}
585
- for image_file in image_files:
586
- image_name = os.path.basename(image_file).split(".")[0]
587
- if image_file not in data_dict:
588
- data_dict[image_name] = {
589
- "image": Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE)),
590
- "image_path": image_file,
591
- }
592
-
593
- sae_data_dict = {}
594
- with open("./data/sae_data/mean_acts.pkl", "rb") as f:
595
- data = pickle.load(f)
596
- sae_data_dict["mean_acts"] = data
597
-
598
- sae_data_dict["mean_act_values"] = {}
599
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
600
- with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
601
- data = pickle.load(f)
602
- sae_data_dict["mean_act_values"][dataset] = data
603
-
604
- return data_dict, sae_data_dict
605
-
606
-
607
- def preload_all_model_data():
608
- """Preload all model data into memory at startup"""
609
- print("Preloading model data...")
610
- for image_name in data_dict.keys():
611
- for model_name in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
612
- try:
613
- data = get_data(image_name, model_name)
614
- cache_key = f"{model_name}_{image_name}"
615
- _CACHE['model_data'][cache_key] = data
616
- except Exception as e:
617
- print(f"Error preloading {cache_key}: {e}")
618
-
619
- def precompute_activations():
620
- """Precompute and cache common activation patterns"""
621
- print("Precomputing activations...")
622
- for image_name in data_dict.keys():
623
- for model_name in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
624
- activation = get_activation_distribution(image_name, model_name)
625
- cache_key = f"activation_{model_name}_{image_name}"
626
- _CACHE['precomputed_activations'][cache_key] = activation.mean(0)
627
-
628
-
629
-
630
- def precompute_segmasks():
631
- """Precompute common segmentation masks"""
632
- print("Precomputing segmentation masks...")
633
- for image_name in data_dict.keys():
634
- for model_type in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
635
- for slider_value in range(0, 100): # Adjust range as needed
636
- try:
637
- mask = get_segmask(image_name, slider_value, model_type)
638
- cache_key = f"{image_name}_{slider_value}_{model_type}"
639
- _CACHE['segmasks'][cache_key] = mask
640
- except Exception as e:
641
- print(f"Error precomputing mask {cache_key}: {e}")
642
-
643
- data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
644
  default_image_name = "christmas-imagenet"
645
 
 
 
646
  with gr.Blocks(
647
  theme=gr.themes.Citrus(),
648
  css="""
649
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
650
- .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
651
  """,
652
  ) as demo:
653
  with gr.Row():
654
  with gr.Column():
655
- # Left View: Image selection and click handling
656
  gr.Markdown("## Select input image and patch on the image")
657
  image_selector = gr.Dropdown(
658
- choices=list(data_dict.keys()),
659
  value=default_image_name,
660
  label="Select Image",
661
  )
662
  image_display = gr.Image(
663
- value=data_dict[default_image_name]["image"],
664
  type="pil",
665
  interactive=True,
666
  )
667
 
668
- # Update image display when a new image is selected
669
  image_selector.change(
670
- fn=lambda img_name: data_dict[img_name]["image"],
671
  inputs=image_selector,
672
  outputs=image_display,
673
  )
674
  image_display.select(
675
- fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
 
 
676
  )
677
 
678
  with gr.Column():
@@ -683,12 +580,8 @@ with gr.Blocks(
683
  value=model_options[0],
684
  label="Select adapted model (MaPLe)",
685
  )
686
- init_plot = plot_activation_distribution(
687
- None, default_image_name, model_options[0]
688
- )
689
- neuron_plot = gr.Plot(
690
- label="Neuron Activation", value=init_plot, show_label=False
691
- )
692
 
693
  image_selector.change(
694
  fn=plot_activation_distribution,
@@ -701,7 +594,9 @@ with gr.Blocks(
701
  outputs=neuron_plot,
702
  )
703
  model_selector.change(
704
- fn=load_image, inputs=[image_selector], outputs=image_display
 
 
705
  )
706
  model_selector.change(
707
  fn=plot_activation_distribution,
@@ -712,10 +607,9 @@ with gr.Blocks(
712
  with gr.Row():
713
  with gr.Column():
714
  radio_names = get_init_radio_options(default_image_name, model_options[0])
715
-
716
- feautre_idx = radio_names[0].split("-")[-1]
717
  markdown_display = gr.Markdown(
718
- f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
719
  )
720
  init_seg, init_tops, init_values = show_activation_heatmap(
721
  default_image_name, radio_names[0], "CLIP"
@@ -727,13 +621,10 @@ with gr.Blocks(
727
  default_image_name, radio_names[0], model_options[0]
728
  )
729
  gr.Markdown("### Localize SAE latent activation using MaPLE")
730
- seg_mask_display_maple = gr.Image(
731
- value=init_seg_maple, type="pil", show_label=False
732
- )
733
 
734
  with gr.Column():
735
  gr.Markdown("## Top activating SAE latent index")
736
-
737
  radio_choices = gr.Radio(
738
  choices=radio_names,
739
  label="Top activating SAE latent",
@@ -741,144 +632,106 @@ with gr.Blocks(
741
  value=radio_names[0],
742
  )
743
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
744
-
745
  markdown_display_2 = gr.Markdown(
746
- f"## Top reference images for the selected SAE latent - {feautre_idx}"
747
  )
748
 
749
  gr.Markdown("### ImageNet")
750
- top_image_1 = gr.Image(
751
- value=init_tops[0], type="pil", label="ImageNet", show_label=False
752
- )
753
  act_value_1 = gr.Markdown(init_values[0])
754
 
755
  gr.Markdown("### ImageNet-Sketch")
756
- top_image_2 = gr.Image(
757
- value=init_tops[1],
758
- type="pil",
759
- label="ImageNet-Sketch",
760
- show_label=False,
761
- )
762
  act_value_2 = gr.Markdown(init_values[1])
763
 
764
  gr.Markdown("### Caltech101")
765
- top_image_3 = gr.Image(
766
- value=init_tops[2], type="pil", label="Caltech101", show_label=False
767
- )
768
  act_value_3 = gr.Markdown(init_values[2])
769
 
 
770
  image_display.select(
771
  fn=update_radio_options,
772
  inputs=[image_selector, model_selector],
773
  outputs=[radio_choices],
774
  )
775
-
776
  model_selector.change(
777
  fn=update_radio_options,
778
  inputs=[image_selector, model_selector],
779
  outputs=[radio_choices],
780
  )
781
-
782
  image_selector.select(
783
  fn=update_radio_options,
784
  inputs=[image_selector, model_selector],
785
  outputs=[radio_choices],
786
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
 
788
- radio_choices.change(
789
- fn=update_all,
790
- inputs=[image_selector, radio_choices, toggle_btn, model_selector],
791
- outputs=[
792
- seg_mask_display,
793
- seg_mask_display_maple,
794
- top_image_1,
795
- top_image_2,
796
- top_image_3,
797
- act_value_1,
798
- act_value_2,
799
- act_value_3,
800
- markdown_display,
801
- markdown_display_2,
802
- ],
803
- )
804
-
805
- toggle_btn.change(
806
- fn=show_activation_heatmap_clip,
807
- inputs=[image_selector, radio_choices, toggle_btn],
808
- outputs=[
809
- seg_mask_display,
810
- top_image_1,
811
- top_image_2,
812
- top_image_3,
813
- act_value_1,
814
- act_value_2,
815
- act_value_3,
816
- ],
817
- )
818
-
819
- # Launch the app
820
- # demo.queue()
821
- # demo.launch()
822
-
823
-
824
- # if __name__ == "__main__":
825
- # demo.queue() # Enable queuing for better handling of concurrent users
826
- # demo.launch(
827
- # server_name="0.0.0.0", # Allow external access
828
- # server_port=7860,
829
- # share=False, # Set to True if you want to create a public URL
830
- # show_error=True,
831
- # # Optimize concurrency
832
- # max_threads=8, # Adjust based on your CPU cores
833
- # )
834
 
835
  if __name__ == "__main__":
836
- import psutil
 
837
 
838
  # Get system memory info
839
  mem = psutil.virtual_memory()
840
  total_ram_gb = mem.total / (1024**3)
841
 
842
- # Configure cache sizes based on available RAM
843
- cache_size = int(total_ram_gb * 100) # Rough estimate: 100 entries per GB
844
-
845
- # Precompute all data
846
- print("Starting precomputation...")
847
- preload_all_model_data()
848
- precompute_activations()
849
- precompute_segmasks()
850
- print("Precomputation complete!")
851
-
852
- # Memory monitoring function
853
- def monitor_memory_usage():
854
- """Monitor and log memory usage"""
855
- process = psutil.Process()
856
- mem_info = process.memory_info()
857
- print(f"""
858
- Memory Usage:
859
- - RSS: {mem_info.rss / (1024**2):.2f} MB
860
- - VMS: {mem_info.vms / (1024**2):.2f} MB
861
- - Cache Size: {len(_CACHE['model_data'])} entries
862
- """)
863
-
864
- # Start periodic monitoring
865
- def start_memory_monitor():
866
- threading.Timer(300.0, start_memory_monitor).start() # Every 5 minutes
867
- monitor_memory_usage()
868
-
869
- # Start the monitoring
870
- import threading
871
- start_memory_monitor()
872
-
873
- # Launch the app with memory-optimized settings
874
- demo.queue(max_size=min(20, int(total_ram_gb))) # Scale queue with RAM
875
- demo.launch(
876
- server_name="0.0.0.0",
877
- server_port=7860,
878
- share=False,
879
- show_error=True,
880
- max_threads=min(16, psutil.cpu_count()), # Scale threads with CPU
881
- websocket_ping_timeout=60,
882
- preventive_refresh=True,
883
- memory_limit_mb=int(total_ram_gb * 1024 * 0.8) # Use up to 80% of RAM
884
- )
 
2
  import os
3
  import pickle
4
  from glob import glob
5
+ import threading
6
+ import psutil
7
  from functools import lru_cache
8
  import concurrent.futures
9
+ from typing import Dict, Tuple, List, Optional
10
+ from time import sleep
11
 
12
  import gradio as gr
13
  import numpy as np
 
14
  import torch
15
  from PIL import Image, ImageDraw
16
+ import plotly.graph_objects as go
17
  from plotly.subplots import make_subplots
18
 
19
+ # Constants
20
  IMAGE_SIZE = 400
21
  DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
22
  GRID_NUM = 14
23
+ PKL_ROOT = "./data/out"
24
+
25
+ # Global cache with better type hints and error handling
26
+ class Cache:
27
+ def __init__(self):
28
+ self.data: Dict[str, Dict] = {
29
+ 'data_dict': {},
30
+ 'sae_data_dict': {},
31
+ 'model_data': {},
32
+ 'segmasks': {},
33
+ 'top_images': {},
34
+ 'precomputed_activations': {}
35
+ }
36
+
37
+ def get(self, category: str, key: str, default=None):
38
+ try:
39
+ return self.data[category].get(key, default)
40
+ except KeyError:
41
+ return default
42
+
43
+ def set(self, category: str, key: str, value):
44
+ try:
45
+ self.data[category][key] = value
46
+ except KeyError:
47
+ self.data[category] = {key: value}
48
+
49
+ def clear_category(self, category: str):
50
+ if category in self.data:
51
+ self.data[category].clear()
52
+
53
+ _CACHE = Cache()
54
 
55
  def load_all_data(image_root: str, pkl_root: str) -> Tuple[Dict, Dict]:
56
  """Load all data with optimized parallel processing."""
57
+ def load_image_file(image_file: str) -> Optional[Dict]:
58
+ try:
59
+ image = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
60
+ return {
61
+ "image": image,
62
+ "image_path": image_file,
63
+ }
64
+ except Exception as e:
65
+ print(f"Error loading image {image_file}: {e}")
66
+ return None
67
+
68
  # Load images in parallel
69
  with concurrent.futures.ThreadPoolExecutor() as executor:
 
70
  future_to_file = {
71
+ executor.submit(load_image_file, image_file): image_file
72
+ for image_file in glob(f"{image_root}/*")
73
  }
74
 
75
  for future in concurrent.futures.as_completed(future_to_file):
76
+ try:
77
+ image_file = future_to_file[future]
78
+ image_name = os.path.basename(image_file).split(".")[0]
79
+ result = future.result()
80
+ if result:
81
+ _CACHE.set('data_dict', image_name, result)
82
+ except Exception as e:
83
+ print(f"Error processing image future: {e}")
84
 
85
  # Load SAE data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  try:
87
+ with open("./data/sae_data/mean_acts.pkl", "rb") as f:
88
+ _CACHE.set('sae_data_dict', "mean_acts", pickle.load(f))
 
 
 
89
  except Exception as e:
90
+ print(f"Error loading mean_acts.pkl: {e}")
 
91
 
92
+ # Load mean act values
93
+ datasets = ["imagenet", "imagenet-sketch", "caltech101"]
94
+ for dataset in datasets:
95
+ try:
96
+ with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
97
+ if "mean_act_values" not in _CACHE.data['sae_data_dict']:
98
+ _CACHE.set('sae_data_dict', "mean_act_values", {})
99
+ _CACHE.data['sae_data_dict']["mean_act_values"][dataset] = pickle.load(f)
100
+ except Exception as e:
101
+ print(f"Error loading mean act values for {dataset}: {e}")
102
+
103
+ return _CACHE.data['data_dict'], _CACHE.data['sae_data_dict']
104
 
105
  @lru_cache(maxsize=1024)
106
  def get_data(image_name: str, model_name: str) -> np.ndarray:
107
+ """Get model data with caching."""
108
  cache_key = f"{model_name}_{image_name}"
109
+ if cache_key not in _CACHE.data['model_data']:
110
+ try:
111
+ data_dir = f"{PKL_ROOT}/{model_name}/{image_name}.pkl.gz"
112
+ with gzip.open(data_dir, "rb") as f:
113
+ _CACHE.data['model_data'][cache_key] = pickle.load(f)
114
+ except Exception as e:
115
+ print(f"Error loading model data for {cache_key}: {e}")
116
+ return np.array([])
117
+ return _CACHE.data['model_data'][cache_key]
118
 
119
  @lru_cache(maxsize=1024)
120
  def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
121
+ """Get activation distribution with memory optimization."""
122
+ try:
123
+ activation = get_data(image_name, model_type)[0]
124
+ mean_acts = _CACHE.get('sae_data_dict', "mean_acts", {}).get("imagenet", np.array([]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ if mean_acts.size > 0:
127
+ noisy_features_indices = (mean_acts > 0.1).nonzero()[0]
128
+ activation[:, noisy_features_indices] = 0
129
+
130
+ return activation
131
+ except Exception as e:
132
+ print(f"Error getting activation distribution: {e}")
133
+ return np.array([])
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ def get_grid_loc(evt: gr.EventData, image: Image.Image) -> Tuple[int, int, int, int]:
136
+ """Get grid location from click event."""
137
  x, y = evt._data["index"][0], evt._data["index"][1]
 
138
  cell_width = image.width // GRID_NUM
139
  cell_height = image.height // GRID_NUM
 
140
  grid_x = x // cell_width
141
  grid_y = y // cell_height
142
  return grid_x, grid_y, cell_width, cell_height
143
 
144
+ def highlight_grid(evt: gr.EventData, image_name: str) -> Image.Image:
145
+ """Highlight selected grid cell."""
146
+ image = _CACHE.get('data_dict', image_name, {}).get("image")
147
+ if not image:
148
+ return None
149
+
150
  grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
151
+
152
  highlighted_image = image.copy()
153
  draw = ImageDraw.Draw(highlighted_image)
154
  box = [
 
158
  (grid_y + 1) * cell_height,
159
  ]
160
  draw.rectangle(box, outline="red", width=3)
 
161
  return highlighted_image
162
 
 
 
 
 
 
 
 
163
  def plot_activations(
164
+ all_activation: np.ndarray,
165
+ tile_activations: Optional[np.ndarray] = None,
166
+ grid_x: Optional[int] = None,
167
+ grid_y: Optional[int] = None,
168
+ top_k: int = 5,
169
+ colors: Tuple[str, str] = ("blue", "cyan"),
170
+ model_name: str = "CLIP",
171
+ ) -> go.Figure:
172
+ """Plot activation distributions."""
173
  fig = go.Figure()
174
 
175
  def _add_scatter_with_annotation(fig, activations, model_name, color, label):
 
198
  )
199
  return fig
200
 
201
+ label = f"{model_name.split('-')[-1]} Image-level"
202
+ fig = _add_scatter_with_annotation(fig, all_activation, model_name, colors[0], label)
203
+
 
204
  if tile_activations is not None:
205
+ label = f"{model_name.split('-')[-1]} Tile ({grid_x}, {grid_y})"
206
+ fig = _add_scatter_with_annotation(fig, tile_activations, model_name, colors[1], label)
 
 
207
 
208
  fig.update_layout(
209
  title="Activation Distribution",
210
  xaxis_title="SAE latent index",
211
  yaxis_title="Activation Value",
212
  template="plotly_white",
 
 
213
  legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
214
  )
215
 
216
  return fig
217
 
218
+ def get_segmask(selected_image: str, slider_value: int, model_type: str) -> Optional[np.ndarray]:
219
+ """Get segmentation mask with caching."""
220
+ cache_key = f"{selected_image}_{slider_value}_{model_type}"
221
+ cached_mask = _CACHE.get('segmasks', cache_key)
222
+ if cached_mask is not None:
223
+ return cached_mask
224
 
225
+ try:
226
+ image = _CACHE.get('data_dict', selected_image, {}).get("image")
227
+ if image is None:
228
+ return None
229
+
230
+ sae_act = get_data(selected_image, model_type)[0]
231
+ temp = sae_act[:, slider_value]
232
+
233
+ mask = torch.tensor(temp[1:].reshape(14, 14)).view(1, 1, 14, 14)
234
+ mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
235
+
236
+ if mask.size == 0:
237
+ return None
238
+
239
+ mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
240
+
241
+ base_opacity = 30
242
+ image_array = np.array(image)[..., :3]
243
+ rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
244
+ rgba_overlay[..., :3] = image_array
245
+
246
+ darkened_image = (image_array * (base_opacity / 255)).astype(np.uint8)
247
+ rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
248
+ rgba_overlay[..., 3] = 255
 
249
 
250
+ _CACHE.set('segmasks', cache_key, rgba_overlay)
251
+ return rgba_overlay
252
+
253
+ except Exception as e:
254
+ print(f"Error generating segmentation mask: {e}")
255
+ return None
256
+
257
+ def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]:
258
+ """Get top images with caching."""
259
+ cache_key = f"{slider_value}_{toggle_btn}"
260
+ cached_images = _CACHE.get('top_images', cache_key)
261
+ if cached_images is not None:
262
+ return cached_images
263
+
264
+ dataset_path = "./data/top_images_masked" if toggle_btn else "./data/top_images"
265
+ paths = [
266
+ os.path.join(dataset_path, dataset, f"{slider_value}.jpg")
267
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
268
+ ]
269
+
270
+ images = [
271
+ Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
272
+ for path in paths
273
+ ]
274
+
275
+ _CACHE.set('top_images', cache_key, images)
276
+ return images
277
 
278
+ # UI Event Handlers
279
  def plot_activation_distribution(
280
+ evt: Optional[gr.EventData],
281
+ selected_image: str,
282
+ model_name: str
283
+ ) -> go.Figure:
284
+ """Plot activation distributions for both models."""
285
  fig = make_subplots(
286
  rows=2,
287
  cols=1,
 
289
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
290
  )
291
 
292
+ def get_activations(evt, selected_image, model_name, colors):
293
+ activation = get_activation_distribution(selected_image, model_name)
294
+ all_activation = activation.mean(0)
295
+
296
+ tile_activations = None
297
+ grid_x = None
298
+ grid_y = None
299
+
300
+ if evt is not None and evt._data is not None:
301
+ image = _CACHE.get('data_dict', selected_image, {}).get("image")
302
+ if image:
303
+ grid_x, grid_y, _, _ = get_grid_loc(evt, image)
304
+ token_idx = grid_y * GRID_NUM + grid_x + 1
305
+ tile_activations = activation[token_idx]
306
+
307
+ return plot_activations(
308
+ all_activation,
309
+ tile_activations,
310
+ grid_x,
311
+ grid_y,
312
+ top_k=5,
313
+ model_name=model_name,
314
+ colors=colors,
315
+ )
316
+
317
+ fig_clip = get_activations(evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef"))
318
+ fig_maple = get_activations(evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4"))
319
 
320
  def _attach_fig(fig, sub_fig, row, col, yref):
321
  for trace in sub_fig.data:
322
  fig.add_trace(trace, row=row, col=col)
 
323
  for annotation in sub_fig.layout.annotations:
324
  annotation.update(yref=yref)
325
  fig.add_annotation(annotation)
 
333
  fig.update_yaxes(title_text="Activation Value", row=1, col=1)
334
  fig.update_yaxes(title_text="Activation Value", row=2, col=1)
335
  fig.update_layout(
 
 
336
  template="plotly_white",
337
  showlegend=True,
338
  legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
 
341
 
342
  return fig
343
 
344
+ def show_activation_heatmap_clip(
345
+ selected_image: str,
346
+ slider_value: str,
347
+ toggle_btn: bool
348
+ ):
349
+ """Show activation heatmap for CLIP model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  rgba_overlay, top_images, act_values = show_activation_heatmap(
351
  selected_image, slider_value, "CLIP", toggle_btn
352
  )
 
361
  act_values[2],
362
  )
363
 
364
+ def show_activation_heatmap(
365
+ selected_image: str,
366
+ slider_value: str,
367
+ model_type: str,
368
+ toggle_btn: bool = False
369
+ ) -> Tuple[np.ndarray, List[Image.Image], List[str]]:
370
+ """Show activation heatmap with segmentation mask and top images."""
371
+ slider_value = int(slider_value.split("-")[-1])
372
+ rgba_overlay = get_segmask(selected_image, slider_value, model_type)
373
+ top_images = get_top_images(slider_value, toggle_btn)
374
+
375
+ act_values = []
376
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
377
+ act_value = _CACHE.get('sae_data_dict', "mean_act_values", {}).get(dataset, np.array([]))[slider_value, :5]
378
+ act_value = [str(round(value, 3)) for value in act_value]
379
+ act_value = " | ".join(act_value)
380
+ out = f"#### Activation values: {act_value}"
381
+ act_values.append(out)
382
 
383
+ return rgba_overlay, top_images, act_values
384
+
385
+ def show_activation_heatmap_maple(
386
+ selected_image: str,
387
+ slider_value: str,
388
+ model_name: str
389
+ ) -> np.ndarray:
390
+ """Show activation heatmap for MaPLE model."""
391
  slider_value = int(slider_value.split("-")[-1])
392
  rgba_overlay = get_segmask(selected_image, slider_value, model_name)
393
  sleep(0.1)
394
  return rgba_overlay
395
 
396
+ def get_init_radio_options(selected_image: str, model_name: str) -> List[str]:
397
+ """Get initial radio options for UI."""
398
  clip_neuron_dict = {}
399
  maple_neuron_dict = {}
400
 
401
+ def _get_top_activation(selected_image: str, model_name: str, neuron_dict: Dict, top_k: int = 5) -> Dict:
402
  activations = get_activation_distribution(selected_image, model_name).mean(0)
403
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
404
  for top_neuron in top_neurons:
405
  neuron_dict[top_neuron] = activations[top_neuron]
406
+ return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
 
 
 
 
 
 
 
 
 
 
407
 
408
+ clip_neuron_dict = _get_top_activation(selected_image, "CLIP", clip_neuron_dict)
409
+ maple_neuron_dict = _get_top_activation(selected_image, model_name, maple_neuron_dict)
410
 
411
+ return get_radio_names(clip_neuron_dict, maple_neuron_dict)
412
 
413
+ def get_radio_names(
414
+ clip_neuron_dict: Dict[int, float],
415
+ maple_neuron_dict: Dict[int, float]
416
+ ) -> List[str]:
417
+ """Generate radio button names based on neuron activations."""
418
  clip_keys = list(clip_neuron_dict.keys())
419
  maple_keys = list(maple_neuron_dict.keys())
420
 
421
  common_keys = list(set(clip_keys).intersection(set(maple_keys)))
422
+ clip_only_keys = list(set(clip_keys) - set(maple_keys))
423
+ maple_only_keys = list(set(maple_keys) - set(clip_keys))
424
 
425
+ common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
 
 
426
  clip_only_keys.sort(reverse=True)
427
  maple_only_keys.sort(reverse=True)
428
 
 
433
 
434
  return out
435
 
436
+ def update_radio_options(
437
+ evt: Optional[gr.EventData],
438
+ selected_image: str,
439
+ model_name: str
440
+ ) -> gr.Radio:
441
+ """Update radio options based on user interaction."""
442
+ clip_neuron_dict = {}
443
+ maple_neuron_dict = {}
444
 
445
+ def _get_top_activation(evt, selected_image, model_name, neuron_dict):
 
 
 
 
 
 
446
  all_activation = get_activation_distribution(selected_image, model_name)
447
  image_activation = all_activation.mean(0)
448
+ top_neurons = list(np.argsort(image_activation)[::-1][:5])
449
+ for top_neuron in top_neurons:
450
+ neuron_dict[top_neuron] = image_activation[top_neuron]
451
 
452
+ if evt is not None and evt._data is not None and isinstance(evt._data["index"], list):
453
+ image = _CACHE.get('data_dict', selected_image, {}).get("image")
454
+ if image:
455
+ grid_x, grid_y, _, _ = get_grid_loc(evt, image)
456
  token_idx = grid_y * GRID_NUM + grid_x + 1
457
  tile_activations = all_activation[token_idx]
458
+ top_tile_neurons = list(np.argsort(tile_activations)[::-1][:5])
459
+ for top_neuron in top_tile_neurons:
460
+ neuron_dict[top_neuron] = tile_activations[top_neuron]
 
 
 
 
 
 
 
 
 
 
461
 
462
+ return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
 
 
 
 
 
 
 
 
 
 
 
463
 
464
+ clip_neuron_dict = _get_top_activation(evt, selected_image, "CLIP", clip_neuron_dict)
465
+ maple_neuron_dict = _get_top_activation(evt, selected_image, model_name, maple_neuron_dict)
 
 
 
 
 
 
 
 
466
 
467
+ radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
468
+ return gr.Radio(choices=radio_choices, label="Top activating SAE latent", value=radio_choices[0])
469
 
470
+ def update_markdown(option_value: str) -> Tuple[str, str]:
471
+ """Update markdown text based on selected option."""
472
  latent_idx = int(option_value.split("-")[-1])
473
  out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
474
  out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
475
  return out_1, out_2
476
 
477
+ def update_all(
478
+ selected_image: str,
479
+ slider_value: str,
480
+ toggle_btn: bool,
481
+ model_name: str
482
+ ) -> Tuple:
483
+ """Update all UI components."""
 
 
 
 
 
484
  (
485
  seg_mask_display,
486
  top_image_1,
 
490
  act_value_2,
491
  act_value_3,
492
  ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
493
+
494
  seg_mask_display_maple = show_activation_heatmap_maple(
495
  selected_image, slider_value, model_name
496
  )
 
509
  markdown_display_2,
510
  )
511
 
512
+ def monitor_memory_usage():
513
+ """Monitor memory usage and clean cache if necessary."""
514
+ process = psutil.Process()
515
+ mem_info = process.memory_info()
516
+ mem_percent = process.memory_percent()
517
+
518
+ print(f"""
519
+ Memory Usage:
520
+ - RSS: {mem_info.rss / (1024**2):.2f} MB
521
+ - VMS: {mem_info.vms / (1024**2):.2f} MB
522
+ - Percent: {mem_percent:.1f}%
523
+ - Cache Sizes: {[len(cache) for cache in _CACHE.data.values()]}
524
+ """)
525
+
526
+ if mem_percent > 80:
527
+ print("Memory usage too high, clearing caches...")
528
+ _CACHE.clear_category('segmasks')
529
+ _CACHE.clear_category('top_images')
530
+ _CACHE.clear_category('precomputed_activations')
531
+
532
+ def start_memory_monitor(interval: int = 300):
533
+ """Start periodic memory monitoring."""
534
+ monitor_memory_usage()
535
+ threading.Timer(interval, start_memory_monitor).start()
536
+
537
+ # Initialize the application
538
+ data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=PKL_ROOT)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  default_image_name = "christmas-imagenet"
540
 
541
+
542
+ # Create the Gradio interface
543
  with gr.Blocks(
544
  theme=gr.themes.Citrus(),
545
  css="""
546
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
547
+ .image-row img { width: auto; height: 50px; }
548
  """,
549
  ) as demo:
550
  with gr.Row():
551
  with gr.Column():
 
552
  gr.Markdown("## Select input image and patch on the image")
553
  image_selector = gr.Dropdown(
554
+ choices=list(_CACHE.data['data_dict'].keys()),
555
  value=default_image_name,
556
  label="Select Image",
557
  )
558
  image_display = gr.Image(
559
+ value=_CACHE.get('data_dict', default_image_name, {}).get("image"),
560
  type="pil",
561
  interactive=True,
562
  )
563
 
 
564
  image_selector.change(
565
+ fn=lambda img_name: _CACHE.get('data_dict', img_name, {}).get("image"),
566
  inputs=image_selector,
567
  outputs=image_display,
568
  )
569
  image_display.select(
570
+ fn=highlight_grid,
571
+ inputs=[image_selector],
572
+ outputs=[image_display]
573
  )
574
 
575
  with gr.Column():
 
580
  value=model_options[0],
581
  label="Select adapted model (MaPLe)",
582
  )
583
+ init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
584
+ neuron_plot = gr.Plot(value=init_plot, show_label=False)
 
 
 
 
585
 
586
  image_selector.change(
587
  fn=plot_activation_distribution,
 
594
  outputs=neuron_plot,
595
  )
596
  model_selector.change(
597
+ fn=lambda img_name: _CACHE.get('data_dict', img_name, {}).get("image"),
598
+ inputs=[image_selector],
599
+ outputs=image_display,
600
  )
601
  model_selector.change(
602
  fn=plot_activation_distribution,
 
607
  with gr.Row():
608
  with gr.Column():
609
  radio_names = get_init_radio_options(default_image_name, model_options[0])
610
+ feature_idx = radio_names[0].split("-")[-1]
 
611
  markdown_display = gr.Markdown(
612
+ f"## Segmentation mask for the selected SAE latent - {feature_idx}"
613
  )
614
  init_seg, init_tops, init_values = show_activation_heatmap(
615
  default_image_name, radio_names[0], "CLIP"
 
621
  default_image_name, radio_names[0], model_options[0]
622
  )
623
  gr.Markdown("### Localize SAE latent activation using MaPLE")
624
+ seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
 
 
625
 
626
  with gr.Column():
627
  gr.Markdown("## Top activating SAE latent index")
 
628
  radio_choices = gr.Radio(
629
  choices=radio_names,
630
  label="Top activating SAE latent",
 
632
  value=radio_names[0],
633
  )
634
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
 
635
  markdown_display_2 = gr.Markdown(
636
+ f"## Top reference images for the selected SAE latent - {feature_idx}"
637
  )
638
 
639
  gr.Markdown("### ImageNet")
640
+ top_image_1 = gr.Image(value=init_tops[0], type="pil", show_label=False)
 
 
641
  act_value_1 = gr.Markdown(init_values[0])
642
 
643
  gr.Markdown("### ImageNet-Sketch")
644
+ top_image_2 = gr.Image(value=init_tops[1], type="pil", show_label=False)
 
 
 
 
 
645
  act_value_2 = gr.Markdown(init_values[1])
646
 
647
  gr.Markdown("### Caltech101")
648
+ top_image_3 = gr.Image(value=init_tops[2], type="pil", show_label=False)
 
 
649
  act_value_3 = gr.Markdown(init_values[2])
650
 
651
+ # Event handlers
652
  image_display.select(
653
  fn=update_radio_options,
654
  inputs=[image_selector, model_selector],
655
  outputs=[radio_choices],
656
  )
 
657
  model_selector.change(
658
  fn=update_radio_options,
659
  inputs=[image_selector, model_selector],
660
  outputs=[radio_choices],
661
  )
 
662
  image_selector.select(
663
  fn=update_radio_options,
664
  inputs=[image_selector, model_selector],
665
  outputs=[radio_choices],
666
  )
667
+ radio_choices.change(
668
+ fn=update_all,
669
+ inputs=[image_selector, radio_choices, toggle_btn, model_selector],
670
+ outputs=[
671
+ seg_mask_display,
672
+ seg_mask_display_maple,
673
+ top_image_1,
674
+ top_image_2,
675
+ top_image_3,
676
+ act_value_1,
677
+ act_value_2,
678
+ act_value_3,
679
+ markdown_display,
680
+ markdown_display_2,
681
+ ],
682
+ )
683
 
684
+ toggle_btn.change(
685
+ fn=show_activation_heatmap_clip,
686
+ inputs=[image_selector, radio_choices, toggle_btn],
687
+ outputs=[
688
+ seg_mask_display,
689
+ top_image_1,
690
+ top_image_2,
691
+ top_image_3,
692
+ act_value_1,
693
+ act_value_2,
694
+ act_value_3,
695
+ ],
696
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
  if __name__ == "__main__":
699
+ # Initialize memory monitoring
700
+ start_memory_monitor()
701
 
702
  # Get system memory info
703
  mem = psutil.virtual_memory()
704
  total_ram_gb = mem.total / (1024**3)
705
 
706
+ try:
707
+ print("Starting application initialization...")
708
+
709
+ # Precompute common data
710
+ print("Precomputing activation patterns...")
711
+ for image_name in _CACHE.data['data_dict'].keys():
712
+ for model_name in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
713
+ try:
714
+ activation = get_activation_distribution(image_name, model_name)
715
+ cache_key = f"activation_{model_name}_{image_name}"
716
+ _CACHE.set('precomputed_activations', cache_key, activation.mean(0))
717
+ except Exception as e:
718
+ print(f"Error precomputing activation for {image_name}, {model_name}: {e}")
719
+
720
+ print("Starting Gradio interface...")
721
+ # Launch the app with optimized settings
722
+ demo.queue(max_size=min(20, int(total_ram_gb)))
723
+ demo.launch(
724
+ server_name="0.0.0.0",
725
+ server_port=7860,
726
+ share=False,
727
+ show_error=True,
728
+ max_threads=min(16, psutil.cpu_count()),
729
+ websocket_ping_timeout=60,
730
+ preventive_refresh=True,
731
+ memory_limit_mb=int(total_ram_gb * 1024 * 0.8) # Use up to 80% of RAM
732
+ )
733
+ except Exception as e:
734
+ print(f"Critical error during startup: {e}")
735
+ # Attempt to clean up resources
736
+ _CACHE.data.clear()
737
+ raise