RLT Stage-1 RL Token Encoder (MolmoAct2 / YAM stack-cube)

Backup of the RL Token (RLT) Stage-1 encoder for the frozen MolmoAct2-BimanualYAM stack-cube fine-tune. Faithful PyTorch port of openpi's pi0_rl.py (Xu et al. 2025): a learned <rl> query compresses the VLA's (M=690, 2560) prefix hidden states into a single z_rl token; a causal AR decoder reconstructs the prefix (per-token squared-L2, stop-grad targets, α=0 / frozen VLA). z_rl is the state for the downstream SAC actor-critic.

Chosen encoder

checkpoints/rl_token_encoder_ctxdrop09_best.pt (load ["ema"]). Trained with the openpi/paper knobs (AdamW 5e-5, 1k warmup, grad-clip 1.0, EMA 0.999, 10k steps) plus context_dropout=0.9 — zeroing 90% of the decoder's teacher-forced context, which fixes the AR-leak that otherwise leaves z_rl diffuse (the bare α=0 reconstruction lets the decoder ignore the token).

Validation

baseline (α=0) dropout-0.9 (chosen)
PCA top-10 var 15% 28%
temporal smoothness (↓) 0.72 0.69
success-vs-failure LogReg CV acc 99.2% (silhouette 0.13)

z_rl cleanly separates success (44 teleop demos) from failure (7 baseline rollouts, SR≈0) in t-SNE — see outputs/gate_success_fail.png. Caveat: success/failure are from different sessions, so part of the 99% is domain shift, not pure task semantics — strong upper bound.

Data

Trained on 9,668 (690,2560) prefix shards from the 44 atharva-pantheon/yam-stack-cube demos (~1.3 h teleop @ 10 Hz). Matches the RL Token paper's "small per-task demo set" (1–10 h).

Files

  • code/rl_token_encoder.py (model), train_encoder.py, collect_prefix.py (demo→prefix collector), collect_fail_replay.py (karma-rollout→prefix), tsne_gate.py, gate_success_fail.py.
  • checkpoints/ctxdrop09_best/final (chosen), nodrop_best/final (baseline), ctxdrop05_best.
  • plots/tsne_final.png (phase structure), gate_success_fail.png (success/fail), others.

Use (Phase-4 actor-critic)

import torch
from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig
ae = RLTokenAutoencoder(RLTokenConfig(dim=2560))
ae.load_state_dict(torch.load("rl_token_encoder_ctxdrop09_best.pt", map_location="cpu")["ema"])
ae.eval()
z_rl = ae.encode(prefix, mask)   # (b, M, 2560) -> (b, 2560); SAC state x = (z_rl, proprio)

Gotcha: validate z_rl via tsne_gate.py / gate_success_fail.py, NOT a first-token ablation — the first prefix token is a constant special id (151645), making that test vacuous.

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading