hyesulim commited on
Commit
cabf670
·
verified ·
1 Parent(s): 9fc25a3

test: revert to base

Browse files
Files changed (1) hide show
  1. app.py +340 -436
app.py CHANGED
@@ -2,162 +2,56 @@ import gzip
2
  import os
3
  import pickle
4
  from glob import glob
5
- import threading
6
- import psutil
7
- from functools import lru_cache
8
- import concurrent.futures
9
- from typing import Dict, Tuple, List, Optional
10
  from time import sleep
11
 
12
  import gradio as gr
13
  import numpy as np
 
14
  import torch
15
  from PIL import Image, ImageDraw
16
- import plotly.graph_objects as go
17
  from plotly.subplots import make_subplots
18
 
19
- # Constants
20
  IMAGE_SIZE = 400
21
  DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
22
  GRID_NUM = 14
23
- PKL_ROOT = "./data/out"
24
-
25
- # Global cache with better type hints and error handling
26
- class Cache:
27
- def __init__(self):
28
- self.data: Dict[str, Dict] = {
29
- 'data_dict': {},
30
- 'sae_data_dict': {},
31
- 'model_data': {},
32
- 'segmasks': {},
33
- 'top_images': {},
34
- 'precomputed_activations': {}
35
- }
36
-
37
- def get(self, category: str, key: str, default=None):
38
- try:
39
- return self.data[category].get(key, default)
40
- except KeyError:
41
- return default
42
-
43
- def set(self, category: str, key: str, value):
44
- try:
45
- self.data[category][key] = value
46
- except KeyError:
47
- self.data[category] = {key: value}
48
-
49
- def clear_category(self, category: str):
50
- if category in self.data:
51
- self.data[category].clear()
52
-
53
- _CACHE = Cache()
54
-
55
- def load_all_data(image_root: str, pkl_root: str) -> Tuple[Dict, Dict]:
56
- """Load all data with optimized parallel processing."""
57
- def load_image_file(image_file: str) -> Optional[Dict]:
58
- try:
59
- image = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
60
- return {
61
- "image": image,
62
- "image_path": image_file,
63
- }
64
- except Exception as e:
65
- print(f"Error loading image {image_file}: {e}")
66
- return None
67
-
68
- # Load images in parallel
69
- with concurrent.futures.ThreadPoolExecutor() as executor:
70
- future_to_file = {
71
- executor.submit(load_image_file, image_file): image_file
72
- for image_file in glob(f"{image_root}/*")
73
- }
74
-
75
- for future in concurrent.futures.as_completed(future_to_file):
76
- try:
77
- image_file = future_to_file[future]
78
- image_name = os.path.basename(image_file).split(".")[0]
79
- result = future.result()
80
- if result:
81
- _CACHE.set('data_dict', image_name, result)
82
- except Exception as e:
83
- print(f"Error processing image future: {e}")
84
-
85
- # Load SAE data
86
- try:
87
- with open("./data/sae_data/mean_acts.pkl", "rb") as f:
88
- _CACHE.set('sae_data_dict', "mean_acts", pickle.load(f))
89
- except Exception as e:
90
- print(f"Error loading mean_acts.pkl: {e}")
91
-
92
- # Load mean act values
93
- datasets = ["imagenet", "imagenet-sketch", "caltech101"]
94
- for dataset in datasets:
95
- try:
96
- with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
97
- if "mean_act_values" not in _CACHE.data['sae_data_dict']:
98
- _CACHE.set('sae_data_dict', "mean_act_values", {})
99
- _CACHE.data['sae_data_dict']["mean_act_values"][dataset] = pickle.load(f)
100
- except Exception as e:
101
- print(f"Error loading mean act values for {dataset}: {e}")
102
-
103
- return _CACHE.data['data_dict'], _CACHE.data['sae_data_dict']
104
-
105
- @lru_cache(maxsize=1024)
106
- def get_data(image_name: str, model_name: str) -> np.ndarray:
107
- """Get model data with caching."""
108
- cache_key = f"{model_name}_{image_name}"
109
- if cache_key not in _CACHE.data['model_data']:
110
- try:
111
- data_dir = f"{PKL_ROOT}/{model_name}/{image_name}.pkl.gz"
112
- with gzip.open(data_dir, "rb") as f:
113
- _CACHE.data['model_data'][cache_key] = pickle.load(f)
114
- except Exception as e:
115
- print(f"Error loading model data for {cache_key}: {e}")
116
- return np.array([])
117
- return _CACHE.data['model_data'][cache_key]
118
-
119
- @lru_cache(maxsize=1024)
120
- def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
121
- """Get activation distribution with memory optimization."""
122
- try:
123
- data = get_data(image_name, model_type)
124
- if isinstance(data, (list, tuple)):
125
- activation = data[0]
126
- else:
127
- activation = data
128
-
129
- if not isinstance(activation, np.ndarray):
130
- activation = np.array(activation)
131
-
132
- mean_acts = _CACHE.get('sae_data_dict', "mean_acts", {}).get("imagenet", np.array([]))
133
-
134
- if mean_acts.size > 0 and activation.size > 0:
135
- noisy_features_indices = np.where(mean_acts > 0.1)[0]
136
- if activation.ndim >= 2:
137
- activation[:, noisy_features_indices] = 0
138
-
139
- return activation
140
- except Exception as e:
141
- print(f"Error getting activation distribution: {e}")
142
- return np.array([])
143
 
144
- def get_grid_loc(evt: gr.SelectData, image: Image.Image) -> Tuple[int, int, int, int]:
145
- """Get grid location from click event."""
146
- x, y = evt.index[0], evt.index[1]
147
  cell_width = image.width // GRID_NUM
148
  cell_height = image.height // GRID_NUM
 
149
  grid_x = x // cell_width
150
  grid_y = y // cell_height
151
  return grid_x, grid_y, cell_width, cell_height
152
 
153
- def highlight_grid(evt: gr.SelectData, image_name: str) -> Image.Image:
154
- """Highlight selected grid cell."""
155
- image = _CACHE.get('data_dict', image_name, {}).get("image")
156
- if not image:
157
- return None
158
-
159
  grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
160
-
161
  highlighted_image = image.copy()
162
  draw = ImageDraw.Draw(highlighted_image)
163
  box = [
@@ -167,18 +61,25 @@ def highlight_grid(evt: gr.SelectData, image_name: str) -> Image.Image:
167
  (grid_y + 1) * cell_height,
168
  ]
169
  draw.rectangle(box, outline="red", width=3)
 
170
  return highlighted_image
171
 
 
 
 
 
 
 
 
172
  def plot_activations(
173
- all_activation: np.ndarray,
174
- tile_activations: Optional[np.ndarray] = None,
175
- grid_x: Optional[int] = None,
176
- grid_y: Optional[int] = None,
177
- top_k: int = 5,
178
- colors: Tuple[str, str] = ("blue", "cyan"),
179
- model_name: str = "CLIP",
180
- ) -> go.Figure:
181
- """Plot activation distributions."""
182
  fig = go.Figure()
183
 
184
  def _add_scatter_with_annotation(fig, activations, model_name, color, label):
@@ -207,90 +108,59 @@ def plot_activations(
207
  )
208
  return fig
209
 
210
- label = f"{model_name.split('-')[-1]} Image-level"
211
- fig = _add_scatter_with_annotation(fig, all_activation, model_name, colors[0], label)
212
-
 
213
  if tile_activations is not None:
214
- label = f"{model_name.split('-')[-1]} Tile ({grid_x}, {grid_y})"
215
- fig = _add_scatter_with_annotation(fig, tile_activations, model_name, colors[1], label)
 
 
216
 
217
  fig.update_layout(
218
  title="Activation Distribution",
219
  xaxis_title="SAE latent index",
220
  yaxis_title="Activation Value",
221
  template="plotly_white",
 
 
222
  legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
223
  )
224
 
225
  return fig
226
 
227
- def get_segmask(selected_image: str, slider_value: int, model_type: str) -> Optional[np.ndarray]:
228
- """Get segmentation mask with caching."""
229
- cache_key = f"{selected_image}_{slider_value}_{model_type}"
230
- cached_mask = _CACHE.get('segmasks', cache_key)
231
- if cached_mask is not None:
232
- return cached_mask
233
 
234
- try:
235
- image = _CACHE.get('data_dict', selected_image, {}).get("image")
236
- if image is None:
237
- return None
238
-
239
- sae_act = get_data(selected_image, model_type)[0]
240
- temp = sae_act[:, slider_value]
241
-
242
- mask = torch.tensor(temp[1:].reshape(14, 14)).view(1, 1, 14, 14)
243
- mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
244
-
245
- if mask.size == 0:
246
- return None
247
-
248
- mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
249
-
250
- base_opacity = 30
251
- image_array = np.array(image)[..., :3]
252
- rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
253
- rgba_overlay[..., :3] = image_array
254
-
255
- darkened_image = (image_array * (base_opacity / 255)).astype(np.uint8)
256
- rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
257
- rgba_overlay[..., 3] = 255
258
-
259
- _CACHE.set('segmasks', cache_key, rgba_overlay)
260
- return rgba_overlay
261
-
262
- except Exception as e:
263
- print(f"Error generating segmentation mask: {e}")
264
- return None
265
-
266
- def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]:
267
- """Get top images with caching."""
268
- cache_key = f"{slider_value}_{toggle_btn}"
269
- cached_images = _CACHE.get('top_images', cache_key)
270
- if cached_images is not None:
271
- return cached_images
272
-
273
- dataset_path = "./data/top_images_masked" if toggle_btn else "./data/top_images"
274
- paths = [
275
- os.path.join(dataset_path, dataset, f"{slider_value}.jpg")
276
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
277
- ]
278
-
279
- images = [
280
- Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
281
- for path in paths
282
- ]
283
-
284
- _CACHE.set('top_images', cache_key, images)
285
- return images
286
 
287
- # UI Event Handlers
288
  def plot_activation_distribution(
289
- evt: Optional[gr.EventData],
290
- selected_image: str,
291
- model_name: str
292
- ) -> go.Figure:
293
- """Plot activation distributions for both models."""
294
  fig = make_subplots(
295
  rows=2,
296
  cols=1,
@@ -298,37 +168,17 @@ def plot_activation_distribution(
298
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
299
  )
300
 
301
- def get_activations(evt, selected_image, model_name, colors):
302
- activation = get_activation_distribution(selected_image, model_name)
303
- all_activation = activation.mean(0)
304
-
305
- tile_activations = None
306
- grid_x = None
307
- grid_y = None
308
-
309
- if evt is not None and evt._data is not None:
310
- image = _CACHE.get('data_dict', selected_image, {}).get("image")
311
- if image:
312
- grid_x, grid_y, _, _ = get_grid_loc(evt, image)
313
- token_idx = grid_y * GRID_NUM + grid_x + 1
314
- tile_activations = activation[token_idx]
315
-
316
- return plot_activations(
317
- all_activation,
318
- tile_activations,
319
- grid_x,
320
- grid_y,
321
- top_k=5,
322
- model_name=model_name,
323
- colors=colors,
324
- )
325
-
326
- fig_clip = get_activations(evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef"))
327
- fig_maple = get_activations(evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4"))
328
 
329
  def _attach_fig(fig, sub_fig, row, col, yref):
330
  for trace in sub_fig.data:
331
  fig.add_trace(trace, row=row, col=col)
 
332
  for annotation in sub_fig.layout.annotations:
333
  annotation.update(yref=yref)
334
  fig.add_annotation(annotation)
@@ -342,6 +192,8 @@ def plot_activation_distribution(
342
  fig.update_yaxes(title_text="Activation Value", row=1, col=1)
343
  fig.update_yaxes(title_text="Activation Value", row=2, col=1)
344
  fig.update_layout(
 
 
345
  template="plotly_white",
346
  showlegend=True,
347
  legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
@@ -350,12 +202,73 @@ def plot_activation_distribution(
350
 
351
  return fig
352
 
353
- def show_activation_heatmap_clip(
354
- selected_image: str,
355
- slider_value: str,
356
- toggle_btn: bool
357
- ):
358
- """Show activation heatmap for CLIP model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  rgba_overlay, top_images, act_values = show_activation_heatmap(
360
  selected_image, slider_value, "CLIP", toggle_btn
361
  )
@@ -370,68 +283,49 @@ def show_activation_heatmap_clip(
370
  act_values[2],
371
  )
372
 
373
- def show_activation_heatmap(
374
- selected_image: str,
375
- slider_value: str,
376
- model_type: str,
377
- toggle_btn: bool = False
378
- ) -> Tuple[np.ndarray, List[Image.Image], List[str]]:
379
- """Show activation heatmap with segmentation mask and top images."""
380
- slider_value = int(slider_value.split("-")[-1])
381
- rgba_overlay = get_segmask(selected_image, slider_value, model_type)
382
- top_images = get_top_images(slider_value, toggle_btn)
383
-
384
- act_values = []
385
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
386
- act_value = _CACHE.get('sae_data_dict', "mean_act_values", {}).get(dataset, np.array([]))[slider_value, :5]
387
- act_value = [str(round(value, 3)) for value in act_value]
388
- act_value = " | ".join(act_value)
389
- out = f"#### Activation values: {act_value}"
390
- act_values.append(out)
391
-
392
- return rgba_overlay, top_images, act_values
393
 
394
- def show_activation_heatmap_maple(
395
- selected_image: str,
396
- slider_value: str,
397
- model_name: str
398
- ) -> np.ndarray:
399
- """Show activation heatmap for MaPLE model."""
400
  slider_value = int(slider_value.split("-")[-1])
401
  rgba_overlay = get_segmask(selected_image, slider_value, model_name)
402
  sleep(0.1)
403
  return rgba_overlay
404
 
405
- def get_init_radio_options(selected_image: str, model_name: str) -> List[str]:
406
- """Get initial radio options for UI."""
407
  clip_neuron_dict = {}
408
  maple_neuron_dict = {}
409
 
410
- def _get_top_activation(selected_image: str, model_name: str, neuron_dict: Dict, top_k: int = 5) -> Dict:
411
  activations = get_activation_distribution(selected_image, model_name).mean(0)
412
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
413
  for top_neuron in top_neurons:
414
  neuron_dict[top_neuron] = activations[top_neuron]
415
- return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
 
 
 
416
 
417
- clip_neuron_dict = _get_top_activation(selected_image, "CLIP", clip_neuron_dict)
418
- maple_neuron_dict = _get_top_activation(selected_image, model_name, maple_neuron_dict)
 
 
 
 
419
 
420
- return get_radio_names(clip_neuron_dict, maple_neuron_dict)
421
 
422
- def get_radio_names(
423
- clip_neuron_dict: Dict[int, float],
424
- maple_neuron_dict: Dict[int, float]
425
- ) -> List[str]:
426
- """Generate radio button names based on neuron activations."""
427
  clip_keys = list(clip_neuron_dict.keys())
428
  maple_keys = list(maple_neuron_dict.keys())
429
 
430
  common_keys = list(set(clip_keys).intersection(set(maple_keys)))
431
- clip_only_keys = list(set(clip_keys) - set(maple_keys))
432
- maple_only_keys = list(set(maple_keys) - set(clip_keys))
433
 
434
- common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
 
 
435
  clip_only_keys.sort(reverse=True)
436
  maple_only_keys.sort(reverse=True)
437
 
@@ -442,54 +336,81 @@ def get_radio_names(
442
 
443
  return out
444
 
445
- def update_radio_options(
446
- evt: Optional[gr.EventData],
447
- selected_image: str,
448
- model_name: str
449
- ) -> gr.Radio:
450
- """Update radio options based on user interaction."""
451
- clip_neuron_dict = {}
452
- maple_neuron_dict = {}
453
 
454
- def _get_top_activation(evt, selected_image, model_name, neuron_dict):
 
 
 
 
 
 
455
  all_activation = get_activation_distribution(selected_image, model_name)
456
  image_activation = all_activation.mean(0)
457
- top_neurons = list(np.argsort(image_activation)[::-1][:5])
458
- for top_neuron in top_neurons:
459
- neuron_dict[top_neuron] = image_activation[top_neuron]
460
 
461
- if evt is not None and evt._data is not None and isinstance(evt._data["index"], list):
462
- image = _CACHE.get('data_dict', selected_image, {}).get("image")
463
- if image:
464
- grid_x, grid_y, _, _ = get_grid_loc(evt, image)
465
  token_idx = grid_y * GRID_NUM + grid_x + 1
466
  tile_activations = all_activation[token_idx]
467
- top_tile_neurons = list(np.argsort(tile_activations)[::-1][:5])
468
- for top_neuron in top_tile_neurons:
469
- neuron_dict[top_neuron] = tile_activations[top_neuron]
470
 
471
- return dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
 
 
 
 
 
 
 
 
 
 
472
 
473
- clip_neuron_dict = _get_top_activation(evt, selected_image, "CLIP", clip_neuron_dict)
474
- maple_neuron_dict = _get_top_activation(evt, selected_image, model_name, maple_neuron_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
- radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
477
- return gr.Radio(choices=radio_choices, label="Top activating SAE latent", value=radio_choices[0])
478
 
479
- def update_markdown(option_value: str) -> Tuple[str, str]:
480
- """Update markdown text based on selected option."""
481
  latent_idx = int(option_value.split("-")[-1])
482
  out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
483
  out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
484
  return out_1, out_2
485
 
486
- def update_all(
487
- selected_image: str,
488
- slider_value: str,
489
- toggle_btn: bool,
490
- model_name: str
491
- ) -> Tuple:
492
- """Update all UI components."""
 
 
 
 
 
493
  (
494
  seg_mask_display,
495
  top_image_1,
@@ -499,7 +420,6 @@ def update_all(
499
  act_value_2,
500
  act_value_3,
501
  ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
502
-
503
  seg_mask_display_maple = show_activation_heatmap_maple(
504
  selected_image, slider_value, model_name
505
  )
@@ -518,66 +438,66 @@ def update_all(
518
  markdown_display_2,
519
  )
520
 
521
- def monitor_memory_usage():
522
- """Monitor memory usage and clean cache if necessary."""
523
- process = psutil.Process()
524
- mem_info = process.memory_info()
525
- mem_percent = process.memory_percent()
526
-
527
- print(f"""
528
- Memory Usage:
529
- - RSS: {mem_info.rss / (1024**2):.2f} MB
530
- - VMS: {mem_info.vms / (1024**2):.2f} MB
531
- - Percent: {mem_percent:.1f}%
532
- - Cache Sizes: {[len(cache) for cache in _CACHE.data.values()]}
533
- """)
534
-
535
- if mem_percent > 80:
536
- print("Memory usage too high, clearing caches...")
537
- _CACHE.clear_category('segmasks')
538
- _CACHE.clear_category('top_images')
539
- _CACHE.clear_category('precomputed_activations')
540
-
541
- def start_memory_monitor(interval: int = 300):
542
- """Start periodic memory monitoring."""
543
- monitor_memory_usage()
544
- threading.Timer(interval, start_memory_monitor).start()
545
-
546
- # Initialize the application
547
- data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=PKL_ROOT)
548
  default_image_name = "christmas-imagenet"
549
 
550
- # Create the Gradio interface
551
  with gr.Blocks(
552
  theme=gr.themes.Citrus(),
553
  css="""
554
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
555
- .image-row img { width: auto; height: 50px; }
556
  """,
557
  ) as demo:
558
  with gr.Row():
559
  with gr.Column():
 
560
  gr.Markdown("## Select input image and patch on the image")
561
  image_selector = gr.Dropdown(
562
- choices=list(_CACHE.data['data_dict'].keys()),
563
  value=default_image_name,
564
  label="Select Image",
565
  )
566
  image_display = gr.Image(
567
- value=_CACHE.get('data_dict', default_image_name, {}).get("image"),
568
  type="pil",
569
  interactive=True,
570
  )
571
 
 
572
  image_selector.change(
573
- fn=lambda img_name: _CACHE.get('data_dict', img_name, {}).get("image"),
574
  inputs=image_selector,
575
  outputs=image_display,
576
  )
577
  image_display.select(
578
- fn=highlight_grid,
579
- inputs=[image_selector],
580
- outputs=[image_display]
581
  )
582
 
583
  with gr.Column():
@@ -588,8 +508,12 @@ with gr.Blocks(
588
  value=model_options[0],
589
  label="Select adapted model (MaPLe)",
590
  )
591
- init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
592
- neuron_plot = gr.Plot(value=init_plot, show_label=False)
 
 
 
 
593
 
594
  image_selector.change(
595
  fn=plot_activation_distribution,
@@ -602,9 +526,7 @@ with gr.Blocks(
602
  outputs=neuron_plot,
603
  )
604
  model_selector.change(
605
- fn=lambda img_name: _CACHE.get('data_dict', img_name, {}).get("image"),
606
- inputs=[image_selector],
607
- outputs=image_display,
608
  )
609
  model_selector.change(
610
  fn=plot_activation_distribution,
@@ -615,9 +537,10 @@ with gr.Blocks(
615
  with gr.Row():
616
  with gr.Column():
617
  radio_names = get_init_radio_options(default_image_name, model_options[0])
618
- feature_idx = radio_names[0].split("-")[-1]
 
619
  markdown_display = gr.Markdown(
620
- f"## Segmentation mask for the selected SAE latent - {feature_idx}"
621
  )
622
  init_seg, init_tops, init_values = show_activation_heatmap(
623
  default_image_name, radio_names[0], "CLIP"
@@ -629,10 +552,13 @@ with gr.Blocks(
629
  default_image_name, radio_names[0], model_options[0]
630
  )
631
  gr.Markdown("### Localize SAE latent activation using MaPLE")
632
- seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
 
 
633
 
634
  with gr.Column():
635
  gr.Markdown("## Top activating SAE latent index")
 
636
  radio_choices = gr.Radio(
637
  choices=radio_names,
638
  label="Top activating SAE latent",
@@ -640,103 +566,81 @@ with gr.Blocks(
640
  value=radio_names[0],
641
  )
642
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
 
643
  markdown_display_2 = gr.Markdown(
644
- f"## Top reference images for the selected SAE latent - {feature_idx}"
645
  )
646
 
647
  gr.Markdown("### ImageNet")
648
- top_image_1 = gr.Image(value=init_tops[0], type="pil", show_label=False)
 
 
649
  act_value_1 = gr.Markdown(init_values[0])
650
 
651
  gr.Markdown("### ImageNet-Sketch")
652
- top_image_2 = gr.Image(value=init_tops[1], type="pil", show_label=False)
 
 
 
 
 
653
  act_value_2 = gr.Markdown(init_values[1])
654
 
655
  gr.Markdown("### Caltech101")
656
- top_image_3 = gr.Image(value=init_tops[2], type="pil", show_label=False)
 
 
657
  act_value_3 = gr.Markdown(init_values[2])
658
 
659
- # Event handlers
660
  image_display.select(
661
  fn=update_radio_options,
662
  inputs=[image_selector, model_selector],
663
  outputs=[radio_choices],
664
  )
 
665
  model_selector.change(
666
  fn=update_radio_options,
667
  inputs=[image_selector, model_selector],
668
  outputs=[radio_choices],
669
  )
 
670
  image_selector.select(
671
  fn=update_radio_options,
672
  inputs=[image_selector, model_selector],
673
  outputs=[radio_choices],
674
  )
675
- radio_choices.change(
676
- fn=update_all,
677
- inputs=[image_selector, radio_choices, toggle_btn, model_selector],
678
- outputs=[
679
- seg_mask_display,
680
- seg_mask_display_maple,
681
- top_image_1,
682
- top_image_2,
683
- top_image_3,
684
- act_value_1,
685
- act_value_2,
686
- act_value_3,
687
- markdown_display,
688
- markdown_display_2,
689
- ],
690
- )
691
 
692
- toggle_btn.change(
693
- fn=show_activation_heatmap_clip,
694
- inputs=[image_selector, radio_choices, toggle_btn],
695
- outputs=[
696
- seg_mask_display,
697
- top_image_1,
698
- top_image_2,
699
- top_image_3,
700
- act_value_1,
701
- act_value_2,
702
- act_value_3,
703
- ],
704
- )
 
 
 
705
 
706
- if __name__ == "__main__":
707
- # Initialize memory monitoring
708
- start_memory_monitor()
709
-
710
- # Get system memory info
711
- mem = psutil.virtual_memory()
712
- total_ram_gb = mem.total / (1024**3)
713
-
714
- try:
715
- print("Starting application initialization...")
716
-
717
- # Precompute common data
718
- print("Precomputing activation patterns...")
719
- for image_name in _CACHE.data['data_dict'].keys():
720
- for model_name in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
721
- try:
722
- activation = get_activation_distribution(image_name, model_name)
723
- cache_key = f"activation_{model_name}_{image_name}"
724
- _CACHE.set('precomputed_activations', cache_key, activation.mean(0))
725
- except Exception as e:
726
- print(f"Error precomputing activation for {image_name}, {model_name}: {e}")
727
-
728
- print("Starting Gradio interface...")
729
- # Launch the app with optimized settings
730
- demo.queue(max_size=min(20, int(total_ram_gb)))
731
- demo.launch(
732
- server_name="0.0.0.0",
733
- server_port=7860,
734
- share=False,
735
- show_error=True,
736
- max_threads=min(16, psutil.cpu_count())
737
  )
738
- except Exception as e:
739
- print(f"Critical error during startup: {e}")
740
- # Attempt to clean up resources
741
- _CACHE.data.clear()
742
- raise
 
2
  import os
3
  import pickle
4
  from glob import glob
 
 
 
 
 
5
  from time import sleep
6
 
7
  import gradio as gr
8
  import numpy as np
9
+ import plotly.graph_objects as go
10
  import torch
11
  from PIL import Image, ImageDraw
 
12
  from plotly.subplots import make_subplots
13
 
 
14
  IMAGE_SIZE = 400
15
  DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
16
  GRID_NUM = 14
17
+ pkl_root = "./data/out"
18
+ preloaded_data = {}
19
+
20
+
21
+ def preload_activation(image_name):
22
+ for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
23
+ image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
24
+ with gzip.open(image_file, "rb") as f:
25
+ preloaded_data[model] = pickle.load(f)
26
+
27
+
28
+ def get_activation_distribution(image_name: str, model_type: str):
29
+ activation = get_data(image_name, model_type)[0]
30
+
31
+ noisy_features_indices = (
32
+ (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
33
+ )
34
+ activation[:, noisy_features_indices] = 0
35
+
36
+ return activation
37
+
38
+
39
+ def get_grid_loc(evt, image):
40
+ # Get click coordinates
41
+ x, y = evt._data["index"][0], evt._data["index"][1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
43
  cell_width = image.width // GRID_NUM
44
  cell_height = image.height // GRID_NUM
45
+
46
  grid_x = x // cell_width
47
  grid_y = y // cell_height
48
  return grid_x, grid_y, cell_width, cell_height
49
 
50
+
51
+ def highlight_grid(evt: gr.EventData, image_name):
52
+ image = data_dict[image_name]["image"]
 
 
 
53
  grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
54
+
55
  highlighted_image = image.copy()
56
  draw = ImageDraw.Draw(highlighted_image)
57
  box = [
 
61
  (grid_y + 1) * cell_height,
62
  ]
63
  draw.rectangle(box, outline="red", width=3)
64
+
65
  return highlighted_image
66
 
67
+
68
+ def load_image(img_name):
69
+ return Image.open(data_dict[img_name]["image_path"]).resize(
70
+ (IMAGE_SIZE, IMAGE_SIZE)
71
+ )
72
+
73
+
74
  def plot_activations(
75
+ all_activation,
76
+ tile_activations=None,
77
+ grid_x=None,
78
+ grid_y=None,
79
+ top_k=5,
80
+ colors=("blue", "cyan"),
81
+ model_name="CLIP",
82
+ ):
 
83
  fig = go.Figure()
84
 
85
  def _add_scatter_with_annotation(fig, activations, model_name, color, label):
 
108
  )
109
  return fig
110
 
111
+ label = f"{model_name.split('-')[-0]} Image-level"
112
+ fig = _add_scatter_with_annotation(
113
+ fig, all_activation, model_name, colors[0], label
114
+ )
115
  if tile_activations is not None:
116
+ label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
117
+ fig = _add_scatter_with_annotation(
118
+ fig, tile_activations, model_name, colors[1], label
119
+ )
120
 
121
  fig.update_layout(
122
  title="Activation Distribution",
123
  xaxis_title="SAE latent index",
124
  yaxis_title="Activation Value",
125
  template="plotly_white",
126
+ )
127
+ fig.update_layout(
128
  legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
129
  )
130
 
131
  return fig
132
 
 
 
 
 
 
 
133
 
134
+ def get_activations(evt: gr.EventData, selected_image: str, model_name: str, colors):
135
+ activation = get_activation_distribution(selected_image, model_name)
136
+ all_activation = activation.mean(0)
137
+
138
+ tile_activations = None
139
+ grid_x = None
140
+ grid_y = None
141
+
142
+ if evt is not None:
143
+ if evt._data is not None:
144
+ image = data_dict[selected_image]["image"]
145
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
146
+ token_idx = grid_y * GRID_NUM + grid_x + 1
147
+ tile_activations = activation[token_idx]
148
+
149
+ fig = plot_activations(
150
+ all_activation,
151
+ tile_activations,
152
+ grid_x,
153
+ grid_y,
154
+ top_k=5,
155
+ model_name=model_name,
156
+ colors=colors,
157
+ )
158
+ return fig
159
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
 
161
  def plot_activation_distribution(
162
+ evt: gr.EventData, selected_image: str, model_name: str
163
+ ):
 
 
 
164
  fig = make_subplots(
165
  rows=2,
166
  cols=1,
 
168
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
169
  )
170
 
171
+ fig_clip = get_activations(
172
+ evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
173
+ )
174
+ fig_maple = get_activations(
175
+ evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
176
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  def _attach_fig(fig, sub_fig, row, col, yref):
179
  for trace in sub_fig.data:
180
  fig.add_trace(trace, row=row, col=col)
181
+
182
  for annotation in sub_fig.layout.annotations:
183
  annotation.update(yref=yref)
184
  fig.add_annotation(annotation)
 
192
  fig.update_yaxes(title_text="Activation Value", row=1, col=1)
193
  fig.update_yaxes(title_text="Activation Value", row=2, col=1)
194
  fig.update_layout(
195
+ # height=500,
196
+ # title="Activation Distributions",
197
  template="plotly_white",
198
  showlegend=True,
199
  legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
 
202
 
203
  return fig
204
 
205
+
206
+ def get_segmask(selected_image, slider_value, model_type):
207
+ image = data_dict[selected_image]["image"]
208
+ sae_act = get_data(selected_image, model_type)[0]
209
+ temp = sae_act[:, slider_value]
210
+ try:
211
+ mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
212
+ except Exception as e:
213
+ print(sae_act.shape, slider_value)
214
+ mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
215
+ 0
216
+ ].numpy()
217
+ mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
218
+
219
+ base_opacity = 30
220
+ image_array = np.array(image)[..., :3]
221
+ rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
222
+ rgba_overlay[..., :3] = image_array[..., :3]
223
+
224
+ darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
225
+ rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
226
+ rgba_overlay[..., 3] = 255 # Fully opaque
227
+
228
+ return rgba_overlay
229
+
230
+
231
+ def get_top_images(slider_value, toggle_btn):
232
+ def _get_images(dataset_path):
233
+ top_image_paths = [
234
+ os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
235
+ os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
236
+ os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
237
+ ]
238
+ top_images = [
239
+ (
240
+ Image.open(path)
241
+ if os.path.exists(path)
242
+ else Image.new("RGB", (256, 256), (255, 255, 255))
243
+ )
244
+ for path in top_image_paths
245
+ ]
246
+ return top_images
247
+
248
+ if toggle_btn:
249
+ top_images = _get_images("./data/top_images_masked")
250
+ else:
251
+ top_images = _get_images("./data/top_images")
252
+ return top_images
253
+
254
+
255
+ def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
256
+ slider_value = int(slider_value.split("-")[-1])
257
+ rgba_overlay = get_segmask(selected_image, slider_value, model_type)
258
+ top_images = get_top_images(slider_value, toggle_btn)
259
+
260
+ act_values = []
261
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
262
+ act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5]
263
+ act_value = [str(round(value, 3)) for value in act_value]
264
+ act_value = " | ".join(act_value)
265
+ out = f"#### Activation values: {act_value}"
266
+ act_values.append(out)
267
+
268
+ return rgba_overlay, top_images, act_values
269
+
270
+
271
+ def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
272
  rgba_overlay, top_images, act_values = show_activation_heatmap(
273
  selected_image, slider_value, "CLIP", toggle_btn
274
  )
 
283
  act_values[2],
284
  )
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ def show_activation_heatmap_maple(selected_image, slider_value, model_name):
 
 
 
 
 
288
  slider_value = int(slider_value.split("-")[-1])
289
  rgba_overlay = get_segmask(selected_image, slider_value, model_name)
290
  sleep(0.1)
291
  return rgba_overlay
292
 
293
+
294
+ def get_init_radio_options(selected_image, model_name):
295
  clip_neuron_dict = {}
296
  maple_neuron_dict = {}
297
 
298
+ def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
299
  activations = get_activation_distribution(selected_image, model_name).mean(0)
300
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
301
  for top_neuron in top_neurons:
302
  neuron_dict[top_neuron] = activations[top_neuron]
303
+ sorted_dict = dict(
304
+ sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
305
+ )
306
+ return sorted_dict
307
 
308
+ clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
309
+ maple_neuron_dict = _get_top_actvation(
310
+ selected_image, model_name, maple_neuron_dict
311
+ )
312
+
313
+ radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
314
 
315
+ return radio_choices
316
 
317
+
318
+ def get_radio_names(clip_neuron_dict, maple_neuron_dict):
 
 
 
319
  clip_keys = list(clip_neuron_dict.keys())
320
  maple_keys = list(maple_neuron_dict.keys())
321
 
322
  common_keys = list(set(clip_keys).intersection(set(maple_keys)))
323
+ clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
324
+ maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
325
 
326
+ common_keys.sort(
327
+ key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
328
+ )
329
  clip_only_keys.sort(reverse=True)
330
  maple_only_keys.sort(reverse=True)
331
 
 
336
 
337
  return out
338
 
 
 
 
 
 
 
 
 
339
 
340
+ def update_radio_options(evt: gr.EventData, selected_image, model_name):
341
+ def _sort_and_save_top_k(activations, neuron_dict, top_k=5):
342
+ top_neurons = list(np.argsort(activations)[::-1][:top_k])
343
+ for top_neuron in top_neurons:
344
+ neuron_dict[top_neuron] = activations[top_neuron]
345
+
346
+ def _get_top_actvation(evt, selected_image, model_name, neuron_dict):
347
  all_activation = get_activation_distribution(selected_image, model_name)
348
  image_activation = all_activation.mean(0)
349
+ _sort_and_save_top_k(image_activation, neuron_dict)
 
 
350
 
351
+ if evt is not None:
352
+ if evt._data is not None and isinstance(evt._data["index"], list):
353
+ image = data_dict[selected_image]["image"]
354
+ grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
355
  token_idx = grid_y * GRID_NUM + grid_x + 1
356
  tile_activations = all_activation[token_idx]
357
+ _sort_and_save_top_k(tile_activations, neuron_dict)
 
 
358
 
359
+ sorted_dict = dict(
360
+ sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
361
+ )
362
+ return sorted_dict
363
+
364
+ clip_neuron_dict = {}
365
+ maple_neuron_dict = {}
366
+ clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
367
+ maple_neuron_dict = _get_top_actvation(
368
+ evt, selected_image, model_name, maple_neuron_dict
369
+ )
370
 
371
+ clip_keys = list(clip_neuron_dict.keys())
372
+ maple_keys = list(maple_neuron_dict.keys())
373
+
374
+ common_keys = list(set(clip_keys).intersection(set(maple_keys)))
375
+ clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
376
+ maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
377
+
378
+ common_keys.sort(
379
+ key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
380
+ )
381
+ clip_only_keys.sort(reverse=True)
382
+ maple_only_keys.sort(reverse=True)
383
+
384
+ out = []
385
+ out.extend([f"common-{i}" for i in common_keys[:5]])
386
+ out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
387
+ out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
388
+
389
+ radio_choices = gr.Radio(
390
+ choices=out, label="Top activating SAE latent", value=out[0]
391
+ )
392
+ sleep(0.1)
393
+ return radio_choices
394
 
 
 
395
 
396
+ def update_markdown(option_value):
 
397
  latent_idx = int(option_value.split("-")[-1])
398
  out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
399
  out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
400
  return out_1, out_2
401
 
402
+
403
+ def get_data(image_name, model_name):
404
+ pkl_root = "./data/out"
405
+ data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
406
+ with gzip.open(data_dir, "rb") as f:
407
+ data = pickle.load(f)
408
+ out = data
409
+
410
+ return out
411
+
412
+
413
+ def update_all(selected_image, slider_value, toggle_btn, model_name):
414
  (
415
  seg_mask_display,
416
  top_image_1,
 
420
  act_value_2,
421
  act_value_3,
422
  ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
 
423
  seg_mask_display_maple = show_activation_heatmap_maple(
424
  selected_image, slider_value, model_name
425
  )
 
438
  markdown_display_2,
439
  )
440
 
441
+
442
+ def load_all_data(image_root, pkl_root):
443
+ image_files = glob(f"{image_root}/*")
444
+ data_dict = {}
445
+ for image_file in image_files:
446
+ image_name = os.path.basename(image_file).split(".")[0]
447
+ if image_file not in data_dict:
448
+ data_dict[image_name] = {
449
+ "image": Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE)),
450
+ "image_path": image_file,
451
+ }
452
+
453
+ sae_data_dict = {}
454
+ with open("./data/sae_data/mean_acts.pkl", "rb") as f:
455
+ data = pickle.load(f)
456
+ sae_data_dict["mean_acts"] = data
457
+
458
+ sae_data_dict["mean_act_values"] = {}
459
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
460
+ with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
461
+ data = pickle.load(f)
462
+ sae_data_dict["mean_act_values"][dataset] = data
463
+
464
+ return data_dict, sae_data_dict
465
+
466
+
467
+ data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
468
  default_image_name = "christmas-imagenet"
469
 
470
+
471
  with gr.Blocks(
472
  theme=gr.themes.Citrus(),
473
  css="""
474
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
475
+ .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
476
  """,
477
  ) as demo:
478
  with gr.Row():
479
  with gr.Column():
480
+ # Left View: Image selection and click handling
481
  gr.Markdown("## Select input image and patch on the image")
482
  image_selector = gr.Dropdown(
483
+ choices=list(data_dict.keys()),
484
  value=default_image_name,
485
  label="Select Image",
486
  )
487
  image_display = gr.Image(
488
+ value=data_dict[default_image_name]["image"],
489
  type="pil",
490
  interactive=True,
491
  )
492
 
493
+ # Update image display when a new image is selected
494
  image_selector.change(
495
+ fn=lambda img_name: data_dict[img_name]["image"],
496
  inputs=image_selector,
497
  outputs=image_display,
498
  )
499
  image_display.select(
500
+ fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
 
 
501
  )
502
 
503
  with gr.Column():
 
508
  value=model_options[0],
509
  label="Select adapted model (MaPLe)",
510
  )
511
+ init_plot = plot_activation_distribution(
512
+ None, default_image_name, model_options[0]
513
+ )
514
+ neuron_plot = gr.Plot(
515
+ label="Neuron Activation", value=init_plot, show_label=False
516
+ )
517
 
518
  image_selector.change(
519
  fn=plot_activation_distribution,
 
526
  outputs=neuron_plot,
527
  )
528
  model_selector.change(
529
+ fn=load_image, inputs=[image_selector], outputs=image_display
 
 
530
  )
531
  model_selector.change(
532
  fn=plot_activation_distribution,
 
537
  with gr.Row():
538
  with gr.Column():
539
  radio_names = get_init_radio_options(default_image_name, model_options[0])
540
+
541
+ feautre_idx = radio_names[0].split("-")[-1]
542
  markdown_display = gr.Markdown(
543
+ f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
544
  )
545
  init_seg, init_tops, init_values = show_activation_heatmap(
546
  default_image_name, radio_names[0], "CLIP"
 
552
  default_image_name, radio_names[0], model_options[0]
553
  )
554
  gr.Markdown("### Localize SAE latent activation using MaPLE")
555
+ seg_mask_display_maple = gr.Image(
556
+ value=init_seg_maple, type="pil", show_label=False
557
+ )
558
 
559
  with gr.Column():
560
  gr.Markdown("## Top activating SAE latent index")
561
+
562
  radio_choices = gr.Radio(
563
  choices=radio_names,
564
  label="Top activating SAE latent",
 
566
  value=radio_names[0],
567
  )
568
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
569
+
570
  markdown_display_2 = gr.Markdown(
571
+ f"## Top reference images for the selected SAE latent - {feautre_idx}"
572
  )
573
 
574
  gr.Markdown("### ImageNet")
575
+ top_image_1 = gr.Image(
576
+ value=init_tops[0], type="pil", label="ImageNet", show_label=False
577
+ )
578
  act_value_1 = gr.Markdown(init_values[0])
579
 
580
  gr.Markdown("### ImageNet-Sketch")
581
+ top_image_2 = gr.Image(
582
+ value=init_tops[1],
583
+ type="pil",
584
+ label="ImageNet-Sketch",
585
+ show_label=False,
586
+ )
587
  act_value_2 = gr.Markdown(init_values[1])
588
 
589
  gr.Markdown("### Caltech101")
590
+ top_image_3 = gr.Image(
591
+ value=init_tops[2], type="pil", label="Caltech101", show_label=False
592
+ )
593
  act_value_3 = gr.Markdown(init_values[2])
594
 
 
595
  image_display.select(
596
  fn=update_radio_options,
597
  inputs=[image_selector, model_selector],
598
  outputs=[radio_choices],
599
  )
600
+
601
  model_selector.change(
602
  fn=update_radio_options,
603
  inputs=[image_selector, model_selector],
604
  outputs=[radio_choices],
605
  )
606
+
607
  image_selector.select(
608
  fn=update_radio_options,
609
  inputs=[image_selector, model_selector],
610
  outputs=[radio_choices],
611
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
 
613
+ radio_choices.change(
614
+ fn=update_all,
615
+ inputs=[image_selector, radio_choices, toggle_btn, model_selector],
616
+ outputs=[
617
+ seg_mask_display,
618
+ seg_mask_display_maple,
619
+ top_image_1,
620
+ top_image_2,
621
+ top_image_3,
622
+ act_value_1,
623
+ act_value_2,
624
+ act_value_3,
625
+ markdown_display,
626
+ markdown_display_2,
627
+ ],
628
+ )
629
 
630
+ toggle_btn.change(
631
+ fn=show_activation_heatmap_clip,
632
+ inputs=[image_selector, radio_choices, toggle_btn],
633
+ outputs=[
634
+ seg_mask_display,
635
+ top_image_1,
636
+ top_image_2,
637
+ top_image_3,
638
+ act_value_1,
639
+ act_value_2,
640
+ act_value_3,
641
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
  )
643
+
644
+ # Launch the app
645
+ # demo.queue()
646
+ demo.launch()