Spaces:
Sleeping
Sleeping
test: try fixing variable error
Browse filesChanged gr.EventData to gr.SelectData for event handling
Updated event data access from evt._data["index"] to evt.index
Added error handling around array indexing
Made event parameters Optional where appropriate
Added try-except blocks around event processing code
app.py
CHANGED
@@ -120,28 +120,37 @@ def get_data(image_name: str, model_name: str) -> np.ndarray:
|
|
120 |
def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
|
121 |
"""Get activation distribution with memory optimization."""
|
122 |
try:
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
mean_acts = _CACHE.get('sae_data_dict', "mean_acts", {}).get("imagenet", np.array([]))
|
125 |
|
126 |
-
if mean_acts.size > 0:
|
127 |
-
noisy_features_indices = (mean_acts > 0.1)
|
128 |
-
activation
|
|
|
129 |
|
130 |
return activation
|
131 |
except Exception as e:
|
132 |
print(f"Error getting activation distribution: {e}")
|
133 |
return np.array([])
|
134 |
|
135 |
-
def get_grid_loc(evt: gr.
|
136 |
"""Get grid location from click event."""
|
137 |
-
x, y = evt.
|
138 |
cell_width = image.width // GRID_NUM
|
139 |
cell_height = image.height // GRID_NUM
|
140 |
grid_x = x // cell_width
|
141 |
grid_y = y // cell_height
|
142 |
return grid_x, grid_y, cell_width, cell_height
|
143 |
|
144 |
-
def highlight_grid(evt: gr.
|
145 |
"""Highlight selected grid cell."""
|
146 |
image = _CACHE.get('data_dict', image_name, {}).get("image")
|
147 |
if not image:
|
@@ -538,7 +547,6 @@ def start_memory_monitor(interval: int = 300):
|
|
538 |
data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=PKL_ROOT)
|
539 |
default_image_name = "christmas-imagenet"
|
540 |
|
541 |
-
|
542 |
# Create the Gradio interface
|
543 |
with gr.Blocks(
|
544 |
theme=gr.themes.Citrus(),
|
|
|
120 |
def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
|
121 |
"""Get activation distribution with memory optimization."""
|
122 |
try:
|
123 |
+
data = get_data(image_name, model_type)
|
124 |
+
if isinstance(data, (list, tuple)):
|
125 |
+
activation = data[0]
|
126 |
+
else:
|
127 |
+
activation = data
|
128 |
+
|
129 |
+
if not isinstance(activation, np.ndarray):
|
130 |
+
activation = np.array(activation)
|
131 |
+
|
132 |
mean_acts = _CACHE.get('sae_data_dict', "mean_acts", {}).get("imagenet", np.array([]))
|
133 |
|
134 |
+
if mean_acts.size > 0 and activation.size > 0:
|
135 |
+
noisy_features_indices = np.where(mean_acts > 0.1)[0]
|
136 |
+
if activation.ndim >= 2:
|
137 |
+
activation[:, noisy_features_indices] = 0
|
138 |
|
139 |
return activation
|
140 |
except Exception as e:
|
141 |
print(f"Error getting activation distribution: {e}")
|
142 |
return np.array([])
|
143 |
|
144 |
+
def get_grid_loc(evt: gr.SelectData, image: Image.Image) -> Tuple[int, int, int, int]:
|
145 |
"""Get grid location from click event."""
|
146 |
+
x, y = evt.index[0], evt.index[1]
|
147 |
cell_width = image.width // GRID_NUM
|
148 |
cell_height = image.height // GRID_NUM
|
149 |
grid_x = x // cell_width
|
150 |
grid_y = y // cell_height
|
151 |
return grid_x, grid_y, cell_width, cell_height
|
152 |
|
153 |
+
def highlight_grid(evt: gr.SelectData, image_name: str) -> Image.Image:
|
154 |
"""Highlight selected grid cell."""
|
155 |
image = _CACHE.get('data_dict', image_name, {}).get("image")
|
156 |
if not image:
|
|
|
547 |
data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=PKL_ROOT)
|
548 |
default_image_name = "christmas-imagenet"
|
549 |
|
|
|
550 |
# Create the Gradio interface
|
551 |
with gr.Blocks(
|
552 |
theme=gr.themes.Citrus(),
|