LeWM + PRISM-MPPI for Franka PushT (v3 β€” 411-ep dataset, 200-ep target-clean prior)

An action-conditioned latent world model (LeWM) paired with a PRISM action prior, both trained on real Franka FR3 PushT teleoperation data. Supports three planning modes: PRISM-MPPI (PoG fusion, default), warm-start, and vanilla MPPI β€” selectable at deploy time for A/B comparison.

What's new vs v2

Component v2 v3
Training dataset 304 eps / 72 k frames 411 eps / 94 k frames
WM val_pred_loss 0.0045 0.0046 (similar)
CV @ H=5 0.180 0.195
pred/id @ H=1 0.465 0.402 (better short-h fidelity)
PRISM action prior ❌ not included βœ… prior_head_pusht_fr3_v3.pt (H=3, val MSE 0.348)
PRISM-MPPI mode ❌ vanilla only βœ… 3 modes: pog / warm_start / none
Default H (plan-steps) 3 3 (matches both the prior and the WM's deploy sweet spot)

Model summary

World Model PRISM Prior Head
Architecture ViT-tiny encoder + 6-layer AR-Transformer predictor + Embedder action encoder 3-layer MLP, hidden=512, Ξ²-NLL loss
Parameters 18.03 M 0.51 M
Input (224, 224, 3) RGB obs + goal (z_t, z_g) ∈ ℝ^(2 Γ— 192)
Output next-latent prediction (ΞΌ, Οƒ) over action sequence (H=3, A_block=5, action_dim=2) β€” 15 ticks = 1.5 s
Training data All 411 eps (mixed: 200 target + 211 random terminations) eps 0-199 only (target-completion subset)
Goal supervision (n/a β€” predictor) HER hindsight, episode last frame (sim convention, Andrychowicz et al. 2017)
Final val_pred_loss / val MSE 0.0046 0.348 (Οƒ β‰ˆ 0.53 β‰ˆ √MSE, well-calibrated)

Why H=3 (not the sim convention H=5)?

The v3 WM's per-step rollout fidelity (pred/id) is best at short horizons (0.40 @ H=1, 0.186 @ H=5, 0.331 @ H=25 β€” see docs/34). H=3 keeps both the WM rollout and the prior's action sequence within the high-fidelity envelope. Empirically, training the prior at H=3 (vs sim's H=5) gives βˆ’8.7 % val MSE (0.348 vs 0.381) and a tighter Οƒ (0.53 vs 0.58) β€” full ablation in docs/35 Β§11.

Why train the prior on only the first 200 episodes?

Eps 200-410 were collected with no fixed-target pushing β€” the operator stopped T at arbitrary positions, making the episode-last-frame z_g a noisy supervision signal. Training on the full 411 eps with HER endframe yields a broken prior (val MSE 1.63, fails HARD GATE). Restricting to the target-clean subset 0-199 recovers a useful prior (val MSE 0.38, well-calibrated Οƒ). See docs/35 Β§9 for the full ablation.

Plan-worthiness diagnostics (WM only β€” measured on the train distribution)

Metric @ H=5 Value Interpretation
CV 0.195 Β± 0.010 Borderline below 0.30 plan-worthy threshold; PRISM prior helps
GT_rank 36.6 % Β± 2.1 Direction correct ("weak-align" tier)
pred/id @ H=1 0.402 Β± 0.016 Good single-step action conditioning
pred/id @ H=5 0.186 Β± 0.006 Reasonable 2.5 s rollout fidelity
pred/id @ H=25 0.331 Β± 0.006 Long-horizon drift β€” use H ≀ 5

See docs/34 for the full data-scaling analysis (v1 β†’ v2 β†’ v3 monotone CV climb).

Three planning modes

Mode Description When to use
pog (default) PRISM-MPPI: Product-of-Gaussians fusion of prior (ΞΌ, Οƒ) into MPPI init Main deploy mode
warm_start Use prior mean only, keep planner default Οƒ A/B test (isolates Οƒ-fusion's contribution)
none Vanilla LeWM-MPPI (no prior) Paper-grade A/B baseline, or no-prior fallback

All three modes share the same LeWM, MPPI loop, K=300, n_iters=30, and action scaler β€” apples-to-apples comparison.

Quick start

pip install torch torchvision numpy einops transformers huggingface_hub
from huggingface_hub import snapshot_download
import numpy as np

# Download the bundle (WM + prior + inference code)
local = snapshot_download("YuhaiW/lewm-pusht-fr3-v3")

import sys; sys.path.insert(0, local)
from pusht_lewm_inference import PushtLewmInference

# ── PRISM-MPPI deploy (recommended default) ─────────────────────────────
planner = PushtLewmInference(
    lewm_ckpt      = f"{local}/lewm_pusht_fr3_v3.ckpt",
    prior_ckpt     = f"{local}/prior_head_pusht_fr3_v3.pt",
    injection_mode = "pog",         # "pog" | "warm_start" | "none"
    device         = "cuda",
)

# In the robot control loop (10 Hz):
while not done:
    obs_uint8  = camera_rgb()             # (224, 224, 3) uint8
    goal_uint8 = goal_rgb()               # (224, 224, 3) uint8
    actions    = planner.plan(obs_uint8, goal_uint8)
                                          # (5, 2) float32 β€” Ξ”xy meters for next 0.5 s
    for a in actions:
        robot.send_delta_target(a)
        time.sleep(0.1)                   # 10 Hz tick

To A/B test against vanilla MPPI on the same WM:

vanilla = PushtLewmInference(
    lewm_ckpt      = f"{local}/lewm_pusht_fr3_v3.ckpt",
    prior_ckpt     = f"{local}/prior_head_pusht_fr3_v3.pt",   # loaded for scaler
    injection_mode = "none",                                  # disable PoG fusion
)

Robot expectations

Robot Franka FR3 (or compatible) with Cartesian impedance control
Action interpretation Ξ”-target XY in meters (per tick)
Control frequency 10 Hz
Camera Top-down RGB at 224 Γ— 224
Goal image Single RGB showing the desired final scene
Z, rotation, gripper NOT controlled (XY-only by design; lock in your controller)
Teleop style assumed "Decisive" pushes β€” operator commits and pushes in one smooth motion

What's in the bundle

lewm_pusht_fr3_v3.ckpt                 # 72 MB β€” world model (pickled JEPA object)
prior_head_pusht_fr3_v3.pt             # 2 MB  β€” PRISM prior head + StandardScaler + meta
action_scaler.json                     # 0.5 KB β€” fallback scaler when no prior_ckpt
pusht_lewm_inference.py                # standalone PRISM-MPPI planner (3 modes)
jepa.py, module.py                     # required for LeWM ckpt deserialization
prior_head.py                          # required for prior ckpt deserialization
requirements.txt                       # minimal deps
README.md                              # this file

Architecture overview

   obs (224Β²)                goal (224Β²)
        β”‚                          β”‚
        β–Ό                          β–Ό
   ViT-tiny ↑                 ViT-tiny ↑       ← shared weights
        β”‚                          β”‚
   z_t ∈ ℝ^192             z_g ∈ ℝ^192

   ─── PRISM Prior Head (optional, "pog" / "warm_start" modes) ──
                  β”‚                  β”‚
                  β–Ό                  β–Ό
        concat(z_t, z_g) β†’ MLP β†’ (ΞΌ_p, Οƒ_p)  ∈ ℝ^(H Γ— A_block Γ— A_raw)
        β”‚                                       β”‚
        β–Ό                                       β–Ό
   ─── PoG fusion with MPPI init ΞΌ=0, Οƒ=var_scale ──
                  β”‚
                  β–Ό
        N(ΞΌ_fused, Οƒ_fused) sampled K=300 times
                  β”‚
   ─── MPPI iterations (LeWM AR rollout cost vs z_g) ──
                  β”‚
                  β–Ό
        optimized action sequence (H Γ— A_block, A_raw)
                  β”‚
                  β–Ό
        first A_block actions β†’ robot

Caveats and limitations

  1. PRISM prior trained on 200 eps (target subset). The full 411-ep dataset is heterogeneous (eps 200-410 have random T-final-positions). Training the prior on the clean subset gives a usable signal (val MSE 0.38) but still ~3Γ— worse than sim/red-cube counterparts. The next-best improvement would be collecting future datasets with an explicit goal_pixels field (one printed target image per session).
  2. WM cost surface is borderline (CV @ H=5 = 0.195 < 0.30). The PRISM prior is expected to help bridge the gap; PRISM-MPPI's cost-rescoring step tolerates the borderline cost surface in a way that vanilla MPPI cannot.
  3. Trained on top-down RGB only. Other camera angles are OOD.
  4. 2-D XY action space. Z, rotation, gripper are not controlled.
  5. 10 Hz tick. Faster/slower control loops mismatch the action scaler.

Provenance

  • WM trained: 2026-06-03 (RTX 5090, ~4 h 53 min on all 411 eps)
  • Prior trained: 2026-06-03 (RTX 5090, ~30 s on first 200 eps + sim-aligned HER)
  • Dataset snapshot: Rongxuan-Zhou/pusht_lewm_fr3 sha 1b5dd5db801ef405b43d51dd5b9a3210d8d79ce6
  • Project: PRISM-JEPA
  • Companions:

Citation

@misc{prism-jepa-pusht-fr3-v3,
  title  = {LeWM + PRISM-MPPI for Franka PushT (v3 β€” 411-ep)},
  author = {Wang, Yuhai and Zhou, Rongxuan and collaborators},
  year   = {2026},
  url    = {https://huggingface.co/YuhaiW/lewm-pusht-fr3-v3}
}

If you cite the PRISM action prior mechanism, also cite Andrychowicz et al. (2017) for hindsight experience replay, on which our prior training is based.

License

Apache 2.0.

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading

Dataset used to train YuhaiW/lewm-pusht-fr3-v3