Add files using upload-large-folder tool
Browse files- capvector-oft/scripts/extern/convert_prismatic_weights_to_hf.py +237 -0
- capvector-oft/training_scripts/training.sh +36 -0
- capvector-oft/vla-scripts/extern/convert_openvla_weights_to_hf.py +272 -0
- capvector-oft/vla-scripts/extern/verify_openvla.py +89 -0
- capvector-oft/vla-scripts/finetune.py +1152 -0
- capvector-oft/vla-scripts/finetune_regular_loss.py +1790 -0
- capvector-oft/vla-scripts/merge_lora_weights_and_save.py +73 -0
- capvector-pi05/.dockerignore +3 -0
- capvector-pi05/.gitignore +169 -0
- capvector-pi05/.gitmodules +6 -0
- capvector-pi05/.pre-commit-config.yaml +16 -0
- capvector-pi05/.python-version +1 -0
- capvector-pi05/LICENSE +201 -0
- capvector-pi05/README.md +128 -0
- capvector-pi05/capvector/apply_param_diff.py +135 -0
- capvector-pi05/capvector/compute_param_diff.py +142 -0
- capvector-pi05/docs/docker.md +25 -0
- capvector-pi05/docs/norm_stats.md +69 -0
- capvector-pi05/docs/remote_inference.md +71 -0
- capvector-pi05/examples/aloha_real/Dockerfile +70 -0
- capvector-pi05/examples/aloha_real/README.md +126 -0
- capvector-pi05/examples/aloha_real/compose.yml +66 -0
- capvector-pi05/examples/aloha_real/constants.py +71 -0
- capvector-pi05/examples/aloha_real/convert_aloha_data_to_lerobot.py +263 -0
- capvector-pi05/examples/aloha_real/env.py +57 -0
- capvector-pi05/examples/aloha_real/main.py +51 -0
- capvector-pi05/examples/aloha_real/real_env.py +176 -0
- capvector-pi05/examples/aloha_real/requirements.in +18 -0
- capvector-pi05/examples/aloha_real/requirements.txt +156 -0
- capvector-pi05/examples/aloha_real/robot_utils.py +275 -0
- capvector-pi05/examples/aloha_real/video_display.py +36 -0
- capvector-pi05/examples/aloha_sim/Dockerfile +41 -0
- capvector-pi05/examples/aloha_sim/README.md +36 -0
- capvector-pi05/examples/aloha_sim/compose.yml +42 -0
- capvector-pi05/examples/aloha_sim/env.py +56 -0
- capvector-pi05/examples/aloha_sim/main.py +55 -0
- capvector-pi05/examples/aloha_sim/requirements.in +8 -0
- capvector-pi05/examples/aloha_sim/requirements.txt +132 -0
- capvector-pi05/examples/aloha_sim/saver.py +40 -0
- capvector-pi05/examples/convert_jax_model_to_pytorch.py +587 -0
- capvector-pi05/examples/droid/README.md +84 -0
- capvector-pi05/examples/droid/README_train.md +106 -0
- capvector-pi05/examples/droid/compute_droid_nonidle_ranges.py +103 -0
- capvector-pi05/examples/droid/convert_droid_data_to_lerobot.py +477 -0
- capvector-pi05/examples/droid/main.py +246 -0
- capvector-pi05/examples/inference.ipynb +137 -0
- capvector-pi05/examples/libero/compose.yml +54 -0
- capvector-pi05/examples/libero/convert_libero_data_to_lerobot.py +104 -0
- capvector-pi05/examples/policy_records.ipynb +134 -0
- capvector-pi05/pyproject.toml +142 -0
capvector-oft/scripts/extern/convert_prismatic_weights_to_hf.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
convert_prismatic_weights_to_hf.py
|
| 3 |
+
|
| 4 |
+
Utility script for converting full Prismatic VLM weights (from this repository, in the default "Prismatic" format) to
|
| 5 |
+
the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers``
|
| 6 |
+
via `trust_remote_code = True`.
|
| 7 |
+
|
| 8 |
+
Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the
|
| 9 |
+
line, with first-class support.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Union
|
| 17 |
+
|
| 18 |
+
import draccus
|
| 19 |
+
import timm
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
from timm.models.vision_transformer import LayerScale
|
| 24 |
+
from transformers import AutoTokenizer
|
| 25 |
+
|
| 26 |
+
from prismatic.extern.hf.configuration_prismatic import PrismaticConfig
|
| 27 |
+
from prismatic.extern.hf.modeling_prismatic import PrismaticForConditionalGeneration
|
| 28 |
+
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class HFConvertConfig:
|
| 33 |
+
# fmt: off
|
| 34 |
+
prismatic_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLM (on disk or HF Hub)
|
| 35 |
+
"siglip-224px+7b"
|
| 36 |
+
# "prism-dinosiglip-224px+7b"
|
| 37 |
+
)
|
| 38 |
+
output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model
|
| 39 |
+
"hf-convert/prismatic-siglip-224px-7b"
|
| 40 |
+
)
|
| 41 |
+
output_hf_model_hub_path: str = ( # Path to HF Hub Path for "final" HF model
|
| 42 |
+
"TRI-ML/prismatic-siglip-224px-7b" # => huggingface.co/TRI-ML/prismatic-{...}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# HF Hub Credentials (required for Gated Models like LLaMa-2)
|
| 46 |
+
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
| 47 |
+
|
| 48 |
+
def __post_init__(self) -> None:
|
| 49 |
+
self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token
|
| 50 |
+
|
| 51 |
+
# fmt: on
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
|
| 55 |
+
# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
|
| 56 |
+
# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
|
| 57 |
+
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def ls_apply_patch(ls_module: LayerScale):
|
| 62 |
+
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
|
| 63 |
+
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
|
| 64 |
+
del ls_module.gamma
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# === Conversion Constants ===
|
| 68 |
+
PROJECTOR_KEY_MAPPING = {
|
| 69 |
+
"projector.0.weight": "projector.fc1.weight",
|
| 70 |
+
"projector.0.bias": "projector.fc1.bias",
|
| 71 |
+
"projector.2.weight": "projector.fc2.weight",
|
| 72 |
+
"projector.2.bias": "projector.fc2.bias",
|
| 73 |
+
"projector.4.weight": "projector.fc3.weight",
|
| 74 |
+
"projector.4.bias": "projector.fc3.bias",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def remap_state_dicts_for_hf(
|
| 79 |
+
projector_state_dict: Dict[str, torch.Tensor],
|
| 80 |
+
llm_backbone_state_dict: Dict[str, torch.Tensor],
|
| 81 |
+
vision_backbone_state_dicts: List[Dict[str, torch.Tensor]],
|
| 82 |
+
) -> Dict[str, torch.Tensor]:
|
| 83 |
+
"""Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion."""
|
| 84 |
+
hf_state_dict = {}
|
| 85 |
+
|
| 86 |
+
# Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING`
|
| 87 |
+
for key, value in projector_state_dict.items():
|
| 88 |
+
hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value
|
| 89 |
+
|
| 90 |
+
# Iterate through LLM Backbone =>> replace `llm.` with `language_model.`
|
| 91 |
+
for key, value in llm_backbone_state_dict.items():
|
| 92 |
+
hf_state_dict[key.replace("llm.", "language_model.")] = value
|
| 93 |
+
|
| 94 |
+
# Iterate through Vision Backbone =>> add "vision_backbone." prefix
|
| 95 |
+
assert len(vision_backbone_state_dicts) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!"
|
| 96 |
+
for idx, vision_backbone_state_dict in enumerate(vision_backbone_state_dicts):
|
| 97 |
+
prefix = "vision_backbone.featurizer" if idx == 0 else "vision_backbone.fused_featurizer"
|
| 98 |
+
for key, value in vision_backbone_state_dict.items():
|
| 99 |
+
hf_state_dict[f"{prefix}.{key}"] = value
|
| 100 |
+
|
| 101 |
+
return hf_state_dict
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@draccus.wrap()
|
| 105 |
+
def convert_prismatic_weights_to_hf(cfg: HFConvertConfig) -> None:
|
| 106 |
+
print(f"[*] Converting Prismatic Model `{cfg.prismatic_model_path_or_id}` to HF Transformers Format")
|
| 107 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 108 |
+
|
| 109 |
+
# Get `config.json` and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py`
|
| 110 |
+
if os.path.isdir(cfg.prismatic_model_path_or_id):
|
| 111 |
+
print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.prismatic_model_path_or_id))}`")
|
| 112 |
+
config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
|
| 113 |
+
|
| 114 |
+
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
|
| 115 |
+
assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
|
| 116 |
+
else:
|
| 117 |
+
print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.prismatic_model_path_or_id}`")
|
| 118 |
+
config_json = hf_hub_download("TRI-ML/prismatic-vlms", f"{cfg.prismatic_model_path_or_id}/config.json")
|
| 119 |
+
checkpoint_pt = hf_hub_download(
|
| 120 |
+
"TRI-ML/prismatic-vlms", f"{cfg.prismatic_model_path_or_id}/checkpoints/latest-checkpoint.pt"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer
|
| 124 |
+
with open(config_json, "r") as f:
|
| 125 |
+
prismatic_config = json.load(f)["model"]
|
| 126 |
+
|
| 127 |
+
# Create HF PrismaticConfig (`transformers.PretrainedConfig`)
|
| 128 |
+
hf_config = PrismaticConfig(
|
| 129 |
+
vision_backbone_id=prismatic_config["vision_backbone_id"],
|
| 130 |
+
llm_backbone_id=prismatic_config["llm_backbone_id"],
|
| 131 |
+
arch_specifier=prismatic_config["arch_specifier"],
|
| 132 |
+
image_resize_strategy=prismatic_config["image_resize_strategy"],
|
| 133 |
+
llm_max_length=prismatic_config["llm_max_length"],
|
| 134 |
+
torch_dtype=torch.bfloat16,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer`
|
| 138 |
+
# TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`!
|
| 139 |
+
print("[*] Instantiating and Patching Tokenizer, LLM Config")
|
| 140 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 141 |
+
hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right"
|
| 142 |
+
)
|
| 143 |
+
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 144 |
+
tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload...
|
| 145 |
+
assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!"
|
| 146 |
+
assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!"
|
| 147 |
+
|
| 148 |
+
# Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate
|
| 149 |
+
hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of
|
| 150 |
+
hf_config.text_config.pad_token_id = hf_config.pad_token_id
|
| 151 |
+
hf_config.text_config.torch_dtype = torch.bfloat16
|
| 152 |
+
assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!"
|
| 153 |
+
|
| 154 |
+
# Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform`
|
| 155 |
+
# =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py`
|
| 156 |
+
print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor")
|
| 157 |
+
timm_vision_backbones, input_sizes, interpolations, means, stds = [], [], [], [], []
|
| 158 |
+
for idx, timm_model_id in enumerate(hf_config.timm_model_ids):
|
| 159 |
+
timm_vision_backbone = timm.create_model(
|
| 160 |
+
timm_model_id,
|
| 161 |
+
pretrained=True,
|
| 162 |
+
num_classes=0,
|
| 163 |
+
img_size=hf_config.image_sizes[idx],
|
| 164 |
+
act_layer=hf_config.timm_override_act_layers[idx],
|
| 165 |
+
)
|
| 166 |
+
timm_vision_backbones.append(timm_vision_backbone)
|
| 167 |
+
|
| 168 |
+
# Get Per-Backbone Image Processing
|
| 169 |
+
data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone)
|
| 170 |
+
input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx]))
|
| 171 |
+
interpolations.append(data_cfg["interpolation"])
|
| 172 |
+
means.append(data_cfg["mean"])
|
| 173 |
+
stds.append(data_cfg["std"])
|
| 174 |
+
|
| 175 |
+
# Patch `LayerScale` because of HF annoying `fix_key` overwrite...
|
| 176 |
+
for module in timm_vision_backbone.modules():
|
| 177 |
+
if isinstance(module, LayerScale):
|
| 178 |
+
ls_apply_patch(module)
|
| 179 |
+
|
| 180 |
+
# Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`)
|
| 181 |
+
hf_image_processor = PrismaticImageProcessor(
|
| 182 |
+
use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
|
| 183 |
+
image_resize_strategy=hf_config.image_resize_strategy,
|
| 184 |
+
input_sizes=input_sizes,
|
| 185 |
+
interpolations=interpolations,
|
| 186 |
+
means=means,
|
| 187 |
+
stds=stds,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor)
|
| 191 |
+
print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor")
|
| 192 |
+
hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer)
|
| 193 |
+
|
| 194 |
+
# Load Prismatic Model State Dictionary (in preparation for conversion)
|
| 195 |
+
print("[*] Loading Prismatic VLM State Dictionary from Checkpoint")
|
| 196 |
+
model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"]
|
| 197 |
+
assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?"
|
| 198 |
+
assert ("projector" in model_state_dict) and ("llm_backbone" in model_state_dict), "Missing keys!"
|
| 199 |
+
|
| 200 |
+
# Convert
|
| 201 |
+
print("[*] Running Conversion")
|
| 202 |
+
converted_state_dict = remap_state_dicts_for_hf(
|
| 203 |
+
model_state_dict["projector"],
|
| 204 |
+
model_state_dict["llm_backbone"],
|
| 205 |
+
vision_backbone_state_dicts=[vb.state_dict() for vb in timm_vision_backbones],
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM
|
| 209 |
+
print("[*] Building (Randomly Initialized) Model =>> PrismaticForConditionalGeneration")
|
| 210 |
+
hf_model = PrismaticForConditionalGeneration(hf_config)
|
| 211 |
+
hf_model.load_state_dict(converted_state_dict, strict=True, assign=True)
|
| 212 |
+
|
| 213 |
+
# Cast Model to BF16 before Saving
|
| 214 |
+
hf_model.to(torch.bfloat16)
|
| 215 |
+
|
| 216 |
+
# Save Pretrained Versions to Local Path
|
| 217 |
+
print("[*] Saving Model & Processor to Local Path")
|
| 218 |
+
hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB")
|
| 219 |
+
hf_image_processor.save_pretrained(cfg.output_hf_model_local_path)
|
| 220 |
+
hf_processor.save_pretrained(cfg.output_hf_model_local_path)
|
| 221 |
+
|
| 222 |
+
# Register AutoClasses
|
| 223 |
+
PrismaticConfig.register_for_auto_class()
|
| 224 |
+
PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor")
|
| 225 |
+
PrismaticProcessor.register_for_auto_class("AutoProcessor")
|
| 226 |
+
PrismaticForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq")
|
| 227 |
+
|
| 228 |
+
# Push to Hub
|
| 229 |
+
print("[*] Pushing Model & Processor to HF Hub")
|
| 230 |
+
hf_config.push_to_hub(cfg.output_hf_model_hub_path)
|
| 231 |
+
hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB")
|
| 232 |
+
hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path)
|
| 233 |
+
hf_processor.push_to_hub(cfg.output_hf_model_hub_path)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
convert_prismatic_weights_to_hf()
|
capvector-oft/training_scripts/training.sh
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
VERSION="v0"
|
| 2 |
+
TASK="10" # spatial / object / goal / 10 / 90
|
| 3 |
+
VLA_PATH="checkpoints/initialized_pt_vla/initailized_openvla_with_SF_spatial_v0.4.2"
|
| 4 |
+
DATA_ROOT_DIR="data/libero_openvla"
|
| 5 |
+
RUN_ROOT_DIR="experiments/training_results"
|
| 6 |
+
REGULARIZATION_LORA_VECTOR_PATH="checkpoints/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors"
|
| 7 |
+
WANDB_ENTITY="YOUR_WANDB_ENTITY"
|
| 8 |
+
WANDB_PROJECT="YOUR_WANDB_PROJECT"
|
| 9 |
+
EVAL_LOG_PATH="experiments/eval_logs/${VERSION}_output.log"
|
| 10 |
+
|
| 11 |
+
torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune_regular_loss.py \
|
| 12 |
+
--vla_path "$VLA_PATH" \
|
| 13 |
+
--data_root_dir "$DATA_ROOT_DIR" \
|
| 14 |
+
--dataset_name libero_${TASK}_no_noops \
|
| 15 |
+
--run_root_dir "$RUN_ROOT_DIR" \
|
| 16 |
+
--use_l1_regression True \
|
| 17 |
+
--use_diffusion False \
|
| 18 |
+
--use_film False \
|
| 19 |
+
--num_images_in_input 2 \
|
| 20 |
+
--use_proprio True \
|
| 21 |
+
--batch_size 8 \
|
| 22 |
+
--learning_rate 5e-4 \
|
| 23 |
+
--scheduler CosineAnnealingLR \
|
| 24 |
+
--max_steps 150100 \
|
| 25 |
+
--save_freq 150000 \
|
| 26 |
+
--save_latest_checkpoint_only True \
|
| 27 |
+
--merge_lora_during_training True \
|
| 28 |
+
--regularization_lora_vector_path "$REGULARIZATION_LORA_VECTOR_PATH" \
|
| 29 |
+
--regularization_weight 1e-4 \
|
| 30 |
+
--image_aug True \
|
| 31 |
+
--lora_rank 32 \
|
| 32 |
+
--wandb_entity "$WANDB_ENTITY" \
|
| 33 |
+
--wandb_project "$WANDB_PROJECT" \
|
| 34 |
+
--run_id_override "$VERSION"
|
| 35 |
+
|
| 36 |
+
python experiments/robot/libero/run_libero_eval.py --pretrained_checkpoint "$RUN_ROOT_DIR/$VERSION" --task_suite_name libero_${TASK} > "$EVAL_LOG_PATH" 2>&1
|
capvector-oft/vla-scripts/extern/convert_openvla_weights_to_hf.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
convert_openvla_weights_to_hf.py
|
| 3 |
+
|
| 4 |
+
Utility script for converting full OpenVLA VLA weights (from this repository, in the default "Prismatic" format) to
|
| 5 |
+
the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers``
|
| 6 |
+
via `trust_remote_code = True`.
|
| 7 |
+
|
| 8 |
+
Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the
|
| 9 |
+
line, with first-class support.
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python vla-scripts/extern/convert_openvla_weights_to_hf.py \
|
| 13 |
+
--openvla_model_path_or_id <PATH TO PRISMATIC TRAINING RUN DIR> \
|
| 14 |
+
--output_hf_model_local_path <OUTPUT DIR FOR CONVERTED CHECKPOINT>
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import shutil
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Dict, Union
|
| 23 |
+
|
| 24 |
+
import draccus
|
| 25 |
+
import timm
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from huggingface_hub import hf_hub_download
|
| 29 |
+
from timm.models.vision_transformer import LayerScale
|
| 30 |
+
from transformers import AutoTokenizer
|
| 31 |
+
|
| 32 |
+
from prismatic.conf import ModelConfig
|
| 33 |
+
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
|
| 34 |
+
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
|
| 35 |
+
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class HFConvertConfig:
|
| 40 |
+
# fmt: off
|
| 41 |
+
openvla_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLA (on disk or HF Hub)
|
| 42 |
+
"runs/prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7"
|
| 43 |
+
)
|
| 44 |
+
output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model
|
| 45 |
+
"hf-convert/openvla-7b"
|
| 46 |
+
)
|
| 47 |
+
output_hf_model_hub_path: str = "openvla/openvla-7b" # (Optional) Path to HF Hub Path to push
|
| 48 |
+
# model to
|
| 49 |
+
|
| 50 |
+
# HF Hub Credentials (required for Gated Models like LLaMa-2)
|
| 51 |
+
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
| 52 |
+
|
| 53 |
+
def __post_init__(self) -> None:
|
| 54 |
+
self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token
|
| 55 |
+
|
| 56 |
+
# fmt: on
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
|
| 60 |
+
# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
|
| 61 |
+
# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
|
| 62 |
+
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def ls_apply_patch(ls_module: LayerScale):
|
| 67 |
+
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
|
| 68 |
+
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
|
| 69 |
+
del ls_module.gamma
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# === Conversion Constants ===
|
| 73 |
+
PROJECTOR_KEY_MAPPING = {
|
| 74 |
+
"projector.0.weight": "projector.fc1.weight",
|
| 75 |
+
"projector.0.bias": "projector.fc1.bias",
|
| 76 |
+
"projector.2.weight": "projector.fc2.weight",
|
| 77 |
+
"projector.2.bias": "projector.fc2.bias",
|
| 78 |
+
"projector.4.weight": "projector.fc3.weight",
|
| 79 |
+
"projector.4.bias": "projector.fc3.bias",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def remap_state_dicts_for_hf(
|
| 84 |
+
prismatic_vision_backbone_state_dict: Dict[str, torch.Tensor],
|
| 85 |
+
projector_state_dict: Dict[str, torch.Tensor],
|
| 86 |
+
llm_backbone_state_dict: Dict[str, torch.Tensor],
|
| 87 |
+
use_fused_vision_backbone: bool = False,
|
| 88 |
+
) -> Dict[str, torch.Tensor]:
|
| 89 |
+
"""Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion."""
|
| 90 |
+
hf_state_dict = {}
|
| 91 |
+
|
| 92 |
+
# Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING`
|
| 93 |
+
for key, value in projector_state_dict.items():
|
| 94 |
+
hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value
|
| 95 |
+
|
| 96 |
+
# Iterate through LLM Backbone =>> replace `llm.` with `language_model.`
|
| 97 |
+
for key, value in llm_backbone_state_dict.items():
|
| 98 |
+
hf_state_dict[key.replace("llm.", "language_model.")] = value
|
| 99 |
+
|
| 100 |
+
# Iterate through Vision Backbone =>> add "vision_backbone." prefix
|
| 101 |
+
if not use_fused_vision_backbone:
|
| 102 |
+
for key, value in prismatic_vision_backbone_state_dict.items():
|
| 103 |
+
hf_state_dict[key.replace("featurizer.", "vision_backbone.featurizer.")] = value
|
| 104 |
+
else:
|
| 105 |
+
# Note =>> Assumes that backbones are always DINO + SigLIP...
|
| 106 |
+
for key, value in prismatic_vision_backbone_state_dict.items():
|
| 107 |
+
if key.startswith("dino_featurizer"):
|
| 108 |
+
if key.endswith(".gamma"):
|
| 109 |
+
# Handle `LayerScale gamma` =>> DINOv2 only!
|
| 110 |
+
key = key.replace(".gamma", ".scale_factor")
|
| 111 |
+
hf_state_dict[key.replace("dino_featurizer.", "vision_backbone.featurizer.")] = value
|
| 112 |
+
elif key.startswith("siglip_featurizer"):
|
| 113 |
+
hf_state_dict[key.replace("siglip_featurizer.", "vision_backbone.fused_featurizer.")] = value
|
| 114 |
+
|
| 115 |
+
return hf_state_dict
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@draccus.wrap()
|
| 119 |
+
def convert_openvla_weights_to_hf(cfg: HFConvertConfig) -> None:
|
| 120 |
+
print(f"[*] Converting OpenVLA Model `{cfg.openvla_model_path_or_id}` to HF Transformers Format")
|
| 121 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 122 |
+
|
| 123 |
+
# Get `config.json`, 'dataset_statistics.json' and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py`
|
| 124 |
+
if os.path.isdir(cfg.openvla_model_path_or_id):
|
| 125 |
+
print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.openvla_model_path_or_id))}`")
|
| 126 |
+
config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
|
| 127 |
+
dataset_statistics_json = run_dir / "dataset_statistics.json"
|
| 128 |
+
|
| 129 |
+
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
|
| 130 |
+
assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
|
| 131 |
+
assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
|
| 132 |
+
else:
|
| 133 |
+
print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.openvla_model_path_or_id}`")
|
| 134 |
+
config_json = hf_hub_download("openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/config.json")
|
| 135 |
+
checkpoint_pt = hf_hub_download(
|
| 136 |
+
"openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/checkpoints/latest-checkpoint.pt"
|
| 137 |
+
)
|
| 138 |
+
dataset_statistics_json = hf_hub_download(
|
| 139 |
+
"openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/dataset_statistics.json"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer
|
| 143 |
+
with open(config_json, "r") as f:
|
| 144 |
+
vla_cfg = json.load(f)["vla"]
|
| 145 |
+
prismatic_config = ModelConfig.get_choice_class(vla_cfg["base_vlm"])().__dict__
|
| 146 |
+
|
| 147 |
+
# Load Normalization Statistics
|
| 148 |
+
with open(dataset_statistics_json, "r") as f:
|
| 149 |
+
norm_stats = json.load(f)
|
| 150 |
+
|
| 151 |
+
# Create HF OpenVLAConfig (`transformers.PretrainedConfig`)
|
| 152 |
+
hf_config = OpenVLAConfig(
|
| 153 |
+
vision_backbone_id=prismatic_config["vision_backbone_id"],
|
| 154 |
+
llm_backbone_id=prismatic_config["llm_backbone_id"],
|
| 155 |
+
arch_specifier=prismatic_config["arch_specifier"],
|
| 156 |
+
image_resize_strategy=prismatic_config["image_resize_strategy"],
|
| 157 |
+
llm_max_length=prismatic_config["llm_max_length"],
|
| 158 |
+
torch_dtype=torch.bfloat16,
|
| 159 |
+
norm_stats=norm_stats,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer`
|
| 163 |
+
# TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`!
|
| 164 |
+
print("[*] Instantiating and Patching Tokenizer, LLM Config")
|
| 165 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 166 |
+
hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right"
|
| 167 |
+
)
|
| 168 |
+
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 169 |
+
tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload...
|
| 170 |
+
assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!"
|
| 171 |
+
assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!"
|
| 172 |
+
|
| 173 |
+
# Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate
|
| 174 |
+
hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of
|
| 175 |
+
hf_config.text_config.pad_token_id = hf_config.pad_token_id
|
| 176 |
+
hf_config.text_config.torch_dtype = torch.bfloat16
|
| 177 |
+
assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!"
|
| 178 |
+
|
| 179 |
+
# Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform`
|
| 180 |
+
# =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py`
|
| 181 |
+
print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor")
|
| 182 |
+
input_sizes, interpolations, means, stds = [], [], [], []
|
| 183 |
+
for idx, timm_model_id in enumerate(hf_config.timm_model_ids):
|
| 184 |
+
timm_vision_backbone = timm.create_model(
|
| 185 |
+
timm_model_id,
|
| 186 |
+
pretrained=True,
|
| 187 |
+
num_classes=0,
|
| 188 |
+
img_size=hf_config.image_sizes[idx],
|
| 189 |
+
act_layer=hf_config.timm_override_act_layers[idx],
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Get Per-Backbone Image Processing
|
| 193 |
+
data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone)
|
| 194 |
+
input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx]))
|
| 195 |
+
interpolations.append(data_cfg["interpolation"])
|
| 196 |
+
means.append(data_cfg["mean"])
|
| 197 |
+
stds.append(data_cfg["std"])
|
| 198 |
+
|
| 199 |
+
# Patch `LayerScale` because of HF annoying `fix_key` overwrite...
|
| 200 |
+
for module in timm_vision_backbone.modules():
|
| 201 |
+
if isinstance(module, LayerScale):
|
| 202 |
+
ls_apply_patch(module)
|
| 203 |
+
|
| 204 |
+
# Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`)
|
| 205 |
+
hf_image_processor = PrismaticImageProcessor(
|
| 206 |
+
use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
|
| 207 |
+
image_resize_strategy=hf_config.image_resize_strategy,
|
| 208 |
+
input_sizes=input_sizes,
|
| 209 |
+
interpolations=interpolations,
|
| 210 |
+
means=means,
|
| 211 |
+
stds=stds,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor)
|
| 215 |
+
print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor")
|
| 216 |
+
hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer)
|
| 217 |
+
|
| 218 |
+
# Load Prismatic Model State Dictionary (in preparation for conversion)
|
| 219 |
+
print("[*] Loading Prismatic VLM State Dictionary from Checkpoint")
|
| 220 |
+
model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"]
|
| 221 |
+
assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?"
|
| 222 |
+
assert all([k in model_state_dict for k in ["vision_backbone", "projector", "llm_backbone"]]), "Missing keys!"
|
| 223 |
+
|
| 224 |
+
# Convert
|
| 225 |
+
print("[*] Running Conversion")
|
| 226 |
+
converted_state_dict = remap_state_dicts_for_hf(
|
| 227 |
+
model_state_dict["vision_backbone"],
|
| 228 |
+
model_state_dict["projector"],
|
| 229 |
+
model_state_dict["llm_backbone"],
|
| 230 |
+
use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM
|
| 234 |
+
print("[*] Building (Randomly Initialized) Model =>> OpenVLAForActionPrediction")
|
| 235 |
+
hf_model = OpenVLAForActionPrediction(hf_config)
|
| 236 |
+
hf_model.load_state_dict(converted_state_dict, strict=True, assign=True)
|
| 237 |
+
|
| 238 |
+
# Cast Model to BF16 before Saving
|
| 239 |
+
hf_model.to(torch.bfloat16)
|
| 240 |
+
|
| 241 |
+
# Save Pretrained Versions to Local Path
|
| 242 |
+
print("[*] Saving Model & Processor to Local Path")
|
| 243 |
+
hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB")
|
| 244 |
+
hf_image_processor.save_pretrained(cfg.output_hf_model_local_path)
|
| 245 |
+
hf_processor.save_pretrained(cfg.output_hf_model_local_path)
|
| 246 |
+
|
| 247 |
+
# Copy `dataset_statistics.json` File to Converted Checkpoint Directory
|
| 248 |
+
output_dataset_statistics_json = cfg.output_hf_model_local_path / "dataset_statistics.json"
|
| 249 |
+
shutil.copyfile(dataset_statistics_json, output_dataset_statistics_json)
|
| 250 |
+
|
| 251 |
+
print(f"[*] Saving Complete! Saved converted checkpoint to: {cfg.output_hf_model_local_path}")
|
| 252 |
+
|
| 253 |
+
#####################################################################################
|
| 254 |
+
# Optional: Push Model to Hugging Face Hub
|
| 255 |
+
#####################################################################################
|
| 256 |
+
|
| 257 |
+
# # Register AutoClasses
|
| 258 |
+
# OpenVLAConfig.register_for_auto_class()
|
| 259 |
+
# PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor")
|
| 260 |
+
# PrismaticProcessor.register_for_auto_class("AutoProcessor")
|
| 261 |
+
# OpenVLAForActionPrediction.register_for_auto_class("AutoModelForVision2Seq")
|
| 262 |
+
|
| 263 |
+
# # Push to HF Hub
|
| 264 |
+
# print("[*] Pushing Model & Processor to HF Hub")
|
| 265 |
+
# hf_config.push_to_hub(cfg.output_hf_model_hub_path)
|
| 266 |
+
# hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB")
|
| 267 |
+
# hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path)
|
| 268 |
+
# hf_processor.push_to_hub(cfg.output_hf_model_hub_path)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
convert_openvla_weights_to_hf()
|
capvector-oft/vla-scripts/extern/verify_openvla.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
verify_openvla.py
|
| 3 |
+
|
| 4 |
+
Given an HF-exported OpenVLA model, attempt to load via AutoClasses, and verify forward() and predict_action().
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from transformers import AutoModelForVision2Seq, AutoProcessor
|
| 13 |
+
|
| 14 |
+
# === Verification Arguments
|
| 15 |
+
MODEL_PATH = "openvla/openvla-7b"
|
| 16 |
+
SYSTEM_PROMPT = (
|
| 17 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
| 18 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
| 19 |
+
)
|
| 20 |
+
INSTRUCTION = "put spoon on towel"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_openvla_prompt(instruction: str) -> str:
|
| 24 |
+
if "v01" in MODEL_PATH:
|
| 25 |
+
return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT:"
|
| 26 |
+
else:
|
| 27 |
+
return f"In: What action should the robot take to {instruction.lower()}?\nOut:"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.inference_mode()
|
| 31 |
+
def verify_openvla() -> None:
|
| 32 |
+
print(f"[*] Verifying OpenVLAForActionPrediction using Model `{MODEL_PATH}`")
|
| 33 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 34 |
+
|
| 35 |
+
# Load Processor & VLA
|
| 36 |
+
print("[*] Instantiating Processor and Pretrained OpenVLA")
|
| 37 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
| 38 |
+
|
| 39 |
+
# === BFLOAT16 + FLASH-ATTN MODE ===
|
| 40 |
+
print("[*] Loading in BF16 with Flash-Attention Enabled")
|
| 41 |
+
vla = AutoModelForVision2Seq.from_pretrained(
|
| 42 |
+
MODEL_PATH,
|
| 43 |
+
attn_implementation="flash_attention_2",
|
| 44 |
+
torch_dtype=torch.bfloat16,
|
| 45 |
+
low_cpu_mem_usage=True,
|
| 46 |
+
trust_remote_code=True,
|
| 47 |
+
).to(device)
|
| 48 |
+
|
| 49 |
+
# === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] ===
|
| 50 |
+
# print("[*] Loading in 8-Bit Quantization Mode")
|
| 51 |
+
# vla = AutoModelForVision2Seq.from_pretrained(
|
| 52 |
+
# MODEL_PATH,
|
| 53 |
+
# attn_implementation="flash_attention_2",
|
| 54 |
+
# torch_dtype=torch.float16,
|
| 55 |
+
# quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
| 56 |
+
# low_cpu_mem_usage=True,
|
| 57 |
+
# trust_remote_code=True,
|
| 58 |
+
# )
|
| 59 |
+
|
| 60 |
+
# === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] ===
|
| 61 |
+
# print("[*] Loading in 4-Bit Quantization Mode")
|
| 62 |
+
# vla = AutoModelForVision2Seq.from_pretrained(
|
| 63 |
+
# MODEL_PATH,
|
| 64 |
+
# attn_implementation="flash_attention_2",
|
| 65 |
+
# torch_dtype=torch.float16,
|
| 66 |
+
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
| 67 |
+
# low_cpu_mem_usage=True,
|
| 68 |
+
# trust_remote_code=True,
|
| 69 |
+
# )
|
| 70 |
+
|
| 71 |
+
print("[*] Iterating with Randomly Generated Images")
|
| 72 |
+
for _ in range(100):
|
| 73 |
+
prompt = get_openvla_prompt(INSTRUCTION)
|
| 74 |
+
image = Image.fromarray(np.asarray(np.random.rand(256, 256, 3) * 255, dtype=np.uint8))
|
| 75 |
+
|
| 76 |
+
# === BFLOAT16 MODE ===
|
| 77 |
+
inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
|
| 78 |
+
|
| 79 |
+
# === 8-BIT/4-BIT QUANTIZATION MODE ===
|
| 80 |
+
# inputs = processor(prompt, image).to(device, dtype=torch.float16)
|
| 81 |
+
|
| 82 |
+
# Run OpenVLA Inference
|
| 83 |
+
start_time = time.time()
|
| 84 |
+
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
|
| 85 |
+
print(f"\t=>> Time: {time.time() - start_time:.4f} || Action: {action}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
verify_openvla()
|
capvector-oft/vla-scripts/finetune.py
ADDED
|
@@ -0,0 +1,1152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
finetune.py
|
| 3 |
+
|
| 4 |
+
Fine-tunes OpenVLA via LoRA.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
from collections import deque
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, Optional, Tuple, Type
|
| 13 |
+
|
| 14 |
+
import draccus
|
| 15 |
+
import torch
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import tqdm
|
| 19 |
+
from accelerate import PartialState
|
| 20 |
+
from huggingface_hub import HfApi, snapshot_download
|
| 21 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 23 |
+
from torch.optim import AdamW
|
| 24 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
| 25 |
+
from torch.utils.data import DataLoader
|
| 26 |
+
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
|
| 27 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 28 |
+
|
| 29 |
+
import wandb
|
| 30 |
+
os.environ["WANDB_MODE"]="offline"
|
| 31 |
+
|
| 32 |
+
from experiments.robot.openvla_utils import (
|
| 33 |
+
check_model_logic_mismatch,
|
| 34 |
+
model_is_on_hf_hub,
|
| 35 |
+
update_auto_map,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
|
| 39 |
+
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
|
| 40 |
+
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| 41 |
+
from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead
|
| 42 |
+
from prismatic.models.backbones.llm.prompting import PurePromptBuilder
|
| 43 |
+
from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone
|
| 44 |
+
from prismatic.models.projectors import (
|
| 45 |
+
NoisyActionProjector,
|
| 46 |
+
ProprioProjector,
|
| 47 |
+
)
|
| 48 |
+
from prismatic.training.train_utils import (
|
| 49 |
+
compute_actions_l1_loss,
|
| 50 |
+
compute_token_accuracy,
|
| 51 |
+
get_current_action_mask,
|
| 52 |
+
get_next_actions_mask,
|
| 53 |
+
)
|
| 54 |
+
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
|
| 55 |
+
from prismatic.vla.action_tokenizer import ActionTokenizer
|
| 56 |
+
from prismatic.vla.constants import (
|
| 57 |
+
ACTION_DIM,
|
| 58 |
+
ACTION_PROPRIO_NORMALIZATION_TYPE,
|
| 59 |
+
NUM_ACTIONS_CHUNK,
|
| 60 |
+
PROPRIO_DIM,
|
| 61 |
+
)
|
| 62 |
+
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
|
| 63 |
+
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics
|
| 64 |
+
|
| 65 |
+
# Sane Defaults
|
| 66 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
import debugpy
|
| 70 |
+
try:
|
| 71 |
+
debugpy.listen(("localhost", 9501))
|
| 72 |
+
print("Waiting for debugger attach")
|
| 73 |
+
debugpy.wait_for_client()
|
| 74 |
+
except Exception as e:
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass
|
| 79 |
+
class FinetuneConfig:
|
| 80 |
+
# fmt: off
|
| 81 |
+
vla_path: str = "openvla/openvla-7b" # Path to OpenVLA model (on HuggingFace Hub or stored locally)
|
| 82 |
+
|
| 83 |
+
# Dataset
|
| 84 |
+
data_root_dir: Path = Path("datasets/rlds") # Directory containing RLDS datasets
|
| 85 |
+
dataset_name: str = "aloha_scoop_x_into_bowl" # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`)
|
| 86 |
+
run_root_dir: Path = Path("runs") # Path to directory to store logs & checkpoints
|
| 87 |
+
shuffle_buffer_size: int = 100_000 # Dataloader shuffle buffer size (can reduce if OOM errors occur)
|
| 88 |
+
|
| 89 |
+
# Algorithm and architecture
|
| 90 |
+
use_l1_regression: bool = True # If True, trains continuous action head with L1 regression objective
|
| 91 |
+
use_diffusion: bool = False # If True, trains continuous action head with diffusion modeling objective (DDIM)
|
| 92 |
+
num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
|
| 93 |
+
use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
|
| 94 |
+
num_images_in_input: int = 1 # Number of images in the VLA input (default: 1)
|
| 95 |
+
use_proprio: bool = False # If True, includes robot proprioceptive state in input
|
| 96 |
+
|
| 97 |
+
# Training configuration
|
| 98 |
+
batch_size: int = 8 # Batch size per device (total batch size = batch_size * num GPUs)
|
| 99 |
+
learning_rate: float = 5e-4 # Learning rate
|
| 100 |
+
lr_warmup_steps: int = 0 # Number of steps to warm up learning rate (from 10% to 100%)
|
| 101 |
+
num_steps_before_decay: int = 100_000 # Number of steps before LR decays by 10x
|
| 102 |
+
grad_accumulation_steps: int = 1 # Number of gradient accumulation steps
|
| 103 |
+
max_steps: int = 200_000 # Max number of training steps
|
| 104 |
+
use_val_set: bool = False # If True, uses validation set and log validation metrics
|
| 105 |
+
val_freq: int = 10_000 # (When `use_val_set==True`) Validation set logging frequency in steps
|
| 106 |
+
val_time_limit: int = 180 # (When `use_val_set==True`) Time limit for computing validation metrics
|
| 107 |
+
save_freq: int = 10_000 # Checkpoint saving frequency in steps
|
| 108 |
+
save_latest_checkpoint_only: bool = False # If True, saves only 1 checkpoint, overwriting latest checkpoint
|
| 109 |
+
# (If False, saves all checkpoints)
|
| 110 |
+
resume: bool = False # If True, resumes from checkpoint
|
| 111 |
+
resume_step: Optional[int] = None # (When `resume==True`) Step number that we are resuming from
|
| 112 |
+
image_aug: bool = True # If True, trains with image augmentations (HIGHLY RECOMMENDED)
|
| 113 |
+
diffusion_sample_freq: int = 50 # (When `use_diffusion==True`) Frequency for sampling in steps
|
| 114 |
+
|
| 115 |
+
# LoRA
|
| 116 |
+
use_lora: bool = True # If True, uses LoRA fine-tuning
|
| 117 |
+
lora_rank: int = 32 # Rank of LoRA weight matrix
|
| 118 |
+
lora_dropout: float = 0.0 # Dropout applied to LoRA weights
|
| 119 |
+
merge_lora_during_training: bool = True # If True, merges LoRA weights and saves result during training
|
| 120 |
+
# Note: Merging can be very slow on some machines. If so, set to
|
| 121 |
+
# False and merge final checkpoint offline!
|
| 122 |
+
|
| 123 |
+
# Logging
|
| 124 |
+
wandb_entity: str = "your-wandb-entity" # Name of WandB entity
|
| 125 |
+
wandb_project: str = "your-wandb-project" # Name of WandB project
|
| 126 |
+
run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
|
| 127 |
+
run_id_override: Optional[str] = None # Optional string to override the run ID with
|
| 128 |
+
wandb_log_freq: int = 10 # WandB logging frequency in steps
|
| 129 |
+
|
| 130 |
+
# fmt: on
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def remove_ddp_in_checkpoint(state_dict) -> dict:
|
| 134 |
+
"""
|
| 135 |
+
Removes the 'module.' prefix from parameter names in a PyTorch model state dictionary that was saved using
|
| 136 |
+
DistributedDataParallel (DDP).
|
| 137 |
+
|
| 138 |
+
When a model is trained using PyTorch's DistributedDataParallel, the saved state dictionary contains parameters
|
| 139 |
+
prefixed with 'module.'. This function removes these prefixes to make the state dictionary compatible when
|
| 140 |
+
loading into models that are not yet wrapped in DDP.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
state_dict (dict): PyTorch model state dictionary.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
dict: A new state dictionary with the same contents but with 'module.' prefixes removed from parameter names.
|
| 147 |
+
Parameters without the 'module.' prefix remain unchanged.
|
| 148 |
+
"""
|
| 149 |
+
new_state_dict = {}
|
| 150 |
+
for k, v in state_dict.items():
|
| 151 |
+
if k[:7] == "module.":
|
| 152 |
+
new_state_dict[k[7:]] = v
|
| 153 |
+
else:
|
| 154 |
+
new_state_dict[k] = v
|
| 155 |
+
return new_state_dict
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_run_id(cfg) -> str:
|
| 159 |
+
"""
|
| 160 |
+
Generates or retrieves an identifier string for an experiment run.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
cfg (FinetuneConfig): Training configuration.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
str: Experiment run ID.
|
| 167 |
+
"""
|
| 168 |
+
if cfg.run_id_override is not None:
|
| 169 |
+
# Override the run ID with the user-provided ID
|
| 170 |
+
run_id = cfg.run_id_override
|
| 171 |
+
elif cfg.resume:
|
| 172 |
+
# Override run ID with the previous resumed run's ID
|
| 173 |
+
run_id = cfg.vla_path.split("/")[-1]
|
| 174 |
+
# Remove the "--XXX_chkpt" suffix from the run ID if it exists
|
| 175 |
+
if "chkpt" in run_id.split("--")[-1]:
|
| 176 |
+
run_id = "--".join(run_id.split("--")[:-1])
|
| 177 |
+
else:
|
| 178 |
+
run_id = (
|
| 179 |
+
f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}"
|
| 180 |
+
f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
|
| 181 |
+
f"+lr-{cfg.learning_rate}"
|
| 182 |
+
)
|
| 183 |
+
if cfg.use_lora:
|
| 184 |
+
run_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}"
|
| 185 |
+
if cfg.image_aug:
|
| 186 |
+
run_id += "--image_aug"
|
| 187 |
+
if cfg.run_id_note is not None:
|
| 188 |
+
run_id += f"--{cfg.run_id_note}"
|
| 189 |
+
return run_id
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def load_checkpoint(module_name: str, path: str, step: int, device: str = "cpu") -> dict:
|
| 193 |
+
"""
|
| 194 |
+
Loads a checkpoint for a given module.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
module_name (str): Name of model component to load checkpoint for.
|
| 198 |
+
path (str): Path to checkpoint directory.
|
| 199 |
+
step (int): Gradient step number of saved checkpoint.
|
| 200 |
+
device (str): String specifying how to remap storage locations (default = "cpu").
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
dict: PyTorch model state dictionary.
|
| 204 |
+
"""
|
| 205 |
+
checkpoint_path = os.path.join(path, f"{module_name}--{step}_checkpoint.pt")
|
| 206 |
+
print(f"Loading checkpoint: {checkpoint_path}")
|
| 207 |
+
state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
|
| 208 |
+
return remove_ddp_in_checkpoint(state_dict)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def wrap_ddp(module: nn.Module, device_id: int, find_unused: bool = False) -> DDP:
|
| 212 |
+
"""
|
| 213 |
+
Wrap a module with DistributedDataParallel.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
module (nn.Module): PyTorch module.
|
| 217 |
+
device_id (str): Device ID.
|
| 218 |
+
find_unused (bool): Whether to detect parameters without gradients in distributed training.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
DistributedDataParallel: PyTorch module wrapped with DDP.
|
| 222 |
+
"""
|
| 223 |
+
return DDP(module, device_ids=[device_id], find_unused_parameters=find_unused, gradient_as_bucket_view=True)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def count_parameters(module: nn.Module, name: str) -> None:
|
| 227 |
+
"""
|
| 228 |
+
Counts and prints the number of trainable parameters in a module.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
module (nn.Module): PyTorch module.
|
| 232 |
+
module_name (str): Name of model component.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
None.
|
| 236 |
+
"""
|
| 237 |
+
num_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 238 |
+
print(f"# trainable params in {name}: {num_params}")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def init_module(
|
| 242 |
+
module_class: Type[nn.Module],
|
| 243 |
+
module_name: str,
|
| 244 |
+
cfg: FinetuneConfig,
|
| 245 |
+
device_id: int,
|
| 246 |
+
module_args: dict,
|
| 247 |
+
to_bf16: bool = False,
|
| 248 |
+
find_unused_params: bool = False,
|
| 249 |
+
) -> DDP:
|
| 250 |
+
"""
|
| 251 |
+
Initializes a module, optionally loads checkpoint, moves to device, and wraps with DDP.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
module_class (Type[nn.Module]): Class of PyTorch module to initialize.
|
| 255 |
+
module_name (str): Name of model component to load checkpoint for.
|
| 256 |
+
cfg (FinetuneConfig): Training configuration.
|
| 257 |
+
device_id (str): Device ID.
|
| 258 |
+
module_args (dict): Args for initializing the module.
|
| 259 |
+
to_bf16 (bool): Whether to convert to torch.bfloat16 data type.
|
| 260 |
+
find_unused_params (bool): Whether to detect parameters without gradients in distributed training.
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
DistributedDataParallel: PyTorch module wrapped with DDP.
|
| 264 |
+
"""
|
| 265 |
+
module = module_class(**module_args)
|
| 266 |
+
count_parameters(module, module_name)
|
| 267 |
+
|
| 268 |
+
if cfg.resume:
|
| 269 |
+
state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step)
|
| 270 |
+
module.load_state_dict(state_dict)
|
| 271 |
+
|
| 272 |
+
if to_bf16:
|
| 273 |
+
module = module.to(torch.bfloat16)
|
| 274 |
+
module = module.to(device_id)
|
| 275 |
+
|
| 276 |
+
return wrap_ddp(module, device_id, find_unused_params)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def run_forward_pass(
|
| 280 |
+
vla,
|
| 281 |
+
action_head,
|
| 282 |
+
noisy_action_projector,
|
| 283 |
+
proprio_projector,
|
| 284 |
+
batch,
|
| 285 |
+
action_tokenizer,
|
| 286 |
+
device_id,
|
| 287 |
+
use_l1_regression,
|
| 288 |
+
use_diffusion,
|
| 289 |
+
use_proprio,
|
| 290 |
+
use_film,
|
| 291 |
+
num_patches,
|
| 292 |
+
compute_diffusion_l1=False,
|
| 293 |
+
num_diffusion_steps_train=None,
|
| 294 |
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
| 295 |
+
"""
|
| 296 |
+
Compute model forward pass and metrics for both training and validation.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| 300 |
+
action_head (nn.Module): Action head module.
|
| 301 |
+
noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| 302 |
+
proprio_projector (nn.Module): Proprioceptive state projector module.
|
| 303 |
+
batch (dict): Input batch.
|
| 304 |
+
action_tokenizer (ActionTokenizer): Action tokenizer.
|
| 305 |
+
device_id (str): Device ID.
|
| 306 |
+
use_l1_regression (bool): Whether to use L1 regression.
|
| 307 |
+
use_diffusion (bool): Whether to use diffusion.
|
| 308 |
+
use_proprio (bool): Whether to use proprioceptive state as input.
|
| 309 |
+
use_film (bool): Whether to use FiLM for better language following.
|
| 310 |
+
num_patches (int): Number of vision patches.
|
| 311 |
+
compute_diffusion_l1 (bool): Whether to sample actions and compute L1 loss for diffusion (do this once every
|
| 312 |
+
diffusion_sample_freq steps during training; do it every batch for validation)
|
| 313 |
+
num_diffusion_steps_train (int): Number of diffusion steps for training (only used for diffusion).
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
tuple: (loss, metrics_dict)
|
| 317 |
+
loss: The loss tensor with gradient for backpropagation.
|
| 318 |
+
metrics_dict: Dictionary of computed metrics (detached values for logging).
|
| 319 |
+
"""
|
| 320 |
+
metrics = {}
|
| 321 |
+
|
| 322 |
+
# Get ground-truth action labels
|
| 323 |
+
ground_truth_actions = batch["actions"].to(device_id).to(torch.bfloat16)
|
| 324 |
+
|
| 325 |
+
# [Only for diffusion] Sample noisy actions used as input for noise predictor network
|
| 326 |
+
if use_diffusion:
|
| 327 |
+
noisy_dict = action_head.module.sample_noisy_actions(ground_truth_actions)
|
| 328 |
+
noise, noisy_actions, diffusion_timestep_embeddings = (
|
| 329 |
+
noisy_dict["noise"],
|
| 330 |
+
noisy_dict["noisy_actions"],
|
| 331 |
+
noisy_dict["diffusion_timestep_embeddings"],
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
noise, noisy_actions, diffusion_timestep_embeddings = None, None, None
|
| 335 |
+
|
| 336 |
+
# VLA forward pass
|
| 337 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 338 |
+
output: CausalLMOutputWithPast = vla(
|
| 339 |
+
input_ids=batch["input_ids"].to(device_id),
|
| 340 |
+
attention_mask=batch["attention_mask"].to(device_id),
|
| 341 |
+
pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
|
| 342 |
+
labels=batch["labels"],
|
| 343 |
+
output_hidden_states=True,
|
| 344 |
+
proprio=batch["proprio"] if use_proprio else None,
|
| 345 |
+
proprio_projector=proprio_projector if use_proprio else None,
|
| 346 |
+
noisy_actions=noisy_actions if use_diffusion else None,
|
| 347 |
+
noisy_action_projector=noisy_action_projector if use_diffusion else None,
|
| 348 |
+
diffusion_timestep_embeddings=diffusion_timestep_embeddings if use_diffusion else None,
|
| 349 |
+
use_film=use_film,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Get action masks needed for logging
|
| 353 |
+
ground_truth_token_ids = batch["labels"][:, 1:].to(device_id)
|
| 354 |
+
current_action_mask = get_current_action_mask(ground_truth_token_ids)
|
| 355 |
+
next_actions_mask = get_next_actions_mask(ground_truth_token_ids)
|
| 356 |
+
|
| 357 |
+
# Compute metrics for discrete action representation (next-token prediction)
|
| 358 |
+
if not (use_l1_regression or use_diffusion):
|
| 359 |
+
loss = output.loss
|
| 360 |
+
predicted_token_ids = output.logits[:, num_patches:-1].argmax(dim=2)
|
| 361 |
+
curr_action_accuracy = compute_token_accuracy(
|
| 362 |
+
predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
|
| 363 |
+
)
|
| 364 |
+
curr_action_l1_loss = compute_actions_l1_loss(
|
| 365 |
+
action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
|
| 366 |
+
)
|
| 367 |
+
next_actions_accuracy = compute_token_accuracy(
|
| 368 |
+
predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
|
| 369 |
+
)
|
| 370 |
+
next_actions_l1_loss = compute_actions_l1_loss(
|
| 371 |
+
action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
|
| 372 |
+
)
|
| 373 |
+
metrics.update(
|
| 374 |
+
{
|
| 375 |
+
"loss_value": loss.item(), # Detached value for logging
|
| 376 |
+
"curr_action_accuracy": curr_action_accuracy.item(),
|
| 377 |
+
"curr_action_l1_loss": curr_action_l1_loss.item(),
|
| 378 |
+
"next_actions_accuracy": next_actions_accuracy.item(),
|
| 379 |
+
"next_actions_l1_loss": next_actions_l1_loss.item(),
|
| 380 |
+
}
|
| 381 |
+
)
|
| 382 |
+
# Compute metrics for continuous action representations (L1 regression | diffusion)
|
| 383 |
+
else:
|
| 384 |
+
# Get last layer hidden states
|
| 385 |
+
last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
| 386 |
+
# Get hidden states for text portion of prompt+response (after the vision patches)
|
| 387 |
+
text_hidden_states = last_hidden_states[:, num_patches:-1]
|
| 388 |
+
# Get hidden states for action portion of response
|
| 389 |
+
batch_size = batch["input_ids"].shape[0]
|
| 390 |
+
actions_hidden_states = (
|
| 391 |
+
text_hidden_states[current_action_mask | next_actions_mask]
|
| 392 |
+
.reshape(batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1)
|
| 393 |
+
.to(torch.bfloat16)
|
| 394 |
+
) # (B, act_chunk_len, D)
|
| 395 |
+
|
| 396 |
+
if use_l1_regression:
|
| 397 |
+
# Predict action
|
| 398 |
+
predicted_actions = action_head.module.predict_action(actions_hidden_states)
|
| 399 |
+
# Get full L1 loss
|
| 400 |
+
loss = torch.nn.L1Loss()(ground_truth_actions, predicted_actions)
|
| 401 |
+
|
| 402 |
+
if use_diffusion:
|
| 403 |
+
# Predict noise
|
| 404 |
+
noise_pred = action_head.module.predict_noise(actions_hidden_states)
|
| 405 |
+
# Get diffusion noise prediction MSE loss
|
| 406 |
+
noise_pred = noise_pred.reshape(noise.shape)
|
| 407 |
+
loss = nn.functional.mse_loss(noise_pred, noise, reduction="mean")
|
| 408 |
+
|
| 409 |
+
# Only sample actions and compute L1 losses if specified
|
| 410 |
+
if compute_diffusion_l1:
|
| 411 |
+
with torch.no_grad():
|
| 412 |
+
predicted_actions = run_diffusion_sampling(
|
| 413 |
+
vla=vla,
|
| 414 |
+
action_head=action_head,
|
| 415 |
+
noisy_action_projector=noisy_action_projector,
|
| 416 |
+
proprio_projector=proprio_projector,
|
| 417 |
+
batch=batch,
|
| 418 |
+
batch_size=batch_size,
|
| 419 |
+
num_patches=num_patches,
|
| 420 |
+
actions_shape=ground_truth_actions.shape,
|
| 421 |
+
device_id=device_id,
|
| 422 |
+
current_action_mask=current_action_mask,
|
| 423 |
+
next_actions_mask=next_actions_mask,
|
| 424 |
+
use_proprio=use_proprio,
|
| 425 |
+
use_film=use_film,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
metrics.update(
|
| 429 |
+
{
|
| 430 |
+
"loss_value": loss.item(), # Detached value for logging
|
| 431 |
+
}
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Get detailed L1 losses for logging
|
| 435 |
+
should_log_l1_loss = not use_diffusion or (use_diffusion and compute_diffusion_l1)
|
| 436 |
+
if should_log_l1_loss:
|
| 437 |
+
ground_truth_curr_action = ground_truth_actions[:, 0]
|
| 438 |
+
predicted_curr_action = predicted_actions[:, 0]
|
| 439 |
+
ground_truth_next_actions = ground_truth_actions[:, 1:]
|
| 440 |
+
predicted_next_actions = predicted_actions[:, 1:]
|
| 441 |
+
curr_action_l1_loss = torch.nn.L1Loss()(ground_truth_curr_action, predicted_curr_action)
|
| 442 |
+
next_actions_l1_loss = torch.nn.L1Loss()(ground_truth_next_actions, predicted_next_actions)
|
| 443 |
+
metrics.update(
|
| 444 |
+
{
|
| 445 |
+
"curr_action_l1_loss": curr_action_l1_loss.item(),
|
| 446 |
+
"next_actions_l1_loss": next_actions_l1_loss.item(),
|
| 447 |
+
}
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# Return both the loss tensor (with gradients) and the metrics dictionary (with detached values)
|
| 451 |
+
return loss, metrics
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def run_diffusion_sampling(
|
| 455 |
+
vla,
|
| 456 |
+
action_head,
|
| 457 |
+
noisy_action_projector,
|
| 458 |
+
proprio_projector,
|
| 459 |
+
batch,
|
| 460 |
+
batch_size,
|
| 461 |
+
num_patches,
|
| 462 |
+
actions_shape,
|
| 463 |
+
device_id,
|
| 464 |
+
current_action_mask,
|
| 465 |
+
next_actions_mask,
|
| 466 |
+
use_proprio,
|
| 467 |
+
use_film,
|
| 468 |
+
) -> torch.Tensor:
|
| 469 |
+
"""
|
| 470 |
+
Run diffusion sampling (reverse diffusion) to generate actions.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| 474 |
+
action_head (nn.Module): Action head module.
|
| 475 |
+
noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| 476 |
+
proprio_projector (nn.Module): Proprioceptive state projector module.
|
| 477 |
+
batch (dict): Input batch.
|
| 478 |
+
batch_size (int): Batch size.
|
| 479 |
+
num_patches (int): Number of vision patches.
|
| 480 |
+
actions_shape (tuple): Shape of ground-truth actions.
|
| 481 |
+
device_id (str): Device ID.
|
| 482 |
+
current_action_mask (torch.Tensor): Mask for current action.
|
| 483 |
+
next_actions_mask (torch.Tensor): Mask for next actions.
|
| 484 |
+
use_proprio (bool): Whether to use proprioceptive state as input.
|
| 485 |
+
use_film (bool): Whether to use FiLM for better language following.
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
torch.Tensor: Predicted actions.
|
| 489 |
+
"""
|
| 490 |
+
# Sample random noisy action, used as the starting point for reverse diffusion
|
| 491 |
+
noise = torch.randn(
|
| 492 |
+
size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM),
|
| 493 |
+
device=device_id,
|
| 494 |
+
dtype=torch.bfloat16,
|
| 495 |
+
) # (B, chunk_len, action_dim)
|
| 496 |
+
|
| 497 |
+
# Set diffusion timestep values
|
| 498 |
+
action_head.module.noise_scheduler.set_timesteps(action_head.module.num_diffusion_steps_train)
|
| 499 |
+
|
| 500 |
+
# Reverse diffusion: Iteratively denoise to generate action, conditioned on observation
|
| 501 |
+
curr_noisy_actions = noise
|
| 502 |
+
for t in action_head.module.noise_scheduler.timesteps:
|
| 503 |
+
# Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action embedding,
|
| 504 |
+
# and diffusion timestep embedding)
|
| 505 |
+
timesteps = torch.Tensor([t]).repeat(batch_size).to(device_id)
|
| 506 |
+
diffusion_timestep_embeddings = (
|
| 507 |
+
action_head.module.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
|
| 508 |
+
) # (B, llm_dim)
|
| 509 |
+
diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
|
| 510 |
+
|
| 511 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 512 |
+
output = vla(
|
| 513 |
+
input_ids=batch["input_ids"].to(device_id),
|
| 514 |
+
attention_mask=batch["attention_mask"].to(device_id),
|
| 515 |
+
pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
|
| 516 |
+
labels=batch["labels"],
|
| 517 |
+
output_hidden_states=True,
|
| 518 |
+
proprio=batch["proprio"] if use_proprio else None,
|
| 519 |
+
proprio_projector=proprio_projector if use_proprio else None,
|
| 520 |
+
noisy_actions=curr_noisy_actions,
|
| 521 |
+
noisy_action_projector=noisy_action_projector,
|
| 522 |
+
diffusion_timestep_embeddings=diffusion_timestep_embeddings,
|
| 523 |
+
use_film=use_film,
|
| 524 |
+
)
|
| 525 |
+
# Get last layer hidden states
|
| 526 |
+
last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
| 527 |
+
# Get hidden states for text portion of prompt+response (after the vision patches)
|
| 528 |
+
text_hidden_states = last_hidden_states[:, num_patches:-1]
|
| 529 |
+
# Get hidden states for action portion of response
|
| 530 |
+
actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask].reshape(
|
| 531 |
+
batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1
|
| 532 |
+
) # (B, act_chunk_len, D)
|
| 533 |
+
actions_hidden_states = actions_hidden_states.to(torch.bfloat16)
|
| 534 |
+
# Predict noise
|
| 535 |
+
noise_pred = action_head.module.predict_noise(actions_hidden_states)
|
| 536 |
+
|
| 537 |
+
# Compute the action at the previous diffusion timestep: x_t -> x_{t-1}
|
| 538 |
+
curr_noisy_actions = action_head.module.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
|
| 539 |
+
|
| 540 |
+
return curr_noisy_actions.reshape(actions_shape)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def compute_smoothened_metrics(metrics_deques) -> dict:
|
| 544 |
+
"""
|
| 545 |
+
Compute smoothened metrics from recent deques.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
metrics_deques (dict): Dictionary of deques containing recent metrics.
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
dict: Dictionary of smoothened metrics.
|
| 552 |
+
"""
|
| 553 |
+
smoothened_metrics = {}
|
| 554 |
+
for name, deque in metrics_deques.items():
|
| 555 |
+
if deque and len(deque) > 0:
|
| 556 |
+
smoothened_metrics[name] = sum(deque) / len(deque)
|
| 557 |
+
return smoothened_metrics
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def log_metrics_to_wandb(metrics, prefix, step, wandb_entity) -> None:
|
| 561 |
+
"""
|
| 562 |
+
Log metrics to Weights & Biases.
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
metrics (dict): Dictionary of metrics to log
|
| 566 |
+
prefix (str): Prefix for metric names
|
| 567 |
+
step (int): Training step
|
| 568 |
+
wandb_entity (str): W&B entity instance
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
None.
|
| 572 |
+
"""
|
| 573 |
+
log_dict = {}
|
| 574 |
+
for name, value in metrics.items():
|
| 575 |
+
# Map loss_value to Loss for better readability in W&B
|
| 576 |
+
if name == "loss_value":
|
| 577 |
+
log_dict[f"{prefix}/Loss"] = value
|
| 578 |
+
# Keep other metrics as is
|
| 579 |
+
else:
|
| 580 |
+
log_dict[f"{prefix}/{name.replace('_', ' ').title()}"] = value
|
| 581 |
+
wandb_entity.log(log_dict, step=step)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def save_training_checkpoint(
|
| 585 |
+
cfg,
|
| 586 |
+
run_dir,
|
| 587 |
+
log_step,
|
| 588 |
+
vla,
|
| 589 |
+
processor,
|
| 590 |
+
proprio_projector,
|
| 591 |
+
noisy_action_projector,
|
| 592 |
+
action_head,
|
| 593 |
+
train_dataset,
|
| 594 |
+
distributed_state,
|
| 595 |
+
) -> None:
|
| 596 |
+
"""
|
| 597 |
+
Save all training checkpoints including model components, LoRA adapter, and dataset statistics.
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
cfg (FinetuneConfig): Training configuration.
|
| 601 |
+
run_dir (Path): Experiment run directory path.
|
| 602 |
+
log_step (int): Current logging step.
|
| 603 |
+
vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| 604 |
+
processor (PrismaticProcessor): OpenVLA inputs processor.
|
| 605 |
+
proprio_projector (nn.Module): Proprioceptive state projector module.
|
| 606 |
+
noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| 607 |
+
action_head (nn.Module): Action head module.
|
| 608 |
+
train_dataset (RLDSDataset): Training dataset.
|
| 609 |
+
distributed_state (PartialState): Distributed training state.
|
| 610 |
+
|
| 611 |
+
Returns:
|
| 612 |
+
None.
|
| 613 |
+
"""
|
| 614 |
+
# Determine checkpoint paths and naming
|
| 615 |
+
if cfg.save_latest_checkpoint_only:
|
| 616 |
+
checkpoint_dir = run_dir
|
| 617 |
+
checkpoint_name_suffix = "latest_checkpoint.pt"
|
| 618 |
+
else:
|
| 619 |
+
checkpoint_dir = Path(str(run_dir) + f"--{log_step}_chkpt")
|
| 620 |
+
checkpoint_name_suffix = f"{log_step}_checkpoint.pt"
|
| 621 |
+
|
| 622 |
+
adapter_dir = checkpoint_dir / "lora_adapter"
|
| 623 |
+
|
| 624 |
+
# Create directories and save dataset statistics (main process only)
|
| 625 |
+
if distributed_state.is_main_process:
|
| 626 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 627 |
+
os.makedirs(adapter_dir, exist_ok=True)
|
| 628 |
+
save_dataset_statistics(train_dataset.dataset_statistics, checkpoint_dir)
|
| 629 |
+
print(f"Saving Model Checkpoint for Step {log_step}")
|
| 630 |
+
|
| 631 |
+
# Wait for directories to be created
|
| 632 |
+
dist.barrier()
|
| 633 |
+
|
| 634 |
+
# Save model components (main process only)
|
| 635 |
+
if distributed_state.is_main_process:
|
| 636 |
+
# Save processor and LoRA adapter
|
| 637 |
+
processor.save_pretrained(checkpoint_dir)
|
| 638 |
+
vla.module.save_pretrained(adapter_dir)
|
| 639 |
+
|
| 640 |
+
# Save other components
|
| 641 |
+
if cfg.use_proprio and proprio_projector is not None:
|
| 642 |
+
torch.save(proprio_projector.state_dict(), checkpoint_dir / f"proprio_projector--{checkpoint_name_suffix}")
|
| 643 |
+
|
| 644 |
+
if cfg.use_diffusion and noisy_action_projector is not None:
|
| 645 |
+
torch.save(
|
| 646 |
+
noisy_action_projector.state_dict(), checkpoint_dir / f"noisy_action_projector--{checkpoint_name_suffix}"
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
if (cfg.use_l1_regression or cfg.use_diffusion) and action_head is not None:
|
| 650 |
+
torch.save(action_head.state_dict(), checkpoint_dir / f"action_head--{checkpoint_name_suffix}")
|
| 651 |
+
|
| 652 |
+
if cfg.use_film:
|
| 653 |
+
# To be safe, just save the entire vision backbone (not just FiLM components)
|
| 654 |
+
torch.save(
|
| 655 |
+
vla.module.vision_backbone.state_dict(), checkpoint_dir / f"vision_backbone--{checkpoint_name_suffix}"
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
# Wait for model components to be saved
|
| 659 |
+
dist.barrier()
|
| 660 |
+
|
| 661 |
+
# Merge LoRA weights into base model and save resulting model checkpoint
|
| 662 |
+
# Note: Can be very slow on some devices; if so, we recommend merging offline
|
| 663 |
+
if cfg.use_lora and cfg.merge_lora_during_training:
|
| 664 |
+
base_vla = AutoModelForVision2Seq.from_pretrained(
|
| 665 |
+
cfg.vla_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
|
| 666 |
+
)
|
| 667 |
+
merged_vla = PeftModel.from_pretrained(base_vla, adapter_dir)
|
| 668 |
+
merged_vla = merged_vla.merge_and_unload()
|
| 669 |
+
|
| 670 |
+
if distributed_state.is_main_process:
|
| 671 |
+
merged_vla.save_pretrained(checkpoint_dir)
|
| 672 |
+
print(f"Saved merged model for Step {log_step} at: {checkpoint_dir}")
|
| 673 |
+
|
| 674 |
+
# Wait for merged model to be saved
|
| 675 |
+
dist.barrier()
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def run_validation(
|
| 679 |
+
vla,
|
| 680 |
+
action_head,
|
| 681 |
+
noisy_action_projector,
|
| 682 |
+
proprio_projector,
|
| 683 |
+
val_dataloader,
|
| 684 |
+
action_tokenizer,
|
| 685 |
+
device_id,
|
| 686 |
+
cfg,
|
| 687 |
+
num_patches,
|
| 688 |
+
log_step,
|
| 689 |
+
distributed_state,
|
| 690 |
+
val_time_limit,
|
| 691 |
+
) -> None:
|
| 692 |
+
"""
|
| 693 |
+
Compute validation set metrics for logging.
|
| 694 |
+
|
| 695 |
+
Args:
|
| 696 |
+
vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| 697 |
+
action_head (nn.Module): Action head module.
|
| 698 |
+
noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| 699 |
+
proprio_projector (nn.Module): Proprioceptive state projector module.
|
| 700 |
+
val_dataloader (DataLoader): Validation data loader.
|
| 701 |
+
action_tokenizer (ActionTokenizer): Action tokenizer.
|
| 702 |
+
device_id (str): Device ID.
|
| 703 |
+
cfg (FinetuneConfig): Training configuration.
|
| 704 |
+
num_patches (int): Number of vision patches.
|
| 705 |
+
log_step (int): Current logging step.
|
| 706 |
+
distributed_state (PartialState): Distributed training state.
|
| 707 |
+
val_time_limit (int): Time limit for computing validation metrics.
|
| 708 |
+
|
| 709 |
+
Returns:
|
| 710 |
+
None.
|
| 711 |
+
"""
|
| 712 |
+
val_start_time = time.time()
|
| 713 |
+
vla.eval()
|
| 714 |
+
val_batches_count = 0
|
| 715 |
+
|
| 716 |
+
# List to store validation metrics
|
| 717 |
+
all_val_metrics = []
|
| 718 |
+
|
| 719 |
+
with torch.no_grad():
|
| 720 |
+
for batch in val_dataloader:
|
| 721 |
+
# Always compute L1 loss for validation, even for diffusion
|
| 722 |
+
_, metrics = run_forward_pass(
|
| 723 |
+
vla=vla,
|
| 724 |
+
action_head=action_head,
|
| 725 |
+
noisy_action_projector=noisy_action_projector,
|
| 726 |
+
proprio_projector=proprio_projector,
|
| 727 |
+
batch=batch,
|
| 728 |
+
action_tokenizer=action_tokenizer,
|
| 729 |
+
device_id=device_id,
|
| 730 |
+
use_l1_regression=cfg.use_l1_regression,
|
| 731 |
+
use_diffusion=cfg.use_diffusion,
|
| 732 |
+
use_proprio=cfg.use_proprio,
|
| 733 |
+
use_film=cfg.use_film,
|
| 734 |
+
num_patches=num_patches,
|
| 735 |
+
compute_diffusion_l1=True,
|
| 736 |
+
num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
# Add the loss value to the metrics
|
| 740 |
+
metrics["loss"] = metrics["loss_value"]
|
| 741 |
+
all_val_metrics.append(metrics)
|
| 742 |
+
val_batches_count += 1
|
| 743 |
+
|
| 744 |
+
# Cut testing on validation set short if it exceeds time limit
|
| 745 |
+
if time.time() - val_start_time > val_time_limit:
|
| 746 |
+
break
|
| 747 |
+
|
| 748 |
+
# Compute average validation metrics
|
| 749 |
+
avg_val_metrics = {}
|
| 750 |
+
for metric_name in all_val_metrics[0].keys():
|
| 751 |
+
values = [metrics[metric_name] for metrics in all_val_metrics if metric_name in metrics]
|
| 752 |
+
if values:
|
| 753 |
+
avg_val_metrics[metric_name] = sum(values) / len(values)
|
| 754 |
+
|
| 755 |
+
# Add batch count to metrics
|
| 756 |
+
avg_val_metrics["val_batches_count"] = val_batches_count
|
| 757 |
+
|
| 758 |
+
# Log validation metrics to W&B
|
| 759 |
+
if distributed_state.is_main_process:
|
| 760 |
+
log_metrics_to_wandb(avg_val_metrics, "VLA Val", log_step, wandb)
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
@draccus.wrap()
|
| 764 |
+
def finetune(cfg: FinetuneConfig) -> None:
|
| 765 |
+
"""
|
| 766 |
+
Fine-tunes base VLA on demonstration dataset via LoRA.
|
| 767 |
+
|
| 768 |
+
Allows toggling different action representations (discrete vs. continuous), different learning objectives
|
| 769 |
+
(next-token prediction vs. L1 regression vs. diffusion), FiLM. Also allows for additional model inputs,
|
| 770 |
+
such as additional camera images and robot proprioceptive state. Assumes parallel action generation with
|
| 771 |
+
action chunking.
|
| 772 |
+
|
| 773 |
+
Args:
|
| 774 |
+
cfg (FinetuneConfig): Training configuration.
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
None.
|
| 778 |
+
"""
|
| 779 |
+
assert cfg.use_lora, "Only LoRA fine-tuning is supported. Please set --use_lora=True!"
|
| 780 |
+
assert not (cfg.use_l1_regression and cfg.use_diffusion), (
|
| 781 |
+
"Cannot do both L1 regression and diffusion. Please pick one of them!"
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Trim trailing forward slash ('/') in VLA path if it exists
|
| 785 |
+
cfg.vla_path = cfg.vla_path.rstrip("/")
|
| 786 |
+
print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`")
|
| 787 |
+
|
| 788 |
+
# Get experiment run ID
|
| 789 |
+
run_id = get_run_id(cfg)
|
| 790 |
+
|
| 791 |
+
# Create experiment run directory
|
| 792 |
+
run_dir = cfg.run_root_dir / run_id
|
| 793 |
+
os.makedirs(run_dir, exist_ok=True)
|
| 794 |
+
|
| 795 |
+
# GPU setup
|
| 796 |
+
distributed_state = PartialState()
|
| 797 |
+
device_id = distributed_state.local_process_index
|
| 798 |
+
torch.cuda.set_device(device_id)
|
| 799 |
+
torch.cuda.empty_cache()
|
| 800 |
+
|
| 801 |
+
# Initialize wandb logging
|
| 802 |
+
if distributed_state.is_main_process:
|
| 803 |
+
wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=run_id)
|
| 804 |
+
|
| 805 |
+
# Print detected constants
|
| 806 |
+
print(
|
| 807 |
+
"Detected constants:\n"
|
| 808 |
+
f"\tNUM_ACTIONS_CHUNK: {NUM_ACTIONS_CHUNK}\n"
|
| 809 |
+
f"\tACTION_DIM: {ACTION_DIM}\n"
|
| 810 |
+
f"\tPROPRIO_DIM: {PROPRIO_DIM}\n"
|
| 811 |
+
f"\tACTION_PROPRIO_NORMALIZATION_TYPE: {ACTION_PROPRIO_NORMALIZATION_TYPE}"
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
# Two options:
|
| 815 |
+
# (1) Base model is on Hugging Face Hub
|
| 816 |
+
# - Then download it and record the path to the download directory
|
| 817 |
+
# (2) Base model is stored locally
|
| 818 |
+
# - Then register model config in HF Auto Classes
|
| 819 |
+
# In both cases, we want to check whether any changes have been made to
|
| 820 |
+
# the `modeling_prismatic.py` file in this codebase; if so, we will copy
|
| 821 |
+
# the file to the downloaded or locally stored checkpoint directory so
|
| 822 |
+
# that the user's changes to the VLA class logic go into effect
|
| 823 |
+
if model_is_on_hf_hub(cfg.vla_path):
|
| 824 |
+
# Download model directly from Hugging Face Hub
|
| 825 |
+
vla_download_path = snapshot_download(repo_id=cfg.vla_path)
|
| 826 |
+
# Overwrite VLA path
|
| 827 |
+
cfg.vla_path = vla_download_path
|
| 828 |
+
else:
|
| 829 |
+
# Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub)
|
| 830 |
+
AutoConfig.register("openvla", OpenVLAConfig)
|
| 831 |
+
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
|
| 832 |
+
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
|
| 833 |
+
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
|
| 834 |
+
|
| 835 |
+
# Update config.json and sync model files
|
| 836 |
+
if distributed_state.is_main_process:
|
| 837 |
+
update_auto_map(cfg.vla_path)
|
| 838 |
+
check_model_logic_mismatch(cfg.vla_path)
|
| 839 |
+
|
| 840 |
+
# Wait for model files to be synced
|
| 841 |
+
dist.barrier()
|
| 842 |
+
|
| 843 |
+
# Load processor and VLA
|
| 844 |
+
processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
|
| 845 |
+
vla = AutoModelForVision2Seq.from_pretrained(
|
| 846 |
+
cfg.vla_path,
|
| 847 |
+
torch_dtype=torch.bfloat16,
|
| 848 |
+
low_cpu_mem_usage=True,
|
| 849 |
+
trust_remote_code=True,
|
| 850 |
+
).to(device_id)
|
| 851 |
+
|
| 852 |
+
# Set number of images in VLA input
|
| 853 |
+
vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input)
|
| 854 |
+
|
| 855 |
+
# LoRA setup
|
| 856 |
+
if cfg.use_lora:
|
| 857 |
+
lora_config = LoraConfig(
|
| 858 |
+
r=cfg.lora_rank,
|
| 859 |
+
lora_alpha=min(cfg.lora_rank, 16),
|
| 860 |
+
lora_dropout=cfg.lora_dropout,
|
| 861 |
+
target_modules="all-linear",
|
| 862 |
+
init_lora_weights="gaussian",
|
| 863 |
+
)
|
| 864 |
+
vla = get_peft_model(vla, lora_config)
|
| 865 |
+
vla.print_trainable_parameters()
|
| 866 |
+
|
| 867 |
+
# FiLM setup
|
| 868 |
+
if cfg.use_film:
|
| 869 |
+
count_parameters(vla.vision_backbone, "vla.vision_backbone (original)")
|
| 870 |
+
# Wrap vision backbone with FiLM wrapper
|
| 871 |
+
# Important: For this, must specify `vla.model.vision_backbone` instead of just `vla.vision_backbone`, since the
|
| 872 |
+
# latter would cause the new wrapped backbone to be saved as a new attribute of `vla` instead of overwriting the
|
| 873 |
+
# original one (due to the LoRA wrapper)
|
| 874 |
+
vla.model.vision_backbone = FiLMedPrismaticVisionBackbone(
|
| 875 |
+
vision_backbone=vla.model.vision_backbone,
|
| 876 |
+
llm_dim=vla.llm_dim,
|
| 877 |
+
)
|
| 878 |
+
count_parameters(vla.vision_backbone, "vla.vision_backbone (post-wrap)")
|
| 879 |
+
if cfg.resume:
|
| 880 |
+
state_dict = load_checkpoint("vision_backbone", cfg.vla_path, cfg.resume_step)
|
| 881 |
+
vla.model.vision_backbone.load_state_dict(state_dict)
|
| 882 |
+
vla.model.vision_backbone = vla.model.vision_backbone.to(device_id)
|
| 883 |
+
|
| 884 |
+
# Wrap VLA with DDP
|
| 885 |
+
vla = wrap_ddp(vla, device_id, find_unused=True)
|
| 886 |
+
|
| 887 |
+
# If applicable, instantiate proprio projector
|
| 888 |
+
if cfg.use_proprio:
|
| 889 |
+
proprio_projector = init_module(
|
| 890 |
+
ProprioProjector,
|
| 891 |
+
"proprio_projector",
|
| 892 |
+
cfg,
|
| 893 |
+
device_id,
|
| 894 |
+
{"llm_dim": vla.module.llm_dim, "proprio_dim": PROPRIO_DIM},
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
# If applicable, instantiate continuous action head for L1 regression
|
| 898 |
+
if cfg.use_l1_regression:
|
| 899 |
+
action_head = init_module(
|
| 900 |
+
L1RegressionActionHead,
|
| 901 |
+
"action_head",
|
| 902 |
+
cfg,
|
| 903 |
+
device_id,
|
| 904 |
+
{"input_dim": vla.module.llm_dim, "hidden_dim": vla.module.llm_dim, "action_dim": ACTION_DIM},
|
| 905 |
+
to_bf16=True,
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
# If applicable, instantiate diffusion action head and noisy action projector
|
| 909 |
+
if cfg.use_diffusion:
|
| 910 |
+
action_head = init_module(
|
| 911 |
+
DiffusionActionHead,
|
| 912 |
+
"action_head",
|
| 913 |
+
cfg,
|
| 914 |
+
device_id,
|
| 915 |
+
{
|
| 916 |
+
"input_dim": vla.module.llm_dim,
|
| 917 |
+
"hidden_dim": vla.module.llm_dim,
|
| 918 |
+
"action_dim": ACTION_DIM,
|
| 919 |
+
"num_diffusion_steps_train": cfg.num_diffusion_steps_train,
|
| 920 |
+
},
|
| 921 |
+
to_bf16=True,
|
| 922 |
+
)
|
| 923 |
+
noisy_action_projector = init_module(
|
| 924 |
+
NoisyActionProjector, "noisy_action_projector", cfg, device_id, {"llm_dim": vla.module.llm_dim}
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
# Get number of vision patches
|
| 928 |
+
NUM_PATCHES = vla.module.vision_backbone.get_num_patches() * vla.module.vision_backbone.get_num_images_in_input()
|
| 929 |
+
# If we have proprio inputs, a single proprio embedding is appended to the end of the vision patch embeddings
|
| 930 |
+
if cfg.use_proprio:
|
| 931 |
+
NUM_PATCHES += 1
|
| 932 |
+
# For diffusion, a single diffusion timestep embedding is appended to the end of the vision patch embeddings
|
| 933 |
+
if cfg.use_diffusion:
|
| 934 |
+
NUM_PATCHES += 1
|
| 935 |
+
|
| 936 |
+
# Instantiate optimizer
|
| 937 |
+
trainable_params = [param for param in vla.parameters() if param.requires_grad]
|
| 938 |
+
if cfg.use_l1_regression or cfg.use_diffusion:
|
| 939 |
+
trainable_params += [param for param in action_head.parameters() if param.requires_grad]
|
| 940 |
+
if cfg.use_diffusion:
|
| 941 |
+
trainable_params += [param for param in noisy_action_projector.parameters() if param.requires_grad]
|
| 942 |
+
if cfg.use_proprio:
|
| 943 |
+
trainable_params += [param for param in proprio_projector.parameters() if param.requires_grad]
|
| 944 |
+
print(f"# total trainable params: {sum(p.numel() for p in trainable_params)}")
|
| 945 |
+
optimizer = AdamW(trainable_params, lr=cfg.learning_rate)
|
| 946 |
+
|
| 947 |
+
# Record original learning rate
|
| 948 |
+
original_lr = optimizer.param_groups[0]["lr"]
|
| 949 |
+
|
| 950 |
+
# Create learning rate scheduler
|
| 951 |
+
scheduler = MultiStepLR(
|
| 952 |
+
optimizer,
|
| 953 |
+
milestones=[cfg.num_steps_before_decay], # Number of steps after which LR will change
|
| 954 |
+
gamma=0.1, # Multiplicative factor of learning rate decay
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
# Create Action Tokenizer
|
| 958 |
+
action_tokenizer = ActionTokenizer(processor.tokenizer)
|
| 959 |
+
|
| 960 |
+
# Load Fine-tuning Dataset =>> note that we use an RLDS-formatted dataset following Open X-Embodiment by default.
|
| 961 |
+
# =>> If you want to use a non-RLDS dataset (e.g., a standard PyTorch Dataset) see the following commented block.
|
| 962 |
+
# =>> Note that our training code does not loop over epochs because the RLDS loader does this implicitly; if using
|
| 963 |
+
# your own Dataset, make sure to add the appropriate logic to the training loop!
|
| 964 |
+
#
|
| 965 |
+
# ---
|
| 966 |
+
# from prismatic.vla.datasets import DummyDataset
|
| 967 |
+
#
|
| 968 |
+
# train_dataset = DummyDataset(
|
| 969 |
+
# action_tokenizer,
|
| 970 |
+
# processor.tokenizer,
|
| 971 |
+
# image_transform=processor.image_processor.apply_transform,
|
| 972 |
+
# prompt_builder_fn=PurePromptBuilder,
|
| 973 |
+
# )
|
| 974 |
+
# ---
|
| 975 |
+
|
| 976 |
+
# We assume that the model takes as input one third-person camera image and 1 or 2 optional wrist camera image(s)
|
| 977 |
+
use_wrist_image = cfg.num_images_in_input > 1
|
| 978 |
+
|
| 979 |
+
# Create training and optional validation datasets
|
| 980 |
+
batch_transform = RLDSBatchTransform(
|
| 981 |
+
action_tokenizer,
|
| 982 |
+
processor.tokenizer,
|
| 983 |
+
image_transform=processor.image_processor.apply_transform,
|
| 984 |
+
prompt_builder_fn=PurePromptBuilder,
|
| 985 |
+
use_wrist_image=use_wrist_image,
|
| 986 |
+
use_proprio=cfg.use_proprio,
|
| 987 |
+
)
|
| 988 |
+
train_dataset = RLDSDataset(
|
| 989 |
+
cfg.data_root_dir,
|
| 990 |
+
cfg.dataset_name,
|
| 991 |
+
batch_transform,
|
| 992 |
+
resize_resolution=tuple(vla.module.config.image_sizes),
|
| 993 |
+
shuffle_buffer_size=cfg.shuffle_buffer_size,
|
| 994 |
+
image_aug=cfg.image_aug,
|
| 995 |
+
)
|
| 996 |
+
if cfg.use_val_set:
|
| 997 |
+
val_dataset = RLDSDataset(
|
| 998 |
+
cfg.data_root_dir,
|
| 999 |
+
cfg.dataset_name,
|
| 1000 |
+
batch_transform,
|
| 1001 |
+
resize_resolution=tuple(vla.module.config.image_sizes),
|
| 1002 |
+
shuffle_buffer_size=cfg.shuffle_buffer_size // 10,
|
| 1003 |
+
image_aug=cfg.image_aug,
|
| 1004 |
+
train=False,
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
# [Important] Save dataset statistics so that we can unnormalize actions during inference
|
| 1008 |
+
if distributed_state.is_main_process:
|
| 1009 |
+
save_dataset_statistics(train_dataset.dataset_statistics, run_dir)
|
| 1010 |
+
|
| 1011 |
+
# Create collator and dataloader
|
| 1012 |
+
collator = PaddedCollatorForActionPrediction(
|
| 1013 |
+
processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
|
| 1014 |
+
)
|
| 1015 |
+
dataloader = DataLoader(
|
| 1016 |
+
train_dataset,
|
| 1017 |
+
batch_size=cfg.batch_size,
|
| 1018 |
+
sampler=None,
|
| 1019 |
+
collate_fn=collator,
|
| 1020 |
+
num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
|
| 1021 |
+
)
|
| 1022 |
+
if cfg.use_val_set:
|
| 1023 |
+
val_batch_size = cfg.batch_size
|
| 1024 |
+
val_dataloader = DataLoader(
|
| 1025 |
+
val_dataset,
|
| 1026 |
+
batch_size=val_batch_size,
|
| 1027 |
+
sampler=None,
|
| 1028 |
+
collate_fn=collator,
|
| 1029 |
+
num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
# Deque to store recent train metrics (used for computing smoothened metrics for gradient accumulation)
|
| 1033 |
+
recent_metrics = {
|
| 1034 |
+
"loss_value": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1035 |
+
"curr_action_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1036 |
+
"curr_action_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1037 |
+
"next_actions_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1038 |
+
"next_actions_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1039 |
+
}
|
| 1040 |
+
|
| 1041 |
+
# Start training
|
| 1042 |
+
with tqdm.tqdm(total=cfg.max_steps, leave=False) as progress:
|
| 1043 |
+
vla.train()
|
| 1044 |
+
optimizer.zero_grad()
|
| 1045 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 1046 |
+
# Compute training metrics and loss
|
| 1047 |
+
compute_diffusion_l1 = cfg.use_diffusion and batch_idx % cfg.diffusion_sample_freq == 0
|
| 1048 |
+
loss, metrics = run_forward_pass(
|
| 1049 |
+
vla=vla,
|
| 1050 |
+
action_head=action_head,
|
| 1051 |
+
noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| 1052 |
+
proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| 1053 |
+
batch=batch,
|
| 1054 |
+
action_tokenizer=action_tokenizer,
|
| 1055 |
+
device_id=device_id,
|
| 1056 |
+
use_l1_regression=cfg.use_l1_regression,
|
| 1057 |
+
use_diffusion=cfg.use_diffusion,
|
| 1058 |
+
use_proprio=cfg.use_proprio,
|
| 1059 |
+
use_film=cfg.use_film,
|
| 1060 |
+
num_patches=NUM_PATCHES,
|
| 1061 |
+
compute_diffusion_l1=compute_diffusion_l1,
|
| 1062 |
+
num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
# Normalize loss to account for gradient accumulation
|
| 1066 |
+
normalized_loss = loss / cfg.grad_accumulation_steps
|
| 1067 |
+
|
| 1068 |
+
# Backward pass
|
| 1069 |
+
normalized_loss.backward()
|
| 1070 |
+
|
| 1071 |
+
# Store recent train metrics
|
| 1072 |
+
for metric_name, value in metrics.items():
|
| 1073 |
+
if metric_name in recent_metrics:
|
| 1074 |
+
recent_metrics[metric_name].append(value)
|
| 1075 |
+
|
| 1076 |
+
# Compute gradient step index
|
| 1077 |
+
gradient_step_idx = batch_idx // cfg.grad_accumulation_steps
|
| 1078 |
+
|
| 1079 |
+
# Compute smoothened train metrics
|
| 1080 |
+
smoothened_metrics = compute_smoothened_metrics(recent_metrics)
|
| 1081 |
+
|
| 1082 |
+
# Push Metrics to W&B (every wandb_log_freq gradient steps)
|
| 1083 |
+
log_step = gradient_step_idx if not cfg.resume else cfg.resume_step + gradient_step_idx
|
| 1084 |
+
if distributed_state.is_main_process and log_step % cfg.wandb_log_freq == 0:
|
| 1085 |
+
log_metrics_to_wandb(smoothened_metrics, "VLA Train", log_step, wandb)
|
| 1086 |
+
|
| 1087 |
+
# [If applicable] Linearly warm up learning rate from 10% to 100% of original
|
| 1088 |
+
if cfg.lr_warmup_steps > 0:
|
| 1089 |
+
lr_progress = min((gradient_step_idx + 1) / cfg.lr_warmup_steps, 1.0) # Cap at 1.0
|
| 1090 |
+
current_lr = original_lr * (0.1 + 0.9 * lr_progress)
|
| 1091 |
+
for param_group in optimizer.param_groups:
|
| 1092 |
+
param_group["lr"] = current_lr
|
| 1093 |
+
|
| 1094 |
+
if distributed_state.is_main_process and gradient_step_idx % cfg.wandb_log_freq == 0:
|
| 1095 |
+
# Log the learning rate
|
| 1096 |
+
# Make sure to do this AFTER any learning rate modifications (e.g., warmup/decay)
|
| 1097 |
+
wandb.log(
|
| 1098 |
+
{
|
| 1099 |
+
"VLA Train/Learning Rate": scheduler.get_last_lr()[0],
|
| 1100 |
+
},
|
| 1101 |
+
step=log_step,
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
# Optimizer and LR scheduler step
|
| 1105 |
+
if (batch_idx + 1) % cfg.grad_accumulation_steps == 0:
|
| 1106 |
+
optimizer.step()
|
| 1107 |
+
scheduler.step()
|
| 1108 |
+
optimizer.zero_grad()
|
| 1109 |
+
progress.update()
|
| 1110 |
+
|
| 1111 |
+
# Save model checkpoint: either keep latest checkpoint only or all checkpoints
|
| 1112 |
+
if gradient_step_idx > 0 and log_step % cfg.save_freq == 0:
|
| 1113 |
+
save_training_checkpoint(
|
| 1114 |
+
cfg=cfg,
|
| 1115 |
+
run_dir=run_dir,
|
| 1116 |
+
log_step=log_step,
|
| 1117 |
+
vla=vla,
|
| 1118 |
+
processor=processor,
|
| 1119 |
+
proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| 1120 |
+
noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| 1121 |
+
action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None,
|
| 1122 |
+
train_dataset=train_dataset,
|
| 1123 |
+
distributed_state=distributed_state,
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
# Test model on validation set
|
| 1127 |
+
if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0:
|
| 1128 |
+
run_validation(
|
| 1129 |
+
vla=vla,
|
| 1130 |
+
action_head=action_head,
|
| 1131 |
+
noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| 1132 |
+
proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| 1133 |
+
val_dataloader=val_dataloader,
|
| 1134 |
+
action_tokenizer=action_tokenizer,
|
| 1135 |
+
device_id=device_id,
|
| 1136 |
+
cfg=cfg,
|
| 1137 |
+
num_patches=NUM_PATCHES,
|
| 1138 |
+
log_step=log_step,
|
| 1139 |
+
distributed_state=distributed_state,
|
| 1140 |
+
val_time_limit=cfg.val_time_limit,
|
| 1141 |
+
)
|
| 1142 |
+
# Set model back to training mode after validation
|
| 1143 |
+
vla.train()
|
| 1144 |
+
|
| 1145 |
+
# Stop training when max_steps is reached
|
| 1146 |
+
if log_step == cfg.max_steps:
|
| 1147 |
+
print(f"Max step {cfg.max_steps} reached! Stopping training...")
|
| 1148 |
+
break
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
if __name__ == "__main__":
|
| 1152 |
+
finetune()
|
capvector-oft/vla-scripts/finetune_regular_loss.py
ADDED
|
@@ -0,0 +1,1790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#This is for the experiment of CapVector, stopping the gradient propagation in the direction of the new added vector
|
| 2 |
+
"""
|
| 3 |
+
finetune.py
|
| 4 |
+
|
| 5 |
+
Fine-tunes OpenVLA via LoRA.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import ctypes
|
| 10 |
+
|
| 11 |
+
lib_path = "/share/miniconda3/lib/libstdc++.so.6"
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
ctypes.CDLL(lib_path)
|
| 15 |
+
print(f"Successfully preloaded {lib_path}")
|
| 16 |
+
except Exception as e:
|
| 17 |
+
print(f"Failed to preload {lib_path}: {e}")
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import time
|
| 21 |
+
from collections import deque
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Dict, Optional, Tuple, Type
|
| 25 |
+
|
| 26 |
+
import draccus
|
| 27 |
+
import torch
|
| 28 |
+
import torch.distributed as dist
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
import tqdm
|
| 31 |
+
import numpy as np
|
| 32 |
+
from accelerate import PartialState
|
| 33 |
+
from huggingface_hub import HfApi, snapshot_download
|
| 34 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
| 35 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 36 |
+
from torch.optim import AdamW
|
| 37 |
+
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
|
| 38 |
+
from torch.utils.data import DataLoader
|
| 39 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 40 |
+
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
|
| 41 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 42 |
+
|
| 43 |
+
import wandb
|
| 44 |
+
os.environ["WANDB_MODE"]="offline"
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
from safetensors import safe_open
|
| 48 |
+
SAFETENSORS_AVAILABLE = True
|
| 49 |
+
except ImportError:
|
| 50 |
+
SAFETENSORS_AVAILABLE = False
|
| 51 |
+
print("Warning: safetensors not available, will try torch.load instead")
|
| 52 |
+
|
| 53 |
+
from experiments.robot.openvla_utils import (
|
| 54 |
+
check_model_logic_mismatch,
|
| 55 |
+
model_is_on_hf_hub,
|
| 56 |
+
update_auto_map,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
|
| 60 |
+
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
|
| 61 |
+
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| 62 |
+
from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead
|
| 63 |
+
from prismatic.models.backbones.llm.prompting import PurePromptBuilder
|
| 64 |
+
from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone
|
| 65 |
+
from prismatic.models.ema_model import EMAModel
|
| 66 |
+
from prismatic.models.projectors import (
|
| 67 |
+
NoisyActionProjector,
|
| 68 |
+
ProprioProjector,
|
| 69 |
+
)
|
| 70 |
+
from prismatic.training.train_utils import (
|
| 71 |
+
compute_actions_l1_loss,
|
| 72 |
+
compute_token_accuracy,
|
| 73 |
+
get_current_action_mask,
|
| 74 |
+
get_next_actions_mask,
|
| 75 |
+
)
|
| 76 |
+
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
|
| 77 |
+
from prismatic.vla.action_tokenizer import ActionTokenizer
|
| 78 |
+
from prismatic.vla.constants import (
|
| 79 |
+
ACTION_DIM,
|
| 80 |
+
ACTION_PROPRIO_NORMALIZATION_TYPE,
|
| 81 |
+
NUM_ACTIONS_CHUNK,
|
| 82 |
+
PROPRIO_DIM,
|
| 83 |
+
)
|
| 84 |
+
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
|
| 85 |
+
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics
|
| 86 |
+
|
| 87 |
+
# Sane Defaults
|
| 88 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 89 |
+
|
| 90 |
+
#wx: stop gradient in the feature vector direction
|
| 91 |
+
|
| 92 |
+
EPS = 1e-12
|
| 93 |
+
|
| 94 |
+
def register_orthogonal_grad_hook(model, vector_W, debug=False):
|
| 95 |
+
name_to_param = dict(model.named_parameters())
|
| 96 |
+
|
| 97 |
+
hooked_A = 0
|
| 98 |
+
hooked_B = 0
|
| 99 |
+
hooked_direct = 0
|
| 100 |
+
|
| 101 |
+
missed = 0
|
| 102 |
+
missed_name = []
|
| 103 |
+
|
| 104 |
+
direct_missed = 0
|
| 105 |
+
direct_missed_name = []
|
| 106 |
+
|
| 107 |
+
printed = {"A": False, "B": False, "D": False}
|
| 108 |
+
|
| 109 |
+
def proj_out(g2, v2):
|
| 110 |
+
vn2 = (v2 * v2).sum().detach()
|
| 111 |
+
if vn2.item() <= EPS:
|
| 112 |
+
return g2
|
| 113 |
+
gv = (g2 * v2).sum()
|
| 114 |
+
return g2 - (gv / (vn2 + EPS)) * v2
|
| 115 |
+
|
| 116 |
+
for w_name, vW in vector_W.items():
|
| 117 |
+
if "vision_backbone" in w_name:
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
prefix = "base_model.model."
|
| 121 |
+
A_name = prefix + w_name.replace(".weight", ".lora_A.default.weight")
|
| 122 |
+
B_name = prefix + w_name.replace(".weight", ".lora_B.default.weight")
|
| 123 |
+
|
| 124 |
+
# ===== 1) 先尝试 LoRA hook =====
|
| 125 |
+
if A_name in name_to_param and B_name in name_to_param:
|
| 126 |
+
A = name_to_param[A_name]
|
| 127 |
+
B = name_to_param[B_name]
|
| 128 |
+
|
| 129 |
+
# 两个都不训练就不 hook
|
| 130 |
+
if (not A.requires_grad) and (not B.requires_grad):
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
# vW 固定到 device/dtype
|
| 134 |
+
vW = vW.to(device=A.device, dtype=A.dtype)
|
| 135 |
+
vW2 = vW.reshape(vW.shape[0], -1) if vW.ndim != 2 else vW # [out, in_flat]
|
| 136 |
+
|
| 137 |
+
# ---- hook A:动态用当前 B 计算 vA = B^T vW ----
|
| 138 |
+
if A.requires_grad:
|
| 139 |
+
def hook_A(g, A_ref=A, B_ref=B, vW2_ref=vW2):
|
| 140 |
+
if g is None:
|
| 141 |
+
return None
|
| 142 |
+
g2 = g.reshape(g.shape[0], -1) if g.ndim != 2 else g
|
| 143 |
+
|
| 144 |
+
B_mat = B_ref.detach()
|
| 145 |
+
B2 = B_mat.reshape(B_mat.shape[0], -1) if B_mat.ndim != 2 else B_mat # [out, r]
|
| 146 |
+
|
| 147 |
+
if B2.shape[0] != vW2_ref.shape[0]:
|
| 148 |
+
return g
|
| 149 |
+
|
| 150 |
+
vA = torch.matmul(B2.transpose(0, 1), vW2_ref) # [r, in_flat]
|
| 151 |
+
|
| 152 |
+
if debug and not printed["A"]:
|
| 153 |
+
print(f"[hook fired] A: ||B||={B2.norm().item():.4e}, ||vA||={vA.norm().item():.4e}, ||g||={g2.norm().item():.4e}")
|
| 154 |
+
printed["A"] = True
|
| 155 |
+
|
| 156 |
+
g2_new = proj_out(g2, vA)
|
| 157 |
+
return g2_new.view_as(g)
|
| 158 |
+
|
| 159 |
+
A.register_hook(hook_A)
|
| 160 |
+
hooked_A += 1
|
| 161 |
+
|
| 162 |
+
# ---- hook B:动态用当前 A 计算 vB = vW A^T ----
|
| 163 |
+
if B.requires_grad:
|
| 164 |
+
def hook_B(g, A_ref=A, B_ref=B, vW2_ref=vW2):
|
| 165 |
+
if g is None:
|
| 166 |
+
return None
|
| 167 |
+
g2 = g.reshape(g.shape[0], -1) if g.ndim != 2 else g
|
| 168 |
+
|
| 169 |
+
A_mat = A_ref.detach()
|
| 170 |
+
A2 = A_mat.reshape(A_mat.shape[0], -1) if A_mat.ndim != 2 else A_mat # [r, in_flat]
|
| 171 |
+
|
| 172 |
+
if A2.shape[1] != vW2_ref.shape[1]:
|
| 173 |
+
return g
|
| 174 |
+
|
| 175 |
+
vB = torch.matmul(vW2_ref, A2.transpose(0, 1)) # [out, r]
|
| 176 |
+
|
| 177 |
+
if debug and not printed["B"]:
|
| 178 |
+
print(f"[hook fired] B: ||A||={A2.norm().item():.4e}, ||vB||={vB.norm().item():.4e}, ||g||={g2.norm().item():.4e}")
|
| 179 |
+
printed["B"] = True
|
| 180 |
+
|
| 181 |
+
g2_new = proj_out(g2, vB)
|
| 182 |
+
return g2_new.view_as(g)
|
| 183 |
+
|
| 184 |
+
B.register_hook(hook_B)
|
| 185 |
+
hooked_B += 1
|
| 186 |
+
|
| 187 |
+
# 这一轮已经成功走 LoRA 分支了
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
# ===== 2) LoRA 不存在:fallback 到“直接参数”hook(比如 layernorm)=====
|
| 191 |
+
missed += 1
|
| 192 |
+
missed_name.append(w_name)
|
| 193 |
+
|
| 194 |
+
# 尝试对齐到非 LoRA 参数名
|
| 195 |
+
# 绝大多数情况下:base_model.model.<w_name>
|
| 196 |
+
direct_name = prefix + w_name
|
| 197 |
+
|
| 198 |
+
# 有些 vector 的命名可能不带 base_model.model,而你的模型参数名可能是别的前缀
|
| 199 |
+
# 这里给一个“再尝试一次”的备选:如果 direct_name 找不到,就尝试去掉 language_model/等前缀的情况
|
| 200 |
+
# (你也可以按自己工程实际再加规则)
|
| 201 |
+
if direct_name not in name_to_param:
|
| 202 |
+
# 再试一次:如果 w_name 本身已经含 base_model.model 就不加 prefix
|
| 203 |
+
if w_name in name_to_param:
|
| 204 |
+
direct_name = w_name
|
| 205 |
+
else:
|
| 206 |
+
direct_missed += 1
|
| 207 |
+
direct_missed_name.append(w_name)
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
P = name_to_param[direct_name]
|
| 211 |
+
if not P.requires_grad:
|
| 212 |
+
# 找到了但不训练:不 hook,也不算 direct_missed
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
vP = vector_W[w_name].to(device=P.device, dtype=P.dtype)
|
| 216 |
+
vP2 = vP.reshape(vP.shape[0], -1) if vP.ndim != 2 else vP
|
| 217 |
+
|
| 218 |
+
def hook_direct(g, v_ref=vP2):
|
| 219 |
+
if g is None:
|
| 220 |
+
return None
|
| 221 |
+
g2 = g.reshape(g.shape[0], -1) if g.ndim != 2 else g
|
| 222 |
+
|
| 223 |
+
# shape 不匹配就不动(避免 hook 改尺寸报错)
|
| 224 |
+
if g2.shape != v_ref.shape:
|
| 225 |
+
return g
|
| 226 |
+
|
| 227 |
+
if debug and not printed["D"]:
|
| 228 |
+
print(f"[hook fired] Direct: param={direct_name}, ||v||={v_ref.norm().item():.4e}, ||g||={g2.norm().item():.4e}")
|
| 229 |
+
printed["D"] = True
|
| 230 |
+
|
| 231 |
+
g2_new = proj_out(g2, v_ref)
|
| 232 |
+
return g2_new.view_as(g)
|
| 233 |
+
|
| 234 |
+
P.register_hook(hook_direct)
|
| 235 |
+
hooked_direct += 1
|
| 236 |
+
|
| 237 |
+
print(
|
| 238 |
+
f"[hook summary] hooked lora_A: {hooked_A}, lora_B: {hooked_B}, direct: {hooked_direct}, "
|
| 239 |
+
f"missed(lora-not-found): {missed}, direct_missed: {direct_missed}"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# 如果你想看具体 miss 列表:
|
| 243 |
+
# print("[missed lora-not-found names]")
|
| 244 |
+
# for n in missed_name: print(" -", n)
|
| 245 |
+
# print("[direct_missed names]")
|
| 246 |
+
# for n in direct_missed_name: print(" -", n)
|
| 247 |
+
|
| 248 |
+
# import pdb; pdb.set_trace()
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# def register_orthogonal_grad_hook(model, vector_W, debug=False):
|
| 252 |
+
# name_to_param = dict(model.named_parameters())
|
| 253 |
+
|
| 254 |
+
# hooked_A = 0
|
| 255 |
+
# hooked_B = 0
|
| 256 |
+
# missed = 0
|
| 257 |
+
|
| 258 |
+
# printed = {"A": False, "B": False} # 用于只打印一次
|
| 259 |
+
|
| 260 |
+
# for w_name, vW in vector_W.items():
|
| 261 |
+
# if "vision_backbone" in w_name:
|
| 262 |
+
# continue
|
| 263 |
+
# # import pdb; pdb.set_trace()
|
| 264 |
+
# prefix = "base_model.model."
|
| 265 |
+
# A_name = prefix + w_name.replace(".weight", ".lora_A.default.weight")
|
| 266 |
+
# B_name = prefix + w_name.replace(".weight", ".lora_B.default.weight")
|
| 267 |
+
|
| 268 |
+
# if A_name not in name_to_param or B_name not in name_to_param:
|
| 269 |
+
# missed += 1
|
| 270 |
+
# continue
|
| 271 |
+
|
| 272 |
+
# A = name_to_param[A_name]
|
| 273 |
+
# B = name_to_param[B_name]
|
| 274 |
+
|
| 275 |
+
# if (not A.requires_grad) and (not B.requires_grad):
|
| 276 |
+
# continue
|
| 277 |
+
|
| 278 |
+
# vW = vW.to(device=A.device, dtype=A.dtype)
|
| 279 |
+
|
| 280 |
+
# with torch.no_grad():
|
| 281 |
+
# # A_mat = A.detach().view(1, -1) # (1, in)
|
| 282 |
+
# # B_mat = B.detach().view(-1, 1) # (out,1)
|
| 283 |
+
|
| 284 |
+
# # vA = torch.matmul(B_mat.T, vW) # (1,in)
|
| 285 |
+
# # vB = torch.matmul(vW, A_mat.T) # (out,1)
|
| 286 |
+
# B_mat = B.detach()
|
| 287 |
+
# A_mat = A.detach()
|
| 288 |
+
# # import pdb; pdb.set_trace()
|
| 289 |
+
|
| 290 |
+
# # 统一把 vW 变成二维: [out, in_flat]
|
| 291 |
+
# if vW.ndim != 2:
|
| 292 |
+
# vW2 = vW.reshape(vW.shape[0], -1)
|
| 293 |
+
# else:
|
| 294 |
+
# vW2 = vW
|
| 295 |
+
|
| 296 |
+
# # A 也可能不是严格二维(一般是二维,但保险起见)#看了一下AB都是二维
|
| 297 |
+
# if A_mat.ndim != 2:
|
| 298 |
+
# A2 = A_mat.reshape(A_mat.shape[0], -1) # [r, in_flat]
|
| 299 |
+
# else:
|
| 300 |
+
# A2 = A_mat
|
| 301 |
+
|
| 302 |
+
# # B 通常是二维 [out, r]
|
| 303 |
+
# if B_mat.ndim != 2:
|
| 304 |
+
# B2 = B_mat.reshape(B_mat.shape[0], -1) # [out, r]
|
| 305 |
+
# else:
|
| 306 |
+
# B2 = B_mat
|
| 307 |
+
|
| 308 |
+
# # 形状校验:不匹配就跳过这个 w_name(避免再报错)
|
| 309 |
+
# # 需要:B2: [out, r] 与 vW2: [out, in_flat] 的 out 对齐
|
| 310 |
+
# # 需要:A2: [r, in_flat] 与 vW2: [out, in_flat] 的 in_flat 对齐
|
| 311 |
+
# if B2.shape[0] != vW2.shape[0] or A2.shape[1] != vW2.shape[1] or A2.shape[0] != B2.shape[1]:
|
| 312 |
+
# missed += 1
|
| 313 |
+
# continue
|
| 314 |
+
|
| 315 |
+
# vA = torch.matmul(B2.transpose(0, 1), vW2) # [r, in_flat]
|
| 316 |
+
# vB = torch.matmul(vW2, A2.transpose(0, 1)) # [out, r]
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# # hook A
|
| 320 |
+
# if A.requires_grad:
|
| 321 |
+
# vA_norm2 = (vA * vA).sum().detach()
|
| 322 |
+
# if vA_norm2.item() > EPS:
|
| 323 |
+
# def make_hook_A(v, vn2):
|
| 324 |
+
# def hook(g):
|
| 325 |
+
# if debug and not printed["A"]:
|
| 326 |
+
# print(f"[hook fired] lora_A grad norm: {g.norm().item():.4e}")
|
| 327 |
+
# printed["A"] = True
|
| 328 |
+
# gv = (g * v).sum()
|
| 329 |
+
# proj = (gv / (vn2 + EPS)) * v
|
| 330 |
+
# return g - proj
|
| 331 |
+
# return hook
|
| 332 |
+
|
| 333 |
+
# A.register_hook(make_hook_A(vA, vA_norm2))
|
| 334 |
+
# hooked_A += 1
|
| 335 |
+
|
| 336 |
+
# # hook B
|
| 337 |
+
# if B.requires_grad:
|
| 338 |
+
# vB_norm2 = (vB * vB).sum().detach()
|
| 339 |
+
# if vB_norm2.item() > EPS:
|
| 340 |
+
# def make_hook_B(v, vn2):
|
| 341 |
+
# def hook(g):
|
| 342 |
+
# if debug and not printed["B"]:
|
| 343 |
+
# print(f"[hook fired] lora_B grad norm: {g.norm().item():.4e}")
|
| 344 |
+
# printed["B"] = True
|
| 345 |
+
# gv = (g * v).sum()
|
| 346 |
+
# proj = (gv / (vn2 + EPS)) * v
|
| 347 |
+
# return g - proj
|
| 348 |
+
# return hook
|
| 349 |
+
|
| 350 |
+
# B.register_hook(make_hook_B(vB, vB_norm2))
|
| 351 |
+
# hooked_B += 1
|
| 352 |
+
|
| 353 |
+
# print(f"[hook summary] hooked lora_A: {hooked_A}, hooked lora_B: {hooked_B}, missed: {missed}")
|
| 354 |
+
# import pdb; pdb.set_trace()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# 用法:
|
| 359 |
+
# vector_sd = torch.load("your_vector.pth")["state_dict"] or similar
|
| 360 |
+
# register_orthogonal_grad_hook(model, vector_sd)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# import debugpy
|
| 364 |
+
# try:
|
| 365 |
+
# debugpy.listen(("localhost", 9501))
|
| 366 |
+
# print("Waiting for debugger attach")
|
| 367 |
+
# debugpy.wait_for_client()
|
| 368 |
+
# except Exception as e:
|
| 369 |
+
# pass
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@dataclass
|
| 373 |
+
class FinetuneConfig:
|
| 374 |
+
# fmt: off
|
| 375 |
+
vla_path: str = "openvla/openvla-7b" # Path to OpenVLA model (on HuggingFace Hub or stored locally)
|
| 376 |
+
|
| 377 |
+
# Dataset
|
| 378 |
+
data_root_dir: Path = Path("datasets/rlds") # Directory containing RLDS datasets
|
| 379 |
+
dataset_name: str = "aloha_scoop_x_into_bowl" # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`)
|
| 380 |
+
run_root_dir: Path = Path("runs") # Path to directory to store logs & checkpoints
|
| 381 |
+
shuffle_buffer_size: int = 100_000 # Dataloader shuffle buffer size (can reduce if OOM errors occur)
|
| 382 |
+
|
| 383 |
+
# Algorithm and architecture
|
| 384 |
+
use_l1_regression: bool = True # If True, trains continuous action head with L1 regression objective
|
| 385 |
+
use_diffusion: bool = False # If True, trains continuous action head with diffusion modeling objective (DDIM)
|
| 386 |
+
num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
|
| 387 |
+
use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
|
| 388 |
+
num_images_in_input: int = 1 # Number of images in the VLA input (default: 1)
|
| 389 |
+
use_proprio: bool = False # If True, includes robot proprioceptive state in input
|
| 390 |
+
|
| 391 |
+
# Training configuration
|
| 392 |
+
batch_size: int = 8 # Batch size per device (total batch size = batch_size * num GPUs)
|
| 393 |
+
learning_rate: float = 5e-4 # Learning rate
|
| 394 |
+
lr_warmup_steps: int = 0 # Number of steps to warm up learning rate (from 10% to 100%)
|
| 395 |
+
num_steps_before_decay: int = 100_000 # Number of steps before LR decays by 10x
|
| 396 |
+
grad_accumulation_steps: int = 1 # Number of gradient accumulation steps
|
| 397 |
+
max_steps: int = 200_000 # Max number of training steps
|
| 398 |
+
use_val_set: bool = False # If True, uses validation set and log validation metrics
|
| 399 |
+
val_freq: int = 10_000 # (When `use_val_set==True`) Validation set logging frequency in steps
|
| 400 |
+
val_time_limit: int = 180 # (When `use_val_set==True`) Time limit for computing validation metrics
|
| 401 |
+
save_freq: int = 10_000 # Checkpoint saving frequency in steps
|
| 402 |
+
save_latest_checkpoint_only: bool = False # If True, saves only 1 checkpoint, overwriting latest checkpoint
|
| 403 |
+
# (If False, saves all checkpoints)
|
| 404 |
+
scheduler: str = 'MultiStepLR' # "MultiStepLR" or "CosineAnnealingLR" or "WarmupCosineLR"
|
| 405 |
+
resume: bool = False # If True, resumes from checkpoint
|
| 406 |
+
resume_step: Optional[int] = None # (When `resume==True`) Step number that we are resuming from
|
| 407 |
+
image_aug: bool = True # If True, trains with image augmentations (HIGHLY RECOMMENDED)
|
| 408 |
+
diffusion_sample_freq: int = 50 # (When `use_diffusion==True`) Frequency for sampling in steps
|
| 409 |
+
|
| 410 |
+
# LoRA
|
| 411 |
+
use_lora: bool = True # If True, uses LoRA fine-tuning
|
| 412 |
+
lora_rank: int = 32 # Rank of LoRA weight matrix
|
| 413 |
+
lora_dropout: float = 0.0 # Dropout applied to LoRA weights
|
| 414 |
+
merge_lora_during_training: bool = True # If True, merges LoRA weights and saves result during training
|
| 415 |
+
# Note: Merging can be very slow on some machines. If so, set to
|
| 416 |
+
# False and merge final checkpoint offline!
|
| 417 |
+
|
| 418 |
+
# Regularization
|
| 419 |
+
regularization_lora_vector_path: str = None # Path to regularization vector
|
| 420 |
+
regularization_weight: float = 1e-3 # Weight of regularization loss
|
| 421 |
+
|
| 422 |
+
# Logging
|
| 423 |
+
wandb_entity: str = "your-wandb-entity" # Name of WandB entity
|
| 424 |
+
wandb_project: str = "your-wandb-project" # Name of WandB project
|
| 425 |
+
run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
|
| 426 |
+
run_id_override: Optional[str] = None # Optional string to override the run ID with
|
| 427 |
+
wandb_log_freq: int = 10 # WandB logging frequency in steps
|
| 428 |
+
|
| 429 |
+
# EMA
|
| 430 |
+
use_ema: bool = False # If True, maintains an EMA copy of the model
|
| 431 |
+
inv_gamma: float = 1 # EMA inverse gamma parameter
|
| 432 |
+
|
| 433 |
+
# fmt: on
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def remove_ddp_in_checkpoint(state_dict) -> dict:
|
| 437 |
+
"""
|
| 438 |
+
Removes the 'module.' prefix from parameter names in a PyTorch model state dictionary that was saved using
|
| 439 |
+
DistributedDataParallel (DDP).
|
| 440 |
+
|
| 441 |
+
When a model is trained using PyTorch's DistributedDataParallel, the saved state dictionary contains parameters
|
| 442 |
+
prefixed with 'module.'. This function removes these prefixes to make the state dictionary compatible when
|
| 443 |
+
loading into models that are not yet wrapped in DDP.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
state_dict (dict): PyTorch model state dictionary.
|
| 447 |
+
|
| 448 |
+
Returns:
|
| 449 |
+
dict: A new state dictionary with the same contents but with 'module.' prefixes removed from parameter names.
|
| 450 |
+
Parameters without the 'module.' prefix remain unchanged.
|
| 451 |
+
"""
|
| 452 |
+
new_state_dict = {}
|
| 453 |
+
for k, v in state_dict.items():
|
| 454 |
+
if k[:7] == "module.":
|
| 455 |
+
new_state_dict[k[7:]] = v
|
| 456 |
+
else:
|
| 457 |
+
new_state_dict[k] = v
|
| 458 |
+
return new_state_dict
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def get_run_id(cfg) -> str:
|
| 462 |
+
"""
|
| 463 |
+
Generates or retrieves an identifier string for an experiment run.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
cfg (FinetuneConfig): Training configuration.
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
str: Experiment run ID.
|
| 470 |
+
"""
|
| 471 |
+
if cfg.run_id_override is not None:
|
| 472 |
+
# Override the run ID with the user-provided ID
|
| 473 |
+
run_id = cfg.run_id_override
|
| 474 |
+
elif cfg.resume:
|
| 475 |
+
# Override run ID with the previous resumed run's ID
|
| 476 |
+
run_id = cfg.vla_path.split("/")[-1]
|
| 477 |
+
# Remove the "--XXX_chkpt" suffix from the run ID if it exists
|
| 478 |
+
if "chkpt" in run_id.split("--")[-1]:
|
| 479 |
+
run_id = "--".join(run_id.split("--")[:-1])
|
| 480 |
+
else:
|
| 481 |
+
run_id = (
|
| 482 |
+
f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}"
|
| 483 |
+
f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
|
| 484 |
+
f"+lr-{cfg.learning_rate}"
|
| 485 |
+
)
|
| 486 |
+
if cfg.use_lora:
|
| 487 |
+
run_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}"
|
| 488 |
+
if cfg.image_aug:
|
| 489 |
+
run_id += "--image_aug"
|
| 490 |
+
if cfg.run_id_note is not None:
|
| 491 |
+
run_id += f"--{cfg.run_id_note}"
|
| 492 |
+
return run_id
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def load_checkpoint(module_name: str, path: str, step: int, device: str = "cpu") -> dict:
|
| 496 |
+
"""
|
| 497 |
+
Loads a checkpoint for a given module.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
module_name (str): Name of model component to load checkpoint for.
|
| 501 |
+
path (str): Path to checkpoint directory.
|
| 502 |
+
step (int): Gradient step number of saved checkpoint.
|
| 503 |
+
device (str): String specifying how to remap storage locations (default = "cpu").
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
dict: PyTorch model state dictionary.
|
| 507 |
+
"""
|
| 508 |
+
checkpoint_path = os.path.join(path, f"{module_name}--{step}_checkpoint.pt")
|
| 509 |
+
print(f"Loading checkpoint: {checkpoint_path}")
|
| 510 |
+
state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
|
| 511 |
+
return remove_ddp_in_checkpoint(state_dict)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def wrap_ddp(module: nn.Module, device_id: int, find_unused: bool = False) -> DDP:
|
| 515 |
+
"""
|
| 516 |
+
Wrap a module with DistributedDataParallel.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
module (nn.Module): PyTorch module.
|
| 520 |
+
device_id (str): Device ID.
|
| 521 |
+
find_unused (bool): Whether to detect parameters without gradients in distributed training.
|
| 522 |
+
|
| 523 |
+
Returns:
|
| 524 |
+
DistributedDataParallel: PyTorch module wrapped with DDP.
|
| 525 |
+
"""
|
| 526 |
+
return DDP(module, device_ids=[device_id], find_unused_parameters=find_unused, gradient_as_bucket_view=True)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def count_parameters(module: nn.Module, name: str) -> None:
|
| 530 |
+
"""
|
| 531 |
+
Counts and prints the number of trainable parameters in a module.
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
module (nn.Module): PyTorch module.
|
| 535 |
+
module_name (str): Name of model component.
|
| 536 |
+
|
| 537 |
+
Returns:
|
| 538 |
+
None.
|
| 539 |
+
"""
|
| 540 |
+
num_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 541 |
+
print(f"# trainable params in {name}: {num_params}")
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def init_module(
|
| 545 |
+
module_class: Type[nn.Module],
|
| 546 |
+
module_name: str,
|
| 547 |
+
cfg: FinetuneConfig,
|
| 548 |
+
device_id: int,
|
| 549 |
+
module_args: dict,
|
| 550 |
+
to_bf16: bool = False,
|
| 551 |
+
find_unused_params: bool = False,
|
| 552 |
+
) -> DDP:
|
| 553 |
+
"""
|
| 554 |
+
Initializes a module, optionally loads checkpoint, moves to device, and wraps with DDP.
|
| 555 |
+
|
| 556 |
+
Args:
|
| 557 |
+
module_class (Type[nn.Module]): Class of PyTorch module to initialize.
|
| 558 |
+
module_name (str): Name of model component to load checkpoint for.
|
| 559 |
+
cfg (FinetuneConfig): Training configuration.
|
| 560 |
+
device_id (str): Device ID.
|
| 561 |
+
module_args (dict): Args for initializing the module.
|
| 562 |
+
to_bf16 (bool): Whether to convert to torch.bfloat16 data type.
|
| 563 |
+
find_unused_params (bool): Whether to detect parameters without gradients in distributed training.
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
DistributedDataParallel: PyTorch module wrapped with DDP.
|
| 567 |
+
"""
|
| 568 |
+
module = module_class(**module_args)
|
| 569 |
+
count_parameters(module, module_name)
|
| 570 |
+
|
| 571 |
+
if cfg.resume:
|
| 572 |
+
state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step)
|
| 573 |
+
module.load_state_dict(state_dict)
|
| 574 |
+
|
| 575 |
+
if to_bf16:
|
| 576 |
+
module = module.to(torch.bfloat16)
|
| 577 |
+
module = module.to(device_id)
|
| 578 |
+
|
| 579 |
+
return wrap_ddp(module, device_id, find_unused_params)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def run_forward_pass(
|
| 583 |
+
vla,
|
| 584 |
+
action_head,
|
| 585 |
+
noisy_action_projector,
|
| 586 |
+
proprio_projector,
|
| 587 |
+
batch,
|
| 588 |
+
action_tokenizer,
|
| 589 |
+
device_id,
|
| 590 |
+
use_l1_regression,
|
| 591 |
+
use_diffusion,
|
| 592 |
+
use_proprio,
|
| 593 |
+
use_film,
|
| 594 |
+
num_patches,
|
| 595 |
+
compute_diffusion_l1=False,
|
| 596 |
+
num_diffusion_steps_train=None,
|
| 597 |
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
| 598 |
+
"""
|
| 599 |
+
Compute model forward pass and metrics for both training and validation.
|
| 600 |
+
|
| 601 |
+
Args:
|
| 602 |
+
vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| 603 |
+
action_head (nn.Module): Action head module.
|
| 604 |
+
noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| 605 |
+
proprio_projector (nn.Module): Proprioceptive state projector module.
|
| 606 |
+
batch (dict): Input batch.
|
| 607 |
+
action_tokenizer (ActionTokenizer): Action tokenizer.
|
| 608 |
+
device_id (str): Device ID.
|
| 609 |
+
use_l1_regression (bool): Whether to use L1 regression.
|
| 610 |
+
use_diffusion (bool): Whether to use diffusion.
|
| 611 |
+
use_proprio (bool): Whether to use proprioceptive state as input.
|
| 612 |
+
use_film (bool): Whether to use FiLM for better language following.
|
| 613 |
+
num_patches (int): Number of vision patches.
|
| 614 |
+
compute_diffusion_l1 (bool): Whether to sample actions and compute L1 loss for diffusion (do this once every
|
| 615 |
+
diffusion_sample_freq steps during training; do it every batch for validation)
|
| 616 |
+
num_diffusion_steps_train (int): Number of diffusion steps for training (only used for diffusion).
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
tuple: (loss, metrics_dict)
|
| 620 |
+
loss: The loss tensor with gradient for backpropagation.
|
| 621 |
+
metrics_dict: Dictionary of computed metrics (detached values for logging).
|
| 622 |
+
"""
|
| 623 |
+
metrics = {}
|
| 624 |
+
|
| 625 |
+
# Get ground-truth action labels
|
| 626 |
+
ground_truth_actions = batch["actions"].to(device_id).to(torch.bfloat16)
|
| 627 |
+
|
| 628 |
+
# [Only for diffusion] Sample noisy actions used as input for noise predictor network
|
| 629 |
+
if use_diffusion:
|
| 630 |
+
noisy_dict = action_head.module.sample_noisy_actions(ground_truth_actions)
|
| 631 |
+
noise, noisy_actions, diffusion_timestep_embeddings = (
|
| 632 |
+
noisy_dict["noise"],
|
| 633 |
+
noisy_dict["noisy_actions"],
|
| 634 |
+
noisy_dict["diffusion_timestep_embeddings"],
|
| 635 |
+
)
|
| 636 |
+
else:
|
| 637 |
+
noise, noisy_actions, diffusion_timestep_embeddings = None, None, None
|
| 638 |
+
|
| 639 |
+
# VLA forward pass
|
| 640 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 641 |
+
output: CausalLMOutputWithPast = vla(
|
| 642 |
+
input_ids=batch["input_ids"].to(device_id),
|
| 643 |
+
attention_mask=batch["attention_mask"].to(device_id),
|
| 644 |
+
pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
|
| 645 |
+
labels=batch["labels"],
|
| 646 |
+
output_hidden_states=True,
|
| 647 |
+
proprio=batch["proprio"] if use_proprio else None,
|
| 648 |
+
proprio_projector=proprio_projector if use_proprio else None,
|
| 649 |
+
noisy_actions=noisy_actions if use_diffusion else None,
|
| 650 |
+
noisy_action_projector=noisy_action_projector if use_diffusion else None,
|
| 651 |
+
diffusion_timestep_embeddings=diffusion_timestep_embeddings if use_diffusion else None,
|
| 652 |
+
use_film=use_film,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# Get action masks needed for logging
|
| 656 |
+
ground_truth_token_ids = batch["labels"][:, 1:].to(device_id)
|
| 657 |
+
current_action_mask = get_current_action_mask(ground_truth_token_ids)
|
| 658 |
+
next_actions_mask = get_next_actions_mask(ground_truth_token_ids)
|
| 659 |
+
|
| 660 |
+
# Compute metrics for discrete action representation (next-token prediction)
|
| 661 |
+
if not (use_l1_regression or use_diffusion):
|
| 662 |
+
loss = output.loss
|
| 663 |
+
predicted_token_ids = output.logits[:, num_patches:-1].argmax(dim=2)
|
| 664 |
+
curr_action_accuracy = compute_token_accuracy(
|
| 665 |
+
predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
|
| 666 |
+
)
|
| 667 |
+
curr_action_l1_loss = compute_actions_l1_loss(
|
| 668 |
+
action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
|
| 669 |
+
)
|
| 670 |
+
next_actions_accuracy = compute_token_accuracy(
|
| 671 |
+
predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
|
| 672 |
+
)
|
| 673 |
+
next_actions_l1_loss = compute_actions_l1_loss(
|
| 674 |
+
action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
|
| 675 |
+
)
|
| 676 |
+
metrics.update(
|
| 677 |
+
{
|
| 678 |
+
"loss_value": loss.item(), # Detached value for logging
|
| 679 |
+
"curr_action_accuracy": curr_action_accuracy.item(),
|
| 680 |
+
"curr_action_l1_loss": curr_action_l1_loss.item(),
|
| 681 |
+
"next_actions_accuracy": next_actions_accuracy.item(),
|
| 682 |
+
"next_actions_l1_loss": next_actions_l1_loss.item(),
|
| 683 |
+
}
|
| 684 |
+
)
|
| 685 |
+
# Compute metrics for continuous action representations (L1 regression | diffusion)
|
| 686 |
+
else:
|
| 687 |
+
# Get last layer hidden states
|
| 688 |
+
last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
| 689 |
+
# Get hidden states for text portion of prompt+response (after the vision patches)
|
| 690 |
+
text_hidden_states = last_hidden_states[:, num_patches:-1]
|
| 691 |
+
# Get hidden states for action portion of response
|
| 692 |
+
batch_size = batch["input_ids"].shape[0]
|
| 693 |
+
actions_hidden_states = (
|
| 694 |
+
text_hidden_states[current_action_mask | next_actions_mask]
|
| 695 |
+
.reshape(batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1)
|
| 696 |
+
.to(torch.bfloat16)
|
| 697 |
+
) # (B, act_chunk_len, D)
|
| 698 |
+
|
| 699 |
+
if use_l1_regression:
|
| 700 |
+
# Predict action
|
| 701 |
+
predicted_actions = action_head.module.predict_action(actions_hidden_states)
|
| 702 |
+
# Get full L1 loss
|
| 703 |
+
loss = torch.nn.L1Loss()(ground_truth_actions, predicted_actions)
|
| 704 |
+
|
| 705 |
+
if use_diffusion:
|
| 706 |
+
# Predict noise
|
| 707 |
+
noise_pred = action_head.module.predict_noise(actions_hidden_states)
|
| 708 |
+
# Get diffusion noise prediction MSE loss
|
| 709 |
+
noise_pred = noise_pred.reshape(noise.shape)
|
| 710 |
+
loss = nn.functional.mse_loss(noise_pred, noise, reduction="mean")
|
| 711 |
+
|
| 712 |
+
# Only sample actions and compute L1 losses if specified
|
| 713 |
+
if compute_diffusion_l1:
|
| 714 |
+
with torch.no_grad():
|
| 715 |
+
predicted_actions = run_diffusion_sampling(
|
| 716 |
+
vla=vla,
|
| 717 |
+
action_head=action_head,
|
| 718 |
+
noisy_action_projector=noisy_action_projector,
|
| 719 |
+
proprio_projector=proprio_projector,
|
| 720 |
+
batch=batch,
|
| 721 |
+
batch_size=batch_size,
|
| 722 |
+
num_patches=num_patches,
|
| 723 |
+
actions_shape=ground_truth_actions.shape,
|
| 724 |
+
device_id=device_id,
|
| 725 |
+
current_action_mask=current_action_mask,
|
| 726 |
+
next_actions_mask=next_actions_mask,
|
| 727 |
+
use_proprio=use_proprio,
|
| 728 |
+
use_film=use_film,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
metrics.update(
|
| 732 |
+
{
|
| 733 |
+
"loss_value": loss.item(), # Detached value for logging
|
| 734 |
+
}
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
# Get detailed L1 losses for logging
|
| 738 |
+
should_log_l1_loss = not use_diffusion or (use_diffusion and compute_diffusion_l1)
|
| 739 |
+
if should_log_l1_loss:
|
| 740 |
+
ground_truth_curr_action = ground_truth_actions[:, 0]
|
| 741 |
+
predicted_curr_action = predicted_actions[:, 0]
|
| 742 |
+
ground_truth_next_actions = ground_truth_actions[:, 1:]
|
| 743 |
+
predicted_next_actions = predicted_actions[:, 1:]
|
| 744 |
+
curr_action_l1_loss = torch.nn.L1Loss()(ground_truth_curr_action, predicted_curr_action)
|
| 745 |
+
next_actions_l1_loss = torch.nn.L1Loss()(ground_truth_next_actions, predicted_next_actions)
|
| 746 |
+
metrics.update(
|
| 747 |
+
{
|
| 748 |
+
"curr_action_l1_loss": curr_action_l1_loss.item(),
|
| 749 |
+
"next_actions_l1_loss": next_actions_l1_loss.item(),
|
| 750 |
+
}
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# Return both the loss tensor (with gradients) and the metrics dictionary (with detached values)
|
| 754 |
+
return loss, metrics
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def run_diffusion_sampling(
|
| 758 |
+
vla,
|
| 759 |
+
action_head,
|
| 760 |
+
noisy_action_projector,
|
| 761 |
+
proprio_projector,
|
| 762 |
+
batch,
|
| 763 |
+
batch_size,
|
| 764 |
+
num_patches,
|
| 765 |
+
actions_shape,
|
| 766 |
+
device_id,
|
| 767 |
+
current_action_mask,
|
| 768 |
+
next_actions_mask,
|
| 769 |
+
use_proprio,
|
| 770 |
+
use_film,
|
| 771 |
+
) -> torch.Tensor:
|
| 772 |
+
"""
|
| 773 |
+
Run diffusion sampling (reverse diffusion) to generate actions.
|
| 774 |
+
|
| 775 |
+
Args:
|
| 776 |
+
vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| 777 |
+
action_head (nn.Module): Action head module.
|
| 778 |
+
noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| 779 |
+
proprio_projector (nn.Module): Proprioceptive state projector module.
|
| 780 |
+
batch (dict): Input batch.
|
| 781 |
+
batch_size (int): Batch size.
|
| 782 |
+
num_patches (int): Number of vision patches.
|
| 783 |
+
actions_shape (tuple): Shape of ground-truth actions.
|
| 784 |
+
device_id (str): Device ID.
|
| 785 |
+
current_action_mask (torch.Tensor): Mask for current action.
|
| 786 |
+
next_actions_mask (torch.Tensor): Mask for next actions.
|
| 787 |
+
use_proprio (bool): Whether to use proprioceptive state as input.
|
| 788 |
+
use_film (bool): Whether to use FiLM for better language following.
|
| 789 |
+
|
| 790 |
+
Returns:
|
| 791 |
+
torch.Tensor: Predicted actions.
|
| 792 |
+
"""
|
| 793 |
+
# Sample random noisy action, used as the starting point for reverse diffusion
|
| 794 |
+
noise = torch.randn(
|
| 795 |
+
size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM),
|
| 796 |
+
device=device_id,
|
| 797 |
+
dtype=torch.bfloat16,
|
| 798 |
+
) # (B, chunk_len, action_dim)
|
| 799 |
+
|
| 800 |
+
# Set diffusion timestep values
|
| 801 |
+
action_head.module.noise_scheduler.set_timesteps(action_head.module.num_diffusion_steps_train)
|
| 802 |
+
|
| 803 |
+
# Reverse diffusion: Iteratively denoise to generate action, conditioned on observation
|
| 804 |
+
curr_noisy_actions = noise
|
| 805 |
+
for t in action_head.module.noise_scheduler.timesteps:
|
| 806 |
+
# Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action embedding,
|
| 807 |
+
# and diffusion timestep embedding)
|
| 808 |
+
timesteps = torch.Tensor([t]).repeat(batch_size).to(device_id)
|
| 809 |
+
diffusion_timestep_embeddings = (
|
| 810 |
+
action_head.module.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
|
| 811 |
+
) # (B, llm_dim)
|
| 812 |
+
diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
|
| 813 |
+
|
| 814 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 815 |
+
output = vla(
|
| 816 |
+
input_ids=batch["input_ids"].to(device_id),
|
| 817 |
+
attention_mask=batch["attention_mask"].to(device_id),
|
| 818 |
+
pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
|
| 819 |
+
labels=batch["labels"],
|
| 820 |
+
output_hidden_states=True,
|
| 821 |
+
proprio=batch["proprio"] if use_proprio else None,
|
| 822 |
+
proprio_projector=proprio_projector if use_proprio else None,
|
| 823 |
+
noisy_actions=curr_noisy_actions,
|
| 824 |
+
noisy_action_projector=noisy_action_projector,
|
| 825 |
+
diffusion_timestep_embeddings=diffusion_timestep_embeddings,
|
| 826 |
+
use_film=use_film,
|
| 827 |
+
)
|
| 828 |
+
# Get last layer hidden states
|
| 829 |
+
last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
|
| 830 |
+
# Get hidden states for text portion of prompt+response (after the vision patches)
|
| 831 |
+
text_hidden_states = last_hidden_states[:, num_patches:-1]
|
| 832 |
+
# Get hidden states for action portion of response
|
| 833 |
+
actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask].reshape(
|
| 834 |
+
batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1
|
| 835 |
+
) # (B, act_chunk_len, D)
|
| 836 |
+
actions_hidden_states = actions_hidden_states.to(torch.bfloat16)
|
| 837 |
+
# Predict noise
|
| 838 |
+
noise_pred = action_head.module.predict_noise(actions_hidden_states)
|
| 839 |
+
|
| 840 |
+
# Compute the action at the previous diffusion timestep: x_t -> x_{t-1}
|
| 841 |
+
curr_noisy_actions = action_head.module.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
|
| 842 |
+
|
| 843 |
+
return curr_noisy_actions.reshape(actions_shape)
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
def compute_smoothened_metrics(metrics_deques) -> dict:
|
| 847 |
+
"""
|
| 848 |
+
Compute smoothened metrics from recent deques.
|
| 849 |
+
|
| 850 |
+
Args:
|
| 851 |
+
metrics_deques (dict): Dictionary of deques containing recent metrics.
|
| 852 |
+
|
| 853 |
+
Returns:
|
| 854 |
+
dict: Dictionary of smoothened metrics.
|
| 855 |
+
"""
|
| 856 |
+
smoothened_metrics = {}
|
| 857 |
+
for name, deque in metrics_deques.items():
|
| 858 |
+
if deque and len(deque) > 0:
|
| 859 |
+
smoothened_metrics[name] = sum(deque) / len(deque)
|
| 860 |
+
return smoothened_metrics
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def compute_diff_regularization_loss(model, diff_params_dict, regularization_weight=1.0):
|
| 864 |
+
"""
|
| 865 |
+
计算模型参数和diff_path中同名参数之间的正则化loss,用于防止模型参数向diff_path参数的方向更新。
|
| 866 |
+
参考正交化loss的实现方式,计算参数之间的内积来惩罚相似性。
|
| 867 |
+
|
| 868 |
+
Args:
|
| 869 |
+
model: 模型(可能是DDP包装的)
|
| 870 |
+
diff_params_dict: 从diff_path加载的参数字典
|
| 871 |
+
regularization_weight: 正则化权重
|
| 872 |
+
|
| 873 |
+
Returns:
|
| 874 |
+
regularization_loss: 正则化loss值
|
| 875 |
+
"""
|
| 876 |
+
orthogonal_loss = 0.
|
| 877 |
+
matched_count = 0
|
| 878 |
+
|
| 879 |
+
# 获取模型的实际模块(如果是DDP包装的)
|
| 880 |
+
model_module = model.module if hasattr(model, 'module') else model
|
| 881 |
+
|
| 882 |
+
for name, param in model_module.named_parameters():
|
| 883 |
+
if "lora" in name:
|
| 884 |
+
if not param.requires_grad:
|
| 885 |
+
continue
|
| 886 |
+
|
| 887 |
+
# 尝试匹配diff_params_dict中的同名参数
|
| 888 |
+
# 需要处理可能的命名差异:
|
| 889 |
+
# 1. diff_path中可能没有"base_model.model."前缀
|
| 890 |
+
# 2. diff_path中可能在.lora_A或.lora_B后多了一个".default"
|
| 891 |
+
# 例如:model中是 "xxx.lora_A.weight"
|
| 892 |
+
# diff中是 "xxx.lora_A.default.weight"
|
| 893 |
+
matched_diff_param = None
|
| 894 |
+
|
| 895 |
+
# 首先尝试直接匹配
|
| 896 |
+
if name in diff_params_dict:
|
| 897 |
+
import pdb; pdb.set_trace()
|
| 898 |
+
matched_diff_param = diff_params_dict[name]
|
| 899 |
+
else:
|
| 900 |
+
# import pdb; pdb.set_trace()
|
| 901 |
+
# 尝试处理".default"的差异:在.lora_A或.lora_B后添加.default
|
| 902 |
+
# follow o-lora只约束lora_A的参数
|
| 903 |
+
if ".lora_A." in name:
|
| 904 |
+
name_with_default = name.replace(".lora_A.default.", ".lora_A.")
|
| 905 |
+
if name_with_default in diff_params_dict:
|
| 906 |
+
matched_diff_param = diff_params_dict[name_with_default]
|
| 907 |
+
# elif ".lora_B." in name:
|
| 908 |
+
# name_with_default = name.replace(".lora_B.default.", ".lora_B.")
|
| 909 |
+
# if name_with_default in diff_params_dict:
|
| 910 |
+
# matched_diff_param = diff_params_dict[name_with_default]
|
| 911 |
+
|
| 912 |
+
if matched_diff_param is not None:
|
| 913 |
+
# print(f"匹配到参数: {name}")
|
| 914 |
+
# 确保参数在同一个设备上
|
| 915 |
+
diff_param = matched_diff_param.to(device=param.device, dtype=param.dtype)
|
| 916 |
+
|
| 917 |
+
# 检查形状是否匹配
|
| 918 |
+
if param.shape == diff_param.shape:
|
| 919 |
+
# 使用detach().clone().requires_grad_()来避免DDP的重复标记问题
|
| 920 |
+
# 这会创建一个新的tensor,保持梯度连接,但不会触发DDP的重复标记
|
| 921 |
+
param_safe = param.clone()
|
| 922 |
+
diff_param_safe = diff_param.detach().clone()
|
| 923 |
+
|
| 924 |
+
# 对于视觉模型内的多维lora参数
|
| 925 |
+
param_flat = param_safe.reshape(-1) # [N]
|
| 926 |
+
diff_param_flat = diff_param_safe.reshape(-1) # [N]
|
| 927 |
+
inner_product = torch.abs((param_flat * diff_param_flat).sum())
|
| 928 |
+
orthogonal_loss += inner_product
|
| 929 |
+
matched_count += 1
|
| 930 |
+
# print(f"匹配到参数: {name} 的正则化loss: {inner_product}")
|
| 931 |
+
|
| 932 |
+
# print(f"正则化loss: {orthogonal_loss}")
|
| 933 |
+
if matched_count > 0:
|
| 934 |
+
orthogonal_loss = orthogonal_loss * regularization_weight
|
| 935 |
+
else:
|
| 936 |
+
# 如果没有匹配的参数,返回0(需要梯度,这样在backward时不会报错)
|
| 937 |
+
# 但实际梯度为0,所以不会影响训练
|
| 938 |
+
device = next(model_module.parameters()).device
|
| 939 |
+
orthogonal_loss = torch.tensor(0.0, device=device, requires_grad=True)
|
| 940 |
+
|
| 941 |
+
return orthogonal_loss
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def load_diff_params(diff_path, device="cpu"):
|
| 945 |
+
"""
|
| 946 |
+
从safetensors或pth文件加载参数。
|
| 947 |
+
|
| 948 |
+
Args:
|
| 949 |
+
diff_path: 参数文件路径
|
| 950 |
+
device: 加载到的设备
|
| 951 |
+
|
| 952 |
+
Returns:
|
| 953 |
+
diff_params_dict: 参数字典
|
| 954 |
+
"""
|
| 955 |
+
diff_params_dict = {}
|
| 956 |
+
|
| 957 |
+
if diff_path.endswith('.safetensors'):
|
| 958 |
+
if not SAFETENSORS_AVAILABLE:
|
| 959 |
+
raise ImportError("safetensors library is required to load .safetensors files")
|
| 960 |
+
|
| 961 |
+
with safe_open(diff_path, framework="pt", device=device) as f:
|
| 962 |
+
for key in f.keys():
|
| 963 |
+
diff_params_dict[key] = f.get_tensor(key)
|
| 964 |
+
else:
|
| 965 |
+
# 假设是pth或其他torch格式
|
| 966 |
+
loaded = torch.load(diff_path, map_location=device)
|
| 967 |
+
if isinstance(loaded, dict):
|
| 968 |
+
if "state_dict" in loaded:
|
| 969 |
+
diff_params_dict = loaded["state_dict"]
|
| 970 |
+
else:
|
| 971 |
+
diff_params_dict = loaded
|
| 972 |
+
else:
|
| 973 |
+
diff_params_dict = loaded
|
| 974 |
+
|
| 975 |
+
return diff_params_dict
|
| 976 |
+
|
| 977 |
+
|
| 978 |
+
def log_metrics_to_wandb(metrics, prefix, step, wandb_entity) -> None:
|
| 979 |
+
"""
|
| 980 |
+
Log metrics to Weights & Biases.
|
| 981 |
+
|
| 982 |
+
Args:
|
| 983 |
+
metrics (dict): Dictionary of metrics to log
|
| 984 |
+
prefix (str): Prefix for metric names
|
| 985 |
+
step (int): Training step
|
| 986 |
+
wandb_entity (str): W&B entity instance
|
| 987 |
+
|
| 988 |
+
Returns:
|
| 989 |
+
None.
|
| 990 |
+
"""
|
| 991 |
+
log_dict = {}
|
| 992 |
+
for name, value in metrics.items():
|
| 993 |
+
# Map loss_value to Loss for better readability in W&B
|
| 994 |
+
if name == "loss_value":
|
| 995 |
+
log_dict[f"{prefix}/Loss"] = value
|
| 996 |
+
# Keep other metrics as is
|
| 997 |
+
else:
|
| 998 |
+
log_dict[f"{prefix}/{name.replace('_', ' ').title()}"] = value
|
| 999 |
+
wandb_entity.log(log_dict, step=step)
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
def save_training_checkpoint(
|
| 1003 |
+
cfg,
|
| 1004 |
+
run_dir,
|
| 1005 |
+
log_step,
|
| 1006 |
+
vla,
|
| 1007 |
+
processor,
|
| 1008 |
+
proprio_projector,
|
| 1009 |
+
noisy_action_projector,
|
| 1010 |
+
action_head,
|
| 1011 |
+
train_dataset,
|
| 1012 |
+
distributed_state,
|
| 1013 |
+
) -> None:
|
| 1014 |
+
"""
|
| 1015 |
+
Save all training checkpoints including model components, LoRA adapter, and dataset statistics.
|
| 1016 |
+
|
| 1017 |
+
Args:
|
| 1018 |
+
cfg (FinetuneConfig): Training configuration.
|
| 1019 |
+
run_dir (Path): Experiment run directory path.
|
| 1020 |
+
log_step (int): Current logging step.
|
| 1021 |
+
vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| 1022 |
+
processor (PrismaticProcessor): OpenVLA inputs processor.
|
| 1023 |
+
proprio_projector (nn.Module): Proprioceptive state projector module.
|
| 1024 |
+
noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| 1025 |
+
action_head (nn.Module): Action head module.
|
| 1026 |
+
train_dataset (RLDSDataset): Training dataset.
|
| 1027 |
+
distributed_state (PartialState): Distributed training state.
|
| 1028 |
+
|
| 1029 |
+
Returns:
|
| 1030 |
+
None.
|
| 1031 |
+
"""
|
| 1032 |
+
# Determine checkpoint paths and naming
|
| 1033 |
+
if cfg.save_latest_checkpoint_only:
|
| 1034 |
+
checkpoint_dir = run_dir
|
| 1035 |
+
checkpoint_name_suffix = "latest_checkpoint.pt"
|
| 1036 |
+
else:
|
| 1037 |
+
checkpoint_dir = run_dir / f"{log_step}_chkpt"
|
| 1038 |
+
checkpoint_name_suffix = f"{log_step}_checkpoint.pt"
|
| 1039 |
+
|
| 1040 |
+
adapter_dir = checkpoint_dir / "lora_adapter"
|
| 1041 |
+
|
| 1042 |
+
# Create directories and save dataset statistics (main process only)
|
| 1043 |
+
if distributed_state.is_main_process:
|
| 1044 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 1045 |
+
os.makedirs(adapter_dir, exist_ok=True)
|
| 1046 |
+
save_dataset_statistics(train_dataset.dataset_statistics, checkpoint_dir)
|
| 1047 |
+
print(f"Saving Model Checkpoint for Step {log_step}")
|
| 1048 |
+
|
| 1049 |
+
# Wait for directories to be created
|
| 1050 |
+
dist.barrier()
|
| 1051 |
+
|
| 1052 |
+
# Save model components (main process only)
|
| 1053 |
+
if distributed_state.is_main_process:
|
| 1054 |
+
# Save processor and LoRA adapter
|
| 1055 |
+
processor.save_pretrained(checkpoint_dir)
|
| 1056 |
+
vla.module.save_pretrained(adapter_dir)
|
| 1057 |
+
|
| 1058 |
+
# Save other components
|
| 1059 |
+
if cfg.use_proprio and proprio_projector is not None:
|
| 1060 |
+
torch.save(proprio_projector.state_dict(), checkpoint_dir / f"proprio_projector--{checkpoint_name_suffix}")
|
| 1061 |
+
|
| 1062 |
+
if cfg.use_diffusion and noisy_action_projector is not None:
|
| 1063 |
+
torch.save(
|
| 1064 |
+
noisy_action_projector.state_dict(), checkpoint_dir / f"noisy_action_projector--{checkpoint_name_suffix}"
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
if (cfg.use_l1_regression or cfg.use_diffusion) and action_head is not None:
|
| 1068 |
+
torch.save(action_head.state_dict(), checkpoint_dir / f"action_head--{checkpoint_name_suffix}")
|
| 1069 |
+
|
| 1070 |
+
if cfg.use_film:
|
| 1071 |
+
# To be safe, just save the entire vision backbone (not just FiLM components)
|
| 1072 |
+
torch.save(
|
| 1073 |
+
vla.module.vision_backbone.state_dict(), checkpoint_dir / f"vision_backbone--{checkpoint_name_suffix}"
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
# Wait for model components to be saved
|
| 1077 |
+
dist.barrier()
|
| 1078 |
+
|
| 1079 |
+
# Merge LoRA weights into base model and save resulting model checkpoint
|
| 1080 |
+
# Note: Can be very slow on some devices; if so, we recommend merging offline
|
| 1081 |
+
if cfg.use_lora and cfg.merge_lora_during_training:
|
| 1082 |
+
base_vla = AutoModelForVision2Seq.from_pretrained(
|
| 1083 |
+
cfg.vla_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
|
| 1084 |
+
)
|
| 1085 |
+
merged_vla = PeftModel.from_pretrained(base_vla, adapter_dir)
|
| 1086 |
+
merged_vla = merged_vla.merge_and_unload()
|
| 1087 |
+
|
| 1088 |
+
if distributed_state.is_main_process:
|
| 1089 |
+
merged_vla.save_pretrained(checkpoint_dir)
|
| 1090 |
+
print(f"Saved merged model for Step {log_step} at: {checkpoint_dir}")
|
| 1091 |
+
|
| 1092 |
+
# Wait for merged model to be saved
|
| 1093 |
+
dist.barrier()
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
def run_validation(
|
| 1097 |
+
vla,
|
| 1098 |
+
action_head,
|
| 1099 |
+
noisy_action_projector,
|
| 1100 |
+
proprio_projector,
|
| 1101 |
+
val_dataloader,
|
| 1102 |
+
action_tokenizer,
|
| 1103 |
+
device_id,
|
| 1104 |
+
cfg,
|
| 1105 |
+
num_patches,
|
| 1106 |
+
log_step,
|
| 1107 |
+
distributed_state,
|
| 1108 |
+
val_time_limit,
|
| 1109 |
+
) -> None:
|
| 1110 |
+
"""
|
| 1111 |
+
Compute validation set metrics for logging.
|
| 1112 |
+
|
| 1113 |
+
Args:
|
| 1114 |
+
vla (OpenVLAForActionPrediction): Vision-language-action policy.
|
| 1115 |
+
action_head (nn.Module): Action head module.
|
| 1116 |
+
noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
|
| 1117 |
+
proprio_projector (nn.Module): Proprioceptive state projector module.
|
| 1118 |
+
val_dataloader (DataLoader): Validation data loader.
|
| 1119 |
+
action_tokenizer (ActionTokenizer): Action tokenizer.
|
| 1120 |
+
device_id (str): Device ID.
|
| 1121 |
+
cfg (FinetuneConfig): Training configuration.
|
| 1122 |
+
num_patches (int): Number of vision patches.
|
| 1123 |
+
log_step (int): Current logging step.
|
| 1124 |
+
distributed_state (PartialState): Distributed training state.
|
| 1125 |
+
val_time_limit (int): Time limit for computing validation metrics.
|
| 1126 |
+
|
| 1127 |
+
Returns:
|
| 1128 |
+
None.
|
| 1129 |
+
"""
|
| 1130 |
+
val_start_time = time.time()
|
| 1131 |
+
vla.eval()
|
| 1132 |
+
val_batches_count = 0
|
| 1133 |
+
|
| 1134 |
+
# List to store validation metrics
|
| 1135 |
+
all_val_metrics = []
|
| 1136 |
+
|
| 1137 |
+
with torch.no_grad():
|
| 1138 |
+
for batch in val_dataloader:
|
| 1139 |
+
# Always compute L1 loss for validation, even for diffusion
|
| 1140 |
+
_, metrics = run_forward_pass(
|
| 1141 |
+
vla=vla,
|
| 1142 |
+
action_head=action_head,
|
| 1143 |
+
noisy_action_projector=noisy_action_projector,
|
| 1144 |
+
proprio_projector=proprio_projector,
|
| 1145 |
+
batch=batch,
|
| 1146 |
+
action_tokenizer=action_tokenizer,
|
| 1147 |
+
device_id=device_id,
|
| 1148 |
+
use_l1_regression=cfg.use_l1_regression,
|
| 1149 |
+
use_diffusion=cfg.use_diffusion,
|
| 1150 |
+
use_proprio=cfg.use_proprio,
|
| 1151 |
+
use_film=cfg.use_film,
|
| 1152 |
+
num_patches=num_patches,
|
| 1153 |
+
compute_diffusion_l1=True,
|
| 1154 |
+
num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
# Add the loss value to the metrics
|
| 1158 |
+
metrics["loss"] = metrics["loss_value"]
|
| 1159 |
+
all_val_metrics.append(metrics)
|
| 1160 |
+
val_batches_count += 1
|
| 1161 |
+
|
| 1162 |
+
# Cut testing on validation set short if it exceeds time limit
|
| 1163 |
+
if time.time() - val_start_time > val_time_limit:
|
| 1164 |
+
break
|
| 1165 |
+
|
| 1166 |
+
# Compute average validation metrics
|
| 1167 |
+
avg_val_metrics = {}
|
| 1168 |
+
for metric_name in all_val_metrics[0].keys():
|
| 1169 |
+
values = [metrics[metric_name] for metrics in all_val_metrics if metric_name in metrics]
|
| 1170 |
+
if values:
|
| 1171 |
+
avg_val_metrics[metric_name] = sum(values) / len(values)
|
| 1172 |
+
|
| 1173 |
+
# Add batch count to metrics
|
| 1174 |
+
avg_val_metrics["val_batches_count"] = val_batches_count
|
| 1175 |
+
|
| 1176 |
+
# Log validation metrics to W&B
|
| 1177 |
+
if distributed_state.is_main_process:
|
| 1178 |
+
log_metrics_to_wandb(avg_val_metrics, "VLA Val", log_step, wandb)
|
| 1179 |
+
|
| 1180 |
+
|
| 1181 |
+
@draccus.wrap()
|
| 1182 |
+
def finetune(cfg: FinetuneConfig) -> None:
|
| 1183 |
+
"""
|
| 1184 |
+
Fine-tunes base VLA on demonstration dataset via LoRA.
|
| 1185 |
+
|
| 1186 |
+
Allows toggling different action representations (discrete vs. continuous), different learning objectives
|
| 1187 |
+
(next-token prediction vs. L1 regression vs. diffusion), FiLM. Also allows for additional model inputs,
|
| 1188 |
+
such as additional camera images and robot proprioceptive state. Assumes parallel action generation with
|
| 1189 |
+
action chunking.
|
| 1190 |
+
|
| 1191 |
+
Args:
|
| 1192 |
+
cfg (FinetuneConfig): Training configuration.
|
| 1193 |
+
|
| 1194 |
+
Returns:
|
| 1195 |
+
None.
|
| 1196 |
+
"""
|
| 1197 |
+
assert cfg.use_lora, "Only LoRA fine-tuning is supported. Please set --use_lora=True!"
|
| 1198 |
+
assert not (cfg.use_l1_regression and cfg.use_diffusion), (
|
| 1199 |
+
"Cannot do both L1 regression and diffusion. Please pick one of them!"
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
# Trim trailing forward slash ('/') in VLA path if it exists
|
| 1203 |
+
cfg.vla_path = cfg.vla_path.rstrip("/")
|
| 1204 |
+
print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`")
|
| 1205 |
+
|
| 1206 |
+
# Get experiment run ID
|
| 1207 |
+
run_id = get_run_id(cfg)
|
| 1208 |
+
|
| 1209 |
+
# Create experiment run directory
|
| 1210 |
+
run_dir = cfg.run_root_dir / run_id
|
| 1211 |
+
os.makedirs(run_dir, exist_ok=True)
|
| 1212 |
+
|
| 1213 |
+
# GPU setup
|
| 1214 |
+
distributed_state = PartialState()
|
| 1215 |
+
device_id = distributed_state.local_process_index
|
| 1216 |
+
torch.cuda.set_device(device_id)
|
| 1217 |
+
torch.cuda.empty_cache()
|
| 1218 |
+
|
| 1219 |
+
# Initialize wandb logging
|
| 1220 |
+
if distributed_state.is_main_process:
|
| 1221 |
+
wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=run_id, id=run_id)
|
| 1222 |
+
|
| 1223 |
+
# Print detected constants
|
| 1224 |
+
print(
|
| 1225 |
+
"Detected constants:\n"
|
| 1226 |
+
f"\tNUM_ACTIONS_CHUNK: {NUM_ACTIONS_CHUNK}\n"
|
| 1227 |
+
f"\tACTION_DIM: {ACTION_DIM}\n"
|
| 1228 |
+
f"\tPROPRIO_DIM: {PROPRIO_DIM}\n"
|
| 1229 |
+
f"\tACTION_PROPRIO_NORMALIZATION_TYPE: {ACTION_PROPRIO_NORMALIZATION_TYPE}"
|
| 1230 |
+
)
|
| 1231 |
+
|
| 1232 |
+
# Two options:
|
| 1233 |
+
# (1) Base model is on Hugging Face Hub
|
| 1234 |
+
# - Then download it and record the path to the download directory
|
| 1235 |
+
# (2) Base model is stored locally
|
| 1236 |
+
# - Then register model config in HF Auto Classes
|
| 1237 |
+
# In both cases, we want to check whether any changes have been made to
|
| 1238 |
+
# the `modeling_prismatic.py` file in this codebase; if so, we will copy
|
| 1239 |
+
# the file to the downloaded or locally stored checkpoint directory so
|
| 1240 |
+
# that the user's changes to the VLA class logic go into effect
|
| 1241 |
+
if model_is_on_hf_hub(cfg.vla_path):
|
| 1242 |
+
# Download model directly from Hugging Face Hub
|
| 1243 |
+
vla_download_path = snapshot_download(repo_id=cfg.vla_path)
|
| 1244 |
+
# Overwrite VLA path
|
| 1245 |
+
cfg.vla_path = vla_download_path
|
| 1246 |
+
else:
|
| 1247 |
+
# Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub)
|
| 1248 |
+
AutoConfig.register("openvla", OpenVLAConfig)
|
| 1249 |
+
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
|
| 1250 |
+
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
|
| 1251 |
+
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
|
| 1252 |
+
|
| 1253 |
+
# Update config.json and sync model files
|
| 1254 |
+
if distributed_state.is_main_process:
|
| 1255 |
+
update_auto_map(cfg.vla_path)
|
| 1256 |
+
check_model_logic_mismatch(cfg.vla_path)
|
| 1257 |
+
|
| 1258 |
+
# Wait for model files to be synced
|
| 1259 |
+
dist.barrier()
|
| 1260 |
+
|
| 1261 |
+
# Load processor and VLA
|
| 1262 |
+
processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
|
| 1263 |
+
vla = AutoModelForVision2Seq.from_pretrained(
|
| 1264 |
+
cfg.vla_path,
|
| 1265 |
+
torch_dtype=torch.bfloat16,
|
| 1266 |
+
low_cpu_mem_usage=True,
|
| 1267 |
+
trust_remote_code=True,
|
| 1268 |
+
).to(device_id)
|
| 1269 |
+
|
| 1270 |
+
# Set number of images in VLA input
|
| 1271 |
+
vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input)
|
| 1272 |
+
|
| 1273 |
+
# LoRA setup
|
| 1274 |
+
if cfg.use_lora:
|
| 1275 |
+
lora_config = LoraConfig(
|
| 1276 |
+
r=cfg.lora_rank,
|
| 1277 |
+
lora_alpha=min(cfg.lora_rank, 16),
|
| 1278 |
+
lora_dropout=cfg.lora_dropout,
|
| 1279 |
+
target_modules="all-linear",
|
| 1280 |
+
init_lora_weights="gaussian",
|
| 1281 |
+
)
|
| 1282 |
+
vla = get_peft_model(vla, lora_config)
|
| 1283 |
+
vla.print_trainable_parameters()
|
| 1284 |
+
|
| 1285 |
+
# FiLM setup
|
| 1286 |
+
if cfg.use_film:
|
| 1287 |
+
count_parameters(vla.vision_backbone, "vla.vision_backbone (original)")
|
| 1288 |
+
# Wrap vision backbone with FiLM wrapper
|
| 1289 |
+
# Important: For this, must specify `vla.model.vision_backbone` instead of just `vla.vision_backbone`, since the
|
| 1290 |
+
# latter would cause the new wrapped backbone to be saved as a new attribute of `vla` instead of overwriting the
|
| 1291 |
+
# original one (due to the LoRA wrapper)
|
| 1292 |
+
vla.model.vision_backbone = FiLMedPrismaticVisionBackbone(
|
| 1293 |
+
vision_backbone=vla.model.vision_backbone,
|
| 1294 |
+
llm_dim=vla.llm_dim,
|
| 1295 |
+
)
|
| 1296 |
+
count_parameters(vla.vision_backbone, "vla.vision_backbone (post-wrap)")
|
| 1297 |
+
if cfg.resume:
|
| 1298 |
+
state_dict = load_checkpoint("vision_backbone", cfg.vla_path, cfg.resume_step)
|
| 1299 |
+
vla.model.vision_backbone.load_state_dict(state_dict)
|
| 1300 |
+
vla.model.vision_backbone = vla.model.vision_backbone.to(device_id)
|
| 1301 |
+
|
| 1302 |
+
# Wrap VLA with DDP
|
| 1303 |
+
vla = wrap_ddp(vla, device_id, find_unused=False)
|
| 1304 |
+
|
| 1305 |
+
# vla._set_static_graph()
|
| 1306 |
+
|
| 1307 |
+
# If applicable, instantiate proprio projector
|
| 1308 |
+
if cfg.use_proprio:
|
| 1309 |
+
proprio_projector = init_module(
|
| 1310 |
+
ProprioProjector,
|
| 1311 |
+
"proprio_projector",
|
| 1312 |
+
cfg,
|
| 1313 |
+
device_id,
|
| 1314 |
+
{"llm_dim": vla.module.llm_dim, "proprio_dim": PROPRIO_DIM},
|
| 1315 |
+
)
|
| 1316 |
+
else:
|
| 1317 |
+
proprio_projector = None
|
| 1318 |
+
|
| 1319 |
+
# If applicable, instantiate continuous action head for L1 regression
|
| 1320 |
+
if cfg.use_l1_regression:
|
| 1321 |
+
action_head = init_module(
|
| 1322 |
+
L1RegressionActionHead,
|
| 1323 |
+
"action_head",
|
| 1324 |
+
cfg,
|
| 1325 |
+
device_id,
|
| 1326 |
+
{"input_dim": vla.module.llm_dim, "hidden_dim": vla.module.llm_dim, "action_dim": ACTION_DIM},
|
| 1327 |
+
to_bf16=True,
|
| 1328 |
+
)
|
| 1329 |
+
else:
|
| 1330 |
+
action_head = None
|
| 1331 |
+
|
| 1332 |
+
# If applicable, instantiate diffusion action head and noisy action projector
|
| 1333 |
+
if cfg.use_diffusion:
|
| 1334 |
+
action_head = init_module(
|
| 1335 |
+
DiffusionActionHead,
|
| 1336 |
+
"action_head",
|
| 1337 |
+
cfg,
|
| 1338 |
+
device_id,
|
| 1339 |
+
{
|
| 1340 |
+
"input_dim": vla.module.llm_dim,
|
| 1341 |
+
"hidden_dim": vla.module.llm_dim,
|
| 1342 |
+
"action_dim": ACTION_DIM,
|
| 1343 |
+
"num_diffusion_steps_train": cfg.num_diffusion_steps_train,
|
| 1344 |
+
},
|
| 1345 |
+
to_bf16=True,
|
| 1346 |
+
)
|
| 1347 |
+
noisy_action_projector = init_module(
|
| 1348 |
+
NoisyActionProjector, "noisy_action_projector", cfg, device_id, {"llm_dim": vla.module.llm_dim}
|
| 1349 |
+
)
|
| 1350 |
+
else:
|
| 1351 |
+
noisy_action_projector = None
|
| 1352 |
+
|
| 1353 |
+
# EMA
|
| 1354 |
+
if cfg.use_ema:
|
| 1355 |
+
ema_vla = EMAModel(vla,
|
| 1356 |
+
action_head,
|
| 1357 |
+
proprio_projector,
|
| 1358 |
+
noisy_action_projector,
|
| 1359 |
+
inv_gamma=cfg.inv_gamma
|
| 1360 |
+
)
|
| 1361 |
+
|
| 1362 |
+
# Get number of vision patches
|
| 1363 |
+
NUM_PATCHES = vla.module.vision_backbone.get_num_patches() * vla.module.vision_backbone.get_num_images_in_input()
|
| 1364 |
+
# If we have proprio inputs, a single proprio embedding is appended to the end of the vision patch embeddings
|
| 1365 |
+
if cfg.use_proprio:
|
| 1366 |
+
NUM_PATCHES += 1
|
| 1367 |
+
# For diffusion, a single diffusion timestep embedding is appended to the end of the vision patch embeddings
|
| 1368 |
+
if cfg.use_diffusion:
|
| 1369 |
+
NUM_PATCHES += 1
|
| 1370 |
+
|
| 1371 |
+
diff_path = cfg.regularization_lora_vector_path # <- 改成你的
|
| 1372 |
+
|
| 1373 |
+
# Load diff parameters for regularization
|
| 1374 |
+
diff_params_dict = {}
|
| 1375 |
+
if diff_path and os.path.exists(diff_path):
|
| 1376 |
+
print(f"Loading diff parameters from {diff_path}")
|
| 1377 |
+
diff_params_dict = load_diff_params(diff_path, device="cpu")
|
| 1378 |
+
print(f"Loaded {len(diff_params_dict)} parameters from diff_path")
|
| 1379 |
+
else:
|
| 1380 |
+
print(f"Warning: diff_path {diff_path} does not exist, skipping regularization loss")
|
| 1381 |
+
|
| 1382 |
+
# Regularization weight (you can make this configurable via cfg if needed)
|
| 1383 |
+
regularization_weight = cfg.regularization_weight # 可以根据需要调整这个权重
|
| 1384 |
+
|
| 1385 |
+
# Instantiate optimizer
|
| 1386 |
+
trainable_params = [param for param in vla.parameters() if param.requires_grad]
|
| 1387 |
+
if cfg.use_l1_regression or cfg.use_diffusion:
|
| 1388 |
+
trainable_params += [param for param in action_head.parameters() if param.requires_grad]
|
| 1389 |
+
if cfg.use_diffusion:
|
| 1390 |
+
trainable_params += [param for param in noisy_action_projector.parameters() if param.requires_grad]
|
| 1391 |
+
if cfg.use_proprio:
|
| 1392 |
+
trainable_params += [param for param in proprio_projector.parameters() if param.requires_grad]
|
| 1393 |
+
print(f"# total trainable params: {sum(p.numel() for p in trainable_params)}")
|
| 1394 |
+
optimizer = AdamW(trainable_params, lr=cfg.learning_rate)
|
| 1395 |
+
|
| 1396 |
+
# Record original learning rate
|
| 1397 |
+
original_lr = optimizer.param_groups[0]["lr"]
|
| 1398 |
+
|
| 1399 |
+
# Create learning rate scheduler
|
| 1400 |
+
if cfg.scheduler == 'MultiStepLR':
|
| 1401 |
+
scheduler = MultiStepLR(
|
| 1402 |
+
optimizer,
|
| 1403 |
+
milestones=[cfg.num_steps_before_decay], # Number of steps after which LR will change
|
| 1404 |
+
gamma=0.1, # Multiplicative factor of learning rate decay
|
| 1405 |
+
)
|
| 1406 |
+
elif cfg.scheduler == 'CosineAnnealingLR':
|
| 1407 |
+
scheduler = CosineAnnealingLR(
|
| 1408 |
+
optimizer,
|
| 1409 |
+
T_max=cfg.max_steps, # Total number of steps for the cosine annealing
|
| 1410 |
+
eta_min=cfg.learning_rate * 1e-3,
|
| 1411 |
+
)
|
| 1412 |
+
elif cfg.scheduler == 'WarmupCosineLR':
|
| 1413 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 1414 |
+
optimizer,
|
| 1415 |
+
num_warmup_steps=500,
|
| 1416 |
+
num_training_steps=cfg.max_steps,
|
| 1417 |
+
)
|
| 1418 |
+
else:
|
| 1419 |
+
raise ValueError(f"Unsupported scheduler type: {cfg.scheduler}")
|
| 1420 |
+
|
| 1421 |
+
# Create Action Tokenizer
|
| 1422 |
+
action_tokenizer = ActionTokenizer(processor.tokenizer)
|
| 1423 |
+
|
| 1424 |
+
# Load Fine-tuning Dataset =>> note that we use an RLDS-formatted dataset following Open X-Embodiment by default.
|
| 1425 |
+
# =>> If you want to use a non-RLDS dataset (e.g., a standard PyTorch Dataset) see the following commented block.
|
| 1426 |
+
# =>> Note that our training code does not loop over epochs because the RLDS loader does this implicitly; if using
|
| 1427 |
+
# your own Dataset, make sure to add the appropriate logic to the training loop!
|
| 1428 |
+
#
|
| 1429 |
+
# ---
|
| 1430 |
+
# from prismatic.vla.datasets import DummyDataset
|
| 1431 |
+
#
|
| 1432 |
+
# train_dataset = DummyDataset(
|
| 1433 |
+
# action_tokenizer,
|
| 1434 |
+
# processor.tokenizer,
|
| 1435 |
+
# image_transform=processor.image_processor.apply_transform,
|
| 1436 |
+
# prompt_builder_fn=PurePromptBuilder,
|
| 1437 |
+
# )
|
| 1438 |
+
# ---
|
| 1439 |
+
|
| 1440 |
+
# We assume that the model takes as input one third-person camera image and 1 or 2 optional wrist camera image(s)
|
| 1441 |
+
use_wrist_image = cfg.num_images_in_input > 1
|
| 1442 |
+
|
| 1443 |
+
# Create training and optional validation datasets
|
| 1444 |
+
batch_transform = RLDSBatchTransform(
|
| 1445 |
+
action_tokenizer,
|
| 1446 |
+
processor.tokenizer,
|
| 1447 |
+
image_transform=processor.image_processor.apply_transform,
|
| 1448 |
+
prompt_builder_fn=PurePromptBuilder,
|
| 1449 |
+
use_wrist_image=use_wrist_image,
|
| 1450 |
+
use_proprio=cfg.use_proprio,
|
| 1451 |
+
)
|
| 1452 |
+
train_dataset = RLDSDataset(
|
| 1453 |
+
cfg.data_root_dir,
|
| 1454 |
+
cfg.dataset_name,
|
| 1455 |
+
batch_transform,
|
| 1456 |
+
resize_resolution=tuple(vla.module.config.image_sizes),
|
| 1457 |
+
shuffle_buffer_size=cfg.shuffle_buffer_size,
|
| 1458 |
+
image_aug=cfg.image_aug,
|
| 1459 |
+
)
|
| 1460 |
+
if cfg.use_val_set:
|
| 1461 |
+
val_dataset = RLDSDataset(
|
| 1462 |
+
cfg.data_root_dir,
|
| 1463 |
+
cfg.dataset_name,
|
| 1464 |
+
batch_transform,
|
| 1465 |
+
resize_resolution=tuple(vla.module.config.image_sizes),
|
| 1466 |
+
shuffle_buffer_size=cfg.shuffle_buffer_size // 10,
|
| 1467 |
+
image_aug=cfg.image_aug,
|
| 1468 |
+
train=False,
|
| 1469 |
+
)
|
| 1470 |
+
|
| 1471 |
+
# [Important] Save dataset statistics so that we can unnormalize actions during inference
|
| 1472 |
+
if distributed_state.is_main_process:
|
| 1473 |
+
save_dataset_statistics(train_dataset.dataset_statistics, run_dir)
|
| 1474 |
+
|
| 1475 |
+
# Create collator and dataloader
|
| 1476 |
+
collator = PaddedCollatorForActionPrediction(
|
| 1477 |
+
processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
|
| 1478 |
+
)
|
| 1479 |
+
dataloader = DataLoader(
|
| 1480 |
+
train_dataset,
|
| 1481 |
+
batch_size=cfg.batch_size,
|
| 1482 |
+
sampler=None,
|
| 1483 |
+
collate_fn=collator,
|
| 1484 |
+
num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
|
| 1485 |
+
)
|
| 1486 |
+
if cfg.use_val_set:
|
| 1487 |
+
val_batch_size = cfg.batch_size
|
| 1488 |
+
val_dataloader = DataLoader(
|
| 1489 |
+
val_dataset,
|
| 1490 |
+
batch_size=val_batch_size,
|
| 1491 |
+
sampler=None,
|
| 1492 |
+
collate_fn=collator,
|
| 1493 |
+
num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
|
| 1494 |
+
)
|
| 1495 |
+
|
| 1496 |
+
# Deque to store recent train metrics (used for computing smoothened metrics for gradient accumulation)
|
| 1497 |
+
recent_metrics = {
|
| 1498 |
+
"loss_value": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1499 |
+
"curr_action_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1500 |
+
"curr_action_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1501 |
+
"next_actions_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1502 |
+
"next_actions_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1503 |
+
"regularization_loss": deque(maxlen=cfg.grad_accumulation_steps),
|
| 1504 |
+
}
|
| 1505 |
+
|
| 1506 |
+
# Start training
|
| 1507 |
+
with tqdm.tqdm(total=cfg.max_steps, leave=False) as progress:
|
| 1508 |
+
vla.train()
|
| 1509 |
+
optimizer.zero_grad()
|
| 1510 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 1511 |
+
# Compute training metrics and loss
|
| 1512 |
+
compute_diffusion_l1 = cfg.use_diffusion and batch_idx % cfg.diffusion_sample_freq == 0
|
| 1513 |
+
loss, metrics = run_forward_pass(
|
| 1514 |
+
vla=vla,
|
| 1515 |
+
action_head=action_head,
|
| 1516 |
+
noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| 1517 |
+
proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| 1518 |
+
batch=batch,
|
| 1519 |
+
action_tokenizer=action_tokenizer,
|
| 1520 |
+
device_id=device_id,
|
| 1521 |
+
use_l1_regression=cfg.use_l1_regression,
|
| 1522 |
+
use_diffusion=cfg.use_diffusion,
|
| 1523 |
+
use_proprio=cfg.use_proprio,
|
| 1524 |
+
use_film=cfg.use_film,
|
| 1525 |
+
num_patches=NUM_PATCHES,
|
| 1526 |
+
compute_diffusion_l1=compute_diffusion_l1,
|
| 1527 |
+
num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
|
| 1528 |
+
)
|
| 1529 |
+
|
| 1530 |
+
# Add regularization loss if diff_params_dict is available
|
| 1531 |
+
if diff_params_dict:
|
| 1532 |
+
########################### Regularization Loss ##########################
|
| 1533 |
+
regularization_loss = compute_diff_regularization_loss(
|
| 1534 |
+
vla, diff_params_dict, regularization_weight=regularization_weight
|
| 1535 |
+
)
|
| 1536 |
+
# print(f"正则化loss: {regularization_loss}")
|
| 1537 |
+
# print(f"主loss: {loss}")
|
| 1538 |
+
# 这两行是用于梯度检查的
|
| 1539 |
+
# 保存主loss用于梯度检查
|
| 1540 |
+
# main_loss = loss.clone()
|
| 1541 |
+
# reg_loss = regularization_loss.clone()
|
| 1542 |
+
# print('loss:', loss)
|
| 1543 |
+
# print('regularization_loss:', regularization_loss)
|
| 1544 |
+
|
| 1545 |
+
# with vla.no_sync():
|
| 1546 |
+
# regularization_loss.backward()
|
| 1547 |
+
|
| 1548 |
+
# model_module = vla.module if hasattr(vla, 'module') else vla
|
| 1549 |
+
# reg_grads = {}
|
| 1550 |
+
# for name, param in model_module.named_parameters():
|
| 1551 |
+
# if "lora_A" in name and param.requires_grad and param.grad is not None:
|
| 1552 |
+
|
| 1553 |
+
# reg_grads[name] = param.grad.clone()
|
| 1554 |
+
|
| 1555 |
+
|
| 1556 |
+
dummy_loss = 0.0
|
| 1557 |
+
for p in vla.parameters():
|
| 1558 |
+
if p.requires_grad:
|
| 1559 |
+
dummy_loss = dummy_loss + p.sum() * 0.0
|
| 1560 |
+
|
| 1561 |
+
print('action loss:', loss)
|
| 1562 |
+
print('regularization_loss:', regularization_loss)
|
| 1563 |
+
print('dummy_loss:', dummy_loss)
|
| 1564 |
+
|
| 1565 |
+
loss = loss + regularization_loss + dummy_loss
|
| 1566 |
+
|
| 1567 |
+
|
| 1568 |
+
|
| 1569 |
+
loss.backward()
|
| 1570 |
+
# main_grads = {}
|
| 1571 |
+
# for name, param in model_module.named_parameters():
|
| 1572 |
+
# if "lora_A" in name and param.requires_grad and param.grad is not None:
|
| 1573 |
+
|
| 1574 |
+
# main_grads[name] = param.grad.clone()
|
| 1575 |
+
|
| 1576 |
+
# print('################################################')
|
| 1577 |
+
# for name in main_grads.keys():
|
| 1578 |
+
# if name in reg_grads:
|
| 1579 |
+
# main_grad_norm = main_grads[name].norm().item()
|
| 1580 |
+
# reg_grad_norm = reg_grads[name].norm().item()
|
| 1581 |
+
# combined_grad_norm = (main_grads[name] + reg_grads[name]).norm().item()
|
| 1582 |
+
# print(f" {name}:")
|
| 1583 |
+
# print(f" 主loss梯度norm: {main_grad_norm:.6f}")
|
| 1584 |
+
# print(f" 正则化loss梯度norm: {reg_grad_norm:.6f}")
|
| 1585 |
+
# print(f" 合并梯度norm: {combined_grad_norm:.6f}")
|
| 1586 |
+
|
| 1587 |
+
|
| 1588 |
+
# print('################################################')
|
| 1589 |
+
# # Log regularization loss
|
| 1590 |
+
# metrics["regularization_loss"] = regularization_loss.item()
|
| 1591 |
+
# #############################################################################
|
| 1592 |
+
|
| 1593 |
+
# # 这个if下面是用于梯度检查的
|
| 1594 |
+
# # 检查两个loss分别对应的梯度(在backward之前)
|
| 1595 |
+
# if diff_params_dict and batch_idx % cfg.wandb_log_freq == 0:
|
| 1596 |
+
# # 获取模型参数用于检查梯度
|
| 1597 |
+
# model_module = vla.module if hasattr(vla, 'module') else vla
|
| 1598 |
+
|
| 1599 |
+
# # 先清零梯度
|
| 1600 |
+
# optimizer.zero_grad()
|
| 1601 |
+
|
| 1602 |
+
# # 只对主loss进行backward
|
| 1603 |
+
# main_loss_normalized = main_loss / cfg.grad_accumulation_steps
|
| 1604 |
+
# main_loss_normalized.backward(retain_graph=True)
|
| 1605 |
+
|
| 1606 |
+
# # 保存主loss的梯度
|
| 1607 |
+
# main_grads = {}
|
| 1608 |
+
# for name, param in model_module.named_parameters():
|
| 1609 |
+
# if "lora_A" in name and param.requires_grad and param.grad is not None:
|
| 1610 |
+
|
| 1611 |
+
# main_grads[name] = param.grad.clone()
|
| 1612 |
+
|
| 1613 |
+
# # 清零梯度,只对正则化loss进行backward
|
| 1614 |
+
# optimizer.zero_grad()
|
| 1615 |
+
# reg_loss_normalized = reg_loss / cfg.grad_accumulation_steps
|
| 1616 |
+
# reg_loss_normalized.backward(retain_graph=True)
|
| 1617 |
+
|
| 1618 |
+
# # 保存正则化loss的梯度
|
| 1619 |
+
# reg_grads = {}
|
| 1620 |
+
# for name, param in model_module.named_parameters():
|
| 1621 |
+
# if "lora_A" in name and param.requires_grad and param.grad is not None:
|
| 1622 |
+
# reg_grads[name] = param.grad.clone()
|
| 1623 |
+
|
| 1624 |
+
# # 打印梯度信息
|
| 1625 |
+
# print(f"\n[梯度检查] Step {batch_idx // cfg.grad_accumulation_steps}")
|
| 1626 |
+
# sample_count = 0
|
| 1627 |
+
# for name in main_grads.keys():
|
| 1628 |
+
# if name in reg_grads:
|
| 1629 |
+
# main_grad_norm = main_grads[name].norm().item()
|
| 1630 |
+
# reg_grad_norm = reg_grads[name].norm().item()
|
| 1631 |
+
# combined_grad_norm = (main_grads[name] + reg_grads[name]).norm().item()
|
| 1632 |
+
# print(f" {name}:")
|
| 1633 |
+
# print(f" 主loss梯度norm: {main_grad_norm:.6f}")
|
| 1634 |
+
# print(f" 正则化loss梯度norm: {reg_grad_norm:.6f}")
|
| 1635 |
+
# print(f" 合并梯度norm: {combined_grad_norm:.6f}")
|
| 1636 |
+
# sample_count += 1
|
| 1637 |
+
# if sample_count >= 3: # 只检查前3个参数作为示例
|
| 1638 |
+
# break
|
| 1639 |
+
# print()
|
| 1640 |
+
|
| 1641 |
+
# # 清零梯度,准备正常的backward
|
| 1642 |
+
# optimizer.zero_grad()
|
| 1643 |
+
|
| 1644 |
+
# # Normalize loss to account for gradient accumulation
|
| 1645 |
+
# normalized_loss = loss / cfg.grad_accumulation_steps
|
| 1646 |
+
|
| 1647 |
+
# # Backward pass
|
| 1648 |
+
# normalized_loss.backward()
|
| 1649 |
+
|
| 1650 |
+
# Store recent train metrics
|
| 1651 |
+
for metric_name, value in metrics.items():
|
| 1652 |
+
if metric_name in recent_metrics:
|
| 1653 |
+
recent_metrics[metric_name].append(value)
|
| 1654 |
+
|
| 1655 |
+
# Compute gradient step index
|
| 1656 |
+
gradient_step_idx = batch_idx // cfg.grad_accumulation_steps
|
| 1657 |
+
|
| 1658 |
+
# Compute smoothened train metrics
|
| 1659 |
+
smoothened_metrics = compute_smoothened_metrics(recent_metrics)
|
| 1660 |
+
|
| 1661 |
+
# Push Metrics to W&B (every wandb_log_freq gradient steps)
|
| 1662 |
+
log_step = gradient_step_idx if not cfg.resume else cfg.resume_step + gradient_step_idx
|
| 1663 |
+
if distributed_state.is_main_process and log_step % cfg.wandb_log_freq == 0:
|
| 1664 |
+
log_metrics_to_wandb(smoothened_metrics, "VLA Train", log_step, wandb)
|
| 1665 |
+
|
| 1666 |
+
# [If applicable] Linearly warm up learning rate from 10% to 100% of original
|
| 1667 |
+
if cfg.lr_warmup_steps > 0:
|
| 1668 |
+
lr_progress = min((gradient_step_idx + 1) / cfg.lr_warmup_steps, 1.0) # Cap at 1.0
|
| 1669 |
+
current_lr = original_lr * (0.1 + 0.9 * lr_progress)
|
| 1670 |
+
for param_group in optimizer.param_groups:
|
| 1671 |
+
param_group["lr"] = current_lr
|
| 1672 |
+
|
| 1673 |
+
# Optimizer and LR scheduler step
|
| 1674 |
+
if (batch_idx + 1) % cfg.grad_accumulation_steps == 0:
|
| 1675 |
+
optimizer.step()
|
| 1676 |
+
scheduler.step()
|
| 1677 |
+
optimizer.zero_grad()
|
| 1678 |
+
progress.update()
|
| 1679 |
+
if cfg.use_ema:
|
| 1680 |
+
ema_vla.step(vla, action_head, proprio_projector, noisy_action_projector)
|
| 1681 |
+
|
| 1682 |
+
if distributed_state.is_main_process and gradient_step_idx % cfg.wandb_log_freq == 0:
|
| 1683 |
+
# Log the learning rate
|
| 1684 |
+
# Make sure to do this AFTER any learning rate modifications (e.g., warmup/decay)
|
| 1685 |
+
wandb.log(
|
| 1686 |
+
{
|
| 1687 |
+
"VLA Train/Learning Rate": scheduler.get_last_lr()[0],
|
| 1688 |
+
},
|
| 1689 |
+
step=log_step,
|
| 1690 |
+
)
|
| 1691 |
+
|
| 1692 |
+
if cfg.use_ema:
|
| 1693 |
+
# Log the EMA decay value
|
| 1694 |
+
wandb.log(
|
| 1695 |
+
{
|
| 1696 |
+
"VLA Train/EMA Decay": ema_vla.decay,
|
| 1697 |
+
},
|
| 1698 |
+
step=log_step,
|
| 1699 |
+
)
|
| 1700 |
+
# Log the EMA eval loss
|
| 1701 |
+
ema_vla.apply_shadow(vla, action_head, proprio_projector, noisy_action_projector)
|
| 1702 |
+
with torch.no_grad():
|
| 1703 |
+
vla.eval()
|
| 1704 |
+
action_head.eval() if action_head else None
|
| 1705 |
+
_, ema_metrics = run_forward_pass(
|
| 1706 |
+
vla=vla,
|
| 1707 |
+
action_head=action_head,
|
| 1708 |
+
noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| 1709 |
+
proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| 1710 |
+
batch=batch,
|
| 1711 |
+
action_tokenizer=action_tokenizer,
|
| 1712 |
+
device_id=device_id,
|
| 1713 |
+
use_l1_regression=cfg.use_l1_regression,
|
| 1714 |
+
use_diffusion=cfg.use_diffusion,
|
| 1715 |
+
use_proprio=cfg.use_proprio,
|
| 1716 |
+
use_film=cfg.use_film,
|
| 1717 |
+
num_patches=NUM_PATCHES,
|
| 1718 |
+
compute_diffusion_l1=compute_diffusion_l1,
|
| 1719 |
+
num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
|
| 1720 |
+
)
|
| 1721 |
+
ema_loss = ema_metrics['loss_value']
|
| 1722 |
+
vla.train()
|
| 1723 |
+
action_head.train() if action_head else None
|
| 1724 |
+
ema_vla.restore(vla, action_head, proprio_projector, noisy_action_projector)
|
| 1725 |
+
wandb.log(
|
| 1726 |
+
{
|
| 1727 |
+
"VLA Train/EMA Loss": ema_loss,
|
| 1728 |
+
},
|
| 1729 |
+
step=log_step,
|
| 1730 |
+
)
|
| 1731 |
+
|
| 1732 |
+
# Save model checkpoint: either keep latest checkpoint only or all checkpoints
|
| 1733 |
+
if gradient_step_idx > 0 and log_step % cfg.save_freq == 0:
|
| 1734 |
+
save_training_checkpoint(
|
| 1735 |
+
cfg=cfg,
|
| 1736 |
+
run_dir=run_dir,
|
| 1737 |
+
log_step=log_step,
|
| 1738 |
+
vla=vla,
|
| 1739 |
+
processor=processor,
|
| 1740 |
+
proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| 1741 |
+
noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| 1742 |
+
action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None,
|
| 1743 |
+
train_dataset=train_dataset,
|
| 1744 |
+
distributed_state=distributed_state,
|
| 1745 |
+
)
|
| 1746 |
+
|
| 1747 |
+
if cfg.use_ema:
|
| 1748 |
+
# Also save EMA model checkpoint
|
| 1749 |
+
ema_vla.apply_shadow(vla, action_head, proprio_projector, noisy_action_projector)
|
| 1750 |
+
save_training_checkpoint(
|
| 1751 |
+
cfg=cfg,
|
| 1752 |
+
run_dir=run_dir / "ema_model",
|
| 1753 |
+
log_step=log_step,
|
| 1754 |
+
vla=vla,
|
| 1755 |
+
processor=processor,
|
| 1756 |
+
proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| 1757 |
+
noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| 1758 |
+
action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None,
|
| 1759 |
+
train_dataset=train_dataset,
|
| 1760 |
+
distributed_state=distributed_state,
|
| 1761 |
+
)
|
| 1762 |
+
ema_vla.restore(vla, action_head, proprio_projector, noisy_action_projector)
|
| 1763 |
+
|
| 1764 |
+
# Test model on validation set
|
| 1765 |
+
if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0:
|
| 1766 |
+
run_validation(
|
| 1767 |
+
vla=vla,
|
| 1768 |
+
action_head=action_head,
|
| 1769 |
+
noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
|
| 1770 |
+
proprio_projector=proprio_projector if cfg.use_proprio else None,
|
| 1771 |
+
val_dataloader=val_dataloader,
|
| 1772 |
+
action_tokenizer=action_tokenizer,
|
| 1773 |
+
device_id=device_id,
|
| 1774 |
+
cfg=cfg,
|
| 1775 |
+
num_patches=NUM_PATCHES,
|
| 1776 |
+
log_step=log_step,
|
| 1777 |
+
distributed_state=distributed_state,
|
| 1778 |
+
val_time_limit=cfg.val_time_limit,
|
| 1779 |
+
)
|
| 1780 |
+
# Set model back to training mode after validation
|
| 1781 |
+
vla.train()
|
| 1782 |
+
|
| 1783 |
+
# Stop training when max_steps is reached
|
| 1784 |
+
if log_step == cfg.max_steps:
|
| 1785 |
+
print(f"Max step {cfg.max_steps} reached! Stopping training...")
|
| 1786 |
+
break
|
| 1787 |
+
|
| 1788 |
+
|
| 1789 |
+
if __name__ == "__main__":
|
| 1790 |
+
finetune()
|
capvector-oft/vla-scripts/merge_lora_weights_and_save.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loads a checkpoint that only has a LoRA adapter (no merged model) and merges the adapter
|
| 3 |
+
into the base OpenVLA model. Saves the final checkpoint in the same directory.
|
| 4 |
+
|
| 5 |
+
Make sure to specify the correct base checkpoint when running this script. For example,
|
| 6 |
+
- if you fine-tuned the default OpenVLA-7B model without modifications, then `--base_checkpoint=="openvla/openvla-7b"`
|
| 7 |
+
- if you fine-tuned a different model or resumed fine-tuning from a different checkpoint, then specify that base checkpoint
|
| 8 |
+
- if you fine-tuned the default OpenVLA-7B model with modifications to `modeling_prismatic.py` (OpenVLA class definition),
|
| 9 |
+
then the base checkpoint path should point to the checkpoint containing the modifications
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python vla-scripts/merge_lora_weights_and_save.py \
|
| 13 |
+
--base_checkpoint openvla/openvla-7b \
|
| 14 |
+
--lora_finetuned_checkpoint_dir /PATH/TO/CHECKPOINT/DIR/
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import time
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Union
|
| 22 |
+
|
| 23 |
+
import draccus
|
| 24 |
+
import torch
|
| 25 |
+
from peft import PeftModel
|
| 26 |
+
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
|
| 27 |
+
|
| 28 |
+
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
|
| 29 |
+
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
|
| 30 |
+
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class ConvertConfig:
|
| 35 |
+
# fmt: off
|
| 36 |
+
|
| 37 |
+
base_checkpoint: Union[str, Path] = "" # Base model checkpoint path/dir (either openvla/openvla-7b or whichever model you fine-tuned / resumed training from)
|
| 38 |
+
lora_finetuned_checkpoint_dir: Union[str, Path] = "" # Checkpoint directory containing the LoRA adapter
|
| 39 |
+
|
| 40 |
+
# fmt: on
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@draccus.wrap()
|
| 44 |
+
def main(cfg: ConvertConfig) -> None:
|
| 45 |
+
# Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub)
|
| 46 |
+
AutoConfig.register("openvla", OpenVLAConfig)
|
| 47 |
+
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
|
| 48 |
+
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
|
| 49 |
+
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
|
| 50 |
+
|
| 51 |
+
# Load Model using HF AutoClasses
|
| 52 |
+
print(f"Loading base model: {cfg.base_checkpoint}")
|
| 53 |
+
vla = AutoModelForVision2Seq.from_pretrained(
|
| 54 |
+
cfg.base_checkpoint,
|
| 55 |
+
torch_dtype=torch.bfloat16,
|
| 56 |
+
low_cpu_mem_usage=True,
|
| 57 |
+
trust_remote_code=True,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Load LoRA weights and merge into base model, then save final checkpoint
|
| 61 |
+
print("Merging LoRA weights into base model...")
|
| 62 |
+
start_time = time.time()
|
| 63 |
+
merged_vla = PeftModel.from_pretrained(vla, os.path.join(cfg.lora_finetuned_checkpoint_dir, "lora_adapter")).to(
|
| 64 |
+
"cuda"
|
| 65 |
+
)
|
| 66 |
+
merged_vla = merged_vla.merge_and_unload()
|
| 67 |
+
merged_vla.save_pretrained(cfg.lora_finetuned_checkpoint_dir)
|
| 68 |
+
print(f"\nMerging complete! Time elapsed (sec): {time.time() - start_time}")
|
| 69 |
+
print(f"\nSaved merged model checkpoint at:\n{cfg.lora_finetuned_checkpoint_dir}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
main()
|
capvector-pi05/.dockerignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
checkpoints
|
| 3 |
+
data
|
capvector-pi05/.gitignore
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data directories.
|
| 2 |
+
assets/
|
| 3 |
+
checkpoints/
|
| 4 |
+
data/
|
| 5 |
+
wandb/
|
| 6 |
+
|
| 7 |
+
# Byte-compiled / optimized / DLL files
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.py[cod]
|
| 10 |
+
*$py.class
|
| 11 |
+
|
| 12 |
+
# C extensions
|
| 13 |
+
*.so
|
| 14 |
+
|
| 15 |
+
# Distribution / packaging
|
| 16 |
+
.Python
|
| 17 |
+
build/
|
| 18 |
+
develop-eggs/
|
| 19 |
+
dist/
|
| 20 |
+
downloads/
|
| 21 |
+
eggs/
|
| 22 |
+
.eggs/
|
| 23 |
+
lib/
|
| 24 |
+
lib64/
|
| 25 |
+
parts/
|
| 26 |
+
sdist/
|
| 27 |
+
var/
|
| 28 |
+
wheels/
|
| 29 |
+
share/python-wheels/
|
| 30 |
+
*.egg-info/
|
| 31 |
+
.installed.cfg
|
| 32 |
+
*.egg
|
| 33 |
+
MANIFEST
|
| 34 |
+
|
| 35 |
+
# PyInstaller
|
| 36 |
+
# Usually these files are written by a python script from a template
|
| 37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 38 |
+
*.manifest
|
| 39 |
+
*.spec
|
| 40 |
+
|
| 41 |
+
# Installer logs
|
| 42 |
+
pip-log.txt
|
| 43 |
+
pip-delete-this-directory.txt
|
| 44 |
+
|
| 45 |
+
# Unit test / coverage reports
|
| 46 |
+
htmlcov/
|
| 47 |
+
.tox/
|
| 48 |
+
.nox/
|
| 49 |
+
.coverage
|
| 50 |
+
.coverage.*
|
| 51 |
+
.cache
|
| 52 |
+
nosetests.xml
|
| 53 |
+
coverage.xml
|
| 54 |
+
*.cover
|
| 55 |
+
*.py,cover
|
| 56 |
+
.hypothesis/
|
| 57 |
+
.pytest_cache/
|
| 58 |
+
cover/
|
| 59 |
+
|
| 60 |
+
# Translations
|
| 61 |
+
*.mo
|
| 62 |
+
*.pot
|
| 63 |
+
|
| 64 |
+
# Django stuff:
|
| 65 |
+
*.log
|
| 66 |
+
local_settings.py
|
| 67 |
+
db.sqlite3
|
| 68 |
+
db.sqlite3-journal
|
| 69 |
+
|
| 70 |
+
# Flask stuff:
|
| 71 |
+
instance/
|
| 72 |
+
.webassets-cache
|
| 73 |
+
|
| 74 |
+
# Scrapy stuff:
|
| 75 |
+
.scrapy
|
| 76 |
+
|
| 77 |
+
# Sphinx documentation
|
| 78 |
+
docs/_build/
|
| 79 |
+
|
| 80 |
+
# PyBuilder
|
| 81 |
+
.pybuilder/
|
| 82 |
+
target/
|
| 83 |
+
|
| 84 |
+
# Jupyter Notebook
|
| 85 |
+
.ipynb_checkpoints
|
| 86 |
+
|
| 87 |
+
# IPython
|
| 88 |
+
profile_default/
|
| 89 |
+
ipython_config.py
|
| 90 |
+
|
| 91 |
+
# pyenv
|
| 92 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 94 |
+
# .python-version
|
| 95 |
+
|
| 96 |
+
# pipenv
|
| 97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 100 |
+
# install all needed dependencies.
|
| 101 |
+
#Pipfile.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
.idea/
|
| 169 |
+
.vscode/
|
capvector-pi05/.gitmodules
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "third_party/aloha"]
|
| 2 |
+
path = third_party/aloha
|
| 3 |
+
url = https://github.com/Physical-Intelligence/aloha.git
|
| 4 |
+
[submodule "third_party/libero"]
|
| 5 |
+
path = third_party/libero
|
| 6 |
+
url = https://github.com/Lifelong-Robot-Learning/LIBERO.git
|
capvector-pi05/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exclude: third_party/
|
| 2 |
+
|
| 3 |
+
repos:
|
| 4 |
+
- repo: https://github.com/astral-sh/uv-pre-commit
|
| 5 |
+
# uv version.
|
| 6 |
+
rev: 0.5.14
|
| 7 |
+
hooks:
|
| 8 |
+
- id: uv-lock
|
| 9 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 10 |
+
# Ruff version.
|
| 11 |
+
rev: v0.8.6
|
| 12 |
+
hooks:
|
| 13 |
+
# Run the linter.
|
| 14 |
+
- id: ruff
|
| 15 |
+
args: [--fix]
|
| 16 |
+
- id: ruff-format
|
capvector-pi05/.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
capvector-pi05/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
capvector-pi05/README.md
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## 1. Environment Setup
|
| 2 |
+
We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment:
|
| 3 |
+
|
| 4 |
+
```bash
|
| 5 |
+
GIT_LFS_SKIP_SMUDGE=1 uv sync
|
| 6 |
+
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
|
| 7 |
+
cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
|
| 8 |
+
source .venv/bin/activate
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## 2. Data Preparation
|
| 15 |
+
Here we take the real-world Aloha data as example, more detail simulation data could be refered in the [official openpi repo](https://github.com/Physical-Intelligence/openpi/).
|
| 16 |
+
|
| 17 |
+
First, you need to collect the task-specific raw data with your own robot, and save it in the `.hdf5` format.
|
| 18 |
+
|
| 19 |
+
Then, convert the data to LeRobot dataset format.
|
| 20 |
+
```bash
|
| 21 |
+
uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
| 22 |
+
# By default, The converted data is stored in ~/.cache/huggingface/lerobot/<org>/<dataset-name>/
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
## 3. Obtain the capability vectors and merge it to obtain $\theta_{meta}$
|
| 27 |
+
|
| 28 |
+
First, define your task-specific config in [config.py](src/openpi/training/config.py). And we provide an example of our real-world task [here](src/openpi/training/config.py#L776-L808).
|
| 29 |
+
|
| 30 |
+
Then, convert a JAX model checkpoint to PyTorch format:
|
| 31 |
+
```bash
|
| 32 |
+
uv run examples/convert_jax_model_to_pytorch.py \
|
| 33 |
+
--checkpoint_dir gs://openpi-assets/checkpoints/pi05_base \
|
| 34 |
+
--config_name <config_name> \
|
| 35 |
+
--output_path checkpoints/pytorch_pi05_base
|
| 36 |
+
# This command will automatically download pi05_base checkpoint to ~/.cache/openpi/openpi-assets/checkpoints/pi05_base/
|
| 37 |
+
# Otherwise you can download it manually and modify the --checkpoint_dir
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
> ⭐ If you don't use the regularization strategy, you could download the [capability-merged meta model](https://huggingface.co/haofuly/capvector_models_collection/capvector_pi05/merged_model) we provided, place it at `./checkpoints/vector_init/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial/`, and directly jump to the next [Training step](#4-training).
|
| 41 |
+
|
| 42 |
+
Then, the capability vectors are obtained by simply conducting parameter arithmetic between two models finetuned with different strategies. Therefore, we need to prepare these two trained models, *e.g.*, [Pi0.5 on LIBERO-Spatial)](https://huggingface.co/haofuly/capvector_models_collection/capvector_pi05/pi05_baseline_30000step_spatial) and [Pi0.5-SF on LIBERO-Spatial)](https://huggingface.co/haofuly/capvector_models_collection/capvector_pi05/pi05_spatialforcing_30000step_spatial). The directory structure is as below:
|
| 43 |
+
```
|
| 44 |
+
capvector-pi05
|
| 45 |
+
├── checkpoints
|
| 46 |
+
· ├── pi05-LIBEROspatial
|
| 47 |
+
│ ├── model.safetensors
|
| 48 |
+
│ └── ...
|
| 49 |
+
├── pi05SF-LIBEROspatial
|
| 50 |
+
│ ├── model.safetensors
|
| 51 |
+
│ └── ...
|
| 52 |
+
├── diff
|
| 53 |
+
├── vector_init
|
| 54 |
+
·
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Next, conduct parameter arithmetic between these two models:
|
| 58 |
+
```bash
|
| 59 |
+
CONFIG=pi05_capvector_aloha_place_block && \
|
| 60 |
+
EXT=pi05SF-LIBEROspatial && \
|
| 61 |
+
DOWN=pi05-LIBEROspatial && \
|
| 62 |
+
uv run capvector/compute_param_diff.py \
|
| 63 |
+
--config $CONFIG \
|
| 64 |
+
--a.dir checkpoints/$EXT \
|
| 65 |
+
--b.dir checkpoints/$DOWN \
|
| 66 |
+
--out checkpoints/diff/${EXT}_minus_${DOWN}.pth \
|
| 67 |
+
--strict-keys \
|
| 68 |
+
--dtype fp32
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Finally, merge these diff parameters to obtain $\theta_{meta}:
|
| 72 |
+
```bash
|
| 73 |
+
DIFF=pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial && \
|
| 74 |
+
uv run capvector/apply_param_diff.py \
|
| 75 |
+
--base-safetensors checkpoints/pytorch_pi05_base/model.safetensors \
|
| 76 |
+
--diff-pth checkpoints/diff/${DIFF}.pth \
|
| 77 |
+
--out-safetensors checkpoints/vector_init/${DIFF}/model.safetensors \
|
| 78 |
+
--scale 1.0 \
|
| 79 |
+
--no-strict-keys \
|
| 80 |
+
--dtype fp32 \
|
| 81 |
+
--device cpu
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
## 4. Training
|
| 86 |
+
First, you need to compute the normalization statistics for the training data.
|
| 87 |
+
```bash
|
| 88 |
+
uv run scripts/compute_norm_stats.py --config-name <config_name>
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Finally, launch training using one of these modes:
|
| 92 |
+
```bash
|
| 93 |
+
# Single GPU training:
|
| 94 |
+
uv run scripts/train_regular_loss_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
|
| 95 |
+
# Example:
|
| 96 |
+
uv run scripts/train_regular_loss_pytorch.py pi05_capvector_aloha_place_block --exp_name pytorch_test
|
| 97 |
+
uv run scripts/train_regular_loss_pytorch.py pi05_capvector_aloha_place_block --exp_name pytorch_test --overwrite # Overwrite existing checkpoints
|
| 98 |
+
|
| 99 |
+
# Multi-GPU training (single node):
|
| 100 |
+
uv run torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_regular_loss_pytorch.py <config_name> --exp_name <run_name>
|
| 101 |
+
|
| 102 |
+
# Multi-Node Training:
|
| 103 |
+
uv run torchrun \
|
| 104 |
+
--nnodes=<num_nodes> \
|
| 105 |
+
--nproc_per_node=<gpus_per_node> \
|
| 106 |
+
--node_rank=<rank_of_node> \
|
| 107 |
+
--master_addr=<master_ip> \
|
| 108 |
+
--master_port=<port> \
|
| 109 |
+
scripts/train_regular_loss_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
## 5. Inference
|
| 114 |
+
Real-world inference is executed in the server-client form.
|
| 115 |
+
|
| 116 |
+
First, launch a model server (we use the checkpoint for iteration 20,000 for this example, modify as needed):
|
| 117 |
+
```bash
|
| 118 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=<config_name> --policy.dir=checkpoints/<config_name>/<run_name>/20000
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
This will spin up a server that listens on port 8000 and waits for observations to be sent to it.
|
| 122 |
+
|
| 123 |
+
Then, We can then run an client robot script that queries the server.
|
| 124 |
+
|
| 125 |
+
You need to write your client script according to your robot. A simple [client exmaple](examples/simple_client/main.py) is as below:
|
| 126 |
+
```bash
|
| 127 |
+
uv run examples/simple_client/main.py --env ALOHA
|
| 128 |
+
```
|
capvector-pi05/capvector/apply_param_diff.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import tyro
|
| 7 |
+
from safetensors.torch import load_file, save_file
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclasses.dataclass
|
| 11 |
+
class Args:
|
| 12 |
+
# Base pretrained weights in safetensors
|
| 13 |
+
base_safetensors: str
|
| 14 |
+
|
| 15 |
+
# Diff checkpoint in .pth (either {"state_dict": ...} or raw state_dict)
|
| 16 |
+
diff_pth: str
|
| 17 |
+
|
| 18 |
+
# Output safetensors path
|
| 19 |
+
out_safetensors: str = "model_merged.safetensors"
|
| 20 |
+
|
| 21 |
+
# final = base + scale * diff
|
| 22 |
+
scale: float = 1.0
|
| 23 |
+
|
| 24 |
+
# whether keys must match exactly
|
| 25 |
+
strict_keys: bool = True # use --strict-keys / --no-strict-keys
|
| 26 |
+
|
| 27 |
+
# arithmetic dtype
|
| 28 |
+
dtype: str = "fp32" # fp32/fp16/bf16
|
| 29 |
+
|
| 30 |
+
# compute device
|
| 31 |
+
device: str = "cpu" # cpu/cuda
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def cast(t: torch.Tensor, dtype: str) -> torch.Tensor:
|
| 35 |
+
if dtype == "fp32":
|
| 36 |
+
return t.float()
|
| 37 |
+
if dtype == "fp16":
|
| 38 |
+
return t.half()
|
| 39 |
+
if dtype == "bf16":
|
| 40 |
+
return t.bfloat16()
|
| 41 |
+
raise ValueError(f"Unknown dtype: {dtype}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_diff_state_dict(path: str) -> dict[str, torch.Tensor]:
|
| 45 |
+
obj = torch.load(path, map_location="cpu")
|
| 46 |
+
if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict):
|
| 47 |
+
sd = obj["state_dict"]
|
| 48 |
+
elif isinstance(obj, dict):
|
| 49 |
+
sd = obj
|
| 50 |
+
else:
|
| 51 |
+
raise RuntimeError(f"Unexpected diff format: {type(obj)}")
|
| 52 |
+
|
| 53 |
+
for k, v in sd.items():
|
| 54 |
+
if not isinstance(v, torch.Tensor):
|
| 55 |
+
raise RuntimeError(f"Diff contains non-tensor at key={k}: {type(v)}")
|
| 56 |
+
return sd
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main(args: Args) -> None:
|
| 60 |
+
logging.info("Loading base safetensors: %s", args.base_safetensors)
|
| 61 |
+
base_sd = load_file(args.base_safetensors, device="cpu") # dict[str, Tensor]
|
| 62 |
+
|
| 63 |
+
logging.info("Loading diff pth: %s", args.diff_pth)
|
| 64 |
+
diff_sd = load_diff_state_dict(args.diff_pth)
|
| 65 |
+
|
| 66 |
+
keys_base = set(base_sd.keys())
|
| 67 |
+
keys_diff = set(diff_sd.keys())
|
| 68 |
+
|
| 69 |
+
if args.strict_keys:
|
| 70 |
+
if keys_base != keys_diff:
|
| 71 |
+
only_base = sorted(list(keys_base - keys_diff))[:30]
|
| 72 |
+
only_diff = sorted(list(keys_diff - keys_base))[:30]
|
| 73 |
+
raise RuntimeError(
|
| 74 |
+
"Keys mismatch between base safetensors and diff.\n"
|
| 75 |
+
f"Only in base (up to 30): {only_base}\n"
|
| 76 |
+
f"Only in diff (up to 30): {only_diff}\n"
|
| 77 |
+
"Use --no-strict-keys to apply on intersection only."
|
| 78 |
+
)
|
| 79 |
+
keys_apply = keys_base
|
| 80 |
+
else:
|
| 81 |
+
keys_apply = keys_base & keys_diff
|
| 82 |
+
logging.warning("Non-strict mode: applying on intersection keys: %d", len(keys_apply))
|
| 83 |
+
|
| 84 |
+
dev = torch.device(args.device)
|
| 85 |
+
|
| 86 |
+
merged_sd: dict[str, torch.Tensor] = {}
|
| 87 |
+
applied_float = 0
|
| 88 |
+
skipped_nonfloat = 0
|
| 89 |
+
skipped_missing = 0
|
| 90 |
+
|
| 91 |
+
for k, base_t_cpu in base_sd.items():
|
| 92 |
+
base_t = base_t_cpu # already on cpu
|
| 93 |
+
|
| 94 |
+
if k not in keys_apply:
|
| 95 |
+
merged_sd[k] = base_t
|
| 96 |
+
skipped_missing += 1
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
diff_t_cpu = diff_sd[k]
|
| 100 |
+
|
| 101 |
+
if base_t.shape != diff_t_cpu.shape:
|
| 102 |
+
raise RuntimeError(f"Shape mismatch at key={k}: base {base_t.shape} vs diff {diff_t_cpu.shape}")
|
| 103 |
+
|
| 104 |
+
# only add for floating-point tensors
|
| 105 |
+
if base_t.is_floating_point() and diff_t_cpu.is_floating_point():
|
| 106 |
+
a = cast(base_t.to(dev), args.dtype)
|
| 107 |
+
d = cast(diff_t_cpu.to(dev), args.dtype)
|
| 108 |
+
out = a + args.scale * d
|
| 109 |
+
merged_sd[k] = out.to(base_t.dtype).detach().cpu()
|
| 110 |
+
applied_float += 1
|
| 111 |
+
else:
|
| 112 |
+
merged_sd[k] = base_t
|
| 113 |
+
skipped_nonfloat += 1
|
| 114 |
+
|
| 115 |
+
out_path = Path(args.out_safetensors)
|
| 116 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 117 |
+
|
| 118 |
+
# safetensors 需要所有 tensor 在 CPU
|
| 119 |
+
for k, v in merged_sd.items():
|
| 120 |
+
if v.device.type != "cpu":
|
| 121 |
+
merged_sd[k] = v.cpu()
|
| 122 |
+
|
| 123 |
+
logging.info(
|
| 124 |
+
"Done. applied_float=%d, skipped_nonfloat=%d, skipped_missing=%d",
|
| 125 |
+
applied_float,
|
| 126 |
+
skipped_nonfloat,
|
| 127 |
+
skipped_missing,
|
| 128 |
+
)
|
| 129 |
+
logging.info("Saving merged safetensors to: %s", str(out_path))
|
| 130 |
+
save_file(merged_sd, str(out_path))
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
logging.basicConfig(level=logging.INFO, force=True)
|
| 135 |
+
main(tyro.cli(Args))
|
capvector-pi05/capvector/compute_param_diff.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import tyro
|
| 8 |
+
|
| 9 |
+
from openpi.training import config as _config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclasses.dataclass
|
| 13 |
+
class CkptSpec:
|
| 14 |
+
dir: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclasses.dataclass
|
| 18 |
+
class Args:
|
| 19 |
+
config: str
|
| 20 |
+
a: CkptSpec
|
| 21 |
+
b: CkptSpec
|
| 22 |
+
out: str = "checkpoints/diff/a_minus_b.pth"
|
| 23 |
+
only_vlm: bool = False
|
| 24 |
+
strict_keys: bool = False
|
| 25 |
+
dtype: str = "fp32"
|
| 26 |
+
device: str = "cpu"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _extract_state_dict(obj: Any) -> dict[str, torch.Tensor]:
|
| 30 |
+
"""
|
| 31 |
+
Try best to get a torch state_dict from a Policy or Module-like object.
|
| 32 |
+
"""
|
| 33 |
+
# Case 1: policy itself has state_dict()
|
| 34 |
+
if hasattr(obj, "state_dict") and callable(obj.state_dict):
|
| 35 |
+
sd = obj.state_dict()
|
| 36 |
+
if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()):
|
| 37 |
+
return sd
|
| 38 |
+
|
| 39 |
+
# Case 2: common attributes that hold torch.nn.Module
|
| 40 |
+
for attr in ["model", "_model", "module", "net", "_net", "policy", "_policy"]:
|
| 41 |
+
if hasattr(obj, attr):
|
| 42 |
+
m = getattr(obj, attr)
|
| 43 |
+
if hasattr(m, "state_dict") and callable(m.state_dict):
|
| 44 |
+
sd = m.state_dict()
|
| 45 |
+
if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()):
|
| 46 |
+
return sd
|
| 47 |
+
|
| 48 |
+
raise RuntimeError(
|
| 49 |
+
"Cannot extract state_dict. "
|
| 50 |
+
"Please inspect Policy object and update attribute list in _extract_state_dict()."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _cast_tensor(t: torch.Tensor, dtype: str) -> torch.Tensor:
|
| 55 |
+
if dtype == "fp32":
|
| 56 |
+
return t.float()
|
| 57 |
+
if dtype == "fp16":
|
| 58 |
+
return t.half()
|
| 59 |
+
if dtype == "bf16":
|
| 60 |
+
return t.bfloat16()
|
| 61 |
+
raise ValueError(f"Unknown dtype: {dtype}")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def load_model(config_name: str, spec: CkptSpec):
|
| 65 |
+
cfg = _config.get_config(config_name)
|
| 66 |
+
weight_path = Path(spec.dir) / "model.safetensors"
|
| 67 |
+
if not weight_path.exists():
|
| 68 |
+
raise FileNotFoundError(f"Missing model.safetensors in checkpoint directory: {spec.dir}")
|
| 69 |
+
return cfg.model.load_pytorch(cfg, str(weight_path))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def main(args: Args) -> None:
|
| 73 |
+
logging.info("Loading A model from %s with config %s", args.a.dir, args.config)
|
| 74 |
+
model_a = load_model(args.config, args.a)
|
| 75 |
+
logging.info("Loading B model from %s with config %s", args.b.dir, args.config)
|
| 76 |
+
model_b = load_model(args.config, args.b)
|
| 77 |
+
|
| 78 |
+
sd_a = _extract_state_dict(model_a)
|
| 79 |
+
sd_b = _extract_state_dict(model_b)
|
| 80 |
+
|
| 81 |
+
keys_a = set(sd_a.keys())
|
| 82 |
+
keys_b = set(sd_b.keys())
|
| 83 |
+
|
| 84 |
+
if args.strict_keys:
|
| 85 |
+
if keys_a != keys_b:
|
| 86 |
+
only_a = sorted(list(keys_a - keys_b))[:20]
|
| 87 |
+
only_b = sorted(list(keys_b - keys_a))[:20]
|
| 88 |
+
raise RuntimeError(
|
| 89 |
+
f"State dict keys mismatch.\n"
|
| 90 |
+
f"Only in A (show up to 20): {only_a}\n"
|
| 91 |
+
f"Only in B (show up to 20): {only_b}\n"
|
| 92 |
+
f"Set --strict-keys False to subtract intersection only."
|
| 93 |
+
)
|
| 94 |
+
keys = sorted(keys_a)
|
| 95 |
+
else:
|
| 96 |
+
keys = sorted(list(keys_a & keys_b))
|
| 97 |
+
logging.warning("Non-strict mode: subtracting only intersection keys: %d", len(keys))
|
| 98 |
+
|
| 99 |
+
device = torch.device(args.device)
|
| 100 |
+
diff: dict[str, torch.Tensor] = {}
|
| 101 |
+
|
| 102 |
+
if args.only_vlm:
|
| 103 |
+
ZERO_PREFIXES = [
|
| 104 |
+
"paligemma_with_expert.gemma_expert.",
|
| 105 |
+
"action_in_proj.",
|
| 106 |
+
"action_out_proj.",
|
| 107 |
+
"action_time_mlp_in",
|
| 108 |
+
"action_time_mlp_oout",
|
| 109 |
+
]
|
| 110 |
+
else:
|
| 111 |
+
ZERO_PREFIXES = []
|
| 112 |
+
|
| 113 |
+
for k in keys:
|
| 114 |
+
ta = sd_a[k].to(device)
|
| 115 |
+
tb = sd_b[k].to(device)
|
| 116 |
+
|
| 117 |
+
if ta.shape != tb.shape:
|
| 118 |
+
raise RuntimeError(f"Shape mismatch at key={k}: {ta.shape} vs {tb.shape}")
|
| 119 |
+
|
| 120 |
+
zero_this = any(k.startswith(p) for p in ZERO_PREFIXES)
|
| 121 |
+
|
| 122 |
+
if zero_this:
|
| 123 |
+
out = torch.zeros_like(ta)
|
| 124 |
+
else:
|
| 125 |
+
if ta.is_floating_point():
|
| 126 |
+
out = _cast_tensor(ta, args.dtype) - _cast_tensor(tb, args.dtype)
|
| 127 |
+
else:
|
| 128 |
+
out = ta
|
| 129 |
+
|
| 130 |
+
diff[k] = out.detach().cpu()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
out_path = Path(args.out)
|
| 135 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 136 |
+
torch.save({"state_dict": diff, "a": dataclasses.asdict(args.a), "b": dataclasses.asdict(args.b)}, out_path)
|
| 137 |
+
logging.info("Saved diff checkpoint to: %s", str(out_path))
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
logging.basicConfig(level=logging.INFO, force=True)
|
| 142 |
+
main(tyro.cli(Args))
|
capvector-pi05/docs/docker.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Docker Setup
|
| 2 |
+
|
| 3 |
+
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
|
| 4 |
+
|
| 5 |
+
- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
|
| 6 |
+
- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
|
| 7 |
+
- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
|
| 8 |
+
- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
|
| 9 |
+
- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
|
| 13 |
+
|
| 14 |
+
Build the Docker image and start the container with the following command:
|
| 15 |
+
```bash
|
| 16 |
+
docker compose -f scripts/docker/compose.yml up --build
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
To build and run the Docker image for a specific example, use the following command:
|
| 20 |
+
```bash
|
| 21 |
+
docker compose -f examples/<example_name>/compose.yml up --build
|
| 22 |
+
```
|
| 23 |
+
where `<example_name>` is the name of the example you want to run.
|
| 24 |
+
|
| 25 |
+
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
|
capvector-pi05/docs/norm_stats.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Normalization statistics
|
| 2 |
+
|
| 3 |
+
Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
|
| 4 |
+
|
| 5 |
+
## Reloading normalization statistics
|
| 6 |
+
|
| 7 |
+
When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
|
| 8 |
+
|
| 9 |
+
**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
|
| 10 |
+
|
| 11 |
+
```python
|
| 12 |
+
TrainConfig(
|
| 13 |
+
...
|
| 14 |
+
data=LeRobotAlohaDataConfig(
|
| 15 |
+
...
|
| 16 |
+
assets=AssetsConfig(
|
| 17 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 18 |
+
asset_id="trossen",
|
| 19 |
+
),
|
| 20 |
+
),
|
| 21 |
+
)
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
|
| 25 |
+
|
| 26 |
+
**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
|
| 27 |
+
|
| 28 |
+
**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
## Provided Pre-training Normalization Statistics
|
| 32 |
+
|
| 33 |
+
Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
|
| 34 |
+
| Robot | Description | Asset ID |
|
| 35 |
+
|-------|-------------|----------|
|
| 36 |
+
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
|
| 37 |
+
| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
|
| 38 |
+
| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
|
| 39 |
+
| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
|
| 40 |
+
| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
|
| 41 |
+
| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
|
| 42 |
+
| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
|
| 43 |
+
| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
|
| 44 |
+
| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
## Pi0 Model Action Space Definitions
|
| 48 |
+
|
| 49 |
+
Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
|
| 50 |
+
```
|
| 51 |
+
"dim_0:dim_5": "left arm joint angles",
|
| 52 |
+
"dim_6": "left arm gripper position",
|
| 53 |
+
"dim_7:dim_12": "right arm joint angles (for bi-manual only)",
|
| 54 |
+
"dim_13": "right arm gripper position (for bi-manual only)",
|
| 55 |
+
|
| 56 |
+
# For mobile robots:
|
| 57 |
+
"dim_14:dim_15": "x-y base velocity (for mobile robots only)",
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
|
| 61 |
+
|
| 62 |
+
For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
|
| 63 |
+
|
| 64 |
+
General info for Pi robots:
|
| 65 |
+
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
|
| 66 |
+
- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
|
| 67 |
+
- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
|
| 68 |
+
|
| 69 |
+
For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.
|
capvector-pi05/docs/remote_inference.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Running openpi models remotely
|
| 3 |
+
|
| 4 |
+
We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).
|
| 5 |
+
|
| 6 |
+
## Starting a remote policy server
|
| 7 |
+
|
| 8 |
+
To start a remote policy server, you can simply run the following command:
|
| 9 |
+
|
| 10 |
+
```bash
|
| 11 |
+
uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
|
| 21 |
+
|
| 22 |
+
## Querying the remote policy server from your robot code
|
| 23 |
+
|
| 24 |
+
We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.
|
| 25 |
+
|
| 26 |
+
First, install the `openpi-client` package in your robot environment:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
cd $OPENPI_ROOT/packages/openpi-client
|
| 30 |
+
pip install -e .
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
from openpi_client import image_tools
|
| 37 |
+
from openpi_client import websocket_client_policy
|
| 38 |
+
|
| 39 |
+
# Outside of episode loop, initialize the policy client.
|
| 40 |
+
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
|
| 41 |
+
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
|
| 42 |
+
|
| 43 |
+
for step in range(num_steps):
|
| 44 |
+
# Inside the episode loop, construct the observation.
|
| 45 |
+
# Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
|
| 46 |
+
# We provide utilities for resizing images + uint8 conversion so you match the training routines.
|
| 47 |
+
# The typical resize_size for pre-trained pi0 models is 224.
|
| 48 |
+
# Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
|
| 49 |
+
observation = {
|
| 50 |
+
"observation/image": image_tools.convert_to_uint8(
|
| 51 |
+
image_tools.resize_with_pad(img, 224, 224)
|
| 52 |
+
),
|
| 53 |
+
"observation/wrist_image": image_tools.convert_to_uint8(
|
| 54 |
+
image_tools.resize_with_pad(wrist_img, 224, 224)
|
| 55 |
+
),
|
| 56 |
+
"observation/state": state,
|
| 57 |
+
"prompt": task_instruction,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Call the policy server with the current observation.
|
| 61 |
+
# This returns an action chunk of shape (action_horizon, action_dim).
|
| 62 |
+
# Note that you typically only need to call the policy every N steps and execute steps
|
| 63 |
+
# from the predicted action chunk open-loop in the remaining steps.
|
| 64 |
+
action_chunk = client.infer(observation)["actions"]
|
| 65 |
+
|
| 66 |
+
# Execute the actions in the environment.
|
| 67 |
+
...
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](../examples/simple_client/main.py).
|
capvector-pi05/examples/aloha_real/Dockerfile
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for the Aloha real environment.
|
| 2 |
+
|
| 3 |
+
# Build the container:
|
| 4 |
+
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
| 5 |
+
|
| 6 |
+
# Run the container:
|
| 7 |
+
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
| 8 |
+
|
| 9 |
+
FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
|
| 10 |
+
SHELL ["/bin/bash", "-c"]
|
| 11 |
+
|
| 12 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 13 |
+
RUN apt-get update && \
|
| 14 |
+
apt-get install -y --no-install-recommends \
|
| 15 |
+
cmake \
|
| 16 |
+
curl \
|
| 17 |
+
libffi-dev \
|
| 18 |
+
python3-rosdep \
|
| 19 |
+
python3-rosinstall \
|
| 20 |
+
python3-rosinstall-generator \
|
| 21 |
+
whiptail \
|
| 22 |
+
git \
|
| 23 |
+
wget \
|
| 24 |
+
openssh-client \
|
| 25 |
+
ros-noetic-cv-bridge \
|
| 26 |
+
ros-noetic-usb-cam \
|
| 27 |
+
ros-noetic-realsense2-camera \
|
| 28 |
+
keyboard-configuration
|
| 29 |
+
|
| 30 |
+
WORKDIR /root
|
| 31 |
+
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
| 32 |
+
RUN chmod +x xsarm_amd64_install.sh
|
| 33 |
+
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
| 34 |
+
|
| 35 |
+
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
| 36 |
+
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
| 37 |
+
|
| 38 |
+
# Install python 3.10 because this ROS image comes with 3.8
|
| 39 |
+
RUN mkdir /python && \
|
| 40 |
+
cd /python && \
|
| 41 |
+
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
| 42 |
+
tar -zxvf Python-3.10.14.tgz && \
|
| 43 |
+
cd Python-3.10.14 && \
|
| 44 |
+
ls -lhR && \
|
| 45 |
+
./configure --enable-optimizations && \
|
| 46 |
+
make install && \
|
| 47 |
+
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
| 48 |
+
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
| 49 |
+
cd ~ && rm -rf /python && \
|
| 50 |
+
rm -rf /var/lib/apt/lists/*
|
| 51 |
+
|
| 52 |
+
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
| 53 |
+
ENV UV_HTTP_TIMEOUT=120
|
| 54 |
+
ENV UV_LINK_MODE=copy
|
| 55 |
+
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
| 56 |
+
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
| 57 |
+
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
| 58 |
+
|
| 59 |
+
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
| 63 |
+
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
| 64 |
+
#!/bin/bash
|
| 65 |
+
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
| 66 |
+
EOF
|
| 67 |
+
RUN chmod +x /usr/local/bin/entrypoint.sh
|
| 68 |
+
|
| 69 |
+
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
| 70 |
+
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
capvector-pi05/examples/aloha_real/README.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run Aloha (Real Robot)
|
| 2 |
+
|
| 3 |
+
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
| 8 |
+
|
| 9 |
+
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
| 10 |
+
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
| 11 |
+
|
| 12 |
+
## With Docker
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
|
| 16 |
+
docker compose -f examples/aloha_real/compose.yml up --build
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Without Docker
|
| 20 |
+
|
| 21 |
+
Terminal window 1:
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
# Create virtual environment
|
| 25 |
+
uv venv --python 3.10 examples/aloha_real/.venv
|
| 26 |
+
source examples/aloha_real/.venv/bin/activate
|
| 27 |
+
uv pip sync examples/aloha_real/requirements.txt
|
| 28 |
+
uv pip install -e packages/openpi-client
|
| 29 |
+
|
| 30 |
+
# Run the robot
|
| 31 |
+
python -m examples.aloha_real.main
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
Terminal window 2:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
roslaunch aloha ros_nodes.launch
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Terminal window 3:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## **ALOHA Checkpoint Guide**
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
|
| 50 |
+
|
| 51 |
+
While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
### **Toast Task**
|
| 57 |
+
|
| 58 |
+
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
|
| 59 |
+
|
| 60 |
+
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
|
| 61 |
+
- **Prompt**: "take the toast out of the toaster"
|
| 62 |
+
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
|
| 63 |
+
- **Object Distribution**:
|
| 64 |
+
- Works on both real toast and rubber fake toast
|
| 65 |
+
- Compatible with standard 2-slice toasters
|
| 66 |
+
- Works with plates of varying colors
|
| 67 |
+
|
| 68 |
+
### **Scene Setup Guidelines**
|
| 69 |
+
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
|
| 70 |
+
|
| 71 |
+
- The toaster should be positioned in the top-left quadrant of the workspace.
|
| 72 |
+
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
|
| 73 |
+
- The plate should be placed roughly in the lower-center of the workspace.
|
| 74 |
+
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
### **Towel Task**
|
| 78 |
+
|
| 79 |
+
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
|
| 80 |
+
|
| 81 |
+
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
|
| 82 |
+
- **Prompt**: "fold the towel"
|
| 83 |
+
- **Object Distribution**:
|
| 84 |
+
- Works on towels of varying solid colors
|
| 85 |
+
- Performance is worse on heavily textured or striped towels
|
| 86 |
+
|
| 87 |
+
### **Scene Setup Guidelines**
|
| 88 |
+
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
|
| 89 |
+
|
| 90 |
+
- The towel should be flattened and roughly centered on the table.
|
| 91 |
+
- Choose a towel that does not blend in with the table surface.
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
### **Tupperware Task**
|
| 95 |
+
|
| 96 |
+
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
|
| 97 |
+
|
| 98 |
+
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
|
| 99 |
+
- **Prompt**: "open the tupperware and put the food on the plate"
|
| 100 |
+
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
|
| 101 |
+
- **Object Distribution**:
|
| 102 |
+
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
|
| 103 |
+
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
|
| 104 |
+
- The policy has seen plates of varying solid colors.
|
| 105 |
+
|
| 106 |
+
### **Scene Setup Guidelines**
|
| 107 |
+
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
|
| 108 |
+
|
| 109 |
+
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
|
| 110 |
+
- Positioning:
|
| 111 |
+
- Tupperware should be on the left.
|
| 112 |
+
- Plate should be on the right or bottom.
|
| 113 |
+
- The tupperware flap should point toward the plate.
|
| 114 |
+
|
| 115 |
+
## Training on your own Aloha dataset
|
| 116 |
+
|
| 117 |
+
1. Convert the dataset to the LeRobot dataset v2.0 format.
|
| 118 |
+
|
| 119 |
+
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
2. Define a training config that uses the custom dataset.
|
| 123 |
+
|
| 124 |
+
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
| 125 |
+
|
| 126 |
+
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
capvector-pi05/examples/aloha_real/compose.yml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run with:
|
| 2 |
+
# docker compose -f examples/aloha_real/compose.yml up --build
|
| 3 |
+
services:
|
| 4 |
+
runtime:
|
| 5 |
+
image: aloha_real
|
| 6 |
+
depends_on:
|
| 7 |
+
- aloha_ros_nodes
|
| 8 |
+
- ros_master
|
| 9 |
+
- openpi_server
|
| 10 |
+
build:
|
| 11 |
+
context: ../..
|
| 12 |
+
dockerfile: examples/aloha_real/Dockerfile
|
| 13 |
+
init: true
|
| 14 |
+
tty: true
|
| 15 |
+
network_mode: host
|
| 16 |
+
privileged: true
|
| 17 |
+
volumes:
|
| 18 |
+
- $PWD:/app
|
| 19 |
+
- ../../data:/data
|
| 20 |
+
|
| 21 |
+
aloha_ros_nodes:
|
| 22 |
+
image: aloha_real
|
| 23 |
+
depends_on:
|
| 24 |
+
- ros_master
|
| 25 |
+
build:
|
| 26 |
+
context: ../..
|
| 27 |
+
dockerfile: examples/aloha_real/Dockerfile
|
| 28 |
+
init: true
|
| 29 |
+
tty: true
|
| 30 |
+
network_mode: host
|
| 31 |
+
privileged: true
|
| 32 |
+
volumes:
|
| 33 |
+
- /dev:/dev
|
| 34 |
+
command: roslaunch --wait aloha ros_nodes.launch
|
| 35 |
+
|
| 36 |
+
ros_master:
|
| 37 |
+
image: ros:noetic-robot
|
| 38 |
+
network_mode: host
|
| 39 |
+
privileged: true
|
| 40 |
+
command:
|
| 41 |
+
- roscore
|
| 42 |
+
|
| 43 |
+
openpi_server:
|
| 44 |
+
image: openpi_server
|
| 45 |
+
build:
|
| 46 |
+
context: ../..
|
| 47 |
+
dockerfile: scripts/docker/serve_policy.Dockerfile
|
| 48 |
+
init: true
|
| 49 |
+
tty: true
|
| 50 |
+
network_mode: host
|
| 51 |
+
volumes:
|
| 52 |
+
- $PWD:/app
|
| 53 |
+
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
| 54 |
+
environment:
|
| 55 |
+
- SERVER_ARGS
|
| 56 |
+
- OPENPI_DATA_HOME=/openpi_assets
|
| 57 |
+
- IS_DOCKER=true
|
| 58 |
+
|
| 59 |
+
# Comment out this block if not running on a machine with GPUs.
|
| 60 |
+
deploy:
|
| 61 |
+
resources:
|
| 62 |
+
reservations:
|
| 63 |
+
devices:
|
| 64 |
+
- driver: nvidia
|
| 65 |
+
count: 1
|
| 66 |
+
capabilities: [gpu]
|
capvector-pi05/examples/aloha_real/constants.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
| 2 |
+
# ruff: noqa
|
| 3 |
+
|
| 4 |
+
### Task parameters
|
| 5 |
+
|
| 6 |
+
### ALOHA fixed constants
|
| 7 |
+
DT = 0.001
|
| 8 |
+
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
| 9 |
+
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
| 10 |
+
|
| 11 |
+
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
| 12 |
+
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
| 13 |
+
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
| 14 |
+
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
| 15 |
+
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
| 16 |
+
|
| 17 |
+
# Gripper joint limits (qpos[6])
|
| 18 |
+
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
| 19 |
+
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
| 20 |
+
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
| 21 |
+
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
| 22 |
+
|
| 23 |
+
############################ Helper functions ############################
|
| 24 |
+
|
| 25 |
+
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
| 26 |
+
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
| 27 |
+
)
|
| 28 |
+
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
| 29 |
+
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
| 30 |
+
)
|
| 31 |
+
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
| 32 |
+
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
| 33 |
+
)
|
| 34 |
+
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
| 35 |
+
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
| 36 |
+
)
|
| 37 |
+
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
| 38 |
+
|
| 39 |
+
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
| 40 |
+
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
| 41 |
+
)
|
| 42 |
+
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
| 43 |
+
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
| 44 |
+
)
|
| 45 |
+
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
| 46 |
+
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
| 47 |
+
)
|
| 48 |
+
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
| 49 |
+
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
| 50 |
+
)
|
| 51 |
+
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
| 52 |
+
|
| 53 |
+
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
| 54 |
+
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
| 55 |
+
|
| 56 |
+
MASTER_POS2JOINT = (
|
| 57 |
+
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
| 58 |
+
+ MASTER_GRIPPER_JOINT_CLOSE
|
| 59 |
+
)
|
| 60 |
+
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
| 61 |
+
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
| 62 |
+
)
|
| 63 |
+
PUPPET_POS2JOINT = (
|
| 64 |
+
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
| 65 |
+
+ PUPPET_GRIPPER_JOINT_CLOSE
|
| 66 |
+
)
|
| 67 |
+
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
| 68 |
+
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
capvector-pi05/examples/aloha_real/convert_aloha_data_to_lerobot.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
|
| 3 |
+
|
| 4 |
+
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import dataclasses
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import shutil
|
| 10 |
+
from typing import Literal
|
| 11 |
+
|
| 12 |
+
import h5py
|
| 13 |
+
from lerobot.common.constants import HF_LEROBOT_HOME
|
| 14 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import tqdm
|
| 18 |
+
import tyro
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclasses.dataclass(frozen=True)
|
| 22 |
+
class DatasetConfig:
|
| 23 |
+
use_videos: bool = True
|
| 24 |
+
tolerance_s: float = 0.0001
|
| 25 |
+
image_writer_processes: int = 10
|
| 26 |
+
image_writer_threads: int = 5
|
| 27 |
+
video_backend: str | None = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
DEFAULT_DATASET_CONFIG = DatasetConfig()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def create_empty_dataset(
|
| 34 |
+
repo_id: str,
|
| 35 |
+
robot_type: str,
|
| 36 |
+
cameras: list[str],
|
| 37 |
+
mode: Literal["video", "image"] = "video",
|
| 38 |
+
*,
|
| 39 |
+
has_velocity: bool = False,
|
| 40 |
+
has_effort: bool = False,
|
| 41 |
+
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
| 42 |
+
) -> LeRobotDataset:
|
| 43 |
+
motors = [
|
| 44 |
+
"right_waist",
|
| 45 |
+
"right_shoulder",
|
| 46 |
+
"right_elbow",
|
| 47 |
+
"right_forearm_roll",
|
| 48 |
+
"right_wrist_angle",
|
| 49 |
+
"right_wrist_rotate",
|
| 50 |
+
"right_gripper",
|
| 51 |
+
"left_waist",
|
| 52 |
+
"left_shoulder",
|
| 53 |
+
"left_elbow",
|
| 54 |
+
"left_forearm_roll",
|
| 55 |
+
"left_wrist_angle",
|
| 56 |
+
"left_wrist_rotate",
|
| 57 |
+
"left_gripper",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
features = {
|
| 61 |
+
"observation.state": {
|
| 62 |
+
"dtype": "float32",
|
| 63 |
+
"shape": (len(motors),),
|
| 64 |
+
"names": [
|
| 65 |
+
motors,
|
| 66 |
+
],
|
| 67 |
+
},
|
| 68 |
+
"action": {
|
| 69 |
+
"dtype": "float32",
|
| 70 |
+
"shape": (len(motors),),
|
| 71 |
+
"names": [
|
| 72 |
+
motors,
|
| 73 |
+
],
|
| 74 |
+
},
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
if has_velocity:
|
| 78 |
+
features["observation.velocity"] = {
|
| 79 |
+
"dtype": "float32",
|
| 80 |
+
"shape": (len(motors),),
|
| 81 |
+
"names": [
|
| 82 |
+
motors,
|
| 83 |
+
],
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
if has_effort:
|
| 87 |
+
features["observation.effort"] = {
|
| 88 |
+
"dtype": "float32",
|
| 89 |
+
"shape": (len(motors),),
|
| 90 |
+
"names": [
|
| 91 |
+
motors,
|
| 92 |
+
],
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
for cam in cameras:
|
| 96 |
+
features[f"observation.images.{cam}"] = {
|
| 97 |
+
"dtype": mode,
|
| 98 |
+
"shape": (3, 480, 640),
|
| 99 |
+
"names": [
|
| 100 |
+
"channels",
|
| 101 |
+
"height",
|
| 102 |
+
"width",
|
| 103 |
+
],
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if Path(HF_LEROBOT_HOME / repo_id).exists():
|
| 107 |
+
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
| 108 |
+
|
| 109 |
+
return LeRobotDataset.create(
|
| 110 |
+
repo_id=repo_id,
|
| 111 |
+
fps=50,
|
| 112 |
+
robot_type=robot_type,
|
| 113 |
+
features=features,
|
| 114 |
+
use_videos=dataset_config.use_videos,
|
| 115 |
+
tolerance_s=dataset_config.tolerance_s,
|
| 116 |
+
image_writer_processes=dataset_config.image_writer_processes,
|
| 117 |
+
image_writer_threads=dataset_config.image_writer_threads,
|
| 118 |
+
video_backend=dataset_config.video_backend,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_cameras(hdf5_files: list[Path]) -> list[str]:
|
| 123 |
+
with h5py.File(hdf5_files[0], "r") as ep:
|
| 124 |
+
# ignore depth channel, not currently handled
|
| 125 |
+
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def has_velocity(hdf5_files: list[Path]) -> bool:
|
| 129 |
+
with h5py.File(hdf5_files[0], "r") as ep:
|
| 130 |
+
return "/observations/qvel" in ep
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def has_effort(hdf5_files: list[Path]) -> bool:
|
| 134 |
+
with h5py.File(hdf5_files[0], "r") as ep:
|
| 135 |
+
return "/observations/effort" in ep
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
|
| 139 |
+
imgs_per_cam = {}
|
| 140 |
+
for camera in cameras:
|
| 141 |
+
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
| 142 |
+
|
| 143 |
+
if uncompressed:
|
| 144 |
+
# load all images in RAM
|
| 145 |
+
imgs_array = ep[f"/observations/images/{camera}"][:]
|
| 146 |
+
else:
|
| 147 |
+
import cv2
|
| 148 |
+
|
| 149 |
+
# load one compressed image after the other in RAM and uncompress
|
| 150 |
+
imgs_array = []
|
| 151 |
+
for data in ep[f"/observations/images/{camera}"]:
|
| 152 |
+
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
|
| 153 |
+
imgs_array = np.array(imgs_array)
|
| 154 |
+
|
| 155 |
+
imgs_per_cam[camera] = imgs_array
|
| 156 |
+
return imgs_per_cam
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def load_raw_episode_data(
|
| 160 |
+
ep_path: Path,
|
| 161 |
+
cameras: list[str],
|
| 162 |
+
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
| 163 |
+
with h5py.File(ep_path, "r") as ep:
|
| 164 |
+
state = torch.from_numpy(ep["/observations/qpos"][:])
|
| 165 |
+
action = torch.from_numpy(ep["/action"][:])
|
| 166 |
+
|
| 167 |
+
velocity = None
|
| 168 |
+
if "/observations/qvel" in ep:
|
| 169 |
+
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
| 170 |
+
|
| 171 |
+
effort = None
|
| 172 |
+
if "/observations/effort" in ep:
|
| 173 |
+
effort = torch.from_numpy(ep["/observations/effort"][:])
|
| 174 |
+
|
| 175 |
+
imgs_per_cam = load_raw_images_per_camera(ep, cameras)
|
| 176 |
+
|
| 177 |
+
return imgs_per_cam, state, action, velocity, effort
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def populate_dataset(
|
| 181 |
+
dataset: LeRobotDataset,
|
| 182 |
+
hdf5_files: list[Path],
|
| 183 |
+
cameras: list[str],
|
| 184 |
+
task: str,
|
| 185 |
+
episodes: list[int] | None = None,
|
| 186 |
+
) -> LeRobotDataset:
|
| 187 |
+
if episodes is None:
|
| 188 |
+
episodes = range(len(hdf5_files))
|
| 189 |
+
|
| 190 |
+
for ep_idx in tqdm.tqdm(episodes):
|
| 191 |
+
ep_path = hdf5_files[ep_idx]
|
| 192 |
+
|
| 193 |
+
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path, cameras)
|
| 194 |
+
num_frames = state.shape[0]
|
| 195 |
+
|
| 196 |
+
for i in range(num_frames):
|
| 197 |
+
frame = {
|
| 198 |
+
"observation.state": state[i],
|
| 199 |
+
"action": action[i],
|
| 200 |
+
"task": task,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
for camera, img_array in imgs_per_cam.items():
|
| 204 |
+
frame[f"observation.images.{camera}"] = img_array[i]
|
| 205 |
+
|
| 206 |
+
if velocity is not None:
|
| 207 |
+
frame["observation.velocity"] = velocity[i]
|
| 208 |
+
if effort is not None:
|
| 209 |
+
frame["observation.effort"] = effort[i]
|
| 210 |
+
|
| 211 |
+
dataset.add_frame(frame)
|
| 212 |
+
|
| 213 |
+
dataset.save_episode()
|
| 214 |
+
|
| 215 |
+
return dataset
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def port_aloha(
|
| 219 |
+
raw_dir: Path,
|
| 220 |
+
repo_id: str,
|
| 221 |
+
task: str = "DEBUG",
|
| 222 |
+
*,
|
| 223 |
+
episodes: list[int] | None = None,
|
| 224 |
+
push_to_hub: bool = False,
|
| 225 |
+
is_mobile: bool = False,
|
| 226 |
+
mode: Literal["video", "image"] = "image",
|
| 227 |
+
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
| 228 |
+
):
|
| 229 |
+
if (HF_LEROBOT_HOME / repo_id).exists():
|
| 230 |
+
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
| 231 |
+
|
| 232 |
+
if not raw_dir.exists():
|
| 233 |
+
raise ValueError(f"Raw directory {raw_dir} does not exist. Please provide a valid path to the raw data.")
|
| 234 |
+
|
| 235 |
+
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
| 236 |
+
|
| 237 |
+
# Get camera names from the first episode
|
| 238 |
+
cameras = get_cameras(hdf5_files)
|
| 239 |
+
print(f"Detected cameras: {cameras}")
|
| 240 |
+
|
| 241 |
+
dataset = create_empty_dataset(
|
| 242 |
+
repo_id,
|
| 243 |
+
robot_type="mobile_aloha" if is_mobile else "aloha",
|
| 244 |
+
cameras=cameras,
|
| 245 |
+
mode=mode,
|
| 246 |
+
has_effort=has_effort(hdf5_files),
|
| 247 |
+
has_velocity=has_velocity(hdf5_files),
|
| 248 |
+
dataset_config=dataset_config,
|
| 249 |
+
)
|
| 250 |
+
dataset = populate_dataset(
|
| 251 |
+
dataset,
|
| 252 |
+
hdf5_files,
|
| 253 |
+
cameras=cameras,
|
| 254 |
+
task=task,
|
| 255 |
+
episodes=episodes,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if push_to_hub:
|
| 259 |
+
dataset.push_to_hub()
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
tyro.cli(port_aloha)
|
capvector-pi05/examples/aloha_real/env.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional # noqa: UP035
|
| 2 |
+
|
| 3 |
+
import einops
|
| 4 |
+
from openpi_client import image_tools
|
| 5 |
+
from openpi_client.runtime import environment as _environment
|
| 6 |
+
from typing_extensions import override
|
| 7 |
+
|
| 8 |
+
from examples.aloha_real import real_env as _real_env
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AlohaRealEnvironment(_environment.Environment):
|
| 12 |
+
"""An environment for an Aloha robot on real hardware."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
|
| 17 |
+
render_height: int = 224,
|
| 18 |
+
render_width: int = 224,
|
| 19 |
+
) -> None:
|
| 20 |
+
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
|
| 21 |
+
self._render_height = render_height
|
| 22 |
+
self._render_width = render_width
|
| 23 |
+
|
| 24 |
+
self._ts = None
|
| 25 |
+
|
| 26 |
+
@override
|
| 27 |
+
def reset(self) -> None:
|
| 28 |
+
self._ts = self._env.reset()
|
| 29 |
+
|
| 30 |
+
@override
|
| 31 |
+
def is_episode_complete(self) -> bool:
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
@override
|
| 35 |
+
def get_observation(self) -> dict:
|
| 36 |
+
if self._ts is None:
|
| 37 |
+
raise RuntimeError("Timestep is not set. Call reset() first.")
|
| 38 |
+
|
| 39 |
+
obs = self._ts.observation
|
| 40 |
+
for k in list(obs["images"].keys()):
|
| 41 |
+
if "_depth" in k:
|
| 42 |
+
del obs["images"][k]
|
| 43 |
+
|
| 44 |
+
for cam_name in obs["images"]:
|
| 45 |
+
img = image_tools.convert_to_uint8(
|
| 46 |
+
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
|
| 47 |
+
)
|
| 48 |
+
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"state": obs["qpos"],
|
| 52 |
+
"images": obs["images"],
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
@override
|
| 56 |
+
def apply_action(self, action: dict) -> None:
|
| 57 |
+
self._ts = self._env.step(action["actions"])
|
capvector-pi05/examples/aloha_real/main.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from openpi_client import action_chunk_broker
|
| 5 |
+
from openpi_client import websocket_client_policy as _websocket_client_policy
|
| 6 |
+
from openpi_client.runtime import runtime as _runtime
|
| 7 |
+
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
| 8 |
+
import tyro
|
| 9 |
+
|
| 10 |
+
from examples.aloha_real import env as _env
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclasses.dataclass
|
| 14 |
+
class Args:
|
| 15 |
+
host: str = "0.0.0.0"
|
| 16 |
+
port: int = 8000
|
| 17 |
+
|
| 18 |
+
action_horizon: int = 25
|
| 19 |
+
|
| 20 |
+
num_episodes: int = 1
|
| 21 |
+
max_episode_steps: int = 1000
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main(args: Args) -> None:
|
| 25 |
+
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
|
| 26 |
+
host=args.host,
|
| 27 |
+
port=args.port,
|
| 28 |
+
)
|
| 29 |
+
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
|
| 30 |
+
|
| 31 |
+
metadata = ws_client_policy.get_server_metadata()
|
| 32 |
+
runtime = _runtime.Runtime(
|
| 33 |
+
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
|
| 34 |
+
agent=_policy_agent.PolicyAgent(
|
| 35 |
+
policy=action_chunk_broker.ActionChunkBroker(
|
| 36 |
+
policy=ws_client_policy,
|
| 37 |
+
action_horizon=args.action_horizon,
|
| 38 |
+
)
|
| 39 |
+
),
|
| 40 |
+
subscribers=[],
|
| 41 |
+
max_hz=50,
|
| 42 |
+
num_episodes=args.num_episodes,
|
| 43 |
+
max_episode_steps=args.max_episode_steps,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
runtime.run()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
logging.basicConfig(level=logging.INFO, force=True)
|
| 51 |
+
tyro.cli(main)
|
capvector-pi05/examples/aloha_real/real_env.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
| 2 |
+
# ruff: noqa
|
| 3 |
+
import collections
|
| 4 |
+
import time
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
import dm_env
|
| 7 |
+
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
| 8 |
+
from interbotix_xs_msgs.msg import JointSingleCommand
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from examples.aloha_real import constants
|
| 12 |
+
from examples.aloha_real import robot_utils
|
| 13 |
+
|
| 14 |
+
# This is the reset position that is used by the standard Aloha runtime.
|
| 15 |
+
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RealEnv:
|
| 19 |
+
"""
|
| 20 |
+
Environment for real robot bi-manual manipulation
|
| 21 |
+
Action space: [left_arm_qpos (6), # absolute joint position
|
| 22 |
+
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
| 23 |
+
right_arm_qpos (6), # absolute joint position
|
| 24 |
+
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
| 25 |
+
|
| 26 |
+
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
| 27 |
+
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
| 28 |
+
right_arm_qpos (6), # absolute joint position
|
| 29 |
+
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
| 30 |
+
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
| 31 |
+
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
| 32 |
+
right_arm_qvel (6), # absolute joint velocity (rad)
|
| 33 |
+
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
| 34 |
+
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
| 35 |
+
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
| 36 |
+
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
| 37 |
+
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
|
| 41 |
+
# reset_position = START_ARM_POSE[:6]
|
| 42 |
+
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
|
| 43 |
+
|
| 44 |
+
self.puppet_bot_left = InterbotixManipulatorXS(
|
| 45 |
+
robot_model="vx300s",
|
| 46 |
+
group_name="arm",
|
| 47 |
+
gripper_name="gripper",
|
| 48 |
+
robot_name="puppet_left",
|
| 49 |
+
init_node=init_node,
|
| 50 |
+
)
|
| 51 |
+
self.puppet_bot_right = InterbotixManipulatorXS(
|
| 52 |
+
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
| 53 |
+
)
|
| 54 |
+
if setup_robots:
|
| 55 |
+
self.setup_robots()
|
| 56 |
+
|
| 57 |
+
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
| 58 |
+
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
| 59 |
+
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
| 60 |
+
self.gripper_command = JointSingleCommand(name="gripper")
|
| 61 |
+
|
| 62 |
+
def setup_robots(self):
|
| 63 |
+
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
| 64 |
+
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
| 65 |
+
|
| 66 |
+
def get_qpos(self):
|
| 67 |
+
left_qpos_raw = self.recorder_left.qpos
|
| 68 |
+
right_qpos_raw = self.recorder_right.qpos
|
| 69 |
+
left_arm_qpos = left_qpos_raw[:6]
|
| 70 |
+
right_arm_qpos = right_qpos_raw[:6]
|
| 71 |
+
left_gripper_qpos = [
|
| 72 |
+
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
| 73 |
+
] # this is position not joint
|
| 74 |
+
right_gripper_qpos = [
|
| 75 |
+
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
| 76 |
+
] # this is position not joint
|
| 77 |
+
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
| 78 |
+
|
| 79 |
+
def get_qvel(self):
|
| 80 |
+
left_qvel_raw = self.recorder_left.qvel
|
| 81 |
+
right_qvel_raw = self.recorder_right.qvel
|
| 82 |
+
left_arm_qvel = left_qvel_raw[:6]
|
| 83 |
+
right_arm_qvel = right_qvel_raw[:6]
|
| 84 |
+
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
| 85 |
+
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
| 86 |
+
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
| 87 |
+
|
| 88 |
+
def get_effort(self):
|
| 89 |
+
left_effort_raw = self.recorder_left.effort
|
| 90 |
+
right_effort_raw = self.recorder_right.effort
|
| 91 |
+
left_robot_effort = left_effort_raw[:7]
|
| 92 |
+
right_robot_effort = right_effort_raw[:7]
|
| 93 |
+
return np.concatenate([left_robot_effort, right_robot_effort])
|
| 94 |
+
|
| 95 |
+
def get_images(self):
|
| 96 |
+
return self.image_recorder.get_images()
|
| 97 |
+
|
| 98 |
+
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
| 99 |
+
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
| 100 |
+
self.gripper_command.cmd = left_gripper_desired_joint
|
| 101 |
+
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
| 102 |
+
|
| 103 |
+
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
| 104 |
+
right_gripper_desired_pos_normalized
|
| 105 |
+
)
|
| 106 |
+
self.gripper_command.cmd = right_gripper_desired_joint
|
| 107 |
+
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
| 108 |
+
|
| 109 |
+
def _reset_joints(self):
|
| 110 |
+
robot_utils.move_arms(
|
| 111 |
+
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def _reset_gripper(self):
|
| 115 |
+
"""Set to position mode and do position resets: first close then open. Then change back to PWM mode
|
| 116 |
+
|
| 117 |
+
NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
|
| 118 |
+
was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
|
| 119 |
+
increase the frequency of motor faults.
|
| 120 |
+
"""
|
| 121 |
+
robot_utils.move_grippers(
|
| 122 |
+
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
| 123 |
+
)
|
| 124 |
+
robot_utils.move_grippers(
|
| 125 |
+
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def get_observation(self):
|
| 129 |
+
obs = collections.OrderedDict()
|
| 130 |
+
obs["qpos"] = self.get_qpos()
|
| 131 |
+
obs["qvel"] = self.get_qvel()
|
| 132 |
+
obs["effort"] = self.get_effort()
|
| 133 |
+
obs["images"] = self.get_images()
|
| 134 |
+
return obs
|
| 135 |
+
|
| 136 |
+
def get_reward(self):
|
| 137 |
+
return 0
|
| 138 |
+
|
| 139 |
+
def reset(self, *, fake=False):
|
| 140 |
+
if not fake:
|
| 141 |
+
# Reboot puppet robot gripper motors
|
| 142 |
+
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
| 143 |
+
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
| 144 |
+
self._reset_joints()
|
| 145 |
+
self._reset_gripper()
|
| 146 |
+
return dm_env.TimeStep(
|
| 147 |
+
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def step(self, action):
|
| 151 |
+
state_len = int(len(action) / 2)
|
| 152 |
+
left_action = action[:state_len]
|
| 153 |
+
right_action = action[state_len:]
|
| 154 |
+
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
| 155 |
+
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
| 156 |
+
self.set_gripper_pose(left_action[-1], right_action[-1])
|
| 157 |
+
time.sleep(constants.DT)
|
| 158 |
+
return dm_env.TimeStep(
|
| 159 |
+
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_action(master_bot_left, master_bot_right):
|
| 164 |
+
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
| 165 |
+
# Arm actions
|
| 166 |
+
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
| 167 |
+
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
| 168 |
+
# Gripper actions
|
| 169 |
+
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
| 170 |
+
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
| 171 |
+
|
| 172 |
+
return action
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
|
| 176 |
+
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
|
capvector-pi05/examples/aloha_real/requirements.in
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Pillow
|
| 2 |
+
dm_control
|
| 3 |
+
einops
|
| 4 |
+
h5py
|
| 5 |
+
matplotlib
|
| 6 |
+
modern_robotics
|
| 7 |
+
msgpack
|
| 8 |
+
numpy>=1.22.4,<2.0.0
|
| 9 |
+
opencv-python
|
| 10 |
+
packaging
|
| 11 |
+
pexpect
|
| 12 |
+
pyquaternion
|
| 13 |
+
pyrealsense2
|
| 14 |
+
pyyaml
|
| 15 |
+
requests
|
| 16 |
+
rospkg
|
| 17 |
+
tyro
|
| 18 |
+
websockets
|
capvector-pi05/examples/aloha_real/requirements.txt
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
| 3 |
+
absl-py==2.1.0
|
| 4 |
+
# via
|
| 5 |
+
# dm-control
|
| 6 |
+
# dm-env
|
| 7 |
+
# labmaze
|
| 8 |
+
# mujoco
|
| 9 |
+
catkin-pkg==1.0.0
|
| 10 |
+
# via rospkg
|
| 11 |
+
certifi==2024.8.30
|
| 12 |
+
# via requests
|
| 13 |
+
charset-normalizer==3.4.0
|
| 14 |
+
# via requests
|
| 15 |
+
contourpy==1.1.1
|
| 16 |
+
# via matplotlib
|
| 17 |
+
cycler==0.12.1
|
| 18 |
+
# via matplotlib
|
| 19 |
+
distro==1.9.0
|
| 20 |
+
# via rospkg
|
| 21 |
+
dm-control==1.0.23
|
| 22 |
+
# via -r examples/aloha_real/requirements.in
|
| 23 |
+
dm-env==1.6
|
| 24 |
+
# via dm-control
|
| 25 |
+
dm-tree==0.1.8
|
| 26 |
+
# via
|
| 27 |
+
# dm-control
|
| 28 |
+
# dm-env
|
| 29 |
+
docstring-parser==0.16
|
| 30 |
+
# via tyro
|
| 31 |
+
docutils==0.20.1
|
| 32 |
+
# via catkin-pkg
|
| 33 |
+
einops==0.8.0
|
| 34 |
+
# via -r examples/aloha_real/requirements.in
|
| 35 |
+
etils==1.3.0
|
| 36 |
+
# via mujoco
|
| 37 |
+
fonttools==4.55.2
|
| 38 |
+
# via matplotlib
|
| 39 |
+
glfw==2.8.0
|
| 40 |
+
# via
|
| 41 |
+
# dm-control
|
| 42 |
+
# mujoco
|
| 43 |
+
h5py==3.11.0
|
| 44 |
+
# via -r examples/aloha_real/requirements.in
|
| 45 |
+
idna==3.10
|
| 46 |
+
# via requests
|
| 47 |
+
importlib-resources==6.4.5
|
| 48 |
+
# via etils
|
| 49 |
+
kiwisolver==1.4.7
|
| 50 |
+
# via matplotlib
|
| 51 |
+
labmaze==1.0.6
|
| 52 |
+
# via dm-control
|
| 53 |
+
lxml==5.3.0
|
| 54 |
+
# via dm-control
|
| 55 |
+
markdown-it-py==3.0.0
|
| 56 |
+
# via rich
|
| 57 |
+
matplotlib==3.7.5
|
| 58 |
+
# via -r examples/aloha_real/requirements.in
|
| 59 |
+
mdurl==0.1.2
|
| 60 |
+
# via markdown-it-py
|
| 61 |
+
modern-robotics==1.1.1
|
| 62 |
+
# via -r examples/aloha_real/requirements.in
|
| 63 |
+
msgpack==1.1.0
|
| 64 |
+
# via -r examples/aloha_real/requirements.in
|
| 65 |
+
mujoco==3.2.3
|
| 66 |
+
# via dm-control
|
| 67 |
+
numpy==1.24.4
|
| 68 |
+
# via
|
| 69 |
+
# -r examples/aloha_real/requirements.in
|
| 70 |
+
# contourpy
|
| 71 |
+
# dm-control
|
| 72 |
+
# dm-env
|
| 73 |
+
# h5py
|
| 74 |
+
# labmaze
|
| 75 |
+
# matplotlib
|
| 76 |
+
# modern-robotics
|
| 77 |
+
# mujoco
|
| 78 |
+
# opencv-python
|
| 79 |
+
# pyquaternion
|
| 80 |
+
# scipy
|
| 81 |
+
opencv-python==4.10.0.84
|
| 82 |
+
# via -r examples/aloha_real/requirements.in
|
| 83 |
+
packaging==24.2
|
| 84 |
+
# via
|
| 85 |
+
# -r examples/aloha_real/requirements.in
|
| 86 |
+
# matplotlib
|
| 87 |
+
pexpect==4.9.0
|
| 88 |
+
# via -r examples/aloha_real/requirements.in
|
| 89 |
+
pillow==10.4.0
|
| 90 |
+
# via
|
| 91 |
+
# -r examples/aloha_real/requirements.in
|
| 92 |
+
# matplotlib
|
| 93 |
+
protobuf==5.29.1
|
| 94 |
+
# via dm-control
|
| 95 |
+
ptyprocess==0.7.0
|
| 96 |
+
# via pexpect
|
| 97 |
+
pygments==2.18.0
|
| 98 |
+
# via rich
|
| 99 |
+
pyopengl==3.1.7
|
| 100 |
+
# via
|
| 101 |
+
# dm-control
|
| 102 |
+
# mujoco
|
| 103 |
+
pyparsing==3.1.4
|
| 104 |
+
# via
|
| 105 |
+
# catkin-pkg
|
| 106 |
+
# dm-control
|
| 107 |
+
# matplotlib
|
| 108 |
+
pyquaternion==0.9.9
|
| 109 |
+
# via -r examples/aloha_real/requirements.in
|
| 110 |
+
pyrealsense2==2.55.1.6486
|
| 111 |
+
# via -r examples/aloha_real/requirements.in
|
| 112 |
+
python-dateutil==2.9.0.post0
|
| 113 |
+
# via
|
| 114 |
+
# catkin-pkg
|
| 115 |
+
# matplotlib
|
| 116 |
+
pyyaml==6.0.2
|
| 117 |
+
# via
|
| 118 |
+
# -r examples/aloha_real/requirements.in
|
| 119 |
+
# rospkg
|
| 120 |
+
requests==2.32.3
|
| 121 |
+
# via
|
| 122 |
+
# -r examples/aloha_real/requirements.in
|
| 123 |
+
# dm-control
|
| 124 |
+
rich==13.9.4
|
| 125 |
+
# via tyro
|
| 126 |
+
rospkg==1.5.1
|
| 127 |
+
# via -r examples/aloha_real/requirements.in
|
| 128 |
+
scipy==1.10.1
|
| 129 |
+
# via dm-control
|
| 130 |
+
setuptools==75.3.0
|
| 131 |
+
# via
|
| 132 |
+
# catkin-pkg
|
| 133 |
+
# dm-control
|
| 134 |
+
# labmaze
|
| 135 |
+
shtab==1.7.1
|
| 136 |
+
# via tyro
|
| 137 |
+
six==1.17.0
|
| 138 |
+
# via python-dateutil
|
| 139 |
+
tqdm==4.67.1
|
| 140 |
+
# via dm-control
|
| 141 |
+
typeguard==4.4.0
|
| 142 |
+
# via tyro
|
| 143 |
+
typing-extensions==4.12.2
|
| 144 |
+
# via
|
| 145 |
+
# etils
|
| 146 |
+
# rich
|
| 147 |
+
# typeguard
|
| 148 |
+
# tyro
|
| 149 |
+
tyro==0.9.2
|
| 150 |
+
# via -r examples/aloha_real/requirements.in
|
| 151 |
+
urllib3==2.2.3
|
| 152 |
+
# via requests
|
| 153 |
+
websockets==14.1
|
| 154 |
+
# via -r examples/aloha_real/requirements.in
|
| 155 |
+
zipp==3.20.2
|
| 156 |
+
# via etils
|
capvector-pi05/examples/aloha_real/robot_utils.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
| 2 |
+
# ruff: noqa
|
| 3 |
+
from collections import deque
|
| 4 |
+
import datetime
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
from aloha.msg import RGBGrayscaleImage
|
| 9 |
+
from cv_bridge import CvBridge
|
| 10 |
+
from interbotix_xs_msgs.msg import JointGroupCommand
|
| 11 |
+
from interbotix_xs_msgs.msg import JointSingleCommand
|
| 12 |
+
import numpy as np
|
| 13 |
+
import rospy
|
| 14 |
+
from sensor_msgs.msg import JointState
|
| 15 |
+
|
| 16 |
+
from examples.aloha_real import constants
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ImageRecorder:
|
| 20 |
+
def __init__(self, init_node=True, is_debug=False):
|
| 21 |
+
self.is_debug = is_debug
|
| 22 |
+
self.bridge = CvBridge()
|
| 23 |
+
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
| 24 |
+
|
| 25 |
+
if init_node:
|
| 26 |
+
rospy.init_node("image_recorder", anonymous=True)
|
| 27 |
+
for cam_name in self.camera_names:
|
| 28 |
+
setattr(self, f"{cam_name}_rgb_image", None)
|
| 29 |
+
setattr(self, f"{cam_name}_depth_image", None)
|
| 30 |
+
setattr(self, f"{cam_name}_timestamp", 0.0)
|
| 31 |
+
if cam_name == "cam_high":
|
| 32 |
+
callback_func = self.image_cb_cam_high
|
| 33 |
+
elif cam_name == "cam_low":
|
| 34 |
+
callback_func = self.image_cb_cam_low
|
| 35 |
+
elif cam_name == "cam_left_wrist":
|
| 36 |
+
callback_func = self.image_cb_cam_left_wrist
|
| 37 |
+
elif cam_name == "cam_right_wrist":
|
| 38 |
+
callback_func = self.image_cb_cam_right_wrist
|
| 39 |
+
else:
|
| 40 |
+
raise NotImplementedError
|
| 41 |
+
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
| 42 |
+
if self.is_debug:
|
| 43 |
+
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
| 44 |
+
|
| 45 |
+
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
| 46 |
+
time.sleep(0.5)
|
| 47 |
+
|
| 48 |
+
def image_cb(self, cam_name, data):
|
| 49 |
+
setattr(
|
| 50 |
+
self,
|
| 51 |
+
f"{cam_name}_rgb_image",
|
| 52 |
+
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
| 53 |
+
)
|
| 54 |
+
# setattr(
|
| 55 |
+
# self,
|
| 56 |
+
# f"{cam_name}_depth_image",
|
| 57 |
+
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
| 58 |
+
# )
|
| 59 |
+
setattr(
|
| 60 |
+
self,
|
| 61 |
+
f"{cam_name}_timestamp",
|
| 62 |
+
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
| 63 |
+
)
|
| 64 |
+
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
| 65 |
+
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
| 66 |
+
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
| 67 |
+
if self.is_debug:
|
| 68 |
+
getattr(self, f"{cam_name}_timestamps").append(
|
| 69 |
+
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def image_cb_cam_high(self, data):
|
| 73 |
+
cam_name = "cam_high"
|
| 74 |
+
return self.image_cb(cam_name, data)
|
| 75 |
+
|
| 76 |
+
def image_cb_cam_low(self, data):
|
| 77 |
+
cam_name = "cam_low"
|
| 78 |
+
return self.image_cb(cam_name, data)
|
| 79 |
+
|
| 80 |
+
def image_cb_cam_left_wrist(self, data):
|
| 81 |
+
cam_name = "cam_left_wrist"
|
| 82 |
+
return self.image_cb(cam_name, data)
|
| 83 |
+
|
| 84 |
+
def image_cb_cam_right_wrist(self, data):
|
| 85 |
+
cam_name = "cam_right_wrist"
|
| 86 |
+
return self.image_cb(cam_name, data)
|
| 87 |
+
|
| 88 |
+
def get_images(self):
|
| 89 |
+
image_dict = {}
|
| 90 |
+
for cam_name in self.camera_names:
|
| 91 |
+
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
| 92 |
+
time.sleep(0.00001)
|
| 93 |
+
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
| 94 |
+
depth_image = getattr(self, f"{cam_name}_depth_image")
|
| 95 |
+
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
| 96 |
+
image_dict[cam_name] = rgb_image
|
| 97 |
+
image_dict[f"{cam_name}_depth"] = depth_image
|
| 98 |
+
return image_dict
|
| 99 |
+
|
| 100 |
+
def print_diagnostics(self):
|
| 101 |
+
def dt_helper(l):
|
| 102 |
+
l = np.array(l)
|
| 103 |
+
diff = l[1:] - l[:-1]
|
| 104 |
+
return np.mean(diff)
|
| 105 |
+
|
| 106 |
+
for cam_name in self.camera_names:
|
| 107 |
+
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
| 108 |
+
print(f"{cam_name} {image_freq=:.2f}")
|
| 109 |
+
print()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Recorder:
|
| 113 |
+
def __init__(self, side, init_node=True, is_debug=False):
|
| 114 |
+
self.secs = None
|
| 115 |
+
self.nsecs = None
|
| 116 |
+
self.qpos = None
|
| 117 |
+
self.effort = None
|
| 118 |
+
self.arm_command = None
|
| 119 |
+
self.gripper_command = None
|
| 120 |
+
self.is_debug = is_debug
|
| 121 |
+
|
| 122 |
+
if init_node:
|
| 123 |
+
rospy.init_node("recorder", anonymous=True)
|
| 124 |
+
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
| 125 |
+
rospy.Subscriber(
|
| 126 |
+
f"/puppet_{side}/commands/joint_group",
|
| 127 |
+
JointGroupCommand,
|
| 128 |
+
self.puppet_arm_commands_cb,
|
| 129 |
+
)
|
| 130 |
+
rospy.Subscriber(
|
| 131 |
+
f"/puppet_{side}/commands/joint_single",
|
| 132 |
+
JointSingleCommand,
|
| 133 |
+
self.puppet_gripper_commands_cb,
|
| 134 |
+
)
|
| 135 |
+
if self.is_debug:
|
| 136 |
+
self.joint_timestamps = deque(maxlen=50)
|
| 137 |
+
self.arm_command_timestamps = deque(maxlen=50)
|
| 138 |
+
self.gripper_command_timestamps = deque(maxlen=50)
|
| 139 |
+
time.sleep(0.1)
|
| 140 |
+
|
| 141 |
+
def puppet_state_cb(self, data):
|
| 142 |
+
self.qpos = data.position
|
| 143 |
+
self.qvel = data.velocity
|
| 144 |
+
self.effort = data.effort
|
| 145 |
+
self.data = data
|
| 146 |
+
if self.is_debug:
|
| 147 |
+
self.joint_timestamps.append(time.time())
|
| 148 |
+
|
| 149 |
+
def puppet_arm_commands_cb(self, data):
|
| 150 |
+
self.arm_command = data.cmd
|
| 151 |
+
if self.is_debug:
|
| 152 |
+
self.arm_command_timestamps.append(time.time())
|
| 153 |
+
|
| 154 |
+
def puppet_gripper_commands_cb(self, data):
|
| 155 |
+
self.gripper_command = data.cmd
|
| 156 |
+
if self.is_debug:
|
| 157 |
+
self.gripper_command_timestamps.append(time.time())
|
| 158 |
+
|
| 159 |
+
def print_diagnostics(self):
|
| 160 |
+
def dt_helper(l):
|
| 161 |
+
l = np.array(l)
|
| 162 |
+
diff = l[1:] - l[:-1]
|
| 163 |
+
return np.mean(diff)
|
| 164 |
+
|
| 165 |
+
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
| 166 |
+
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
| 167 |
+
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
| 168 |
+
|
| 169 |
+
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_arm_joint_positions(bot):
|
| 173 |
+
return bot.arm.core.joint_states.position[:6]
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_arm_gripper_positions(bot):
|
| 177 |
+
return bot.gripper.core.joint_states.position[6]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def move_arms(bot_list, target_pose_list, move_time=1):
|
| 181 |
+
num_steps = int(move_time / constants.DT)
|
| 182 |
+
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
| 183 |
+
traj_list = [
|
| 184 |
+
np.linspace(curr_pose, target_pose, num_steps)
|
| 185 |
+
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
| 186 |
+
]
|
| 187 |
+
for t in range(num_steps):
|
| 188 |
+
for bot_id, bot in enumerate(bot_list):
|
| 189 |
+
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
| 190 |
+
time.sleep(constants.DT)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def move_grippers(bot_list, target_pose_list, move_time):
|
| 194 |
+
print(f"Moving grippers to {target_pose_list=}")
|
| 195 |
+
gripper_command = JointSingleCommand(name="gripper")
|
| 196 |
+
num_steps = int(move_time / constants.DT)
|
| 197 |
+
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
| 198 |
+
traj_list = [
|
| 199 |
+
np.linspace(curr_pose, target_pose, num_steps)
|
| 200 |
+
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
| 204 |
+
for t in range(num_steps):
|
| 205 |
+
d = {}
|
| 206 |
+
for bot_id, bot in enumerate(bot_list):
|
| 207 |
+
gripper_command.cmd = traj_list[bot_id][t]
|
| 208 |
+
bot.gripper.core.pub_single.publish(gripper_command)
|
| 209 |
+
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
| 210 |
+
f.write(json.dumps(d) + "\n")
|
| 211 |
+
time.sleep(constants.DT)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def setup_puppet_bot(bot):
|
| 215 |
+
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
| 216 |
+
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
| 217 |
+
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
| 218 |
+
torque_on(bot)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def setup_master_bot(bot):
|
| 222 |
+
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
| 223 |
+
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
| 224 |
+
torque_off(bot)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def set_standard_pid_gains(bot):
|
| 228 |
+
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
| 229 |
+
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def set_low_pid_gains(bot):
|
| 233 |
+
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
| 234 |
+
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def torque_off(bot):
|
| 238 |
+
bot.dxl.robot_torque_enable("group", "arm", False)
|
| 239 |
+
bot.dxl.robot_torque_enable("single", "gripper", False)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def torque_on(bot):
|
| 243 |
+
bot.dxl.robot_torque_enable("group", "arm", True)
|
| 244 |
+
bot.dxl.robot_torque_enable("single", "gripper", True)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# for DAgger
|
| 248 |
+
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
| 249 |
+
print("\nSyncing!")
|
| 250 |
+
|
| 251 |
+
# activate master arms
|
| 252 |
+
torque_on(master_bot_left)
|
| 253 |
+
torque_on(master_bot_right)
|
| 254 |
+
|
| 255 |
+
# get puppet arm positions
|
| 256 |
+
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
| 257 |
+
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
| 258 |
+
|
| 259 |
+
# get puppet gripper positions
|
| 260 |
+
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
| 261 |
+
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
| 262 |
+
|
| 263 |
+
# move master arms to puppet positions
|
| 264 |
+
move_arms(
|
| 265 |
+
[master_bot_left, master_bot_right],
|
| 266 |
+
[puppet_left_qpos, puppet_right_qpos],
|
| 267 |
+
move_time=1,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# move master grippers to puppet positions
|
| 271 |
+
move_grippers(
|
| 272 |
+
[master_bot_left, master_bot_right],
|
| 273 |
+
[puppet_left_gripper, puppet_right_gripper],
|
| 274 |
+
move_time=1,
|
| 275 |
+
)
|
capvector-pi05/examples/aloha_real/video_display.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
from openpi_client.runtime import subscriber as _subscriber
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VideoDisplay(_subscriber.Subscriber):
|
| 8 |
+
"""Displays video frames."""
|
| 9 |
+
|
| 10 |
+
def __init__(self) -> None:
|
| 11 |
+
self._ax: plt.Axes | None = None
|
| 12 |
+
self._plt_img: plt.Image | None = None
|
| 13 |
+
|
| 14 |
+
@override
|
| 15 |
+
def on_episode_start(self) -> None:
|
| 16 |
+
plt.ion()
|
| 17 |
+
self._ax = plt.subplot()
|
| 18 |
+
self._plt_img = None
|
| 19 |
+
|
| 20 |
+
@override
|
| 21 |
+
def on_step(self, observation: dict, action: dict) -> None:
|
| 22 |
+
assert self._ax is not None
|
| 23 |
+
|
| 24 |
+
im = observation["image"][0] # [C, H, W]
|
| 25 |
+
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
| 26 |
+
|
| 27 |
+
if self._plt_img is None:
|
| 28 |
+
self._plt_img = self._ax.imshow(im)
|
| 29 |
+
else:
|
| 30 |
+
self._plt_img.set_data(im)
|
| 31 |
+
plt.pause(0.001)
|
| 32 |
+
|
| 33 |
+
@override
|
| 34 |
+
def on_episode_end(self) -> None:
|
| 35 |
+
plt.ioff()
|
| 36 |
+
plt.close()
|
capvector-pi05/examples/aloha_sim/Dockerfile
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for the Aloha simulation environment.
|
| 2 |
+
|
| 3 |
+
# Build the container:
|
| 4 |
+
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
|
| 5 |
+
|
| 6 |
+
# Run the container:
|
| 7 |
+
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
|
| 8 |
+
|
| 9 |
+
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
|
| 10 |
+
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
| 11 |
+
|
| 12 |
+
RUN apt-get update && \
|
| 13 |
+
apt-get install -y \
|
| 14 |
+
libosmesa6-dev \
|
| 15 |
+
libgl1-mesa-glx \
|
| 16 |
+
libglew-dev \
|
| 17 |
+
libglfw3-dev \
|
| 18 |
+
libgles2-mesa-dev
|
| 19 |
+
ENV MUJOCO_GL=egl
|
| 20 |
+
|
| 21 |
+
WORKDIR /app
|
| 22 |
+
|
| 23 |
+
# Copy from the cache instead of linking since it's a mounted volume
|
| 24 |
+
ENV UV_LINK_MODE=copy
|
| 25 |
+
|
| 26 |
+
# Write the virtual environment outside of the project directory so it doesn't
|
| 27 |
+
# leak out of the container when we mount the application code.
|
| 28 |
+
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
| 29 |
+
|
| 30 |
+
# Copy the requirements files so we can install dependencies.
|
| 31 |
+
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
| 32 |
+
# This strategy is best for development-style usage.
|
| 33 |
+
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
|
| 34 |
+
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
| 35 |
+
|
| 36 |
+
# Install python dependencies.
|
| 37 |
+
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
| 38 |
+
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
| 39 |
+
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
| 40 |
+
|
| 41 |
+
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
|
capvector-pi05/examples/aloha_sim/README.md
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run Aloha Sim
|
| 2 |
+
|
| 3 |
+
## With Docker
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
export SERVER_ARGS="--env ALOHA_SIM"
|
| 7 |
+
docker compose -f examples/aloha_sim/compose.yml up --build
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
## Without Docker
|
| 11 |
+
|
| 12 |
+
Terminal window 1:
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
# Create virtual environment
|
| 16 |
+
uv venv --python 3.10 examples/aloha_sim/.venv
|
| 17 |
+
source examples/aloha_sim/.venv/bin/activate
|
| 18 |
+
uv pip sync examples/aloha_sim/requirements.txt
|
| 19 |
+
uv pip install -e packages/openpi-client
|
| 20 |
+
|
| 21 |
+
# Run the simulation
|
| 22 |
+
MUJOCO_GL=egl python examples/aloha_sim/main.py
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Note: If you are seeing EGL errors, you may need to install the following dependencies:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Terminal window 2:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
# Run the server
|
| 35 |
+
uv run scripts/serve_policy.py --env ALOHA_SIM
|
| 36 |
+
```
|
capvector-pi05/examples/aloha_sim/compose.yml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run with:
|
| 2 |
+
# docker compose -f examples/aloha_sim/compose.yml up --build
|
| 3 |
+
services:
|
| 4 |
+
runtime:
|
| 5 |
+
image: aloha_sim
|
| 6 |
+
depends_on:
|
| 7 |
+
- openpi_server
|
| 8 |
+
build:
|
| 9 |
+
context: ../..
|
| 10 |
+
dockerfile: examples/aloha_sim/Dockerfile
|
| 11 |
+
init: true
|
| 12 |
+
tty: true
|
| 13 |
+
network_mode: host
|
| 14 |
+
privileged: true
|
| 15 |
+
volumes:
|
| 16 |
+
- $PWD:/app
|
| 17 |
+
- ../../data:/data
|
| 18 |
+
|
| 19 |
+
openpi_server:
|
| 20 |
+
image: openpi_server
|
| 21 |
+
build:
|
| 22 |
+
context: ../..
|
| 23 |
+
dockerfile: scripts/docker/serve_policy.Dockerfile
|
| 24 |
+
init: true
|
| 25 |
+
tty: true
|
| 26 |
+
network_mode: host
|
| 27 |
+
volumes:
|
| 28 |
+
- $PWD:/app
|
| 29 |
+
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
| 30 |
+
environment:
|
| 31 |
+
- SERVER_ARGS
|
| 32 |
+
- OPENPI_DATA_HOME=/openpi_assets
|
| 33 |
+
- IS_DOCKER=true
|
| 34 |
+
|
| 35 |
+
# Comment out this block if not running on a machine with GPUs.
|
| 36 |
+
deploy:
|
| 37 |
+
resources:
|
| 38 |
+
reservations:
|
| 39 |
+
devices:
|
| 40 |
+
- driver: nvidia
|
| 41 |
+
count: 1
|
| 42 |
+
capabilities: [gpu]
|
capvector-pi05/examples/aloha_sim/env.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gym_aloha # noqa: F401
|
| 2 |
+
import gymnasium
|
| 3 |
+
import numpy as np
|
| 4 |
+
from openpi_client import image_tools
|
| 5 |
+
from openpi_client.runtime import environment as _environment
|
| 6 |
+
from typing_extensions import override
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AlohaSimEnvironment(_environment.Environment):
|
| 10 |
+
"""An environment for an Aloha robot in simulation."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
|
| 13 |
+
np.random.seed(seed)
|
| 14 |
+
self._rng = np.random.default_rng(seed)
|
| 15 |
+
|
| 16 |
+
self._gym = gymnasium.make(task, obs_type=obs_type)
|
| 17 |
+
|
| 18 |
+
self._last_obs = None
|
| 19 |
+
self._done = True
|
| 20 |
+
self._episode_reward = 0.0
|
| 21 |
+
|
| 22 |
+
@override
|
| 23 |
+
def reset(self) -> None:
|
| 24 |
+
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
| 25 |
+
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
| 26 |
+
self._done = False
|
| 27 |
+
self._episode_reward = 0.0
|
| 28 |
+
|
| 29 |
+
@override
|
| 30 |
+
def is_episode_complete(self) -> bool:
|
| 31 |
+
return self._done
|
| 32 |
+
|
| 33 |
+
@override
|
| 34 |
+
def get_observation(self) -> dict:
|
| 35 |
+
if self._last_obs is None:
|
| 36 |
+
raise RuntimeError("Observation is not set. Call reset() first.")
|
| 37 |
+
|
| 38 |
+
return self._last_obs # type: ignore
|
| 39 |
+
|
| 40 |
+
@override
|
| 41 |
+
def apply_action(self, action: dict) -> None:
|
| 42 |
+
gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
|
| 43 |
+
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
| 44 |
+
self._done = terminated or truncated
|
| 45 |
+
self._episode_reward = max(self._episode_reward, reward)
|
| 46 |
+
|
| 47 |
+
def _convert_observation(self, gym_obs: dict) -> dict:
|
| 48 |
+
img = gym_obs["pixels"]["top"]
|
| 49 |
+
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
|
| 50 |
+
# Convert axis order from [H, W, C] --> [C, H, W]
|
| 51 |
+
img = np.transpose(img, (2, 0, 1))
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"state": gym_obs["agent_pos"],
|
| 55 |
+
"images": {"cam_high": img},
|
| 56 |
+
}
|
capvector-pi05/examples/aloha_sim/main.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
import pathlib
|
| 4 |
+
|
| 5 |
+
import env as _env
|
| 6 |
+
from openpi_client import action_chunk_broker
|
| 7 |
+
from openpi_client import websocket_client_policy as _websocket_client_policy
|
| 8 |
+
from openpi_client.runtime import runtime as _runtime
|
| 9 |
+
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
| 10 |
+
import saver as _saver
|
| 11 |
+
import tyro
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclasses.dataclass
|
| 15 |
+
class Args:
|
| 16 |
+
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
|
| 17 |
+
|
| 18 |
+
task: str = "gym_aloha/AlohaTransferCube-v0"
|
| 19 |
+
seed: int = 0
|
| 20 |
+
|
| 21 |
+
action_horizon: int = 10
|
| 22 |
+
|
| 23 |
+
host: str = "0.0.0.0"
|
| 24 |
+
port: int = 8000
|
| 25 |
+
|
| 26 |
+
display: bool = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main(args: Args) -> None:
|
| 30 |
+
runtime = _runtime.Runtime(
|
| 31 |
+
environment=_env.AlohaSimEnvironment(
|
| 32 |
+
task=args.task,
|
| 33 |
+
seed=args.seed,
|
| 34 |
+
),
|
| 35 |
+
agent=_policy_agent.PolicyAgent(
|
| 36 |
+
policy=action_chunk_broker.ActionChunkBroker(
|
| 37 |
+
policy=_websocket_client_policy.WebsocketClientPolicy(
|
| 38 |
+
host=args.host,
|
| 39 |
+
port=args.port,
|
| 40 |
+
),
|
| 41 |
+
action_horizon=args.action_horizon,
|
| 42 |
+
)
|
| 43 |
+
),
|
| 44 |
+
subscribers=[
|
| 45 |
+
_saver.VideoSaver(args.out_dir),
|
| 46 |
+
],
|
| 47 |
+
max_hz=50,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
runtime.run()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
logging.basicConfig(level=logging.INFO, force=True)
|
| 55 |
+
tyro.cli(main)
|
capvector-pi05/examples/aloha_sim/requirements.in
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gym-aloha
|
| 2 |
+
imageio
|
| 3 |
+
matplotlib
|
| 4 |
+
msgpack
|
| 5 |
+
numpy>=1.22.4,<2.0.0
|
| 6 |
+
typing-extensions
|
| 7 |
+
tyro
|
| 8 |
+
websockets
|
capvector-pi05/examples/aloha_sim/requirements.txt
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
|
| 3 |
+
absl-py==2.1.0
|
| 4 |
+
# via
|
| 5 |
+
# dm-control
|
| 6 |
+
# dm-env
|
| 7 |
+
# labmaze
|
| 8 |
+
# mujoco
|
| 9 |
+
certifi==2024.8.30
|
| 10 |
+
# via requests
|
| 11 |
+
charset-normalizer==3.4.0
|
| 12 |
+
# via requests
|
| 13 |
+
cloudpickle==3.1.0
|
| 14 |
+
# via gymnasium
|
| 15 |
+
contourpy==1.3.1
|
| 16 |
+
# via matplotlib
|
| 17 |
+
cycler==0.12.1
|
| 18 |
+
# via matplotlib
|
| 19 |
+
dm-control==1.0.14
|
| 20 |
+
# via gym-aloha
|
| 21 |
+
dm-env==1.6
|
| 22 |
+
# via dm-control
|
| 23 |
+
dm-tree==0.1.8
|
| 24 |
+
# via
|
| 25 |
+
# dm-control
|
| 26 |
+
# dm-env
|
| 27 |
+
docstring-parser==0.16
|
| 28 |
+
# via tyro
|
| 29 |
+
farama-notifications==0.0.4
|
| 30 |
+
# via gymnasium
|
| 31 |
+
fonttools==4.55.2
|
| 32 |
+
# via matplotlib
|
| 33 |
+
glfw==2.8.0
|
| 34 |
+
# via
|
| 35 |
+
# dm-control
|
| 36 |
+
# mujoco
|
| 37 |
+
gym-aloha==0.1.1
|
| 38 |
+
# via -r examples/aloha_sim/requirements.in
|
| 39 |
+
gymnasium==1.0.0
|
| 40 |
+
# via gym-aloha
|
| 41 |
+
idna==3.10
|
| 42 |
+
# via requests
|
| 43 |
+
imageio==2.36.1
|
| 44 |
+
# via
|
| 45 |
+
# -r examples/aloha_sim/requirements.in
|
| 46 |
+
# gym-aloha
|
| 47 |
+
imageio-ffmpeg==0.5.1
|
| 48 |
+
# via imageio
|
| 49 |
+
kiwisolver==1.4.7
|
| 50 |
+
# via matplotlib
|
| 51 |
+
labmaze==1.0.6
|
| 52 |
+
# via dm-control
|
| 53 |
+
lxml==5.3.0
|
| 54 |
+
# via dm-control
|
| 55 |
+
markdown-it-py==3.0.0
|
| 56 |
+
# via rich
|
| 57 |
+
matplotlib==3.9.3
|
| 58 |
+
# via -r examples/aloha_sim/requirements.in
|
| 59 |
+
mdurl==0.1.2
|
| 60 |
+
# via markdown-it-py
|
| 61 |
+
msgpack==1.1.0
|
| 62 |
+
# via -r examples/aloha_sim/requirements.in
|
| 63 |
+
mujoco==2.3.7
|
| 64 |
+
# via
|
| 65 |
+
# dm-control
|
| 66 |
+
# gym-aloha
|
| 67 |
+
numpy==1.26.4
|
| 68 |
+
# via
|
| 69 |
+
# -r examples/aloha_sim/requirements.in
|
| 70 |
+
# contourpy
|
| 71 |
+
# dm-control
|
| 72 |
+
# dm-env
|
| 73 |
+
# gymnasium
|
| 74 |
+
# imageio
|
| 75 |
+
# labmaze
|
| 76 |
+
# matplotlib
|
| 77 |
+
# mujoco
|
| 78 |
+
# scipy
|
| 79 |
+
packaging==24.2
|
| 80 |
+
# via matplotlib
|
| 81 |
+
pillow==11.0.0
|
| 82 |
+
# via
|
| 83 |
+
# imageio
|
| 84 |
+
# matplotlib
|
| 85 |
+
protobuf==5.29.1
|
| 86 |
+
# via dm-control
|
| 87 |
+
psutil==6.1.0
|
| 88 |
+
# via imageio
|
| 89 |
+
pygments==2.18.0
|
| 90 |
+
# via rich
|
| 91 |
+
pyopengl==3.1.7
|
| 92 |
+
# via
|
| 93 |
+
# dm-control
|
| 94 |
+
# mujoco
|
| 95 |
+
pyparsing==3.2.0
|
| 96 |
+
# via
|
| 97 |
+
# dm-control
|
| 98 |
+
# matplotlib
|
| 99 |
+
python-dateutil==2.9.0.post0
|
| 100 |
+
# via matplotlib
|
| 101 |
+
requests==2.32.3
|
| 102 |
+
# via dm-control
|
| 103 |
+
rich==13.9.4
|
| 104 |
+
# via tyro
|
| 105 |
+
scipy==1.14.1
|
| 106 |
+
# via dm-control
|
| 107 |
+
setuptools==75.6.0
|
| 108 |
+
# via
|
| 109 |
+
# dm-control
|
| 110 |
+
# imageio-ffmpeg
|
| 111 |
+
# labmaze
|
| 112 |
+
shtab==1.7.1
|
| 113 |
+
# via tyro
|
| 114 |
+
six==1.17.0
|
| 115 |
+
# via python-dateutil
|
| 116 |
+
tqdm==4.67.1
|
| 117 |
+
# via dm-control
|
| 118 |
+
typeguard==4.4.1
|
| 119 |
+
# via tyro
|
| 120 |
+
typing-extensions==4.12.2
|
| 121 |
+
# via
|
| 122 |
+
# -r examples/aloha_sim/requirements.in
|
| 123 |
+
# gymnasium
|
| 124 |
+
# rich
|
| 125 |
+
# typeguard
|
| 126 |
+
# tyro
|
| 127 |
+
tyro==0.9.2
|
| 128 |
+
# via -r examples/aloha_sim/requirements.in
|
| 129 |
+
urllib3==2.2.3
|
| 130 |
+
# via requests
|
| 131 |
+
websockets==14.1
|
| 132 |
+
# via -r examples/aloha_sim/requirements.in
|
capvector-pi05/examples/aloha_sim/saver.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import pathlib
|
| 3 |
+
|
| 4 |
+
import imageio
|
| 5 |
+
import numpy as np
|
| 6 |
+
from openpi_client.runtime import subscriber as _subscriber
|
| 7 |
+
from typing_extensions import override
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VideoSaver(_subscriber.Subscriber):
|
| 11 |
+
"""Saves episode data."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
|
| 14 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 15 |
+
self._out_dir = out_dir
|
| 16 |
+
self._images: list[np.ndarray] = []
|
| 17 |
+
self._subsample = subsample
|
| 18 |
+
|
| 19 |
+
@override
|
| 20 |
+
def on_episode_start(self) -> None:
|
| 21 |
+
self._images = []
|
| 22 |
+
|
| 23 |
+
@override
|
| 24 |
+
def on_step(self, observation: dict, action: dict) -> None:
|
| 25 |
+
im = observation["images"]["cam_high"] # [C, H, W]
|
| 26 |
+
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
| 27 |
+
self._images.append(im)
|
| 28 |
+
|
| 29 |
+
@override
|
| 30 |
+
def on_episode_end(self) -> None:
|
| 31 |
+
existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
|
| 32 |
+
next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
|
| 33 |
+
out_path = self._out_dir / f"out_{next_idx}.mp4"
|
| 34 |
+
|
| 35 |
+
logging.info(f"Saving video to {out_path}")
|
| 36 |
+
imageio.mimwrite(
|
| 37 |
+
out_path,
|
| 38 |
+
[np.asarray(x) for x in self._images[:: self._subsample]],
|
| 39 |
+
fps=50 // max(1, self._subsample),
|
| 40 |
+
)
|
capvector-pi05/examples/convert_jax_model_to_pytorch.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
|
| 4 |
+
|
| 5 |
+
This script loads a JAX model checkpoint using orbax and can either:
|
| 6 |
+
1. Print out all the parameter keys in a hierarchical structure for inspection
|
| 7 |
+
2. Convert the JAX model to PyTorch format using our PI0Pytorch model
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
# Just inspect keys:
|
| 11 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
| 12 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
| 13 |
+
|
| 14 |
+
# Convert to PyTorch:
|
| 15 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
| 16 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
# pi0_droid
|
| 20 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
|
| 21 |
+
|
| 22 |
+
# pi0_aloha_sim
|
| 23 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
| 24 |
+
|
| 25 |
+
# pi05_droid
|
| 26 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import json
|
| 30 |
+
import os
|
| 31 |
+
import pathlib
|
| 32 |
+
import shutil
|
| 33 |
+
from typing import Literal
|
| 34 |
+
|
| 35 |
+
from flax.nnx import traversals
|
| 36 |
+
import numpy as np
|
| 37 |
+
import orbax.checkpoint as ocp
|
| 38 |
+
import safetensors
|
| 39 |
+
import torch
|
| 40 |
+
import tyro
|
| 41 |
+
|
| 42 |
+
import openpi.models.gemma
|
| 43 |
+
import openpi.models.model
|
| 44 |
+
import openpi.models.pi0_config
|
| 45 |
+
import openpi.models_pytorch.pi0_pytorch
|
| 46 |
+
from openpi.training import utils
|
| 47 |
+
import openpi.training.config as _config
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def slice_paligemma_state_dict(state_dict, config):
|
| 51 |
+
"""Convert PaliGemma JAX parameters to PyTorch format."""
|
| 52 |
+
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
| 53 |
+
|
| 54 |
+
# patch embeddings
|
| 55 |
+
jax_key = f"img/embedding/kernel{suffix}"
|
| 56 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
|
| 57 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
|
| 58 |
+
|
| 59 |
+
jax_key = f"img/embedding/bias{suffix}"
|
| 60 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
|
| 61 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 62 |
+
|
| 63 |
+
# positional embeddings
|
| 64 |
+
jax_key = f"img/pos_embedding{suffix}"
|
| 65 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
|
| 66 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
|
| 67 |
+
|
| 68 |
+
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
| 69 |
+
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
| 70 |
+
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
| 71 |
+
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
| 72 |
+
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
| 73 |
+
|
| 74 |
+
encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
| 75 |
+
encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
| 76 |
+
encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
| 77 |
+
encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
| 78 |
+
|
| 79 |
+
encoderblock_attention_0_key_kernel = state_dict.pop(
|
| 80 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
|
| 81 |
+
)
|
| 82 |
+
encoderblock_attention_0_key_bias = state_dict.pop(
|
| 83 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
|
| 84 |
+
)
|
| 85 |
+
encoderblock_attention_0_value_kernel = state_dict.pop(
|
| 86 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
|
| 87 |
+
)
|
| 88 |
+
encoderblock_attention_0_value_bias = state_dict.pop(
|
| 89 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
|
| 90 |
+
)
|
| 91 |
+
encoderblock_attention_0_query_kernel = state_dict.pop(
|
| 92 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
|
| 93 |
+
)
|
| 94 |
+
encoderblock_attention_0_query_bias = state_dict.pop(
|
| 95 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
|
| 96 |
+
)
|
| 97 |
+
encoderblock_attention_0_out_kernel = state_dict.pop(
|
| 98 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
|
| 99 |
+
)
|
| 100 |
+
encoderblock_attention_0_out_bias = state_dict.pop(
|
| 101 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
for i in range(config.vision_config.num_hidden_layers):
|
| 105 |
+
state_dict[
|
| 106 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
|
| 107 |
+
] = encoderblock_layernorm0_scale[i].transpose()
|
| 108 |
+
state_dict[
|
| 109 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
|
| 110 |
+
] = encoderblock_layernorm0_bias[i]
|
| 111 |
+
state_dict[
|
| 112 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
|
| 113 |
+
] = encoderblock_layernorm1_scale[i].transpose()
|
| 114 |
+
state_dict[
|
| 115 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
|
| 116 |
+
] = encoderblock_layernorm1_bias[i]
|
| 117 |
+
state_dict[
|
| 118 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
|
| 119 |
+
] = encoderblock_mlp_dense0_kernel[i].transpose()
|
| 120 |
+
state_dict[
|
| 121 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
|
| 122 |
+
] = encoderblock_mlp_dense0_bias[i]
|
| 123 |
+
state_dict[
|
| 124 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
|
| 125 |
+
] = encoderblock_mlp_dense1_kernel[i].transpose()
|
| 126 |
+
state_dict[
|
| 127 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
|
| 128 |
+
] = encoderblock_mlp_dense1_bias[i]
|
| 129 |
+
state_dict[
|
| 130 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
|
| 131 |
+
] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 132 |
+
state_dict[
|
| 133 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
|
| 134 |
+
] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 135 |
+
state_dict[
|
| 136 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
|
| 137 |
+
] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 138 |
+
state_dict[
|
| 139 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
|
| 140 |
+
] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 141 |
+
state_dict[
|
| 142 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
|
| 143 |
+
] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 144 |
+
state_dict[
|
| 145 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
|
| 146 |
+
] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 147 |
+
state_dict[
|
| 148 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
|
| 149 |
+
] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 150 |
+
state_dict[
|
| 151 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
|
| 152 |
+
] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 153 |
+
|
| 154 |
+
jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
|
| 155 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
|
| 156 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
| 157 |
+
|
| 158 |
+
jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
|
| 159 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
|
| 160 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 161 |
+
|
| 162 |
+
# multimodal projector
|
| 163 |
+
jax_key = f"img/head/kernel{suffix}"
|
| 164 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
|
| 165 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
| 166 |
+
|
| 167 |
+
jax_key = f"img/head/bias{suffix}"
|
| 168 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
|
| 169 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 170 |
+
|
| 171 |
+
# text decoder (gemma)
|
| 172 |
+
jax_key = f"llm/embedder/input_embedding{suffix}"
|
| 173 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
| 174 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 175 |
+
|
| 176 |
+
# pop the einsum attention + mlp representations
|
| 177 |
+
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
| 178 |
+
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
| 179 |
+
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
| 180 |
+
|
| 181 |
+
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
| 182 |
+
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
| 183 |
+
|
| 184 |
+
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
| 185 |
+
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
| 186 |
+
|
| 187 |
+
for i in range(config.text_config.num_hidden_layers):
|
| 188 |
+
q_proj_weight_reshaped = (
|
| 189 |
+
llm_attention_q_einsum[i]
|
| 190 |
+
.transpose(0, 2, 1)
|
| 191 |
+
.reshape(
|
| 192 |
+
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
|
| 196 |
+
q_proj_weight_reshaped
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
| 200 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
|
| 201 |
+
k_proj_weight_reshaped
|
| 202 |
+
)
|
| 203 |
+
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
| 204 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
|
| 205 |
+
v_proj_weight_reshaped
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
o_proj_weight_reshaped = (
|
| 209 |
+
llm_attention_attn_vec_einsum[i]
|
| 210 |
+
.transpose(2, 0, 1)
|
| 211 |
+
.reshape(
|
| 212 |
+
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
|
| 216 |
+
o_proj_weight_reshaped
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
| 220 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
|
| 221 |
+
gate_proj_weight.transpose()
|
| 222 |
+
)
|
| 223 |
+
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
| 224 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
|
| 225 |
+
up_proj_weight.transpose()
|
| 226 |
+
)
|
| 227 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
|
| 228 |
+
llm_mlp_linear[i].transpose()
|
| 229 |
+
)
|
| 230 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
|
| 231 |
+
llm_input_layernorm[i]
|
| 232 |
+
)
|
| 233 |
+
state_dict[
|
| 234 |
+
f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
|
| 235 |
+
] = llm_post_attention_layernorm[i]
|
| 236 |
+
|
| 237 |
+
jax_key = f"llm/final_norm/scale{suffix}"
|
| 238 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
|
| 239 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 240 |
+
|
| 241 |
+
expert_dict = {}
|
| 242 |
+
final_state_dict = {}
|
| 243 |
+
|
| 244 |
+
# Expert-related keys to extract (including pi05 Dense layer parameters)
|
| 245 |
+
expert_keys = [
|
| 246 |
+
f"llm/final_norm_1/scale{suffix}",
|
| 247 |
+
f"llm/final_norm_1/Dense_0/bias{suffix}",
|
| 248 |
+
f"llm/final_norm_1/Dense_0/kernel{suffix}",
|
| 249 |
+
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
| 250 |
+
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
| 251 |
+
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
| 252 |
+
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
| 253 |
+
f"llm/layers/mlp_1/linear{suffix}",
|
| 254 |
+
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
| 255 |
+
f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
|
| 256 |
+
f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
|
| 257 |
+
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
| 258 |
+
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
|
| 259 |
+
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
for key, value in state_dict.items():
|
| 263 |
+
if key not in expert_keys:
|
| 264 |
+
final_state_dict[key] = torch.from_numpy(value)
|
| 265 |
+
else:
|
| 266 |
+
expert_dict[key] = value
|
| 267 |
+
|
| 268 |
+
return final_state_dict, expert_dict
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
|
| 272 |
+
"""Convert Gemma JAX parameters to PyTorch format."""
|
| 273 |
+
# Add missing attributes to config if they don't exist
|
| 274 |
+
if not hasattr(config, "vocab_size"):
|
| 275 |
+
config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
|
| 276 |
+
if not hasattr(config, "hidden_size"):
|
| 277 |
+
config.hidden_size = config.width
|
| 278 |
+
if not hasattr(config, "num_hidden_layers"):
|
| 279 |
+
config.num_hidden_layers = config.depth
|
| 280 |
+
if not hasattr(config, "num_attention_heads"):
|
| 281 |
+
config.num_attention_heads = config.num_heads
|
| 282 |
+
|
| 283 |
+
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
| 284 |
+
|
| 285 |
+
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
| 286 |
+
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
| 287 |
+
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
| 288 |
+
|
| 289 |
+
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
| 290 |
+
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
| 291 |
+
|
| 292 |
+
# Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
|
| 293 |
+
if "pi05" in checkpoint_dir:
|
| 294 |
+
# Pi05 with adaptive normalization
|
| 295 |
+
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 296 |
+
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 297 |
+
llm_input_layernorm_kernel = state_dict.pop(
|
| 298 |
+
f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
|
| 299 |
+
)
|
| 300 |
+
llm_post_attention_layernorm_kernel = state_dict.pop(
|
| 301 |
+
f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
# Regular pi0 with standard RMSNorm
|
| 305 |
+
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
| 306 |
+
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
| 307 |
+
|
| 308 |
+
for i in range(config.num_hidden_layers):
|
| 309 |
+
q_proj_weight_reshaped = (
|
| 310 |
+
llm_attention_q_einsum[i]
|
| 311 |
+
.transpose(0, 2, 1)
|
| 312 |
+
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
| 313 |
+
)
|
| 314 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
|
| 315 |
+
q_proj_weight_reshaped
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
| 319 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
|
| 320 |
+
k_proj_weight_reshaped
|
| 321 |
+
)
|
| 322 |
+
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
| 323 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
|
| 324 |
+
v_proj_weight_reshaped
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
o_proj_weight_reshaped = (
|
| 328 |
+
llm_attention_attn_vec_einsum[i]
|
| 329 |
+
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
| 330 |
+
.transpose(1, 0)
|
| 331 |
+
)
|
| 332 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
|
| 333 |
+
o_proj_weight_reshaped
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
| 337 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
|
| 338 |
+
gate_proj_weight.transpose()
|
| 339 |
+
)
|
| 340 |
+
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
| 341 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
|
| 342 |
+
up_proj_weight.transpose()
|
| 343 |
+
)
|
| 344 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
|
| 345 |
+
i
|
| 346 |
+
].transpose()
|
| 347 |
+
|
| 348 |
+
if "pi05" in checkpoint_dir:
|
| 349 |
+
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
| 350 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
|
| 351 |
+
llm_input_layernorm_bias[i]
|
| 352 |
+
)
|
| 353 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
|
| 354 |
+
llm_post_attention_layernorm_bias[i]
|
| 355 |
+
)
|
| 356 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
|
| 357 |
+
llm_input_layernorm_kernel[i].transpose()
|
| 358 |
+
)
|
| 359 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
|
| 360 |
+
llm_post_attention_layernorm_kernel[i].transpose()
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
# Regular pi0 with standard RMSNorm
|
| 364 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
|
| 365 |
+
llm_input_layernorm[i]
|
| 366 |
+
)
|
| 367 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
|
| 368 |
+
llm_post_attention_layernorm[i]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Handle final norm layer
|
| 372 |
+
if "pi05" in checkpoint_dir:
|
| 373 |
+
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
| 374 |
+
final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 375 |
+
final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
|
| 376 |
+
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
|
| 377 |
+
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
|
| 378 |
+
else:
|
| 379 |
+
# Regular pi0 with standard RMSNorm
|
| 380 |
+
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
|
| 381 |
+
f"llm/final_norm_{num_expert}/scale{suffix}"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
|
| 385 |
+
|
| 386 |
+
final_state_dict = {}
|
| 387 |
+
for key, value in state_dict.items():
|
| 388 |
+
if not isinstance(value, torch.Tensor):
|
| 389 |
+
final_state_dict[key] = torch.from_numpy(value)
|
| 390 |
+
else:
|
| 391 |
+
final_state_dict[key] = value
|
| 392 |
+
|
| 393 |
+
return final_state_dict
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
|
| 397 |
+
"""Load and process params by restoring via JAX model loader first.
|
| 398 |
+
This respects dtype conversions that occur during model restore.
|
| 399 |
+
"""
|
| 400 |
+
# Use repository restore utility to load a pure dict of params (value suffix removed)
|
| 401 |
+
params = openpi.models.model.restore_params(
|
| 402 |
+
f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def load_jax_model_and_print_keys(checkpoint_dir: str):
|
| 409 |
+
"""
|
| 410 |
+
Load JAX model from checkpoint and print all parameter keys.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
checkpoint_dir: Path to the checkpoint directory
|
| 414 |
+
"""
|
| 415 |
+
checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
|
| 416 |
+
# Initialize checkpointer
|
| 417 |
+
checkpointer = ocp.PyTreeCheckpointer()
|
| 418 |
+
metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
|
| 419 |
+
print(utils.array_tree_to_info(metadata))
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def convert_pi0_checkpoint(
|
| 423 |
+
checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
|
| 424 |
+
):
|
| 425 |
+
"""
|
| 426 |
+
Convert PI0 JAX checkpoint to PyTorch format.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
checkpoint_dir: Path to the JAX checkpoint
|
| 430 |
+
precision: Model precision (float32, bfloat16, float16)
|
| 431 |
+
output_path: Path to save the converted PyTorch model
|
| 432 |
+
model_config: Model config
|
| 433 |
+
"""
|
| 434 |
+
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
|
| 435 |
+
print(f"Model config: {model_config}")
|
| 436 |
+
|
| 437 |
+
# Break down orbax ckpts by restoring via JAX to respect dtype
|
| 438 |
+
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
|
| 439 |
+
|
| 440 |
+
# Process projection params
|
| 441 |
+
if model_config.pi05:
|
| 442 |
+
keys = [
|
| 443 |
+
"action_in_proj",
|
| 444 |
+
"action_out_proj",
|
| 445 |
+
"time_mlp_in",
|
| 446 |
+
"time_mlp_out",
|
| 447 |
+
]
|
| 448 |
+
else:
|
| 449 |
+
keys = [
|
| 450 |
+
"state_proj",
|
| 451 |
+
"action_in_proj",
|
| 452 |
+
"action_out_proj",
|
| 453 |
+
"action_time_mlp_in",
|
| 454 |
+
"action_time_mlp_out",
|
| 455 |
+
]
|
| 456 |
+
|
| 457 |
+
projection_params = {}
|
| 458 |
+
for key in keys:
|
| 459 |
+
kernel_params = initial_params["projection_params"][key]["kernel"]
|
| 460 |
+
bias_params = initial_params["projection_params"][key]["bias"]
|
| 461 |
+
if isinstance(kernel_params, dict):
|
| 462 |
+
weight = kernel_params["value"]
|
| 463 |
+
bias = bias_params["value"]
|
| 464 |
+
else:
|
| 465 |
+
weight = kernel_params
|
| 466 |
+
bias = bias_params
|
| 467 |
+
|
| 468 |
+
pytorch_weight_key = f"{key}.weight"
|
| 469 |
+
pytorch_bias_key = f"{key}.bias"
|
| 470 |
+
|
| 471 |
+
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
|
| 472 |
+
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
|
| 473 |
+
|
| 474 |
+
# Create configs based on checkpoint path
|
| 475 |
+
# All models use the same PaliGemma config structure
|
| 476 |
+
class PaliGemmaConfig:
|
| 477 |
+
def __init__(self):
|
| 478 |
+
self.vision_config = type(
|
| 479 |
+
"obj",
|
| 480 |
+
(object,),
|
| 481 |
+
{
|
| 482 |
+
"hidden_size": 1152,
|
| 483 |
+
"num_hidden_layers": 27,
|
| 484 |
+
"num_attention_heads": 16,
|
| 485 |
+
"intermediate_size": 4304,
|
| 486 |
+
"patch_size": 14,
|
| 487 |
+
"projection_dim": 2048,
|
| 488 |
+
},
|
| 489 |
+
)()
|
| 490 |
+
self.text_config = type(
|
| 491 |
+
"obj",
|
| 492 |
+
(object,),
|
| 493 |
+
{
|
| 494 |
+
"hidden_size": 2048,
|
| 495 |
+
"num_hidden_layers": 18,
|
| 496 |
+
"num_attention_heads": 8,
|
| 497 |
+
"head_dim": 256,
|
| 498 |
+
"intermediate_size": 16384,
|
| 499 |
+
},
|
| 500 |
+
)()
|
| 501 |
+
|
| 502 |
+
paligemma_config = PaliGemmaConfig()
|
| 503 |
+
action_expert_config = openpi.models.gemma.get_config("gemma_300m")
|
| 504 |
+
|
| 505 |
+
# Process PaliGemma weights
|
| 506 |
+
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
|
| 507 |
+
|
| 508 |
+
# Process Gemma weights from expert_params
|
| 509 |
+
gemma_params = slice_gemma_state_dict(
|
| 510 |
+
expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Instantiate model
|
| 514 |
+
pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
|
| 515 |
+
|
| 516 |
+
# Combine all parameters (no prefix needed for our model structure)
|
| 517 |
+
all_params = {**paligemma_params, **gemma_params, **projection_params}
|
| 518 |
+
|
| 519 |
+
# Load state dict
|
| 520 |
+
pi0_model.load_state_dict(all_params, strict=False)
|
| 521 |
+
|
| 522 |
+
if precision == "float32":
|
| 523 |
+
pi0_model = pi0_model.to(torch.float32)
|
| 524 |
+
elif precision == "bfloat16":
|
| 525 |
+
pi0_model = pi0_model.to(torch.bfloat16)
|
| 526 |
+
else:
|
| 527 |
+
raise ValueError(f"Invalid precision: {precision}")
|
| 528 |
+
|
| 529 |
+
# Save the converted model using safetensors
|
| 530 |
+
os.makedirs(output_path, exist_ok=True)
|
| 531 |
+
|
| 532 |
+
# Save model weights as SafeTensors using save_model to handle tied weights
|
| 533 |
+
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
|
| 534 |
+
|
| 535 |
+
# Copy assets folder if it exists
|
| 536 |
+
assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
|
| 537 |
+
if assets_source.exists():
|
| 538 |
+
assets_dest = pathlib.Path(output_path) / "assets"
|
| 539 |
+
if assets_dest.exists():
|
| 540 |
+
shutil.rmtree(assets_dest)
|
| 541 |
+
shutil.copytree(assets_source, assets_dest)
|
| 542 |
+
|
| 543 |
+
# Save config as JSON for reference
|
| 544 |
+
config_dict = {
|
| 545 |
+
"action_dim": model_config.action_dim,
|
| 546 |
+
"action_horizon": model_config.action_horizon,
|
| 547 |
+
"paligemma_variant": model_config.paligemma_variant,
|
| 548 |
+
"action_expert_variant": model_config.action_expert_variant,
|
| 549 |
+
"precision": precision,
|
| 550 |
+
}
|
| 551 |
+
with open(os.path.join(output_path, "config.json"), "w") as f:
|
| 552 |
+
json.dump(config_dict, f, indent=2)
|
| 553 |
+
|
| 554 |
+
print("Model conversion completed successfully!")
|
| 555 |
+
print(f"Model saved to {output_path}")
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def main(
|
| 559 |
+
checkpoint_dir: str,
|
| 560 |
+
config_name: str,
|
| 561 |
+
output_path: str | None = None,
|
| 562 |
+
precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
|
| 563 |
+
*,
|
| 564 |
+
inspect_only: bool = False,
|
| 565 |
+
):
|
| 566 |
+
"""Load JAX model and optionally convert to PyTorch.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
checkpoint_dir: Path to the JAX checkpoint directory
|
| 570 |
+
output_path: Path to save converted PyTorch model (required for conversion)
|
| 571 |
+
precision: Precision for model conversion
|
| 572 |
+
inspect_only: Only inspect parameter keys, don't convert
|
| 573 |
+
"""
|
| 574 |
+
model_config = _config.get_config(config_name).model
|
| 575 |
+
if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
|
| 576 |
+
raise ValueError(f"Config {config_name} is not a Pi0Config")
|
| 577 |
+
if inspect_only:
|
| 578 |
+
load_jax_model_and_print_keys(checkpoint_dir)
|
| 579 |
+
else:
|
| 580 |
+
if not output_path:
|
| 581 |
+
print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
|
| 582 |
+
return
|
| 583 |
+
convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
if __name__ == "__main__":
|
| 587 |
+
tyro.cli(main)
|
capvector-pi05/examples/droid/README.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DROID Policies in openpi
|
| 2 |
+
|
| 3 |
+
We offer instructions for:
|
| 4 |
+
- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
|
| 5 |
+
- [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
|
| 6 |
+
- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
|
| 7 |
+
- [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)
|
| 8 |
+
|
| 9 |
+
## Running DROID Inference
|
| 10 |
+
|
| 11 |
+
This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy.
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
### Step 1: Start a policy server
|
| 15 |
+
|
| 16 |
+
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
|
| 17 |
+
|
| 18 |
+
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
|
| 19 |
+
2. Start the OpenPI server via the following command:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
You can also run the equivalent command below:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
uv run scripts/serve_policy.py --env=DROID
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Step 2: Run the DROID robot
|
| 32 |
+
|
| 33 |
+
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
|
| 34 |
+
2. On the control laptop, activate your DROID conda environment.
|
| 35 |
+
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
|
| 36 |
+
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
|
| 37 |
+
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
|
| 38 |
+
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
| 39 |
+
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
|
| 46 |
+
|
| 47 |
+
## Troubleshooting
|
| 48 |
+
|
| 49 |
+
| Issue | Solution |
|
| 50 |
+
|-------|----------|
|
| 51 |
+
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
|
| 52 |
+
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
|
| 53 |
+
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
|
| 54 |
+
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
## Running Other Policies
|
| 58 |
+
|
| 59 |
+
We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
|
| 60 |
+
|
| 61 |
+
```
|
| 62 |
+
# Train from pi0-FAST, using FAST tokenizer
|
| 63 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
|
| 64 |
+
|
| 65 |
+
# Train from pi0, using flow matching
|
| 66 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid
|
| 67 |
+
|
| 68 |
+
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
|
| 69 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
|
| 70 |
+
|
| 71 |
+
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
|
| 72 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
|
| 73 |
+
|
| 74 |
+
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
|
| 75 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
|
| 76 |
+
|
| 77 |
+
# Trained from PaliGemma, using FSQ tokenizer.
|
| 78 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
|
| 79 |
+
|
| 80 |
+
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
|
| 81 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).
|
capvector-pi05/examples/droid/README_train.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training on DROID
|
| 2 |
+
|
| 3 |
+
Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
|
| 4 |
+
(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.
|
| 5 |
+
|
| 6 |
+
In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough
|
| 7 |
+
for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.
|
| 8 |
+
|
| 9 |
+
## Install
|
| 10 |
+
|
| 11 |
+
We need a few additional dependencies for RLDS data loading. Run:
|
| 12 |
+
```bash
|
| 13 |
+
uv sync --group rlds
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Download DROID dataset
|
| 17 |
+
|
| 18 |
+
You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
|
| 19 |
+
```
|
| 20 |
+
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).
|
| 24 |
+
|
| 25 |
+
You will need 1.8TB of disk storage to download the DROID RLDS dataset.
|
| 26 |
+
|
| 27 |
+
## Run
|
| 28 |
+
|
| 29 |
+
First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).
|
| 30 |
+
|
| 31 |
+
Then, compute normalization statistics (this will take ~10 minutes):
|
| 32 |
+
```bash
|
| 33 |
+
uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Run training:
|
| 37 |
+
```bash
|
| 38 |
+
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
**Note**: The original pi0.5-DROID model was trained with joint velocity actions.
|
| 42 |
+
Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
|
| 43 |
+
Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
## Compute Requirements
|
| 47 |
+
|
| 48 |
+
Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
|
| 49 |
+
If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).
|
| 50 |
+
|
| 51 |
+
We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
## Data Filtering
|
| 55 |
+
|
| 56 |
+
Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
|
| 57 |
+
|
| 58 |
+
By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
|
| 59 |
+
|
| 60 |
+
**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
|
| 61 |
+
|
| 62 |
+
## RoboArena
|
| 63 |
+
|
| 64 |
+
Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)
|
| 65 |
+
|
| 66 |
+
If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Fine-Tuning on Custom DROID Datasets
|
| 70 |
+
|
| 71 |
+
Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.
|
| 72 |
+
|
| 73 |
+
Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## Step 1: Converting your custom DROID dataset to LeRobot
|
| 77 |
+
|
| 78 |
+
We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
|
| 79 |
+
```
|
| 80 |
+
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
|
| 84 |
+
```
|
| 85 |
+
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).
|
| 89 |
+
|
| 90 |
+
Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
|
| 91 |
+
```
|
| 92 |
+
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## Step 2: Run fine-tuning with your custom dataset
|
| 96 |
+
|
| 97 |
+
Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created.
|
| 98 |
+
You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).
|
| 99 |
+
|
| 100 |
+
To launch training:
|
| 101 |
+
```
|
| 102 |
+
uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.
|
| 106 |
+
|
capvector-pi05/examples/droid/compute_droid_nonidle_ranges.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
|
| 3 |
+
that should be sampled during training (all others are filtered out).
|
| 4 |
+
|
| 5 |
+
Filtering logic:
|
| 6 |
+
We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
|
| 7 |
+
(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
|
| 8 |
+
this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
|
| 9 |
+
ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
|
| 10 |
+
filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).
|
| 11 |
+
|
| 12 |
+
This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
|
| 13 |
+
yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import tensorflow as tf
|
| 22 |
+
import tensorflow_datasets as tfds
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU
|
| 26 |
+
|
| 27 |
+
builder = tfds.builder_from_directory(
|
| 28 |
+
# path to the `droid` directory (not its parent)
|
| 29 |
+
builder_dir="<path_to_droid_dataset_tfds_files>",
|
| 30 |
+
)
|
| 31 |
+
ds = builder.as_dataset(split="train", shuffle_files=False)
|
| 32 |
+
tf.data.experimental.ignore_errors(ds)
|
| 33 |
+
|
| 34 |
+
keep_ranges_path = "<path_to_where_to_save_the_json>"
|
| 35 |
+
|
| 36 |
+
min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out
|
| 37 |
+
min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out
|
| 38 |
+
filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range
|
| 39 |
+
|
| 40 |
+
keep_ranges_map = {}
|
| 41 |
+
if Path(keep_ranges_path).exists():
|
| 42 |
+
with Path(keep_ranges_path).open("r") as f:
|
| 43 |
+
keep_ranges_map = json.load(f)
|
| 44 |
+
print(f"Resuming from {len(keep_ranges_map)} episodes already processed")
|
| 45 |
+
|
| 46 |
+
for ep_idx, ep in enumerate(tqdm(ds)):
|
| 47 |
+
recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
|
| 48 |
+
file_path = ep["episode_metadata"]["file_path"].numpy().decode()
|
| 49 |
+
|
| 50 |
+
key = f"{recording_folderpath}--{file_path}"
|
| 51 |
+
if key in keep_ranges_map:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
|
| 55 |
+
joint_velocities = np.array(joint_velocities)
|
| 56 |
+
|
| 57 |
+
is_idle_array = np.hstack(
|
| 58 |
+
[np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Find what steps go from idle to non-idle and vice-versa
|
| 62 |
+
is_idle_padded = np.concatenate(
|
| 63 |
+
[[False], is_idle_array, [False]]
|
| 64 |
+
) # Start and end with False, so idle at first step is a start of motion
|
| 65 |
+
|
| 66 |
+
is_idle_diff = np.diff(is_idle_padded.astype(int))
|
| 67 |
+
is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle
|
| 68 |
+
is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle
|
| 69 |
+
|
| 70 |
+
# Find which steps correspond to idle segments of length at least min_idle_len
|
| 71 |
+
true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
|
| 72 |
+
is_idle_true_starts = is_idle_true_starts[true_segment_masks]
|
| 73 |
+
is_idle_true_ends = is_idle_true_ends[true_segment_masks]
|
| 74 |
+
|
| 75 |
+
keep_mask = np.ones(len(joint_velocities), dtype=bool)
|
| 76 |
+
for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
|
| 77 |
+
keep_mask[start:end] = False
|
| 78 |
+
|
| 79 |
+
# Get all non-idle ranges of at least 16
|
| 80 |
+
# Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
|
| 81 |
+
keep_padded = np.concatenate([[False], keep_mask, [False]])
|
| 82 |
+
|
| 83 |
+
keep_diff = np.diff(keep_padded.astype(int))
|
| 84 |
+
keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep
|
| 85 |
+
keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out
|
| 86 |
+
|
| 87 |
+
# Find which steps correspond to non-idle segments of length at least min_non_idle_len
|
| 88 |
+
true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
|
| 89 |
+
keep_true_starts = keep_true_starts[true_segment_masks]
|
| 90 |
+
keep_true_ends = keep_true_ends[true_segment_masks]
|
| 91 |
+
|
| 92 |
+
# Add mapping from episode unique ID key to list of non-idle ranges to keep
|
| 93 |
+
keep_ranges_map[key] = []
|
| 94 |
+
for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
|
| 95 |
+
keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))
|
| 96 |
+
|
| 97 |
+
if ep_idx % 1000 == 0:
|
| 98 |
+
with Path(keep_ranges_path).open("w") as f:
|
| 99 |
+
json.dump(keep_ranges_map, f)
|
| 100 |
+
|
| 101 |
+
print("Done!")
|
| 102 |
+
with Path(keep_ranges_path).open("w") as f:
|
| 103 |
+
json.dump(keep_ranges_map, f)
|
capvector-pi05/examples/droid/convert_droid_data_to_lerobot.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data
|
| 6 |
+
|
| 7 |
+
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
|
| 8 |
+
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
|
| 9 |
+
|
| 10 |
+
The resulting dataset will get saved to the $LEROBOT_HOME directory.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
import copy
|
| 15 |
+
import glob
|
| 16 |
+
import json
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import shutil
|
| 19 |
+
|
| 20 |
+
import cv2
|
| 21 |
+
import h5py
|
| 22 |
+
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
|
| 23 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 24 |
+
import numpy as np
|
| 25 |
+
from PIL import Image
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
import tyro
|
| 28 |
+
|
| 29 |
+
REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def resize_image(image, size):
|
| 33 |
+
image = Image.fromarray(image)
|
| 34 |
+
return np.array(image.resize(size, resample=Image.BICUBIC))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(data_dir: str, *, push_to_hub: bool = False):
|
| 38 |
+
# Clean up any existing dataset in the output directory
|
| 39 |
+
output_path = HF_LEROBOT_HOME / REPO_NAME
|
| 40 |
+
if output_path.exists():
|
| 41 |
+
shutil.rmtree(output_path)
|
| 42 |
+
data_dir = Path(data_dir)
|
| 43 |
+
|
| 44 |
+
# Create LeRobot dataset, define features to store
|
| 45 |
+
# We will follow the DROID data naming conventions here.
|
| 46 |
+
# LeRobot assumes that dtype of image data is `image`
|
| 47 |
+
dataset = LeRobotDataset.create(
|
| 48 |
+
repo_id=REPO_NAME,
|
| 49 |
+
robot_type="panda",
|
| 50 |
+
fps=15, # DROID data is typically recorded at 15fps
|
| 51 |
+
features={
|
| 52 |
+
# We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
|
| 53 |
+
"exterior_image_1_left": {
|
| 54 |
+
"dtype": "image",
|
| 55 |
+
"shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset
|
| 56 |
+
"names": ["height", "width", "channel"],
|
| 57 |
+
},
|
| 58 |
+
"exterior_image_2_left": {
|
| 59 |
+
"dtype": "image",
|
| 60 |
+
"shape": (180, 320, 3),
|
| 61 |
+
"names": ["height", "width", "channel"],
|
| 62 |
+
},
|
| 63 |
+
"wrist_image_left": {
|
| 64 |
+
"dtype": "image",
|
| 65 |
+
"shape": (180, 320, 3),
|
| 66 |
+
"names": ["height", "width", "channel"],
|
| 67 |
+
},
|
| 68 |
+
"joint_position": {
|
| 69 |
+
"dtype": "float32",
|
| 70 |
+
"shape": (7,),
|
| 71 |
+
"names": ["joint_position"],
|
| 72 |
+
},
|
| 73 |
+
"gripper_position": {
|
| 74 |
+
"dtype": "float32",
|
| 75 |
+
"shape": (1,),
|
| 76 |
+
"names": ["gripper_position"],
|
| 77 |
+
},
|
| 78 |
+
"actions": {
|
| 79 |
+
"dtype": "float32",
|
| 80 |
+
"shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D)
|
| 81 |
+
"names": ["actions"],
|
| 82 |
+
},
|
| 83 |
+
},
|
| 84 |
+
image_writer_threads=10,
|
| 85 |
+
image_writer_processes=5,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Load language annotations
|
| 89 |
+
# Note: we load the DROID language annotations for this example, but you can manually define them for your own data
|
| 90 |
+
with (data_dir / "aggregated-annotations-030724.json").open() as f:
|
| 91 |
+
language_annotations = json.load(f)
|
| 92 |
+
|
| 93 |
+
# Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
|
| 94 |
+
# We assume the following directory structure:
|
| 95 |
+
# RAW_DROID_PATH/
|
| 96 |
+
# - <...>/
|
| 97 |
+
# - recordings/
|
| 98 |
+
# - MP4/
|
| 99 |
+
# - <camera_id>.mp4 # single-view video of left stereo pair camera
|
| 100 |
+
# - trajectory.hdf5
|
| 101 |
+
# - <...>/
|
| 102 |
+
episode_paths = list(data_dir.glob("**/trajectory.h5"))
|
| 103 |
+
print(f"Found {len(episode_paths)} episodes for conversion")
|
| 104 |
+
|
| 105 |
+
# We will loop over each dataset_name and write episodes to the LeRobot dataset
|
| 106 |
+
for episode_path in tqdm(episode_paths, desc="Converting episodes"):
|
| 107 |
+
# Load raw data
|
| 108 |
+
recording_folderpath = episode_path.parent / "recordings" / "MP4"
|
| 109 |
+
trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))
|
| 110 |
+
|
| 111 |
+
# To load the language instruction, we need to parse out the episode_id from the metadata file
|
| 112 |
+
# Again, you can modify this step for your own data, to load your own language instructions
|
| 113 |
+
metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
|
| 114 |
+
episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
|
| 115 |
+
language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
|
| 116 |
+
"language_instruction1"
|
| 117 |
+
]
|
| 118 |
+
print(f"Converting episode with language instruction: {language_instruction}")
|
| 119 |
+
|
| 120 |
+
# Write to LeRobot dataset
|
| 121 |
+
for step in trajectory:
|
| 122 |
+
camera_type_dict = step["observation"]["camera_type"]
|
| 123 |
+
wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
|
| 124 |
+
exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
|
| 125 |
+
dataset.add_frame(
|
| 126 |
+
{
|
| 127 |
+
# Note: need to flip BGR --> RGB for loaded images
|
| 128 |
+
"exterior_image_1_left": resize_image(
|
| 129 |
+
step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
|
| 130 |
+
),
|
| 131 |
+
"exterior_image_2_left": resize_image(
|
| 132 |
+
step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
|
| 133 |
+
),
|
| 134 |
+
"wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
|
| 135 |
+
"joint_position": np.asarray(
|
| 136 |
+
step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
|
| 137 |
+
),
|
| 138 |
+
"gripper_position": np.asarray(
|
| 139 |
+
step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
|
| 140 |
+
),
|
| 141 |
+
# Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
|
| 142 |
+
"actions": np.concatenate(
|
| 143 |
+
[step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
|
| 144 |
+
),
|
| 145 |
+
"task": language_instruction,
|
| 146 |
+
}
|
| 147 |
+
)
|
| 148 |
+
dataset.save_episode()
|
| 149 |
+
|
| 150 |
+
# Optionally push to the Hugging Face Hub
|
| 151 |
+
if push_to_hub:
|
| 152 |
+
dataset.push_to_hub(
|
| 153 |
+
tags=["libero", "panda", "rlds"],
|
| 154 |
+
private=False,
|
| 155 |
+
push_videos=True,
|
| 156 |
+
license="apache-2.0",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
##########################################################################################################
|
| 161 |
+
################ The rest of this file are functions to parse the raw DROID data #########################
|
| 162 |
+
################ You don't need to worry about understanding this part #########################
|
| 163 |
+
################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
|
| 164 |
+
##########################################################################################################
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
camera_type_dict = {
|
| 168 |
+
"hand_camera_id": 0,
|
| 169 |
+
"varied_camera_1_id": 1,
|
| 170 |
+
"varied_camera_2_id": 1,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
camera_type_to_string_dict = {
|
| 174 |
+
0: "hand_camera",
|
| 175 |
+
1: "varied_camera",
|
| 176 |
+
2: "fixed_camera",
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_camera_type(cam_id):
|
| 181 |
+
if cam_id not in camera_type_dict:
|
| 182 |
+
return None
|
| 183 |
+
type_int = camera_type_dict[cam_id]
|
| 184 |
+
return camera_type_to_string_dict[type_int]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class MP4Reader:
|
| 188 |
+
def __init__(self, filepath, serial_number):
|
| 189 |
+
# Save Parameters #
|
| 190 |
+
self.serial_number = serial_number
|
| 191 |
+
self._index = 0
|
| 192 |
+
|
| 193 |
+
# Open Video Reader #
|
| 194 |
+
self._mp4_reader = cv2.VideoCapture(filepath)
|
| 195 |
+
if not self._mp4_reader.isOpened():
|
| 196 |
+
raise RuntimeError("Corrupted MP4 File")
|
| 197 |
+
|
| 198 |
+
def set_reading_parameters(
|
| 199 |
+
self,
|
| 200 |
+
image=True, # noqa: FBT002
|
| 201 |
+
concatenate_images=False, # noqa: FBT002
|
| 202 |
+
resolution=(0, 0),
|
| 203 |
+
resize_func=None,
|
| 204 |
+
):
|
| 205 |
+
# Save Parameters #
|
| 206 |
+
self.image = image
|
| 207 |
+
self.concatenate_images = concatenate_images
|
| 208 |
+
self.resolution = resolution
|
| 209 |
+
self.resize_func = cv2.resize
|
| 210 |
+
self.skip_reading = not image
|
| 211 |
+
if self.skip_reading:
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
def get_frame_resolution(self):
|
| 215 |
+
width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
|
| 216 |
+
height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
|
| 217 |
+
return (width, height)
|
| 218 |
+
|
| 219 |
+
def get_frame_count(self):
|
| 220 |
+
if self.skip_reading:
|
| 221 |
+
return 0
|
| 222 |
+
return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
|
| 223 |
+
|
| 224 |
+
def set_frame_index(self, index):
|
| 225 |
+
if self.skip_reading:
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
if index < self._index:
|
| 229 |
+
self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
|
| 230 |
+
self._index = index
|
| 231 |
+
|
| 232 |
+
while self._index < index:
|
| 233 |
+
self.read_camera(ignore_data=True)
|
| 234 |
+
|
| 235 |
+
def _process_frame(self, frame):
|
| 236 |
+
frame = copy.deepcopy(frame)
|
| 237 |
+
if self.resolution == (0, 0):
|
| 238 |
+
return frame
|
| 239 |
+
return self.resize_func(frame, self.resolution)
|
| 240 |
+
|
| 241 |
+
def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002
|
| 242 |
+
# Skip if Read Unnecesary #
|
| 243 |
+
if self.skip_reading:
|
| 244 |
+
return {}
|
| 245 |
+
|
| 246 |
+
# Read Camera #
|
| 247 |
+
success, frame = self._mp4_reader.read()
|
| 248 |
+
|
| 249 |
+
self._index += 1
|
| 250 |
+
if not success:
|
| 251 |
+
return None
|
| 252 |
+
if ignore_data:
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
# Return Data #
|
| 256 |
+
data_dict = {}
|
| 257 |
+
|
| 258 |
+
if self.concatenate_images or "stereo" not in self.serial_number:
|
| 259 |
+
data_dict["image"] = {self.serial_number: self._process_frame(frame)}
|
| 260 |
+
else:
|
| 261 |
+
single_width = frame.shape[1] // 2
|
| 262 |
+
data_dict["image"] = {
|
| 263 |
+
self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
|
| 264 |
+
self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
return data_dict
|
| 268 |
+
|
| 269 |
+
def disable_camera(self):
|
| 270 |
+
if hasattr(self, "_mp4_reader"):
|
| 271 |
+
self._mp4_reader.release()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class RecordedMultiCameraWrapper:
|
| 275 |
+
def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006
|
| 276 |
+
# Save Camera Info #
|
| 277 |
+
self.camera_kwargs = camera_kwargs
|
| 278 |
+
|
| 279 |
+
# Open Camera Readers #
|
| 280 |
+
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
|
| 281 |
+
all_filepaths = mp4_filepaths
|
| 282 |
+
|
| 283 |
+
self.camera_dict = {}
|
| 284 |
+
for f in all_filepaths:
|
| 285 |
+
serial_number = f.split("/")[-1][:-4]
|
| 286 |
+
cam_type = get_camera_type(serial_number)
|
| 287 |
+
camera_kwargs.get(cam_type, {})
|
| 288 |
+
|
| 289 |
+
if f.endswith(".mp4"):
|
| 290 |
+
Reader = MP4Reader # noqa: N806
|
| 291 |
+
else:
|
| 292 |
+
raise ValueError
|
| 293 |
+
|
| 294 |
+
self.camera_dict[serial_number] = Reader(f, serial_number)
|
| 295 |
+
|
| 296 |
+
def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006
|
| 297 |
+
full_obs_dict = defaultdict(dict)
|
| 298 |
+
|
| 299 |
+
# Read Cameras In Randomized Order #
|
| 300 |
+
all_cam_ids = list(self.camera_dict.keys())
|
| 301 |
+
# random.shuffle(all_cam_ids)
|
| 302 |
+
|
| 303 |
+
for cam_id in all_cam_ids:
|
| 304 |
+
if "stereo" in cam_id:
|
| 305 |
+
continue
|
| 306 |
+
try:
|
| 307 |
+
cam_type = camera_type_dict[cam_id]
|
| 308 |
+
except KeyError:
|
| 309 |
+
print(f"{self.camera_dict} -- {camera_type_dict}")
|
| 310 |
+
raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904
|
| 311 |
+
curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
|
| 312 |
+
self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)
|
| 313 |
+
|
| 314 |
+
timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
|
| 315 |
+
if index is not None:
|
| 316 |
+
self.camera_dict[cam_id].set_frame_index(index)
|
| 317 |
+
|
| 318 |
+
data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)
|
| 319 |
+
|
| 320 |
+
# Process Returned Data #
|
| 321 |
+
if data_dict is None:
|
| 322 |
+
return None
|
| 323 |
+
for key in data_dict:
|
| 324 |
+
full_obs_dict[key].update(data_dict[key])
|
| 325 |
+
|
| 326 |
+
return full_obs_dict
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006
|
| 330 |
+
length = None
|
| 331 |
+
|
| 332 |
+
for key in hdf5_file:
|
| 333 |
+
if key in keys_to_ignore:
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
curr_data = hdf5_file[key]
|
| 337 |
+
if isinstance(curr_data, h5py.Group):
|
| 338 |
+
curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
|
| 339 |
+
elif isinstance(curr_data, h5py.Dataset):
|
| 340 |
+
curr_length = len(curr_data)
|
| 341 |
+
else:
|
| 342 |
+
raise ValueError
|
| 343 |
+
|
| 344 |
+
if length is None:
|
| 345 |
+
length = curr_length
|
| 346 |
+
assert curr_length == length
|
| 347 |
+
|
| 348 |
+
return length
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006
|
| 352 |
+
data_dict = {}
|
| 353 |
+
|
| 354 |
+
for key in hdf5_file:
|
| 355 |
+
if key in keys_to_ignore:
|
| 356 |
+
continue
|
| 357 |
+
|
| 358 |
+
curr_data = hdf5_file[key]
|
| 359 |
+
if isinstance(curr_data, h5py.Group):
|
| 360 |
+
data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
|
| 361 |
+
elif isinstance(curr_data, h5py.Dataset):
|
| 362 |
+
data_dict[key] = curr_data[index]
|
| 363 |
+
else:
|
| 364 |
+
raise ValueError
|
| 365 |
+
|
| 366 |
+
return data_dict
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class TrajectoryReader:
|
| 370 |
+
def __init__(self, filepath, read_images=True): # noqa: FBT002
|
| 371 |
+
self._hdf5_file = h5py.File(filepath, "r")
|
| 372 |
+
is_video_folder = "observations/videos" in self._hdf5_file
|
| 373 |
+
self._read_images = read_images and is_video_folder
|
| 374 |
+
self._length = get_hdf5_length(self._hdf5_file)
|
| 375 |
+
self._video_readers = {}
|
| 376 |
+
self._index = 0
|
| 377 |
+
|
| 378 |
+
def length(self):
|
| 379 |
+
return self._length
|
| 380 |
+
|
| 381 |
+
def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006
|
| 382 |
+
# Make Sure We Read Within Range #
|
| 383 |
+
if index is None:
|
| 384 |
+
index = self._index
|
| 385 |
+
else:
|
| 386 |
+
assert not self._read_images
|
| 387 |
+
self._index = index
|
| 388 |
+
assert index < self._length
|
| 389 |
+
|
| 390 |
+
# Load Low Dimensional Data #
|
| 391 |
+
keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
|
| 392 |
+
timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)
|
| 393 |
+
|
| 394 |
+
# Increment Read Index #
|
| 395 |
+
self._index += 1
|
| 396 |
+
|
| 397 |
+
# Return Timestep #
|
| 398 |
+
return timestep
|
| 399 |
+
|
| 400 |
+
def close(self):
|
| 401 |
+
self._hdf5_file.close()
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def load_trajectory(
|
| 405 |
+
filepath=None,
|
| 406 |
+
read_cameras=True, # noqa: FBT002
|
| 407 |
+
recording_folderpath=None,
|
| 408 |
+
camera_kwargs={}, # noqa: B006
|
| 409 |
+
remove_skipped_steps=False, # noqa: FBT002
|
| 410 |
+
num_samples_per_traj=None,
|
| 411 |
+
num_samples_per_traj_coeff=1.5,
|
| 412 |
+
):
|
| 413 |
+
read_recording_folderpath = read_cameras and (recording_folderpath is not None)
|
| 414 |
+
|
| 415 |
+
traj_reader = TrajectoryReader(filepath)
|
| 416 |
+
if read_recording_folderpath:
|
| 417 |
+
camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
|
| 418 |
+
|
| 419 |
+
horizon = traj_reader.length()
|
| 420 |
+
timestep_list = []
|
| 421 |
+
|
| 422 |
+
# Choose Timesteps To Save #
|
| 423 |
+
if num_samples_per_traj:
|
| 424 |
+
num_to_save = num_samples_per_traj
|
| 425 |
+
if remove_skipped_steps:
|
| 426 |
+
num_to_save = int(num_to_save * num_samples_per_traj_coeff)
|
| 427 |
+
max_size = min(num_to_save, horizon)
|
| 428 |
+
indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
|
| 429 |
+
else:
|
| 430 |
+
indices_to_save = np.arange(horizon)
|
| 431 |
+
|
| 432 |
+
# Iterate Over Trajectory #
|
| 433 |
+
for i in indices_to_save:
|
| 434 |
+
# Get HDF5 Data #
|
| 435 |
+
timestep = traj_reader.read_timestep(index=i)
|
| 436 |
+
|
| 437 |
+
# If Applicable, Get Recorded Data #
|
| 438 |
+
if read_recording_folderpath:
|
| 439 |
+
timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
|
| 440 |
+
camera_type_dict = {
|
| 441 |
+
k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
|
| 442 |
+
}
|
| 443 |
+
camera_obs = camera_reader.read_cameras(
|
| 444 |
+
index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
|
| 445 |
+
)
|
| 446 |
+
camera_failed = camera_obs is None
|
| 447 |
+
|
| 448 |
+
# Add Data To Timestep If Successful #
|
| 449 |
+
if camera_failed:
|
| 450 |
+
break
|
| 451 |
+
timestep["observation"].update(camera_obs)
|
| 452 |
+
|
| 453 |
+
# Filter Steps #
|
| 454 |
+
step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
|
| 455 |
+
delete_skipped_step = step_skipped and remove_skipped_steps
|
| 456 |
+
|
| 457 |
+
# Save Filtered Timesteps #
|
| 458 |
+
if delete_skipped_step:
|
| 459 |
+
del timestep
|
| 460 |
+
else:
|
| 461 |
+
timestep_list.append(timestep)
|
| 462 |
+
|
| 463 |
+
# Remove Extra Transitions #
|
| 464 |
+
timestep_list = np.array(timestep_list)
|
| 465 |
+
if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
|
| 466 |
+
ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
|
| 467 |
+
timestep_list = timestep_list[ind_to_keep]
|
| 468 |
+
|
| 469 |
+
# Close Readers #
|
| 470 |
+
traj_reader.close()
|
| 471 |
+
|
| 472 |
+
# Return Data #
|
| 473 |
+
return timestep_list
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
if __name__ == "__main__":
|
| 477 |
+
tyro.cli(main)
|
capvector-pi05/examples/droid/main.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ruff: noqa
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import dataclasses
|
| 5 |
+
import datetime
|
| 6 |
+
import faulthandler
|
| 7 |
+
import os
|
| 8 |
+
import signal
|
| 9 |
+
import time
|
| 10 |
+
from moviepy.editor import ImageSequenceClip
|
| 11 |
+
import numpy as np
|
| 12 |
+
from openpi_client import image_tools
|
| 13 |
+
from openpi_client import websocket_client_policy
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from droid.robot_env import RobotEnv
|
| 17 |
+
import tqdm
|
| 18 |
+
import tyro
|
| 19 |
+
|
| 20 |
+
faulthandler.enable()
|
| 21 |
+
|
| 22 |
+
# DROID data collection frequency -- we slow down execution to match this frequency
|
| 23 |
+
DROID_CONTROL_FREQUENCY = 15
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclasses.dataclass
|
| 27 |
+
class Args:
|
| 28 |
+
# Hardware parameters
|
| 29 |
+
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
|
| 30 |
+
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
|
| 31 |
+
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
|
| 32 |
+
|
| 33 |
+
# Policy parameters
|
| 34 |
+
external_camera: str | None = (
|
| 35 |
+
None # which external camera should be fed to the policy, choose from ["left", "right"]
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Rollout parameters
|
| 39 |
+
max_timesteps: int = 600
|
| 40 |
+
# How many actions to execute from a predicted action chunk before querying policy server again
|
| 41 |
+
# 8 is usually a good default (equals 0.5 seconds of action execution).
|
| 42 |
+
open_loop_horizon: int = 8
|
| 43 |
+
|
| 44 |
+
# Remote server parameters
|
| 45 |
+
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
|
| 46 |
+
remote_port: int = (
|
| 47 |
+
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
|
| 52 |
+
# waiting for a new action chunk, it will raise an exception and the server connection dies.
|
| 53 |
+
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
|
| 54 |
+
@contextlib.contextmanager
|
| 55 |
+
def prevent_keyboard_interrupt():
|
| 56 |
+
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
|
| 57 |
+
interrupted = False
|
| 58 |
+
original_handler = signal.getsignal(signal.SIGINT)
|
| 59 |
+
|
| 60 |
+
def handler(signum, frame):
|
| 61 |
+
nonlocal interrupted
|
| 62 |
+
interrupted = True
|
| 63 |
+
|
| 64 |
+
signal.signal(signal.SIGINT, handler)
|
| 65 |
+
try:
|
| 66 |
+
yield
|
| 67 |
+
finally:
|
| 68 |
+
signal.signal(signal.SIGINT, original_handler)
|
| 69 |
+
if interrupted:
|
| 70 |
+
raise KeyboardInterrupt
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def main(args: Args):
|
| 74 |
+
# Make sure external camera is specified by user -- we only use one external camera for the policy
|
| 75 |
+
assert (
|
| 76 |
+
args.external_camera is not None and args.external_camera in ["left", "right"]
|
| 77 |
+
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
|
| 78 |
+
|
| 79 |
+
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
|
| 80 |
+
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
|
| 81 |
+
print("Created the droid env!")
|
| 82 |
+
|
| 83 |
+
# Connect to the policy server
|
| 84 |
+
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
|
| 85 |
+
|
| 86 |
+
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
|
| 87 |
+
|
| 88 |
+
while True:
|
| 89 |
+
instruction = input("Enter instruction: ")
|
| 90 |
+
|
| 91 |
+
# Rollout parameters
|
| 92 |
+
actions_from_chunk_completed = 0
|
| 93 |
+
pred_action_chunk = None
|
| 94 |
+
|
| 95 |
+
# Prepare to save video of rollout
|
| 96 |
+
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
|
| 97 |
+
video = []
|
| 98 |
+
bar = tqdm.tqdm(range(args.max_timesteps))
|
| 99 |
+
print("Running rollout... press Ctrl+C to stop early.")
|
| 100 |
+
for t_step in bar:
|
| 101 |
+
start_time = time.time()
|
| 102 |
+
try:
|
| 103 |
+
# Get the current observation
|
| 104 |
+
curr_obs = _extract_observation(
|
| 105 |
+
args,
|
| 106 |
+
env.get_observation(),
|
| 107 |
+
# Save the first observation to disk
|
| 108 |
+
save_to_disk=t_step == 0,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
video.append(curr_obs[f"{args.external_camera}_image"])
|
| 112 |
+
|
| 113 |
+
# Send websocket request to policy server if it's time to predict a new chunk
|
| 114 |
+
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
|
| 115 |
+
actions_from_chunk_completed = 0
|
| 116 |
+
|
| 117 |
+
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
|
| 118 |
+
# and improve latency.
|
| 119 |
+
request_data = {
|
| 120 |
+
"observation/exterior_image_1_left": image_tools.resize_with_pad(
|
| 121 |
+
curr_obs[f"{args.external_camera}_image"], 224, 224
|
| 122 |
+
),
|
| 123 |
+
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
|
| 124 |
+
"observation/joint_position": curr_obs["joint_position"],
|
| 125 |
+
"observation/gripper_position": curr_obs["gripper_position"],
|
| 126 |
+
"prompt": instruction,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
|
| 130 |
+
# Ctrl+C will be handled after the server call is complete
|
| 131 |
+
with prevent_keyboard_interrupt():
|
| 132 |
+
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
|
| 133 |
+
pred_action_chunk = policy_client.infer(request_data)["actions"]
|
| 134 |
+
assert pred_action_chunk.shape == (10, 8)
|
| 135 |
+
|
| 136 |
+
# Select current action to execute from chunk
|
| 137 |
+
action = pred_action_chunk[actions_from_chunk_completed]
|
| 138 |
+
actions_from_chunk_completed += 1
|
| 139 |
+
|
| 140 |
+
# Binarize gripper action
|
| 141 |
+
if action[-1].item() > 0.5:
|
| 142 |
+
# action[-1] = 1.0
|
| 143 |
+
action = np.concatenate([action[:-1], np.ones((1,))])
|
| 144 |
+
else:
|
| 145 |
+
# action[-1] = 0.0
|
| 146 |
+
action = np.concatenate([action[:-1], np.zeros((1,))])
|
| 147 |
+
|
| 148 |
+
# clip all dimensions of action to [-1, 1]
|
| 149 |
+
action = np.clip(action, -1, 1)
|
| 150 |
+
|
| 151 |
+
env.step(action)
|
| 152 |
+
|
| 153 |
+
# Sleep to match DROID data collection frequency
|
| 154 |
+
elapsed_time = time.time() - start_time
|
| 155 |
+
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
|
| 156 |
+
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
|
| 157 |
+
except KeyboardInterrupt:
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
video = np.stack(video)
|
| 161 |
+
save_filename = "video_" + timestamp
|
| 162 |
+
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
|
| 163 |
+
|
| 164 |
+
success: str | float | None = None
|
| 165 |
+
while not isinstance(success, float):
|
| 166 |
+
success = input(
|
| 167 |
+
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
|
| 168 |
+
)
|
| 169 |
+
if success == "y":
|
| 170 |
+
success = 1.0
|
| 171 |
+
elif success == "n":
|
| 172 |
+
success = 0.0
|
| 173 |
+
|
| 174 |
+
success = float(success) / 100
|
| 175 |
+
if not (0 <= success <= 1):
|
| 176 |
+
print(f"Success must be a number in [0, 100] but got: {success * 100}")
|
| 177 |
+
|
| 178 |
+
df = df.append(
|
| 179 |
+
{
|
| 180 |
+
"success": success,
|
| 181 |
+
"duration": t_step,
|
| 182 |
+
"video_filename": save_filename,
|
| 183 |
+
},
|
| 184 |
+
ignore_index=True,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if input("Do one more eval? (enter y or n) ").lower() != "y":
|
| 188 |
+
break
|
| 189 |
+
env.reset()
|
| 190 |
+
|
| 191 |
+
os.makedirs("results", exist_ok=True)
|
| 192 |
+
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
|
| 193 |
+
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
|
| 194 |
+
df.to_csv(csv_filename)
|
| 195 |
+
print(f"Results saved to {csv_filename}")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
|
| 199 |
+
image_observations = obs_dict["image"]
|
| 200 |
+
left_image, right_image, wrist_image = None, None, None
|
| 201 |
+
for key in image_observations:
|
| 202 |
+
# Note the "left" below refers to the left camera in the stereo pair.
|
| 203 |
+
# The model is only trained on left stereo cams, so we only feed those.
|
| 204 |
+
if args.left_camera_id in key and "left" in key:
|
| 205 |
+
left_image = image_observations[key]
|
| 206 |
+
elif args.right_camera_id in key and "left" in key:
|
| 207 |
+
right_image = image_observations[key]
|
| 208 |
+
elif args.wrist_camera_id in key and "left" in key:
|
| 209 |
+
wrist_image = image_observations[key]
|
| 210 |
+
|
| 211 |
+
# Drop the alpha dimension
|
| 212 |
+
left_image = left_image[..., :3]
|
| 213 |
+
right_image = right_image[..., :3]
|
| 214 |
+
wrist_image = wrist_image[..., :3]
|
| 215 |
+
|
| 216 |
+
# Convert to RGB
|
| 217 |
+
left_image = left_image[..., ::-1]
|
| 218 |
+
right_image = right_image[..., ::-1]
|
| 219 |
+
wrist_image = wrist_image[..., ::-1]
|
| 220 |
+
|
| 221 |
+
# In addition to image observations, also capture the proprioceptive state
|
| 222 |
+
robot_state = obs_dict["robot_state"]
|
| 223 |
+
cartesian_position = np.array(robot_state["cartesian_position"])
|
| 224 |
+
joint_position = np.array(robot_state["joint_positions"])
|
| 225 |
+
gripper_position = np.array([robot_state["gripper_position"]])
|
| 226 |
+
|
| 227 |
+
# Save the images to disk so that they can be viewed live while the robot is running
|
| 228 |
+
# Create one combined image to make live viewing easy
|
| 229 |
+
if save_to_disk:
|
| 230 |
+
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
|
| 231 |
+
combined_image = Image.fromarray(combined_image)
|
| 232 |
+
combined_image.save("robot_camera_views.png")
|
| 233 |
+
|
| 234 |
+
return {
|
| 235 |
+
"left_image": left_image,
|
| 236 |
+
"right_image": right_image,
|
| 237 |
+
"wrist_image": wrist_image,
|
| 238 |
+
"cartesian_position": cartesian_position,
|
| 239 |
+
"joint_position": joint_position,
|
| 240 |
+
"gripper_position": gripper_position,
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
args: Args = tyro.cli(Args)
|
| 246 |
+
main(args)
|
capvector-pi05/examples/inference.ipynb
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import dataclasses\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"import jax\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"from openpi.models import model as _model\n",
|
| 14 |
+
"from openpi.policies import droid_policy\n",
|
| 15 |
+
"from openpi.policies import policy_config as _policy_config\n",
|
| 16 |
+
"from openpi.shared import download\n",
|
| 17 |
+
"from openpi.training import config as _config\n",
|
| 18 |
+
"from openpi.training import data_loader as _data_loader"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "markdown",
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"source": [
|
| 25 |
+
"# Policy inference\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": null,
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"outputs": [],
|
| 35 |
+
"source": [
|
| 36 |
+
"config = _config.get_config(\"pi0_fast_droid\")\n",
|
| 37 |
+
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"# Create a trained policy.\n",
|
| 40 |
+
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
|
| 43 |
+
"example = droid_policy.make_droid_example()\n",
|
| 44 |
+
"result = policy.infer(example)\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"# Delete the policy to free up memory.\n",
|
| 47 |
+
"del policy\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"print(\"Actions shape:\", result[\"actions\"].shape)"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "markdown",
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"source": [
|
| 56 |
+
"# Working with a live model\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": null,
|
| 65 |
+
"metadata": {},
|
| 66 |
+
"outputs": [],
|
| 67 |
+
"source": [
|
| 68 |
+
"config = _config.get_config(\"pi0_aloha_sim\")\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
|
| 71 |
+
"key = jax.random.key(0)\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"# Create a model from the checkpoint.\n",
|
| 74 |
+
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# We can create fake observations and actions to test the model.\n",
|
| 77 |
+
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# Sample actions from the model.\n",
|
| 80 |
+
"loss = model.compute_loss(key, obs, act)\n",
|
| 81 |
+
"print(\"Loss shape:\", loss.shape)"
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "markdown",
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"source": [
|
| 88 |
+
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": null,
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"# Reduce the batch size to reduce memory usage.\n",
|
| 98 |
+
"config = dataclasses.replace(config, batch_size=2)\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"# Load a single batch of data. This is the same data that will be used during training.\n",
|
| 101 |
+
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
|
| 102 |
+
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
|
| 103 |
+
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
|
| 104 |
+
"obs, act = next(iter(loader))\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# Sample actions from the model.\n",
|
| 107 |
+
"loss = model.compute_loss(key, obs, act)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# Delete the model to free up memory.\n",
|
| 110 |
+
"del model\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"print(\"Loss shape:\", loss.shape)"
|
| 113 |
+
]
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
+
"metadata": {
|
| 117 |
+
"kernelspec": {
|
| 118 |
+
"display_name": ".venv",
|
| 119 |
+
"language": "python",
|
| 120 |
+
"name": "python3"
|
| 121 |
+
},
|
| 122 |
+
"language_info": {
|
| 123 |
+
"codemirror_mode": {
|
| 124 |
+
"name": "ipython",
|
| 125 |
+
"version": 3
|
| 126 |
+
},
|
| 127 |
+
"file_extension": ".py",
|
| 128 |
+
"mimetype": "text/x-python",
|
| 129 |
+
"name": "python",
|
| 130 |
+
"nbconvert_exporter": "python",
|
| 131 |
+
"pygments_lexer": "ipython3",
|
| 132 |
+
"version": "3.11.9"
|
| 133 |
+
}
|
| 134 |
+
},
|
| 135 |
+
"nbformat": 4,
|
| 136 |
+
"nbformat_minor": 2
|
| 137 |
+
}
|
capvector-pi05/examples/libero/compose.yml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run with:
|
| 2 |
+
# docker compose -f examples/libero/compose.yml up --build
|
| 3 |
+
services:
|
| 4 |
+
runtime:
|
| 5 |
+
image: libero
|
| 6 |
+
depends_on:
|
| 7 |
+
- openpi_server
|
| 8 |
+
build:
|
| 9 |
+
context: ../..
|
| 10 |
+
dockerfile: examples/libero/Dockerfile
|
| 11 |
+
init: true
|
| 12 |
+
tty: true
|
| 13 |
+
network_mode: host
|
| 14 |
+
privileged: true
|
| 15 |
+
volumes:
|
| 16 |
+
- $PWD:/app
|
| 17 |
+
- ../../data:/data
|
| 18 |
+
- /tmp/.X11-unix:/tmp/.X11-unix:ro
|
| 19 |
+
environment:
|
| 20 |
+
- CLIENT_ARGS
|
| 21 |
+
- DISPLAY=$DISPLAY
|
| 22 |
+
- MUJOCO_GL=${MUJOCO_GL:-egl}
|
| 23 |
+
deploy:
|
| 24 |
+
resources:
|
| 25 |
+
reservations:
|
| 26 |
+
devices:
|
| 27 |
+
- driver: nvidia
|
| 28 |
+
count: 1
|
| 29 |
+
capabilities: [gpu]
|
| 30 |
+
|
| 31 |
+
openpi_server:
|
| 32 |
+
image: openpi_server
|
| 33 |
+
build:
|
| 34 |
+
context: ../..
|
| 35 |
+
dockerfile: scripts/docker/serve_policy.Dockerfile
|
| 36 |
+
init: true
|
| 37 |
+
tty: true
|
| 38 |
+
network_mode: host
|
| 39 |
+
volumes:
|
| 40 |
+
- $PWD:/app
|
| 41 |
+
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
| 42 |
+
environment:
|
| 43 |
+
- SERVER_ARGS
|
| 44 |
+
- OPENPI_DATA_HOME=/openpi_assets
|
| 45 |
+
- IS_DOCKER=true
|
| 46 |
+
|
| 47 |
+
# Comment out this block if not running on a machine with GPUs.
|
| 48 |
+
deploy:
|
| 49 |
+
resources:
|
| 50 |
+
reservations:
|
| 51 |
+
devices:
|
| 52 |
+
- driver: nvidia
|
| 53 |
+
count: 1
|
| 54 |
+
capabilities: [gpu]
|
capvector-pi05/examples/libero/convert_libero_data_to_lerobot.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Minimal example script for converting a dataset to LeRobot format.
|
| 3 |
+
|
| 4 |
+
We use the Libero dataset (stored in RLDS) for this example, but it can be easily
|
| 5 |
+
modified for any other data you have saved in a custom format.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
|
| 9 |
+
|
| 10 |
+
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
|
| 11 |
+
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
|
| 12 |
+
|
| 13 |
+
Note: to run the script, you need to install tensorflow_datasets:
|
| 14 |
+
`uv pip install tensorflow tensorflow_datasets`
|
| 15 |
+
|
| 16 |
+
You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
|
| 17 |
+
The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
|
| 18 |
+
Running this conversion script will take approximately 30 minutes.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import shutil
|
| 22 |
+
|
| 23 |
+
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
|
| 24 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 25 |
+
import tensorflow_datasets as tfds
|
| 26 |
+
import tyro
|
| 27 |
+
|
| 28 |
+
REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
|
| 29 |
+
RAW_DATASET_NAMES = [
|
| 30 |
+
"libero_10_no_noops",
|
| 31 |
+
"libero_goal_no_noops",
|
| 32 |
+
"libero_object_no_noops",
|
| 33 |
+
"libero_spatial_no_noops",
|
| 34 |
+
] # For simplicity we will combine multiple Libero datasets into one training dataset
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(data_dir: str, *, push_to_hub: bool = False):
|
| 38 |
+
# Clean up any existing dataset in the output directory
|
| 39 |
+
output_path = HF_LEROBOT_HOME / REPO_NAME
|
| 40 |
+
if output_path.exists():
|
| 41 |
+
shutil.rmtree(output_path)
|
| 42 |
+
|
| 43 |
+
# Create LeRobot dataset, define features to store
|
| 44 |
+
# OpenPi assumes that proprio is stored in `state` and actions in `action`
|
| 45 |
+
# LeRobot assumes that dtype of image data is `image`
|
| 46 |
+
dataset = LeRobotDataset.create(
|
| 47 |
+
repo_id=REPO_NAME,
|
| 48 |
+
robot_type="panda",
|
| 49 |
+
fps=10,
|
| 50 |
+
features={
|
| 51 |
+
"image": {
|
| 52 |
+
"dtype": "image",
|
| 53 |
+
"shape": (256, 256, 3),
|
| 54 |
+
"names": ["height", "width", "channel"],
|
| 55 |
+
},
|
| 56 |
+
"wrist_image": {
|
| 57 |
+
"dtype": "image",
|
| 58 |
+
"shape": (256, 256, 3),
|
| 59 |
+
"names": ["height", "width", "channel"],
|
| 60 |
+
},
|
| 61 |
+
"state": {
|
| 62 |
+
"dtype": "float32",
|
| 63 |
+
"shape": (8,),
|
| 64 |
+
"names": ["state"],
|
| 65 |
+
},
|
| 66 |
+
"actions": {
|
| 67 |
+
"dtype": "float32",
|
| 68 |
+
"shape": (7,),
|
| 69 |
+
"names": ["actions"],
|
| 70 |
+
},
|
| 71 |
+
},
|
| 72 |
+
image_writer_threads=10,
|
| 73 |
+
image_writer_processes=5,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Loop over raw Libero datasets and write episodes to the LeRobot dataset
|
| 77 |
+
# You can modify this for your own data format
|
| 78 |
+
for raw_dataset_name in RAW_DATASET_NAMES:
|
| 79 |
+
raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
|
| 80 |
+
for episode in raw_dataset:
|
| 81 |
+
for step in episode["steps"].as_numpy_iterator():
|
| 82 |
+
dataset.add_frame(
|
| 83 |
+
{
|
| 84 |
+
"image": step["observation"]["image"],
|
| 85 |
+
"wrist_image": step["observation"]["wrist_image"],
|
| 86 |
+
"state": step["observation"]["state"],
|
| 87 |
+
"actions": step["action"],
|
| 88 |
+
"task": step["language_instruction"].decode(),
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
dataset.save_episode()
|
| 92 |
+
|
| 93 |
+
# Optionally push to the Hugging Face Hub
|
| 94 |
+
if push_to_hub:
|
| 95 |
+
dataset.push_to_hub(
|
| 96 |
+
tags=["libero", "panda", "rlds"],
|
| 97 |
+
private=False,
|
| 98 |
+
push_videos=True,
|
| 99 |
+
license="apache-2.0",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
tyro.cli(main)
|
capvector-pi05/examples/policy_records.ipynb
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import pathlib\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"import numpy as np\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"record_path = pathlib.Path(\"../policy_records\")\n",
|
| 14 |
+
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"records = []\n",
|
| 17 |
+
"for i in range(num_steps):\n",
|
| 18 |
+
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
|
| 19 |
+
" records.append(record)"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"print(\"length of records\", len(records))\n",
|
| 29 |
+
"print(\"keys in records\", records[0].keys())\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"for k in records[0]:\n",
|
| 32 |
+
" print(f\"{k} shape: {records[0][k].shape}\")"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "code",
|
| 37 |
+
"execution_count": null,
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"outputs": [],
|
| 40 |
+
"source": [
|
| 41 |
+
"from PIL import Image\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"def get_image(step: int, idx: int = 0):\n",
|
| 45 |
+
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
|
| 46 |
+
" return img[idx].transpose(1, 2, 0)\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"def show_image(step: int, idx_lst: list[int]):\n",
|
| 50 |
+
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
|
| 51 |
+
" return Image.fromarray(np.hstack(imgs))\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"for i in range(2):\n",
|
| 55 |
+
" display(show_image(i, [0]))"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": 14,
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"import pandas as pd\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"def get_axis(name, axis):\n",
|
| 68 |
+
" return np.array([record[name][axis] for record in records])\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"# qpos is [..., 14] of type float:\n",
|
| 72 |
+
"# 0-5: left arm joint angles\n",
|
| 73 |
+
"# 6: left arm gripper\n",
|
| 74 |
+
"# 7-12: right arm joint angles\n",
|
| 75 |
+
"# 13: right arm gripper\n",
|
| 76 |
+
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"def make_data():\n",
|
| 80 |
+
" cur_dim = 0\n",
|
| 81 |
+
" in_data = {}\n",
|
| 82 |
+
" out_data = {}\n",
|
| 83 |
+
" for name, dim_size in names:\n",
|
| 84 |
+
" for i in range(dim_size):\n",
|
| 85 |
+
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
|
| 86 |
+
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
|
| 87 |
+
" cur_dim += 1\n",
|
| 88 |
+
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"in_data, out_data = make_data()"
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"cell_type": "code",
|
| 96 |
+
"execution_count": null,
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"source": [
|
| 100 |
+
"for name in in_data.columns:\n",
|
| 101 |
+
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
|
| 102 |
+
" data.plot()"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "code",
|
| 107 |
+
"execution_count": null,
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"outputs": [],
|
| 110 |
+
"source": []
|
| 111 |
+
}
|
| 112 |
+
],
|
| 113 |
+
"metadata": {
|
| 114 |
+
"kernelspec": {
|
| 115 |
+
"display_name": ".venv",
|
| 116 |
+
"language": "python",
|
| 117 |
+
"name": "python3"
|
| 118 |
+
},
|
| 119 |
+
"language_info": {
|
| 120 |
+
"codemirror_mode": {
|
| 121 |
+
"name": "ipython",
|
| 122 |
+
"version": 3
|
| 123 |
+
},
|
| 124 |
+
"file_extension": ".py",
|
| 125 |
+
"mimetype": "text/x-python",
|
| 126 |
+
"name": "python",
|
| 127 |
+
"nbconvert_exporter": "python",
|
| 128 |
+
"pygments_lexer": "ipython3",
|
| 129 |
+
"version": "3.11.9"
|
| 130 |
+
}
|
| 131 |
+
},
|
| 132 |
+
"nbformat": 4,
|
| 133 |
+
"nbformat_minor": 2
|
| 134 |
+
}
|
capvector-pi05/pyproject.toml
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "openpi"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Physical Intelligence open source repo"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
license = { file = "LICENSE" }
|
| 8 |
+
dependencies = [
|
| 9 |
+
"augmax>=0.3.4",
|
| 10 |
+
"dm-tree>=0.1.8",
|
| 11 |
+
"einops>=0.8.0",
|
| 12 |
+
"equinox>=0.11.8",
|
| 13 |
+
"flatbuffers>=24.3.25",
|
| 14 |
+
"flax==0.10.2",
|
| 15 |
+
"fsspec[gcs]>=2024.6.0",
|
| 16 |
+
"gym-aloha>=0.1.1",
|
| 17 |
+
"imageio>=2.36.1",
|
| 18 |
+
"jax[cuda12]==0.5.3",
|
| 19 |
+
"jaxtyping==0.2.36",
|
| 20 |
+
"lerobot",
|
| 21 |
+
"ml_collections==1.0.0",
|
| 22 |
+
"numpy>=1.22.4,<2.0.0",
|
| 23 |
+
"numpydantic>=1.6.6",
|
| 24 |
+
"opencv-python>=4.10.0.84",
|
| 25 |
+
"openpi-client",
|
| 26 |
+
"orbax-checkpoint==0.11.13",
|
| 27 |
+
"pillow>=11.0.0",
|
| 28 |
+
"sentencepiece>=0.2.0",
|
| 29 |
+
"torch==2.7.1",
|
| 30 |
+
"tqdm-loggable>=0.2",
|
| 31 |
+
"typing-extensions>=4.12.2",
|
| 32 |
+
"tyro>=0.9.5",
|
| 33 |
+
"wandb>=0.19.1",
|
| 34 |
+
"filelock>=3.16.1",
|
| 35 |
+
"beartype==0.19.0",
|
| 36 |
+
"treescope>=0.1.7",
|
| 37 |
+
"transformers==4.53.2",
|
| 38 |
+
"rich>=14.0.0",
|
| 39 |
+
"polars>=1.30.0",
|
| 40 |
+
"gradio==5.17.1",
|
| 41 |
+
"viser==0.2.23",
|
| 42 |
+
"hydra-core",
|
| 43 |
+
"onnxruntime",
|
| 44 |
+
"safetensors",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
[project.urls]
|
| 49 |
+
Repository = "https://github.com/Physical-Intelligence/openpi"
|
| 50 |
+
|
| 51 |
+
[dependency-groups]
|
| 52 |
+
dev = [
|
| 53 |
+
"pytest>=8.3.4",
|
| 54 |
+
"ruff>=0.8.6",
|
| 55 |
+
"pre-commit>=4.0.1",
|
| 56 |
+
"ipykernel>=6.29.5",
|
| 57 |
+
"ipywidgets>=8.1.5",
|
| 58 |
+
"matplotlib>=3.10.0",
|
| 59 |
+
"pynvml>=12.0.0",
|
| 60 |
+
]
|
| 61 |
+
rlds = [
|
| 62 |
+
"dlimp",
|
| 63 |
+
"tensorflow-cpu==2.15.0",
|
| 64 |
+
"tensorflow-datasets==4.9.9",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
[tool.uv]
|
| 68 |
+
override-dependencies = ["datasets==3.6.0", "ml-dtypes==0.4.1", "tensorstore==0.1.74"]
|
| 69 |
+
|
| 70 |
+
[tool.uv.sources]
|
| 71 |
+
openpi-client = { workspace = true }
|
| 72 |
+
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
|
| 73 |
+
dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" }
|
| 74 |
+
|
| 75 |
+
[tool.uv.workspace]
|
| 76 |
+
members = ["packages/*", "src/vggt"]
|
| 77 |
+
|
| 78 |
+
[tool.ruff]
|
| 79 |
+
line-length = 120
|
| 80 |
+
target-version = "py311"
|
| 81 |
+
extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"]
|
| 82 |
+
|
| 83 |
+
[tool.ruff.lint]
|
| 84 |
+
# https://docs.astral.sh/ruff/rules/
|
| 85 |
+
select = [
|
| 86 |
+
"B",
|
| 87 |
+
"C4",
|
| 88 |
+
"DTZ",
|
| 89 |
+
"E4",
|
| 90 |
+
"E7",
|
| 91 |
+
"E9",
|
| 92 |
+
"F",
|
| 93 |
+
"FBT",
|
| 94 |
+
"FURB",
|
| 95 |
+
"I",
|
| 96 |
+
"ICN",
|
| 97 |
+
"ISC",
|
| 98 |
+
"LOG",
|
| 99 |
+
"N",
|
| 100 |
+
"PD",
|
| 101 |
+
"PERF",
|
| 102 |
+
"PIE",
|
| 103 |
+
"PLC",
|
| 104 |
+
"PLE",
|
| 105 |
+
"PLR1",
|
| 106 |
+
"PLR5",
|
| 107 |
+
"PLW",
|
| 108 |
+
"PT",
|
| 109 |
+
"Q",
|
| 110 |
+
"RET",
|
| 111 |
+
"RUF",
|
| 112 |
+
"SIM",
|
| 113 |
+
"SLF",
|
| 114 |
+
"T10",
|
| 115 |
+
"T20",
|
| 116 |
+
"UP",
|
| 117 |
+
"W",
|
| 118 |
+
]
|
| 119 |
+
ignore = [
|
| 120 |
+
"F722", # Conflicts with array typing.
|
| 121 |
+
"T201", # We use print statements.
|
| 122 |
+
"PD008", # Lots of false positives.
|
| 123 |
+
"ISC001", # Disabling to support ruff format.
|
| 124 |
+
"LOG015", # Use logger.info.
|
| 125 |
+
]
|
| 126 |
+
unfixable = [
|
| 127 |
+
"B905", # Fix defaults to strict=False, which is not what we want.
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
[tool.ruff.lint.isort]
|
| 131 |
+
force-single-line = true
|
| 132 |
+
force-sort-within-sections = true
|
| 133 |
+
single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
|
| 134 |
+
known-third-party = ["wandb"]
|
| 135 |
+
|
| 136 |
+
[build-system]
|
| 137 |
+
requires = ["hatchling"]
|
| 138 |
+
build-backend = "hatchling.build"
|
| 139 |
+
|
| 140 |
+
[tool.pytest.ini_options]
|
| 141 |
+
markers = ["manual: should be run manually."]
|
| 142 |
+
testpaths = ["src", "scripts", "packages"]
|