TraceVLA
Collection
TraceVLA: Visual Trace Prompting Enhances Spatial-Temporal Awareness for Generalist Robotic Policies
•
4 items
•
Updated
•
2
OpenVLA-Phi3V
model is a vision-language-action model obtained by finetuning the base Phi3V Model on the same Open X-Embodiment robot mixture dataset as the original OpenVLA model.
Policy/Settings | Pick up Coke | Move near | Open/Close Drawer | Put in Drawer | Average Success Rate |
---|---|---|---|---|---|
(Visual Matching) OpenVLA-Phi3V | 56.7% | 53.3% | 38.4% | 15.7% | 41.0% |
(Visual Matching) OpenVLA-7B | 23.7% | 65.0% | 57.4% | 0.% | 36.5% |
(Variant Aggregation) OpenVLA-Phi3V | 55.4% | 57.7% | 19.3% | 10.6% | 35.8% |
(Variant Aggregation) OpenVLA-7B | 61.3% | 55.8% | 24.9% | 1.0% | 35.8% |
Policy/Settings | Put Spoon | Put Carrot | Stack Block | Put Eggplant | Average Success Rate |
---|---|---|---|---|---|
OpenVLA-Phi3V | 12.5% | 0% | 0% | 8.3% | 5.2% |
OpenVLA-7B | 8.3% | 8.3% | 4.2% | 45.8% | 16.7% |
Here is the sample inference code of OpenVLA-Phi3V.
# Load Processor & VLA
from transformers import AutoModelForCausalLM , AutoProcessor
from PIL import Image
import json
processor = AutoProcessor.from_pretrained(
model_path, trust_remote_code=True, num_crops=1
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
_attn_implementation='flash_attention_2',
use_cache=False
).cuda()
# Load dataset statistics
dataset_stats_dir = os.path.join(model_path, 'dataset_statistics.json')
with open(dataset_stats_dir, 'r') as file:
action_norm_stats = json.load(file)[dataset_name]['action']
model.prepare_action_inference(action_norm_stats, processor.tokenizer.vocab_size)
lang: str = None # Task language instruction
### IMPORTANT: Make sure image is of size (336,336)
image: PIL.Image = None # Image observation
# Process the prompt & image
prompt_message = {
'role': 'user',
'content': f'<|image_1|>\nWhat action should the robot take to {lang}?',
}
prompt = processor.tokenizer.apply_chat_template(
[prompt_message], tokenize=False, add_generation_prompt=True
)
inputs = processor(prompt, [image], return_tensors='pt').to('cuda')
# Get the action output from model
model.predict_action(**inputs)
For more examples, including scripts for finetuning OpenVLA-Phi3V models on your own robot demonstration datasets, check out our repository.
If you find our code or models useful in your work, please cite our paper:
@misc{zheng2024tracevlavisualtraceprompting,
title={TraceVLA: Visual Trace Prompting Enhances Spatial-Temporal Awareness for Generalist Robotic Policies},
author={Ruijie Zheng and Yongyuan Liang and Shuaiyi Huang and Jianfeng Gao and Hal Daumé III and Andrey Kolobov and Furong Huang and Jianwei Yang},
year={2024},
eprint={2412.10345},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2412.10345},
}