hyesulim commited on
Commit
cfec577
·
verified ·
1 Parent(s): 928847c

test: tried merging radio_choices.change

Browse files
Files changed (1) hide show
  1. app.py +214 -72
app.py CHANGED
@@ -28,7 +28,9 @@ def preload_activation(image_name):
28
  def get_activation_distribution(image_name: str, model_type: str):
29
  activation = get_data(image_name, model_type)[0]
30
 
31
- noisy_features_indices = (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
 
 
32
  activation[:, noisy_features_indices] = 0
33
 
34
  return activation
@@ -52,18 +54,31 @@ def highlight_grid(evt: gr.EventData, image_name):
52
 
53
  highlighted_image = image.copy()
54
  draw = ImageDraw.Draw(highlighted_image)
55
- box = [grid_x * cell_width, grid_y * cell_height, (grid_x + 1) * cell_width, (grid_y + 1) * cell_height]
 
 
 
 
 
56
  draw.rectangle(box, outline="red", width=3)
57
 
58
  return highlighted_image
59
 
60
 
61
  def load_image(img_name):
62
- return Image.open(data_dict[img_name]["image_path"]).resize((IMAGE_SIZE, IMAGE_SIZE))
 
 
63
 
64
 
65
  def plot_activations(
66
- all_activation, tile_activations=None, grid_x=None, grid_y=None, top_k=5, colors=("blue", "cyan"), model_name="CLIP"
 
 
 
 
 
 
67
  ):
68
  fig = go.Figure()
69
 
@@ -94,10 +109,14 @@ def plot_activations(
94
  return fig
95
 
96
  label = f"{model_name.split('-')[-0]} Image-level"
97
- fig = _add_scatter_with_annotation(fig, all_activation, model_name, colors[0], label)
 
 
98
  if tile_activations is not None:
99
  label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
100
- fig = _add_scatter_with_annotation(fig, tile_activations, model_name, colors[1], label)
 
 
101
 
102
  fig.update_layout(
103
  title="Activation Distribution",
@@ -105,7 +124,9 @@ def plot_activations(
105
  yaxis_title="Activation Value",
106
  template="plotly_white",
107
  )
108
- fig.update_layout(legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5))
 
 
109
 
110
  return fig
111
 
@@ -126,12 +147,20 @@ def get_activations(evt: gr.EventData, selected_image: str, model_name: str, col
126
  tile_activations = activation[token_idx]
127
 
128
  fig = plot_activations(
129
- all_activation, tile_activations, grid_x, grid_y, top_k=5, model_name=model_name, colors=colors
 
 
 
 
 
 
130
  )
131
  return fig
132
 
133
 
134
- def plot_activation_distribution(evt: gr.EventData, selected_image: str, model_name: str):
 
 
135
  fig = make_subplots(
136
  rows=2,
137
  cols=1,
@@ -139,8 +168,12 @@ def plot_activation_distribution(evt: gr.EventData, selected_image: str, model_n
139
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
140
  )
141
 
142
- fig_clip = get_activations(evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef"))
143
- fig_maple = get_activations(evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4"))
 
 
 
 
144
 
145
  def _attach_fig(fig, sub_fig, row, col, yref):
146
  for trace in sub_fig.data:
@@ -178,7 +211,9 @@ def get_segmask(selected_image, slider_value, model_type):
178
  mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
179
  except Exception as e:
180
  print(sae_act.shape, slider_value)
181
- mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
 
 
182
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
183
 
184
  base_opacity = 30
@@ -201,7 +236,11 @@ def get_top_images(slider_value, toggle_btn):
201
  os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
202
  ]
203
  top_images = [
204
- Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
 
 
 
 
205
  for path in top_image_paths
206
  ]
207
  return top_images
@@ -230,9 +269,19 @@ def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn
230
 
231
 
232
  def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
233
- rgba_overlay, top_images, act_values = show_activation_heatmap(selected_image, slider_value, "CLIP", toggle_btn)
 
 
234
  sleep(0.1)
235
- return (rgba_overlay, top_images[0], top_images[1], top_images[2], act_values[0], act_values[1], act_values[2])
 
 
 
 
 
 
 
 
236
 
237
 
238
  def show_activation_heatmap_maple(selected_image, slider_value, model_name):
@@ -251,11 +300,15 @@ def get_init_radio_options(selected_image, model_name):
251
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
252
  for top_neuron in top_neurons:
253
  neuron_dict[top_neuron] = activations[top_neuron]
254
- sorted_dict = dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
 
 
255
  return sorted_dict
256
 
257
  clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
258
- maple_neuron_dict = _get_top_actvation(selected_image, model_name, maple_neuron_dict)
 
 
259
 
260
  radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
261
 
@@ -270,7 +323,9 @@ def get_radio_names(clip_neuron_dict, maple_neuron_dict):
270
  clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
271
  maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
272
 
273
- common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
 
 
274
  clip_only_keys.sort(reverse=True)
275
  maple_only_keys.sort(reverse=True)
276
 
@@ -301,13 +356,17 @@ def update_radio_options(evt: gr.EventData, selected_image, model_name):
301
  tile_activations = all_activation[token_idx]
302
  _sort_and_save_top_k(tile_activations, neuron_dict)
303
 
304
- sorted_dict = dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
 
 
305
  return sorted_dict
306
 
307
  clip_neuron_dict = {}
308
  maple_neuron_dict = {}
309
  clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
310
- maple_neuron_dict = _get_top_actvation(evt, selected_image, model_name, maple_neuron_dict)
 
 
311
 
312
  clip_keys = list(clip_neuron_dict.keys())
313
  maple_keys = list(maple_neuron_dict.keys())
@@ -316,7 +375,9 @@ def update_radio_options(evt: gr.EventData, selected_image, model_name):
316
  clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
317
  maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
318
 
319
- common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
 
 
320
  clip_only_keys.sort(reverse=True)
321
  maple_only_keys.sort(reverse=True)
322
 
@@ -325,7 +386,9 @@ def update_radio_options(evt: gr.EventData, selected_image, model_name):
325
  out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
326
  out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
327
 
328
- radio_choices = gr.Radio(choices=out, label="Top activating SAE latent", value=out[0])
 
 
329
  sleep(0.1)
330
  return radio_choices
331
 
@@ -347,6 +410,35 @@ def get_data(image_name, model_name):
347
  return out
348
 
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def load_all_data(image_root, pkl_root):
351
  image_files = glob(f"{image_root}/*")
352
  data_dict = {}
@@ -387,33 +479,59 @@ with gr.Blocks(
387
  with gr.Column():
388
  # Left View: Image selection and click handling
389
  gr.Markdown("## Select input image and patch on the image")
390
- image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
391
- image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
 
 
 
 
 
 
 
 
392
 
393
  # Update image display when a new image is selected
394
  image_selector.change(
395
- fn=lambda img_name: data_dict[img_name]["image"], inputs=image_selector, outputs=image_display
 
 
 
 
 
396
  )
397
- image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
398
 
399
  with gr.Column():
400
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
401
  model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
402
  model_selector = gr.Dropdown(
403
- choices=model_options, value=model_options[0], label="Select adapted model (MaPLe)"
 
 
 
 
 
 
 
 
404
  )
405
- init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
406
- neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
407
 
408
  image_selector.change(
409
- fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
410
  )
411
  image_display.select(
412
- fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
 
 
 
413
  )
414
- model_selector.change(fn=load_image, inputs=[image_selector], outputs=image_display)
415
  model_selector.change(
416
- fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
417
  )
418
 
419
  with gr.Row():
@@ -421,82 +539,106 @@ with gr.Blocks(
421
  radio_names = get_init_radio_options(default_image_name, model_options[0])
422
 
423
  feautre_idx = radio_names[0].split("-")[-1]
424
- markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
425
- init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
 
 
 
 
426
 
427
  gr.Markdown("### Localize SAE latent activation using CLIP")
428
  seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
429
- init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], model_options[0])
 
 
430
  gr.Markdown("### Localize SAE latent activation using MaPLE")
431
- seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
 
 
432
 
433
  with gr.Column():
434
  gr.Markdown("## Top activating SAE latent index")
435
 
436
  radio_choices = gr.Radio(
437
- choices=radio_names, label="Top activating SAE latent", interactive=True, value=radio_names[0]
 
 
 
438
  )
439
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
440
 
441
- markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
 
 
442
 
443
  gr.Markdown("### ImageNet")
444
- top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
 
 
445
  act_value_1 = gr.Markdown(init_values[0])
446
 
447
  gr.Markdown("### ImageNet-Sketch")
448
- top_image_2 = gr.Image(value=init_tops[1], type="pil", label="ImageNet-Sketch", show_label=False)
 
 
 
 
 
449
  act_value_2 = gr.Markdown(init_values[1])
450
 
451
  gr.Markdown("### Caltech101")
452
- top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
 
 
453
  act_value_3 = gr.Markdown(init_values[2])
454
 
455
  image_display.select(
456
- fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
457
  )
458
 
459
  model_selector.change(
460
- fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
461
  )
462
 
463
  image_selector.select(
464
- fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
465
  )
466
 
467
  radio_choices.change(
468
- fn=update_markdown,
469
- inputs=[radio_choices],
470
- outputs=[markdown_display, markdown_display_2],
471
- queue=True,
472
- )
473
-
474
- radio_choices.change(
475
- fn=show_activation_heatmap_clip,
476
- inputs=[image_selector, radio_choices, toggle_btn],
477
- outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
478
- queue=True,
479
- )
480
-
481
- radio_choices.change(
482
- fn=show_activation_heatmap_maple,
483
- inputs=[image_selector, radio_choices, model_selector],
484
- outputs=[seg_mask_display_maple],
485
- queue=True,
486
  )
487
 
488
- # toggle_btn.change(
489
- # fn=get_top_images,
490
- # inputs=[radio_choices, toggle_btn],
491
- # outputs=[top_image_1, top_image_2, top_image_3],
492
- # queue=True,
493
- # )
494
-
495
  toggle_btn.change(
496
  fn=show_activation_heatmap_clip,
497
  inputs=[image_selector, radio_choices, toggle_btn],
498
- outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
499
- queue=True,
 
 
 
 
 
 
 
500
  )
501
 
502
  # Launch the app
 
28
  def get_activation_distribution(image_name: str, model_type: str):
29
  activation = get_data(image_name, model_type)[0]
30
 
31
+ noisy_features_indices = (
32
+ (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
33
+ )
34
  activation[:, noisy_features_indices] = 0
35
 
36
  return activation
 
54
 
55
  highlighted_image = image.copy()
56
  draw = ImageDraw.Draw(highlighted_image)
57
+ box = [
58
+ grid_x * cell_width,
59
+ grid_y * cell_height,
60
+ (grid_x + 1) * cell_width,
61
+ (grid_y + 1) * cell_height,
62
+ ]
63
  draw.rectangle(box, outline="red", width=3)
64
 
65
  return highlighted_image
66
 
67
 
68
  def load_image(img_name):
69
+ return Image.open(data_dict[img_name]["image_path"]).resize(
70
+ (IMAGE_SIZE, IMAGE_SIZE)
71
+ )
72
 
73
 
74
  def plot_activations(
75
+ all_activation,
76
+ tile_activations=None,
77
+ grid_x=None,
78
+ grid_y=None,
79
+ top_k=5,
80
+ colors=("blue", "cyan"),
81
+ model_name="CLIP",
82
  ):
83
  fig = go.Figure()
84
 
 
109
  return fig
110
 
111
  label = f"{model_name.split('-')[-0]} Image-level"
112
+ fig = _add_scatter_with_annotation(
113
+ fig, all_activation, model_name, colors[0], label
114
+ )
115
  if tile_activations is not None:
116
  label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
117
+ fig = _add_scatter_with_annotation(
118
+ fig, tile_activations, model_name, colors[1], label
119
+ )
120
 
121
  fig.update_layout(
122
  title="Activation Distribution",
 
124
  yaxis_title="Activation Value",
125
  template="plotly_white",
126
  )
127
+ fig.update_layout(
128
+ legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
129
+ )
130
 
131
  return fig
132
 
 
147
  tile_activations = activation[token_idx]
148
 
149
  fig = plot_activations(
150
+ all_activation,
151
+ tile_activations,
152
+ grid_x,
153
+ grid_y,
154
+ top_k=5,
155
+ model_name=model_name,
156
+ colors=colors,
157
  )
158
  return fig
159
 
160
 
161
+ def plot_activation_distribution(
162
+ evt: gr.EventData, selected_image: str, model_name: str
163
+ ):
164
  fig = make_subplots(
165
  rows=2,
166
  cols=1,
 
168
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
169
  )
170
 
171
+ fig_clip = get_activations(
172
+ evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
173
+ )
174
+ fig_maple = get_activations(
175
+ evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
176
+ )
177
 
178
  def _attach_fig(fig, sub_fig, row, col, yref):
179
  for trace in sub_fig.data:
 
211
  mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
212
  except Exception as e:
213
  print(sae_act.shape, slider_value)
214
+ mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
215
+ 0
216
+ ].numpy()
217
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
218
 
219
  base_opacity = 30
 
236
  os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
237
  ]
238
  top_images = [
239
+ (
240
+ Image.open(path)
241
+ if os.path.exists(path)
242
+ else Image.new("RGB", (256, 256), (255, 255, 255))
243
+ )
244
  for path in top_image_paths
245
  ]
246
  return top_images
 
269
 
270
 
271
  def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
272
+ rgba_overlay, top_images, act_values = show_activation_heatmap(
273
+ selected_image, slider_value, "CLIP", toggle_btn
274
+ )
275
  sleep(0.1)
276
+ return (
277
+ rgba_overlay,
278
+ top_images[0],
279
+ top_images[1],
280
+ top_images[2],
281
+ act_values[0],
282
+ act_values[1],
283
+ act_values[2],
284
+ )
285
 
286
 
287
  def show_activation_heatmap_maple(selected_image, slider_value, model_name):
 
300
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
301
  for top_neuron in top_neurons:
302
  neuron_dict[top_neuron] = activations[top_neuron]
303
+ sorted_dict = dict(
304
+ sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
305
+ )
306
  return sorted_dict
307
 
308
  clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
309
+ maple_neuron_dict = _get_top_actvation(
310
+ selected_image, model_name, maple_neuron_dict
311
+ )
312
 
313
  radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
314
 
 
323
  clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
324
  maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
325
 
326
+ common_keys.sort(
327
+ key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
328
+ )
329
  clip_only_keys.sort(reverse=True)
330
  maple_only_keys.sort(reverse=True)
331
 
 
356
  tile_activations = all_activation[token_idx]
357
  _sort_and_save_top_k(tile_activations, neuron_dict)
358
 
359
+ sorted_dict = dict(
360
+ sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
361
+ )
362
  return sorted_dict
363
 
364
  clip_neuron_dict = {}
365
  maple_neuron_dict = {}
366
  clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
367
+ maple_neuron_dict = _get_top_actvation(
368
+ evt, selected_image, model_name, maple_neuron_dict
369
+ )
370
 
371
  clip_keys = list(clip_neuron_dict.keys())
372
  maple_keys = list(maple_neuron_dict.keys())
 
375
  clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
376
  maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
377
 
378
+ common_keys.sort(
379
+ key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
380
+ )
381
  clip_only_keys.sort(reverse=True)
382
  maple_only_keys.sort(reverse=True)
383
 
 
386
  out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
387
  out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
388
 
389
+ radio_choices = gr.Radio(
390
+ choices=out, label="Top activating SAE latent", value=out[0]
391
+ )
392
  sleep(0.1)
393
  return radio_choices
394
 
 
410
  return out
411
 
412
 
413
+ def update_all(selected_image, slider_value, toggle_btn, model_name):
414
+ (
415
+ seg_mask_display,
416
+ top_image_1,
417
+ top_image_2,
418
+ top_image_3,
419
+ act_value_1,
420
+ act_value_2,
421
+ act_value_3,
422
+ ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
423
+ seg_mask_display_maple = show_activation_heatmap_maple(
424
+ selected_image, slider_value, model_name
425
+ )
426
+ markdown_display, markdown_display_2 = update_markdown(slider_value)
427
+
428
+ return (
429
+ seg_mask_display,
430
+ seg_mask_display_maple,
431
+ top_image_1,
432
+ top_image_2,
433
+ top_image_3,
434
+ act_value_1,
435
+ act_value_2,
436
+ act_value_3,
437
+ markdown_display,
438
+ markdown_display_2,
439
+ )
440
+
441
+
442
  def load_all_data(image_root, pkl_root):
443
  image_files = glob(f"{image_root}/*")
444
  data_dict = {}
 
479
  with gr.Column():
480
  # Left View: Image selection and click handling
481
  gr.Markdown("## Select input image and patch on the image")
482
+ image_selector = gr.Dropdown(
483
+ choices=list(data_dict.keys()),
484
+ value=default_image_name,
485
+ label="Select Image",
486
+ )
487
+ image_display = gr.Image(
488
+ value=data_dict[default_image_name]["image"],
489
+ type="pil",
490
+ interactive=True,
491
+ )
492
 
493
  # Update image display when a new image is selected
494
  image_selector.change(
495
+ fn=lambda img_name: data_dict[img_name]["image"],
496
+ inputs=image_selector,
497
+ outputs=image_display,
498
+ )
499
+ image_display.select(
500
+ fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
501
  )
 
502
 
503
  with gr.Column():
504
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
505
  model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
506
  model_selector = gr.Dropdown(
507
+ choices=model_options,
508
+ value=model_options[0],
509
+ label="Select adapted model (MaPLe)",
510
+ )
511
+ init_plot = plot_activation_distribution(
512
+ None, default_image_name, model_options[0]
513
+ )
514
+ neuron_plot = gr.Plot(
515
+ label="Neuron Activation", value=init_plot, show_label=False
516
  )
 
 
517
 
518
  image_selector.change(
519
+ fn=plot_activation_distribution,
520
+ inputs=[image_selector, model_selector],
521
+ outputs=neuron_plot,
522
  )
523
  image_display.select(
524
+ fn=plot_activation_distribution,
525
+ inputs=[image_selector, model_selector],
526
+ outputs=neuron_plot,
527
+ )
528
+ model_selector.change(
529
+ fn=load_image, inputs=[image_selector], outputs=image_display
530
  )
 
531
  model_selector.change(
532
+ fn=plot_activation_distribution,
533
+ inputs=[image_selector, model_selector],
534
+ outputs=neuron_plot,
535
  )
536
 
537
  with gr.Row():
 
539
  radio_names = get_init_radio_options(default_image_name, model_options[0])
540
 
541
  feautre_idx = radio_names[0].split("-")[-1]
542
+ markdown_display = gr.Markdown(
543
+ f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
544
+ )
545
+ init_seg, init_tops, init_values = show_activation_heatmap(
546
+ default_image_name, radio_names[0], "CLIP"
547
+ )
548
 
549
  gr.Markdown("### Localize SAE latent activation using CLIP")
550
  seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
551
+ init_seg_maple, _, _ = show_activation_heatmap(
552
+ default_image_name, radio_names[0], model_options[0]
553
+ )
554
  gr.Markdown("### Localize SAE latent activation using MaPLE")
555
+ seg_mask_display_maple = gr.Image(
556
+ value=init_seg_maple, type="pil", show_label=False
557
+ )
558
 
559
  with gr.Column():
560
  gr.Markdown("## Top activating SAE latent index")
561
 
562
  radio_choices = gr.Radio(
563
+ choices=radio_names,
564
+ label="Top activating SAE latent",
565
+ interactive=True,
566
+ value=radio_names[0],
567
  )
568
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
569
 
570
+ markdown_display_2 = gr.Markdown(
571
+ f"## Top reference images for the selected SAE latent - {feautre_idx}"
572
+ )
573
 
574
  gr.Markdown("### ImageNet")
575
+ top_image_1 = gr.Image(
576
+ value=init_tops[0], type="pil", label="ImageNet", show_label=False
577
+ )
578
  act_value_1 = gr.Markdown(init_values[0])
579
 
580
  gr.Markdown("### ImageNet-Sketch")
581
+ top_image_2 = gr.Image(
582
+ value=init_tops[1],
583
+ type="pil",
584
+ label="ImageNet-Sketch",
585
+ show_label=False,
586
+ )
587
  act_value_2 = gr.Markdown(init_values[1])
588
 
589
  gr.Markdown("### Caltech101")
590
+ top_image_3 = gr.Image(
591
+ value=init_tops[2], type="pil", label="Caltech101", show_label=False
592
+ )
593
  act_value_3 = gr.Markdown(init_values[2])
594
 
595
  image_display.select(
596
+ fn=update_radio_options,
597
+ inputs=[image_selector, model_selector],
598
+ outputs=[radio_choices],
599
  )
600
 
601
  model_selector.change(
602
+ fn=update_radio_options,
603
+ inputs=[image_selector, model_selector],
604
+ outputs=[radio_choices],
605
  )
606
 
607
  image_selector.select(
608
+ fn=update_radio_options,
609
+ inputs=[image_selector, model_selector],
610
+ outputs=[radio_choices],
611
  )
612
 
613
  radio_choices.change(
614
+ fn=update_all,
615
+ inputs=[image_selector, radio_choices, toggle_btn, model_selector],
616
+ outputs=[
617
+ seg_mask_display,
618
+ seg_mask_display_maple,
619
+ top_image_1,
620
+ top_image_2,
621
+ top_image_3,
622
+ act_value_1,
623
+ act_value_2,
624
+ act_value_3,
625
+ markdown_display,
626
+ markdown_display_2,
627
+ ],
 
 
 
 
628
  )
629
 
 
 
 
 
 
 
 
630
  toggle_btn.change(
631
  fn=show_activation_heatmap_clip,
632
  inputs=[image_selector, radio_choices, toggle_btn],
633
+ outputs=[
634
+ seg_mask_display,
635
+ top_image_1,
636
+ top_image_2,
637
+ top_image_3,
638
+ act_value_1,
639
+ act_value_2,
640
+ act_value_3,
641
+ ],
642
  )
643
 
644
  # Launch the app