SAE_Brain_Semantic_Interface / scripts /dynadiff_loader.py
Marlin Lee
Sync explorer: UI improvements, bug fixes, DynaDiff NSD index fix
0a17b8d
"""
DynaDiff in-process loader.
Loads the DynaDiff model and exposes a reconstruct() method that returns the
same dict format as the HTTP server's /reconstruct endpoint:
{
"baseline_img": "<base64 PNG>",
"steered_img": "<base64 PNG>",
"gt_img": "<base64 PNG> | None",
"beta_std": float,
}
Usage (in explorer_app.py):
from dynadiff_loader import DynaDiffLoader
loader = DynaDiffLoader(dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir)
loader.start() # begins background model load
loader.n_samples # None until ready
loader.is_ready # True when model is loaded
result = loader.reconstruct(sample_idx, steerings, seed)
"""
import base64
import io
import logging
import os
import threading
import numpy as np
logging.basicConfig(
level=logging.INFO,
format='[DynaDiff %(levelname)s %(asctime)s] %(message)s',
datefmt='%H:%M:%S',
)
log = logging.getLogger(__name__)
N_VOXELS = 15724
# ── Process-level singleton ───────────────────────────────────────────────────
# Bokeh re-executes the app script per session, so DynaDiffLoader would be
# instantiated multiple times. We keep one loader alive for the whole process
# so the model is loaded exactly once and all sessions share it.
_singleton: "DynaDiffLoader | None" = None
_singleton_lock = threading.Lock()
def get_loader(dynadiff_dir, checkpoint, h5_path,
nsd_thumb_dir=None, subject_idx=0) -> "DynaDiffLoader":
"""Return the process-level loader, creating and starting it if needed."""
global _singleton
with _singleton_lock:
if _singleton is None:
_singleton = DynaDiffLoader(
dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir, subject_idx)
_singleton.start()
return _singleton
def _img_to_b64(img_np):
"""(H, W, 3) float32 [0,1] β†’ base64 PNG string."""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
buf = io.BytesIO()
plt.imsave(buf, np.clip(img_np, 0, 1), format='png')
return base64.b64encode(buf.getvalue()).decode('utf-8')
class DynaDiffLoader:
def __init__(self, dynadiff_dir, checkpoint, h5_path,
nsd_thumb_dir=None, subject_idx=0):
self.dynadiff_dir = os.path.abspath(dynadiff_dir)
self.checkpoint = checkpoint
self.h5_path = h5_path if os.path.isabs(h5_path) \
else os.path.join(self.dynadiff_dir, h5_path)
self.nsd_thumb_dir = nsd_thumb_dir
self.subject_idx = subject_idx
self._model = None
self._cfg = None
self._beta_std = None
self._subject_sample_indices = None
self._nsd_to_sample = {}
self._status = 'loading' # 'loading' | 'ok' | 'error'
self._error = ''
self._lock = threading.Lock()
# ── public properties ────────────────────────────────────────────────────
@property
def is_ready(self):
with self._lock:
return self._status == 'ok'
@property
def status(self):
with self._lock:
return self._status, self._error
@property
def n_samples(self):
with self._lock:
idx = self._subject_sample_indices
return len(idx) if idx is not None else None
def sample_idxs_for_nsd_img(self, nsd_img_idx):
"""Return the list of sample_idx values that correspond to a given NSD image index.
Returns an empty list if the image has no trials for this subject or the
mapping is not yet built (model still loading).
"""
with self._lock:
return list(self._nsd_to_sample.get(int(nsd_img_idx), []))
def start(self):
"""Start background model loading thread."""
t = threading.Thread(target=self._load, daemon=True)
t.start()
# ── model loading ────────────────────────────────────────────────────────
def _load(self):
try:
import sys
import torch
import h5py
# Inject dynadiff paths before any imports from those packages
dynadiff_diffusers = os.path.join(self.dynadiff_dir, 'diffusers', 'src')
for p in [self.dynadiff_dir, dynadiff_diffusers]:
if p not in sys.path:
sys.path.insert(0, p)
# Pre-import torchvision so it is fully initialised before dynadiff's
# diffusers fork pulls it in. Without this, torchvision.transforms can
# end up in a partially-initialised state, causing
# "cannot import name 'InterpolationMode' from partially initialized
# module 'torchvision.transforms'".
import torchvision.transforms # noqa: F401
import torchvision.transforms.functional # noqa: F401
# Bokeh's code_runner does os.chdir(original_cwd) in its finally
# block after every session's app script, so we cannot rely on cwd
# being stable across the slow imports below. Build the config
# entirely from absolute paths so no cwd dependency exists.
orig_dir = os.getcwd()
_vd_cache = os.path.join(self.dynadiff_dir, 'versatile_diffusion')
_cache_dir = os.path.join(self.dynadiff_dir, 'cache')
_local_infra = {'cluster': None, 'folder': _cache_dir}
print('[DynaDiff] importing dynadiff modules...', flush=True)
from exca import ConfDict
print('[DynaDiff] exca imported', flush=True)
_cfg_yaml = os.path.join(self.dynadiff_dir, 'config', 'config.yaml')
with open(_cfg_yaml, 'r') as f:
cfg = ConfDict.from_yaml(f)
cfg['versatilediffusion_config.vd_cache_dir'] = _vd_cache
cfg['seed'] = 42
cfg['data.nsd_dataset_config.seed'] = 42
cfg['data.nsd_dataset_config.averaged'] = False
cfg['data.nsd_dataset_config.subject_ids'] = [0]
cfg['infra'] = _local_infra
cfg['data.nsd_dataset_config.infra'] = _local_infra
cfg['image_generation_infra'] = _local_infra
print('[DynaDiff] config loaded', flush=True)
vd_cfg = cfg['versatilediffusion_config']
from model.models import VersatileDiffusion, VersatileDiffusionConfig
print('[DynaDiff] model.models imported', flush=True)
vd_config = VersatileDiffusionConfig(**vd_cfg)
print('[DynaDiff] VersatileDiffusionConfig built', flush=True)
# Resolve checkpoint
ckpt = self.checkpoint
if not os.path.isabs(ckpt):
candidate_pth = os.path.join(self.dynadiff_dir, ckpt)
candidate_ckpt = os.path.join(self.dynadiff_dir,
'training_checkpoints', ckpt)
if os.path.isfile(candidate_pth):
ckpt = candidate_pth
elif os.path.isdir(candidate_ckpt):
ckpt = candidate_ckpt
else:
raise FileNotFoundError(
f'Checkpoint not found: tried {candidate_pth} '
f'and {candidate_ckpt}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_args = dict(config=vd_config,
brain_n_in_channels=N_VOXELS, brain_temp_dim=6)
model = VersatileDiffusion(**model_args)
if os.path.isfile(ckpt):
log.info(f'[DynaDiff] Loading state dict from {ckpt} ...')
sd = torch.load(ckpt, map_location=device, weights_only=False)
if any(k.startswith('model.') for k in sd):
sd = {(k[6:] if k.startswith('model.') else k): v
for k, v in sd.items()}
drop = ('eval_fid', 'eval_inceptionlastconv',
'eval_eff', 'eval_swav', 'eval_lpips')
sd = {k: v for k, v in sd.items()
if not any(k.startswith(p) for p in drop)}
model.load_state_dict(sd, strict=False)
elif os.path.isdir(ckpt):
import deepspeed
log.info(f'[DynaDiff] Consolidating ZeRO checkpoint from {ckpt} ...')
sd = deepspeed.utils.zero_to_fp32 \
.get_fp32_state_dict_from_zero_checkpoint(
checkpoint_dir=ckpt, tag='checkpoint',
exclude_frozen_parameters=False)
sd = {(k[6:] if k.startswith('model.') else k): v
for k, v in sd.items()}
drop = ('eval_fid', 'eval_inceptionlastconv',
'eval_eff', 'eval_swav', 'eval_lpips')
sd = {k: v for k, v in sd.items()
if not any(k.startswith(p) for p in drop)}
model.load_state_dict(sd, strict=False)
else:
raise FileNotFoundError(f'Checkpoint not found: {ckpt}')
model.sanity_check_blurry = False
model.to(device)
model.eval()
log.info(f'[DynaDiff] Model loaded on {device}')
# Beta std
log.info(f'[DynaDiff] Computing beta_std from {self.h5_path} ...')
with h5py.File(self.h5_path, 'r') as hf:
n = min(300, hf['fmri'].shape[0])
beta_std = float(np.array(hf['fmri'][:n]).std(axis=0).mean())
log.info(f'[DynaDiff] beta_std = {beta_std:.5f}')
# Subject sample index mapping
log.info(f'[DynaDiff] Building sample index for subject {self.subject_idx} ...')
with h5py.File(self.h5_path, 'r') as hf:
all_subj = np.array(hf['subject_idx'][:], dtype=np.int64)
all_imgidx = np.array(hf['image_idx'][:], dtype=np.int64)
sample_indices = np.where(all_subj == self.subject_idx)[0].astype(np.int64)
log.info(f'[DynaDiff] {len(sample_indices)} samples for subject {self.subject_idx}')
# Build reverse map: NSD image index β†’ list of sample_idx values
nsd_to_sample: dict[int, list[int]] = {}
for sample_idx_val, h5_row in enumerate(sample_indices):
nsd_img = int(all_imgidx[h5_row])
nsd_to_sample.setdefault(nsd_img, []).append(sample_idx_val)
with self._lock:
self._model = model
self._cfg = cfg
self._beta_std = beta_std
self._subject_sample_indices = sample_indices
self._nsd_to_sample = nsd_to_sample
self._status = 'ok'
log.info('[DynaDiff] Ready.')
except Exception as exc:
log.exception('[DynaDiff] Model loading failed')
with self._lock:
self._status = 'error'
self._error = str(exc)
finally:
os.chdir(orig_dir)
# ── inference ────────────────────────────────────────────────────────────
def reconstruct(self, sample_idx, steerings, seed=42):
"""
steerings: list of (phi_voxel np.ndarray float32, lam float, threshold float)
Returns dict with baseline_img, steered_img, gt_img (base64 PNGs), beta_std.
"""
import torch
with self._lock:
model = self._model
beta_std = self._beta_std
indices = self._subject_sample_indices
if model is None:
raise RuntimeError('Model not loaded yet')
# Map sample_idx β†’ h5 row
if indices is not None:
if not (0 <= sample_idx < len(indices)):
raise IndexError(
f'sample_idx {sample_idx} out of range '
f'(subject has {len(indices)} samples)')
h5_row = int(indices[sample_idx])
else:
h5_row = sample_idx
import h5py
with h5py.File(self.h5_path, 'r') as hf:
fmri = torch.from_numpy(
np.array(hf['fmri'][h5_row], dtype=np.float32)).unsqueeze(0)
img_idx = int(hf['image_idx'][h5_row])
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Apply steering perturbations
steered_fmri = fmri.clone()
for phi_voxel, lam, threshold in steerings:
steered_fmri = self._apply_steering(
steered_fmri, phi_voxel, lam, beta_std, threshold, device)
baseline = self._decode(model, fmri, device, dtype, seed)
steered = self._decode(model, steered_fmri, device, dtype, seed)
gt_img = self._load_gt_image(img_idx)
return {
'baseline_img': _img_to_b64(baseline),
'steered_img': _img_to_b64(steered),
'gt_img': _img_to_b64(gt_img) if gt_img is not None else None,
'beta_std': float(beta_std),
}
@staticmethod
def _apply_steering(fmri_tensor, phi_voxel, lam, beta_std, threshold, device):
import torch
if lam == 0.0:
return fmri_tensor.clone()
steered = fmri_tensor.clone().to(device=device)
phi_t = torch.from_numpy(phi_voxel).to(dtype=steered.dtype, device=device)
phi_max = phi_t.abs().max().item()
scale = (beta_std / phi_max) if phi_max > 1e-12 else 1.0
if threshold < 1.0:
cutoff = float(np.percentile(np.abs(phi_voxel), 100 * (1 - threshold)))
mask = torch.from_numpy(np.abs(phi_voxel) >= cutoff).to(device)
else:
mask = torch.ones(N_VOXELS, dtype=torch.bool, device=device)
perturbation = lam * scale * phi_t
perturbation[~mask] = 0.0
if steered.dim() == 3:
steered[0, :, :] += perturbation.unsqueeze(-1)
else:
steered[0, :] += perturbation
return steered
@staticmethod
@__import__('torch').no_grad()
def _decode(model, fmri_tensor, device, dtype, seed=42,
guidance_scale=3.5, img2img_strength=0.85):
encoding = model.get_condition(
fmri_tensor.to(device=device, dtype=dtype),
__import__('torch').tensor([0], device=device),
)
output = model.reconstruction_from_clipbrainimage(
encoding, seed=seed, guidance_scale=guidance_scale,
img2img_strength=img2img_strength)
recon = output.image[0].cpu().float().permute(1, 2, 0).numpy()
return np.clip(recon, 0, 1)
def _load_gt_image(self, image_idx):
"""Load GT stimulus: thumbnail first, raw H5 fallback."""
if self.nsd_thumb_dir:
thumb = os.path.join(self.nsd_thumb_dir, f'nsd_{image_idx:05d}.jpg')
try:
from PIL import Image as _PIL
return np.array(_PIL.open(thumb).convert('RGB'),
dtype=np.float32) / 255.0
except Exception as e:
log.warning(f'[DynaDiff] thumb load failed ({thumb}): {e}')
# H5 fallback β€” only works if train_unaveraged.h5 is present
try:
import h5py
train_h5 = os.path.join(self.dynadiff_dir,
'processed_nsd_data', 'train_unaveraged.h5')
if not os.path.exists(train_h5):
return None
with h5py.File(train_h5, 'r') as hf:
img = np.array(hf['images'][image_idx], dtype=np.float32)
return np.clip(img, 0, 1)
except Exception as e:
log.warning(f'[DynaDiff] GT image load failed (idx={image_idx}): {e}')
return None