PRISM-JEPA for ARX-X5 Left-Arm Cube Manipulation

A JEPA visual world model + PRISM action prior, trained from scratch on the 201-episode Xia-2004/arx-left-cube dataset (63,619 frames, single-view cam_high only). At inference the two are combined via a closed-form PRISM-MPPI loop suitable for receding-horizon control on a real robot.


1. Bundle

file size role
lewm_arx.ckpt 72 MB LeWM module (JEPA encoder + AR predictor), pickled
prior_head_arx.pt 2.4 MB PRISM prior head (state-dict + action StandardScaler)
arx_inference_demo.py 13 KB Self-contained PRISM-MPPI inference class
jepa.py, module.py ~10 KB Model classes (needed to unpickle lewm_arx.ckpt)
prior_head.py 2.4 KB PriorHead class
requirements.txt ~0.5 KB Pinned runtime dependencies
README.md (this file) Usage and deployment instructions

No stable_worldmodel or stable_pretraining runtime dependency.


2. Install & smoke test

# Download the bundle (set HF_TOKEN if the repo is private)
pip install huggingface_hub
python -c "from huggingface_hub import snapshot_download; \
    snapshot_download(repo_id='YuhaiW/prism-jepa-arx-cube', \
                      local_dir='./prism_arx')"
cd prism_arx/

# Install runtime deps
pip install -r requirements.txt

# Smoke test (loads one frame, runs one plan() call)
python arx_inference_demo.py --lewm-ckpt lewm_arx.ckpt \
                              --prior-ckpt prior_head_arx.pt

Expected: ~0.4 s/plan on RTX 5090, ~5–10 s on CPU.


3. Quick start

import numpy as np
from arx_inference_demo import PrismMPPIInference

planner = PrismMPPIInference(
    lewm_ckpt  = "lewm_arx.ckpt",
    prior_ckpt = "prior_head_arx.pt",
    device     = "cuda",
)

obs_uint8  = current_image_from_cam_high   # (224, 224, 3) uint8 RGB
goal_uint8 = goal_image                     # (224, 224, 3) uint8 RGB
actions    = planner.plan(obs_uint8, goal_uint8)
# actions.shape == (5, 5)  — 5 env-steps × 5 action dims, raw delta-EE units
for a in actions:
    robot.execute(a)

3.1 A/B comparison: PRISM-MPPI vs vanilla LeWM-MPPI

For paper-grade real-robot comparison, instantiate two planners with identical world model / planning hyperparameters and only flip the use_prism flag. They share the same encoder, predictor, MPPI loop, and action StandardScaler — the only difference is whether the PoG fusion at MPPI init uses the prior head's (μ_p, σ_p).

planner_prism = PrismMPPIInference(
    lewm_ckpt  = "lewm_arx.ckpt",
    prior_ckpt = "prior_head_arx.pt",
    use_prism  = True,           # ← PRISM-MPPI (our method)
    device     = "cuda",
)

planner_vanilla = PrismMPPIInference(
    lewm_ckpt  = "lewm_arx.ckpt",
    prior_ckpt = "prior_head_arx.pt",
    use_prism  = False,          # ← vanilla LeWM-MPPI (no prior)
    device     = "cuda",
)

# Alternate trials at the same scene to control for lighting / cube placement
for trial in range(N_TRIALS):
    method = planner_prism if (trial % 2 == 0) else planner_vanilla
    sr     = run_episode(method, obs_stream, goal_uint8)
    log(method=method.__class__.__name__,
        use_prism=method.use_prism,
        sr=sr)

CLI smoke test for both modes:

python arx_inference_demo.py                          # PRISM-MPPI
python arx_inference_demo.py --no-prism               # vanilla LeWM-MPPI

4. ARX-X5 deployment guide

4.1 Action space (read first)

Each row of actions is one 30-Hz tick of delta-end-effector motion:

idx meaning training range (±) unit
0 δx ±0.013 m
1 δy ±0.013 m
2 δz ±0.013 m
3 δyaw (wrist) ±0.10 rad
4 gripper delta ±0.36 rad

Returned in raw env-action units (already denormalized). Frame is the ARX-X5 base frame for xyz / wrist link frame for yaw — same as the operator's teleop commands during data collection.

Important: the gripper dim has weaker learning signal (linear-probe R² ≈ 0.11 vs 0.5+ for the position dims). Expect grasp-timing failures to be the most likely failure mode.

4.2 Camera setup

  • Trained on cam_high only — the third-person front-view in the standard ARX-X5 teleop rig. Do not feed the wrist camera to the model.
  • Output must be (224, 224, 3) uint8 RGB. Center-crop to a square aspect ratio, then resize to 224×224. If you read from OpenCV, convert BGR → RGB.
import cv2
def preprocess(bgr_frame):
    h, w = bgr_frame.shape[:2]; side = min(h, w)
    top, left = (h - side) // 2, (w - side) // 2
    sq = bgr_frame[top:top+side, left:left+side]
    return cv2.cvtColor(cv2.resize(sq, (224, 224), cv2.INTER_AREA), cv2.COLOR_BGR2RGB)

4.3 Goal image

Provide a 224×224 uint8 RGB image of the task target state — ideally the final frame of a successful teleop demonstration, viewed from the same cam_high angle, with the gripper in the same resting pose as in the training distribution. Goal images that drift from the training distribution (different lighting, cube color, gripper position) degrade SR rapidly.

4.4 Receding-horizon control loop

import time, numpy as np
from arx_inference_demo import PrismMPPIInference

planner = PrismMPPIInference(lewm_ckpt="lewm_arx.ckpt",
                              prior_ckpt="prior_head_arx.pt",
                              device="cuda")
goal_uint8 = load_goal_image()

CONTROL_DT = 1.0 / 30.0
N_EXEC = 5                     # = A_block; execute full plan then replan
MAX_STEPS = 250                # ≈ 8 s safety cap

step = 0
while step < MAX_STEPS:
    obs = preprocess(camera.read_cam_high())
    if task_complete(obs, goal_uint8):
        break
    actions = planner.plan(obs, goal_uint8)
    actions = safety_clamp(actions)        # see §4.5
    for a in actions[:N_EXEC]:
        robot.send_delta_action(a)
        time.sleep(CONTROL_DT)
        step += 1
        if step >= MAX_STEPS: break
robot.move_to_home()

4.5 Safety

At minimum, before running on real hardware:

# Per-step magnitude clamp at 2× training-distribution std
ACTION_CLAMP = np.array([0.026, 0.028, 0.028, 0.20, 0.72])
def safety_clamp(actions):
    return np.clip(actions, -ACTION_CLAMP, +ACTION_CLAMP)

# Workspace bounding box — check projected EE pose after each integration step
def in_workspace(pose):
    return (X_MIN <= pose[0] <= X_MAX and
            Y_MIN <= pose[1] <= Y_MAX and
            Z_MIN <= pose[2] <= Z_MAX)

Plus: operator e-stop physically reachable, first 10 trials at 0.5× velocity scaling, watchdog timeout around plan() (it should never take > 1 s).


5. Tuning knobs

knob default when to change
K 128 Lower to 32 for 4× faster planning
n_iters 30 Lower to 10–15 for 2× faster
prior_sigma_scale 2.0 Raise to 5+ if prior overrides good MPPI exploration
temperature 0.5 Lower (0.2) for sharper / more committed plans

Common scenarios:

  • Planner too slow. K=32, n_iters=15 → ~25 ms/plan on RTX 5090.
  • Arm jitters between replans. Lower temperature to 0.2.
  • Plan doesn't move toward goal. First check the goal image is in-distribution. If yes, raise prior_sigma_scale to 5.0 to weaken the prior and let MPPI dominate.
  • Disable the prior entirely (debug). prior_sigma_scale=1e4. Asymptote guarantee reduces to vanilla MPPI.

6. Architecture

  • Encoder: ViT-tiny (192-d CLS, 12 layers, patch_size=14) — LeWM default scale, identical to the simulation benchmark checkpoints.
  • Predictor: ARPredictor (causal transformer, FiLM-modulated on action embedding).
  • PRISM head: 3-layer MLP (512 hidden), input concat(z_t, z_g) ∈ ℝ^384, output (μ, σ) over H=5 × A_block=5 × A_raw=5 = 125 normalized actions.
  • Frameskip: 5 (one plan-step = 5 env-steps).

Algorithm at inference time: Encoder(o_t) → z_t, Encoder(o_g) → z_g, PriorHead(z_t, z_g) → (μ_p, σ_p), then PoG-fuse with MPPI's default N(0, I) and run K=128 × n_iters=30 MPPI iterations with σ held frozen through iterations (the PRISM-MPPI signature). Output is the first 5 env-step actions of the optimized mean, denormalized.


7. Caveats

  1. Distribution shift. Trained on a single ARX-X5 unit. Different camera intrinsics, lighting, or arm geometry will degrade SR. Fine-tuning on small target-unit demos is recommended.
  2. Single-view. cam_high only — wrist camera is not used.
  3. Action magnitudes are small. Training-distribution per-frame deltas are tight (Section 4.1). Outputs exceeding 2× std should be treated as a divergence signal.
  4. Gripper is weak. R² = 0.11 vs 0.5+ for position. Most likely failure mode is grasp timing.
  5. Data is small. 63 k frames is small for a from-scratch ViT. The encoder is healthy (LOEO R² = 0.44, no collapse) but more demos would likely help.
  6. The prior head overfits. Best val at epoch 2 / 50; shipped ckpt is early-stopped. PRISM-MPPI's PoG fusion + frozen-σ mechanism is what protects against the noisy prior at deployment.

8. License

apache-2.0 (matching the source dataset).

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading