Spaces:
Sleeping
Sleeping
test: add lru cache
Browse files
app.py
CHANGED
@@ -4,6 +4,10 @@ import pickle
|
|
4 |
from glob import glob
|
5 |
from time import sleep
|
6 |
|
|
|
|
|
|
|
|
|
7 |
import gradio as gr
|
8 |
import numpy as np
|
9 |
import plotly.graph_objects as go
|
@@ -18,23 +22,160 @@ pkl_root = "./data/out"
|
|
18 |
preloaded_data = {}
|
19 |
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
activation = get_data(image_name, model_type)[0]
|
30 |
-
|
31 |
noisy_features_indices = (
|
32 |
-
(sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
|
33 |
)
|
34 |
activation[:, noisy_features_indices] = 0
|
35 |
-
|
36 |
return activation
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def get_grid_loc(evt, image):
|
40 |
# Get click coordinates
|
@@ -203,53 +344,53 @@ def plot_activation_distribution(
|
|
203 |
return fig
|
204 |
|
205 |
|
206 |
-
def get_segmask(selected_image, slider_value, model_type):
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
def get_top_images(slider_value, toggle_btn):
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
|
254 |
|
255 |
def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
|
@@ -464,7 +605,7 @@ def load_all_data(image_root, pkl_root):
|
|
464 |
return data_dict, sae_data_dict
|
465 |
|
466 |
|
467 |
-
data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
|
468 |
default_image_name = "christmas-imagenet"
|
469 |
|
470 |
|
@@ -643,4 +784,16 @@ with gr.Blocks(
|
|
643 |
|
644 |
# Launch the app
|
645 |
# demo.queue()
|
646 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from glob import glob
|
5 |
from time import sleep
|
6 |
|
7 |
+
from functools import lru_cache
|
8 |
+
import concurrent.futures
|
9 |
+
from typing import Dict, Tuple, List
|
10 |
+
|
11 |
import gradio as gr
|
12 |
import numpy as np
|
13 |
import plotly.graph_objects as go
|
|
|
22 |
preloaded_data = {}
|
23 |
|
24 |
|
25 |
+
# Global cache for data
|
26 |
+
_CACHE = {
|
27 |
+
'data_dict': {},
|
28 |
+
'sae_data_dict': {},
|
29 |
+
'model_data': {},
|
30 |
+
'segmasks': {},
|
31 |
+
'top_images': {}
|
32 |
+
}
|
33 |
+
|
34 |
+
def load_all_data(image_root: str, pkl_root: str) -> Tuple[Dict, Dict]:
|
35 |
+
"""Load all data with optimized parallel processing."""
|
36 |
+
# Load images in parallel
|
37 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
38 |
+
image_files = glob(f"{image_root}/*")
|
39 |
+
future_to_file = {
|
40 |
+
executor.submit(_load_image_file, image_file): image_file
|
41 |
+
for image_file in image_files
|
42 |
+
}
|
43 |
+
|
44 |
+
for future in concurrent.futures.as_completed(future_to_file):
|
45 |
+
image_file = future_to_file[future]
|
46 |
+
image_name = os.path.basename(image_file).split(".")[0]
|
47 |
+
result = future.result()
|
48 |
+
if result is not None:
|
49 |
+
_CACHE['data_dict'][image_name] = result
|
50 |
+
|
51 |
+
# Load SAE data
|
52 |
+
with open("./data/sae_data/mean_acts.pkl", "rb") as f:
|
53 |
+
_CACHE['sae_data_dict']["mean_acts"] = pickle.load(f)
|
54 |
+
|
55 |
+
# Load mean act values in parallel
|
56 |
+
datasets = ["imagenet", "imagenet-sketch", "caltech101"]
|
57 |
+
_CACHE['sae_data_dict']["mean_act_values"] = {}
|
58 |
+
|
59 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
60 |
+
future_to_dataset = {
|
61 |
+
executor.submit(_load_mean_act_values, dataset): dataset
|
62 |
+
for dataset in datasets
|
63 |
+
}
|
64 |
+
|
65 |
+
for future in concurrent.futures.as_completed(future_to_dataset):
|
66 |
+
dataset = future_to_dataset[future]
|
67 |
+
result = future.result()
|
68 |
+
if result is not None:
|
69 |
+
_CACHE['sae_data_dict']["mean_act_values"][dataset] = result
|
70 |
+
|
71 |
+
return _CACHE['data_dict'], _CACHE['sae_data_dict']
|
72 |
+
|
73 |
+
def _load_image_file(image_file: str) -> Dict:
|
74 |
+
"""Helper function to load a single image file."""
|
75 |
+
try:
|
76 |
+
image = Image.open(image_file).resize((IMAGE_SIZE, IMAGE_SIZE))
|
77 |
+
return {
|
78 |
+
"image": image,
|
79 |
+
"image_path": image_file,
|
80 |
+
}
|
81 |
+
except Exception as e:
|
82 |
+
print(f"Error loading {image_file}: {e}")
|
83 |
+
return None
|
84 |
|
85 |
+
def _load_mean_act_values(dataset: str) -> np.ndarray:
|
86 |
+
"""Helper function to load mean act values for a dataset."""
|
87 |
+
try:
|
88 |
+
with gzip.open(f"./data/sae_data/mean_act_values_{dataset}.pkl.gz", "rb") as f:
|
89 |
+
return pickle.load(f)
|
90 |
+
except Exception as e:
|
91 |
+
print(f"Error loading mean act values for {dataset}: {e}")
|
92 |
+
return None
|
93 |
+
|
94 |
+
@lru_cache(maxsize=1024)
|
95 |
+
def get_data(image_name: str, model_name: str) -> np.ndarray:
|
96 |
+
"""Cached function to get model data."""
|
97 |
+
cache_key = f"{model_name}_{image_name}"
|
98 |
+
if cache_key not in _CACHE['model_data']:
|
99 |
+
data_dir = f"{pkl_root}/{model_name}/{image_name}.pkl.gz"
|
100 |
+
with gzip.open(data_dir, "rb") as f:
|
101 |
+
_CACHE['model_data'][cache_key] = pickle.load(f)
|
102 |
+
return _CACHE['model_data'][cache_key]
|
103 |
+
|
104 |
+
@lru_cache(maxsize=1024)
|
105 |
+
def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
|
106 |
+
"""Cached function to get activation distribution."""
|
107 |
activation = get_data(image_name, model_type)[0]
|
|
|
108 |
noisy_features_indices = (
|
109 |
+
(_CACHE['sae_data_dict']["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
|
110 |
)
|
111 |
activation[:, noisy_features_indices] = 0
|
|
|
112 |
return activation
|
113 |
|
114 |
+
@lru_cache(maxsize=1024)
|
115 |
+
def get_segmask(selected_image: str, slider_value: int, model_type: str) -> np.ndarray:
|
116 |
+
"""Cached function to get segmentation mask."""
|
117 |
+
cache_key = f"{selected_image}_{slider_value}_{model_type}"
|
118 |
+
if cache_key not in _CACHE['segmasks']:
|
119 |
+
image = _CACHE['data_dict'][selected_image]["image"]
|
120 |
+
sae_act = get_data(selected_image, model_type)[0]
|
121 |
+
temp = sae_act[:, slider_value]
|
122 |
+
|
123 |
+
mask = torch.Tensor(temp[1:].reshape(14, 14)).view(1, 1, 14, 14)
|
124 |
+
mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][0].numpy()
|
125 |
+
mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
|
126 |
+
|
127 |
+
base_opacity = 30
|
128 |
+
image_array = np.array(image)[..., :3]
|
129 |
+
rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
130 |
+
rgba_overlay[..., :3] = image_array[..., :3]
|
131 |
+
|
132 |
+
darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
|
133 |
+
rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
|
134 |
+
rgba_overlay[..., 3] = 255
|
135 |
+
|
136 |
+
_CACHE['segmasks'][cache_key] = rgba_overlay
|
137 |
+
|
138 |
+
return _CACHE['segmasks'][cache_key]
|
139 |
+
|
140 |
+
@lru_cache(maxsize=1024)
|
141 |
+
def get_top_images(slider_value: int, toggle_btn: bool) -> List[Image.Image]:
|
142 |
+
"""Cached function to get top images."""
|
143 |
+
cache_key = f"{slider_value}_{toggle_btn}"
|
144 |
+
if cache_key not in _CACHE['top_images']:
|
145 |
+
dataset_path = "./data/top_images_masked" if toggle_btn else "./data/top_images"
|
146 |
+
paths = [
|
147 |
+
os.path.join(dataset_path, dataset, f"{slider_value}.jpg")
|
148 |
+
for dataset in ["imagenet", "imagenet-sketch", "caltech101"]
|
149 |
+
]
|
150 |
+
|
151 |
+
_CACHE['top_images'][cache_key] = [
|
152 |
+
Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), (255, 255, 255))
|
153 |
+
for path in paths
|
154 |
+
]
|
155 |
+
|
156 |
+
return _CACHE['top_images'][cache_key]
|
157 |
+
|
158 |
+
# Initialize data
|
159 |
+
data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
|
160 |
+
|
161 |
+
|
162 |
+
# def preload_activation(image_name):
|
163 |
+
# for model in ["CLIP"] + [f"MaPLE-{ds}" for ds in DATASET_LIST]:
|
164 |
+
# image_file = f"{pkl_root}/{model}/{image_name}.pkl.gz"
|
165 |
+
# with gzip.open(image_file, "rb") as f:
|
166 |
+
# preloaded_data[model] = pickle.load(f)
|
167 |
+
|
168 |
+
|
169 |
+
# def get_activation_distribution(image_name: str, model_type: str):
|
170 |
+
# activation = get_data(image_name, model_type)[0]
|
171 |
+
|
172 |
+
# noisy_features_indices = (
|
173 |
+
# (sae_data_dict["mean_acts"]["imagenet"] > 0.1).nonzero()[0].tolist()
|
174 |
+
# )
|
175 |
+
# activation[:, noisy_features_indices] = 0
|
176 |
+
|
177 |
+
# return activation
|
178 |
+
|
179 |
|
180 |
def get_grid_loc(evt, image):
|
181 |
# Get click coordinates
|
|
|
344 |
return fig
|
345 |
|
346 |
|
347 |
+
# def get_segmask(selected_image, slider_value, model_type):
|
348 |
+
# image = data_dict[selected_image]["image"]
|
349 |
+
# sae_act = get_data(selected_image, model_type)[0]
|
350 |
+
# temp = sae_act[:, slider_value]
|
351 |
+
# try:
|
352 |
+
# mask = torch.Tensor(temp[1:,].reshape(14, 14)).view(1, 1, 14, 14)
|
353 |
+
# except Exception as e:
|
354 |
+
# print(sae_act.shape, slider_value)
|
355 |
+
# mask = torch.nn.functional.interpolate(mask, (image.height, image.width))[0][
|
356 |
+
# 0
|
357 |
+
# ].numpy()
|
358 |
+
# mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-10)
|
359 |
+
|
360 |
+
# base_opacity = 30
|
361 |
+
# image_array = np.array(image)[..., :3]
|
362 |
+
# rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
363 |
+
# rgba_overlay[..., :3] = image_array[..., :3]
|
364 |
+
|
365 |
+
# darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
|
366 |
+
# rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
|
367 |
+
# rgba_overlay[..., 3] = 255 # Fully opaque
|
368 |
+
|
369 |
+
# return rgba_overlay
|
370 |
+
|
371 |
+
|
372 |
+
# def get_top_images(slider_value, toggle_btn):
|
373 |
+
# def _get_images(dataset_path):
|
374 |
+
# top_image_paths = [
|
375 |
+
# os.path.join(dataset_path, "imagenet", f"{slider_value}.jpg"),
|
376 |
+
# os.path.join(dataset_path, "imagenet-sketch", f"{slider_value}.jpg"),
|
377 |
+
# os.path.join(dataset_path, "caltech101", f"{slider_value}.jpg"),
|
378 |
+
# ]
|
379 |
+
# top_images = [
|
380 |
+
# (
|
381 |
+
# Image.open(path)
|
382 |
+
# if os.path.exists(path)
|
383 |
+
# else Image.new("RGB", (256, 256), (255, 255, 255))
|
384 |
+
# )
|
385 |
+
# for path in top_image_paths
|
386 |
+
# ]
|
387 |
+
# return top_images
|
388 |
+
|
389 |
+
# if toggle_btn:
|
390 |
+
# top_images = _get_images("./data/top_images_masked")
|
391 |
+
# else:
|
392 |
+
# top_images = _get_images("./data/top_images")
|
393 |
+
# return top_images
|
394 |
|
395 |
|
396 |
def show_activation_heatmap(selected_image, slider_value, model_type, toggle_btn=False):
|
|
|
605 |
return data_dict, sae_data_dict
|
606 |
|
607 |
|
608 |
+
# data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
|
609 |
default_image_name = "christmas-imagenet"
|
610 |
|
611 |
|
|
|
784 |
|
785 |
# Launch the app
|
786 |
# demo.queue()
|
787 |
+
# demo.launch()
|
788 |
+
|
789 |
+
|
790 |
+
if __name__ == "__main__":
|
791 |
+
demo.queue() # Enable queuing for better handling of concurrent users
|
792 |
+
demo.launch(
|
793 |
+
server_name="0.0.0.0", # Allow external access
|
794 |
+
server_port=7860,
|
795 |
+
share=False, # Set to True if you want to create a public URL
|
796 |
+
show_error=True,
|
797 |
+
# Optimize concurrency
|
798 |
+
max_threads=8, # Adjust based on your CPU cores
|
799 |
+
)
|