PRISM-JEPA for ARX-X5 Left-Arm Cube Manipulation
A JEPA visual world model + PRISM action prior, trained from scratch on
the 201-episode
Xia-2004/arx-left-cube
dataset (63,619 frames, single-view cam_high only). At inference the
two are combined via a closed-form PRISM-MPPI loop suitable for
receding-horizon control on a real robot.
1. Bundle
| file | size | role |
|---|---|---|
lewm_arx.ckpt |
72 MB | LeWM module (JEPA encoder + AR predictor), pickled |
prior_head_arx.pt |
2.4 MB | PRISM prior head (state-dict + action StandardScaler) |
arx_inference_demo.py |
13 KB | Self-contained PRISM-MPPI inference class |
jepa.py, module.py |
~10 KB | Model classes (needed to unpickle lewm_arx.ckpt) |
prior_head.py |
2.4 KB | PriorHead class |
requirements.txt |
~0.5 KB | Pinned runtime dependencies |
README.md (this file) |
— | Usage and deployment instructions |
No stable_worldmodel or stable_pretraining runtime dependency.
2. Install & smoke test
# Download the bundle (set HF_TOKEN if the repo is private)
pip install huggingface_hub
python -c "from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='YuhaiW/prism-jepa-arx-cube', \
local_dir='./prism_arx')"
cd prism_arx/
# Install runtime deps
pip install -r requirements.txt
# Smoke test (loads one frame, runs one plan() call)
python arx_inference_demo.py --lewm-ckpt lewm_arx.ckpt \
--prior-ckpt prior_head_arx.pt
Expected: ~0.4 s/plan on RTX 5090, ~5–10 s on CPU.
3. Quick start
import numpy as np
from arx_inference_demo import PrismMPPIInference
planner = PrismMPPIInference(
lewm_ckpt = "lewm_arx.ckpt",
prior_ckpt = "prior_head_arx.pt",
device = "cuda",
)
obs_uint8 = current_image_from_cam_high # (224, 224, 3) uint8 RGB
goal_uint8 = goal_image # (224, 224, 3) uint8 RGB
actions = planner.plan(obs_uint8, goal_uint8)
# actions.shape == (5, 5) — 5 env-steps × 5 action dims, raw delta-EE units
for a in actions:
robot.execute(a)
3.1 A/B comparison: PRISM-MPPI vs vanilla LeWM-MPPI
For paper-grade real-robot comparison, instantiate two planners with
identical world model / planning hyperparameters and only flip the
use_prism flag. They share the same encoder, predictor, MPPI loop,
and action StandardScaler — the only difference is whether the PoG
fusion at MPPI init uses the prior head's (μ_p, σ_p).
planner_prism = PrismMPPIInference(
lewm_ckpt = "lewm_arx.ckpt",
prior_ckpt = "prior_head_arx.pt",
use_prism = True, # ← PRISM-MPPI (our method)
device = "cuda",
)
planner_vanilla = PrismMPPIInference(
lewm_ckpt = "lewm_arx.ckpt",
prior_ckpt = "prior_head_arx.pt",
use_prism = False, # ← vanilla LeWM-MPPI (no prior)
device = "cuda",
)
# Alternate trials at the same scene to control for lighting / cube placement
for trial in range(N_TRIALS):
method = planner_prism if (trial % 2 == 0) else planner_vanilla
sr = run_episode(method, obs_stream, goal_uint8)
log(method=method.__class__.__name__,
use_prism=method.use_prism,
sr=sr)
CLI smoke test for both modes:
python arx_inference_demo.py # PRISM-MPPI
python arx_inference_demo.py --no-prism # vanilla LeWM-MPPI
4. ARX-X5 deployment guide
4.1 Action space (read first)
Each row of actions is one 30-Hz tick of delta-end-effector motion:
| idx | meaning | training range (±) | unit |
|---|---|---|---|
| 0 | δx | ±0.013 | m |
| 1 | δy | ±0.013 | m |
| 2 | δz | ±0.013 | m |
| 3 | δyaw (wrist) | ±0.10 | rad |
| 4 | gripper delta | ±0.36 | rad |
Returned in raw env-action units (already denormalized). Frame is the ARX-X5 base frame for xyz / wrist link frame for yaw — same as the operator's teleop commands during data collection.
Important: the gripper dim has weaker learning signal (linear-probe R² ≈ 0.11 vs 0.5+ for the position dims). Expect grasp-timing failures to be the most likely failure mode.
4.2 Camera setup
- Trained on
cam_highonly — the third-person front-view in the standard ARX-X5 teleop rig. Do not feed the wrist camera to the model. - Output must be (224, 224, 3) uint8 RGB. Center-crop to a square aspect ratio, then resize to 224×224. If you read from OpenCV, convert BGR → RGB.
import cv2
def preprocess(bgr_frame):
h, w = bgr_frame.shape[:2]; side = min(h, w)
top, left = (h - side) // 2, (w - side) // 2
sq = bgr_frame[top:top+side, left:left+side]
return cv2.cvtColor(cv2.resize(sq, (224, 224), cv2.INTER_AREA), cv2.COLOR_BGR2RGB)
4.3 Goal image
Provide a 224×224 uint8 RGB image of the task target state —
ideally the final frame of a successful teleop demonstration, viewed
from the same cam_high angle, with the gripper in the same resting
pose as in the training distribution. Goal images that drift from the
training distribution (different lighting, cube color, gripper
position) degrade SR rapidly.
4.4 Receding-horizon control loop
import time, numpy as np
from arx_inference_demo import PrismMPPIInference
planner = PrismMPPIInference(lewm_ckpt="lewm_arx.ckpt",
prior_ckpt="prior_head_arx.pt",
device="cuda")
goal_uint8 = load_goal_image()
CONTROL_DT = 1.0 / 30.0
N_EXEC = 5 # = A_block; execute full plan then replan
MAX_STEPS = 250 # ≈ 8 s safety cap
step = 0
while step < MAX_STEPS:
obs = preprocess(camera.read_cam_high())
if task_complete(obs, goal_uint8):
break
actions = planner.plan(obs, goal_uint8)
actions = safety_clamp(actions) # see §4.5
for a in actions[:N_EXEC]:
robot.send_delta_action(a)
time.sleep(CONTROL_DT)
step += 1
if step >= MAX_STEPS: break
robot.move_to_home()
4.5 Safety
At minimum, before running on real hardware:
# Per-step magnitude clamp at 2× training-distribution std
ACTION_CLAMP = np.array([0.026, 0.028, 0.028, 0.20, 0.72])
def safety_clamp(actions):
return np.clip(actions, -ACTION_CLAMP, +ACTION_CLAMP)
# Workspace bounding box — check projected EE pose after each integration step
def in_workspace(pose):
return (X_MIN <= pose[0] <= X_MAX and
Y_MIN <= pose[1] <= Y_MAX and
Z_MIN <= pose[2] <= Z_MAX)
Plus: operator e-stop physically reachable, first 10 trials at 0.5×
velocity scaling, watchdog timeout around plan() (it should never
take > 1 s).
5. Tuning knobs
| knob | default | when to change |
|---|---|---|
K |
128 | Lower to 32 for 4× faster planning |
n_iters |
30 | Lower to 10–15 for 2× faster |
prior_sigma_scale |
2.0 | Raise to 5+ if prior overrides good MPPI exploration |
temperature |
0.5 | Lower (0.2) for sharper / more committed plans |
Common scenarios:
- Planner too slow.
K=32, n_iters=15→ ~25 ms/plan on RTX 5090. - Arm jitters between replans. Lower
temperatureto 0.2. - Plan doesn't move toward goal. First check the goal image is
in-distribution. If yes, raise
prior_sigma_scaleto 5.0 to weaken the prior and let MPPI dominate. - Disable the prior entirely (debug).
prior_sigma_scale=1e4. Asymptote guarantee reduces to vanilla MPPI.
6. Architecture
- Encoder: ViT-tiny (192-d CLS, 12 layers, patch_size=14) — LeWM default scale, identical to the simulation benchmark checkpoints.
- Predictor: ARPredictor (causal transformer, FiLM-modulated on action embedding).
- PRISM head: 3-layer MLP (512 hidden), input concat(z_t, z_g) ∈ ℝ^384, output (μ, σ) over H=5 × A_block=5 × A_raw=5 = 125 normalized actions.
- Frameskip: 5 (one plan-step = 5 env-steps).
Algorithm at inference time: Encoder(o_t) → z_t, Encoder(o_g) → z_g,
PriorHead(z_t, z_g) → (μ_p, σ_p), then PoG-fuse with MPPI's default
N(0, I) and run K=128 × n_iters=30 MPPI iterations with σ held
frozen through iterations (the PRISM-MPPI signature). Output is the
first 5 env-step actions of the optimized mean, denormalized.
7. Caveats
- Distribution shift. Trained on a single ARX-X5 unit. Different camera intrinsics, lighting, or arm geometry will degrade SR. Fine-tuning on small target-unit demos is recommended.
- Single-view.
cam_highonly — wrist camera is not used. - Action magnitudes are small. Training-distribution per-frame deltas are tight (Section 4.1). Outputs exceeding 2× std should be treated as a divergence signal.
- Gripper is weak. R² = 0.11 vs 0.5+ for position. Most likely failure mode is grasp timing.
- Data is small. 63 k frames is small for a from-scratch ViT. The encoder is healthy (LOEO R² = 0.44, no collapse) but more demos would likely help.
- The prior head overfits. Best val at epoch 2 / 50; shipped ckpt is early-stopped. PRISM-MPPI's PoG fusion + frozen-σ mechanism is what protects against the noisy prior at deployment.
8. License
apache-2.0 (matching the source dataset).