Marlin Lee commited on
Commit
cf75c2d
·
1 Parent(s): cdaf9dc

Sync explorer_app.py and clip_utils.py from main repo

Browse files
Files changed (1) hide show
  1. scripts/explorer_app.py +46 -5
scripts/explorer_app.py CHANGED
@@ -100,6 +100,11 @@ parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-pat
100
  parser.add_argument("--google-api-key", type=str, default=None,
101
  help="Google API key for Gemini auto-interp button "
102
  "(default: GOOGLE_API_KEY env var)")
 
 
 
 
 
103
  args = parser.parse_args()
104
 
105
 
@@ -120,7 +125,7 @@ def _get_clip():
120
 
121
  # ---------- Load all datasets into a unified list ----------
122
 
123
- def _load_dataset_dict(path, label):
124
  """Load one explorer_data.pt file and return a unified dataset dict."""
125
  print(f"Loading [{label}] from {path} ...")
126
  d = torch.load(path, map_location='cpu', weights_only=False)
@@ -203,6 +208,8 @@ def _load_dataset_dict(path, label):
203
  else:
204
  entry['patch_acts'] = None
205
 
 
 
206
  print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
207
  f"backbone={entry['backbone']}, clip={'yes' if cs is not None else 'no'}, "
208
  f"heatmaps={has_hm}, patch_acts={'yes' if entry['patch_acts'] else 'no'}")
@@ -213,14 +220,17 @@ _all_datasets = []
213
  _active = [0] # index of the currently displayed dataset
214
 
215
  # Primary dataset — always loaded eagerly
216
- _all_datasets.append(_load_dataset_dict(args.data, args.primary_label))
217
 
218
  # Compare datasets — stored as lazy placeholders; loaded on first access
219
  for _ci, _cpath in enumerate(args.compare_data):
220
  _clabel = (args.compare_labels[_ci]
221
  if args.compare_labels and _ci < len(args.compare_labels)
222
  else os.path.basename(_cpath))
223
- _all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True})
 
 
 
224
 
225
 
226
  def _ensure_loaded(idx):
@@ -228,7 +238,7 @@ def _ensure_loaded(idx):
228
  ds = _all_datasets[idx]
229
  if ds.get('_lazy', False):
230
  print(f"[Lazy load] Loading '{ds['label']}' on first access ...")
231
- _all_datasets[idx] = _load_dataset_dict(ds['path'], ds['label'])
232
 
233
 
234
  def _apply_dataset_globals(idx):
@@ -1499,6 +1509,37 @@ def _make_summary_html():
1499
 
1500
  summary_div = Div(text=_make_summary_html(), width=700)
1501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1502
 
1503
  # ---------- Patch Explorer ----------
1504
  # Click patches of an image to find the top active SAE features for that region.
@@ -1876,7 +1917,7 @@ patch_explorer_panel = column(
1876
  patch_feat_table,
1877
  )
1878
 
1879
- summary_section = _make_collapsible("SAE Summary", summary_div)
1880
  patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
1881
  clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
1882
 
 
100
  parser.add_argument("--google-api-key", type=str, default=None,
101
  help="Google API key for Gemini auto-interp button "
102
  "(default: GOOGLE_API_KEY env var)")
103
+ parser.add_argument("--sae-path", type=str, default=None,
104
+ help="Path to SAE weights (.pth) for the primary dataset — "
105
+ "enables the Download SAE weights button in the summary panel")
106
+ parser.add_argument("--compare-sae-paths", type=str, nargs="*", default=[],
107
+ help="SAE weight paths for each --compare-data dataset (in order)")
108
  args = parser.parse_args()
109
 
110
 
 
125
 
126
  # ---------- Load all datasets into a unified list ----------
127
 
128
+ def _load_dataset_dict(path, label, sae_path=None):
129
  """Load one explorer_data.pt file and return a unified dataset dict."""
130
  print(f"Loading [{label}] from {path} ...")
131
  d = torch.load(path, map_location='cpu', weights_only=False)
 
208
  else:
209
  entry['patch_acts'] = None
210
 
211
+ entry['sae_path'] = sae_path
212
+
213
  print(f" d={entry['d_model']}, n={entry['n_images']}, token={entry['token_type']}, "
214
  f"backbone={entry['backbone']}, clip={'yes' if cs is not None else 'no'}, "
215
  f"heatmaps={has_hm}, patch_acts={'yes' if entry['patch_acts'] else 'no'}")
 
220
  _active = [0] # index of the currently displayed dataset
221
 
222
  # Primary dataset — always loaded eagerly
223
+ _all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_path=args.sae_path))
224
 
225
  # Compare datasets — stored as lazy placeholders; loaded on first access
226
  for _ci, _cpath in enumerate(args.compare_data):
227
  _clabel = (args.compare_labels[_ci]
228
  if args.compare_labels and _ci < len(args.compare_labels)
229
  else os.path.basename(_cpath))
230
+ _csae = (args.compare_sae_paths[_ci]
231
+ if args.compare_sae_paths and _ci < len(args.compare_sae_paths)
232
+ else None)
233
+ _all_datasets.append({'label': _clabel, 'path': _cpath, '_lazy': True, 'sae_path': _csae})
234
 
235
 
236
  def _ensure_loaded(idx):
 
238
  ds = _all_datasets[idx]
239
  if ds.get('_lazy', False):
240
  print(f"[Lazy load] Loading '{ds['label']}' on first access ...")
241
+ _all_datasets[idx] = _load_dataset_dict(ds['path'], ds['label'], sae_path=ds.get('sae_path'))
242
 
243
 
244
  def _apply_dataset_globals(idx):
 
1509
 
1510
  summary_div = Div(text=_make_summary_html(), width=700)
1511
 
1512
+ # --- SAE weights download button ---
1513
+ _download_source = ColumnDataSource(data=dict(b64=[''], filename=['']))
1514
+ _download_source.js_on_change('data', CustomJS(args=dict(src=_download_source), code="""
1515
+ const b64 = src.data['b64'][0];
1516
+ const fname = src.data['filename'][0];
1517
+ if (!b64) return;
1518
+ const bytes = Uint8Array.from(atob(b64), c => c.charCodeAt(0));
1519
+ const blob = new Blob([bytes], {type: 'application/octet-stream'});
1520
+ const url = URL.createObjectURL(blob);
1521
+ const a = document.createElement('a');
1522
+ a.href = url; a.download = fname; a.click();
1523
+ URL.revokeObjectURL(url);
1524
+ src.data = {b64: [''], filename: ['']};
1525
+ """))
1526
+
1527
+ sae_download_btn = Button(label="\u2b07 Download SAE weights", button_type="default", width=220)
1528
+
1529
+ def _on_sae_download():
1530
+ ds = _all_datasets[_active[0]]
1531
+ sae_path = ds.get('sae_path')
1532
+ if not sae_path or not os.path.exists(sae_path):
1533
+ status_div.text = "<b style='color:red'>No SAE path set for this model. Pass --sae-path.</b>"
1534
+ return
1535
+ status_div.text = f"<b>Reading {os.path.basename(sae_path)}…</b>"
1536
+ with open(sae_path, 'rb') as f:
1537
+ b64 = base64.b64encode(f.read()).decode('ascii')
1538
+ _download_source.data = dict(b64=[b64], filename=[os.path.basename(sae_path)])
1539
+ status_div.text = ""
1540
+
1541
+ sae_download_btn.on_click(lambda: _on_sae_download())
1542
+
1543
 
1544
  # ---------- Patch Explorer ----------
1545
  # Click patches of an image to find the top active SAE features for that region.
 
1917
  patch_feat_table,
1918
  )
1919
 
1920
+ summary_section = _make_collapsible("SAE Summary", column(summary_div, sae_download_btn))
1921
  patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel)
1922
  clip_section = _make_collapsible("CLIP Text Search", clip_search_panel)
1923