hyesulim commited on
Commit
c6fcf0b
·
verified ·
1 Parent(s): 0025d00

test: add lru cache

Browse files
Files changed (1) hide show
  1. app.py +212 -59
app.py CHANGED
@@ -4,6 +4,10 @@ import pickle
4
  from glob import glob
5
  from time import sleep
6
 
 
 
 
 
7
  import gradio as gr
8
  import numpy as np
9
  import plotly.graph_objects as go
@@ -18,23 +22,160 @@ pkl_root = "./data/out"
18
  preloaded_data = {}
19
 
20
 
21
- def preload_activation(image_name):
22
- for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
23
- image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
24
- with gzip.open(image_file, "rb") as f:
25
- preloaded_data[model] = pickle.load(f)
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def get_activation_distribution(image_name: str, model_type: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  activation = get_data(image_name, model_type)[0]
30
-
31
  noisy_features_indices = (
32
- (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
33
  )
34
  activation[:, noisy_features_indices] = 0
35
-
36
  return activation
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def get_grid_loc(evt, image):
40
  # Get click coordinates
@@ -203,53 +344,53 @@ def plot_activation_distribution(
203
  return fig
204
 
205
 
206
- def get_segmask(selected_image, slider_value, model_type):
207
- image = data_dict[selected_image]["image"]
208
- sae_act = get_data(selected_image, model_type)[0]
209
- temp = sae_act[:, slider_value]
210
- try:
211
- mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
212
- except Exception as e:
213
- print(sae_act.shape, slider_value)
214
- mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
215
- 0
216
- ].numpy()
217
- mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
218
-
219
- base_opacity = 30
220
- image_array = np.array(image)[..., :3]
221
- rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
222
- rgba_overlay[..., :3] = image_array[..., :3]
223
-
224
- darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
225
- rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
226
- rgba_overlay[..., 3] = 255 # Fully opaque
227
-
228
- return rgba_overlay
229
-
230
-
231
- def get_top_images(slider_value, toggle_btn):
232
- def _get_images(dataset_path):
233
- top_image_paths = [
234
- os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
235
- os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
236
- os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
237
- ]
238
- top_images = [
239
- (
240
- Image.open(path)
241
- if os.path.exists(path)
242
- else Image.new("RGB", (256, 256), (255, 255, 255))
243
- )
244
- for path in top_image_paths
245
- ]
246
- return top_images
247
-
248
- if toggle_btn:
249
- top_images = _get_images("./data/top_images_masked")
250
- else:
251
- top_images = _get_images("./data/top_images")
252
- return top_images
253
 
254
 
255
  def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
@@ -464,7 +605,7 @@ def load_all_data(image_root, pkl_root):
464
  return data_dict, sae_data_dict
465
 
466
 
467
- data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
468
  default_image_name = "christmas-imagenet"
469
 
470
 
@@ -643,4 +784,16 @@ with gr.Blocks(
643
 
644
  # Launch the app
645
  # demo.queue()
646
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from glob import glob
5
  from time import sleep
6
 
7
+ from functools import lru_cache
8
+ import concurrent.futures
9
+ from typing import Dict, Tuple, List
10
+
11
  import gradio as gr
12
  import numpy as np
13
  import plotly.graph_objects as go
 
22
  preloaded_data = {}
23
 
24
 
25
+ # Global cache for data
26
+ _CACHE = {
27
+ 'data_dict': {},
28
+ 'sae_data_dict': {},
29
+ 'model_data': {},
30
+ 'segmasks': {},
31
+ 'top_images': {}
32
+ }
33
+
34
+ def load_all_data(image_root: str, pkl_root: str) -> Tuple[Dict, Dict]:
35
+ """Load all data with optimized parallel processing."""
36
+ # Load images in parallel
37
+ with concurrent.futures.ThreadPoolExecutor() as executor:
38
+ image_files = glob(f"{image_root}/*")
39
+ future_to_file = {
40
+ executor.submit(_load_image_file, image_file): image_file
41
+ for image_file in image_files
42
+ }
43
+
44
+ for future in concurrent.futures.as_completed(future_to_file):
45
+ image_file = future_to_file[future]
46
+ image_name = os.path.basename(image_file).split(".")[0]
47
+ result = future.result()
48
+ if result is not None:
49
+ _CACHE['data_dict'][image_name] = result
50
+
51
+ # Load SAE data
52
+ with open("./data/sae_data/mean_acts.pkl", "rb") as f:
53
+ _CACHE['sae_data_dict']["mean_acts"] = pickle.load(f)
54
+
55
+ # Load mean act values in parallel
56
+ datasets = ["imagenet", "imagenet-sketch", "caltech101"]
57
+ _CACHE['sae_data_dict']["mean_act_values"] = {}
58
+
59
+ with concurrent.futures.ThreadPoolExecutor() as executor:
60
+ future_to_dataset = {
61
+ executor.submit(_load_mean_act_values, dataset): dataset
62
+ for dataset in datasets
63
+ }
64
+
65
+ for future in concurrent.futures.as_completed(future_to_dataset):
66
+ dataset = future_to_dataset[future]
67
+ result = future.result()
68
+ if result is not None:
69
+ _CACHE['sae_data_dict']["mean_act_values"][dataset] = result
70
+
71
+ return _CACHE['data_dict'], _CACHE['sae_data_dict']
72
+
73
+ def _load_image_file(image_file: str) -> Dict:
74
+ """Helper function to load a single image file."""
75
+ try:
76
+ image = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
77
+ return {
78
+ "image": image,
79
+ "image_path": image_file,
80
+ }
81
+ except Exception as e:
82
+ print(f"Error loading {image_file}: {e}")
83
+ return None
84
 
85
+ def _load_mean_act_values(dataset: str) -> np.ndarray:
86
+ """Helper function to load mean act values for a dataset."""
87
+ try:
88
+ with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
89
+ return pickle.load(f)
90
+ except Exception as e:
91
+ print(f"Error loading mean act values for {dataset}: {e}")
92
+ return None
93
+
94
+ @lru_cache(maxsize=1024)
95
+ def get_data(image_name: str, model_name: str) -> np.ndarray:
96
+ """Cached function to get model data."""
97
+ cache_key = f"{model_name}_{image_name}"
98
+ if cache_key not in _CACHE['model_data']:
99
+ data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
100
+ with gzip.open(data_dir, "rb") as f:
101
+ _CACHE['model_data'][cache_key] = pickle.load(f)
102
+ return _CACHE['model_data'][cache_key]
103
+
104
+ @lru_cache(maxsize=1024)
105
+ def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
106
+ """Cached function to get activation distribution."""
107
  activation = get_data(image_name, model_type)[0]
 
108
  noisy_features_indices = (
109
+ (_CACHE['sae_data_dict']["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
110
  )
111
  activation[:, noisy_features_indices] = 0
 
112
  return activation
113
 
114
+ @lru_cache(maxsize=1024)
115
+ def get_segmask(selected_image: str, slider_value: int, model_type: str) -> np.ndarray:
116
+ """Cached function to get segmentation mask."""
117
+ cache_key = f"{selected_image}_{slider_value}_{model_type}"
118
+ if cache_key not in _CACHE['segmasks']:
119
+ image = _CACHE['data_dict'][selected_image]["image"]
120
+ sae_act = get_data(selected_image, model_type)[0]
121
+ temp = sae_act[:, slider_value]
122
+
123
+ mask = torch.Tensor(temp[1:].reshape(14, 14)).view(1, 1, 14, 14)
124
+ mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
125
+ mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
126
+
127
+ base_opacity = 30
128
+ image_array = np.array(image)[..., :3]
129
+ rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
130
+ rgba_overlay[..., :3] = image_array[..., :3]
131
+
132
+ darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
133
+ rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
134
+ rgba_overlay[..., 3] = 255
135
+
136
+ _CACHE['segmasks'][cache_key] = rgba_overlay
137
+
138
+ return _CACHE['segmasks'][cache_key]
139
+
140
+ @lru_cache(maxsize=1024)
141
+ def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]:
142
+ """Cached function to get top images."""
143
+ cache_key = f"{slider_value}_{toggle_btn}"
144
+ if cache_key not in _CACHE['top_images']:
145
+ dataset_path = "./data/top_images_masked" if toggle_btn else "./data/top_images"
146
+ paths = [
147
+ os.path.join(dataset_path, dataset, f"{slider_value}.jpg")
148
+ for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
149
+ ]
150
+
151
+ _CACHE['top_images'][cache_key] = [
152
+ Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
153
+ for path in paths
154
+ ]
155
+
156
+ return _CACHE['top_images'][cache_key]
157
+
158
+ # Initialize data
159
+ data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
160
+
161
+
162
+ # def preload_activation(image_name):
163
+ # for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
164
+ # image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
165
+ # with gzip.open(image_file, "rb") as f:
166
+ # preloaded_data[model] = pickle.load(f)
167
+
168
+
169
+ # def get_activation_distribution(image_name: str, model_type: str):
170
+ # activation = get_data(image_name, model_type)[0]
171
+
172
+ # noisy_features_indices = (
173
+ # (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
174
+ # )
175
+ # activation[:, noisy_features_indices] = 0
176
+
177
+ # return activation
178
+
179
 
180
  def get_grid_loc(evt, image):
181
  # Get click coordinates
 
344
  return fig
345
 
346
 
347
+ # def get_segmask(selected_image, slider_value, model_type):
348
+ # image = data_dict[selected_image]["image"]
349
+ # sae_act = get_data(selected_image, model_type)[0]
350
+ # temp = sae_act[:, slider_value]
351
+ # try:
352
+ # mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
353
+ # except Exception as e:
354
+ # print(sae_act.shape, slider_value)
355
+ # mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
356
+ # 0
357
+ # ].numpy()
358
+ # mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
359
+
360
+ # base_opacity = 30
361
+ # image_array = np.array(image)[..., :3]
362
+ # rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
363
+ # rgba_overlay[..., :3] = image_array[..., :3]
364
+
365
+ # darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
366
+ # rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
367
+ # rgba_overlay[..., 3] = 255 # Fully opaque
368
+
369
+ # return rgba_overlay
370
+
371
+
372
+ # def get_top_images(slider_value, toggle_btn):
373
+ # def _get_images(dataset_path):
374
+ # top_image_paths = [
375
+ # os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
376
+ # os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
377
+ # os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
378
+ # ]
379
+ # top_images = [
380
+ # (
381
+ # Image.open(path)
382
+ # if os.path.exists(path)
383
+ # else Image.new("RGB", (256, 256), (255, 255, 255))
384
+ # )
385
+ # for path in top_image_paths
386
+ # ]
387
+ # return top_images
388
+
389
+ # if toggle_btn:
390
+ # top_images = _get_images("./data/top_images_masked")
391
+ # else:
392
+ # top_images = _get_images("./data/top_images")
393
+ # return top_images
394
 
395
 
396
  def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
 
605
  return data_dict, sae_data_dict
606
 
607
 
608
+ # data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
609
  default_image_name = "christmas-imagenet"
610
 
611
 
 
784
 
785
  # Launch the app
786
  # demo.queue()
787
+ # demo.launch()
788
+
789
+
790
+ if __name__ == "__main__":
791
+ demo.queue() # Enable queuing for better handling of concurrent users
792
+ demo.launch(
793
+ server_name="0.0.0.0", # Allow external access
794
+ server_port=7860,
795
+ share=False, # Set to True if you want to create a public URL
796
+ show_error=True,
797
+ # Optimize concurrency
798
+ max_threads=8, # Adjust based on your CPU cores
799
+ )