PRISM-JEPA — OGBench Cube (sim)

JEPA world model + PRISM action prior for the OGBench cube-single manipulation task. These are the exact weights used to produce the headline PRISM-MPPI number in the paper.

Project page: yuhaiw.github.io/PRISM_web Sister repo for PushT: YuhaiW/prism-jepa-pusht Code: YuhaiW/prism-jepa


Headline result (mean ± std over 3 seeds {0, 1, 42}, K = 128)

Vanilla MPPI BC-only PRISM-MPPI (s = 1)
Cube SR (%) 44.0 66.0 79.3 ± 6.1

s = 1 is the only PRISM-specific hyperparameter; see paper §4.4 for the sigma-scale sweep.


Bundle

File Size Role
lewm_object.ckpt ~72 MB Pickled LeWM (frozen JEPA encoder + AR predictor)
prior_head_cube.pt ~2 MB PRISM prior head (3-layer MLP, β-NLL β=0.5, σ-floor 0.05)
jepa.py, module.py ~10 KB Model classes (needed to unpickle the LeWM ckpt)
prior_head.py ~3 KB PriorHead class
requirements.txt <1 KB Pinned runtime dependencies
README.md This file

Reproduce the paper result

# 1. Clone the eval/training code
git clone https://github.com/YuhaiW/prism-jepa.git
cd prism-jepa
uv venv --python=3.10 && source .venv/bin/activate
uv pip install stable-worldmodel[train]
uv pip install opencv-python pygame mujoco pymunk scikit-image hdf5plugin
export STABLEWM_HOME=$PWD/.stable-wm

# 2. Pull the weights from this repo
pip install huggingface_hub
hf download YuhaiW/prism-jepa-cube --local-dir ./hf_cube
mkdir -p $STABLEWM_HOME/cube
mv hf_cube/lewm_object.ckpt $STABLEWM_HOME/cube/
mv hf_cube/prior_head_cube.pt .

# 3. Run PRISM-MPPI (paper main result)
python eval_prism_head.py --config-name=cube policy=cube/lewm solver=mppi \
    +head.injection_mode=pog +head.sigma_scale=1.0 \
    +head.ckpt=prior_head_cube.pt \
    solver.num_samples=128 eval.num_eval=50 seed=0
# repeat with seed=1, seed=42 to reproduce the mean (~79%)

The eval also needs the OGBench cube dataset (used for normalization stats at eval time). See the upstream LeWM collection quentinll/lewm and drop cube_single_expert.h5 under $STABLEWM_HOME/ogbench/.

Vanilla MPPI baseline (no prior)

python eval_prism_head.py --config-name=cube policy=cube/lewm solver=mppi \
    +head.injection_mode=none solver.num_samples=128 eval.num_eval=50 seed=0

Training recipe

The world model was trained from scratch on OGBench cube-single-expert following the upstream LeWM recipe (python train.py data=cube). The prior head was then trained with the world model frozen:

python train_prior_head.py task=cube epochs=50 batch_size=512

β-NLL loss (β = 0.5), σ floored at 0.05, AdamW, cosine LR. ~30 min on a single RTX 5090.


How PRISM-MPPI works (one paragraph)

A standard MPPI planner samples action sequences from N(0, σ_π²) and scores them by ‖ẑ_{t+H} − z_g‖² in JEPA latent space. PRISM trains a lightweight prior head g_φ(z_t, z_g) → (μ_p, σ_p) from offline demonstrations, then fuses it with the planner's default sampling distribution at the initial step via the closed-form Product-of-Gaussians:

σ_init² = ((s·σ_p)⁻² + σ_π⁻²)⁻¹
μ_init  = σ_init² · μ_p / (s·σ_p)²

The MPPI cost stays purely visual (embedding MSE to goal) — no reward, no Q-shortcut. PRISM only re-shapes where samples are drawn from, not how they are scored, which is why the eval-time goal mismatch that hurts pure BC-style policies does not hurt PRISM-MPPI.


Citation

BibTeX TBA — paper under review.

License

MIT. World-model code vendored from LeWM retains its upstream MIT copyright.

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading