hyesulim commited on
Commit
45d706c
·
verified ·
1 Parent(s): bf85231

fix: activation plot bug

Browse files
Files changed (1) hide show
  1. app.py +72 -215
app.py CHANGED
@@ -28,9 +28,7 @@ def preload_activation(image_name):
28
  def get_activation_distribution(image_name: str, model_type: str):
29
  activation = get_data(image_name, model_type)[0]
30
 
31
- noisy_features_indices = (
32
- (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
33
- )
34
  activation[:, noisy_features_indices] = 0
35
 
36
  return activation
@@ -54,31 +52,18 @@ def highlight_grid(evt: gr.EventData, image_name):
54
 
55
  highlighted_image = image.copy()
56
  draw = ImageDraw.Draw(highlighted_image)
57
- box = [
58
- grid_x * cell_width,
59
- grid_y * cell_height,
60
- (grid_x + 1) * cell_width,
61
- (grid_y + 1) * cell_height,
62
- ]
63
  draw.rectangle(box, outline="red", width=3)
64
 
65
  return highlighted_image
66
 
67
 
68
  def load_image(img_name):
69
- return Image.open(data_dict[img_name]["image_path"]).resize(
70
- (IMAGE_SIZE, IMAGE_SIZE)
71
- )
72
 
73
 
74
  def plot_activations(
75
- all_activation,
76
- tile_activations=None,
77
- grid_x=None,
78
- grid_y=None,
79
- top_k=5,
80
- colors=("blue", "cyan"),
81
- model_name="CLIP",
82
  ):
83
  fig = go.Figure()
84
 
@@ -109,14 +94,10 @@ def plot_activations(
109
  return fig
110
 
111
  label = f"{model_name.split('-')[-0]} Image-level"
112
- fig = _add_scatter_with_annotation(
113
- fig, all_activation, model_name, colors[0], label
114
- )
115
  if tile_activations is not None:
116
  label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
117
- fig = _add_scatter_with_annotation(
118
- fig, tile_activations, model_name, colors[1], label
119
- )
120
 
121
  fig.update_layout(
122
  title="Activation Distribution",
@@ -124,9 +105,7 @@ def plot_activations(
124
  yaxis_title="Activation Value",
125
  template="plotly_white",
126
  )
127
- fig.update_layout(
128
- legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
129
- )
130
 
131
  return fig
132
 
@@ -147,20 +126,12 @@ def get_activations(evt: gr.EventData, selected_image: str, model_name: str, col
147
  tile_activations = activation[token_idx]
148
 
149
  fig = plot_activations(
150
- all_activation,
151
- tile_activations,
152
- grid_x,
153
- grid_y,
154
- top_k=5,
155
- model_name=model_name,
156
- colors=colors,
157
  )
158
  return fig
159
 
160
 
161
- def plot_activation_distribution(
162
- evt: gr.EventData, selected_image: str, model_name: str
163
- ):
164
  fig = make_subplots(
165
  rows=2,
166
  cols=1,
@@ -168,12 +139,8 @@ def plot_activation_distribution(
168
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
169
  )
170
 
171
- fig_clip = get_activations(
172
- evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
173
- )
174
- fig_maple = get_activations(
175
- evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
176
- )
177
 
178
  def _attach_fig(fig, sub_fig, row, col, yref):
179
  for trace in sub_fig.data:
@@ -211,9 +178,7 @@ def get_segmask(selected_image, slider_value, model_type):
211
  mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
212
  except Exception as e:
213
  print(sae_act.shape, slider_value)
214
- mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
215
- 0
216
- ].numpy()
217
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
218
 
219
  base_opacity = 30
@@ -236,11 +201,7 @@ def get_top_images(slider_value, toggle_btn):
236
  os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
237
  ]
238
  top_images = [
239
- (
240
- Image.open(path)
241
- if os.path.exists(path)
242
- else Image.new("RGB", (256, 256), (255, 255, 255))
243
- )
244
  for path in top_image_paths
245
  ]
246
  return top_images
@@ -269,19 +230,9 @@ def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn
269
 
270
 
271
  def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
272
- rgba_overlay, top_images, act_values = show_activation_heatmap(
273
- selected_image, slider_value, "CLIP", toggle_btn
274
- )
275
  sleep(0.1)
276
- return (
277
- rgba_overlay,
278
- top_images[0],
279
- top_images[1],
280
- top_images[2],
281
- act_values[0],
282
- act_values[1],
283
- act_values[2],
284
- )
285
 
286
 
287
  def show_activation_heatmap_maple(selected_image, slider_value, model_name):
@@ -300,15 +251,11 @@ def get_init_radio_options(selected_image, model_name):
300
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
301
  for top_neuron in top_neurons:
302
  neuron_dict[top_neuron] = activations[top_neuron]
303
- sorted_dict = dict(
304
- sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
305
- )
306
  return sorted_dict
307
 
308
  clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
309
- maple_neuron_dict = _get_top_actvation(
310
- selected_image, model_name, maple_neuron_dict
311
- )
312
 
313
  radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
314
 
@@ -323,9 +270,7 @@ def get_radio_names(clip_neuron_dict, maple_neuron_dict):
323
  clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
324
  maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
325
 
326
- common_keys.sort(
327
- key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
328
- )
329
  clip_only_keys.sort(reverse=True)
330
  maple_only_keys.sort(reverse=True)
331
 
@@ -356,17 +301,13 @@ def update_radio_options(evt: gr.EventData, selected_image, model_name):
356
  tile_activations = all_activation[token_idx]
357
  _sort_and_save_top_k(tile_activations, neuron_dict)
358
 
359
- sorted_dict = dict(
360
- sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
361
- )
362
  return sorted_dict
363
 
364
  clip_neuron_dict = {}
365
  maple_neuron_dict = {}
366
  clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
367
- maple_neuron_dict = _get_top_actvation(
368
- evt, selected_image, model_name, maple_neuron_dict
369
- )
370
 
371
  clip_keys = list(clip_neuron_dict.keys())
372
  maple_keys = list(maple_neuron_dict.keys())
@@ -375,9 +316,7 @@ def update_radio_options(evt: gr.EventData, selected_image, model_name):
375
  clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
376
  maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
377
 
378
- common_keys.sort(
379
- key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
380
- )
381
  clip_only_keys.sort(reverse=True)
382
  maple_only_keys.sort(reverse=True)
383
 
@@ -386,9 +325,7 @@ def update_radio_options(evt: gr.EventData, selected_image, model_name):
386
  out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
387
  out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
388
 
389
- radio_choices = gr.Radio(
390
- choices=out, label="Top activating SAE latent", value=out[0]
391
- )
392
  sleep(0.1)
393
  return radio_choices
394
 
@@ -410,35 +347,6 @@ def get_data(image_name, model_name):
410
  return out
411
 
412
 
413
- def update_all(selected_image, slider_value, toggle_btn, model_name):
414
- (
415
- seg_mask_display,
416
- top_image_1,
417
- top_image_2,
418
- top_image_3,
419
- act_value_1,
420
- act_value_2,
421
- act_value_3,
422
- ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
423
- seg_mask_display_maple = show_activation_heatmap_maple(
424
- selected_image, slider_value, model_name
425
- )
426
- markdown_display, markdown_display_2 = update_markdown(slider_value)
427
-
428
- return (
429
- seg_mask_display,
430
- seg_mask_display_maple,
431
- top_image_1,
432
- top_image_2,
433
- top_image_3,
434
- act_value_1,
435
- act_value_2,
436
- act_value_3,
437
- markdown_display,
438
- markdown_display_2,
439
- )
440
-
441
-
442
  def load_all_data(image_root, pkl_root):
443
  image_files = glob(f"{image_root}/*")
444
  data_dict = {}
@@ -479,59 +387,33 @@ with gr.Blocks(
479
  with gr.Column():
480
  # Left View: Image selection and click handling
481
  gr.Markdown("## Select input image and patch on the image")
482
- image_selector = gr.Dropdown(
483
- choices=list(data_dict.keys()),
484
- value=default_image_name,
485
- label="Select Image",
486
- )
487
- image_display = gr.Image(
488
- value=data_dict[default_image_name]["image"],
489
- type="pil",
490
- interactive=True,
491
- )
492
 
493
  # Update image display when a new image is selected
494
  image_selector.change(
495
- fn=lambda img_name: data_dict[img_name]["image"],
496
- inputs=image_selector,
497
- outputs=image_display,
498
- )
499
- image_display.select(
500
- fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
501
  )
 
502
 
503
  with gr.Column():
504
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
505
  model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
506
  model_selector = gr.Dropdown(
507
- choices=model_options,
508
- value=model_options[0],
509
- label="Select adapted model (MaPLe)",
510
- )
511
- init_plot = plot_activation_distribution(
512
- None, default_image_name, model_options[0]
513
- )
514
- neuron_plot = gr.Plot(
515
- label="Neuron Activation", value=init_plot, show_label=False
516
  )
 
 
517
 
518
  image_selector.change(
519
- fn=plot_activation_distribution,
520
- inputs=[image_selector, model_selector],
521
- outputs=neuron_plot,
522
  )
523
  image_display.select(
524
- fn=plot_activation_distribution,
525
- inputs=[image_selector, model_selector],
526
- outputs=neuron_plot,
527
- )
528
- model_selector.change(
529
- fn=load_image, inputs=[image_selector], outputs=image_display
530
  )
 
531
  model_selector.change(
532
- fn=plot_activation_distribution,
533
- inputs=[image_selector, model_selector],
534
- outputs=neuron_plot,
535
  )
536
 
537
  with gr.Row():
@@ -539,108 +421,83 @@ with gr.Blocks(
539
  radio_names = get_init_radio_options(default_image_name, model_options[0])
540
 
541
  feautre_idx = radio_names[0].split("-")[-1]
542
- markdown_display = gr.Markdown(
543
- f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
544
- )
545
- init_seg, init_tops, init_values = show_activation_heatmap(
546
- default_image_name, radio_names[0], "CLIP"
547
- )
548
 
549
  gr.Markdown("### Localize SAE latent activation using CLIP")
550
  seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
551
- init_seg_maple, _, _ = show_activation_heatmap(
552
- default_image_name, radio_names[0], model_options[0]
553
- )
554
  gr.Markdown("### Localize SAE latent activation using MaPLE")
555
- seg_mask_display_maple = gr.Image(
556
- value=init_seg_maple, type="pil", show_label=False
557
- )
558
 
559
  with gr.Column():
560
  gr.Markdown("## Top activating SAE latent index")
561
 
562
  radio_choices = gr.Radio(
563
- choices=radio_names,
564
- label="Top activating SAE latent",
565
- interactive=True,
566
- value=radio_names[0],
567
  )
568
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
569
 
570
- markdown_display_2 = gr.Markdown(
571
- f"## Top reference images for the selected SAE latent - {feautre_idx}"
572
- )
573
 
574
  gr.Markdown("### ImageNet")
575
- top_image_1 = gr.Image(
576
- value=init_tops[0], type="pil", label="ImageNet", show_label=False
577
- )
578
  act_value_1 = gr.Markdown(init_values[0])
579
 
580
  gr.Markdown("### ImageNet-Sketch")
581
- top_image_2 = gr.Image(
582
- value=init_tops[1],
583
- type="pil",
584
- label="ImageNet-Sketch",
585
- show_label=False,
586
- )
587
  act_value_2 = gr.Markdown(init_values[1])
588
 
589
  gr.Markdown("### Caltech101")
590
- top_image_3 = gr.Image(
591
- value=init_tops[2], type="pil", label="Caltech101", show_label=False
592
- )
593
  act_value_3 = gr.Markdown(init_values[2])
594
 
595
  image_display.select(
596
- fn=update_radio_options,
597
- inputs=[image_selector, model_selector],
598
- outputs=[radio_choices],
599
  )
600
 
601
  model_selector.change(
602
- fn=update_radio_options,
603
- inputs=[image_selector, model_selector],
604
- outputs=[radio_choices],
605
  )
606
 
607
  image_selector.select(
608
- fn=update_radio_options,
609
- inputs=[image_selector, model_selector],
610
- outputs=[radio_choices],
611
  )
612
 
613
  radio_choices.change(
614
- fn=update_all,
615
- inputs=[image_selector, radio_choices, toggle_btn, model_selector],
616
- outputs=[
617
- seg_mask_display,
618
- seg_mask_display_maple,
619
- top_image_1,
620
- top_image_2,
621
- top_image_3,
622
- act_value_1,
623
- act_value_2,
624
- act_value_3,
625
- markdown_display,
626
- markdown_display_2,
627
- ],
 
 
 
 
628
  )
629
 
 
 
 
 
 
 
 
630
  toggle_btn.change(
631
  fn=show_activation_heatmap_clip,
632
  inputs=[image_selector, radio_choices, toggle_btn],
633
- outputs=[
634
- seg_mask_display,
635
- top_image_1,
636
- top_image_2,
637
- top_image_3,
638
- act_value_1,
639
- act_value_2,
640
- act_value_3,
641
- ],
642
  )
643
 
644
  # Launch the app
645
- # demo.queue()
646
  demo.launch()
 
28
  def get_activation_distribution(image_name: str, model_type: str):
29
  activation = get_data(image_name, model_type)[0]
30
 
31
+ noisy_features_indices = (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
 
 
32
  activation[:, noisy_features_indices] = 0
33
 
34
  return activation
 
52
 
53
  highlighted_image = image.copy()
54
  draw = ImageDraw.Draw(highlighted_image)
55
+ box = [grid_x * cell_width, grid_y * cell_height, (grid_x + 1) * cell_width, (grid_y + 1) * cell_height]
 
 
 
 
 
56
  draw.rectangle(box, outline="red", width=3)
57
 
58
  return highlighted_image
59
 
60
 
61
  def load_image(img_name):
62
+ return Image.open(data_dict[img_name]["image_path"]).resize((IMAGE_SIZE, IMAGE_SIZE))
 
 
63
 
64
 
65
  def plot_activations(
66
+ all_activation, tile_activations=None, grid_x=None, grid_y=None, top_k=5, colors=("blue", "cyan"), model_name="CLIP"
 
 
 
 
 
 
67
  ):
68
  fig = go.Figure()
69
 
 
94
  return fig
95
 
96
  label = f"{model_name.split('-')[-0]} Image-level"
97
+ fig = _add_scatter_with_annotation(fig, all_activation, model_name, colors[0], label)
 
 
98
  if tile_activations is not None:
99
  label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
100
+ fig = _add_scatter_with_annotation(fig, tile_activations, model_name, colors[1], label)
 
 
101
 
102
  fig.update_layout(
103
  title="Activation Distribution",
 
105
  yaxis_title="Activation Value",
106
  template="plotly_white",
107
  )
108
+ fig.update_layout(legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5))
 
 
109
 
110
  return fig
111
 
 
126
  tile_activations = activation[token_idx]
127
 
128
  fig = plot_activations(
129
+ all_activation, tile_activations, grid_x, grid_y, top_k=5, model_name=model_name, colors=colors
 
 
 
 
 
 
130
  )
131
  return fig
132
 
133
 
134
+ def plot_activation_distribution(evt: gr.EventData, selected_image: str, model_name: str):
 
 
135
  fig = make_subplots(
136
  rows=2,
137
  cols=1,
 
139
  subplot_titles=["CLIP Activation", f"{model_name} Activation"],
140
  )
141
 
142
+ fig_clip = get_activations(evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef"))
143
+ fig_maple = get_activations(evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4"))
 
 
 
 
144
 
145
  def _attach_fig(fig, sub_fig, row, col, yref):
146
  for trace in sub_fig.data:
 
178
  mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
179
  except Exception as e:
180
  print(sae_act.shape, slider_value)
181
+ mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
 
 
182
  mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
183
 
184
  base_opacity = 30
 
201
  os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
202
  ]
203
  top_images = [
204
+ Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
 
 
 
 
205
  for path in top_image_paths
206
  ]
207
  return top_images
 
230
 
231
 
232
  def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
233
+ rgba_overlay, top_images, act_values = show_activation_heatmap(selected_image, slider_value, "CLIP", toggle_btn)
 
 
234
  sleep(0.1)
235
+ return (rgba_overlay, top_images[0], top_images[1], top_images[2], act_values[0], act_values[1], act_values[2])
 
 
 
 
 
 
 
 
236
 
237
 
238
  def show_activation_heatmap_maple(selected_image, slider_value, model_name):
 
251
  top_neurons = list(np.argsort(activations)[::-1][:top_k])
252
  for top_neuron in top_neurons:
253
  neuron_dict[top_neuron] = activations[top_neuron]
254
+ sorted_dict = dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
 
 
255
  return sorted_dict
256
 
257
  clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
258
+ maple_neuron_dict = _get_top_actvation(selected_image, model_name, maple_neuron_dict)
 
 
259
 
260
  radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
261
 
 
270
  clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
271
  maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
272
 
273
+ common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
 
 
274
  clip_only_keys.sort(reverse=True)
275
  maple_only_keys.sort(reverse=True)
276
 
 
301
  tile_activations = all_activation[token_idx]
302
  _sort_and_save_top_k(tile_activations, neuron_dict)
303
 
304
+ sorted_dict = dict(sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True))
 
 
305
  return sorted_dict
306
 
307
  clip_neuron_dict = {}
308
  maple_neuron_dict = {}
309
  clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
310
+ maple_neuron_dict = _get_top_actvation(evt, selected_image, model_name, maple_neuron_dict)
 
 
311
 
312
  clip_keys = list(clip_neuron_dict.keys())
313
  maple_keys = list(maple_neuron_dict.keys())
 
316
  clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
317
  maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
318
 
319
+ common_keys.sort(key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True)
 
 
320
  clip_only_keys.sort(reverse=True)
321
  maple_only_keys.sort(reverse=True)
322
 
 
325
  out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
326
  out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
327
 
328
+ radio_choices = gr.Radio(choices=out, label="Top activating SAE latent", value=out[0])
 
 
329
  sleep(0.1)
330
  return radio_choices
331
 
 
347
  return out
348
 
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def load_all_data(image_root, pkl_root):
351
  image_files = glob(f"{image_root}/*")
352
  data_dict = {}
 
387
  with gr.Column():
388
  # Left View: Image selection and click handling
389
  gr.Markdown("## Select input image and patch on the image")
390
+ image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
391
+ image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
 
 
 
 
 
 
 
 
392
 
393
  # Update image display when a new image is selected
394
  image_selector.change(
395
+ fn=lambda img_name: data_dict[img_name]["image"], inputs=image_selector, outputs=image_display
 
 
 
 
 
396
  )
397
+ image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
398
 
399
  with gr.Column():
400
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
401
  model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
402
  model_selector = gr.Dropdown(
403
+ choices=model_options, value=model_options[0], label="Select adapted model (MaPLe)"
 
 
 
 
 
 
 
 
404
  )
405
+ init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
406
+ neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
407
 
408
  image_selector.change(
409
+ fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
410
  )
411
  image_display.select(
412
+ fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
 
 
 
413
  )
414
+ model_selector.change(fn=load_image, inputs=[image_selector], outputs=image_display)
415
  model_selector.change(
416
+ fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
417
  )
418
 
419
  with gr.Row():
 
421
  radio_names = get_init_radio_options(default_image_name, model_options[0])
422
 
423
  feautre_idx = radio_names[0].split("-")[-1]
424
+ markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
425
+ init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
 
 
 
 
426
 
427
  gr.Markdown("### Localize SAE latent activation using CLIP")
428
  seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
429
+ init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], model_options[0])
 
 
430
  gr.Markdown("### Localize SAE latent activation using MaPLE")
431
+ seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
 
 
432
 
433
  with gr.Column():
434
  gr.Markdown("## Top activating SAE latent index")
435
 
436
  radio_choices = gr.Radio(
437
+ choices=radio_names, label="Top activating SAE latent", interactive=True, value=radio_names[0]
 
 
 
438
  )
439
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
440
 
441
+ markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
 
 
442
 
443
  gr.Markdown("### ImageNet")
444
+ top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
 
 
445
  act_value_1 = gr.Markdown(init_values[0])
446
 
447
  gr.Markdown("### ImageNet-Sketch")
448
+ top_image_2 = gr.Image(value=init_tops[1], type="pil", label="ImageNet-Sketch", show_label=False)
 
 
 
 
 
449
  act_value_2 = gr.Markdown(init_values[1])
450
 
451
  gr.Markdown("### Caltech101")
452
+ top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
 
 
453
  act_value_3 = gr.Markdown(init_values[2])
454
 
455
  image_display.select(
456
+ fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
457
  )
458
 
459
  model_selector.change(
460
+ fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
461
  )
462
 
463
  image_selector.select(
464
+ fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
465
  )
466
 
467
  radio_choices.change(
468
+ fn=update_markdown,
469
+ inputs=[radio_choices],
470
+ outputs=[markdown_display, markdown_display_2],
471
+ queue=True,
472
+ )
473
+
474
+ radio_choices.change(
475
+ fn=show_activation_heatmap_clip,
476
+ inputs=[image_selector, radio_choices, toggle_btn],
477
+ outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
478
+ queue=True,
479
+ )
480
+
481
+ radio_choices.change(
482
+ fn=show_activation_heatmap_maple,
483
+ inputs=[image_selector, radio_choices, model_selector],
484
+ outputs=[seg_mask_display_maple],
485
+ queue=True,
486
  )
487
 
488
+ # toggle_btn.change(
489
+ # fn=get_top_images,
490
+ # inputs=[radio_choices, toggle_btn],
491
+ # outputs=[top_image_1, top_image_2, top_image_3],
492
+ # queue=True,
493
+ # )
494
+
495
  toggle_btn.change(
496
  fn=show_activation_heatmap_clip,
497
  inputs=[image_selector, radio_choices, toggle_btn],
498
+ outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
499
+ queue=True,
 
 
 
 
 
 
 
500
  )
501
 
502
  # Launch the app
 
503
  demo.launch()