AR-VLA model card

This model was developed by INSAIT and KU Leuven, trying to reproduce a Pi-0-Fast model pretrained on the bridge dataset for WidowX embodiment.

Code and model weights for AR-VLA models are free to use under the Gemma license.

This repo provides model weights fine-tuned for a widowX setup with one external camera.

The weights work out of the box on simpler env and a real widowX robot in a similar toy kitchen scene.

Use with ๐Ÿค— Transformers

We provide a fully AutoModel compatible implementation of AR-VLA that can be used via Transformers.

Environment setup

The current implementation requires the following additional dependencies: roma, timm, flash-attn.

Here is a snippet to set up a working environment for inference via uv:

  1. Install uv:
wget -qO- [https://github.com/astral-sh/uv/releases/download/0.7.5/uv-installer.sh](https://github.com/astral-sh/uv/releases/download/0.7.5/uv-installer.sh) | sh

Create virtualenv and resolve the dependencies:

Bash
uv venv python 3.10.12
source .venv/bin/activate
uv pip install --torch-backend=cu126 roma==1.5.0 numpy==2.2.4 torch==2.6.0 torchvision==0.21.0 transformers==4.47.0 timm==1.0.15
uv pip install --no-build-isolation setuptools psutil flash-attn==2.7.3

Example usage The following simplified script demonstrates how to run inference and decode the output tokens directly into a usable action sequence using the built-in policy decoder.

Python
import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor

model_id = "you2who/paligemma-fast-bridge"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Load Model and Processor
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# 2. Prepare Inputs
image = Image.open("path/to/main_image.png").convert("RGB")
dataset_name = np.array(["bridge"])

batch = processor.preprocess_inputs(
    chat=["Pick up the cup.", ""],
    images={"main": [image]},
    ee_pose_translation=np.zeros((1, 1, 3), dtype=np.float32),
    ee_pose_rotation=np.array([[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32),
    gripper=np.zeros((1, 1), dtype=np.float32),
    joints=np.zeros((1, 1, 7), dtype=np.float32),
    dataset_name=dataset_name,
    inference_mode=True,
)

# Move batch to device
model_inputs = {
    k: v.to(device) if isinstance(v, torch.Tensor) else v 
    for k, v in batch.items() if k != "images"
}
model_inputs["images"] = {k: v.to(device) for k, v in batch["images"].items()}

# 3. Run Inference
with torch.inference_mode():
    output = model(**model_inputs)

# 4. Token Decode to Action Sequence
policy_decoder = getattr(processor, "vlarm_processor", processor)
num_steps = output.token_ids.shape[1] if hasattr(output, "token_ids") else output.token_logits.shape[1]
valid_mask = torch.ones((model_inputs["input_ids"].shape[0], num_steps), dtype=torch.bool)

# Decode outputs
control_plan = policy_decoder.policy_control_plan_from_model_output(
    model_output=output,
    dataset_name=dataset_name,
    valid_mask=valid_mask,
)

# Flatten the rotation matrix and concatenate the action dimensions
rot_flat = control_plan.rotmat.reshape(*control_plan.rotmat.shape[:-2], 9)
action_sequence = torch.cat([
    control_plan.translation_m, 
    rot_flat, 
    control_plan.gripper_prob
], dim=-1)

# Extract the valid actions for the first batch item
batch_mask = control_plan.valid_mask[0].detach().cpu().bool().squeeze()
valid_actions = action_sequence[0].detach().cpu()[batch_mask]

print("Decoded Action Sequence Shape:", valid_actions.shape)
print("Actions:\n", valid_actions.numpy())

Summary Model type: Vision-Language-Action with autoregressive action generation

Model id: you2who/paligemma-fast-bridge

License: Gemma Terms of Use

Downloads last month
1
Safetensors
Model size
3B params
Tensor type
F32
ยท
Video Preview
loading

Collection including you2who/paligemma-fast-bridge