PRISM-JEPA Β· red_cube (ARX-X5) β€” LeWM world model + PRISM action prior

Complete deployable stack for goal-conditioned visuomotor planning on the ARX-X5 "red cube" task: the LeWM JEPA world model + the PRISM goal-conditioned action prior + self-contained PRISM-MPPI inference code.

⚠️ Status β€” read first

This is a research artifact / deployment hand-off package, NOT a validated policy.

  • Trained on 201 teleop demos (Xia-2004/red_cube, 42,165 frames, 5-DoF).
  • The world model converges cleanly and its forward model is accurate (rollout pred/id β‰ˆ 0.25 < 0.5), but its MPPI cost surface is weak/flat on this small-real-robot, small-action data (CV β‰ˆ 0.14 β‰ͺ 0.30 "discriminative" threshold). Consequence: the planner's distinctive cost-rescoring is dormant, so in practice PRISM β‰ˆ a goal-conditioned BC prior β€” in offline A/B it produces ~31% more expert-like actions than vanilla LeWM-MPPI (which wanders), but adds no measurable goal-progress in the world model's own latent metric (paired t-test p = 0.57).
  • Never run on the real robot. Treat as a starting point; add workspace/velocity safety limits and validate before any hardware run.
  • Full analysis & how these numbers were obtained: project doc docs/30_red_cube_cv_investigation_and_prism.md.

Contents

file description
lewm_red_cube_epoch_100_object.ckpt LeWM world model β€” pickled JEPA: ViT-tiny encoder + AR transformer predictor + action encoder (~18M params)
prior_head_red_cube.pt PRISM goal-conditioned action prior β€” state_dict + config + action StandardScaler (mean/scale)
arx_inference_demo.py self-contained PrismMPPIInference (PoG-fused PRISM-MPPI; use_prism=False β†’ vanilla LeWM-MPPI)
jepa.py, module.py, prior_head.py model classes required to unpickle the ckpt and run the prior

Observation / action space

  • Observation: single top-down RGB frame, 224Γ—224Γ—3 uint8 (RealSense camera_third).
  • Goal: an RGB goal image, same format (the prior + cost are conditioned on it).
  • Action: 5-DoF delta end-effector [dx, dy, dz, dyaw, d_gripper], raw units, one per control tick. plan() returns one plan-step = A_block = 5 ticks β†’ shape (5, 5).

Dependencies

torch, numpy, einops, and transformers (the encoder inside the ckpt is a HuggingFace ViT, needed at unpickle time). The three bundled .py files must be importable from the working directory. (If unpickling complains about a missing class, also pip install stable-pretraining.)

Deploy β€” receding-horizon control loop

from arx_inference_demo import PrismMPPIInference

planner = PrismMPPIInference(
    lewm_ckpt  = "lewm_red_cube_epoch_100_object.ckpt",
    prior_ckpt = "prior_head_red_cube.pt",
    use_prism  = True,     # True = PRISM (prior βŠ— MPPI via PoG fusion); False = vanilla LeWM-MPPI
    device     = "cuda",
)

goal_img = load_goal_image()                   # (224,224,3) uint8 β€” the task goal image
while not done:
    obs     = camera.read()                    # (224,224,3) uint8, top-down camera_third view
    actions = planner.plan(obs, goal_img)      # (5, 5) raw [dx,dy,dz,dyaw,d_gripper]
    for a in actions:                          # receding horizon: execute the block, then replan
        robot.execute(a)                       # (or execute fewer than 5 and replan more often)

plan() runs one full PRISM-MPPI optimization and returns the first A_block = 5 env-step actions of the optimized plan, in raw action units (already de-normalized).

Key hyperparameters (PrismMPPIInference constructor)

arg default meaning
H 5 planning horizon (plan-steps)
A_block 5 env-steps (ticks) per plan-step ("frameskip")
K 128 MPPI samples per iteration
n_iters 30 MPPI refinement iterations
var_scale 1.0 initial planner sampling std
prior_sigma_scale 2.0 multiplier on the prior Οƒ before PoG fusion (PRISM only)
temperature 0.5 MPPI softmax temperature
history_size 3 LeWM history-window length (must match training)

H, A_block, A_raw, history_size must match the checkpoints β€” the constructor asserts the prior head's config agrees. Change them only if you retrain.

PRISM vs vanilla (A/B)

Build a second planner with use_prism=False for a baseline (plain LeWM-MPPI, no prior, same encoder/predictor/MPPI loop). On this task PRISM produces more expert-like actions; vanilla tends to wander because the cost surface is flat.

Provenance

Data: Xia-2004/red_cube (ARX-X5 left-arm teleop). Sibling of Xia-2004/arx-left-cube. World-model architecture is identical to the sim LeWM (ViT-tiny, embed_dim 192, predictor depth 6 / heads 16) β€” part of the PRISM-JEPA project (sister of Newt-PRISM, CoRL 2026).

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading