Spaces:
Sleeping
Sleeping
Marlin Lee
Fix thread-safety race in brain surface interpolation; bake DINOv2 into Docker image
e311ea3 | """ | |
| Brain-alignment data: Phi matrices, voxel coordinates, DynaDiff loader. | |
| All data is loaded once at module import time from --phi-dir / --dynadiff-dir. | |
| Public flags HAS_PHI and HAS_DYNADIFF tell panels what is available. | |
| """ | |
| import base64 | |
| import io | |
| import os | |
| import sys | |
| import threading | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from .args import args | |
| # ---------- Nilearn surface rendering (TribeV2-style) ---------- | |
| _NILEARN_AVAILABLE = False | |
| _fsavg5 = None # cached fsaverage5 surface data | |
| _fsavg5_tree = None # cached KDTree over pial-left coords | |
| _fsavg5_pials = None # cached (pial_left_xyz, pial_right_xyz) | |
| try: | |
| import nibabel as _nib | |
| from nilearn.datasets import fetch_surf_fsaverage as _fetch_surf_fsaverage | |
| from nilearn.plotting import plot_surf_stat_map as _plot_surf_stat_map | |
| from scipy.spatial import cKDTree as _cKDTree | |
| _NILEARN_AVAILABLE = True | |
| except ImportError: | |
| pass | |
| def _get_fsavg5(): | |
| global _fsavg5, _fsavg5_pials, _fsavg5_tree | |
| if _fsavg5 is None: | |
| _fsavg5 = _fetch_surf_fsaverage('fsaverage5') | |
| pl = _nib.load(_fsavg5['pial_left']).darrays[0].data | |
| pr = _nib.load(_fsavg5['pial_right']).darrays[0].data | |
| _fsavg5_pials = (pl, pr) | |
| _fsavg5_tree = None | |
| return _fsavg5 | |
| _surf_interp = None # cached per-hemisphere (mask, idxs, weights) for fast IDW | |
| _surf_interp_lock = threading.Lock() | |
| def _ensure_surf_interp(coords: np.ndarray, max_dist: float = 12.0, k: int = 5): | |
| """Build and cache the KDTree + IDW weights for voxel-to-surface projection.""" | |
| global _surf_interp, _fsavg5_tree | |
| if _surf_interp is not None: | |
| return | |
| with _surf_interp_lock: | |
| if _surf_interp is not None: | |
| return | |
| _get_fsavg5() | |
| _fsavg5_tree = _cKDTree(coords) | |
| result = [] | |
| for pial_xyz in _fsavg5_pials: | |
| dists, idxs = _fsavg5_tree.query(pial_xyz, k=k, workers=-1) | |
| mask = dists[:, 0] <= max_dist | |
| d = np.where(dists[mask] == 0, 1e-10, dists[mask]) | |
| w = 1.0 / d | |
| w /= w.sum(axis=1, keepdims=True) | |
| result.append((mask, idxs[mask], w, pial_xyz.shape[0])) | |
| _surf_interp = result | |
| def _voxels_to_surface(values: np.ndarray, coords: np.ndarray, | |
| max_dist: float = 12.0, k: int = 5): | |
| """Interpolate voxel values onto fsaverage5 pial vertices via KDTree IDW. | |
| Returns (tex_left, tex_right) each shape (10242,) with NaN for vertices | |
| farther than max_dist mm from any voxel. | |
| """ | |
| _ensure_surf_interp(coords, max_dist, k) | |
| textures = [] | |
| for mask, idxs, weights, n_verts in _surf_interp: | |
| tex = np.full(n_verts, np.nan, dtype=np.float32) | |
| if mask.any(): | |
| tex[mask] = (weights * values[idxs]).sum(axis=1) | |
| textures.append(tex) | |
| return textures[0], textures[1] | |
| def _render_brain_surface_b64(values: np.ndarray, title: str = '', | |
| compact: bool = False, cbar_label: str = '', | |
| figsize=(12, 3.5), dpi=80) -> str | None: | |
| """Render voxel values on fsaverage5 cortical surface. | |
| Uses KDTree IDW to project values onto pial vertices, then renders with | |
| nilearn's plot_surf_stat_map on the inflated fsaverage5 mesh. | |
| compact=True → single left-posterior view; False → 4-view (lat+med, both hemis). | |
| Returns base64 PNG or None if nilearn unavailable. | |
| """ | |
| if not _NILEARN_AVAILABLE or _voxel_coords is None: | |
| return None | |
| fs = _get_fsavg5() | |
| tex_l, tex_r = _voxels_to_surface(values, _voxel_coords) | |
| vmax = float(np.nanpercentile(np.abs(values), 98)) or 1e-6 | |
| kwargs = dict(cmap='RdBu_r', colorbar=False, vmin=-vmax, vmax=vmax, | |
| bg_on_data=True) | |
| _VIEWS_FULL = [ | |
| (tex_l, 'infl_left', 'sulc_left', 'left', (0, -135)), | |
| (tex_l, 'infl_left', 'sulc_left', 'left', (0, 0)), | |
| (tex_r, 'infl_right', 'sulc_right', 'right', (0, 180)), | |
| (tex_r, 'infl_right', 'sulc_right', 'right', (0, -45)), | |
| ] | |
| if compact: | |
| fig, ax = plt.subplots(1, 1, figsize=(3.5, 2.8), | |
| subplot_kw={'projection': '3d'}, | |
| facecolor='#f8f8f8') | |
| _plot_surf_stat_map(surf_mesh=fs['infl_left'], stat_map=tex_l, | |
| bg_map=fs['sulc_left'], hemi='left', view=(0, -135), | |
| axes=ax, figure=fig, **kwargs) | |
| ax.set_box_aspect(None, zoom=1.4) | |
| else: | |
| fig, axes = plt.subplots( | |
| 1, 4, figsize=figsize, facecolor='#f8f8f8', | |
| subplot_kw={'projection': '3d'}, | |
| gridspec_kw={'wspace': -0.1, 'hspace': 0}, | |
| ) | |
| for ax, (tex, infl_k, sulc_k, hemi, view) in zip(axes, _VIEWS_FULL): | |
| _plot_surf_stat_map(surf_mesh=fs[infl_k], stat_map=tex, | |
| bg_map=fs[sulc_k], hemi=hemi, view=view, | |
| axes=ax, figure=fig, **kwargs) | |
| ax.set_box_aspect(None, zoom=1.4) | |
| sm = plt.cm.ScalarMappable(cmap='RdBu_r', | |
| norm=plt.Normalize(vmin=-vmax, vmax=vmax)) | |
| sm.set_array([]) | |
| cbar_ax = fig.add_axes([0.92, 0.2, 0.015, 0.6]) | |
| cbar = fig.colorbar(sm, cax=cbar_ax) | |
| if cbar_label: | |
| cbar.set_label(cbar_label, fontsize=9) | |
| if title: | |
| fig.suptitle(title, fontsize=10) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight', | |
| facecolor='#f8f8f8') | |
| plt.close(fig) | |
| return base64.b64encode(buf.getvalue()).decode('utf-8') | |
| # ---------- Phi (brain alignment) ---------- | |
| _phi_cv = None # (C, V) concept-by-voxel matrix, memory-mapped | |
| _phi_c = None # (C,) per-concept cortical leverage scores | |
| _voxel_coords = None # (V, 3) MNI voxel coordinates | |
| _voxel_to_vertex = None # (V,) fsaverage vertex → voxel map (surface-space phi only) | |
| _N_VOXELS_DD = 15724 | |
| _N_VERTS_FSAVG = 37984 | |
| def _pick_best_file(candidates: list, model_key: str, search_dir: str) -> str | None: | |
| """Prefer model_key substring match; fall back to largest file.""" | |
| if not candidates: | |
| return None | |
| if model_key: | |
| matched = [f for f in candidates if model_key in f.lower()] | |
| if matched: | |
| return sorted(matched)[0] | |
| print(f"[Phi] WARNING: --phi-model '{model_key}' matched nothing in {candidates}; " | |
| "falling back to largest file") | |
| return max(candidates, key=lambda f: os.path.getsize(os.path.join(search_dir, f))) | |
| if args.phi_dir and os.path.isdir(args.phi_dir): | |
| _pdir = args.phi_dir | |
| _model_key = (args.phi_model or "").lower() | |
| # Phi_cv matrix | |
| _phi_mat_files = [f for f in os.listdir(_pdir) | |
| if f.lower().startswith('phi_cv') and f.endswith('.npy')] | |
| _phi_pick = _pick_best_file(_phi_mat_files, _model_key, _pdir) | |
| if _phi_pick: | |
| _phi_path = os.path.join(_pdir, _phi_pick) | |
| _phi_cv = np.load(_phi_path, mmap_mode='r') | |
| print(f"[Phi] Loaded {_phi_pick}: shape {_phi_cv.shape}, dtype {_phi_cv.dtype}") | |
| if _phi_cv.shape[1] == _N_VERTS_FSAVG: | |
| _v2v_path = os.path.join(_pdir, 'voxel_to_vertex_map.npy') | |
| if os.path.exists(_v2v_path): | |
| _voxel_to_vertex = np.load(_v2v_path) | |
| print(f"[Phi] Surface-space phi; loaded voxel_to_vertex_map: " | |
| f"{_voxel_to_vertex.shape}") | |
| else: | |
| print("[Phi] WARNING: surface-space phi but voxel_to_vertex_map.npy not found") | |
| elif _phi_cv.shape[1] == _N_VOXELS_DD: | |
| print("[Phi] Voxel-space phi detected.") | |
| else: | |
| print(f"[Phi] WARNING: unexpected phi dimension {_phi_cv.shape[1]}") | |
| else: | |
| print(f"[Phi] WARNING: no Phi_cv_*.npy found in {_pdir}") | |
| # phi_c leverage scores | |
| _phi_c_files = [f for f in os.listdir(_pdir) | |
| if f.lower().startswith('phi_c') | |
| and not f.lower().startswith('phi_cv') | |
| and f.endswith('.npy')] | |
| _phi_c_pick = _pick_best_file(_phi_c_files, _model_key, _pdir) | |
| if _phi_c_pick: | |
| _phi_c = np.load(os.path.join(_pdir, _phi_c_pick)) | |
| print(f"[Phi] Leverage scores {_phi_c_pick}: shape {_phi_c.shape}, " | |
| f"range [{_phi_c.min():.4f}, {_phi_c.max():.4f}]") | |
| else: | |
| print(f"[Phi] No phi_c_*.npy found in {_pdir} — leverage scores unavailable") | |
| # Voxel coordinates | |
| _coords_path = os.path.join(_pdir, 'voxel_coords.npy') | |
| if os.path.exists(_coords_path): | |
| _voxel_coords = np.load(_coords_path) | |
| print(f"[Phi] Voxel coordinates: {_voxel_coords.shape}") | |
| else: | |
| print("[Phi] voxel_coords.npy not found — cortical scatter unavailable") | |
| HAS_PHI = _phi_cv is not None | |
| # ---------- DynaDiff ---------- | |
| _dd_loader = None | |
| HAS_DYNADIFF = False | |
| _scripts_dir = os.path.dirname(os.path.abspath(__file__)) + '/..' | |
| if args.dynadiff_modal_url: | |
| # ── Modal HTTP mode (no local GPU needed) ──────────────────────────────── | |
| if not HAS_PHI: | |
| print("[DynaDiff] WARNING: --phi-dir not set; steering panel requires Phi data. " | |
| "Disabling.") | |
| else: | |
| try: | |
| sys.path.insert(0, _scripts_dir) | |
| from dynadiff_loader import HTTPDynaDiffLoader | |
| _token = (args.dynadiff_modal_token | |
| or os.environ.get("DYNADIFF_MODAL_TOKEN", "")) | |
| _dd_loader = HTTPDynaDiffLoader( | |
| url=args.dynadiff_modal_url, | |
| token=_token, | |
| ) | |
| _dd_loader.start() | |
| HAS_DYNADIFF = True | |
| print(f"[DynaDiff] Modal endpoint: {args.dynadiff_modal_url}") | |
| except Exception as err: | |
| print(f"[DynaDiff] WARNING: Could not init Modal loader ({err}). " | |
| "Steering panel will be disabled.") | |
| elif args.dynadiff_dir and os.path.isdir(args.dynadiff_dir): | |
| # ── In-process mode (original, requires local GPU + dynadiff repo) ─────── | |
| if not HAS_PHI: | |
| print("[DynaDiff] WARNING: --phi-dir not set; steering panel requires Phi data. " | |
| "Disabling.") | |
| else: | |
| try: | |
| sys.path.insert(0, _scripts_dir) | |
| from dynadiff_loader import get_loader | |
| _h5 = args.dynadiff_h5 | |
| if not os.path.isabs(_h5): | |
| _h5 = os.path.join(args.dynadiff_dir, _h5) | |
| _dd_loader = get_loader( | |
| dynadiff_dir = args.dynadiff_dir, | |
| checkpoint = args.dynadiff_checkpoint, | |
| h5_path = _h5, | |
| nsd_thumb_dir = args.brain_thumbnails, | |
| subject_idx = 0, | |
| ) | |
| HAS_DYNADIFF = True | |
| print(f"[DynaDiff] In-process loader ready " | |
| f"(checkpoint: {args.dynadiff_checkpoint})") | |
| except Exception as err: | |
| print(f"[DynaDiff] WARNING: Could not start loader ({err}). " | |
| "Steering panel will be disabled.") | |
| # ---------- Per-feature helpers ---------- | |
| def phi_cv_shape() -> tuple | None: | |
| """Return (_phi_cv.shape[0], _phi_cv.shape[1]) or None if not loaded.""" | |
| return _phi_cv.shape if _phi_cv is not None else None | |
| def phi_c_for_feat(feat: int) -> float | None: | |
| """Cortical leverage score for a feature, or None.""" | |
| if _phi_c is None or feat >= len(_phi_c): | |
| return None | |
| return float(_phi_c[feat]) | |
| def phi_voxel_row(feat: int) -> np.ndarray | None: | |
| """Return the phi row in voxel space (15724,) float32, or None.""" | |
| if _phi_cv is None or feat >= _phi_cv.shape[0]: | |
| return None | |
| row = np.array(_phi_cv[feat], dtype=np.float32) | |
| if _voxel_to_vertex is not None: | |
| return row[_voxel_to_vertex] | |
| return row | |
| def phi_c_vals(indices) -> list: | |
| """Return phi_c leverage values for a list of feature indices (0.0 when unavailable).""" | |
| if _phi_c is None: | |
| return [0.0] * len(indices) | |
| return [float(_phi_c[i]) if i < len(_phi_c) else 0.0 for i in indices] | |
| def feat_display_name(feat: int | None) -> str: | |
| """Best-effort display name for DynaDiff feature table.""" | |
| if feat is None: | |
| return 'unknown' | |
| from .state import active_ds | |
| ds = active_ds() | |
| return ds['feature_names'].get(feat) or f'feat {feat}' | |
| def dynadiff_request(sample_idx: int, steerings: list, seed: int) -> dict: | |
| """Run DynaDiff reconstruction. Raises RuntimeError if model not ready.""" | |
| status, err = _dd_loader.status | |
| if status == 'loading': | |
| raise RuntimeError('DynaDiff model still loading — try again shortly') | |
| if status == 'error': | |
| raise RuntimeError(f'DynaDiff model load failed: {err}') | |
| return _dd_loader.reconstruct(sample_idx, steerings, seed) | |
| # ---------- Rendering helpers ---------- | |
| def _render_phi_map_b64_compact(feat: int, figsize=(3.5, 2.8), dpi=70) -> str | None: | |
| """Single left-lateral surface view of phi, small enough for a steering card.""" | |
| from .state import active_ds | |
| cached = active_ds().get('phi_map_cache', {}).get(feat) | |
| if cached is not None: | |
| return cached | |
| phi_vox = phi_voxel_row(feat) | |
| if phi_vox is None: | |
| return None | |
| b64 = _render_brain_surface_b64(phi_vox, compact=True, dpi=dpi) | |
| if b64 is not None: | |
| return b64 | |
| # Fallback: axial scatter | |
| if _voxel_coords is None: | |
| return None | |
| vmax = float(np.abs(phi_vox).max()) or 1e-6 | |
| fig, ax = plt.subplots(1, 1, figsize=figsize, facecolor='#f8f8f8') | |
| ax.scatter(_voxel_coords[:, 0], _voxel_coords[:, 1], | |
| c=phi_vox, cmap='RdBu_r', s=3, alpha=0.8, | |
| vmin=-vmax, vmax=vmax, rasterized=True, marker='s') | |
| ax.set_aspect('equal'); ax.set_xticks([]); ax.set_yticks([]) | |
| ax.set_facecolor('#f8f8f8') | |
| fig.tight_layout(pad=0.2) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight', facecolor='#f8f8f8') | |
| plt.close(fig) | |
| return base64.b64encode(buf.getvalue()).decode('utf-8') | |
| def _render_cortical_profile_b64(feat: int) -> str | None: | |
| """Base64 PNG of cortical profile on fsaverage5 surface (TribeV2-style).""" | |
| from .state import active_ds | |
| cached = active_ds().get('cortical_profile_cache', {}).get(feat) | |
| if cached is not None: | |
| return cached | |
| phi_vox = phi_voxel_row(feat) | |
| if phi_vox is None: | |
| return None | |
| phi_c_val = phi_c_for_feat(feat) | |
| phi_c_str = f' (φ_c = {phi_c_val:.4f})' if phi_c_val is not None else '' | |
| title_str = f'Cortical Profile — Feature {feat}{phi_c_str}' | |
| b64 = _render_brain_surface_b64(phi_vox, title=title_str, | |
| cbar_label='Φ weight', dpi=90) | |
| if b64 is not None: | |
| return b64 | |
| # Fallback: 2-view axial/coronal scatter | |
| if _voxel_coords is None: | |
| return None | |
| vmax = float(np.abs(phi_vox).max()) or 1e-6 | |
| fig, axes = plt.subplots(1, 2, figsize=(10, 4.0), facecolor='#f8f8f8') | |
| for ax, (t, xi, yi) in zip(axes, [("Axial (x–y)", 0, 1), | |
| ("Coronal (x–z)", 0, 2)]): | |
| sc = ax.scatter(_voxel_coords[:, xi], _voxel_coords[:, yi], | |
| c=phi_vox, cmap='RdBu_r', s=4, alpha=0.75, | |
| vmin=-vmax, vmax=vmax, rasterized=True, marker='s') | |
| ax.set_title(t, fontsize=10); ax.set_aspect('equal') | |
| ax.set_xticks([]); ax.set_yticks([]); ax.set_facecolor('#f8f8f8') | |
| fig.subplots_adjust(right=0.88, top=0.88) | |
| cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.65]) | |
| fig.colorbar(sc, cax=cbar_ax).set_label('Φ weight', fontsize=9) | |
| fig.suptitle(title_str, fontsize=11) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=90, bbox_inches='tight', facecolor='#f8f8f8') | |
| plt.close(fig) | |
| return base64.b64encode(buf.getvalue()).decode('utf-8') | |
| def get_dd_fmri(sample_idx: int) -> np.ndarray | None: | |
| """Return raw fMRI (N_VOXELS,) for a DynaDiff sample index, or None.""" | |
| if _dd_loader is None: | |
| return None | |
| try: | |
| return _dd_loader.get_fmri(sample_idx) | |
| except Exception: | |
| return None | |
| def apply_steering_fmri(fmri: np.ndarray, steerings: list) -> np.ndarray: | |
| """Apply steering perturbations to fMRI in-place (numpy). | |
| steerings: list of (phi_voxel np.ndarray, lam float, threshold float) | |
| """ | |
| if _dd_loader is None: | |
| return fmri | |
| beta_std = _dd_loader.beta_std | |
| if beta_std is None: | |
| return fmri | |
| result = fmri.copy() | |
| for phi_voxel, lam, thr in steerings: | |
| if phi_voxel is None: | |
| continue | |
| phi_max = float(np.abs(phi_voxel).max()) | |
| if phi_max < 1e-12: | |
| continue | |
| scale = beta_std / phi_max | |
| if thr < 1.0: | |
| cutoff = float(np.percentile(np.abs(phi_voxel), 100.0 * (1.0 - thr))) | |
| mask = np.abs(phi_voxel) >= cutoff | |
| else: | |
| mask = np.ones(len(phi_voxel), dtype=bool) | |
| perturb = lam * scale * phi_voxel | |
| perturb[~mask] = 0.0 | |
| result += perturb | |
| return result | |
| def render_fmri_brain_compact_b64(fmri_voxels: np.ndarray, | |
| title: str = '') -> str | None: | |
| """Compact left-lateral surface view of fMRI voxel activity, returns base64 PNG.""" | |
| if fmri_voxels is None or _voxel_coords is None: | |
| return None | |
| while fmri_voxels.ndim > 1: | |
| fmri_voxels = fmri_voxels.mean(axis=-1) | |
| b64 = _render_brain_surface_b64(fmri_voxels, title=title, compact=True, dpi=70) | |
| if b64 is not None: | |
| return b64 | |
| # Fallback: axial scatter | |
| vmax = float(np.abs(fmri_voxels).max()) or 1e-6 | |
| fig, ax = plt.subplots(1, 1, figsize=(3.5, 2.8), facecolor='#f8f8f8') | |
| ax.scatter(_voxel_coords[:, 0], _voxel_coords[:, 1], | |
| c=fmri_voxels, cmap='RdBu_r', s=3, alpha=0.8, | |
| vmin=-vmax, vmax=vmax, rasterized=True, marker='s') | |
| ax.set_aspect('equal'); ax.set_xticks([]); ax.set_yticks([]) | |
| ax.set_facecolor('#f8f8f8') | |
| if title: | |
| ax.set_title(title, fontsize=9) | |
| fig.tight_layout(pad=0.2) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=70, bbox_inches='tight', facecolor='#f8f8f8') | |
| plt.close(fig) | |
| return base64.b64encode(buf.getvalue()).decode('utf-8') | |
| def render_cortical_profile(feat: int) -> str: | |
| """Two-view scatter of phi voxel weights as an inline PNG HTML block.""" | |
| b64 = _render_cortical_profile_b64(feat) | |
| if b64 is None: | |
| return "" | |
| return ( | |
| '<h3 style="margin:4px 0 6px 0;color:#333;border-bottom:2px solid #e0e0e0;' | |
| 'padding-bottom:4px">Cortical Profile (Φ)</h3>' | |
| f'<img src="data:image/png;base64,{b64}" ' | |
| 'style="max-width:100%;border-radius:4px;border:1px solid #ddd"/>' | |
| ) | |