PRISM-JEPA Β· red_cube (ARX-X5) β LeWM world model + PRISM action prior
Complete deployable stack for goal-conditioned visuomotor planning on the ARX-X5 "red cube" task: the LeWM JEPA world model + the PRISM goal-conditioned action prior + self-contained PRISM-MPPI inference code.
β οΈ Status β read first
This is a research artifact / deployment hand-off package, NOT a validated policy.
- Trained on 201 teleop demos (
Xia-2004/red_cube, 42,165 frames, 5-DoF). - The world model converges cleanly and its forward model is accurate (rollout pred/id β 0.25 < 0.5), but its MPPI cost surface is weak/flat on this small-real-robot, small-action data (CV β 0.14 βͺ 0.30 "discriminative" threshold). Consequence: the planner's distinctive cost-rescoring is dormant, so in practice PRISM β a goal-conditioned BC prior β in offline A/B it produces ~31% more expert-like actions than vanilla LeWM-MPPI (which wanders), but adds no measurable goal-progress in the world model's own latent metric (paired t-test p = 0.57).
- Never run on the real robot. Treat as a starting point; add workspace/velocity safety limits and validate before any hardware run.
- Full analysis & how these numbers were obtained: project doc
docs/30_red_cube_cv_investigation_and_prism.md.
Contents
| file | description |
|---|---|
lewm_red_cube_epoch_100_object.ckpt |
LeWM world model β pickled JEPA: ViT-tiny encoder + AR transformer predictor + action encoder (~18M params) |
prior_head_red_cube.pt |
PRISM goal-conditioned action prior β state_dict + config + action StandardScaler (mean/scale) |
arx_inference_demo.py |
self-contained PrismMPPIInference (PoG-fused PRISM-MPPI; use_prism=False β vanilla LeWM-MPPI) |
jepa.py, module.py, prior_head.py |
model classes required to unpickle the ckpt and run the prior |
Observation / action space
- Observation: single top-down RGB frame, 224Γ224Γ3 uint8 (RealSense
camera_third). - Goal: an RGB goal image, same format (the prior + cost are conditioned on it).
- Action: 5-DoF delta end-effector
[dx, dy, dz, dyaw, d_gripper], raw units, one per control tick.plan()returns one plan-step =A_block = 5ticks β shape(5, 5).
Dependencies
torch, numpy, einops, and transformers (the encoder inside the ckpt is a HuggingFace
ViT, needed at unpickle time). The three bundled .py files must be importable from the
working directory. (If unpickling complains about a missing class, also pip install stable-pretraining.)
Deploy β receding-horizon control loop
from arx_inference_demo import PrismMPPIInference
planner = PrismMPPIInference(
lewm_ckpt = "lewm_red_cube_epoch_100_object.ckpt",
prior_ckpt = "prior_head_red_cube.pt",
use_prism = True, # True = PRISM (prior β MPPI via PoG fusion); False = vanilla LeWM-MPPI
device = "cuda",
)
goal_img = load_goal_image() # (224,224,3) uint8 β the task goal image
while not done:
obs = camera.read() # (224,224,3) uint8, top-down camera_third view
actions = planner.plan(obs, goal_img) # (5, 5) raw [dx,dy,dz,dyaw,d_gripper]
for a in actions: # receding horizon: execute the block, then replan
robot.execute(a) # (or execute fewer than 5 and replan more often)
plan() runs one full PRISM-MPPI optimization and returns the first A_block = 5 env-step
actions of the optimized plan, in raw action units (already de-normalized).
Key hyperparameters (PrismMPPIInference constructor)
| arg | default | meaning |
|---|---|---|
H |
5 | planning horizon (plan-steps) |
A_block |
5 | env-steps (ticks) per plan-step ("frameskip") |
K |
128 | MPPI samples per iteration |
n_iters |
30 | MPPI refinement iterations |
var_scale |
1.0 | initial planner sampling std |
prior_sigma_scale |
2.0 | multiplier on the prior Ο before PoG fusion (PRISM only) |
temperature |
0.5 | MPPI softmax temperature |
history_size |
3 | LeWM history-window length (must match training) |
H, A_block, A_raw, history_size must match the checkpoints β the constructor asserts
the prior head's config agrees. Change them only if you retrain.
PRISM vs vanilla (A/B)
Build a second planner with use_prism=False for a baseline (plain LeWM-MPPI, no prior,
same encoder/predictor/MPPI loop). On this task PRISM produces more expert-like actions;
vanilla tends to wander because the cost surface is flat.
Provenance
Data: Xia-2004/red_cube (ARX-X5
left-arm teleop). Sibling of Xia-2004/arx-left-cube. World-model architecture is identical
to the sim LeWM (ViT-tiny, embed_dim 192, predictor depth 6 / heads 16) β part of the
PRISM-JEPA project (sister of Newt-PRISM, CoRL 2026).