Spaces:
Sleeping
Sleeping
| """ | |
| DynaDiff in-process loader. | |
| Loads the DynaDiff model and exposes a reconstruct() method that returns the | |
| same dict format as the HTTP server's /reconstruct endpoint: | |
| { | |
| "baseline_img": "<base64 PNG>", | |
| "steered_img": "<base64 PNG>", | |
| "gt_img": "<base64 PNG> | None", | |
| "beta_std": float, | |
| } | |
| Usage (in explorer_app.py): | |
| from dynadiff_loader import DynaDiffLoader | |
| loader = DynaDiffLoader(dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir) | |
| loader.start() # begins background model load | |
| loader.n_samples # None until ready | |
| loader.is_ready # True when model is loaded | |
| result = loader.reconstruct(sample_idx, steerings, seed) | |
| """ | |
| import base64 | |
| import io | |
| import logging | |
| import os | |
| import threading | |
| import numpy as np | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='[DynaDiff %(levelname)s %(asctime)s] %(message)s', | |
| datefmt='%H:%M:%S', | |
| ) | |
| log = logging.getLogger(__name__) | |
| N_VOXELS = 15724 | |
| # ββ Process-level singleton βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Bokeh re-executes the app script per session, so DynaDiffLoader would be | |
| # instantiated multiple times. We keep one loader alive for the whole process | |
| # so the model is loaded exactly once and all sessions share it. | |
| _singleton: "DynaDiffLoader | None" = None | |
| _singleton_lock = threading.Lock() | |
| def get_loader(dynadiff_dir, checkpoint, h5_path, | |
| nsd_thumb_dir=None, subject_idx=0) -> "DynaDiffLoader": | |
| """Return the process-level loader, creating and starting it if needed.""" | |
| global _singleton | |
| with _singleton_lock: | |
| if _singleton is None: | |
| _singleton = DynaDiffLoader( | |
| dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir, subject_idx) | |
| _singleton.start() | |
| return _singleton | |
| def _img_to_b64(img_np): | |
| """(H, W, 3) float32 [0,1] β base64 PNG string.""" | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| buf = io.BytesIO() | |
| plt.imsave(buf, np.clip(img_np, 0, 1), format='png') | |
| return base64.b64encode(buf.getvalue()).decode('utf-8') | |
| class DynaDiffLoader: | |
| def __init__(self, dynadiff_dir, checkpoint, h5_path, | |
| nsd_thumb_dir=None, subject_idx=0): | |
| self.dynadiff_dir = os.path.abspath(dynadiff_dir) | |
| self.checkpoint = checkpoint | |
| self.h5_path = h5_path if os.path.isabs(h5_path) \ | |
| else os.path.join(self.dynadiff_dir, h5_path) | |
| self.nsd_thumb_dir = nsd_thumb_dir | |
| self.subject_idx = subject_idx | |
| self._model = None | |
| self._cfg = None | |
| self._beta_std = None | |
| self._subject_sample_indices = None | |
| self._nsd_to_sample = {} | |
| self._status = 'loading' # 'loading' | 'ok' | 'error' | |
| self._error = '' | |
| self._lock = threading.Lock() | |
| # ββ public properties ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def is_ready(self): | |
| with self._lock: | |
| return self._status == 'ok' | |
| def status(self): | |
| with self._lock: | |
| return self._status, self._error | |
| def n_samples(self): | |
| with self._lock: | |
| idx = self._subject_sample_indices | |
| return len(idx) if idx is not None else None | |
| def sample_idxs_for_nsd_img(self, nsd_img_idx): | |
| """Return the list of sample_idx values that correspond to a given NSD image index. | |
| Returns an empty list if the image has no trials for this subject or the | |
| mapping is not yet built (model still loading). | |
| """ | |
| with self._lock: | |
| return list(self._nsd_to_sample.get(int(nsd_img_idx), [])) | |
| def start(self): | |
| """Start background model loading thread.""" | |
| t = threading.Thread(target=self._load, daemon=True) | |
| t.start() | |
| # ββ model loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load(self): | |
| try: | |
| import sys | |
| import torch | |
| import h5py | |
| # Inject dynadiff paths before any imports from those packages | |
| dynadiff_diffusers = os.path.join(self.dynadiff_dir, 'diffusers', 'src') | |
| for p in [self.dynadiff_dir, dynadiff_diffusers]: | |
| if p not in sys.path: | |
| sys.path.insert(0, p) | |
| # Pre-import torchvision so it is fully initialised before dynadiff's | |
| # diffusers fork pulls it in. Without this, torchvision.transforms can | |
| # end up in a partially-initialised state, causing | |
| # "cannot import name 'InterpolationMode' from partially initialized | |
| # module 'torchvision.transforms'". | |
| import torchvision.transforms # noqa: F401 | |
| import torchvision.transforms.functional # noqa: F401 | |
| # Bokeh's code_runner does os.chdir(original_cwd) in its finally | |
| # block after every session's app script, so we cannot rely on cwd | |
| # being stable across the slow imports below. Build the config | |
| # entirely from absolute paths so no cwd dependency exists. | |
| orig_dir = os.getcwd() | |
| _vd_cache = os.path.join(self.dynadiff_dir, 'versatile_diffusion') | |
| _cache_dir = os.path.join(self.dynadiff_dir, 'cache') | |
| _local_infra = {'cluster': None, 'folder': _cache_dir} | |
| print('[DynaDiff] importing dynadiff modules...', flush=True) | |
| from exca import ConfDict | |
| print('[DynaDiff] exca imported', flush=True) | |
| _cfg_yaml = os.path.join(self.dynadiff_dir, 'config', 'config.yaml') | |
| with open(_cfg_yaml, 'r') as f: | |
| cfg = ConfDict.from_yaml(f) | |
| cfg['versatilediffusion_config.vd_cache_dir'] = _vd_cache | |
| cfg['seed'] = 42 | |
| cfg['data.nsd_dataset_config.seed'] = 42 | |
| cfg['data.nsd_dataset_config.averaged'] = False | |
| cfg['data.nsd_dataset_config.subject_ids'] = [0] | |
| cfg['infra'] = _local_infra | |
| cfg['data.nsd_dataset_config.infra'] = _local_infra | |
| cfg['image_generation_infra'] = _local_infra | |
| print('[DynaDiff] config loaded', flush=True) | |
| vd_cfg = cfg['versatilediffusion_config'] | |
| from model.models import VersatileDiffusion, VersatileDiffusionConfig | |
| print('[DynaDiff] model.models imported', flush=True) | |
| vd_config = VersatileDiffusionConfig(**vd_cfg) | |
| print('[DynaDiff] VersatileDiffusionConfig built', flush=True) | |
| # Resolve checkpoint | |
| ckpt = self.checkpoint | |
| if not os.path.isabs(ckpt): | |
| candidate_pth = os.path.join(self.dynadiff_dir, ckpt) | |
| candidate_ckpt = os.path.join(self.dynadiff_dir, | |
| 'training_checkpoints', ckpt) | |
| if os.path.isfile(candidate_pth): | |
| ckpt = candidate_pth | |
| elif os.path.isdir(candidate_ckpt): | |
| ckpt = candidate_ckpt | |
| else: | |
| raise FileNotFoundError( | |
| f'Checkpoint not found: tried {candidate_pth} ' | |
| f'and {candidate_ckpt}') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model_args = dict(config=vd_config, | |
| brain_n_in_channels=N_VOXELS, brain_temp_dim=6) | |
| model = VersatileDiffusion(**model_args) | |
| if os.path.isfile(ckpt): | |
| log.info(f'[DynaDiff] Loading state dict from {ckpt} ...') | |
| sd = torch.load(ckpt, map_location=device, weights_only=False) | |
| if any(k.startswith('model.') for k in sd): | |
| sd = {(k[6:] if k.startswith('model.') else k): v | |
| for k, v in sd.items()} | |
| drop = ('eval_fid', 'eval_inceptionlastconv', | |
| 'eval_eff', 'eval_swav', 'eval_lpips') | |
| sd = {k: v for k, v in sd.items() | |
| if not any(k.startswith(p) for p in drop)} | |
| model.load_state_dict(sd, strict=False) | |
| elif os.path.isdir(ckpt): | |
| import deepspeed | |
| log.info(f'[DynaDiff] Consolidating ZeRO checkpoint from {ckpt} ...') | |
| sd = deepspeed.utils.zero_to_fp32 \ | |
| .get_fp32_state_dict_from_zero_checkpoint( | |
| checkpoint_dir=ckpt, tag='checkpoint', | |
| exclude_frozen_parameters=False) | |
| sd = {(k[6:] if k.startswith('model.') else k): v | |
| for k, v in sd.items()} | |
| drop = ('eval_fid', 'eval_inceptionlastconv', | |
| 'eval_eff', 'eval_swav', 'eval_lpips') | |
| sd = {k: v for k, v in sd.items() | |
| if not any(k.startswith(p) for p in drop)} | |
| model.load_state_dict(sd, strict=False) | |
| else: | |
| raise FileNotFoundError(f'Checkpoint not found: {ckpt}') | |
| model.sanity_check_blurry = False | |
| model.to(device) | |
| model.eval() | |
| log.info(f'[DynaDiff] Model loaded on {device}') | |
| # Beta std | |
| log.info(f'[DynaDiff] Computing beta_std from {self.h5_path} ...') | |
| with h5py.File(self.h5_path, 'r') as hf: | |
| n = min(300, hf['fmri'].shape[0]) | |
| beta_std = float(np.array(hf['fmri'][:n]).std(axis=0).mean()) | |
| log.info(f'[DynaDiff] beta_std = {beta_std:.5f}') | |
| # Subject sample index mapping | |
| log.info(f'[DynaDiff] Building sample index for subject {self.subject_idx} ...') | |
| with h5py.File(self.h5_path, 'r') as hf: | |
| all_subj = np.array(hf['subject_idx'][:], dtype=np.int64) | |
| all_imgidx = np.array(hf['image_idx'][:], dtype=np.int64) | |
| sample_indices = np.where(all_subj == self.subject_idx)[0].astype(np.int64) | |
| log.info(f'[DynaDiff] {len(sample_indices)} samples for subject {self.subject_idx}') | |
| # Build reverse map: NSD image index β list of sample_idx values | |
| nsd_to_sample: dict[int, list[int]] = {} | |
| for sample_idx_val, h5_row in enumerate(sample_indices): | |
| nsd_img = int(all_imgidx[h5_row]) | |
| nsd_to_sample.setdefault(nsd_img, []).append(sample_idx_val) | |
| with self._lock: | |
| self._model = model | |
| self._cfg = cfg | |
| self._beta_std = beta_std | |
| self._subject_sample_indices = sample_indices | |
| self._nsd_to_sample = nsd_to_sample | |
| self._status = 'ok' | |
| log.info('[DynaDiff] Ready.') | |
| except Exception as exc: | |
| log.exception('[DynaDiff] Model loading failed') | |
| with self._lock: | |
| self._status = 'error' | |
| self._error = str(exc) | |
| finally: | |
| os.chdir(orig_dir) | |
| # ββ inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reconstruct(self, sample_idx, steerings, seed=42): | |
| """ | |
| steerings: list of (phi_voxel np.ndarray float32, lam float, threshold float) | |
| Returns dict with baseline_img, steered_img, gt_img (base64 PNGs), beta_std. | |
| """ | |
| import torch | |
| with self._lock: | |
| model = self._model | |
| beta_std = self._beta_std | |
| indices = self._subject_sample_indices | |
| if model is None: | |
| raise RuntimeError('Model not loaded yet') | |
| # Map sample_idx β h5 row | |
| if indices is not None: | |
| if not (0 <= sample_idx < len(indices)): | |
| raise IndexError( | |
| f'sample_idx {sample_idx} out of range ' | |
| f'(subject has {len(indices)} samples)') | |
| h5_row = int(indices[sample_idx]) | |
| else: | |
| h5_row = sample_idx | |
| import h5py | |
| with h5py.File(self.h5_path, 'r') as hf: | |
| fmri = torch.from_numpy( | |
| np.array(hf['fmri'][h5_row], dtype=np.float32)).unsqueeze(0) | |
| img_idx = int(hf['image_idx'][h5_row]) | |
| device = next(model.parameters()).device | |
| dtype = next(model.parameters()).dtype | |
| # Apply steering perturbations | |
| steered_fmri = fmri.clone() | |
| for phi_voxel, lam, threshold in steerings: | |
| steered_fmri = self._apply_steering( | |
| steered_fmri, phi_voxel, lam, beta_std, threshold, device) | |
| baseline = self._decode(model, fmri, device, dtype, seed) | |
| steered = self._decode(model, steered_fmri, device, dtype, seed) | |
| gt_img = self._load_gt_image(img_idx) | |
| return { | |
| 'baseline_img': _img_to_b64(baseline), | |
| 'steered_img': _img_to_b64(steered), | |
| 'gt_img': _img_to_b64(gt_img) if gt_img is not None else None, | |
| 'beta_std': float(beta_std), | |
| } | |
| def _apply_steering(fmri_tensor, phi_voxel, lam, beta_std, threshold, device): | |
| import torch | |
| if lam == 0.0: | |
| return fmri_tensor.clone() | |
| steered = fmri_tensor.clone().to(device=device) | |
| phi_t = torch.from_numpy(phi_voxel).to(dtype=steered.dtype, device=device) | |
| phi_max = phi_t.abs().max().item() | |
| scale = (beta_std / phi_max) if phi_max > 1e-12 else 1.0 | |
| if threshold < 1.0: | |
| cutoff = float(np.percentile(np.abs(phi_voxel), 100 * (1 - threshold))) | |
| mask = torch.from_numpy(np.abs(phi_voxel) >= cutoff).to(device) | |
| else: | |
| mask = torch.ones(N_VOXELS, dtype=torch.bool, device=device) | |
| perturbation = lam * scale * phi_t | |
| perturbation[~mask] = 0.0 | |
| if steered.dim() == 3: | |
| steered[0, :, :] += perturbation.unsqueeze(-1) | |
| else: | |
| steered[0, :] += perturbation | |
| return steered | |
| def _decode(model, fmri_tensor, device, dtype, seed=42, | |
| guidance_scale=3.5, img2img_strength=0.85): | |
| encoding = model.get_condition( | |
| fmri_tensor.to(device=device, dtype=dtype), | |
| __import__('torch').tensor([0], device=device), | |
| ) | |
| output = model.reconstruction_from_clipbrainimage( | |
| encoding, seed=seed, guidance_scale=guidance_scale, | |
| img2img_strength=img2img_strength) | |
| recon = output.image[0].cpu().float().permute(1, 2, 0).numpy() | |
| return np.clip(recon, 0, 1) | |
| def _load_gt_image(self, image_idx): | |
| """Load GT stimulus: thumbnail first, raw H5 fallback.""" | |
| if self.nsd_thumb_dir: | |
| thumb = os.path.join(self.nsd_thumb_dir, f'nsd_{image_idx:05d}.jpg') | |
| try: | |
| from PIL import Image as _PIL | |
| return np.array(_PIL.open(thumb).convert('RGB'), | |
| dtype=np.float32) / 255.0 | |
| except Exception as e: | |
| log.warning(f'[DynaDiff] thumb load failed ({thumb}): {e}') | |
| # H5 fallback β only works if train_unaveraged.h5 is present | |
| try: | |
| import h5py | |
| train_h5 = os.path.join(self.dynadiff_dir, | |
| 'processed_nsd_data', 'train_unaveraged.h5') | |
| if not os.path.exists(train_h5): | |
| return None | |
| with h5py.File(train_h5, 'r') as hf: | |
| img = np.array(hf['images'][image_idx], dtype=np.float32) | |
| return np.clip(img, 0, 1) | |
| except Exception as e: | |
| log.warning(f'[DynaDiff] GT image load failed (idx={image_idx}): {e}') | |
| return None | |