RoboVLMs model card
Introduction
This repo contains the pre-trained models through RoboVLMs, which is a unified framework for easily building VLAs from VLMs.
We open-source three pre-trained model checkpoints and their configs:
kosmos_ph_calvin_abcd
: RoboKosMos(KosMos+Policy Head) trained on the CALVIN dataset (split ABCD).kosmos_ph_calvin_abc
: RoboKosMos(KosMos+Policy Head) trained on the CALVIN dataset (split ABC).kosmos_ph_oxe-pretrain
: RoboKosMos(KosMos+Policy Head) trained on the OXE-magic-soup dataset.
Usage
The model can be used to predict action based on the vision and language input. RoboVLMs supports several VLA structures, multi-view input and various backbones. Taking kosmos_ph_calvin_abcd
as an example:
import torch
import json, functools
from PIL import Image
from robovlms.train.base_trainer import BaseTrainer
from robovlms.data.data_utils import preprocess_image
from robovlms.data.data_utils import get_text_function
configs = josn.load(open('configs/kosmos_ph_calvin_abcd.json', 'r'))
pretrained_path = 'checkpoints/kosmos_ph_calvin_abcd.pt'
configs['model_load_path'] = pretrained_path
model = BaseTrainer.from_checkpoint(configs)
image_fn = functools.partial(
preprocess_image,
image_processor=model.model.image_processor,
model_type=configs["model"],
)
text_fn = get_text_function(model.model.tokenizer, configs["model"])
prompt = "Task: pickup the bottle on the table"
text_tensor, attention_mask = text_preprocess([lang])
for step in range(MAX_STEPS):
image: Image.Image = get_from_side_camera(...)
image = image_fn([image]).unsqueeze(0)
input_dict["rgb"] = image
input_dict["text"] = text_tensor
input_dict['text_mask'] = attention_mask
### if wrist camera is available
wrist_image: Image.Image = get_from_wrist_camera(...)
wrist_image = image_fn([wrist_image]).unsqueeze(0)
input_dict["hand_rgb"] = wrist_image
action = model.inference_step(input_dict)["action"]
# unormalize / reproject the action if necessary
from robovlms.data.data_utils import unnoramalize_action
if isinstance(action, tuple):
action = (
unnoramalize_action(
action[0], self.configs["norm_min"], self.configs["norm_max"]
),
action[1],
)
else:
action = unnoramalize_action(
action, self.configs["norm_min"], self.configs["norm_max"]
)