hyesulim commited on
Commit
9fc25a3
·
verified ·
1 Parent(s): 230f749

test: try fixing variable error

Browse files

Changed gr.EventData to gr.SelectData for event handling
Updated event data access from evt._data["index"] to evt.index
Added error handling around array indexing
Made event parameters Optional where appropriate
Added try-except blocks around event processing code

Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -120,28 +120,37 @@ def get_data(image_name: str, model_name: str) -> np.ndarray:
120
  def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
121
  """Get activation distribution with memory optimization."""
122
  try:
123
- activation = get_data(image_name, model_type)[0]
 
 
 
 
 
 
 
 
124
  mean_acts = _CACHE.get('sae_data_dict', "mean_acts", {}).get("imagenet", np.array([]))
125
 
126
- if mean_acts.size > 0:
127
- noisy_features_indices = (mean_acts > 0.1).nonzero()[0]
128
- activation[:, noisy_features_indices] = 0
 
129
 
130
  return activation
131
  except Exception as e:
132
  print(f"Error getting activation distribution: {e}")
133
  return np.array([])
134
 
135
- def get_grid_loc(evt: gr.EventData, image: Image.Image) -> Tuple[int, int, int, int]:
136
  """Get grid location from click event."""
137
- x, y = evt._data["index"][0], evt._data["index"][1]
138
  cell_width = image.width // GRID_NUM
139
  cell_height = image.height // GRID_NUM
140
  grid_x = x // cell_width
141
  grid_y = y // cell_height
142
  return grid_x, grid_y, cell_width, cell_height
143
 
144
- def highlight_grid(evt: gr.EventData, image_name: str) -> Image.Image:
145
  """Highlight selected grid cell."""
146
  image = _CACHE.get('data_dict', image_name, {}).get("image")
147
  if not image:
@@ -538,7 +547,6 @@ def start_memory_monitor(interval: int = 300):
538
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=PKL_ROOT)
539
  default_image_name = "christmas-imagenet"
540
 
541
-
542
  # Create the Gradio interface
543
  with gr.Blocks(
544
  theme=gr.themes.Citrus(),
 
120
  def get_activation_distribution(image_name: str, model_type: str) -> np.ndarray:
121
  """Get activation distribution with memory optimization."""
122
  try:
123
+ data = get_data(image_name, model_type)
124
+ if isinstance(data, (list, tuple)):
125
+ activation = data[0]
126
+ else:
127
+ activation = data
128
+
129
+ if not isinstance(activation, np.ndarray):
130
+ activation = np.array(activation)
131
+
132
  mean_acts = _CACHE.get('sae_data_dict', "mean_acts", {}).get("imagenet", np.array([]))
133
 
134
+ if mean_acts.size > 0 and activation.size > 0:
135
+ noisy_features_indices = np.where(mean_acts > 0.1)[0]
136
+ if activation.ndim >= 2:
137
+ activation[:, noisy_features_indices] = 0
138
 
139
  return activation
140
  except Exception as e:
141
  print(f"Error getting activation distribution: {e}")
142
  return np.array([])
143
 
144
+ def get_grid_loc(evt: gr.SelectData, image: Image.Image) -> Tuple[int, int, int, int]:
145
  """Get grid location from click event."""
146
+ x, y = evt.index[0], evt.index[1]
147
  cell_width = image.width // GRID_NUM
148
  cell_height = image.height // GRID_NUM
149
  grid_x = x // cell_width
150
  grid_y = y // cell_height
151
  return grid_x, grid_y, cell_width, cell_height
152
 
153
+ def highlight_grid(evt: gr.SelectData, image_name: str) -> Image.Image:
154
  """Highlight selected grid cell."""
155
  image = _CACHE.get('data_dict', image_name, {}).get("image")
156
  if not image:
 
547
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=PKL_ROOT)
548
  default_image_name = "christmas-imagenet"
549
 
 
550
  # Create the Gradio interface
551
  with gr.Blocks(
552
  theme=gr.themes.Citrus(),