| """ |
| Interactive SAE Feature Explorer - Bokeh Server App. |
| |
| Visualizes SAE features with: |
| - UMAP scatter plot of features (activation-based and dictionary-based) |
| - Click a feature to see its top-activating images with heatmap overlays |
| - Patch explorer: click patches of any image to find active SAE features |
| (uses live GPU inference via the backbone + SAE loaded from --sae-path) |
| - Feature naming: assign names to features, saved to JSON, searchable |
| - CLIP text search, Gemini auto-interp, DynaDiff brain steering panel |
| - Optional NSD brain MEI dataset (--brain-data) shown in the dataset dropdown |
| |
| Launch: see run_explorer.sh |
| """ |
|
|
| import argparse |
| import os |
| import io |
| import json |
| import base64 |
| import random |
| import threading |
| from collections import OrderedDict |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import matplotlib.colors as mcolors |
| from PIL import Image |
| import sys |
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'src')) |
| from clip_utils import load_clip, compute_text_embeddings |
|
|
| from bokeh.io import curdoc |
| from bokeh.layouts import column, row |
| from bokeh.events import MouseMove |
| from bokeh.models import ( |
| ColumnDataSource, HoverTool, Div, Select, TextInput, Button, |
| DataTable, TableColumn, NumberFormatter, NumberEditor, |
| Slider, Toggle, RadioButtonGroup, CustomJS, |
| ) |
| from bokeh.plotting import figure |
| from bokeh.palettes import Turbo256 |
| from bokeh.transform import linear_cmap |
|
|
|
|
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data", type=str, required=True) |
| parser.add_argument("--image-dir", type=str, required=True, |
| help="Primary image directory used during precompute") |
| parser.add_argument("--extra-image-dir", type=str, default=[], nargs="*", |
| help="Additional image directories used during precompute") |
| parser.add_argument("--thumb-size", type=int, default=256) |
| parser.add_argument("--inference-cache-size", type=int, default=64, |
| help="Number of images to keep in the patch-activations LRU cache") |
| parser.add_argument("--names-file", type=str, default=None, |
| help="Path to JSON file for saving feature names " |
| "(default: <data>_feature_names.json)") |
| parser.add_argument("--primary-label", type=str, default="Primary", |
| help="Display label for the primary --data file") |
| parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-patch14", |
| help="HuggingFace CLIP model ID for free-text search " |
| "(only loaded on first out-of-vocab query)") |
| parser.add_argument("--google-api-key", type=str, default=None, |
| help="Google API key for Gemini auto-interp button " |
| "(default: GOOGLE_API_KEY env var)") |
| parser.add_argument("--sae-url", type=str, default=None, |
| help="Download URL for the SAE weights — shown as a link in the summary panel") |
| parser.add_argument("--phi-dir", type=str, default=None, |
| help="Directory containing Phi_cv_*.npy, phi_c_*.npy, voxel_coords.npy " |
| "(brain-alignment data; enables cortical profile and brain leverage features)") |
| parser.add_argument("--phi-model", type=str, default=None, |
| help="Model name substring to match phi files (e.g. 'dinov3', 'dinov2', 'clip_encoder'). " |
| "Default: pick largest Phi_cv_*.npy by file size.") |
| parser.add_argument("--dynadiff-dir", type=str, default=None, |
| help="Path to the local dynadiff repo. " |
| "When provided (with --phi-dir), enables the brain steering panel.") |
| parser.add_argument("--dynadiff-checkpoint", type=str, |
| default="dynadiff_padded_sub01.pth", |
| help="Checkpoint filename or path (relative to --dynadiff-dir or absolute).") |
| parser.add_argument("--dynadiff-h5", type=str, |
| default="extracted_training_data/consolidated_sub01.h5", |
| help="Path to fMRI H5 (relative to --dynadiff-dir or absolute).") |
| parser.add_argument("--brain-data", type=str, default=None, |
| help="Path to brain_meis.pt produced by precompute_nsd_meis.py. " |
| "Adds 'NSD Brain (DINOv2 L11)' as a selectable dataset in the " |
| "dataset dropdown, using NSD images and NSD-based UMAPs.") |
| parser.add_argument("--brain-thumbnails", type=str, default=None, |
| help="Directory containing NSD JPEG thumbnails (nsd_XXXXX.jpg). " |
| "Required with --brain-data if image_paths are not absolute paths.") |
| parser.add_argument("--brain-label", type=str, default="NSD Brain (DINOv2 L11)", |
| help="Dataset label shown in the dropdown for --brain-data.") |
| parser.add_argument("--sae-path", type=str, default=None, |
| help="Path to SAE state-dict .pth file. When provided the backbone + SAE " |
| "are loaded on GPU so any image can be explored without pre-computed " |
| "patch activations.") |
| parser.add_argument("--backbone", type=str, default="dinov2", |
| help="Backbone name matching the SAE (default: dinov2).") |
| parser.add_argument("--layer", type=int, default=11, |
| help="Backbone layer used during SAE training (default: 11).") |
| parser.add_argument("--top-k", type=int, default=100, |
| help="SAE top-k sparsity (default: 100).") |
| args = parser.parse_args() |
|
|
|
|
| |
| _clip_handle = None |
|
|
| def _get_clip(): |
| """Load CLIP once and cache it.""" |
| global _clip_handle |
| if _clip_handle is None: |
| _dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| print(f"[CLIP] Loading {args.clip_model} on {_dev} (first free-text query)...") |
| _m, _p = load_clip(_dev, model_name=args.clip_model) |
| _clip_handle = (_m, _p, _dev) |
| print("[CLIP] Ready.") |
| return _clip_handle |
|
|
|
|
| |
| |
| _gpu_runner = None |
|
|
| def _get_gpu_runner(): |
| """Load backbone + SAE on GPU once; return the runner tuple or None.""" |
| global _gpu_runner |
| if _gpu_runner is not None: |
| return _gpu_runner |
| if not args.sae_path: |
| return None |
| if not torch.cuda.is_available(): |
| print("[GPU runner] No CUDA device — on-the-fly inference disabled.") |
| return None |
| import sys, os as _os |
| sys.path.insert(0, _os.path.abspath(_os.path.join(_os.path.dirname(__file__), '..', 'src'))) |
| from backbone_runners import load_batched_backbone |
| from precompute_utils import load_sae, extract_tokens as _et |
| _dev = torch.device("cuda:0") |
| print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {_dev} ...") |
| _fwd, _d_hidden, _n_reg, _tfm = load_batched_backbone(args.backbone, args.layer, _dev) |
| _sae = load_sae(args.sae_path, _d_hidden, d_model, args.top_k, _dev) |
| _gpu_runner = (_fwd, _sae, _tfm, _n_reg, _et, args.backbone, _dev) |
| print("[GPU runner] Ready.") |
| return _gpu_runner |
|
|
|
|
| def _run_gpu_inference(pil_img): |
| """Run pil_img through backbone→SAE; return (n_patches, d_sae) float32 numpy or None.""" |
| runner = _get_gpu_runner() |
| if runner is None: |
| return None |
| _fwd, _sae, _tfm, _n_reg, _et, _bname, _dev = runner |
| tensor = _tfm(pil_img).unsqueeze(0).to(_dev) |
| with torch.inference_mode(): |
| hidden = _fwd(tensor) |
| tokens = _et(hidden, _bname, 'spatial', _n_reg) |
| flat = tokens.reshape(-1, tokens.shape[-1]) |
| _, z, _ = _sae(flat) |
| print(f"[GPU runner] z shape={z.shape}, nonzero={int((z>0).sum())}, max={float(z.max()):.4f}") |
| return z.cpu().float().numpy() |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| _phi_cv = None |
| _phi_c = None |
| _voxel_coords = None |
| _voxel_to_vertex = None |
|
|
| _N_VOXELS_DD = 15724 |
| _N_VERTS_FSAVG = 37984 |
|
|
| if args.phi_dir: |
| _pdir = args.phi_dir |
| _phi_model_key = (args.phi_model or "").lower() |
|
|
| def _pick_phi_file(candidates, model_key): |
| """Pick best phi file: model_key substring match, else largest by size.""" |
| if not candidates: |
| return None |
| if model_key: |
| matched = [f for f in candidates if model_key in f.lower()] |
| if matched: |
| return sorted(matched)[0] |
| print(f"[Phi] WARNING: --phi-model '{model_key}' matched no files in {candidates}; " |
| "falling back to largest file") |
| |
| return max(candidates, key=lambda f: os.path.getsize(os.path.join(_pdir, f))) |
|
|
| |
| _phi_mat_files = [f for f in os.listdir(_pdir) |
| if f.lower().startswith('phi_cv') and f.endswith('.npy')] |
| _phi_mat_pick = _pick_phi_file(_phi_mat_files, _phi_model_key) |
| if _phi_mat_pick: |
| _phi_path = os.path.join(_pdir, _phi_mat_pick) |
| _phi_cv = np.load(_phi_path, mmap_mode='r') |
| print(f"[Phi] Loaded {_phi_mat_pick}: shape {_phi_cv.shape}, dtype {_phi_cv.dtype}") |
| if _phi_cv.shape[1] == _N_VERTS_FSAVG: |
| _v2v_path = os.path.join(_pdir, 'voxel_to_vertex_map.npy') |
| if os.path.exists(_v2v_path): |
| _voxel_to_vertex = np.load(_v2v_path) |
| print(f"[Phi] Surface-space phi; loaded voxel_to_vertex_map: {_voxel_to_vertex.shape}") |
| else: |
| print("[Phi] WARNING: surface-space phi but voxel_to_vertex_map.npy not found") |
| elif _phi_cv.shape[1] == _N_VOXELS_DD: |
| print("[Phi] Voxel-space phi detected.") |
| else: |
| print(f"[Phi] WARNING: unexpected phi dimension {_phi_cv.shape[1]}") |
| else: |
| print(f"[Phi] WARNING: no Phi_cv_*.npy found in {_pdir}") |
| |
| _phi_c_files = [f for f in os.listdir(_pdir) |
| if f.lower().startswith('phi_c') |
| and not f.lower().startswith('phi_cv') |
| and f.endswith('.npy')] |
| _phi_c_pick = _pick_phi_file(_phi_c_files, _phi_model_key) |
| if _phi_c_pick: |
| _phi_c = np.load(os.path.join(_pdir, _phi_c_pick)) |
| print(f"[Phi] Leverage scores {_phi_c_pick}: shape {_phi_c.shape}, " |
| f"range [{_phi_c.min():.4f}, {_phi_c.max():.4f}]") |
| else: |
| print(f"[Phi] No phi_c_*.npy found in {_pdir} — leverage scores unavailable") |
| |
| _coords_path = os.path.join(_pdir, 'voxel_coords.npy') |
| if os.path.exists(_coords_path): |
| _voxel_coords = np.load(_coords_path) |
| print(f"[Phi] Voxel coordinates: {_voxel_coords.shape}") |
| else: |
| print("[Phi] voxel_coords.npy not found — cortical scatter unavailable") |
|
|
| HAS_PHI = _phi_cv is not None |
|
|
|
|
| |
| |
| _dd_loader = None |
| HAS_DYNADIFF = False |
|
|
| if args.dynadiff_dir and os.path.isdir(args.dynadiff_dir): |
| if not HAS_PHI: |
| print("[DynaDiff] WARNING: --phi-dir not set; steering panel requires Phi data. Disabling.") |
| else: |
| try: |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from dynadiff_loader import get_loader |
|
|
| _h5 = args.dynadiff_h5 |
| if not os.path.isabs(_h5): |
| _h5 = os.path.join(args.dynadiff_dir, _h5) |
|
|
| _dd_loader = get_loader( |
| dynadiff_dir = args.dynadiff_dir, |
| checkpoint = args.dynadiff_checkpoint, |
| h5_path = _h5, |
| nsd_thumb_dir = args.brain_thumbnails, |
| subject_idx = 0, |
| ) |
| HAS_DYNADIFF = True |
| print(f"[DynaDiff] In-process loader ready (checkpoint: {args.dynadiff_checkpoint})") |
| except Exception as _dd_err: |
| print(f"[DynaDiff] WARNING: Could not start in-process loader ({_dd_err}). " |
| "Steering panel will be disabled.") |
|
|
|
|
| |
|
|
| def _load_dataset_dict(path, label, sae_url=None): |
| """Load one explorer_data.pt file and return a unified dataset dict.""" |
| print(f"Loading [{label}] from {path} ...") |
| d = torch.load(path, map_location='cpu', weights_only=False) |
| names_file = (args.names_file if path == args.data and args.names_file |
| else os.path.splitext(path)[0] + '_feature_names.json') |
| feat_names = {} |
| if os.path.exists(names_file): |
| with open(names_file) as _nf: |
| feat_names = {int(k): v for k, v in json.load(_nf).items()} |
| auto_interp_file = os.path.splitext(path)[0] + '_auto_interp.json' |
| auto_interp = {} |
| if os.path.exists(auto_interp_file): |
| with open(auto_interp_file) as _af: |
| auto_interp = {int(k): v for k, v in json.load(_af).items()} |
| print(f" Loaded {len(auto_interp)} auto-interp labels from " |
| f"{os.path.basename(auto_interp_file)}") |
| entry = { |
| 'label': label, |
| 'path': path, |
| 'image_paths': d['image_paths'], |
| 'd_model': d['d_model'], |
| 'n_images': d['n_images'], |
| 'patch_grid': d['patch_grid'], |
| 'image_size': d['image_size'], |
| 'backbone': d.get('backbone', 'dinov3'), |
| 'top_img_idx': d['top_img_idx'], |
| 'top_img_act': d['top_img_act'], |
| 'mean_img_idx': d.get('mean_img_idx', d['top_img_idx']), |
| 'mean_img_act': d.get('mean_img_act', d['top_img_act']), |
| 'nsd_top_img_idx': d.get('nsd_top_img_idx', None), |
| 'nsd_top_img_act': d.get('nsd_top_img_act', None), |
| 'nsd_mean_img_idx': d.get('nsd_mean_img_idx', None), |
| 'nsd_mean_img_act': d.get('nsd_mean_img_act', None), |
| 'feature_frequency': d['feature_frequency'], |
| 'feature_mean_act': d['feature_mean_act'], |
| 'umap_coords': d['umap_coords'].numpy(), |
| 'dict_umap_coords': d['dict_umap_coords'].numpy() if 'dict_umap_coords' in d else np.full((d['d_model'], 2), np.nan, dtype=np.float32), |
| 'clip_embeds': d.get('clip_feature_embeds', None), |
| 'nsd_clip_embeds': d.get('nsd_clip_feature_embeds', None), |
| 'inference_cache': OrderedDict(), |
| 'names_file': names_file, |
| 'auto_interp_file': auto_interp_file, |
| 'feature_names': feat_names, |
| 'auto_interp_names': auto_interp, |
| } |
| |
| sidecar = os.path.splitext(path)[0] + '_heatmaps.pt' |
| if os.path.exists(sidecar): |
| print(f" Loading pre-computed heatmaps from {os.path.basename(sidecar)} ...") |
| hm = torch.load(sidecar, map_location='cpu', weights_only=True) |
| entry['top_heatmaps'] = hm.get('top_heatmaps') |
| entry['mean_heatmaps'] = hm.get('mean_heatmaps') |
| entry['nsd_top_heatmaps'] = hm.get('nsd_top_heatmaps') |
| entry['nsd_mean_heatmaps'] = hm.get('nsd_mean_heatmaps') |
| |
| entry['heatmap_patch_grid'] = hm.get('patch_grid', d['patch_grid']) |
| has_hm = 'yes (no GPU needed for heatmaps)' |
| else: |
| entry['top_heatmaps'] = None |
| entry['mean_heatmaps'] = None |
| entry['nsd_top_heatmaps'] = None |
| entry['nsd_mean_heatmaps'] = None |
| entry['heatmap_patch_grid'] = d['patch_grid'] |
| has_hm = 'no' |
|
|
| entry['sae_url'] = sae_url |
|
|
| print(f" d={entry['d_model']}, n={entry['n_images']}, backbone={entry['backbone']}, clip={'yes' if entry.get('clip_embeds') is not None else 'no'}, " |
| f"heatmaps={has_hm}") |
| return entry |
|
|
|
|
| _all_datasets = [] |
|
|
|
|
| |
| class _S: |
| """Mutable module-level state shared by all Bokeh callbacks. |
| |
| Using a plain-class namespace avoids the ``[value]`` mutable-list idiom; |
| attributes can be read and written by any function without ``global`` statements. |
| """ |
| active: int = 0 |
| render_token: int = 0 |
| search_filter = None |
| color_by: str = "Log Frequency" |
| hf_push = None |
| patch_img = None |
| patch_z = None |
|
|
|
|
| |
| _all_datasets.append(_load_dataset_dict(args.data, args.primary_label, sae_url=args.sae_url)) |
|
|
| def _load_brain_dataset_dict(path, label, thumb_dir): |
| """Load a brain_meis.pt file and return a dataset entry dict. |
| |
| Brain MEI files share the same entry schema as regular explorer_data.pt files |
| but have a different on-disk layout (NSD image indices, no CLIP embeddings, etc.) |
| and may store only basenames in image_paths (resolved via thumb_dir at load time). |
| Returns None on failure. |
| """ |
| print(f"[Brain] Loading NSD dataset from {path} ...") |
| try: |
| bd = torch.load(path, map_location='cpu', weights_only=False) |
| except Exception as err: |
| print(f"[Brain] WARNING: Failed to load NSD dataset: {err}") |
| return None |
|
|
| |
| |
| raw_paths = bd.get('image_paths', []) |
| if raw_paths and thumb_dir and ( |
| not os.path.isabs(raw_paths[0]) or not os.path.exists(raw_paths[0]) |
| ): |
| bd_paths = [os.path.join(thumb_dir, os.path.basename(p)) for p in raw_paths] |
| else: |
| bd_paths = raw_paths |
|
|
| d_model = bd['d_model'] |
| nan2 = np.full((d_model, 2), np.nan, dtype=np.float32) |
| stem = os.path.splitext(path)[0] |
|
|
| entry = { |
| 'label': label, |
| 'path': path, |
| 'image_paths': bd_paths, |
| 'd_model': d_model, |
| 'n_images': bd.get('n_images', len(bd_paths)), |
| 'patch_grid': bd.get('patch_grid', 16), |
| 'image_size': bd.get('image_size', 224), |
| 'backbone': bd.get('backbone', 'dinov2'), |
| 'top_img_idx': bd['top_img_idx'], |
| 'top_img_act': bd['top_img_act'], |
| 'mean_img_idx': bd.get('mean_img_idx', bd['top_img_idx']), |
| 'mean_img_act': bd.get('mean_img_act', bd['top_img_act']), |
| 'top_heatmaps': None, |
| 'mean_heatmaps': None, |
| 'heatmap_patch_grid': bd.get('patch_grid', 16), |
| 'feature_frequency': bd['feature_frequency'], |
| 'feature_mean_act': bd['feature_mean_act'], |
| 'umap_coords': bd['umap_coords'].numpy() if 'umap_coords' in bd else nan2, |
| 'dict_umap_coords': bd['dict_umap_coords'].numpy() if 'dict_umap_coords' in bd else nan2, |
| 'clip_embeds': bd.get('clip_feature_embeds', None), |
| 'inference_cache': OrderedDict(), |
| 'names_file': stem + '_feature_names.json', |
| 'auto_interp_file': stem + '_auto_interp.json', |
| 'feature_names': {}, |
| 'auto_interp_names': {}, |
| 'sae_url': None, |
| } |
|
|
| |
| sidecar = stem + '_heatmaps.pt' |
| if os.path.exists(sidecar): |
| print(f"[Brain] Loading heatmaps sidecar: {os.path.basename(sidecar)} ...") |
| bhm = torch.load(sidecar, map_location='cpu', weights_only=False) |
| entry['top_heatmaps'] = bhm.get('top_heatmaps') |
| entry['mean_heatmaps'] = bhm.get('mean_heatmaps') |
| entry['heatmap_patch_grid'] = bhm.get('patch_grid', bd.get('patch_grid', 16)) |
|
|
| print(f"[Brain] Added '{label}' dataset: " |
| f"d_model={d_model}, n_images={entry['n_images']}, backbone={entry['backbone']}") |
| return entry |
|
|
|
|
| |
| |
| if args.brain_data and os.path.exists(args.brain_data): |
| _brain_entry = _load_brain_dataset_dict( |
| args.brain_data, args.brain_label, args.brain_thumbnails or '') |
| if _brain_entry is not None: |
| _all_datasets.append(_brain_entry) |
| elif args.brain_data: |
| print(f"[Brain] WARNING: --brain-data file not found: {args.brain_data}") |
|
|
|
|
| _basename_to_idx = {} |
|
|
|
|
| def _build_basename_index(paths): |
| """Build stem→idx and full-basename→idx lookup for fast filename search.""" |
| d = {} |
| for i, p in enumerate(paths): |
| base = os.path.basename(p) |
| stem = os.path.splitext(base)[0] |
| d[base] = i |
| d[stem] = i |
| return d |
|
|
|
|
| def _apply_dataset_globals(idx): |
| """Swap every module-level data alias to point at dataset[idx]. |
| |
| Bokeh callbacks capture module-level names at import time, so the |
| simplest way to support dataset switching is to rebind these aliases |
| each time the active dataset changes. All callbacks read these names; |
| only this function and the initialisation below may write them. |
| """ |
| global image_paths, d_model, n_images, patch_grid, image_size, heatmap_patch_grid |
| global top_img_idx, top_img_act, mean_img_idx, mean_img_act |
| global nsd_top_img_idx, nsd_top_img_act, nsd_mean_img_idx, nsd_mean_img_act, HAS_NSD_SUBSET |
| global top_heatmaps, mean_heatmaps |
| global nsd_top_heatmaps, nsd_mean_heatmaps |
| global feature_frequency, feature_mean_act |
| global umap_coords, dict_umap_coords |
| global freq, mean_act, log_freq |
| global live_mask, live_indices, dict_live_mask, dict_live_indices |
| global umap_backup |
| global _clip_embeds, _nsd_clip_embeds, HAS_CLIP |
| global feature_names, _names_file, auto_interp_names, _auto_interp_file |
| global _active_feats |
| global _basename_to_idx |
|
|
| ds = _all_datasets[idx] |
| image_paths = ds['image_paths'] |
| _basename_to_idx = _build_basename_index(image_paths) |
| d_model = ds['d_model'] |
| n_images = ds['n_images'] |
| patch_grid = ds['patch_grid'] |
| image_size = ds['image_size'] |
| top_img_idx = ds['top_img_idx'] |
| top_img_act = ds['top_img_act'] |
| mean_img_idx = ds['mean_img_idx'] |
| mean_img_act = ds['mean_img_act'] |
| nsd_top_img_idx = ds.get('nsd_top_img_idx') |
| nsd_top_img_act = ds.get('nsd_top_img_act') |
| nsd_mean_img_idx = ds.get('nsd_mean_img_idx') |
| nsd_mean_img_act = ds.get('nsd_mean_img_act') |
| nsd_top_heatmaps = ds.get('nsd_top_heatmaps') |
| nsd_mean_heatmaps = ds.get('nsd_mean_heatmaps') |
| HAS_NSD_SUBSET = nsd_top_img_idx is not None |
| top_heatmaps = ds.get('top_heatmaps') |
| mean_heatmaps = ds.get('mean_heatmaps') |
| heatmap_patch_grid = ds.get('heatmap_patch_grid', patch_grid) |
| feature_frequency = ds['feature_frequency'] |
| feature_mean_act = ds['feature_mean_act'] |
| umap_coords = ds['umap_coords'] |
| dict_umap_coords = ds['dict_umap_coords'] |
| _clip_embeds = ds['clip_embeds'] |
| _nsd_clip_embeds = ds.get('nsd_clip_embeds') |
| HAS_CLIP = _clip_embeds is not None |
| feature_names = ds['feature_names'] |
| _names_file = ds['names_file'] |
| auto_interp_names = ds['auto_interp_names'] |
| _auto_interp_file = ds['auto_interp_file'] |
|
|
| |
| freq = feature_frequency.numpy() |
| mean_act = feature_mean_act.numpy() |
| log_freq = np.log10(freq + 1) |
| live_mask = ~np.isnan(umap_coords[:, 0]) |
| live_indices = np.where(live_mask)[0] |
| dict_live_mask = ~np.isnan(dict_umap_coords[:, 0]) |
| dict_live_indices = np.where(dict_live_mask)[0] |
| umap_backup = dict( |
| act_x=umap_coords[live_mask, 0].tolist(), |
| act_y=umap_coords[live_mask, 1].tolist(), |
| act_feat=live_indices.tolist(), |
| dict_x=dict_umap_coords[dict_live_mask, 0].tolist(), |
| dict_y=dict_umap_coords[dict_live_mask, 1].tolist(), |
| dict_feat=dict_live_indices.tolist(), |
| ) |
| |
| _active_feats = [int(i) for i in range(d_model) if feature_frequency[i].item() > 0] |
|
|
|
|
| |
| _apply_dataset_globals(0) |
|
|
|
|
| def _save_names(): |
| with open(_names_file, 'w') as _f: |
| json.dump({str(k): v for k, v in sorted(feature_names.items())}, _f, indent=2) |
| print(f"Saved {len(feature_names)} feature names to {_names_file}") |
| _schedule_hf_push(_names_file) |
|
|
|
|
| def _save_auto_interp(): |
| with open(_auto_interp_file, 'w') as _f: |
| json.dump({str(k): v for k, v in sorted(auto_interp_names.items())}, _f, indent=2) |
| print(f"Saved {len(auto_interp_names)} auto-interp labels to {_auto_interp_file}") |
| _schedule_hf_push(_auto_interp_file) |
|
|
|
|
| def _schedule_hf_push(names_file_path): |
| """Debounce HF dataset upload: waits 2 s after the last save, then pushes in a thread. |
| No-op if HF_TOKEN / HF_DATASET_REPO are not set (i.e. running locally).""" |
| hf_token = os.environ.get("HF_TOKEN") |
| hf_repo = os.environ.get("HF_DATASET_REPO") |
| if not (hf_token and hf_repo): |
| return |
|
|
| |
| if _S.hf_push is not None: |
| try: |
| curdoc().remove_timeout_callback(_S.hf_push) |
| except Exception: |
| pass |
|
|
| def _push_thread(): |
| try: |
| from huggingface_hub import upload_file |
| upload_file( |
| path_or_fileobj=names_file_path, |
| path_in_repo=os.path.basename(names_file_path), |
| repo_id=hf_repo, |
| repo_type="dataset", |
| token=hf_token, |
| commit_message="Update feature names", |
| ) |
| print(f" Pushed {os.path.basename(names_file_path)} to HF dataset {hf_repo}") |
| except Exception as e: |
| print(f" Warning: could not push feature names to HF: {e}") |
|
|
| def _fire(): |
| _S.hf_push = None |
| threading.Thread(target=_push_thread, daemon=True).start() |
|
|
| _S.hf_push = curdoc().add_timeout_callback(_fire, 2000) |
|
|
|
|
| def _display_name(feat: int) -> str: |
| """Return the label to show in tables: manual label takes priority over auto-interp.""" |
| m = feature_names.get(feat) |
| if m: |
| return m |
| a = auto_interp_names.get(feat) |
| return f"[auto] {a}" if a else "" |
|
|
|
|
|
|
|
|
| def compute_patch_activations(img_idx): |
| """Return (n_patches, d_sae) float32 via GPU inference, or None if unavailable. |
| |
| Results are cached in a per-dataset LRU cache keyed by image index. |
| """ |
| ds = _all_datasets[_S.active] |
| cache = ds['inference_cache'] |
|
|
| if img_idx in cache: |
| cache.move_to_end(img_idx) |
| return cache[img_idx] |
|
|
| try: |
| pil = load_image(img_idx) |
| z_np = _run_gpu_inference(pil) |
| except Exception as _e: |
| print(f"[GPU runner] inference failed for img {img_idx}: {_e}") |
| z_np = None |
|
|
| if z_np is not None: |
| cache[img_idx] = z_np |
| if len(cache) > args.inference_cache_size: |
| cache.popitem(last=False) |
| return z_np |
|
|
|
|
| |
| def create_alpha_cmap(base='jet'): |
| base_cmap = plt.cm.get_cmap(base) |
| colors = base_cmap(np.arange(base_cmap.N)) |
| colors[:, -1] = np.linspace(0.0, 1.0, base_cmap.N) |
| return mcolors.LinearSegmentedColormap.from_list('alpha_cmap', colors) |
|
|
| ALPHA_JET = create_alpha_cmap('jet') |
|
|
|
|
| |
| THUMB = args.thumb_size |
|
|
|
|
| def _parse_img_label(value): |
| """Parse an image label into an integer index. |
| |
| Accepts: |
| - exact filename match: 'nsd_31215.jpg', 'nsd_31215', '000000204103.jpg' |
| - bare integer index: '42' |
| - ImageNet-style synset: 'n02655020_475' (basename lookup, then trailing-int fallback) |
| |
| Basename lookup is tried before integer parsing so that zero-padded COCO |
| filenames like '000000204103' are resolved to the correct dataset entry |
| rather than being misinterpreted as raw index 204103. |
| Raises ValueError on failure. |
| """ |
| val = value.strip() |
| |
| |
| key = os.path.splitext(val)[0] |
| if key in _basename_to_idx: |
| return _basename_to_idx[key] |
| if val in _basename_to_idx: |
| return _basename_to_idx[val] |
| |
| try: |
| return int(val) |
| except ValueError: |
| pass |
| |
| return int(val.rsplit('_', 1)[-1]) |
|
|
|
|
| def _resolve_img_path(stored_path): |
| """Resolve a stored image path, searching image dirs first. Returns None on failure.""" |
| if os.path.isabs(stored_path) and os.path.exists(stored_path): |
| return stored_path |
| basename = os.path.basename(stored_path) |
| for base in filter(None, [args.image_dir] + (args.extra_image_dir or [])): |
| candidate = os.path.join(base, basename) |
| if os.path.exists(candidate): |
| return candidate |
| if os.path.exists(stored_path): |
| return stored_path |
| return None |
|
|
|
|
| def _load_image_by_path(path): |
| """Load a single image, searching args.image_dir / args.extra_image_dir first.""" |
| resolved = _resolve_img_path(path) or path |
| return Image.open(resolved).convert("RGB") |
|
|
|
|
| def load_image(img_idx): |
| """Load an image by index using the active dataset's image_paths.""" |
| return _load_image_by_path(image_paths[img_idx]) |
|
|
|
|
| def render_heatmap_overlay(img_idx, heatmap_16x16, size=THUMB, cmap=ALPHA_JET, alpha=1.0): |
| """Render image with heatmap overlay.""" |
| img = load_image(img_idx).resize((size, size), Image.BILINEAR) |
| img_arr = np.array(img).astype(np.float32) / 255.0 |
|
|
| heatmap = heatmap_16x16.numpy() if isinstance(heatmap_16x16, torch.Tensor) else heatmap_16x16 |
| heatmap = heatmap.astype(np.float32) |
| heatmap_up = cv2.resize(heatmap, (size, size), interpolation=cv2.INTER_CUBIC) |
|
|
| hmax = heatmap_up.max() |
| heatmap_norm = heatmap_up / hmax if hmax > 0 else heatmap_up |
|
|
| overlay = cmap(heatmap_norm) |
| ov_alpha = overlay[:, :, 3:4] * alpha |
| blended = img_arr * (1 - ov_alpha) + overlay[:, :, :3] * ov_alpha |
| blended = np.clip(blended * 255, 0, 255).astype(np.uint8) |
| return Image.fromarray(blended) |
|
|
|
|
| def render_zoomed_overlay(img_idx, heatmap_16x16, size=THUMB, pg=None, alpha=None, |
| center='peak'): |
| """Render heatmap overlay cropped to the zoom window at the current slider level. |
| |
| At full zoom (slider == pg) the whole image is returned. At lower values |
| the overlay is cropped to a zoom_patches × zoom_patches patch window and |
| upscaled to `size`. |
| |
| center='peak' — window centred on the argmax patch (good for max-ranked images) |
| center='centroid' — window centred on the activation-weighted centroid |
| (good for mean-ranked images where activation is diffuse) |
| """ |
| if pg is None: |
| pg = heatmap_patch_grid |
| if alpha is None: |
| alpha = heatmap_alpha_slider.value |
| heatmap = heatmap_16x16.numpy() if isinstance(heatmap_16x16, torch.Tensor) else heatmap_16x16 |
|
|
| |
| overlay = render_heatmap_overlay(img_idx, heatmap, size=image_size, alpha=alpha) |
|
|
| zoom_patches = int(zoom_slider.value) |
| if zoom_patches >= pg: |
| return overlay.resize((size, size), Image.BILINEAR) |
|
|
| |
| if center == 'centroid': |
| total = heatmap.sum() |
| if total > 0: |
| rows = np.arange(pg) |
| cols = np.arange(pg) |
| peak_row = int(np.average(rows, weights=heatmap.sum(axis=1))) |
| peak_col = int(np.average(cols, weights=heatmap.sum(axis=0))) |
| else: |
| peak_row, peak_col = pg // 2, pg // 2 |
| else: |
| peak_idx = np.argmax(heatmap) |
| peak_row, peak_col = divmod(int(peak_idx), pg) |
|
|
| patch_px = image_size // pg |
| half = (zoom_patches * patch_px) // 2 |
| cy = peak_row * patch_px + patch_px // 2 |
| cx = peak_col * patch_px + patch_px // 2 |
| y0 = max(0, cy - half); y1 = min(image_size, cy + half) |
| x0 = max(0, cx - half); x1 = min(image_size, cx + half) |
| return overlay.crop((x0, y0, x1, y1)).resize((size, size), Image.BILINEAR) |
|
|
|
|
| def pil_to_data_url(img): |
| buf = io.BytesIO() |
| img.save(buf, format="JPEG", quality=85) |
| b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
| return f"data:image/jpeg;base64,{b64}" |
|
|
|
|
| |
|
|
| def _phi_c_for_feat(feat): |
| """Return cortical leverage score φ_c for a feature, or None.""" |
| if _phi_c is None or feat >= len(_phi_c): |
| return None |
| return float(_phi_c[feat]) |
|
|
|
|
| def _phi_voxel_row(feat): |
| """Return the phi row in voxel space (15724,) float32, or None.""" |
| if _phi_cv is None or feat >= _phi_cv.shape[0]: |
| return None |
| phi_row = np.array(_phi_cv[feat], dtype=np.float32) |
| if _voxel_to_vertex is not None: |
| return phi_row[_voxel_to_vertex] |
| return phi_row |
|
|
|
|
| def _render_steering_preview(feats, lams, thresholds): |
| """Render the net combined steering direction across all chosen features. |
| |
| Computes: sum_i( lam_i * threshold_mask_i * phi_i / max|phi_i| ) |
| Returns an HTML string with an inline PNG brain map, or "" if no data. |
| """ |
| if not feats or _voxel_coords is None: |
| return "" |
| combined = np.zeros(_N_VOXELS_DD, dtype=np.float32) |
| n_valid = 0 |
| for f, lam, thr in zip(feats, lams, thresholds): |
| phi = _phi_voxel_row(f) |
| if phi is None: |
| continue |
| phi_max = float(np.abs(phi).max()) |
| if phi_max < 1e-12: |
| continue |
| norm_phi = phi / phi_max |
| if thr < 1.0: |
| cutoff = float(np.percentile(np.abs(phi), 100.0 * (1.0 - thr))) |
| norm_phi = norm_phi * (np.abs(phi) >= cutoff) |
| combined += lam * norm_phi |
| n_valid += 1 |
| if n_valid == 0 or np.abs(combined).max() < 1e-12: |
| return "" |
| vmax = float(np.abs(combined).max()) or 1e-6 |
| fig, axes = plt.subplots(1, 2, figsize=(8, 3.2), facecolor='#f8f8f8') |
| for ax, (title, xi, yi) in zip(axes, [("Axial (x–y)", 0, 1), ("Coronal (x–z)", 0, 2)]): |
| sc = ax.scatter( |
| _voxel_coords[:, xi], _voxel_coords[:, yi], |
| c=combined, cmap='RdBu_r', s=3, alpha=0.7, |
| vmin=-vmax, vmax=vmax, rasterized=True, marker='s', |
| ) |
| ax.set_title(title, fontsize=9) |
| ax.set_aspect('equal') |
| ax.set_xticks([]); ax.set_yticks([]) |
| ax.set_facecolor('#f8f8f8') |
| fig.subplots_adjust(right=0.88, top=0.85) |
| cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.65]) |
| cbar = fig.colorbar(sc, cax=cbar_ax) |
| cbar.set_label('Δ fMRI (norm.)', fontsize=8) |
| lbl = f'{n_valid} feature{"s" if n_valid != 1 else ""}' |
| fig.suptitle(f'Net brain modification — {lbl}', fontsize=10) |
| buf = io.BytesIO() |
| fig.savefig(buf, format='png', dpi=80, bbox_inches='tight', facecolor='#f8f8f8') |
| plt.close(fig) |
| b64 = base64.b64encode(buf.getvalue()).decode('utf-8') |
| return ( |
| '<h4 style="margin:6px 0 3px 0;color:#333;font-size:12px">Net Brain Modification</h4>' |
| f'<img src="data:image/png;base64,{b64}" ' |
| 'style="max-width:100%;border-radius:4px;border:1px solid #ddd"/>' |
| ) |
|
|
|
|
| def _render_cortical_profile(feat): |
| """Render two 2D scatter views of voxel phi values as an inline PNG HTML block. |
| |
| Returns empty string when phi data is unavailable for this feature. |
| """ |
| phi_vox = _phi_voxel_row(feat) |
| if phi_vox is None or _voxel_coords is None: |
| return "" |
|
|
| vmax = float(np.abs(phi_vox).max()) or 1e-6 |
| phi_c_val = _phi_c_for_feat(feat) |
| phi_c_str = f"φ_c = {phi_c_val:.4f}" if phi_c_val is not None else "" |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(10, 4.0), facecolor='#f8f8f8') |
| view_pairs = [("Axial (x – y)", 0, 1), ("Coronal (x – z)", 0, 2)] |
| for ax, (title, xi, yi) in zip(axes, view_pairs): |
| sc = ax.scatter( |
| _voxel_coords[:, xi], _voxel_coords[:, yi], |
| c=phi_vox, cmap='RdBu_r', s=4, alpha=0.75, |
| vmin=-vmax, vmax=vmax, rasterized=True, marker='s', |
| ) |
| ax.set_title(title, fontsize=10) |
| ax.set_aspect('equal') |
| ax.set_xticks([]); ax.set_yticks([]) |
| ax.set_facecolor('#f8f8f8') |
|
|
| fig.subplots_adjust(right=0.88, top=0.88) |
| cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.65]) |
| cbar = fig.colorbar(sc, cax=cbar_ax) |
| cbar.set_label('Φ weight', fontsize=9) |
| fig.suptitle( |
| f'Cortical Profile — Feature {feat}' + (f' ({phi_c_str})' if phi_c_str else ''), |
| fontsize=11, |
| ) |
|
|
| buf = io.BytesIO() |
| fig.savefig(buf, format='png', dpi=90, bbox_inches='tight', facecolor='#f8f8f8') |
| plt.close(fig) |
| b64 = base64.b64encode(buf.getvalue()).decode('utf-8') |
|
|
| return ( |
| '<h3 style="margin:4px 0 6px 0;color:#333;border-bottom:2px solid #e0e0e0;' |
| 'padding-bottom:4px">Cortical Profile (Φ)</h3>' |
| f'<img src="data:image/png;base64,{b64}" ' |
| 'style="max-width:100%;border-radius:4px;border:1px solid #ddd"/>' |
| ) |
|
|
|
|
| def _status_html(state, msg): |
| """Return a styled HTML status banner.""" |
| styles = { |
| 'idle': 'background:#f5f5f5;border-left:4px solid #bbb;color:#666', |
| 'loading': 'background:#fff8e0;border-left:4px solid #f0a020;color:#7a5000', |
| 'ok': 'background:#e8f4e8;border-left:4px solid #2a8a2a;color:#1a5a1a', |
| 'dead': 'background:#fce8e8;border-left:4px solid #c03030;color:#8a1a1a', |
| } |
| style = styles.get(state, styles['idle']) |
| return f'<div style="{style};padding:7px 12px;border-radius:3px;font-size:13px">{msg}</div>' |
|
|
|
|
| |
|
|
| def _dynadiff_request(sample_idx, steerings, seed): |
| """Run DynaDiff reconstruction. |
| |
| steerings: list of (phi_voxel np.ndarray, lam float, threshold float) |
| Returns dict with baseline_img, steered_img, gt_img, beta_std. |
| """ |
| status, err = _dd_loader.status |
| if status == 'loading': |
| raise RuntimeError('DynaDiff model still loading — try again shortly') |
| if status == 'error': |
| raise RuntimeError(f'DynaDiff model load failed: {err}') |
| return _dd_loader.reconstruct(sample_idx, steerings, seed) |
|
|
|
|
| def _make_steering_html(resps, concept_name): |
| """Build HTML showing GT | Baseline | Steered for one or more trials. |
| |
| resps: list of (trial_label, resp_dict) pairs. |
| """ |
| header = ( |
| f'<h3 style="margin:4px 0 6px 0;color:#333;border-bottom:2px solid #e0e0e0;' |
| f'padding-bottom:4px">DynaDiff Steering — {concept_name}</h3>' |
| ) |
| rows_html = '' |
| for trial_label, resp in resps: |
| parts = [] |
| for label, key in [('GT', 'gt_img'), |
| ('Baseline', 'baseline_img'), |
| ('Steered', 'steered_img')]: |
| b64 = resp.get(key) |
| if b64 is None: |
| img_html = ('<div style="width:160px;height:160px;background:#eee;' |
| 'display:flex;align-items:center;justify-content:center;' |
| 'color:#999;font-size:12px">N/A</div>') |
| else: |
| img_html = (f'<img src="data:image/png;base64,{b64}" ' |
| 'style="width:160px;height:160px;object-fit:contain;' |
| 'border:1px solid #ddd;border-radius:4px"/>') |
| parts.append( |
| f'<div style="text-align:center;margin:0 4px">' |
| f'{img_html}' |
| f'<div style="font-size:11px;color:#555;margin-top:3px">{label}</div>' |
| f'</div>' |
| ) |
| trial_head = (f'<div style="font-size:11px;font-weight:bold;color:#777;' |
| f'margin:6px 0 3px 4px">{trial_label}</div>') |
| rows_html += (trial_head |
| + '<div style="display:flex;align-items:flex-end;margin-bottom:8px">' |
| + ''.join(parts) + '</div>') |
| return header + rows_html |
|
|
|
|
| def make_image_grid_html(images_info, title): |
| if not images_info: |
| return (f'<h3 style="margin:4px 0 6px 0;color:#444;border-bottom:2px solid #e8e8e8;' |
| f'padding-bottom:4px">{title}</h3>' |
| f'<p style="color:#aaa;font-style:italic;margin:4px 0">No examples available</p>') |
| thumb_w = min(THUMB, 224) |
| html = (f'<h3 style="margin:4px 0 8px 0;color:#333;border-bottom:2px solid #e0e0e0;' |
| f'padding-bottom:4px">{title}</h3>') |
| html += '<div style="display:flex;flex-wrap:wrap;gap:8px;padding:2px 0 10px 0">' |
| for img, caption in images_info: |
| url = pil_to_data_url(img) |
| parts = caption.split('<br>') |
| cap_html = ''.join(f'<div>{p}</div>' for p in parts) |
| html += (f'<div style="text-align:center;width:{thumb_w}px">' |
| f'<img src="{url}" width="{thumb_w}" height="{thumb_w}"' |
| f' style="border:1px solid #d0d0d0;border-radius:5px;display:block"/>' |
| f'<div style="font-size:10px;color:#555;margin-top:3px;line-height:1.4">' |
| f'{cap_html}</div></div>') |
| html += '</div>' |
| return html |
|
|
|
|
| def make_compare_aggregations_html(top_infos, mean_infos, feat, n_each=6, model_label=None): |
| """Figure-ready side-by-side comparison of the first two aggregation methods. |
| |
| Only Top (Max Activation) and Mean Activation are shown so that a screenshot |
| of this element stands alone as a clean figure panel. |
| """ |
| col_thumb = min(THUMB, 160) |
|
|
| |
| sections = [ |
| ("Top Activation", "#2563a8", top_infos), |
| ("Mean Activation", "#1a7a4a", mean_infos), |
| ] |
|
|
| cols_per_row = 2 |
| strip_w = cols_per_row * col_thumb + (cols_per_row - 1) * 6 |
|
|
| |
| |
| html = ( |
| '<div style="font-family:Arial,Helvetica,sans-serif;background:#ffffff;' |
| 'padding:16px 20px 14px 20px;display:inline-block">' |
| |
| f'<div style="font-size:13px;font-weight:bold;color:#222;margin-bottom:14px;' |
| f'letter-spacing:0.1px">' |
| + (f'{model_label} — ' if model_label else '') |
| + f'Feature {feat}</div>' |
| '<div style="display:flex;gap:24px;align-items:flex-start">' |
| ) |
|
|
| for method_name, color, infos in sections: |
| shown = (infos or [])[:n_each] |
|
|
| html += ( |
| f'<div style="display:inline-flex;flex-direction:column">' |
| |
| f'<div style="background:{color};color:#ffffff;font-size:13px;font-weight:bold;' |
| f'text-align:center;padding:6px 0;border-radius:5px;margin-bottom:10px;' |
| f'letter-spacing:0.4px;width:{strip_w}px;box-sizing:border-box">{method_name}</div>' |
| f'<div style="display:grid;grid-template-columns:repeat({cols_per_row},{col_thumb}px);gap:6px">' |
| ) |
| if not shown: |
| html += '<div style="color:#aaa;font-style:italic;font-size:11px;padding:8px">No images</div>' |
| for img, caption in shown: |
| url = pil_to_data_url(img) |
| parts = caption.split('<br>') |
| cap_html = '<br>'.join(parts) |
| html += ( |
| f'<div style="text-align:center">' |
| f'<img src="{url}" width="{col_thumb}" height="{col_thumb}"' |
| f' style="border:1px solid #ccc;border-radius:3px;display:block"/>' |
| f'<div style="font-size:9px;color:#555;margin-top:3px;line-height:1.35">' |
| f'{cap_html}</div></div>' |
| ) |
| html += '</div></div>' |
|
|
| html += '</div></div>' |
| return html |
|
|
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| def _phi_c_vals(indices): |
| """Return phi_c leverage values for a list of feature indices (0.0 when unavailable).""" |
| if _phi_c is None: |
| return [0.0] * len(indices) |
| return [float(_phi_c[i]) if i < len(_phi_c) else 0.0 for i in indices] |
|
|
|
|
|
|
| def _make_point_alphas(n): |
| """Return uniform 0.6 alpha for all n UMAP points.""" |
| return [0.6] * n |
|
|
|
|
| def _make_color_vals(indices): |
| """Return color values for the UMAP scatter based on current _S.color_by.""" |
| cb = _S.color_by |
| idx_arr = np.array(indices, dtype=int) |
| if cb == "Mean Activation": |
| return mean_act[idx_arr].tolist() |
| elif cb == "Brain Leverage (φ_c)": |
| return _phi_c_vals(indices) |
| else: |
| return log_freq[idx_arr].tolist() |
|
|
|
|
| umap_source = ColumnDataSource(data=dict( |
| x=umap_coords[live_mask, 0].tolist(), |
| y=umap_coords[live_mask, 1].tolist(), |
| feature_idx=live_indices.tolist(), |
| frequency=freq[live_mask].tolist(), |
| log_freq=log_freq[live_mask].tolist(), |
| mean_act=mean_act[live_mask].tolist(), |
| phi_c_val=_phi_c_vals(live_indices.tolist()), |
| color_val=log_freq[live_mask].tolist(), |
| point_alpha=_make_point_alphas(int(live_mask.sum())), |
| )) |
|
|
|
|
| |
| _init_log_freq = log_freq[live_mask] |
| color_mapper = linear_cmap( |
| field_name='color_val', palette=Turbo256, |
| low=float(np.percentile(_init_log_freq, 2)) if live_mask.any() else 0, |
| high=float(np.percentile(_init_log_freq, 98)) if live_mask.any() else 1, |
| ) |
|
|
| def _set_color_mapper_range(color_vals): |
| """Update color_mapper low/high to the 2nd–98th percentile of color_vals.""" |
| if not color_vals: |
| return |
| arr = np.array(color_vals) |
| lo, hi = float(np.percentile(arr, 2)), float(np.percentile(arr, 98)) |
| if lo == hi: |
| hi = lo + 1e-6 |
| color_mapper['transform'].low = lo |
| color_mapper['transform'].high = hi |
|
|
|
|
| umap_fig = figure( |
| title="UMAP of SAE Features (by activation pattern)", |
| width=700, height=650, |
| tools="pan,wheel_zoom,box_zoom,reset,tap", |
| active_scroll="wheel_zoom", |
| ) |
| umap_scatter = umap_fig.scatter( |
| 'x', 'y', source=umap_source, size=4, alpha='point_alpha', |
| color=color_mapper, |
| selection_color="#FF2222", selection_alpha=1.0, |
| selection_line_color="white", selection_line_width=1.5, |
| ) |
|
|
| |
| _zoom_cb = CustomJS(args=dict(renderer=umap_scatter, x_range=umap_fig.x_range), code=""" |
| const span = x_range.end - x_range.start; |
| if (window._umap_base_span === undefined) { |
| window._umap_base_span = span; |
| } |
| const zoom = window._umap_base_span / span; |
| const new_size = Math.min(12, Math.max(3, 3 * Math.pow(zoom, 0.1))); |
| renderer.glyph.size = new_size; |
| renderer.nonselection_glyph.size = new_size; |
| renderer.selection_glyph.size = Math.max(14, new_size * 3); |
| """) |
| umap_fig.x_range.js_on_change('start', _zoom_cb) |
| umap_fig.x_range.js_on_change('end', _zoom_cb) |
|
|
| _phi_hover = [("Brain φ_c", "@phi_c_val{0.0000}")] if HAS_PHI else [] |
| umap_fig.add_tools(HoverTool(tooltips=[ |
| ("Feature", "@feature_idx"), |
| ("Frequency", "@frequency{0}"), |
| ("Mean Act", "@mean_act{0.000}"), |
| ] + _phi_hover)) |
|
|
|
|
| |
| dataset_select = Select( |
| title="Dataset:", |
| value="0", |
| options=[(str(i), ds['label']) for i, ds in enumerate(_all_datasets)], |
| width=250, |
| ) |
|
|
|
|
| def _on_dataset_switch(attr, old, new): |
| idx = int(new) |
| old_idx = int(old) |
|
|
| |
| _prev_feat_str = feature_input.value.strip() |
| _old_d_model = _all_datasets[old_idx]['d_model'] |
|
|
| _S.active = idx |
| _apply_dataset_globals(idx) |
|
|
| |
| _feat_ids = live_indices.tolist() |
| _color_vals = _make_color_vals(_feat_ids) |
| _phi_c_list = _phi_c_vals(_feat_ids) |
| umap_source.data = dict( |
| x=umap_coords[live_mask, 0].tolist(), |
| y=umap_coords[live_mask, 1].tolist(), |
| feature_idx=_feat_ids, |
| frequency=freq[live_mask].tolist(), |
| log_freq=log_freq[live_mask].tolist(), |
| mean_act=mean_act[live_mask].tolist(), |
| phi_c_val=_phi_c_list, |
| color_val=_color_vals, |
| point_alpha=_make_point_alphas(len(_feat_ids)), |
| ) |
| _set_color_mapper_range(_color_vals) |
| umap_source.selected.indices = [] |
| umap_type_select.value = "Activation Pattern" |
| umap_fig.title.text = f"UMAP — {_all_datasets[idx]['label']}" |
|
|
| |
| _S.search_filter = None |
| _apply_order(_get_sorted_order()) |
|
|
| |
| summary_div.text = _make_summary_html() |
|
|
| |
| can_explore = bool(args.sae_path) |
| patch_fig.visible = can_explore |
| patch_info_div.visible = can_explore |
| if not can_explore: |
| patch_info_div.text = ( |
| '<p style="color:#888;font-style:italic">Patch explorer unavailable: no --sae-path provided.</p>') |
| patch_info_div.visible = True |
|
|
| |
| if HAS_CLIP: |
| clip_result_div.text = "" |
| clip_result_source.data = dict( |
| feature_idx=[], clip_score=[], frequency=[], mean_act=[], phi_c_val=[], name=[]) |
|
|
| |
| _same_space = (_all_datasets[idx]['d_model'] == _old_d_model) |
| _restore_feat = None |
| if _same_space and _prev_feat_str: |
| try: |
| _restore_feat = int(_prev_feat_str) |
| except ValueError: |
| pass |
|
|
| if _restore_feat is not None and 0 <= _restore_feat < d_model: |
| feature_input.value = str(_restore_feat) |
| update_feature_display(_restore_feat) |
| else: |
| feature_input.value = "" |
| stats_div.text = "<h3>Select a feature to explore</h3>" |
| brain_div.text = "" |
| status_div.text = _status_html('idle', 'Model switched — select a feature to explore.') |
| if HAS_DYNADIFF: |
| _dd_output.text = "" |
| _dd_status.text = "" |
| for div in [top_heatmap_div, mean_heatmap_div]: |
| div.text = "" |
|
|
|
|
| dataset_select.on_change('value', _on_dataset_switch) |
|
|
|
|
| |
| status_div = Div( |
| text=_status_html('idle', 'Select a feature on the UMAP or from the list to begin.'), |
| width=900, |
| ) |
| stats_div = Div(text="<h3>Click a feature on the UMAP to explore it</h3>", width=900) |
| top_heatmap_div = Div(text="", width=900) |
| mean_heatmap_div = Div(text="", width=900) |
| compare_agg_div = Div(text="", width=1400) |
| brain_div = Div(text="", width=900) |
|
|
| |
| |
|
|
| def _phi_cv_feat_name(feat): |
| """Best-effort display name for the feature.""" |
| if feat is None: |
| return 'unknown' |
| ds = _all_datasets[_S.active] if _all_datasets else None |
| if ds and feat in ds.get('feature_names', {}): |
| return ds['feature_names'][feat] |
| return f'feat {feat}' |
|
|
| def _build_dynadiff_panel(): |
| """Build the DynaDiff brain-steering panel widgets and callbacks. |
| |
| Returns (panel_body, dd_output, dd_status, dd_feat_input). |
| When HAS_DYNADIFF is False, panel_body is None and the divs are 1-pixel stubs. |
| dd_feat_input is None when disabled so callers must guard before use. |
| """ |
| if not HAS_DYNADIFF: |
| return None, Div(text="", width=1), Div(text="", width=1), None, None |
|
|
| |
| dd_source = ColumnDataSource(data=dict(feat=[], name=[], lam=[], threshold=[])) |
|
|
| dd_table = DataTable( |
| source=dd_source, |
| columns=[ |
| TableColumn(field='feat', title='#', width=55), |
| TableColumn(field='name', title='Feature', width=190), |
| TableColumn(field='lam', title='λ', width=60, |
| editor=NumberEditor(), |
| formatter=NumberFormatter(format='0.0')), |
| TableColumn(field='threshold', title='Brain%', width=65, |
| editor=NumberEditor(), |
| formatter=NumberFormatter(format='0.00')), |
| ], |
| editable=True, |
| width=460, |
| height=130, |
| index_position=None, |
| ) |
|
|
| |
| dd_steer_div = Div(text="", width=460) |
|
|
| def _update_dd_preview(): |
| feats = list(dd_source.data['feat']) |
| lams = list(dd_source.data['lam']) |
| thrs = list(dd_source.data['threshold']) |
| dd_steer_div.text = _render_steering_preview(feats, lams, thrs) |
|
|
| dd_source.on_change('data', lambda attr, old, new: _update_dd_preview()) |
|
|
| |
| dd_feat_input = TextInput(title="Feature index:", placeholder="e.g. 1234", width=120) |
| dd_add_lam_input = TextInput(title="λ:", value="3.0", width=65) |
| dd_add_thr_select = Select( |
| title="Brain %:", |
| options=[('0.05', '5%'), ('0.10', '10%'), ('0.25', '25%'), ('1.0', '100%')], |
| value='0.10', |
| width=90, |
| ) |
| dd_feat_add_btn = Button(label="Add", button_type="success", width=55) |
| dd_feat_remove_btn = Button(label="Remove selected", button_type="light", width=130) |
| dd_feat_clear_btn = Button(label="Clear all", button_type="light", width=80) |
|
|
| |
| dd_sample_input = TextInput(title="Sample idx", value="0", width=180) |
| dd_seed_input = TextInput(title="Seed:", value="42", width=70) |
| dd_btn = Button(label="Steer & Reconstruct", button_type="primary", width=200) |
| dd_status = Div(text="", width=460) |
| dd_output = Div(text="", width=460) |
|
|
| def _on_add_feat(): |
| try: |
| f = int(dd_feat_input.value.strip()) |
| except ValueError: |
| dd_status.text = '<span style="color:#c00">Invalid feature index.</span>' |
| return |
| if _phi_cv is None or f < 0 or f >= _phi_cv.shape[0]: |
| n = _phi_cv.shape[0] if _phi_cv is not None else '?' |
| dd_status.text = f'<span style="color:#c00">Feature {f} out of range (0–{n}).</span>' |
| return |
| try: |
| lam = float(dd_add_lam_input.value) |
| except ValueError: |
| lam = 3.0 |
| threshold = float(dd_add_thr_select.value) |
| new_data = {k: list(v) for k, v in dd_source.data.items()} |
| new_data['feat'].append(f) |
| new_data['name'].append(_phi_cv_feat_name(f)) |
| new_data['lam'].append(lam) |
| new_data['threshold'].append(threshold) |
| dd_source.data = new_data |
| dd_status.text = '' |
|
|
| def _on_remove_feat(): |
| sel = dd_source.selected.indices |
| if not sel: |
| dd_status.text = '<span style="color:#888">Select a row first.</span>' |
| return |
| new_data = {k: [v for i, v in enumerate(vals) if i not in sel] |
| for k, vals in dd_source.data.items()} |
| dd_source.data = new_data |
| dd_source.selected.indices = [] |
| dd_status.text = '' |
|
|
| def _on_clear_feats(): |
| dd_source.data = dict(feat=[], name=[], lam=[], threshold=[]) |
| dd_status.text = '' |
|
|
| dd_feat_add_btn.on_click(_on_add_feat) |
| dd_feat_remove_btn.on_click(_on_remove_feat) |
| dd_feat_clear_btn.on_click(_on_clear_feats) |
|
|
| def _reconstruct_thread(sample_idxs, steerings, seed, feat_name, doc): |
| try: |
| resps = [] |
| for i, sidx in enumerate(sample_idxs): |
| trial_label = f'Trial {i+1} (sample {sidx})' |
| resp = _dynadiff_request(sidx, steerings, seed) |
| resps.append((trial_label, resp)) |
| html = _make_steering_html(resps, feat_name) |
| def _apply(html=html): |
| dd_output.text = html |
| dd_status.text = '' |
| dd_btn.disabled = False |
| doc.add_next_tick_callback(_apply) |
| except Exception as exc: |
| msg = str(exc) |
| def _show_err(msg=msg): |
| dd_status.text = f'<span style="color:#c00">Error: {msg}</span>' |
| dd_btn.disabled = False |
| doc.add_next_tick_callback(_show_err) |
|
|
| def _on_reconstruct(): |
| feats = list(dd_source.data['feat']) |
| lams = list(dd_source.data['lam']) |
| thrs = list(dd_source.data['threshold']) |
| if not feats: |
| dd_status.text = '<span style="color:#c00">Add at least one feature first.</span>' |
| return |
| steerings = [] |
| for f, lam, thr in zip(feats, lams, thrs): |
| phi = _phi_voxel_row(f) |
| if phi is not None: |
| steerings.append((phi, float(lam), float(thr))) |
| if not steerings: |
| dd_status.text = '<span style="color:#c00">No phi data for selected features.</span>' |
| return |
| _raw = dd_sample_input.value.strip() |
| try: |
| _parsed = _parse_img_label(_raw) |
| except ValueError: |
| dd_status.text = '<span style="color:#c00">Invalid sample index.</span>' |
| return |
| |
| |
| |
| _dd_cur_status, _dd_cur_err = _dd_loader.status |
| if _dd_cur_status == 'loading': |
| dd_status.text = ('<span style="color:#f0a020">' |
| 'DynaDiff model still loading — try again shortly.</span>') |
| return |
| if _dd_cur_status == 'error': |
| dd_status.text = (f'<span style="color:#c00">' |
| f'DynaDiff model load failed: {_dd_cur_err}</span>') |
| return |
| |
| |
| |
| |
| if '_' in _raw: |
| try: |
| nsd_img_idx = int(_raw.rsplit('_', 1)[-1]) |
| except ValueError: |
| dd_status.text = '<span style="color:#c00">Could not parse NSD image index.</span>' |
| return |
| sample_idxs = _dd_loader.sample_idxs_for_nsd_img(nsd_img_idx) |
| if not sample_idxs: |
| dd_status.text = ( |
| f'<span style="color:#c00">NSD image {nsd_img_idx} has no trials ' |
| f'for this subject.</span>') |
| return |
| else: |
| sample_idxs = [_parsed] |
| _n = _dd_loader.n_samples |
| if _n is not None and any(not (0 <= s < _n) for s in sample_idxs): |
| dd_status.text = f'<span style="color:#c00">sample_idx must be 0–{_n-1}.</span>' |
| return |
| try: |
| seed = int(dd_seed_input.value) |
| except ValueError: |
| seed = 42 |
| names = list(dd_source.data['name']) |
| feat_name = ' + '.join(names) if names else 'unknown' |
| dd_btn.disabled = True |
| n_trials = len(sample_idxs) |
| dd_status.text = (f'<i style="color:#888">Running DynaDiff reconstruction ' |
| f'({n_trials} trial{"s" if n_trials > 1 else ""})…</i>') |
| doc = curdoc() |
| threading.Thread( |
| target=_reconstruct_thread, |
| args=(sample_idxs, steerings, seed, feat_name, doc), |
| daemon=True, |
| ).start() |
|
|
| dd_btn.on_click(_on_reconstruct) |
|
|
| panel_body = column( |
| row(dd_feat_input, dd_add_lam_input, dd_add_thr_select, dd_feat_add_btn), |
| row(dd_feat_remove_btn, dd_feat_clear_btn), |
| dd_table, |
| dd_steer_div, |
| row(dd_sample_input, dd_seed_input), |
| row(dd_btn, dd_status), |
| dd_output, |
| ) |
| return panel_body, dd_output, dd_status, dd_feat_input, dd_sample_input |
|
|
|
|
| |
| |
| |
| _dd_panel_body, _dd_output, _dd_status, _dd_feat_input, _dd_sample_input = _build_dynadiff_panel() |
|
|
| |
| name_input = TextInput( |
| title="Feature name (auto-saved):", |
| placeholder="Enter a name for this feature...", |
| width=420, |
| ) |
|
|
| |
| _gemini_api_key = args.google_api_key or os.environ.get("GOOGLE_API_KEY") |
| gemini_btn = Button( |
| label="Label with Gemini", |
| width=140, |
| button_type="warning", |
| disabled=(_gemini_api_key is None), |
| ) |
| gemini_status_div = Div(text=( |
| "<i style='color:#aaa'>No GOOGLE_API_KEY set</i>" |
| if _gemini_api_key is None else "" |
| ), width=300) |
|
|
| |
| zoom_slider = Slider( |
| title="Zoom (patches)", value=16, start=1, end=16, step=1, width=220, |
| ) |
|
|
| |
| heatmap_alpha_slider = Slider( |
| title="Heatmap opacity", value=1.0, start=0.0, end=1.0, step=0.05, width=220, |
| ) |
|
|
| |
| _view_options = ["Top (max activation)", "Mean activation", "Compare aggregations"] |
|
|
| view_select = Select( |
| title="Image ranking:", |
| value="Top (max activation)", |
| options=_view_options, |
| width=250, |
| ) |
|
|
| nsd_subset_toggle = RadioButtonGroup( |
| labels=["All images", "NSD sub01"], |
| active=0, |
| width=220, |
| ) |
|
|
| N_DISPLAY = 6 |
|
|
|
|
| def update_feature_display(feature_idx): |
| feat = int(feature_idx) |
| _S.render_token += 1 |
| my_token = _S.render_token |
|
|
| freq_val = feature_frequency[feat].item() |
| mean_val = feature_mean_act[feat].item() |
| dead = "DEAD FEATURE" if freq_val == 0 else "" |
|
|
| feat_name = feature_names.get(feat, "") |
| auto_name = auto_interp_names.get(feat, "") |
| name_parts = [] |
| if feat_name: |
| name_parts.append( |
| f'<div style="color:#1a6faf;font-style:italic;margin:2px 0 3px 0">' |
| f'🏷︎ {feat_name}' |
| f'<span style="font-size:10px;color:#999;margin-left:6px">(manual)</span></div>' |
| ) |
| if auto_name: |
| name_parts.append( |
| f'<div style="color:#5a9a5a;font-style:italic;margin:2px 0 3px 0">' |
| f'🤖 {auto_name}' |
| f'<span style="font-size:10px;color:#999;margin-left:6px">(auto-interp)</span></div>' |
| ) |
| name_display = "".join(name_parts) |
|
|
| phi_c_val = _phi_c_for_feat(feat) |
| phi_chip = (f' · <b>φ_c:</b> {phi_c_val:.4f}' if phi_c_val is not None else '') |
| stats_div.text = ( |
| f'<h2 style="margin:4px 0">Feature {feat}' |
| f'<span style="color:red;margin-left:8px">{dead}</span>' |
| f'<span style="font-size:13px;font-weight:normal;color:#555;margin-left:14px">' |
| f'<b>Freq:</b> {int(freq_val):,} · ' |
| f'<b>Mean act:</b> {mean_val:.4f}' |
| f'{phi_chip}</span></h2>' |
| + name_display |
| ) |
| name_input.value = feat_name |
|
|
| if freq_val == 0: |
| status_div.text = _status_html( |
| 'dead', f'Feature {feat} is dead — it never activated on the precompute set.') |
| brain_div.text = _render_cortical_profile(feat) |
| for div in [top_heatmap_div, mean_heatmap_div, compare_agg_div]: |
| div.text = "" |
| return |
|
|
| status_div.text = _status_html( |
| 'loading', f'⏳ Rendering heatmaps for feature {feat}...') |
|
|
| def _render(): |
| |
| if _S.render_token != my_token: |
| return |
|
|
| _SLOT_EMPTY = object() |
|
|
| def _render_one(img_idx_tensor, act_tensor, ranking_idx, heatmap_tensor=None, |
| center='peak'): |
| img_i = img_idx_tensor[feat, ranking_idx].item() |
| if img_i < 0: |
| return _SLOT_EMPTY |
| try: |
| |
| if heatmap_tensor is not None and heatmap_patch_grid > 1: |
| hmap = heatmap_tensor[feat, ranking_idx].float().numpy() |
| hmap = hmap.reshape(heatmap_patch_grid, heatmap_patch_grid) |
| else: |
| hmap = None |
|
|
| img_label = os.path.splitext(os.path.basename(image_paths[img_i]))[0] |
| act_val = float(act_tensor[feat, ranking_idx].item()) |
| caption = f"act={act_val:.4f} {img_label}" |
| if hmap is None: |
| plain = load_image(img_i).resize((THUMB, THUMB), Image.BILINEAR) |
| return (plain, caption) |
| img_out = render_zoomed_overlay(img_i, hmap, size=THUMB, center=center) |
| return (img_out, caption) |
| except (FileNotFoundError, OSError): |
| return None |
| except Exception as e: |
| ph = Image.new("RGB", (THUMB, THUMB), "gray") |
| return (ph, f"Error: {e}") |
|
|
| def _collect(idx_tensor, act_tensor, hm_tensor, n, center='peak'): |
| """Render up to n images, skipping unavailable files but stopping at empty slots.""" |
| results = [] |
| for j in range(min(n, idx_tensor.shape[1])): |
| hm = _render_one(idx_tensor, act_tensor, j, hm_tensor, center=center) |
| if hm is _SLOT_EMPTY: |
| break |
| if hm is None: |
| continue |
| results.append(hm) |
| return results |
|
|
| |
| _use_nsd = nsd_subset_toggle.active == 1 and HAS_NSD_SUBSET |
| _top_idx = nsd_top_img_idx if _use_nsd else top_img_idx |
| _top_act = nsd_top_img_act if _use_nsd else top_img_act |
| _mean_idx = nsd_mean_img_idx if _use_nsd else mean_img_idx |
| _mean_act = nsd_mean_img_act if _use_nsd else mean_img_act |
| _top_hm = nsd_top_heatmaps if _use_nsd else top_heatmaps |
| _mean_hm = nsd_mean_heatmaps if _use_nsd else mean_heatmaps |
|
|
| heatmap_infos = _collect(_top_idx, _top_act, _top_hm, N_DISPLAY) |
|
|
| _subset_label = " [NSD sub01]" if _use_nsd else "" |
| top_heatmap_div.text = make_image_grid_html( |
| heatmap_infos, f"Top by Max Activation (feature {feat}){_subset_label}") |
|
|
| |
| mean_hm_infos = _collect(_mean_idx, _mean_act, _mean_hm, N_DISPLAY, center='centroid') |
|
|
| mean_heatmap_div.text = make_image_grid_html( |
| mean_hm_infos, f"Top by Mean Activation (feature {feat}){_subset_label}") |
|
|
| |
| compare_agg_div.text = make_compare_aggregations_html( |
| heatmap_infos, mean_hm_infos, feat, |
| model_label=_all_datasets[_S.active]['label']) |
|
|
| brain_div.text = _render_cortical_profile(feat) |
|
|
| |
| |
| |
| |
| if HAS_DYNADIFF: |
| _dd_feat_input.value = str(feat) |
| _use_nsd_dd = nsd_subset_toggle.active == 1 and HAS_NSD_SUBSET |
| if _use_nsd_dd and _dd_sample_input is not None: |
| _top_i = nsd_top_img_idx[feat, 0].item() |
| if _top_i >= 0: |
| _dd_sample_input.value = os.path.splitext( |
| os.path.basename(image_paths[_top_i]))[0] |
| _dd_status.text = ( |
| '<i style="color:#888">Feature pre-filled → click Add, then Steer & Reconstruct.</i>' |
| if _phi_voxel_row(feat) is not None else |
| '<span style="color:#c00">No phi data for this feature.</span>' |
| ) |
|
|
| status_div.text = _status_html('ok', f'✓ Feature {feat} ready.') |
| _update_view_visibility() |
|
|
| curdoc().add_next_tick_callback(_render) |
|
|
|
|
| |
| def _update_view_visibility(): |
| v = view_select.value |
| is_compare = (v == "Compare aggregations") |
| top_heatmap_div.visible = (v == "Top (max activation)") |
| mean_heatmap_div.visible = (v == "Mean activation") |
| compare_agg_div.visible = is_compare |
|
|
| view_select.on_change('value', lambda attr, old, new: _update_view_visibility()) |
| _update_view_visibility() |
|
|
|
|
| def _rerender_current_feature(attr, old, new): |
| """Re-render the current feature when any display control changes.""" |
| try: |
| feat = int(feature_input.value) |
| if 0 <= feat < d_model: |
| update_feature_display(feat) |
| except ValueError: |
| pass |
|
|
| zoom_slider.on_change('value', _rerender_current_feature) |
| heatmap_alpha_slider.on_change('value', _rerender_current_feature) |
| nsd_subset_toggle.on_change('active', _rerender_current_feature) |
|
|
|
|
| |
| def _umap_alphas_for_selection(selected_pos): |
| """Return point_alpha list: 0.6 for selected point, 0.2 for all others.""" |
| n = len(umap_source.data['feature_idx']) |
| if selected_pos is None: |
| return [0.6] * n |
| return [0.6 if i == selected_pos else 0.2 for i in range(n)] |
|
|
|
|
| def on_umap_select(attr, old, new): |
| if new: |
| umap_source.data['point_alpha'] = _umap_alphas_for_selection(new[0]) |
| feature_idx = umap_source.data['feature_idx'][new[0]] |
| feature_input.value = str(feature_idx) |
| update_feature_display(feature_idx) |
| else: |
| umap_source.data['point_alpha'] = _umap_alphas_for_selection(None) |
|
|
| umap_source.selected.on_change('indices', on_umap_select) |
|
|
|
|
| |
| _umap_type_options = ["Activation Pattern", "Dictionary Geometry"] |
|
|
| umap_type_select = Select( |
| title="UMAP Type", value="Activation Pattern", |
| options=_umap_type_options, width=220, |
| ) |
|
|
| |
| _color_options = ["Log Frequency", "Mean Activation"] |
| if _phi_c is not None: |
| _color_options.append("Brain Leverage (φ_c)") |
|
|
| umap_color_select = Select( |
| title="Color by:", value="Log Frequency", |
| options=_color_options, width=200, |
| ) |
|
|
|
|
| def _apply_umap_color(color_by, feat_ids): |
| """Update umap_source color_val and color_mapper range for the given indices.""" |
| _S.color_by = color_by |
| new_colors = _make_color_vals(feat_ids) |
| umap_source.data['color_val'] = new_colors |
| _set_color_mapper_range(new_colors) |
|
|
|
|
| def _on_umap_color_change(attr, old, new): |
| feat_ids = list(umap_source.data['feature_idx']) |
| _apply_umap_color(new, feat_ids) |
|
|
|
|
| umap_color_select.on_change('value', _on_umap_color_change) |
|
|
|
|
|
|
| def on_umap_type_change(attr, old, new): |
| color_vals = [] |
| if new == "Activation Pattern": |
| feat_ids = umap_backup['act_feat'] |
| color_vals = _make_color_vals(feat_ids) |
| _phi_c_list = _phi_c_vals(feat_ids) |
| umap_source.data = dict( |
| x=umap_backup['act_x'], |
| y=umap_backup['act_y'], |
| feature_idx=feat_ids, |
| frequency=freq[live_mask].tolist(), |
| log_freq=log_freq[live_mask].tolist(), |
| mean_act=mean_act[live_mask].tolist(), |
| phi_c_val=_phi_c_list, |
| color_val=color_vals, |
| point_alpha=_make_point_alphas(len(feat_ids)), |
| ) |
| umap_fig.title.text = "UMAP of SAE Features (by activation pattern)" |
| else: |
| feat_ids = umap_backup['dict_feat'] |
| dict_freq = freq[dict_live_mask] |
| dict_log_freq = log_freq[dict_live_mask] |
| dict_mean_act = mean_act[dict_live_mask] |
| color_vals = _make_color_vals(feat_ids) |
| _phi_c_list = _phi_c_vals(feat_ids) |
| umap_source.data = dict( |
| x=umap_backup['dict_x'], |
| y=umap_backup['dict_y'], |
| feature_idx=feat_ids, |
| frequency=dict_freq.tolist(), |
| log_freq=dict_log_freq.tolist(), |
| mean_act=dict_mean_act.tolist(), |
| phi_c_val=_phi_c_list, |
| color_val=color_vals, |
| point_alpha=_make_point_alphas(len(feat_ids)), |
| ) |
| umap_fig.title.text = "UMAP of SAE Features (by dictionary geometry)" |
| _set_color_mapper_range(color_vals) |
|
|
| umap_type_select.on_change('value', on_umap_type_change) |
|
|
|
|
| |
| feature_input = TextInput(title="Feature Index:", value="", width=120) |
| go_button = Button(label="Go", width=60) |
| random_btn = Button(label="Random", width=70) |
|
|
|
|
| def _select_and_display(feat): |
| """Show the detail panel for feat and sync the UMAP highlight.""" |
| update_feature_display(feat) |
| feat_list = umap_source.data['feature_idx'] |
| if feat in feat_list: |
| umap_source.selected.indices = [feat_list.index(feat)] |
|
|
|
|
| def on_go_click(): |
| try: |
| feat = int(feature_input.value) |
| if 0 <= feat < d_model: |
| _select_and_display(feat) |
| else: |
| stats_div.text = f"<h3>Feature {feat} out of range (0-{d_model-1})</h3>" |
| except ValueError: |
| stats_div.text = "<h3>Please enter a valid integer</h3>" |
|
|
| go_button.on_click(on_go_click) |
|
|
|
|
| def _on_random(): |
| if not _active_feats: |
| return |
| feat = random.choice(_active_feats) |
| feature_input.value = str(feat) |
| _select_and_display(feat) |
|
|
| random_btn.on_click(_on_random) |
|
|
|
|
| |
|
|
| _init_order = np.argsort(-freq) |
| feature_list_source = ColumnDataSource(data=dict( |
| feature_idx=_init_order.tolist(), |
| frequency=freq[_init_order].tolist(), |
| mean_act=mean_act[_init_order].tolist(), |
| phi_c_val=_phi_c_vals(_init_order.tolist()), |
| name=[_display_name(int(i)) for i in _init_order], |
| )) |
|
|
| def _phi_col(): |
| """Return phi_c column definition list (single element) if phi data is loaded, else [].""" |
| if not HAS_PHI: |
| return [] |
| return [TableColumn(field="phi_c_val", title="φ_c", width=65, |
| formatter=NumberFormatter(format="0.0000"))] |
|
|
| feature_table = DataTable( |
| source=feature_list_source, |
| columns=[ |
| TableColumn(field="feature_idx", title="Feature", width=60), |
| TableColumn(field="frequency", title="Freq", width=70, |
| formatter=NumberFormatter(format="0,0")), |
| TableColumn(field="mean_act", title="Mean Act", width=80, |
| formatter=NumberFormatter(format="0.0000")), |
| ] + _phi_col() + [ |
| TableColumn(field="name", title="Name", width=200), |
| ], |
| width=500, height=500, sortable=True, index_position=None, |
| ) |
|
|
| |
|
|
|
|
| def _get_sorted_order(): |
| order = np.argsort(-freq) |
| if _S.search_filter is not None: |
| mask = np.isin(order, list(_S.search_filter)) |
| order = order[mask] |
| return order |
|
|
|
|
| def _apply_order(order): |
| feature_list_source.data = dict( |
| feature_idx=order.tolist(), |
| frequency=freq[order].tolist(), |
| mean_act=mean_act[order].tolist(), |
| phi_c_val=_phi_c_vals(order.tolist()), |
| name=[_display_name(int(i)) for i in order], |
| ) |
|
|
|
|
| def _update_table_names(): |
| """Refresh the name column after saving or deleting a feature name.""" |
| _apply_order(np.array(feature_list_source.data['feature_idx'])) |
|
|
|
|
| def _on_table_select(attr, old, new): |
| if new: |
| feat = feature_list_source.data['feature_idx'][new[0]] |
| feature_input.value = str(feat) |
| _select_and_display(feat) |
|
|
| feature_list_source.selected.on_change('indices', _on_table_select) |
|
|
|
|
| |
| def on_name_change(attr, old, new): |
| try: |
| feat = int(feature_input.value) |
| except ValueError: |
| return |
| name = new.strip() |
| if name: |
| feature_names[feat] = name |
| elif feat in feature_names: |
| del feature_names[feat] |
| _save_names() |
| _update_table_names() |
|
|
| name_input.on_change('value', on_name_change) |
|
|
|
|
| |
| _N_GEMINI_IMAGES = 6 |
| _GEMINI_MODEL = "gemini-2.5-flash" |
| _GEMINI_HM_ALPHA = 0.25 |
|
|
| def _gemini_label_thread(feat, mei_items, doc): |
| """Run in a worker thread: call Gemini and push the result back to the doc. |
| |
| mei_items: list of (path_str, heatmap_np_or_None) where heatmap is (H, W) float32. |
| """ |
| try: |
| from google import genai |
| from google.genai import types |
|
|
| SYSTEM_PROMPT = ( |
| "You are labeling features of a Sparse Autoencoder (SAE) trained on a " |
| "vision transformer. Each SAE feature is a sparse direction in activation " |
| "space that fires strongly on certain visual patterns." |
| ) |
| USER_PROMPT = ( |
| "The images below are the top maximally-activating images for one SAE feature. " |
| "In 2–5 words, give a precise label for the visual concept this feature detects. " |
| "Be specific — prefer 'dog snout close-up' over 'dog', or 'brick wall texture' " |
| "over 'texture'. " |
| "Reply with ONLY the label, no explanation, no punctuation at the end." |
| ) |
|
|
| client = genai.Client(api_key=_gemini_api_key) |
| parts = [] |
| for path, _heatmap in mei_items[:_N_GEMINI_IMAGES]: |
| resolved = _resolve_img_path(path) |
| if resolved is None: |
| continue |
| try: |
| img = Image.open(resolved).convert("RGB").resize((224, 224), Image.BILINEAR) |
| buf = io.BytesIO() |
| img.save(buf, format="JPEG", quality=85) |
| parts.append(types.Part.from_bytes(data=buf.getvalue(), mime_type="image/jpeg")) |
| except Exception: |
| continue |
|
|
| if not parts: |
| def _no_images(): |
| gemini_btn.disabled = False |
| gemini_status_div.text = "<span style='color:#c00'>No images could be loaded.</span>" |
| doc.add_next_tick_callback(_no_images) |
| return |
|
|
| parts.append(types.Part.from_text(text=USER_PROMPT)) |
| response = client.models.generate_content( |
| model=_GEMINI_MODEL, |
| contents=parts, |
| config=types.GenerateContentConfig(system_instruction=SYSTEM_PROMPT), |
| ) |
| label = response.text.strip().strip(".,;:\"'") |
|
|
| def _apply_label(feat=feat, label=label): |
| auto_interp_names[feat] = label |
| _save_auto_interp() |
| _update_table_names() |
| |
| try: |
| update_feature_display(feat) |
| except Exception: |
| pass |
| gemini_btn.disabled = False |
| gemini_status_div.text = ( |
| f"<span style='color:#1a6faf'><b>Labeled:</b> {label}</span>" |
| ) |
| print(f" [Gemini] feat {feat}: {label}") |
|
|
| doc.add_next_tick_callback(_apply_label) |
|
|
| except Exception as e: |
| err = str(e) |
| def _show_err(err=err): |
| gemini_btn.disabled = False |
| gemini_status_div.text = f"<span style='color:#c00'>Error: {err[:120]}</span>" |
| print(f" [Gemini] feat {feat} error: {err}") |
| doc.add_next_tick_callback(_show_err) |
|
|
|
|
| def _on_gemini_click(): |
| try: |
| feat = int(feature_input.value) |
| except ValueError: |
| gemini_status_div.text = "<span style='color:#c00'>Select a feature first.</span>" |
| return |
|
|
| if feature_frequency[feat].item() == 0: |
| gemini_status_div.text = "<span style='color:#c00'>Dead feature — no images.</span>" |
| return |
|
|
| n_top_stored = top_img_idx.shape[1] |
| mei_items = [] |
| for j in range(n_top_stored): |
| idx = top_img_idx[feat, j].item() |
| if idx >= 0: |
| hm = None |
| if top_heatmaps is not None: |
| hm = top_heatmaps[feat, j].float().numpy().reshape(heatmap_patch_grid, heatmap_patch_grid) |
| mei_items.append((image_paths[idx], hm)) |
|
|
| if not mei_items: |
| gemini_status_div.text = "<span style='color:#c00'>No MEI paths found.</span>" |
| return |
|
|
| gemini_btn.disabled = True |
| gemini_status_div.text = "<i style='color:#888'>Calling Gemini…</i>" |
|
|
| doc = curdoc() |
| t = threading.Thread( |
| target=_gemini_label_thread, |
| args=(feat, mei_items, doc), |
| daemon=True, |
| ) |
| t.start() |
|
|
|
|
| if _gemini_api_key: |
| gemini_btn.on_click(_on_gemini_click) |
|
|
|
|
| |
| search_input = TextInput( |
| title="Search feature names:", |
| placeholder="Type to search...", |
| width=220, |
| ) |
| search_btn = Button(label="Search", width=70, button_type="primary") |
| clear_search_btn = Button(label="Clear", width=60) |
| search_result_div = Div(text="", width=360) |
|
|
|
|
| def _do_search(): |
| query = search_input.value.strip().lower() |
| if not query: |
| _S.search_filter = None |
| search_result_div.text = "" |
| _apply_order(_get_sorted_order()) |
| return |
| matches = { |
| i for i, name in feature_names.items() if query in name.lower() |
| } | { |
| i for i, name in auto_interp_names.items() if query in name.lower() |
| } |
| _S.search_filter = matches |
| _apply_order(_get_sorted_order()) |
| if matches: |
| search_result_div.text = ( |
| f'<span style="color:#1a6faf"><b>{len(matches)}</b> feature(s) matching ' |
| f'“{query}”</span>' |
| ) |
| else: |
| search_result_div.text = ( |
| f'<span style="color:#c00">No features named “{query}”</span>' |
| ) |
|
|
|
|
| def _do_clear_search(): |
| search_input.value = "" |
| _S.search_filter = None |
| search_result_div.text = "" |
| _apply_order(_get_sorted_order()) |
|
|
|
|
| search_btn.on_click(_do_search) |
| clear_search_btn.on_click(_do_clear_search) |
|
|
|
|
| |
| def _make_summary_html(): |
| ds = _all_datasets[_S.active] |
| n_umap_act = int(live_mask.sum()) |
| n_live_dict = int(dict_live_mask.sum()) |
| n_truly_active = int((freq > 0).sum()) |
| n_dead = d_model - n_truly_active |
| tok_label = f"{patch_grid}×{patch_grid} = {patch_grid**2} patches" |
| backbone_label = ds.get('backbone', 'dinov3').upper() |
| clip_label = "yes" if ds.get('clip_embeds') is not None else "no" |
| hm_label = "yes" if ds.get('top_heatmaps') is not None else "no" |
| sae_url = ds.get('sae_url') |
| dl_row = (f'<tr><td><b>SAE weights:</b></td>' |
| f'<td><a href="{sae_url}" download style="color:#1a6faf">⬇ Download</a></td></tr>' |
| if sae_url else '') |
| return f""" |
| <div style="background:#f0f4f8;padding:12px;border-radius:6px;margin-bottom:8px;"> |
| <h2 style="margin:0 0 8px 0">SAE Feature Explorer</h2> |
| <table style="font-size:13px;"> |
| <tr><td><b>Active model:</b></td><td><b style="color:#1a6faf">{ds['label']}</b></td></tr> |
| <tr><td><b>Backbone:</b></td><td>{backbone_label}</td></tr> |
| <tr><td><b>Dictionary size:</b></td><td>{d_model:,}</td></tr> |
| <tr><td><b>Active (fired ≥1):</b></td><td>{n_truly_active:,} ({100*n_truly_active/d_model:.1f}%)</td></tr> |
| <tr><td><b>Dead:</b></td><td>{n_dead:,} ({100*n_dead/d_model:.1f}%)</td></tr> |
| <tr><td><b>Images:</b></td><td>{n_images:,}</td></tr> |
| <tr><td><b>Tokens/image:</b></td><td>{tok_label}</td></tr> |
| {dl_row} |
| </table> |
| </div>""" |
|
|
| summary_div = Div(text=_make_summary_html(), width=700) |
|
|
|
|
| |
| |
| |
|
|
| _PATCH_FIG_PX = 400 |
|
|
| |
| |
| _pr = [r for r in range(patch_grid) for _ in range(patch_grid)] |
| _pc = list(range(patch_grid)) * patch_grid |
|
|
| patch_grid_source = ColumnDataSource(data=dict( |
| x=[c + 0.5 for c in _pc], |
| y=[patch_grid - r - 0.5 for r in _pr], |
| row=_pr, |
| col=_pc, |
| )) |
|
|
| patch_bg_source = ColumnDataSource(data=dict( |
| image=[], x=[0], y=[0], dw=[patch_grid], dh=[patch_grid], |
| )) |
|
|
| patch_fig = figure( |
| width=_PATCH_FIG_PX, height=_PATCH_FIG_PX, |
| x_range=(0, patch_grid), y_range=(0, patch_grid), |
| tools=["tap", "reset"], |
| title="Click or drag to paint patch selection", |
| toolbar_location="above", |
| visible=False, |
| ) |
|
|
| |
| |
| |
| _paint_js = CustomJS(args=dict(source=patch_grid_source, pg=patch_grid), code=""" |
| if (!window._patch_paint_init) { |
| window._patch_paint_init = true; |
| window._patch_btn_held = false; |
| document.addEventListener('mousedown', () => { window._patch_btn_held = true; }); |
| document.addEventListener('mouseup', () => { window._patch_btn_held = false; }); |
| } |
| if (!window._patch_btn_held) return; |
| |
| const x = cb_obj.x, y = cb_obj.y; |
| if (x === null || y === null || x < 0 || x >= pg || y < 0 || y >= pg) return; |
| |
| const col = Math.floor(x); |
| const row = pg - 1 - Math.floor(y); |
| const flat_idx = row * pg + col; |
| |
| const sel = source.selected.indices.slice(); |
| if (sel.indexOf(flat_idx) === -1) { |
| sel.push(flat_idx); |
| source.selected.indices = sel; |
| } |
| """) |
| patch_fig.js_on_event(MouseMove, _paint_js) |
| patch_fig.image_rgba( |
| source=patch_bg_source, |
| image='image', x='x', y='y', dw='dw', dh='dh', |
| ) |
| patch_fig.rect( |
| source=patch_grid_source, |
| x='x', y='y', width=0.95, height=0.95, |
| fill_color='yellow', fill_alpha=0.0, |
| line_color='white', line_alpha=0.35, line_width=0.5, |
| selection_fill_color='red', selection_fill_alpha=0.45, |
| nonselection_fill_alpha=0.0, nonselection_line_alpha=0.35, |
| ) |
| patch_fig.axis.visible = False |
| patch_fig.xgrid.visible = False |
| patch_fig.ygrid.visible = False |
|
|
| patch_img_input = TextInput(title="Image Index:", value="0", width=120) |
| load_patch_btn = Button(label="Load Image", width=90, button_type="primary") |
| clear_patch_btn = Button(label="Clear", width=60) |
|
|
| patch_feat_source = ColumnDataSource(data=dict( |
| feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[], |
| )) |
| patch_feat_table = DataTable( |
| source=patch_feat_source, |
| columns=[ |
| TableColumn(field="feature_idx", title="Feature", width=65), |
| TableColumn(field="patch_act", title="Patch Act", width=85, |
| formatter=NumberFormatter(format="0.0000")), |
| TableColumn(field="frequency", title="Freq", width=65, |
| formatter=NumberFormatter(format="0,0")), |
| TableColumn(field="mean_act", title="Mean Act", width=80, |
| formatter=NumberFormatter(format="0.0000")), |
| ] + _phi_col(), |
| width=310 + (65 if HAS_PHI else 0), height=350, index_position=None, sortable=False, visible=False, |
| ) |
| patch_info_div = Div( |
| text="<i>Load an image, then click patches to find top features.</i>", |
| width=310, |
| ) |
|
|
|
|
|
|
| def _pil_to_bokeh_rgba(pil_img, size): |
| pil_img = pil_img.resize((size, size), Image.BILINEAR).convert("RGBA") |
| arr = np.array(pil_img, dtype=np.uint8) |
| out = np.empty((size, size), dtype=np.uint32) |
| view = out.view(dtype=np.uint8).reshape((size, size, 4)) |
| view[:, :, :] = arr |
| return out[::-1].copy() |
|
|
|
|
| def _do_load_patch_image(): |
| try: |
| img_idx = _parse_img_label(patch_img_input.value) |
| except ValueError: |
| patch_info_div.text = "<b style='color:red'>Invalid image index</b>" |
| return |
| if not (0 <= img_idx < n_images): |
| patch_info_div.text = f"<b style='color:red'>Index out of range (0–{n_images - 1})</b>" |
| return |
|
|
| _S.patch_img = img_idx |
| try: |
| pil = load_image(img_idx) |
| bokeh_arr = _pil_to_bokeh_rgba(pil, _PATCH_FIG_PX) |
| patch_bg_source.data = dict( |
| image=[bokeh_arr], x=[0], y=[0], dw=[patch_grid], dh=[patch_grid], |
| ) |
| except Exception as e: |
| patch_info_div.text = f"<b style='color:red'>Error loading image: {e}</b>" |
| return |
|
|
| |
| load_patch_btn.disabled = True |
| patch_info_div.text = ( |
| "<span style='color:#1a6faf'>⏳ Computing patch activations" |
| + (" (running GPU inference — first image may take ~10 s)…" |
| if _gpu_runner is None and args.sae_path else "…") |
| + "</span>" |
| ) |
|
|
| _doc = curdoc() |
|
|
| def _bg(): |
| try: |
| z_np = compute_patch_activations(img_idx) |
| except Exception as e: |
| err = str(e) |
| def _show_err(err=err): |
| load_patch_btn.disabled = False |
| patch_info_div.text = f"<b style='color:red'>Error: {err}</b>" |
| _doc.add_next_tick_callback(_show_err) |
| return |
|
|
| def _apply(z_np=z_np): |
| _S.patch_z = z_np |
| load_patch_btn.disabled = False |
| patch_fig.visible = True |
| patch_grid_source.selected.indices = [] |
| patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[]) |
|
|
| if z_np is None: |
| patch_feat_table.visible = False |
| patch_info_div.text = ( |
| f"<b style='color:#888'>GPU inference unavailable for image {img_idx}. " |
| f"Ensure --sae-path is set and the GPU runner loaded successfully.</b>" |
| ) |
| return |
|
|
| patch_feat_table.visible = True |
| patch_info_div.text = ( |
| f"Image {img_idx} loaded. " |
| f"Drag to select a region, or click individual patches." |
| ) |
|
|
| _doc.add_next_tick_callback(_apply) |
|
|
| threading.Thread(target=_bg, daemon=True).start() |
|
|
|
|
| load_patch_btn.on_click(_do_load_patch_image) |
|
|
|
|
| def _do_clear_patches(): |
| patch_grid_source.selected.indices = [] |
| patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[]) |
| patch_info_div.text = "<i>Selection cleared.</i>" |
|
|
| clear_patch_btn.on_click(_do_clear_patches) |
|
|
|
|
| def _get_top_features_for_patches(patch_indices, top_n=20): |
| """Sum SAE activations over selected patches; return top features.""" |
| z_np = _S.patch_z |
| if z_np is None: |
| return [], [], [], [] |
|
|
| |
| z_selected = z_np[patch_indices] |
| feat_sums = z_selected.sum(axis=0) |
| print(f"[patch] patch_indices={patch_indices}, z_np shape={z_np.shape}, feat_sums max={feat_sums.max():.4f}, nonzero={int((feat_sums>0).sum())}") |
|
|
| top_feats = np.argsort(-feat_sums)[:top_n] |
| top_feats = top_feats[feat_sums[top_feats] > 0] |
|
|
| feats = top_feats.tolist() |
| acts = feat_sums[top_feats].tolist() |
| freqs = [int(feature_frequency[f].item()) for f in feats] |
| means = [float(feature_mean_act[f].item()) for f in feats] |
| return feats, acts, freqs, means |
|
|
|
|
| def _on_patch_select(attr, old, new): |
| if _S.patch_img is None: |
| return |
| if not new: |
| patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[]) |
| patch_info_div.text = "<i>Selection cleared.</i>" |
| return |
|
|
| |
| rows = [patch_grid_source.data['row'][i] for i in new] |
| cols = [patch_grid_source.data['col'][i] for i in new] |
| patch_indices = [r * patch_grid + c for r, c in zip(rows, cols)] |
|
|
| feats, acts, freqs, means = _get_top_features_for_patches(patch_indices) |
| patch_feat_source.data = dict( |
| feature_idx=feats, patch_act=acts, frequency=freqs, mean_act=means, |
| phi_c_val=_phi_c_vals(feats), |
| ) |
| patch_info_div.text = ( |
| f"{len(new)} patch(es) selected → {len(feats)} feature(s) found. " |
| f"Click a row below to explore the feature." |
| ) |
|
|
| patch_grid_source.selected.on_change('indices', _on_patch_select) |
|
|
|
|
| def _on_patch_feat_table_select(attr, old, new): |
| if not new: |
| return |
| feat = patch_feat_source.data['feature_idx'][new[0]] |
| feature_input.value = str(feat) |
| _select_and_display(feat) |
|
|
| patch_feat_source.selected.on_change('indices', _on_patch_feat_table_select) |
|
|
|
|
| |
| def _build_clip_panel(): |
| """Build the CLIP text-search panel widgets and callbacks. |
| |
| Returns (panel, result_div, result_source). |
| When HAS_CLIP is False, result_div and result_source are None and panel is a |
| static placeholder Div. |
| """ |
| if not HAS_CLIP: |
| panel = Div( |
| text="<i style='color:#aaa'>CLIP text search unavailable — " |
| "run <code>scripts/add_clip_embeddings.py</code> to enable.</i>", |
| width=470, |
| ) |
| return panel, None, None |
|
|
| clip_query_input = TextInput( |
| title="Search features by text (CLIP):", |
| placeholder="e.g. 'dog', 'red stripes', 'water'...", |
| width=280, |
| ) |
| clip_search_btn = Button(label="Search", width=70, button_type="primary") |
| result_div = Div(text="", width=360) |
| clip_top_k_input = TextInput(title="Top-K results:", value="20", width=70) |
|
|
| result_source = ColumnDataSource(data=dict( |
| feature_idx=[], clip_score=[], frequency=[], mean_act=[], phi_c_val=[], name=[], |
| )) |
| clip_result_table = DataTable( |
| source=result_source, |
| columns=[ |
| TableColumn(field="feature_idx", title="Feature", width=65), |
| TableColumn(field="clip_score", title="CLIP score", width=85, |
| formatter=NumberFormatter(format="0.0000")), |
| TableColumn(field="frequency", title="Freq", width=65, |
| formatter=NumberFormatter(format="0,0")), |
| TableColumn(field="mean_act", title="Mean Act", width=80, |
| formatter=NumberFormatter(format="0.0000")), |
| ] + _phi_col() + [ |
| TableColumn(field="name", title="Name", width=160), |
| ], |
| width=470 + (65 if HAS_PHI else 0), height=300, index_position=None, sortable=False, |
| ) |
|
|
| def _do_search(): |
| query = clip_query_input.value.strip() |
| if not query: |
| result_div.text = "<i>Enter a text query above.</i>" |
| return |
| try: |
| top_k = max(1, int(clip_top_k_input.value)) |
| except ValueError: |
| top_k = 20 |
|
|
| |
| |
| _use_nsd_embeds = nsd_subset_toggle.active == 1 and _nsd_clip_embeds is not None |
| _active_embeds = _nsd_clip_embeds if _use_nsd_embeds else _clip_embeds |
| result_div.text = "<i>Encoding query with CLIP…</i>" |
| try: |
| clip_m, clip_p, clip_dev = _get_clip() |
| q_embed = compute_text_embeddings([query], clip_m, clip_p, clip_dev) |
| scores_vec = (_active_embeds.float() @ q_embed.T).squeeze(-1) |
| except Exception as exc: |
| result_div.text = f"<span style='color:#c00'>CLIP error: {exc}</span>" |
| return |
|
|
| |
| if nsd_subset_toggle.active == 1 and HAS_NSD_SUBSET: |
| nsd_mask = nsd_top_img_idx[:, 0] >= 0 |
| scores_vec = scores_vec.clone() |
| scores_vec[~nsd_mask] = float('-inf') |
|
|
| top_indices = torch.topk(scores_vec, k=min(top_k, len(scores_vec))).indices.tolist() |
| |
| top_indices = [i for i in top_indices if scores_vec[i] > float('-inf')] |
| result_source.data = dict( |
| feature_idx=top_indices, |
| clip_score=[float(scores_vec[i]) for i in top_indices], |
| frequency=[int(feature_frequency[i].item()) for i in top_indices], |
| mean_act=[float(feature_mean_act[i].item()) for i in top_indices], |
| phi_c_val=_phi_c_vals(top_indices), |
| name=[_display_name(int(i)) for i in top_indices], |
| ) |
| _subset_note = " [NSD sub01]" if (nsd_subset_toggle.active == 1 and HAS_NSD_SUBSET) else "" |
| result_div.text = ( |
| f'<span style="color:#1a6faf"><b>{len(top_indices)}</b> features for ' |
| f'“{query}”{_subset_note}</span>' |
| ) |
|
|
| clip_search_btn.on_click(_do_search) |
|
|
| def _on_result_select(attr, old, new): |
| if not new: |
| return |
| feat = result_source.data['feature_idx'][new[0]] |
| feature_input.value = str(feat) |
| _select_and_display(feat) |
|
|
| result_source.selected.on_change('indices', _on_result_select) |
|
|
| panel = column( |
| row(clip_query_input, clip_top_k_input, clip_search_btn), |
| result_div, |
| clip_result_table, |
| ) |
| return panel, result_div, result_source |
|
|
|
|
| clip_search_panel, clip_result_div, clip_result_source = _build_clip_panel() |
|
|
|
|
|
|
|
|
| |
| controls = row(umap_type_select, umap_color_select, feature_input, go_button, random_btn) |
|
|
| name_panel = column( |
| name_input, |
| row(gemini_btn, gemini_status_div), |
| ) |
|
|
| search_panel = column( |
| row(search_input, search_btn, clear_search_btn), |
| search_result_div, |
| ) |
|
|
| feature_list_panel = column(search_panel, feature_table) |
|
|
|
|
| def _make_collapsible(title, body, initially_open=False): |
| """Wrap a widget in a toggle-able collapsible section.""" |
| btn = Toggle( |
| label=("▼ " if initially_open else "▶ ") + title, |
| active=initially_open, |
| button_type="light", |
| width=500, |
| height=30, |
| ) |
| body.visible = initially_open |
| btn.js_on_click(CustomJS(args=dict(body=body, btn=btn, title=title), code=""" |
| body.visible = btn.active; |
| btn.label = (btn.active ? '▼ ' : '▶ ') + title; |
| """)) |
| return column(btn, body) |
|
|
|
|
| patch_explorer_panel = column( |
| row(patch_img_input, load_patch_btn, clear_patch_btn), |
| patch_fig, |
| patch_info_div, |
| patch_feat_table, |
| ) |
|
|
| summary_section = _make_collapsible("SAE Summary", summary_div) |
| patch_section = _make_collapsible("Patch Explorer", patch_explorer_panel) |
| clip_section = _make_collapsible("CLIP Text Search", clip_search_panel) |
|
|
| _ds_select_row = ([dataset_select] if len(_all_datasets) > 1 else []) |
| left_panel = column(*_ds_select_row, controls, umap_fig, feature_list_panel) |
|
|
| middle_panel = column( |
| status_div, |
| stats_div, |
| name_panel, |
| row(view_select, |
| column(Div(text="<b>Images:</b>", width=60, height=15, styles={"padding-top":"5px"}), |
| nsd_subset_toggle), |
| column(zoom_slider, heatmap_alpha_slider)), |
| compare_agg_div, |
| top_heatmap_div, |
| mean_heatmap_div, |
| brain_div, |
| ) |
|
|
| dd_section = ( |
| _make_collapsible("DynaDiff Brain Steering", _dd_panel_body, initially_open=True) |
| if HAS_DYNADIFF else Div(text="", width=1) |
| ) |
|
|
| right_panel = column(summary_section, patch_section, clip_section, dd_section) |
|
|
| layout = row(left_panel, middle_panel, right_panel) |
| curdoc().add_root(layout) |
| curdoc().title = "SAE Feature Explorer" |
|
|
| print("Explorer app ready!") |
|
|
| |
| if args.sae_path: |
| def _warmup_gpu(): |
| try: |
| _get_gpu_runner() |
| except Exception as _e: |
| print(f"[GPU runner] Warmup failed: {_e}") |
| threading.Thread(target=_warmup_gpu, daemon=True).start() |
|
|