VLAarchtests / tests /test_rgbd_forward_contract.py
lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
from train.trainer import build_policy
def test_rgbd_forward_contract(tiny_policy_config, tiny_trainer_config, tiny_batch):
config = tiny_policy_config()
batch = tiny_batch(chunk_size=config.decoder.chunk_size)
policy = build_policy(config, tiny_trainer_config(policy_type="elastic_reveal"))
output = policy(
images=batch["images"],
depths=batch["depths"],
depth_valid=batch["depth_valid"],
camera_intrinsics=batch["camera_intrinsics"],
camera_extrinsics=batch["camera_extrinsics"],
proprio=batch["proprio"],
texts=batch["texts"],
history_images=batch["history_images"],
history_depths=batch["history_depths"],
history_depth_valid=batch["history_depth_valid"],
history_proprio=batch["history_proprio"],
history_actions=batch["history_actions"],
plan=True,
compute_equivariance_probe=True,
)
assert output["action_mean"].shape[0] == batch["images"].shape[0]
assert output["depth_tokens"] is not None
assert output["geometry_tokens"] is not None
assert output["camera_tokens"] is not None
assert output["proposal_candidates"].shape[1] == config.decoder.num_candidates
assert output["planner_topk_indices"].shape[1] == config.planner.top_k
assert output["planned_rollout"]["target_belief_field"].shape[1] == config.planner.top_k
assert "opening_quality" in output["interaction_state"]
assert "gap_width" in output["interaction_state"]
assert "hold_quality" in output["interaction_state"]
assert output["equivariance_probe_action_mean"].shape == output["equivariance_target_action_mean"].shape