RoboVLMs model card

Introduction

This repo contains the pre-trained models through RoboVLMs, which is a unified framework for easily building VLAs from VLMs.

We open-source three pre-trained model checkpoints and their configs:

  • kosmos_ph_calvin_abcd: RoboKosMos(KosMos+Policy Head) trained on the CALVIN dataset (split ABCD).
  • kosmos_ph_calvin_abc: RoboKosMos(KosMos+Policy Head) trained on the CALVIN dataset (split ABC).
  • kosmos_ph_oxe-pretrain: RoboKosMos(KosMos+Policy Head) trained on the OXE-magic-soup dataset.

Usage

The model can be used to predict action based on the vision and language input. RoboVLMs supports several VLA structures, multi-view input and various backbones. Taking kosmos_ph_calvin_abcd as an example:

import torch
import json, functools
from PIL import Image
from robovlms.train.base_trainer import BaseTrainer
from robovlms.data.data_utils import preprocess_image
from robovlms.data.data_utils import get_text_function

configs = josn.load(open('configs/kosmos_ph_calvin_abcd.json', 'r'))
pretrained_path = 'checkpoints/kosmos_ph_calvin_abcd.pt'
configs['model_load_path'] = pretrained_path

model = BaseTrainer.from_checkpoint(configs)

image_fn = functools.partial(
    preprocess_image,
    image_processor=model.model.image_processor,
    model_type=configs["model"],
)
text_fn = get_text_function(model.model.tokenizer, configs["model"])
prompt = "Task: pickup the bottle on the table"
text_tensor, attention_mask = text_preprocess([lang])

for step in range(MAX_STEPS):
    
    image: Image.Image = get_from_side_camera(...)
    image = image_fn([image]).unsqueeze(0)
    
    input_dict["rgb"] = image
    input_dict["text"] = text_tensor
    input_dict['text_mask'] = attention_mask

    ### if wrist camera is available
    wrist_image: Image.Image = get_from_wrist_camera(...)
    wrist_image = image_fn([wrist_image]).unsqueeze(0)
    input_dict["hand_rgb"] = wrist_image

    action = model.inference_step(input_dict)["action"]

    # unormalize / reproject the action if necessary
    from robovlms.data.data_utils import unnoramalize_action
    if isinstance(action, tuple):
        action = (
            unnoramalize_action(
                action[0], self.configs["norm_min"], self.configs["norm_max"]
            ),
            action[1],
        )
    else:
        action = unnoramalize_action(
            action, self.configs["norm_min"], self.configs["norm_max"]
        )

Evaluation

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .