hyesulim commited on
Commit
4cd925e
·
verified ·
1 Parent(s): c08c98f

test: fixed typo

Browse files
Files changed (1) hide show
  1. app.py +12 -813
app.py CHANGED
@@ -155,805 +155,6 @@ def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]:
155
 
156
  return _CACHE['top_images'][cache_key]
157
 
158
- # Initialize data
159
- data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
160
-
161
-
162
- # def preload_activation(image_name):
163
- # for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
164
- # image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
165
- # with gzip.open(image_file, "rb") as f:
166
- # preloaded_data[model] = pickle.load(f)
167
-
168
-
169
- # def get_activation_distribution(image_name: str, model_type: str):
170
- # activation = get_data(image_name, model_type)[0]
171
-
172
- # noisy_features_indices = (
173
- # (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
174
- # )
175
- # activation[:, noisy_features_indices] = 0
176
-
177
- # return activation
178
-
179
-
180
- def get_grid_loc(evt, image):
181
- # Get click coordinates
182
- x, y = evt._data["index"][0], evt._data["index"][1]
183
-
184
- cell_width = image.width // GRID_NUM
185
- cell_height = image.height // GRID_NUM
186
-
187
- grid_x = x // cell_width
188
- grid_y = y // cell_height
189
- return grid_x, grid_y, cell_width, cell_height
190
-
191
-
192
- def highlight_grid(evt: gr.EventData, image_name):
193
- image = data_dict[image_name]["image"]
194
- grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
195
-
196
- highlighted_image = image.copy()
197
- draw = ImageDraw.Draw(highlighted_image)
198
- box = [
199
- grid_x * cell_width,
200
- grid_y * cell_height,
201
- (grid_x + 1) * cell_width,
202
- (grid_y + 1) * cell_height,
203
- ]
204
- draw.rectangle(box, outline="red", width=3)
205
-
206
- return highlighted_image
207
-
208
-
209
- def load_image(img_name):
210
- return Image.open(data_dict[img_name]["image_path"]).resize(
211
- (IMAGE_SIZE, IMAGE_SIZE)
212
- )
213
-
214
-
215
- def plot_activations(
216
- all_activation,
217
- tile_activations=None,
218
- grid_x=None,
219
- grid_y=None,
220
- top_k=5,
221
- colors=("blue", "cyan"),
222
- model_name="CLIP",
223
- ):
224
- fig = go.Figure()
225
-
226
- def _add_scatter_with_annotation(fig, activations, model_name, color, label):
227
- fig.add_trace(
228
- go.Scatter(
229
- x=np.arange(len(activations)),
230
- y=activations,
231
- mode="lines",
232
- name=label,
233
- line=dict(color=color, dash="solid"),
234
- showlegend=True,
235
- )
236
- )
237
- top_neurons = np.argsort(activations)[::-1][:top_k]
238
- for idx in top_neurons:
239
- fig.add_annotation(
240
- x=idx,
241
- y=activations[idx],
242
- text=str(idx),
243
- showarrow=True,
244
- arrowhead=2,
245
- ax=0,
246
- ay=-15,
247
- arrowcolor=color,
248
- opacity=0.7,
249
- )
250
- return fig
251
-
252
- label = f"{model_name.split('-')[-0]} Image-level"
253
- fig = _add_scatter_with_annotation(
254
- fig, all_activation, model_name, colors[0], label
255
- )
256
- if tile_activations is not None:
257
- label = f"{model_name.split('-')[-0]} Tile ({grid_x}, {grid_y})"
258
- fig = _add_scatter_with_annotation(
259
- fig, tile_activations, model_name, colors[1], label
260
- )
261
-
262
- fig.update_layout(
263
- title="Activation Distribution",
264
- xaxis_title="SAE latent index",
265
- yaxis_title="Activation Value",
266
- template="plotly_white",
267
- )
268
- fig.update_layout(
269
- legend=dict(orientation="h", yanchor="middle", y=0.5, xanchor="center", x=0.5)
270
- )
271
-
272
- return fig
273
-
274
-
275
- def get_activations(evt: gr.EventData, selected_image: str, model_name: str, colors):
276
- activation = get_activation_distribution(selected_image, model_name)
277
- all_activation = activation.mean(0)
278
-
279
- tile_activations = None
280
- grid_x = None
281
- grid_y = None
282
-
283
- if evt is not None:
284
- if evt._data is not None:
285
- image = data_dict[selected_image]["image"]
286
- grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
287
- token_idx = grid_y * GRID_NUM + grid_x + 1
288
- tile_activations = activation[token_idx]
289
-
290
- fig = plot_activations(
291
- all_activation,
292
- tile_activations,
293
- grid_x,
294
- grid_y,
295
- top_k=5,
296
- model_name=model_name,
297
- colors=colors,
298
- )
299
- return fig
300
-
301
-
302
- def plot_activation_distribution(
303
- evt: gr.EventData, selected_image: str, model_name: str
304
- ):
305
- fig = make_subplots(
306
- rows=2,
307
- cols=1,
308
- shared_xaxes=True,
309
- subplot_titles=["CLIP Activation", f"{model_name} Activation"],
310
- )
311
-
312
- fig_clip = get_activations(
313
- evt, selected_image, "CLIP", colors=("#00b4d8", "#90e0ef")
314
- )
315
- fig_maple = get_activations(
316
- evt, selected_image, model_name, colors=("#ff5a5f", "#ffcad4")
317
- )
318
-
319
- def _attach_fig(fig, sub_fig, row, col, yref):
320
- for trace in sub_fig.data:
321
- fig.add_trace(trace, row=row, col=col)
322
-
323
- for annotation in sub_fig.layout.annotations:
324
- annotation.update(yref=yref)
325
- fig.add_annotation(annotation)
326
- return fig
327
-
328
- fig = _attach_fig(fig, fig_clip, row=1, col=1, yref="y1")
329
- fig = _attach_fig(fig, fig_maple, row=2, col=1, yref="y2")
330
-
331
- fig.update_xaxes(title_text="SAE Latent Index", row=2, col=1)
332
- fig.update_xaxes(title_text="SAE Latent Index", row=1, col=1)
333
- fig.update_yaxes(title_text="Activation Value", row=1, col=1)
334
- fig.update_yaxes(title_text="Activation Value", row=2, col=1)
335
- fig.update_layout(
336
- # height=500,
337
- # title="Activation Distributions",
338
- template="plotly_white",
339
- showlegend=True,
340
- legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
341
- margin=dict(l=20, r=20, t=40, b=20),
342
- )
343
-
344
- return fig
345
-
346
-
347
- # def get_segmask(selected_image, slider_value, model_type):
348
- # image = data_dict[selected_image]["image"]
349
- # sae_act = get_data(selected_image, model_type)[0]
350
- # temp = sae_act[:, slider_value]
351
- # try:
352
- # mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
353
- # except Exception as e:
354
- # print(sae_act.shape, slider_value)
355
- # mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
356
- # 0
357
- # ].numpy()
358
- # mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
359
-
360
- # base_opacity = 30
361
- # image_array = np.array(image)[..., :3]
362
- # rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
363
- # rgba_overlay[..., :3] = image_array[..., :3]
364
-
365
- # darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
366
- # rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
367
- # rgba_overlay[..., 3] = 255 # Fully opaque
368
-
369
- # return rgba_overlay
370
-
371
-
372
- # def get_top_images(slider_value, toggle_btn):
373
- # def _get_images(dataset_path):
374
- # top_image_paths = [
375
- # os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
376
- # os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
377
- # os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
378
- # ]
379
- # top_images = [
380
- # (
381
- # Image.open(path)
382
- # if os.path.exists(path)
383
- # else Image.new("RGB", (256, 256), (255, 255, 255))
384
- # )
385
- # for path in top_image_paths
386
- # ]
387
- # return top_images
388
-
389
- # if toggle_btn:
390
- # top_images = _get_images("./data/top_images_masked")
391
- # else:
392
- # top_images = _get_images("./data/top_images")
393
- # return top_images
394
-
395
-
396
- def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
397
- slider_value = int(slider_value.split("-")[-1])
398
- rgba_overlay = get_segmask(selected_image, slider_value, model_type)
399
- top_images = get_top_images(slider_value, toggle_btn)
400
-
401
- act_values = []
402
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
403
- act_value = sae_data_dict["mean_act_values"][dataset][slider_value, :5]
404
- act_value = [str(round(value, 3)) for value in act_value]
405
- act_value = " | ".join(act_value)
406
- out = f"#### Activation values: {act_value}"
407
- act_values.append(out)
408
-
409
- return rgba_overlay, top_images, act_values
410
-
411
-
412
- def show_activation_heatmap_clip(selected_image, slider_value, toggle_btn):
413
- rgba_overlay, top_images, act_values = show_activation_heatmap(
414
- selected_image, slider_value, "CLIP", toggle_btn
415
- )
416
- sleep(0.1)
417
- return (
418
- rgba_overlay,
419
- top_images[0],
420
- top_images[1],
421
- top_images[2],
422
- act_values[0],
423
- act_values[1],
424
- act_values[2],
425
- )
426
-
427
-
428
- def show_activation_heatmap_maple(selected_image, slider_value, model_name):
429
- slider_value = int(slider_value.split("-")[-1])
430
- rgba_overlay = get_segmask(selected_image, slider_value, model_name)
431
- sleep(0.1)
432
- return rgba_overlay
433
-
434
-
435
- def get_init_radio_options(selected_image, model_name):
436
- clip_neuron_dict = {}
437
- maple_neuron_dict = {}
438
-
439
- def _get_top_actvation(selected_image, model_name, neuron_dict, top_k=5):
440
- activations = get_activation_distribution(selected_image, model_name).mean(0)
441
- top_neurons = list(np.argsort(activations)[::-1][:top_k])
442
- for top_neuron in top_neurons:
443
- neuron_dict[top_neuron] = activations[top_neuron]
444
- sorted_dict = dict(
445
- sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
446
- )
447
- return sorted_dict
448
-
449
- clip_neuron_dict = _get_top_actvation(selected_image, "CLIP", clip_neuron_dict)
450
- maple_neuron_dict = _get_top_actvation(
451
- selected_image, model_name, maple_neuron_dict
452
- )
453
-
454
- radio_choices = get_radio_names(clip_neuron_dict, maple_neuron_dict)
455
-
456
- return radio_choices
457
-
458
-
459
- def get_radio_names(clip_neuron_dict, maple_neuron_dict):
460
- clip_keys = list(clip_neuron_dict.keys())
461
- maple_keys = list(maple_neuron_dict.keys())
462
-
463
- common_keys = list(set(clip_keys).intersection(set(maple_keys)))
464
- clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
465
- maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
466
-
467
- common_keys.sort(
468
- key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
469
- )
470
- clip_only_keys.sort(reverse=True)
471
- maple_only_keys.sort(reverse=True)
472
-
473
- out = []
474
- out.extend([f"common-{i}" for i in common_keys[:5]])
475
- out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
476
- out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
477
-
478
- return out
479
-
480
-
481
- def update_radio_options(evt: gr.EventData, selected_image, model_name):
482
- def _sort_and_save_top_k(activations, neuron_dict, top_k=5):
483
- top_neurons = list(np.argsort(activations)[::-1][:top_k])
484
- for top_neuron in top_neurons:
485
- neuron_dict[top_neuron] = activations[top_neuron]
486
-
487
- def _get_top_actvation(evt, selected_image, model_name, neuron_dict):
488
- all_activation = get_activation_distribution(selected_image, model_name)
489
- image_activation = all_activation.mean(0)
490
- _sort_and_save_top_k(image_activation, neuron_dict)
491
-
492
- if evt is not None:
493
- if evt._data is not None and isinstance(evt._data["index"], list):
494
- image = data_dict[selected_image]["image"]
495
- grid_x, grid_y, cell_width, cell_height = get_grid_loc(evt, image)
496
- token_idx = grid_y * GRID_NUM + grid_x + 1
497
- tile_activations = all_activation[token_idx]
498
- _sort_and_save_top_k(tile_activations, neuron_dict)
499
-
500
- sorted_dict = dict(
501
- sorted(neuron_dict.items(), key=lambda item: item[1], reverse=True)
502
- )
503
- return sorted_dict
504
-
505
- clip_neuron_dict = {}
506
- maple_neuron_dict = {}
507
- clip_neuron_dict = _get_top_actvation(evt, selected_image, "CLIP", clip_neuron_dict)
508
- maple_neuron_dict = _get_top_actvation(
509
- evt, selected_image, model_name, maple_neuron_dict
510
- )
511
-
512
- clip_keys = list(clip_neuron_dict.keys())
513
- maple_keys = list(maple_neuron_dict.keys())
514
-
515
- common_keys = list(set(clip_keys).intersection(set(maple_keys)))
516
- clip_only_keys = list(set(clip_keys) - (set(maple_keys)))
517
- maple_only_keys = list(set(maple_keys) - (set(clip_keys)))
518
-
519
- common_keys.sort(
520
- key=lambda x: max(clip_neuron_dict[x], maple_neuron_dict[x]), reverse=True
521
- )
522
- clip_only_keys.sort(reverse=True)
523
- maple_only_keys.sort(reverse=True)
524
-
525
- out = []
526
- out.extend([f"common-{i}" for i in common_keys[:5]])
527
- out.extend([f"CLIP-{i}" for i in clip_only_keys[:5]])
528
- out.extend([f"MaPLE-{i}" for i in maple_only_keys[:5]])
529
-
530
- radio_choices = gr.Radio(
531
- choices=out, label="Top activating SAE latent", value=out[0]
532
- )
533
- sleep(0.1)
534
- return radio_choices
535
-
536
-
537
- def update_markdown(option_value):
538
- latent_idx = int(option_value.split("-")[-1])
539
- out_1 = f"## Segmentation mask for the selected SAE latent - {latent_idx}"
540
- out_2 = f"## Top reference images for the selected SAE latent - {latent_idx}"
541
- return out_1, out_2
542
-
543
-
544
- def get_data(image_name, model_name):
545
- pkl_root = "./data/out"
546
- data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
547
- with gzip.open(data_dir, "rb") as f:
548
- data = pickle.load(f)
549
- out = data
550
-
551
- return out
552
-
553
-
554
- def update_all(selected_image, slider_value, toggle_btn, model_name):
555
- (
556
- seg_mask_display,
557
- top_image_1,
558
- top_image_2,
559
- top_image_3,
560
- act_value_1,
561
- act_value_2,
562
- act_value_3,
563
- ) = show_activation_heatmap_clip(selected_image, slider_value, toggle_btn)
564
- seg_mask_display_maple = show_activation_heatmap_maple(
565
- selected_image, slider_value, model_name
566
- )
567
- markdown_display, markdown_display_2 = update_markdown(slider_value)
568
-
569
- return (
570
- seg_mask_display,
571
- seg_mask_display_maple,
572
- top_image_1,
573
- top_image_2,
574
- top_image_3,
575
- act_value_1,
576
- act_value_2,
577
- act_value_3,
578
- markdown_display,
579
- markdown_display_2,
580
- )
581
-
582
-
583
- def load_all_data(image_root, pkl_root):
584
- image_files = glob(f"{image_root}/*")
585
- data_dict = {}
586
- for image_file in image_files:
587
- image_name = os.path.basename(image_file).split(".")[0]
588
- if image_file not in data_dict:
589
- data_dict[image_name] = {
590
- "image": Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE)),
591
- "image_path": image_file,
592
- }
593
-
594
- sae_data_dict = {}
595
- with open("./data/sae_data/mean_acts.pkl", "rb") as f:
596
- data = pickle.load(f)
597
- sae_data_dict["mean_acts"] = data
598
-
599
- sae_data_dict["mean_act_values"] = {}
600
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]:
601
- with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
602
- data = pickle.load(f)
603
- sae_data_dict["mean_act_values"][dataset] = data
604
-
605
- return data_dict, sae_data_dict
606
-
607
-
608
- # data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
609
- default_image_name = "christmas-imagenet"
610
-
611
-
612
- with gr.Blocks(
613
- theme=gr.themes.Citrus(),
614
- css="""
615
- .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
616
- .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
617
- """,
618
- ) as demo:
619
- with gr.Row():
620
- with gr.Column():
621
- # Left View: Image selection and click handling
622
- gr.Markdown("## Select input image and patch on the image")
623
- image_selector = gr.Dropdown(
624
- choices=list(data_dict.keys()),
625
- value=default_image_name,
626
- label="Select Image",
627
- )
628
- image_display = gr.Image(
629
- value=data_dict[default_image_name]["image"],
630
- type="pil",
631
- interactive=True,
632
- )
633
-
634
- # Update image display when a new image is selected
635
- image_selector.change(
636
- fn=lambda img_name: data_dict[img_name]["image"],
637
- inputs=image_selector,
638
- outputs=image_display,
639
- )
640
- image_display.select(
641
- fn=highlight_grid, inputs=[image_selector], outputs=[image_display]
642
- )
643
-
644
- with gr.Column():
645
- gr.Markdown("## SAE latent activations of CLIP and MaPLE")
646
- model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
647
- model_selector = gr.Dropdown(
648
- choices=model_options,
649
- value=model_options[0],
650
- label="Select adapted model (MaPLe)",
651
- )
652
- init_plot = plot_activation_distribution(
653
- None, default_image_name, model_options[0]
654
- )
655
- neuron_plot = gr.Plot(
656
- label="Neuron Activation", value=init_plot, show_label=False
657
- )
658
-
659
- image_selector.change(
660
- fn=plot_activation_distribution,
661
- inputs=[image_selector, model_selector],
662
- outputs=neuron_plot,
663
- )
664
- image_display.select(
665
- fn=plot_activation_distribution,
666
- inputs=[image_selector, model_selector],
667
- outputs=neuron_plot,
668
- )
669
- model_selector.change(
670
- fn=load_image, inputs=[image_selector], outputs=image_display
671
- )
672
- model_selector.change(
673
- fn=plot_activation_distribution,
674
- inputs=[image_selector, model_selector],
675
- outputs=neuron_plot,
676
- )
677
-
678
- with gr.Row():
679
- with gr.Column():
680
- radio_names = get_init_radio_options(default_image_name, model_options[0])
681
-
682
- feautre_idx = radio_names[0].split("-")[-1]
683
- markdown_display = gr.Markdown(
684
- f"## Segmentation mask for the selected SAE latent - {feautre_idx}"
685
- )
686
- init_seg, init_tops, init_values = show_activation_heatmap(
687
- default_image_name, radio_names[0], "CLIP"
688
- )
689
-
690
- gr.Markdown("### Localize SAE latent activation using CLIP")
691
- seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
692
- init_seg_maple, _, _ = show_activation_heatmap(
693
- default_image_name, radio_names[0], model_options[0]
694
- )
695
- gr.Markdown("### Localize SAE latent activation using MaPLE")
696
- seg_mask_display_maple = gr.Image(
697
- value=init_seg_maple, type="pil", show_label=False
698
- )
699
-
700
- with gr.Column():
701
- gr.Markdown("## Top activating SAE latent index")
702
-
703
- radio_choices = gr.Radio(
704
- choices=radio_names,
705
- label="Top activating SAE latent",
706
- interactive=True,
707
- value=radio_names[0],
708
- )
709
- toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
710
-
711
- markdown_display_2 = gr.Markdown(
712
- f"## Top reference images for the selected SAE latent - {feautre_idx}"
713
- )
714
-
715
- gr.Markdown("### ImageNet")
716
- top_image_1 = gr.Image(
717
- value=init_tops[0], type="pil", label="ImageNet", show_label=False
718
- )
719
- act_value_1 = gr.Markdown(init_values[0])
720
-
721
- gr.Markdown("### ImageNet-Sketch")
722
- top_image_2 = gr.Image(
723
- value=init_tops[1],
724
- type="pil",
725
- label="ImageNet-Sketch",
726
- show_label=False,
727
- )
728
- act_value_2 = gr.Markdown(init_values[1])
729
-
730
- gr.Markdown("### Caltech101")
731
- top_image_3 = gr.Image(
732
- value=init_tops[2], type="pil", label="Caltech101", show_label=False
733
- )
734
- act_value_3 = gr.Markdown(init_values[2])
735
-
736
- image_display.select(
737
- fn=update_radio_options,
738
- inputs=[image_selector, model_selector],
739
- outputs=[radio_choices],
740
- )
741
-
742
- model_selector.change(
743
- fn=update_radio_options,
744
- inputs=[image_selector, model_selector],
745
- outputs=[radio_choices],
746
- )
747
-
748
- image_selector.select(
749
- fn=update_radio_options,
750
- inputs=[image_selector, model_selector],
751
- outputs=[radio_choices],
752
- )
753
-
754
- radio_choices.change(
755
- fn=update_all,
756
- inputs=[image_selector, radio_choices, toggle_btn, model_selector],
757
- outputs=[
758
- seg_mask_display,
759
- seg_mask_display_maple,
760
- top_image_1,
761
- top_image_2,
762
- top_image_3,
763
- act_value_1,
764
- act_value_2,
765
- act_value_3,
766
- markdown_display,
767
- markdown_display_2,
768
- ],
769
- )
770
-
771
- toggle_btn.change(
772
- fn=show_activation_heatmap_clip,
773
- inputs=[image_selector, radio_choices, toggle_btn],
774
- outputs=[
775
- seg_mask_display,
776
- top_image_1,
777
- top_image_2,
778
- top_image_3,
779
- act_value_1,
780
- act_value_2,
781
- act_value_3,
782
- ],
783
- )
784
-
785
- # Launch the app
786
- # demo.queue()
787
- # demo.launch()
788
-
789
-
790
- if __name__ == "__main__":
791
- demo.queue() # Enable queuing for better handling of concurrent users
792
- demo.launch(
793
- server_name="0.0.0.0", # Allow external access
794
- server_port=7860,
795
- share=False, # Set to True if you want to create a public URL
796
- show_error=True,
797
- # Optimize concurrency
798
- max_threads=8, # Adjust based on your CPU cores
799
- )
800
- import gzip
801
- import os
802
- import pickle
803
- from glob import glob
804
- from time import sleep
805
-
806
- from functools import lru_cache
807
- import concurrent.futures
808
- from typing import Dict, Tuple, List
809
-
810
- import gradio as gr
811
- import numpy as np
812
- import plotly.graph_objects as go
813
- import torch
814
- from PIL import Image, ImageDraw
815
- from plotly.subplots import make_subplots
816
-
817
- IMAGE_SIZE = 400
818
- DATASET_LIST = ["imagenet", "oxford_flowers", "ucf101", "caltech101", "dtd", "eurosat"]
819
- GRID_NUM = 14
820
- pkl_root = "./data/out"
821
- preloaded_data = {}
822
-
823
-
824
- # Global cache for data
825
- _CACHE = {
826
- 'data_dict': {},
827
- 'sae_data_dict': {},
828
- 'model_data': {},
829
- 'segmasks': {},
830
- 'top_images': {}
831
- }
832
-
833
- def load_all_data(image_root: str, pkl_root: str) -> Tuple[Dict, Dict]:
834
- """Load all data with optimized parallel processing."""
835
- # Load images in parallel
836
- with concurrent.futures.ThreadPoolExecutor() as executor:
837
- image_files = glob(f"{image_root}/*")
838
- future_to_file = {
839
- executor.submit(_load_image_file, image_file): image_file
840
- for image_file in image_files
841
- }
842
-
843
- for future in concurrent.futures.as_completed(future_to_file):
844
- image_file = future_to_file[future]
845
- image_name = os.path.basename(image_file).split(".")[0]
846
- result = future.result()
847
- if result is not None:
848
- _CACHE['data_dict'][image_name] = result
849
-
850
- # Load SAE data
851
- with open("./data/sae_data/mean_acts.pkl", "rb") as f:
852
- _CACHE['sae_data_dict']["mean_acts"] = pickle.load(f)
853
-
854
- # Load mean act values in parallel
855
- datasets = ["imagenet", "imagenet-sketch", "caltech101"]
856
- _CACHE['sae_data_dict']["mean_act_values"] = {}
857
-
858
- with concurrent.futures.ThreadPoolExecutor() as executor:
859
- future_to_dataset = {
860
- executor.submit(_load_mean_act_values, dataset): dataset
861
- for dataset in datasets
862
- }
863
-
864
- for future in concurrent.futures.as_completed(future_to_dataset):
865
- dataset = future_to_dataset[future]
866
- result = future.result()
867
- if result is not None:
868
- _CACHE['sae_data_dict']["mean_act_values"][dataset] = result
869
-
870
- return _CACHE['data_dict'], _CACHE['sae_data_dict']
871
-
872
- def _load_image_file(image_file: str) -> Dict:
873
- """Helper function to load a single image file."""
874
- try:
875
- image = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
876
- return {
877
- "image": image,
878
- "image_path": image_file,
879
- }
880
- except Exception as e:
881
- print(f"Error loading {image_file}: {e}")
882
- return None
883
-
884
- def _load_mean_act_values(dataset: str) -> np.ndarray:
885
- """Helper function to load mean act values for a dataset."""
886
- try:
887
- with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
888
- return pickle.load(f)
889
- except Exception as e:
890
- print(f"Error loading mean act values for {dataset}: {e}")
891
- return None
892
-
893
- @lru_cache(maxsize=1024)
894
- def get_data(image_name: str, model_name: str) -> np.ndarray:
895
- """Cached function to get model data."""
896
- cache_key = f"{model_name}_{image_name}"
897
- if cache_key not in _CACHE['model_data']:
898
- data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
899
- with gzip.open(data_dir, "rb") as f:
900
- _CACHE['model_data'][cache_key] = pickle.load(f)
901
- return _CACHE['model_data'][cache_key]
902
-
903
- @lru_cache(maxsize=1024)
904
- def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
905
- """Cached function to get activation distribution."""
906
- activation = get_data(image_name, model_type)[0]
907
- noisy_features_indices = (
908
- (_CACHE['sae_data_dict']["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
909
- )
910
- activation[:, noisy_features_indices] = 0
911
- return activation
912
-
913
- @lru_cache(maxsize=1024)
914
- def get_segmask(selected_image: str, slider_value: int, model_type: str) -> np.ndarray:
915
- """Cached function to get segmentation mask."""
916
- cache_key = f"{selected_image}_{slider_value}_{model_type}"
917
- if cache_key not in _CACHE['segmasks']:
918
- image = _CACHE['data_dict'][selected_image]["image"]
919
- sae_act = get_data(selected_image, model_type)[0]
920
- temp = sae_act[:, slider_value]
921
-
922
- mask = torch.Tensor(temp[1:].reshape(14, 14)).view(1, 1, 14, 14)
923
- mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
924
- mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
925
-
926
- base_opacity = 30
927
- image_array = np.array(image)[..., :3]
928
- rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
929
- rgba_overlay[..., :3] = image_array[..., :3]
930
-
931
- darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
932
- rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
933
- rgba_overlay[..., 3] = 255
934
-
935
- _CACHE['segmasks'][cache_key] = rgba_overlay
936
-
937
- return _CACHE['segmasks'][cache_key]
938
-
939
- @lru_cache(maxsize=1024)
940
- def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]:
941
- """Cached function to get top images."""
942
- cache_key = f"{slider_value}_{toggle_btn}"
943
- if cache_key not in _CACHE['top_images']:
944
- dataset_path = "./data/top_images_masked" if toggle_btn else "./data/top_images"
945
- paths = [
946
- os.path.join(dataset_path, dataset, f"{slider_value}.jpg")
947
- for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
948
- ]
949
-
950
- _CACHE['top_images'][cache_key] = [
951
- Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
952
- for path in paths
953
- ]
954
-
955
- return _CACHE['top_images'][cache_key]
956
-
957
 
958
  # def preload_activation(image_name):
959
  # for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
@@ -1413,9 +614,6 @@ def preload_all_model_data():
1413
  except Exception as e:
1414
  print(f"Error preloading {cache_key}: {e}")
1415
 
1416
- # Add to initialization
1417
- preload_all_model_data()
1418
-
1419
  def precompute_activations():
1420
  """Precompute and cache common activation patterns"""
1421
  print("Precomputing activations...")
@@ -1425,11 +623,7 @@ def precompute_activations():
1425
  cache_key = f"activation_{model_name}_{image_name}"
1426
  _CACHE['precomputed_activations'][cache_key] = activation.mean(0)
1427
 
1428
- # Add to _CACHE initialization
1429
- _CACHE['precomputed_activations'] = {}
1430
 
1431
- # Add to initialization
1432
- precompute_activations()
1433
 
1434
  def precompute_segmasks():
1435
  """Precompute common segmentation masks"""
@@ -1444,13 +638,6 @@ def precompute_segmasks():
1444
  except Exception as e:
1445
  print(f"Error precomputing mask {cache_key}: {e}")
1446
 
1447
- # Add to initialization
1448
- precompute_segmasks()
1449
-
1450
-
1451
- data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
1452
- default_image_name = "christmas-imagenet"
1453
-
1454
 
1455
  with gr.Blocks(
1456
  theme=gr.themes.Citrus(),
@@ -1672,6 +859,18 @@ if __name__ == "__main__":
1672
  import threading
1673
  start_memory_monitor()
1674
 
 
 
 
 
 
 
 
 
 
 
 
 
1675
  # Launch the app with memory-optimized settings
1676
  demo.queue(max_size=min(20, int(total_ram_gb))) # Scale queue with RAM
1677
  demo.launch(
 
155
 
156
  return _CACHE['top_images'][cache_key]
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  # def preload_activation(image_name):
160
  # for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
 
614
  except Exception as e:
615
  print(f"Error preloading {cache_key}: {e}")
616
 
 
 
 
617
  def precompute_activations():
618
  """Precompute and cache common activation patterns"""
619
  print("Precomputing activations...")
 
623
  cache_key = f"activation_{model_name}_{image_name}"
624
  _CACHE['precomputed_activations'][cache_key] = activation.mean(0)
625
 
 
 
626
 
 
 
627
 
628
  def precompute_segmasks():
629
  """Precompute common segmentation masks"""
 
638
  except Exception as e:
639
  print(f"Error precomputing mask {cache_key}: {e}")
640
 
 
 
 
 
 
 
 
641
 
642
  with gr.Blocks(
643
  theme=gr.themes.Citrus(),
 
859
  import threading
860
  start_memory_monitor()
861
 
862
+
863
+ # Add to initialization
864
+ preload_all_model_data()
865
+
866
+ _CACHE['precomputed_activations'] = {}
867
+ precompute_activations()
868
+ precompute_segmasks()
869
+
870
+ data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
871
+ default_image_name = "christmas-imagenet"
872
+
873
+
874
  # Launch the app with memory-optimized settings
875
  demo.queue(max_size=min(20, int(total_ram_gb))) # Scale queue with RAM
876
  demo.launch(