Marlin Lee
Fall back to CPU for backbone+SAE inference when no CUDA available
02c5c77
"""
Lazy loaders for the CLIP model and GPU backbone+SAE runner.
Both are loaded at most once, on first use. Neither is required:
- CLIP is needed only for free-text feature search
- The GPU runner is needed only for live patch-activation inference
"""
import os
import sys
import numpy as np
import torch
from .args import args
from .state import _all_datasets
# ---------- CLIP ----------
_clip_handle = None # (model, processor, device) once loaded
def get_clip():
"""Return (model, processor, device), loading CLIP on first call."""
global _clip_handle
if _clip_handle is None:
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
from clip_utils import load_clip
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"[CLIP] Loading {args.clip_model} on {dev} (first free-text query)...")
model, processor = load_clip(dev, model_name=args.clip_model)
_clip_handle = (model, processor, dev)
print("[CLIP] Ready.")
return _clip_handle
# ---------- GPU backbone + SAE ----------
_gpu_runner = None # (fwd_fn, sae, transform_fn, n_reg, extract_fn, backbone_name, device)
def get_gpu_runner():
"""Return the runner tuple, loading on first call. Returns None if unavailable."""
global _gpu_runner
if _gpu_runner is not None:
return _gpu_runner
if not args.sae_path:
return None
src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
sys.path.insert(0, src_dir)
from backbone_runners import load_batched_backbone
from precompute_utils import load_sae, extract_tokens
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
d_model = _all_datasets[0]['d_model'] # SAE output dim fixed to primary dataset
print(f"[GPU runner] Loading {args.backbone} layer {args.layer} + SAE on {dev} ...")
fwd, d_hidden, n_reg, tfm = load_batched_backbone(args.backbone, args.layer, dev)
sae = load_sae(args.sae_path, d_hidden, d_model, args.top_k, dev)
_gpu_runner = (fwd, sae, tfm, n_reg, extract_tokens, args.backbone, dev)
print("[GPU runner] Ready.")
return _gpu_runner
def run_gpu_inference(pil_img) -> np.ndarray | None:
"""Run pil_img through backbone→SAE; return (n_patches, d_sae) float32 or None."""
runner = get_gpu_runner()
if runner is None:
return None
fwd, sae, tfm, n_reg, extract_tokens, backbone_name, dev = runner
tensor = tfm(pil_img).unsqueeze(0).to(dev)
with torch.inference_mode():
hidden = fwd(tensor)
tokens = extract_tokens(hidden, backbone_name, 'spatial', n_reg)
flat = tokens.reshape(-1, tokens.shape[-1])
_, z, _ = sae(flat)
print(f"[GPU runner] z shape={z.shape}, "
f"nonzero={int((z > 0).sum())}, max={float(z.max()):.4f}")
return z.cpu().float().numpy()
def warmup_gpu_runner():
"""Load the GPU runner in a background thread so the first patch request is fast."""
import threading
if args.sae_path:
def _warmup():
try:
get_gpu_runner()
except Exception as e:
print(f"[GPU runner] Warmup failed: {e}")
threading.Thread(target=_warmup, daemon=True).start()