hyesulim commited on
Commit
928847c
·
verified ·
1 Parent(s): ee02e77

revert changes

Browse files
Files changed (1) hide show
  1. app.py +29 -216
app.py CHANGED
@@ -372,233 +372,61 @@ def load_all_data(image_root, pkl_root):
372
  return data_dict, sae_data_dict
373
 
374
 
375
- # data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
376
- # default_image_name = "christmas-imagenet"
377
-
378
-
379
- # with gr.Blocks(
380
- # theme=gr.themes.Citrus(),
381
- # css="""
382
- # .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
383
- # .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
384
- # """,
385
- # ) as demo:
386
- # with gr.Row():
387
- # with gr.Column():
388
- # # Left View: Image selection and click handling
389
- # gr.Markdown("## Select input image and patch on the image")
390
- # image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
391
- # image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
392
-
393
- # # Update image display when a new image is selected
394
- # image_selector.change(
395
- # fn=lambda img_name: data_dict[img_name]["image"], inputs=image_selector, outputs=image_display
396
- # )
397
- # image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
398
-
399
- # with gr.Column():
400
- # gr.Markdown("## SAE latent activations of CLIP and MaPLE")
401
- # model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
402
- # model_selector = gr.Dropdown(
403
- # choices=model_options, value=model_options[0], label="Select adapted model (MaPLe)"
404
- # )
405
- # init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
406
- # neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
407
-
408
- # image_selector.change(
409
- # fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
410
- # )
411
- # image_display.select(
412
- # fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
413
- # )
414
- # model_selector.change(fn=load_image, inputs=[image_selector], outputs=image_display)
415
- # model_selector.change(
416
- # fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
417
- # )
418
-
419
- # with gr.Row():
420
- # with gr.Column():
421
- # radio_names = get_init_radio_options(default_image_name, model_options[0])
422
-
423
- # feautre_idx = radio_names[0].split("-")[-1]
424
- # markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
425
- # init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
426
-
427
- # gr.Markdown("### Localize SAE latent activation using CLIP")
428
- # seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
429
- # init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], model_options[0])
430
- # gr.Markdown("### Localize SAE latent activation using MaPLE")
431
- # seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
432
-
433
- # with gr.Column():
434
- # gr.Markdown("## Top activating SAE latent index")
435
-
436
- # radio_choices = gr.Radio(
437
- # choices=radio_names, label="Top activating SAE latent", interactive=True, value=radio_names[0]
438
- # )
439
- # toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
440
-
441
- # markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
442
-
443
- # gr.Markdown("### ImageNet")
444
- # top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
445
- # act_value_1 = gr.Markdown(init_values[0])
446
-
447
- # gr.Markdown("### ImageNet-Sketch")
448
- # top_image_2 = gr.Image(value=init_tops[1], type="pil", label="ImageNet-Sketch", show_label=False)
449
- # act_value_2 = gr.Markdown(init_values[1])
450
-
451
- # gr.Markdown("### Caltech101")
452
- # top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
453
- # act_value_3 = gr.Markdown(init_values[2])
454
-
455
- # image_display.select(
456
- # fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
457
- # )
458
-
459
- # model_selector.change(
460
- # fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
461
- # )
462
-
463
- # image_selector.select(
464
- # fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
465
- # )
466
-
467
- # radio_choices.change(
468
- # fn=update_markdown,
469
- # inputs=[radio_choices],
470
- # outputs=[markdown_display, markdown_display_2],
471
- # queue=True,
472
- # )
473
-
474
- # radio_choices.change(
475
- # fn=show_activation_heatmap_clip,
476
- # inputs=[image_selector, radio_choices, toggle_btn],
477
- # outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
478
- # queue=True,
479
- # )
480
-
481
- # radio_choices.change(
482
- # fn=show_activation_heatmap_maple,
483
- # inputs=[image_selector, radio_choices, model_selector],
484
- # outputs=[seg_mask_display_maple],
485
- # queue=True,
486
- # )
487
-
488
- # # toggle_btn.change(
489
- # # fn=get_top_images,
490
- # # inputs=[radio_choices, toggle_btn],
491
- # # outputs=[top_image_1, top_image_2, top_image_3],
492
- # # queue=True,
493
- # # )
494
-
495
- # toggle_btn.change(
496
- # fn=show_activation_heatmap_clip,
497
- # inputs=[image_selector, radio_choices, toggle_btn],
498
- # outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
499
- # queue=True,
500
- # )
501
-
502
- # # Launch the app
503
- # demo.launch()
504
-
505
- # Precompute all necessary data and store in caches before launching the Gradio app.
506
-
507
- # Load data once at startup
508
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
509
  default_image_name = "christmas-imagenet"
510
- model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
511
- default_model = model_options[0]
512
 
513
- # Precompute activation distributions for all images/models to avoid repeated I/O.
514
- activation_cache = {}
515
- for img_name in data_dict.keys():
516
- for mdl in ["CLIP"] + model_options:
517
- activation_cache[(img_name, mdl)] = get_activation_distribution(img_name, mdl)
518
-
519
- # Precompute initial radio options and top-neuron related info for default states.
520
- radio_names = get_init_radio_options(default_image_name, default_model)
521
- feautre_idx = radio_names[0].split("-")[-1]
522
-
523
- # Precompute initial figures and mask overlays so they don't need to be recomputed on load.
524
- init_plot = plot_activation_distribution(None, default_image_name, default_model)
525
- init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
526
- init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], default_model)
527
 
528
  with gr.Blocks(
529
  theme=gr.themes.Citrus(),
530
  css="""
531
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
532
  .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
533
- """
534
  ) as demo:
535
  with gr.Row():
536
  with gr.Column():
 
537
  gr.Markdown("## Select input image and patch on the image")
538
-
539
- # Instead of recomputing, just directly load from data_dict
540
  image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
541
  image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
542
 
543
- # When image changes, just display the corresponding image
544
- def update_image_display(img_name):
545
- return data_dict[img_name]["image"]
546
-
547
  image_selector.change(
548
- fn=update_image_display,
549
- inputs=image_selector,
550
- outputs=image_display
551
  )
552
-
553
- # Highlight selected grid cell
554
  image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
555
 
556
  with gr.Column():
557
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
558
-
559
  model_selector = gr.Dropdown(
560
- choices=model_options,
561
- value=default_model,
562
- label="Select adapted model (MaPLe)"
563
  )
564
-
565
  neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
566
 
567
- # Update plot based on image/model
568
- def update_plot(img_name, model_name):
569
- # Use precomputed activation distributions from activation_cache
570
- # to create the figure. If figure creation is expensive, consider caching plots as well.
571
- return plot_activation_distribution(None, img_name, model_name)
572
-
573
  image_selector.change(
574
- fn=update_plot,
575
- inputs=[image_selector, model_selector],
576
- outputs=neuron_plot
577
  )
578
  image_display.select(
579
- fn=update_plot,
580
- inputs=[image_selector, model_selector],
581
- outputs=neuron_plot
582
- )
583
- model_selector.change(
584
- fn=lambda img_name: data_dict[img_name]["image"],
585
- inputs=[image_selector],
586
- outputs=image_display
587
  )
 
588
  model_selector.change(
589
- fn=update_plot,
590
- inputs=[image_selector, model_selector],
591
- outputs=neuron_plot
592
  )
593
 
594
  with gr.Row():
595
  with gr.Column():
596
- # Use previously precomputed segmentation masks and tops instead of recomputing on load
 
 
597
  markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
 
598
 
599
  gr.Markdown("### Localize SAE latent activation using CLIP")
600
  seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
601
-
602
  gr.Markdown("### Localize SAE latent activation using MaPLE")
603
  seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
604
 
@@ -606,17 +434,12 @@ with gr.Blocks(
606
  gr.Markdown("## Top activating SAE latent index")
607
 
608
  radio_choices = gr.Radio(
609
- choices=radio_names,
610
- label="Top activating SAE latent",
611
- interactive=True,
612
- value=radio_names[0]
613
  )
614
-
615
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
616
 
617
  markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
618
 
619
- # Display precomputed top images and values
620
  gr.Markdown("### ImageNet")
621
  top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
622
  act_value_1 = gr.Markdown(init_values[0])
@@ -629,31 +452,18 @@ with gr.Blocks(
629
  top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
630
  act_value_3 = gr.Markdown(init_values[2])
631
 
632
- # Update radio choices when image/model changes.
633
- # If expensive, this could be cached as well.
634
- def on_image_or_model_change(img_name, model_name):
635
- return update_radio_options(None, img_name, model_name)
636
-
637
  image_display.select(
638
- fn=on_image_or_model_change,
639
- inputs=[image_selector, model_selector],
640
- outputs=[radio_choices],
641
- queue=True
642
  )
 
643
  model_selector.change(
644
- fn=on_image_or_model_change,
645
- inputs=[image_selector, model_selector],
646
- outputs=[radio_choices],
647
- queue=True
648
  )
 
649
  image_selector.select(
650
- fn=on_image_or_model_change,
651
- inputs=[image_selector, model_selector],
652
- outputs=[radio_choices],
653
- queue=True
654
  )
655
 
656
- # Update markdown titles dynamically based on selected radio choice
657
  radio_choices.change(
658
  fn=update_markdown,
659
  inputs=[radio_choices],
@@ -661,7 +471,6 @@ with gr.Blocks(
661
  queue=True,
662
  )
663
 
664
- # Show activation heatmap for CLIP
665
  radio_choices.change(
666
  fn=show_activation_heatmap_clip,
667
  inputs=[image_selector, radio_choices, toggle_btn],
@@ -669,7 +478,6 @@ with gr.Blocks(
669
  queue=True,
670
  )
671
 
672
- # Show activation heatmap for MaPLE
673
  radio_choices.change(
674
  fn=show_activation_heatmap_maple,
675
  inputs=[image_selector, radio_choices, model_selector],
@@ -677,7 +485,13 @@ with gr.Blocks(
677
  queue=True,
678
  )
679
 
680
- # Toggle segmentation mask
 
 
 
 
 
 
681
  toggle_btn.change(
682
  fn=show_activation_heatmap_clip,
683
  inputs=[image_selector, radio_choices, toggle_btn],
@@ -687,4 +501,3 @@ with gr.Blocks(
687
 
688
  # Launch the app
689
  demo.launch()
690
-
 
372
  return data_dict, sae_data_dict
373
 
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
376
  default_image_name = "christmas-imagenet"
 
 
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
  with gr.Blocks(
380
  theme=gr.themes.Citrus(),
381
  css="""
382
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
383
  .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
384
+ """,
385
  ) as demo:
386
  with gr.Row():
387
  with gr.Column():
388
+ # Left View: Image selection and click handling
389
  gr.Markdown("## Select input image and patch on the image")
 
 
390
  image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
391
  image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
392
 
393
+ # Update image display when a new image is selected
 
 
 
394
  image_selector.change(
395
+ fn=lambda img_name: data_dict[img_name]["image"], inputs=image_selector, outputs=image_display
 
 
396
  )
 
 
397
  image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
398
 
399
  with gr.Column():
400
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
401
+ model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
402
  model_selector = gr.Dropdown(
403
+ choices=model_options, value=model_options[0], label="Select adapted model (MaPLe)"
 
 
404
  )
405
+ init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
406
  neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
407
 
 
 
 
 
 
 
408
  image_selector.change(
409
+ fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
410
  )
411
  image_display.select(
412
+ fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
 
 
 
 
 
413
  )
414
+ model_selector.change(fn=load_image, inputs=[image_selector], outputs=image_display)
415
  model_selector.change(
416
+ fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
417
  )
418
 
419
  with gr.Row():
420
  with gr.Column():
421
+ radio_names = get_init_radio_options(default_image_name, model_options[0])
422
+
423
+ feautre_idx = radio_names[0].split("-")[-1]
424
  markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
425
+ init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
426
 
427
  gr.Markdown("### Localize SAE latent activation using CLIP")
428
  seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
429
+ init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], model_options[0])
430
  gr.Markdown("### Localize SAE latent activation using MaPLE")
431
  seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
432
 
 
434
  gr.Markdown("## Top activating SAE latent index")
435
 
436
  radio_choices = gr.Radio(
437
+ choices=radio_names, label="Top activating SAE latent", interactive=True, value=radio_names[0]
 
 
 
438
  )
 
439
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
440
 
441
  markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
442
 
 
443
  gr.Markdown("### ImageNet")
444
  top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
445
  act_value_1 = gr.Markdown(init_values[0])
 
452
  top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
453
  act_value_3 = gr.Markdown(init_values[2])
454
 
 
 
 
 
 
455
  image_display.select(
456
+ fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
 
457
  )
458
+
459
  model_selector.change(
460
+ fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
 
461
  )
462
+
463
  image_selector.select(
464
+ fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
 
465
  )
466
 
 
467
  radio_choices.change(
468
  fn=update_markdown,
469
  inputs=[radio_choices],
 
471
  queue=True,
472
  )
473
 
 
474
  radio_choices.change(
475
  fn=show_activation_heatmap_clip,
476
  inputs=[image_selector, radio_choices, toggle_btn],
 
478
  queue=True,
479
  )
480
 
 
481
  radio_choices.change(
482
  fn=show_activation_heatmap_maple,
483
  inputs=[image_selector, radio_choices, model_selector],
 
485
  queue=True,
486
  )
487
 
488
+ # toggle_btn.change(
489
+ # fn=get_top_images,
490
+ # inputs=[radio_choices, toggle_btn],
491
+ # outputs=[top_image_1, top_image_2, top_image_3],
492
+ # queue=True,
493
+ # )
494
+
495
  toggle_btn.change(
496
  fn=show_activation_heatmap_clip,
497
  inputs=[image_selector, radio_choices, toggle_btn],
 
501
 
502
  # Launch the app
503
  demo.launch()