haofuly commited on
Commit
45ac12e
·
verified ·
1 Parent(s): cf587f4

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. capvector-oft/scripts/extern/convert_prismatic_weights_to_hf.py +237 -0
  2. capvector-oft/training_scripts/training.sh +36 -0
  3. capvector-oft/vla-scripts/extern/convert_openvla_weights_to_hf.py +272 -0
  4. capvector-oft/vla-scripts/extern/verify_openvla.py +89 -0
  5. capvector-oft/vla-scripts/finetune.py +1152 -0
  6. capvector-oft/vla-scripts/finetune_regular_loss.py +1790 -0
  7. capvector-oft/vla-scripts/merge_lora_weights_and_save.py +73 -0
  8. capvector-pi05/.dockerignore +3 -0
  9. capvector-pi05/.gitignore +169 -0
  10. capvector-pi05/.gitmodules +6 -0
  11. capvector-pi05/.pre-commit-config.yaml +16 -0
  12. capvector-pi05/.python-version +1 -0
  13. capvector-pi05/LICENSE +201 -0
  14. capvector-pi05/README.md +128 -0
  15. capvector-pi05/capvector/apply_param_diff.py +135 -0
  16. capvector-pi05/capvector/compute_param_diff.py +142 -0
  17. capvector-pi05/docs/docker.md +25 -0
  18. capvector-pi05/docs/norm_stats.md +69 -0
  19. capvector-pi05/docs/remote_inference.md +71 -0
  20. capvector-pi05/examples/aloha_real/Dockerfile +70 -0
  21. capvector-pi05/examples/aloha_real/README.md +126 -0
  22. capvector-pi05/examples/aloha_real/compose.yml +66 -0
  23. capvector-pi05/examples/aloha_real/constants.py +71 -0
  24. capvector-pi05/examples/aloha_real/convert_aloha_data_to_lerobot.py +263 -0
  25. capvector-pi05/examples/aloha_real/env.py +57 -0
  26. capvector-pi05/examples/aloha_real/main.py +51 -0
  27. capvector-pi05/examples/aloha_real/real_env.py +176 -0
  28. capvector-pi05/examples/aloha_real/requirements.in +18 -0
  29. capvector-pi05/examples/aloha_real/requirements.txt +156 -0
  30. capvector-pi05/examples/aloha_real/robot_utils.py +275 -0
  31. capvector-pi05/examples/aloha_real/video_display.py +36 -0
  32. capvector-pi05/examples/aloha_sim/Dockerfile +41 -0
  33. capvector-pi05/examples/aloha_sim/README.md +36 -0
  34. capvector-pi05/examples/aloha_sim/compose.yml +42 -0
  35. capvector-pi05/examples/aloha_sim/env.py +56 -0
  36. capvector-pi05/examples/aloha_sim/main.py +55 -0
  37. capvector-pi05/examples/aloha_sim/requirements.in +8 -0
  38. capvector-pi05/examples/aloha_sim/requirements.txt +132 -0
  39. capvector-pi05/examples/aloha_sim/saver.py +40 -0
  40. capvector-pi05/examples/convert_jax_model_to_pytorch.py +587 -0
  41. capvector-pi05/examples/droid/README.md +84 -0
  42. capvector-pi05/examples/droid/README_train.md +106 -0
  43. capvector-pi05/examples/droid/compute_droid_nonidle_ranges.py +103 -0
  44. capvector-pi05/examples/droid/convert_droid_data_to_lerobot.py +477 -0
  45. capvector-pi05/examples/droid/main.py +246 -0
  46. capvector-pi05/examples/inference.ipynb +137 -0
  47. capvector-pi05/examples/libero/compose.yml +54 -0
  48. capvector-pi05/examples/libero/convert_libero_data_to_lerobot.py +104 -0
  49. capvector-pi05/examples/policy_records.ipynb +134 -0
  50. capvector-pi05/pyproject.toml +142 -0
capvector-oft/scripts/extern/convert_prismatic_weights_to_hf.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ convert_prismatic_weights_to_hf.py
3
+
4
+ Utility script for converting full Prismatic VLM weights (from this repository, in the default "Prismatic" format) to
5
+ the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers``
6
+ via `trust_remote_code = True`.
7
+
8
+ Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the
9
+ line, with first-class support.
10
+ """
11
+
12
+ import json
13
+ import os
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from typing import Dict, List, Union
17
+
18
+ import draccus
19
+ import timm
20
+ import torch
21
+ import torch.nn as nn
22
+ from huggingface_hub import hf_hub_download
23
+ from timm.models.vision_transformer import LayerScale
24
+ from transformers import AutoTokenizer
25
+
26
+ from prismatic.extern.hf.configuration_prismatic import PrismaticConfig
27
+ from prismatic.extern.hf.modeling_prismatic import PrismaticForConditionalGeneration
28
+ from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
29
+
30
+
31
+ @dataclass
32
+ class HFConvertConfig:
33
+ # fmt: off
34
+ prismatic_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLM (on disk or HF Hub)
35
+ "siglip-224px+7b"
36
+ # "prism-dinosiglip-224px+7b"
37
+ )
38
+ output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model
39
+ "hf-convert/prismatic-siglip-224px-7b"
40
+ )
41
+ output_hf_model_hub_path: str = ( # Path to HF Hub Path for "final" HF model
42
+ "TRI-ML/prismatic-siglip-224px-7b" # => huggingface.co/TRI-ML/prismatic-{...}
43
+ )
44
+
45
+ # HF Hub Credentials (required for Gated Models like LLaMa-2)
46
+ hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
47
+
48
+ def __post_init__(self) -> None:
49
+ self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token
50
+
51
+ # fmt: on
52
+
53
+
54
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
55
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
56
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
57
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
59
+
60
+
61
+ def ls_apply_patch(ls_module: LayerScale):
62
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
63
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
64
+ del ls_module.gamma
65
+
66
+
67
+ # === Conversion Constants ===
68
+ PROJECTOR_KEY_MAPPING = {
69
+ "projector.0.weight": "projector.fc1.weight",
70
+ "projector.0.bias": "projector.fc1.bias",
71
+ "projector.2.weight": "projector.fc2.weight",
72
+ "projector.2.bias": "projector.fc2.bias",
73
+ "projector.4.weight": "projector.fc3.weight",
74
+ "projector.4.bias": "projector.fc3.bias",
75
+ }
76
+
77
+
78
+ def remap_state_dicts_for_hf(
79
+ projector_state_dict: Dict[str, torch.Tensor],
80
+ llm_backbone_state_dict: Dict[str, torch.Tensor],
81
+ vision_backbone_state_dicts: List[Dict[str, torch.Tensor]],
82
+ ) -> Dict[str, torch.Tensor]:
83
+ """Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion."""
84
+ hf_state_dict = {}
85
+
86
+ # Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING`
87
+ for key, value in projector_state_dict.items():
88
+ hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value
89
+
90
+ # Iterate through LLM Backbone =>> replace `llm.` with `language_model.`
91
+ for key, value in llm_backbone_state_dict.items():
92
+ hf_state_dict[key.replace("llm.", "language_model.")] = value
93
+
94
+ # Iterate through Vision Backbone =>> add "vision_backbone." prefix
95
+ assert len(vision_backbone_state_dicts) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!"
96
+ for idx, vision_backbone_state_dict in enumerate(vision_backbone_state_dicts):
97
+ prefix = "vision_backbone.featurizer" if idx == 0 else "vision_backbone.fused_featurizer"
98
+ for key, value in vision_backbone_state_dict.items():
99
+ hf_state_dict[f"{prefix}.{key}"] = value
100
+
101
+ return hf_state_dict
102
+
103
+
104
+ @draccus.wrap()
105
+ def convert_prismatic_weights_to_hf(cfg: HFConvertConfig) -> None:
106
+ print(f"[*] Converting Prismatic Model `{cfg.prismatic_model_path_or_id}` to HF Transformers Format")
107
+ torch.set_default_dtype(torch.bfloat16)
108
+
109
+ # Get `config.json` and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py`
110
+ if os.path.isdir(cfg.prismatic_model_path_or_id):
111
+ print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.prismatic_model_path_or_id))}`")
112
+ config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
113
+
114
+ assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
115
+ assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
116
+ else:
117
+ print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.prismatic_model_path_or_id}`")
118
+ config_json = hf_hub_download("TRI-ML/prismatic-vlms", f"{cfg.prismatic_model_path_or_id}/config.json")
119
+ checkpoint_pt = hf_hub_download(
120
+ "TRI-ML/prismatic-vlms", f"{cfg.prismatic_model_path_or_id}/checkpoints/latest-checkpoint.pt"
121
+ )
122
+
123
+ # Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer
124
+ with open(config_json, "r") as f:
125
+ prismatic_config = json.load(f)["model"]
126
+
127
+ # Create HF PrismaticConfig (`transformers.PretrainedConfig`)
128
+ hf_config = PrismaticConfig(
129
+ vision_backbone_id=prismatic_config["vision_backbone_id"],
130
+ llm_backbone_id=prismatic_config["llm_backbone_id"],
131
+ arch_specifier=prismatic_config["arch_specifier"],
132
+ image_resize_strategy=prismatic_config["image_resize_strategy"],
133
+ llm_max_length=prismatic_config["llm_max_length"],
134
+ torch_dtype=torch.bfloat16,
135
+ )
136
+
137
+ # Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer`
138
+ # TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`!
139
+ print("[*] Instantiating and Patching Tokenizer, LLM Config")
140
+ tokenizer = AutoTokenizer.from_pretrained(
141
+ hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right"
142
+ )
143
+ tokenizer.add_special_tokens({"pad_token": "<PAD>"})
144
+ tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload...
145
+ assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!"
146
+ assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!"
147
+
148
+ # Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate
149
+ hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of
150
+ hf_config.text_config.pad_token_id = hf_config.pad_token_id
151
+ hf_config.text_config.torch_dtype = torch.bfloat16
152
+ assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!"
153
+
154
+ # Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform`
155
+ # =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py`
156
+ print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor")
157
+ timm_vision_backbones, input_sizes, interpolations, means, stds = [], [], [], [], []
158
+ for idx, timm_model_id in enumerate(hf_config.timm_model_ids):
159
+ timm_vision_backbone = timm.create_model(
160
+ timm_model_id,
161
+ pretrained=True,
162
+ num_classes=0,
163
+ img_size=hf_config.image_sizes[idx],
164
+ act_layer=hf_config.timm_override_act_layers[idx],
165
+ )
166
+ timm_vision_backbones.append(timm_vision_backbone)
167
+
168
+ # Get Per-Backbone Image Processing
169
+ data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone)
170
+ input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx]))
171
+ interpolations.append(data_cfg["interpolation"])
172
+ means.append(data_cfg["mean"])
173
+ stds.append(data_cfg["std"])
174
+
175
+ # Patch `LayerScale` because of HF annoying `fix_key` overwrite...
176
+ for module in timm_vision_backbone.modules():
177
+ if isinstance(module, LayerScale):
178
+ ls_apply_patch(module)
179
+
180
+ # Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`)
181
+ hf_image_processor = PrismaticImageProcessor(
182
+ use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
183
+ image_resize_strategy=hf_config.image_resize_strategy,
184
+ input_sizes=input_sizes,
185
+ interpolations=interpolations,
186
+ means=means,
187
+ stds=stds,
188
+ )
189
+
190
+ # Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor)
191
+ print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor")
192
+ hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer)
193
+
194
+ # Load Prismatic Model State Dictionary (in preparation for conversion)
195
+ print("[*] Loading Prismatic VLM State Dictionary from Checkpoint")
196
+ model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"]
197
+ assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?"
198
+ assert ("projector" in model_state_dict) and ("llm_backbone" in model_state_dict), "Missing keys!"
199
+
200
+ # Convert
201
+ print("[*] Running Conversion")
202
+ converted_state_dict = remap_state_dicts_for_hf(
203
+ model_state_dict["projector"],
204
+ model_state_dict["llm_backbone"],
205
+ vision_backbone_state_dicts=[vb.state_dict() for vb in timm_vision_backbones],
206
+ )
207
+
208
+ # Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM
209
+ print("[*] Building (Randomly Initialized) Model =>> PrismaticForConditionalGeneration")
210
+ hf_model = PrismaticForConditionalGeneration(hf_config)
211
+ hf_model.load_state_dict(converted_state_dict, strict=True, assign=True)
212
+
213
+ # Cast Model to BF16 before Saving
214
+ hf_model.to(torch.bfloat16)
215
+
216
+ # Save Pretrained Versions to Local Path
217
+ print("[*] Saving Model & Processor to Local Path")
218
+ hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB")
219
+ hf_image_processor.save_pretrained(cfg.output_hf_model_local_path)
220
+ hf_processor.save_pretrained(cfg.output_hf_model_local_path)
221
+
222
+ # Register AutoClasses
223
+ PrismaticConfig.register_for_auto_class()
224
+ PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor")
225
+ PrismaticProcessor.register_for_auto_class("AutoProcessor")
226
+ PrismaticForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq")
227
+
228
+ # Push to Hub
229
+ print("[*] Pushing Model & Processor to HF Hub")
230
+ hf_config.push_to_hub(cfg.output_hf_model_hub_path)
231
+ hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB")
232
+ hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path)
233
+ hf_processor.push_to_hub(cfg.output_hf_model_hub_path)
234
+
235
+
236
+ if __name__ == "__main__":
237
+ convert_prismatic_weights_to_hf()
capvector-oft/training_scripts/training.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VERSION="v0"
2
+ TASK="10" # spatial / object / goal / 10 / 90
3
+ VLA_PATH="checkpoints/initialized_pt_vla/initailized_openvla_with_SF_spatial_v0.4.2"
4
+ DATA_ROOT_DIR="data/libero_openvla"
5
+ RUN_ROOT_DIR="experiments/training_results"
6
+ REGULARIZATION_LORA_VECTOR_PATH="checkpoints/lora_diff/sf_150000_steps_spatial_adapter_diff.safetensors"
7
+ WANDB_ENTITY="YOUR_WANDB_ENTITY"
8
+ WANDB_PROJECT="YOUR_WANDB_PROJECT"
9
+ EVAL_LOG_PATH="experiments/eval_logs/${VERSION}_output.log"
10
+
11
+ torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune_regular_loss.py \
12
+ --vla_path "$VLA_PATH" \
13
+ --data_root_dir "$DATA_ROOT_DIR" \
14
+ --dataset_name libero_${TASK}_no_noops \
15
+ --run_root_dir "$RUN_ROOT_DIR" \
16
+ --use_l1_regression True \
17
+ --use_diffusion False \
18
+ --use_film False \
19
+ --num_images_in_input 2 \
20
+ --use_proprio True \
21
+ --batch_size 8 \
22
+ --learning_rate 5e-4 \
23
+ --scheduler CosineAnnealingLR \
24
+ --max_steps 150100 \
25
+ --save_freq 150000 \
26
+ --save_latest_checkpoint_only True \
27
+ --merge_lora_during_training True \
28
+ --regularization_lora_vector_path "$REGULARIZATION_LORA_VECTOR_PATH" \
29
+ --regularization_weight 1e-4 \
30
+ --image_aug True \
31
+ --lora_rank 32 \
32
+ --wandb_entity "$WANDB_ENTITY" \
33
+ --wandb_project "$WANDB_PROJECT" \
34
+ --run_id_override "$VERSION"
35
+
36
+ python experiments/robot/libero/run_libero_eval.py --pretrained_checkpoint "$RUN_ROOT_DIR/$VERSION" --task_suite_name libero_${TASK} > "$EVAL_LOG_PATH" 2>&1
capvector-oft/vla-scripts/extern/convert_openvla_weights_to_hf.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ convert_openvla_weights_to_hf.py
3
+
4
+ Utility script for converting full OpenVLA VLA weights (from this repository, in the default "Prismatic" format) to
5
+ the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers``
6
+ via `trust_remote_code = True`.
7
+
8
+ Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the
9
+ line, with first-class support.
10
+
11
+ Usage:
12
+ python vla-scripts/extern/convert_openvla_weights_to_hf.py \
13
+ --openvla_model_path_or_id <PATH TO PRISMATIC TRAINING RUN DIR> \
14
+ --output_hf_model_local_path <OUTPUT DIR FOR CONVERTED CHECKPOINT>
15
+ """
16
+
17
+ import json
18
+ import os
19
+ import shutil
20
+ from dataclasses import dataclass
21
+ from pathlib import Path
22
+ from typing import Dict, Union
23
+
24
+ import draccus
25
+ import timm
26
+ import torch
27
+ import torch.nn as nn
28
+ from huggingface_hub import hf_hub_download
29
+ from timm.models.vision_transformer import LayerScale
30
+ from transformers import AutoTokenizer
31
+
32
+ from prismatic.conf import ModelConfig
33
+ from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
34
+ from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
35
+ from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
36
+
37
+
38
+ @dataclass
39
+ class HFConvertConfig:
40
+ # fmt: off
41
+ openvla_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLA (on disk or HF Hub)
42
+ "runs/prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7"
43
+ )
44
+ output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model
45
+ "hf-convert/openvla-7b"
46
+ )
47
+ output_hf_model_hub_path: str = "openvla/openvla-7b" # (Optional) Path to HF Hub Path to push
48
+ # model to
49
+
50
+ # HF Hub Credentials (required for Gated Models like LLaMa-2)
51
+ hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
52
+
53
+ def __post_init__(self) -> None:
54
+ self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token
55
+
56
+ # fmt: on
57
+
58
+
59
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
60
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
61
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
62
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
64
+
65
+
66
+ def ls_apply_patch(ls_module: LayerScale):
67
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
68
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
69
+ del ls_module.gamma
70
+
71
+
72
+ # === Conversion Constants ===
73
+ PROJECTOR_KEY_MAPPING = {
74
+ "projector.0.weight": "projector.fc1.weight",
75
+ "projector.0.bias": "projector.fc1.bias",
76
+ "projector.2.weight": "projector.fc2.weight",
77
+ "projector.2.bias": "projector.fc2.bias",
78
+ "projector.4.weight": "projector.fc3.weight",
79
+ "projector.4.bias": "projector.fc3.bias",
80
+ }
81
+
82
+
83
+ def remap_state_dicts_for_hf(
84
+ prismatic_vision_backbone_state_dict: Dict[str, torch.Tensor],
85
+ projector_state_dict: Dict[str, torch.Tensor],
86
+ llm_backbone_state_dict: Dict[str, torch.Tensor],
87
+ use_fused_vision_backbone: bool = False,
88
+ ) -> Dict[str, torch.Tensor]:
89
+ """Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion."""
90
+ hf_state_dict = {}
91
+
92
+ # Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING`
93
+ for key, value in projector_state_dict.items():
94
+ hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value
95
+
96
+ # Iterate through LLM Backbone =>> replace `llm.` with `language_model.`
97
+ for key, value in llm_backbone_state_dict.items():
98
+ hf_state_dict[key.replace("llm.", "language_model.")] = value
99
+
100
+ # Iterate through Vision Backbone =>> add "vision_backbone." prefix
101
+ if not use_fused_vision_backbone:
102
+ for key, value in prismatic_vision_backbone_state_dict.items():
103
+ hf_state_dict[key.replace("featurizer.", "vision_backbone.featurizer.")] = value
104
+ else:
105
+ # Note =>> Assumes that backbones are always DINO + SigLIP...
106
+ for key, value in prismatic_vision_backbone_state_dict.items():
107
+ if key.startswith("dino_featurizer"):
108
+ if key.endswith(".gamma"):
109
+ # Handle `LayerScale gamma` =>> DINOv2 only!
110
+ key = key.replace(".gamma", ".scale_factor")
111
+ hf_state_dict[key.replace("dino_featurizer.", "vision_backbone.featurizer.")] = value
112
+ elif key.startswith("siglip_featurizer"):
113
+ hf_state_dict[key.replace("siglip_featurizer.", "vision_backbone.fused_featurizer.")] = value
114
+
115
+ return hf_state_dict
116
+
117
+
118
+ @draccus.wrap()
119
+ def convert_openvla_weights_to_hf(cfg: HFConvertConfig) -> None:
120
+ print(f"[*] Converting OpenVLA Model `{cfg.openvla_model_path_or_id}` to HF Transformers Format")
121
+ torch.set_default_dtype(torch.bfloat16)
122
+
123
+ # Get `config.json`, 'dataset_statistics.json' and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py`
124
+ if os.path.isdir(cfg.openvla_model_path_or_id):
125
+ print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.openvla_model_path_or_id))}`")
126
+ config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
127
+ dataset_statistics_json = run_dir / "dataset_statistics.json"
128
+
129
+ assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
130
+ assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
131
+ assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
132
+ else:
133
+ print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.openvla_model_path_or_id}`")
134
+ config_json = hf_hub_download("openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/config.json")
135
+ checkpoint_pt = hf_hub_download(
136
+ "openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/checkpoints/latest-checkpoint.pt"
137
+ )
138
+ dataset_statistics_json = hf_hub_download(
139
+ "openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/dataset_statistics.json"
140
+ )
141
+
142
+ # Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer
143
+ with open(config_json, "r") as f:
144
+ vla_cfg = json.load(f)["vla"]
145
+ prismatic_config = ModelConfig.get_choice_class(vla_cfg["base_vlm"])().__dict__
146
+
147
+ # Load Normalization Statistics
148
+ with open(dataset_statistics_json, "r") as f:
149
+ norm_stats = json.load(f)
150
+
151
+ # Create HF OpenVLAConfig (`transformers.PretrainedConfig`)
152
+ hf_config = OpenVLAConfig(
153
+ vision_backbone_id=prismatic_config["vision_backbone_id"],
154
+ llm_backbone_id=prismatic_config["llm_backbone_id"],
155
+ arch_specifier=prismatic_config["arch_specifier"],
156
+ image_resize_strategy=prismatic_config["image_resize_strategy"],
157
+ llm_max_length=prismatic_config["llm_max_length"],
158
+ torch_dtype=torch.bfloat16,
159
+ norm_stats=norm_stats,
160
+ )
161
+
162
+ # Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer`
163
+ # TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`!
164
+ print("[*] Instantiating and Patching Tokenizer, LLM Config")
165
+ tokenizer = AutoTokenizer.from_pretrained(
166
+ hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right"
167
+ )
168
+ tokenizer.add_special_tokens({"pad_token": "<PAD>"})
169
+ tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload...
170
+ assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!"
171
+ assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!"
172
+
173
+ # Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate
174
+ hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of
175
+ hf_config.text_config.pad_token_id = hf_config.pad_token_id
176
+ hf_config.text_config.torch_dtype = torch.bfloat16
177
+ assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!"
178
+
179
+ # Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform`
180
+ # =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py`
181
+ print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor")
182
+ input_sizes, interpolations, means, stds = [], [], [], []
183
+ for idx, timm_model_id in enumerate(hf_config.timm_model_ids):
184
+ timm_vision_backbone = timm.create_model(
185
+ timm_model_id,
186
+ pretrained=True,
187
+ num_classes=0,
188
+ img_size=hf_config.image_sizes[idx],
189
+ act_layer=hf_config.timm_override_act_layers[idx],
190
+ )
191
+
192
+ # Get Per-Backbone Image Processing
193
+ data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone)
194
+ input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx]))
195
+ interpolations.append(data_cfg["interpolation"])
196
+ means.append(data_cfg["mean"])
197
+ stds.append(data_cfg["std"])
198
+
199
+ # Patch `LayerScale` because of HF annoying `fix_key` overwrite...
200
+ for module in timm_vision_backbone.modules():
201
+ if isinstance(module, LayerScale):
202
+ ls_apply_patch(module)
203
+
204
+ # Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`)
205
+ hf_image_processor = PrismaticImageProcessor(
206
+ use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
207
+ image_resize_strategy=hf_config.image_resize_strategy,
208
+ input_sizes=input_sizes,
209
+ interpolations=interpolations,
210
+ means=means,
211
+ stds=stds,
212
+ )
213
+
214
+ # Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor)
215
+ print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor")
216
+ hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer)
217
+
218
+ # Load Prismatic Model State Dictionary (in preparation for conversion)
219
+ print("[*] Loading Prismatic VLM State Dictionary from Checkpoint")
220
+ model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"]
221
+ assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?"
222
+ assert all([k in model_state_dict for k in ["vision_backbone", "projector", "llm_backbone"]]), "Missing keys!"
223
+
224
+ # Convert
225
+ print("[*] Running Conversion")
226
+ converted_state_dict = remap_state_dicts_for_hf(
227
+ model_state_dict["vision_backbone"],
228
+ model_state_dict["projector"],
229
+ model_state_dict["llm_backbone"],
230
+ use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
231
+ )
232
+
233
+ # Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM
234
+ print("[*] Building (Randomly Initialized) Model =>> OpenVLAForActionPrediction")
235
+ hf_model = OpenVLAForActionPrediction(hf_config)
236
+ hf_model.load_state_dict(converted_state_dict, strict=True, assign=True)
237
+
238
+ # Cast Model to BF16 before Saving
239
+ hf_model.to(torch.bfloat16)
240
+
241
+ # Save Pretrained Versions to Local Path
242
+ print("[*] Saving Model & Processor to Local Path")
243
+ hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB")
244
+ hf_image_processor.save_pretrained(cfg.output_hf_model_local_path)
245
+ hf_processor.save_pretrained(cfg.output_hf_model_local_path)
246
+
247
+ # Copy `dataset_statistics.json` File to Converted Checkpoint Directory
248
+ output_dataset_statistics_json = cfg.output_hf_model_local_path / "dataset_statistics.json"
249
+ shutil.copyfile(dataset_statistics_json, output_dataset_statistics_json)
250
+
251
+ print(f"[*] Saving Complete! Saved converted checkpoint to: {cfg.output_hf_model_local_path}")
252
+
253
+ #####################################################################################
254
+ # Optional: Push Model to Hugging Face Hub
255
+ #####################################################################################
256
+
257
+ # # Register AutoClasses
258
+ # OpenVLAConfig.register_for_auto_class()
259
+ # PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor")
260
+ # PrismaticProcessor.register_for_auto_class("AutoProcessor")
261
+ # OpenVLAForActionPrediction.register_for_auto_class("AutoModelForVision2Seq")
262
+
263
+ # # Push to HF Hub
264
+ # print("[*] Pushing Model & Processor to HF Hub")
265
+ # hf_config.push_to_hub(cfg.output_hf_model_hub_path)
266
+ # hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB")
267
+ # hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path)
268
+ # hf_processor.push_to_hub(cfg.output_hf_model_hub_path)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ convert_openvla_weights_to_hf()
capvector-oft/vla-scripts/extern/verify_openvla.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ verify_openvla.py
3
+
4
+ Given an HF-exported OpenVLA model, attempt to load via AutoClasses, and verify forward() and predict_action().
5
+ """
6
+
7
+ import time
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+ from transformers import AutoModelForVision2Seq, AutoProcessor
13
+
14
+ # === Verification Arguments
15
+ MODEL_PATH = "openvla/openvla-7b"
16
+ SYSTEM_PROMPT = (
17
+ "A chat between a curious user and an artificial intelligence assistant. "
18
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
19
+ )
20
+ INSTRUCTION = "put spoon on towel"
21
+
22
+
23
+ def get_openvla_prompt(instruction: str) -> str:
24
+ if "v01" in MODEL_PATH:
25
+ return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT:"
26
+ else:
27
+ return f"In: What action should the robot take to {instruction.lower()}?\nOut:"
28
+
29
+
30
+ @torch.inference_mode()
31
+ def verify_openvla() -> None:
32
+ print(f"[*] Verifying OpenVLAForActionPrediction using Model `{MODEL_PATH}`")
33
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
34
+
35
+ # Load Processor & VLA
36
+ print("[*] Instantiating Processor and Pretrained OpenVLA")
37
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
38
+
39
+ # === BFLOAT16 + FLASH-ATTN MODE ===
40
+ print("[*] Loading in BF16 with Flash-Attention Enabled")
41
+ vla = AutoModelForVision2Seq.from_pretrained(
42
+ MODEL_PATH,
43
+ attn_implementation="flash_attention_2",
44
+ torch_dtype=torch.bfloat16,
45
+ low_cpu_mem_usage=True,
46
+ trust_remote_code=True,
47
+ ).to(device)
48
+
49
+ # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] ===
50
+ # print("[*] Loading in 8-Bit Quantization Mode")
51
+ # vla = AutoModelForVision2Seq.from_pretrained(
52
+ # MODEL_PATH,
53
+ # attn_implementation="flash_attention_2",
54
+ # torch_dtype=torch.float16,
55
+ # quantization_config=BitsAndBytesConfig(load_in_8bit=True),
56
+ # low_cpu_mem_usage=True,
57
+ # trust_remote_code=True,
58
+ # )
59
+
60
+ # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] ===
61
+ # print("[*] Loading in 4-Bit Quantization Mode")
62
+ # vla = AutoModelForVision2Seq.from_pretrained(
63
+ # MODEL_PATH,
64
+ # attn_implementation="flash_attention_2",
65
+ # torch_dtype=torch.float16,
66
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True),
67
+ # low_cpu_mem_usage=True,
68
+ # trust_remote_code=True,
69
+ # )
70
+
71
+ print("[*] Iterating with Randomly Generated Images")
72
+ for _ in range(100):
73
+ prompt = get_openvla_prompt(INSTRUCTION)
74
+ image = Image.fromarray(np.asarray(np.random.rand(256, 256, 3) * 255, dtype=np.uint8))
75
+
76
+ # === BFLOAT16 MODE ===
77
+ inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
78
+
79
+ # === 8-BIT/4-BIT QUANTIZATION MODE ===
80
+ # inputs = processor(prompt, image).to(device, dtype=torch.float16)
81
+
82
+ # Run OpenVLA Inference
83
+ start_time = time.time()
84
+ action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
85
+ print(f"\t=>> Time: {time.time() - start_time:.4f} || Action: {action}")
86
+
87
+
88
+ if __name__ == "__main__":
89
+ verify_openvla()
capvector-oft/vla-scripts/finetune.py ADDED
@@ -0,0 +1,1152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ finetune.py
3
+
4
+ Fine-tunes OpenVLA via LoRA.
5
+ """
6
+
7
+ import os
8
+ import time
9
+ from collections import deque
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Dict, Optional, Tuple, Type
13
+
14
+ import draccus
15
+ import torch
16
+ import torch.distributed as dist
17
+ import torch.nn as nn
18
+ import tqdm
19
+ from accelerate import PartialState
20
+ from huggingface_hub import HfApi, snapshot_download
21
+ from peft import LoraConfig, PeftModel, get_peft_model
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+ from torch.optim import AdamW
24
+ from torch.optim.lr_scheduler import MultiStepLR
25
+ from torch.utils.data import DataLoader
26
+ from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
27
+ from transformers.modeling_outputs import CausalLMOutputWithPast
28
+
29
+ import wandb
30
+ os.environ["WANDB_MODE"]="offline"
31
+
32
+ from experiments.robot.openvla_utils import (
33
+ check_model_logic_mismatch,
34
+ model_is_on_hf_hub,
35
+ update_auto_map,
36
+ )
37
+
38
+ from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
39
+ from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
40
+ from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
41
+ from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead
42
+ from prismatic.models.backbones.llm.prompting import PurePromptBuilder
43
+ from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone
44
+ from prismatic.models.projectors import (
45
+ NoisyActionProjector,
46
+ ProprioProjector,
47
+ )
48
+ from prismatic.training.train_utils import (
49
+ compute_actions_l1_loss,
50
+ compute_token_accuracy,
51
+ get_current_action_mask,
52
+ get_next_actions_mask,
53
+ )
54
+ from prismatic.util.data_utils import PaddedCollatorForActionPrediction
55
+ from prismatic.vla.action_tokenizer import ActionTokenizer
56
+ from prismatic.vla.constants import (
57
+ ACTION_DIM,
58
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
59
+ NUM_ACTIONS_CHUNK,
60
+ PROPRIO_DIM,
61
+ )
62
+ from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
63
+ from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics
64
+
65
+ # Sane Defaults
66
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
67
+
68
+
69
+ import debugpy
70
+ try:
71
+ debugpy.listen(("localhost", 9501))
72
+ print("Waiting for debugger attach")
73
+ debugpy.wait_for_client()
74
+ except Exception as e:
75
+ pass
76
+
77
+
78
+ @dataclass
79
+ class FinetuneConfig:
80
+ # fmt: off
81
+ vla_path: str = "openvla/openvla-7b" # Path to OpenVLA model (on HuggingFace Hub or stored locally)
82
+
83
+ # Dataset
84
+ data_root_dir: Path = Path("datasets/rlds") # Directory containing RLDS datasets
85
+ dataset_name: str = "aloha_scoop_x_into_bowl" # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`)
86
+ run_root_dir: Path = Path("runs") # Path to directory to store logs & checkpoints
87
+ shuffle_buffer_size: int = 100_000 # Dataloader shuffle buffer size (can reduce if OOM errors occur)
88
+
89
+ # Algorithm and architecture
90
+ use_l1_regression: bool = True # If True, trains continuous action head with L1 regression objective
91
+ use_diffusion: bool = False # If True, trains continuous action head with diffusion modeling objective (DDIM)
92
+ num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
93
+ use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
94
+ num_images_in_input: int = 1 # Number of images in the VLA input (default: 1)
95
+ use_proprio: bool = False # If True, includes robot proprioceptive state in input
96
+
97
+ # Training configuration
98
+ batch_size: int = 8 # Batch size per device (total batch size = batch_size * num GPUs)
99
+ learning_rate: float = 5e-4 # Learning rate
100
+ lr_warmup_steps: int = 0 # Number of steps to warm up learning rate (from 10% to 100%)
101
+ num_steps_before_decay: int = 100_000 # Number of steps before LR decays by 10x
102
+ grad_accumulation_steps: int = 1 # Number of gradient accumulation steps
103
+ max_steps: int = 200_000 # Max number of training steps
104
+ use_val_set: bool = False # If True, uses validation set and log validation metrics
105
+ val_freq: int = 10_000 # (When `use_val_set==True`) Validation set logging frequency in steps
106
+ val_time_limit: int = 180 # (When `use_val_set==True`) Time limit for computing validation metrics
107
+ save_freq: int = 10_000 # Checkpoint saving frequency in steps
108
+ save_latest_checkpoint_only: bool = False # If True, saves only 1 checkpoint, overwriting latest checkpoint
109
+ # (If False, saves all checkpoints)
110
+ resume: bool = False # If True, resumes from checkpoint
111
+ resume_step: Optional[int] = None # (When `resume==True`) Step number that we are resuming from
112
+ image_aug: bool = True # If True, trains with image augmentations (HIGHLY RECOMMENDED)
113
+ diffusion_sample_freq: int = 50 # (When `use_diffusion==True`) Frequency for sampling in steps
114
+
115
+ # LoRA
116
+ use_lora: bool = True # If True, uses LoRA fine-tuning
117
+ lora_rank: int = 32 # Rank of LoRA weight matrix
118
+ lora_dropout: float = 0.0 # Dropout applied to LoRA weights
119
+ merge_lora_during_training: bool = True # If True, merges LoRA weights and saves result during training
120
+ # Note: Merging can be very slow on some machines. If so, set to
121
+ # False and merge final checkpoint offline!
122
+
123
+ # Logging
124
+ wandb_entity: str = "your-wandb-entity" # Name of WandB entity
125
+ wandb_project: str = "your-wandb-project" # Name of WandB project
126
+ run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
127
+ run_id_override: Optional[str] = None # Optional string to override the run ID with
128
+ wandb_log_freq: int = 10 # WandB logging frequency in steps
129
+
130
+ # fmt: on
131
+
132
+
133
+ def remove_ddp_in_checkpoint(state_dict) -> dict:
134
+ """
135
+ Removes the 'module.' prefix from parameter names in a PyTorch model state dictionary that was saved using
136
+ DistributedDataParallel (DDP).
137
+
138
+ When a model is trained using PyTorch's DistributedDataParallel, the saved state dictionary contains parameters
139
+ prefixed with 'module.'. This function removes these prefixes to make the state dictionary compatible when
140
+ loading into models that are not yet wrapped in DDP.
141
+
142
+ Args:
143
+ state_dict (dict): PyTorch model state dictionary.
144
+
145
+ Returns:
146
+ dict: A new state dictionary with the same contents but with 'module.' prefixes removed from parameter names.
147
+ Parameters without the 'module.' prefix remain unchanged.
148
+ """
149
+ new_state_dict = {}
150
+ for k, v in state_dict.items():
151
+ if k[:7] == "module.":
152
+ new_state_dict[k[7:]] = v
153
+ else:
154
+ new_state_dict[k] = v
155
+ return new_state_dict
156
+
157
+
158
+ def get_run_id(cfg) -> str:
159
+ """
160
+ Generates or retrieves an identifier string for an experiment run.
161
+
162
+ Args:
163
+ cfg (FinetuneConfig): Training configuration.
164
+
165
+ Returns:
166
+ str: Experiment run ID.
167
+ """
168
+ if cfg.run_id_override is not None:
169
+ # Override the run ID with the user-provided ID
170
+ run_id = cfg.run_id_override
171
+ elif cfg.resume:
172
+ # Override run ID with the previous resumed run's ID
173
+ run_id = cfg.vla_path.split("/")[-1]
174
+ # Remove the "--XXX_chkpt" suffix from the run ID if it exists
175
+ if "chkpt" in run_id.split("--")[-1]:
176
+ run_id = "--".join(run_id.split("--")[:-1])
177
+ else:
178
+ run_id = (
179
+ f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}"
180
+ f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
181
+ f"+lr-{cfg.learning_rate}"
182
+ )
183
+ if cfg.use_lora:
184
+ run_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}"
185
+ if cfg.image_aug:
186
+ run_id += "--image_aug"
187
+ if cfg.run_id_note is not None:
188
+ run_id += f"--{cfg.run_id_note}"
189
+ return run_id
190
+
191
+
192
+ def load_checkpoint(module_name: str, path: str, step: int, device: str = "cpu") -> dict:
193
+ """
194
+ Loads a checkpoint for a given module.
195
+
196
+ Args:
197
+ module_name (str): Name of model component to load checkpoint for.
198
+ path (str): Path to checkpoint directory.
199
+ step (int): Gradient step number of saved checkpoint.
200
+ device (str): String specifying how to remap storage locations (default = "cpu").
201
+
202
+ Returns:
203
+ dict: PyTorch model state dictionary.
204
+ """
205
+ checkpoint_path = os.path.join(path, f"{module_name}--{step}_checkpoint.pt")
206
+ print(f"Loading checkpoint: {checkpoint_path}")
207
+ state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
208
+ return remove_ddp_in_checkpoint(state_dict)
209
+
210
+
211
+ def wrap_ddp(module: nn.Module, device_id: int, find_unused: bool = False) -> DDP:
212
+ """
213
+ Wrap a module with DistributedDataParallel.
214
+
215
+ Args:
216
+ module (nn.Module): PyTorch module.
217
+ device_id (str): Device ID.
218
+ find_unused (bool): Whether to detect parameters without gradients in distributed training.
219
+
220
+ Returns:
221
+ DistributedDataParallel: PyTorch module wrapped with DDP.
222
+ """
223
+ return DDP(module, device_ids=[device_id], find_unused_parameters=find_unused, gradient_as_bucket_view=True)
224
+
225
+
226
+ def count_parameters(module: nn.Module, name: str) -> None:
227
+ """
228
+ Counts and prints the number of trainable parameters in a module.
229
+
230
+ Args:
231
+ module (nn.Module): PyTorch module.
232
+ module_name (str): Name of model component.
233
+
234
+ Returns:
235
+ None.
236
+ """
237
+ num_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
238
+ print(f"# trainable params in {name}: {num_params}")
239
+
240
+
241
+ def init_module(
242
+ module_class: Type[nn.Module],
243
+ module_name: str,
244
+ cfg: FinetuneConfig,
245
+ device_id: int,
246
+ module_args: dict,
247
+ to_bf16: bool = False,
248
+ find_unused_params: bool = False,
249
+ ) -> DDP:
250
+ """
251
+ Initializes a module, optionally loads checkpoint, moves to device, and wraps with DDP.
252
+
253
+ Args:
254
+ module_class (Type[nn.Module]): Class of PyTorch module to initialize.
255
+ module_name (str): Name of model component to load checkpoint for.
256
+ cfg (FinetuneConfig): Training configuration.
257
+ device_id (str): Device ID.
258
+ module_args (dict): Args for initializing the module.
259
+ to_bf16 (bool): Whether to convert to torch.bfloat16 data type.
260
+ find_unused_params (bool): Whether to detect parameters without gradients in distributed training.
261
+
262
+ Returns:
263
+ DistributedDataParallel: PyTorch module wrapped with DDP.
264
+ """
265
+ module = module_class(**module_args)
266
+ count_parameters(module, module_name)
267
+
268
+ if cfg.resume:
269
+ state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step)
270
+ module.load_state_dict(state_dict)
271
+
272
+ if to_bf16:
273
+ module = module.to(torch.bfloat16)
274
+ module = module.to(device_id)
275
+
276
+ return wrap_ddp(module, device_id, find_unused_params)
277
+
278
+
279
+ def run_forward_pass(
280
+ vla,
281
+ action_head,
282
+ noisy_action_projector,
283
+ proprio_projector,
284
+ batch,
285
+ action_tokenizer,
286
+ device_id,
287
+ use_l1_regression,
288
+ use_diffusion,
289
+ use_proprio,
290
+ use_film,
291
+ num_patches,
292
+ compute_diffusion_l1=False,
293
+ num_diffusion_steps_train=None,
294
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
295
+ """
296
+ Compute model forward pass and metrics for both training and validation.
297
+
298
+ Args:
299
+ vla (OpenVLAForActionPrediction): Vision-language-action policy.
300
+ action_head (nn.Module): Action head module.
301
+ noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
302
+ proprio_projector (nn.Module): Proprioceptive state projector module.
303
+ batch (dict): Input batch.
304
+ action_tokenizer (ActionTokenizer): Action tokenizer.
305
+ device_id (str): Device ID.
306
+ use_l1_regression (bool): Whether to use L1 regression.
307
+ use_diffusion (bool): Whether to use diffusion.
308
+ use_proprio (bool): Whether to use proprioceptive state as input.
309
+ use_film (bool): Whether to use FiLM for better language following.
310
+ num_patches (int): Number of vision patches.
311
+ compute_diffusion_l1 (bool): Whether to sample actions and compute L1 loss for diffusion (do this once every
312
+ diffusion_sample_freq steps during training; do it every batch for validation)
313
+ num_diffusion_steps_train (int): Number of diffusion steps for training (only used for diffusion).
314
+
315
+ Returns:
316
+ tuple: (loss, metrics_dict)
317
+ loss: The loss tensor with gradient for backpropagation.
318
+ metrics_dict: Dictionary of computed metrics (detached values for logging).
319
+ """
320
+ metrics = {}
321
+
322
+ # Get ground-truth action labels
323
+ ground_truth_actions = batch["actions"].to(device_id).to(torch.bfloat16)
324
+
325
+ # [Only for diffusion] Sample noisy actions used as input for noise predictor network
326
+ if use_diffusion:
327
+ noisy_dict = action_head.module.sample_noisy_actions(ground_truth_actions)
328
+ noise, noisy_actions, diffusion_timestep_embeddings = (
329
+ noisy_dict["noise"],
330
+ noisy_dict["noisy_actions"],
331
+ noisy_dict["diffusion_timestep_embeddings"],
332
+ )
333
+ else:
334
+ noise, noisy_actions, diffusion_timestep_embeddings = None, None, None
335
+
336
+ # VLA forward pass
337
+ with torch.autocast("cuda", dtype=torch.bfloat16):
338
+ output: CausalLMOutputWithPast = vla(
339
+ input_ids=batch["input_ids"].to(device_id),
340
+ attention_mask=batch["attention_mask"].to(device_id),
341
+ pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
342
+ labels=batch["labels"],
343
+ output_hidden_states=True,
344
+ proprio=batch["proprio"] if use_proprio else None,
345
+ proprio_projector=proprio_projector if use_proprio else None,
346
+ noisy_actions=noisy_actions if use_diffusion else None,
347
+ noisy_action_projector=noisy_action_projector if use_diffusion else None,
348
+ diffusion_timestep_embeddings=diffusion_timestep_embeddings if use_diffusion else None,
349
+ use_film=use_film,
350
+ )
351
+
352
+ # Get action masks needed for logging
353
+ ground_truth_token_ids = batch["labels"][:, 1:].to(device_id)
354
+ current_action_mask = get_current_action_mask(ground_truth_token_ids)
355
+ next_actions_mask = get_next_actions_mask(ground_truth_token_ids)
356
+
357
+ # Compute metrics for discrete action representation (next-token prediction)
358
+ if not (use_l1_regression or use_diffusion):
359
+ loss = output.loss
360
+ predicted_token_ids = output.logits[:, num_patches:-1].argmax(dim=2)
361
+ curr_action_accuracy = compute_token_accuracy(
362
+ predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
363
+ )
364
+ curr_action_l1_loss = compute_actions_l1_loss(
365
+ action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
366
+ )
367
+ next_actions_accuracy = compute_token_accuracy(
368
+ predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
369
+ )
370
+ next_actions_l1_loss = compute_actions_l1_loss(
371
+ action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
372
+ )
373
+ metrics.update(
374
+ {
375
+ "loss_value": loss.item(), # Detached value for logging
376
+ "curr_action_accuracy": curr_action_accuracy.item(),
377
+ "curr_action_l1_loss": curr_action_l1_loss.item(),
378
+ "next_actions_accuracy": next_actions_accuracy.item(),
379
+ "next_actions_l1_loss": next_actions_l1_loss.item(),
380
+ }
381
+ )
382
+ # Compute metrics for continuous action representations (L1 regression | diffusion)
383
+ else:
384
+ # Get last layer hidden states
385
+ last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
386
+ # Get hidden states for text portion of prompt+response (after the vision patches)
387
+ text_hidden_states = last_hidden_states[:, num_patches:-1]
388
+ # Get hidden states for action portion of response
389
+ batch_size = batch["input_ids"].shape[0]
390
+ actions_hidden_states = (
391
+ text_hidden_states[current_action_mask | next_actions_mask]
392
+ .reshape(batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1)
393
+ .to(torch.bfloat16)
394
+ ) # (B, act_chunk_len, D)
395
+
396
+ if use_l1_regression:
397
+ # Predict action
398
+ predicted_actions = action_head.module.predict_action(actions_hidden_states)
399
+ # Get full L1 loss
400
+ loss = torch.nn.L1Loss()(ground_truth_actions, predicted_actions)
401
+
402
+ if use_diffusion:
403
+ # Predict noise
404
+ noise_pred = action_head.module.predict_noise(actions_hidden_states)
405
+ # Get diffusion noise prediction MSE loss
406
+ noise_pred = noise_pred.reshape(noise.shape)
407
+ loss = nn.functional.mse_loss(noise_pred, noise, reduction="mean")
408
+
409
+ # Only sample actions and compute L1 losses if specified
410
+ if compute_diffusion_l1:
411
+ with torch.no_grad():
412
+ predicted_actions = run_diffusion_sampling(
413
+ vla=vla,
414
+ action_head=action_head,
415
+ noisy_action_projector=noisy_action_projector,
416
+ proprio_projector=proprio_projector,
417
+ batch=batch,
418
+ batch_size=batch_size,
419
+ num_patches=num_patches,
420
+ actions_shape=ground_truth_actions.shape,
421
+ device_id=device_id,
422
+ current_action_mask=current_action_mask,
423
+ next_actions_mask=next_actions_mask,
424
+ use_proprio=use_proprio,
425
+ use_film=use_film,
426
+ )
427
+
428
+ metrics.update(
429
+ {
430
+ "loss_value": loss.item(), # Detached value for logging
431
+ }
432
+ )
433
+
434
+ # Get detailed L1 losses for logging
435
+ should_log_l1_loss = not use_diffusion or (use_diffusion and compute_diffusion_l1)
436
+ if should_log_l1_loss:
437
+ ground_truth_curr_action = ground_truth_actions[:, 0]
438
+ predicted_curr_action = predicted_actions[:, 0]
439
+ ground_truth_next_actions = ground_truth_actions[:, 1:]
440
+ predicted_next_actions = predicted_actions[:, 1:]
441
+ curr_action_l1_loss = torch.nn.L1Loss()(ground_truth_curr_action, predicted_curr_action)
442
+ next_actions_l1_loss = torch.nn.L1Loss()(ground_truth_next_actions, predicted_next_actions)
443
+ metrics.update(
444
+ {
445
+ "curr_action_l1_loss": curr_action_l1_loss.item(),
446
+ "next_actions_l1_loss": next_actions_l1_loss.item(),
447
+ }
448
+ )
449
+
450
+ # Return both the loss tensor (with gradients) and the metrics dictionary (with detached values)
451
+ return loss, metrics
452
+
453
+
454
+ def run_diffusion_sampling(
455
+ vla,
456
+ action_head,
457
+ noisy_action_projector,
458
+ proprio_projector,
459
+ batch,
460
+ batch_size,
461
+ num_patches,
462
+ actions_shape,
463
+ device_id,
464
+ current_action_mask,
465
+ next_actions_mask,
466
+ use_proprio,
467
+ use_film,
468
+ ) -> torch.Tensor:
469
+ """
470
+ Run diffusion sampling (reverse diffusion) to generate actions.
471
+
472
+ Args:
473
+ vla (OpenVLAForActionPrediction): Vision-language-action policy.
474
+ action_head (nn.Module): Action head module.
475
+ noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
476
+ proprio_projector (nn.Module): Proprioceptive state projector module.
477
+ batch (dict): Input batch.
478
+ batch_size (int): Batch size.
479
+ num_patches (int): Number of vision patches.
480
+ actions_shape (tuple): Shape of ground-truth actions.
481
+ device_id (str): Device ID.
482
+ current_action_mask (torch.Tensor): Mask for current action.
483
+ next_actions_mask (torch.Tensor): Mask for next actions.
484
+ use_proprio (bool): Whether to use proprioceptive state as input.
485
+ use_film (bool): Whether to use FiLM for better language following.
486
+
487
+ Returns:
488
+ torch.Tensor: Predicted actions.
489
+ """
490
+ # Sample random noisy action, used as the starting point for reverse diffusion
491
+ noise = torch.randn(
492
+ size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM),
493
+ device=device_id,
494
+ dtype=torch.bfloat16,
495
+ ) # (B, chunk_len, action_dim)
496
+
497
+ # Set diffusion timestep values
498
+ action_head.module.noise_scheduler.set_timesteps(action_head.module.num_diffusion_steps_train)
499
+
500
+ # Reverse diffusion: Iteratively denoise to generate action, conditioned on observation
501
+ curr_noisy_actions = noise
502
+ for t in action_head.module.noise_scheduler.timesteps:
503
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action embedding,
504
+ # and diffusion timestep embedding)
505
+ timesteps = torch.Tensor([t]).repeat(batch_size).to(device_id)
506
+ diffusion_timestep_embeddings = (
507
+ action_head.module.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
508
+ ) # (B, llm_dim)
509
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
510
+
511
+ with torch.autocast("cuda", dtype=torch.bfloat16):
512
+ output = vla(
513
+ input_ids=batch["input_ids"].to(device_id),
514
+ attention_mask=batch["attention_mask"].to(device_id),
515
+ pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
516
+ labels=batch["labels"],
517
+ output_hidden_states=True,
518
+ proprio=batch["proprio"] if use_proprio else None,
519
+ proprio_projector=proprio_projector if use_proprio else None,
520
+ noisy_actions=curr_noisy_actions,
521
+ noisy_action_projector=noisy_action_projector,
522
+ diffusion_timestep_embeddings=diffusion_timestep_embeddings,
523
+ use_film=use_film,
524
+ )
525
+ # Get last layer hidden states
526
+ last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
527
+ # Get hidden states for text portion of prompt+response (after the vision patches)
528
+ text_hidden_states = last_hidden_states[:, num_patches:-1]
529
+ # Get hidden states for action portion of response
530
+ actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask].reshape(
531
+ batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1
532
+ ) # (B, act_chunk_len, D)
533
+ actions_hidden_states = actions_hidden_states.to(torch.bfloat16)
534
+ # Predict noise
535
+ noise_pred = action_head.module.predict_noise(actions_hidden_states)
536
+
537
+ # Compute the action at the previous diffusion timestep: x_t -> x_{t-1}
538
+ curr_noisy_actions = action_head.module.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
539
+
540
+ return curr_noisy_actions.reshape(actions_shape)
541
+
542
+
543
+ def compute_smoothened_metrics(metrics_deques) -> dict:
544
+ """
545
+ Compute smoothened metrics from recent deques.
546
+
547
+ Args:
548
+ metrics_deques (dict): Dictionary of deques containing recent metrics.
549
+
550
+ Returns:
551
+ dict: Dictionary of smoothened metrics.
552
+ """
553
+ smoothened_metrics = {}
554
+ for name, deque in metrics_deques.items():
555
+ if deque and len(deque) > 0:
556
+ smoothened_metrics[name] = sum(deque) / len(deque)
557
+ return smoothened_metrics
558
+
559
+
560
+ def log_metrics_to_wandb(metrics, prefix, step, wandb_entity) -> None:
561
+ """
562
+ Log metrics to Weights & Biases.
563
+
564
+ Args:
565
+ metrics (dict): Dictionary of metrics to log
566
+ prefix (str): Prefix for metric names
567
+ step (int): Training step
568
+ wandb_entity (str): W&B entity instance
569
+
570
+ Returns:
571
+ None.
572
+ """
573
+ log_dict = {}
574
+ for name, value in metrics.items():
575
+ # Map loss_value to Loss for better readability in W&B
576
+ if name == "loss_value":
577
+ log_dict[f"{prefix}/Loss"] = value
578
+ # Keep other metrics as is
579
+ else:
580
+ log_dict[f"{prefix}/{name.replace('_', ' ').title()}"] = value
581
+ wandb_entity.log(log_dict, step=step)
582
+
583
+
584
+ def save_training_checkpoint(
585
+ cfg,
586
+ run_dir,
587
+ log_step,
588
+ vla,
589
+ processor,
590
+ proprio_projector,
591
+ noisy_action_projector,
592
+ action_head,
593
+ train_dataset,
594
+ distributed_state,
595
+ ) -> None:
596
+ """
597
+ Save all training checkpoints including model components, LoRA adapter, and dataset statistics.
598
+
599
+ Args:
600
+ cfg (FinetuneConfig): Training configuration.
601
+ run_dir (Path): Experiment run directory path.
602
+ log_step (int): Current logging step.
603
+ vla (OpenVLAForActionPrediction): Vision-language-action policy.
604
+ processor (PrismaticProcessor): OpenVLA inputs processor.
605
+ proprio_projector (nn.Module): Proprioceptive state projector module.
606
+ noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
607
+ action_head (nn.Module): Action head module.
608
+ train_dataset (RLDSDataset): Training dataset.
609
+ distributed_state (PartialState): Distributed training state.
610
+
611
+ Returns:
612
+ None.
613
+ """
614
+ # Determine checkpoint paths and naming
615
+ if cfg.save_latest_checkpoint_only:
616
+ checkpoint_dir = run_dir
617
+ checkpoint_name_suffix = "latest_checkpoint.pt"
618
+ else:
619
+ checkpoint_dir = Path(str(run_dir) + f"--{log_step}_chkpt")
620
+ checkpoint_name_suffix = f"{log_step}_checkpoint.pt"
621
+
622
+ adapter_dir = checkpoint_dir / "lora_adapter"
623
+
624
+ # Create directories and save dataset statistics (main process only)
625
+ if distributed_state.is_main_process:
626
+ os.makedirs(checkpoint_dir, exist_ok=True)
627
+ os.makedirs(adapter_dir, exist_ok=True)
628
+ save_dataset_statistics(train_dataset.dataset_statistics, checkpoint_dir)
629
+ print(f"Saving Model Checkpoint for Step {log_step}")
630
+
631
+ # Wait for directories to be created
632
+ dist.barrier()
633
+
634
+ # Save model components (main process only)
635
+ if distributed_state.is_main_process:
636
+ # Save processor and LoRA adapter
637
+ processor.save_pretrained(checkpoint_dir)
638
+ vla.module.save_pretrained(adapter_dir)
639
+
640
+ # Save other components
641
+ if cfg.use_proprio and proprio_projector is not None:
642
+ torch.save(proprio_projector.state_dict(), checkpoint_dir / f"proprio_projector--{checkpoint_name_suffix}")
643
+
644
+ if cfg.use_diffusion and noisy_action_projector is not None:
645
+ torch.save(
646
+ noisy_action_projector.state_dict(), checkpoint_dir / f"noisy_action_projector--{checkpoint_name_suffix}"
647
+ )
648
+
649
+ if (cfg.use_l1_regression or cfg.use_diffusion) and action_head is not None:
650
+ torch.save(action_head.state_dict(), checkpoint_dir / f"action_head--{checkpoint_name_suffix}")
651
+
652
+ if cfg.use_film:
653
+ # To be safe, just save the entire vision backbone (not just FiLM components)
654
+ torch.save(
655
+ vla.module.vision_backbone.state_dict(), checkpoint_dir / f"vision_backbone--{checkpoint_name_suffix}"
656
+ )
657
+
658
+ # Wait for model components to be saved
659
+ dist.barrier()
660
+
661
+ # Merge LoRA weights into base model and save resulting model checkpoint
662
+ # Note: Can be very slow on some devices; if so, we recommend merging offline
663
+ if cfg.use_lora and cfg.merge_lora_during_training:
664
+ base_vla = AutoModelForVision2Seq.from_pretrained(
665
+ cfg.vla_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
666
+ )
667
+ merged_vla = PeftModel.from_pretrained(base_vla, adapter_dir)
668
+ merged_vla = merged_vla.merge_and_unload()
669
+
670
+ if distributed_state.is_main_process:
671
+ merged_vla.save_pretrained(checkpoint_dir)
672
+ print(f"Saved merged model for Step {log_step} at: {checkpoint_dir}")
673
+
674
+ # Wait for merged model to be saved
675
+ dist.barrier()
676
+
677
+
678
+ def run_validation(
679
+ vla,
680
+ action_head,
681
+ noisy_action_projector,
682
+ proprio_projector,
683
+ val_dataloader,
684
+ action_tokenizer,
685
+ device_id,
686
+ cfg,
687
+ num_patches,
688
+ log_step,
689
+ distributed_state,
690
+ val_time_limit,
691
+ ) -> None:
692
+ """
693
+ Compute validation set metrics for logging.
694
+
695
+ Args:
696
+ vla (OpenVLAForActionPrediction): Vision-language-action policy.
697
+ action_head (nn.Module): Action head module.
698
+ noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
699
+ proprio_projector (nn.Module): Proprioceptive state projector module.
700
+ val_dataloader (DataLoader): Validation data loader.
701
+ action_tokenizer (ActionTokenizer): Action tokenizer.
702
+ device_id (str): Device ID.
703
+ cfg (FinetuneConfig): Training configuration.
704
+ num_patches (int): Number of vision patches.
705
+ log_step (int): Current logging step.
706
+ distributed_state (PartialState): Distributed training state.
707
+ val_time_limit (int): Time limit for computing validation metrics.
708
+
709
+ Returns:
710
+ None.
711
+ """
712
+ val_start_time = time.time()
713
+ vla.eval()
714
+ val_batches_count = 0
715
+
716
+ # List to store validation metrics
717
+ all_val_metrics = []
718
+
719
+ with torch.no_grad():
720
+ for batch in val_dataloader:
721
+ # Always compute L1 loss for validation, even for diffusion
722
+ _, metrics = run_forward_pass(
723
+ vla=vla,
724
+ action_head=action_head,
725
+ noisy_action_projector=noisy_action_projector,
726
+ proprio_projector=proprio_projector,
727
+ batch=batch,
728
+ action_tokenizer=action_tokenizer,
729
+ device_id=device_id,
730
+ use_l1_regression=cfg.use_l1_regression,
731
+ use_diffusion=cfg.use_diffusion,
732
+ use_proprio=cfg.use_proprio,
733
+ use_film=cfg.use_film,
734
+ num_patches=num_patches,
735
+ compute_diffusion_l1=True,
736
+ num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
737
+ )
738
+
739
+ # Add the loss value to the metrics
740
+ metrics["loss"] = metrics["loss_value"]
741
+ all_val_metrics.append(metrics)
742
+ val_batches_count += 1
743
+
744
+ # Cut testing on validation set short if it exceeds time limit
745
+ if time.time() - val_start_time > val_time_limit:
746
+ break
747
+
748
+ # Compute average validation metrics
749
+ avg_val_metrics = {}
750
+ for metric_name in all_val_metrics[0].keys():
751
+ values = [metrics[metric_name] for metrics in all_val_metrics if metric_name in metrics]
752
+ if values:
753
+ avg_val_metrics[metric_name] = sum(values) / len(values)
754
+
755
+ # Add batch count to metrics
756
+ avg_val_metrics["val_batches_count"] = val_batches_count
757
+
758
+ # Log validation metrics to W&B
759
+ if distributed_state.is_main_process:
760
+ log_metrics_to_wandb(avg_val_metrics, "VLA Val", log_step, wandb)
761
+
762
+
763
+ @draccus.wrap()
764
+ def finetune(cfg: FinetuneConfig) -> None:
765
+ """
766
+ Fine-tunes base VLA on demonstration dataset via LoRA.
767
+
768
+ Allows toggling different action representations (discrete vs. continuous), different learning objectives
769
+ (next-token prediction vs. L1 regression vs. diffusion), FiLM. Also allows for additional model inputs,
770
+ such as additional camera images and robot proprioceptive state. Assumes parallel action generation with
771
+ action chunking.
772
+
773
+ Args:
774
+ cfg (FinetuneConfig): Training configuration.
775
+
776
+ Returns:
777
+ None.
778
+ """
779
+ assert cfg.use_lora, "Only LoRA fine-tuning is supported. Please set --use_lora=True!"
780
+ assert not (cfg.use_l1_regression and cfg.use_diffusion), (
781
+ "Cannot do both L1 regression and diffusion. Please pick one of them!"
782
+ )
783
+
784
+ # Trim trailing forward slash ('/') in VLA path if it exists
785
+ cfg.vla_path = cfg.vla_path.rstrip("/")
786
+ print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`")
787
+
788
+ # Get experiment run ID
789
+ run_id = get_run_id(cfg)
790
+
791
+ # Create experiment run directory
792
+ run_dir = cfg.run_root_dir / run_id
793
+ os.makedirs(run_dir, exist_ok=True)
794
+
795
+ # GPU setup
796
+ distributed_state = PartialState()
797
+ device_id = distributed_state.local_process_index
798
+ torch.cuda.set_device(device_id)
799
+ torch.cuda.empty_cache()
800
+
801
+ # Initialize wandb logging
802
+ if distributed_state.is_main_process:
803
+ wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=run_id)
804
+
805
+ # Print detected constants
806
+ print(
807
+ "Detected constants:\n"
808
+ f"\tNUM_ACTIONS_CHUNK: {NUM_ACTIONS_CHUNK}\n"
809
+ f"\tACTION_DIM: {ACTION_DIM}\n"
810
+ f"\tPROPRIO_DIM: {PROPRIO_DIM}\n"
811
+ f"\tACTION_PROPRIO_NORMALIZATION_TYPE: {ACTION_PROPRIO_NORMALIZATION_TYPE}"
812
+ )
813
+
814
+ # Two options:
815
+ # (1) Base model is on Hugging Face Hub
816
+ # - Then download it and record the path to the download directory
817
+ # (2) Base model is stored locally
818
+ # - Then register model config in HF Auto Classes
819
+ # In both cases, we want to check whether any changes have been made to
820
+ # the `modeling_prismatic.py` file in this codebase; if so, we will copy
821
+ # the file to the downloaded or locally stored checkpoint directory so
822
+ # that the user's changes to the VLA class logic go into effect
823
+ if model_is_on_hf_hub(cfg.vla_path):
824
+ # Download model directly from Hugging Face Hub
825
+ vla_download_path = snapshot_download(repo_id=cfg.vla_path)
826
+ # Overwrite VLA path
827
+ cfg.vla_path = vla_download_path
828
+ else:
829
+ # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub)
830
+ AutoConfig.register("openvla", OpenVLAConfig)
831
+ AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
832
+ AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
833
+ AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
834
+
835
+ # Update config.json and sync model files
836
+ if distributed_state.is_main_process:
837
+ update_auto_map(cfg.vla_path)
838
+ check_model_logic_mismatch(cfg.vla_path)
839
+
840
+ # Wait for model files to be synced
841
+ dist.barrier()
842
+
843
+ # Load processor and VLA
844
+ processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
845
+ vla = AutoModelForVision2Seq.from_pretrained(
846
+ cfg.vla_path,
847
+ torch_dtype=torch.bfloat16,
848
+ low_cpu_mem_usage=True,
849
+ trust_remote_code=True,
850
+ ).to(device_id)
851
+
852
+ # Set number of images in VLA input
853
+ vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input)
854
+
855
+ # LoRA setup
856
+ if cfg.use_lora:
857
+ lora_config = LoraConfig(
858
+ r=cfg.lora_rank,
859
+ lora_alpha=min(cfg.lora_rank, 16),
860
+ lora_dropout=cfg.lora_dropout,
861
+ target_modules="all-linear",
862
+ init_lora_weights="gaussian",
863
+ )
864
+ vla = get_peft_model(vla, lora_config)
865
+ vla.print_trainable_parameters()
866
+
867
+ # FiLM setup
868
+ if cfg.use_film:
869
+ count_parameters(vla.vision_backbone, "vla.vision_backbone (original)")
870
+ # Wrap vision backbone with FiLM wrapper
871
+ # Important: For this, must specify `vla.model.vision_backbone` instead of just `vla.vision_backbone`, since the
872
+ # latter would cause the new wrapped backbone to be saved as a new attribute of `vla` instead of overwriting the
873
+ # original one (due to the LoRA wrapper)
874
+ vla.model.vision_backbone = FiLMedPrismaticVisionBackbone(
875
+ vision_backbone=vla.model.vision_backbone,
876
+ llm_dim=vla.llm_dim,
877
+ )
878
+ count_parameters(vla.vision_backbone, "vla.vision_backbone (post-wrap)")
879
+ if cfg.resume:
880
+ state_dict = load_checkpoint("vision_backbone", cfg.vla_path, cfg.resume_step)
881
+ vla.model.vision_backbone.load_state_dict(state_dict)
882
+ vla.model.vision_backbone = vla.model.vision_backbone.to(device_id)
883
+
884
+ # Wrap VLA with DDP
885
+ vla = wrap_ddp(vla, device_id, find_unused=True)
886
+
887
+ # If applicable, instantiate proprio projector
888
+ if cfg.use_proprio:
889
+ proprio_projector = init_module(
890
+ ProprioProjector,
891
+ "proprio_projector",
892
+ cfg,
893
+ device_id,
894
+ {"llm_dim": vla.module.llm_dim, "proprio_dim": PROPRIO_DIM},
895
+ )
896
+
897
+ # If applicable, instantiate continuous action head for L1 regression
898
+ if cfg.use_l1_regression:
899
+ action_head = init_module(
900
+ L1RegressionActionHead,
901
+ "action_head",
902
+ cfg,
903
+ device_id,
904
+ {"input_dim": vla.module.llm_dim, "hidden_dim": vla.module.llm_dim, "action_dim": ACTION_DIM},
905
+ to_bf16=True,
906
+ )
907
+
908
+ # If applicable, instantiate diffusion action head and noisy action projector
909
+ if cfg.use_diffusion:
910
+ action_head = init_module(
911
+ DiffusionActionHead,
912
+ "action_head",
913
+ cfg,
914
+ device_id,
915
+ {
916
+ "input_dim": vla.module.llm_dim,
917
+ "hidden_dim": vla.module.llm_dim,
918
+ "action_dim": ACTION_DIM,
919
+ "num_diffusion_steps_train": cfg.num_diffusion_steps_train,
920
+ },
921
+ to_bf16=True,
922
+ )
923
+ noisy_action_projector = init_module(
924
+ NoisyActionProjector, "noisy_action_projector", cfg, device_id, {"llm_dim": vla.module.llm_dim}
925
+ )
926
+
927
+ # Get number of vision patches
928
+ NUM_PATCHES = vla.module.vision_backbone.get_num_patches() * vla.module.vision_backbone.get_num_images_in_input()
929
+ # If we have proprio inputs, a single proprio embedding is appended to the end of the vision patch embeddings
930
+ if cfg.use_proprio:
931
+ NUM_PATCHES += 1
932
+ # For diffusion, a single diffusion timestep embedding is appended to the end of the vision patch embeddings
933
+ if cfg.use_diffusion:
934
+ NUM_PATCHES += 1
935
+
936
+ # Instantiate optimizer
937
+ trainable_params = [param for param in vla.parameters() if param.requires_grad]
938
+ if cfg.use_l1_regression or cfg.use_diffusion:
939
+ trainable_params += [param for param in action_head.parameters() if param.requires_grad]
940
+ if cfg.use_diffusion:
941
+ trainable_params += [param for param in noisy_action_projector.parameters() if param.requires_grad]
942
+ if cfg.use_proprio:
943
+ trainable_params += [param for param in proprio_projector.parameters() if param.requires_grad]
944
+ print(f"# total trainable params: {sum(p.numel() for p in trainable_params)}")
945
+ optimizer = AdamW(trainable_params, lr=cfg.learning_rate)
946
+
947
+ # Record original learning rate
948
+ original_lr = optimizer.param_groups[0]["lr"]
949
+
950
+ # Create learning rate scheduler
951
+ scheduler = MultiStepLR(
952
+ optimizer,
953
+ milestones=[cfg.num_steps_before_decay], # Number of steps after which LR will change
954
+ gamma=0.1, # Multiplicative factor of learning rate decay
955
+ )
956
+
957
+ # Create Action Tokenizer
958
+ action_tokenizer = ActionTokenizer(processor.tokenizer)
959
+
960
+ # Load Fine-tuning Dataset =>> note that we use an RLDS-formatted dataset following Open X-Embodiment by default.
961
+ # =>> If you want to use a non-RLDS dataset (e.g., a standard PyTorch Dataset) see the following commented block.
962
+ # =>> Note that our training code does not loop over epochs because the RLDS loader does this implicitly; if using
963
+ # your own Dataset, make sure to add the appropriate logic to the training loop!
964
+ #
965
+ # ---
966
+ # from prismatic.vla.datasets import DummyDataset
967
+ #
968
+ # train_dataset = DummyDataset(
969
+ # action_tokenizer,
970
+ # processor.tokenizer,
971
+ # image_transform=processor.image_processor.apply_transform,
972
+ # prompt_builder_fn=PurePromptBuilder,
973
+ # )
974
+ # ---
975
+
976
+ # We assume that the model takes as input one third-person camera image and 1 or 2 optional wrist camera image(s)
977
+ use_wrist_image = cfg.num_images_in_input > 1
978
+
979
+ # Create training and optional validation datasets
980
+ batch_transform = RLDSBatchTransform(
981
+ action_tokenizer,
982
+ processor.tokenizer,
983
+ image_transform=processor.image_processor.apply_transform,
984
+ prompt_builder_fn=PurePromptBuilder,
985
+ use_wrist_image=use_wrist_image,
986
+ use_proprio=cfg.use_proprio,
987
+ )
988
+ train_dataset = RLDSDataset(
989
+ cfg.data_root_dir,
990
+ cfg.dataset_name,
991
+ batch_transform,
992
+ resize_resolution=tuple(vla.module.config.image_sizes),
993
+ shuffle_buffer_size=cfg.shuffle_buffer_size,
994
+ image_aug=cfg.image_aug,
995
+ )
996
+ if cfg.use_val_set:
997
+ val_dataset = RLDSDataset(
998
+ cfg.data_root_dir,
999
+ cfg.dataset_name,
1000
+ batch_transform,
1001
+ resize_resolution=tuple(vla.module.config.image_sizes),
1002
+ shuffle_buffer_size=cfg.shuffle_buffer_size // 10,
1003
+ image_aug=cfg.image_aug,
1004
+ train=False,
1005
+ )
1006
+
1007
+ # [Important] Save dataset statistics so that we can unnormalize actions during inference
1008
+ if distributed_state.is_main_process:
1009
+ save_dataset_statistics(train_dataset.dataset_statistics, run_dir)
1010
+
1011
+ # Create collator and dataloader
1012
+ collator = PaddedCollatorForActionPrediction(
1013
+ processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
1014
+ )
1015
+ dataloader = DataLoader(
1016
+ train_dataset,
1017
+ batch_size=cfg.batch_size,
1018
+ sampler=None,
1019
+ collate_fn=collator,
1020
+ num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
1021
+ )
1022
+ if cfg.use_val_set:
1023
+ val_batch_size = cfg.batch_size
1024
+ val_dataloader = DataLoader(
1025
+ val_dataset,
1026
+ batch_size=val_batch_size,
1027
+ sampler=None,
1028
+ collate_fn=collator,
1029
+ num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
1030
+ )
1031
+
1032
+ # Deque to store recent train metrics (used for computing smoothened metrics for gradient accumulation)
1033
+ recent_metrics = {
1034
+ "loss_value": deque(maxlen=cfg.grad_accumulation_steps),
1035
+ "curr_action_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
1036
+ "curr_action_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
1037
+ "next_actions_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
1038
+ "next_actions_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
1039
+ }
1040
+
1041
+ # Start training
1042
+ with tqdm.tqdm(total=cfg.max_steps, leave=False) as progress:
1043
+ vla.train()
1044
+ optimizer.zero_grad()
1045
+ for batch_idx, batch in enumerate(dataloader):
1046
+ # Compute training metrics and loss
1047
+ compute_diffusion_l1 = cfg.use_diffusion and batch_idx % cfg.diffusion_sample_freq == 0
1048
+ loss, metrics = run_forward_pass(
1049
+ vla=vla,
1050
+ action_head=action_head,
1051
+ noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
1052
+ proprio_projector=proprio_projector if cfg.use_proprio else None,
1053
+ batch=batch,
1054
+ action_tokenizer=action_tokenizer,
1055
+ device_id=device_id,
1056
+ use_l1_regression=cfg.use_l1_regression,
1057
+ use_diffusion=cfg.use_diffusion,
1058
+ use_proprio=cfg.use_proprio,
1059
+ use_film=cfg.use_film,
1060
+ num_patches=NUM_PATCHES,
1061
+ compute_diffusion_l1=compute_diffusion_l1,
1062
+ num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
1063
+ )
1064
+
1065
+ # Normalize loss to account for gradient accumulation
1066
+ normalized_loss = loss / cfg.grad_accumulation_steps
1067
+
1068
+ # Backward pass
1069
+ normalized_loss.backward()
1070
+
1071
+ # Store recent train metrics
1072
+ for metric_name, value in metrics.items():
1073
+ if metric_name in recent_metrics:
1074
+ recent_metrics[metric_name].append(value)
1075
+
1076
+ # Compute gradient step index
1077
+ gradient_step_idx = batch_idx // cfg.grad_accumulation_steps
1078
+
1079
+ # Compute smoothened train metrics
1080
+ smoothened_metrics = compute_smoothened_metrics(recent_metrics)
1081
+
1082
+ # Push Metrics to W&B (every wandb_log_freq gradient steps)
1083
+ log_step = gradient_step_idx if not cfg.resume else cfg.resume_step + gradient_step_idx
1084
+ if distributed_state.is_main_process and log_step % cfg.wandb_log_freq == 0:
1085
+ log_metrics_to_wandb(smoothened_metrics, "VLA Train", log_step, wandb)
1086
+
1087
+ # [If applicable] Linearly warm up learning rate from 10% to 100% of original
1088
+ if cfg.lr_warmup_steps > 0:
1089
+ lr_progress = min((gradient_step_idx + 1) / cfg.lr_warmup_steps, 1.0) # Cap at 1.0
1090
+ current_lr = original_lr * (0.1 + 0.9 * lr_progress)
1091
+ for param_group in optimizer.param_groups:
1092
+ param_group["lr"] = current_lr
1093
+
1094
+ if distributed_state.is_main_process and gradient_step_idx % cfg.wandb_log_freq == 0:
1095
+ # Log the learning rate
1096
+ # Make sure to do this AFTER any learning rate modifications (e.g., warmup/decay)
1097
+ wandb.log(
1098
+ {
1099
+ "VLA Train/Learning Rate": scheduler.get_last_lr()[0],
1100
+ },
1101
+ step=log_step,
1102
+ )
1103
+
1104
+ # Optimizer and LR scheduler step
1105
+ if (batch_idx + 1) % cfg.grad_accumulation_steps == 0:
1106
+ optimizer.step()
1107
+ scheduler.step()
1108
+ optimizer.zero_grad()
1109
+ progress.update()
1110
+
1111
+ # Save model checkpoint: either keep latest checkpoint only or all checkpoints
1112
+ if gradient_step_idx > 0 and log_step % cfg.save_freq == 0:
1113
+ save_training_checkpoint(
1114
+ cfg=cfg,
1115
+ run_dir=run_dir,
1116
+ log_step=log_step,
1117
+ vla=vla,
1118
+ processor=processor,
1119
+ proprio_projector=proprio_projector if cfg.use_proprio else None,
1120
+ noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
1121
+ action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None,
1122
+ train_dataset=train_dataset,
1123
+ distributed_state=distributed_state,
1124
+ )
1125
+
1126
+ # Test model on validation set
1127
+ if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0:
1128
+ run_validation(
1129
+ vla=vla,
1130
+ action_head=action_head,
1131
+ noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
1132
+ proprio_projector=proprio_projector if cfg.use_proprio else None,
1133
+ val_dataloader=val_dataloader,
1134
+ action_tokenizer=action_tokenizer,
1135
+ device_id=device_id,
1136
+ cfg=cfg,
1137
+ num_patches=NUM_PATCHES,
1138
+ log_step=log_step,
1139
+ distributed_state=distributed_state,
1140
+ val_time_limit=cfg.val_time_limit,
1141
+ )
1142
+ # Set model back to training mode after validation
1143
+ vla.train()
1144
+
1145
+ # Stop training when max_steps is reached
1146
+ if log_step == cfg.max_steps:
1147
+ print(f"Max step {cfg.max_steps} reached! Stopping training...")
1148
+ break
1149
+
1150
+
1151
+ if __name__ == "__main__":
1152
+ finetune()
capvector-oft/vla-scripts/finetune_regular_loss.py ADDED
@@ -0,0 +1,1790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #This is for the experiment of CapVector, stopping the gradient propagation in the direction of the new added vector
2
+ """
3
+ finetune.py
4
+
5
+ Fine-tunes OpenVLA via LoRA.
6
+ """
7
+
8
+ import os
9
+ import ctypes
10
+
11
+ lib_path = "/share/miniconda3/lib/libstdc++.so.6"
12
+
13
+ try:
14
+ ctypes.CDLL(lib_path)
15
+ print(f"Successfully preloaded {lib_path}")
16
+ except Exception as e:
17
+ print(f"Failed to preload {lib_path}: {e}")
18
+
19
+ import os
20
+ import time
21
+ from collections import deque
22
+ from dataclasses import dataclass
23
+ from pathlib import Path
24
+ from typing import Dict, Optional, Tuple, Type
25
+
26
+ import draccus
27
+ import torch
28
+ import torch.distributed as dist
29
+ import torch.nn as nn
30
+ import tqdm
31
+ import numpy as np
32
+ from accelerate import PartialState
33
+ from huggingface_hub import HfApi, snapshot_download
34
+ from peft import LoraConfig, PeftModel, get_peft_model
35
+ from torch.nn.parallel import DistributedDataParallel as DDP
36
+ from torch.optim import AdamW
37
+ from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
38
+ from torch.utils.data import DataLoader
39
+ from transformers import get_cosine_schedule_with_warmup
40
+ from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
41
+ from transformers.modeling_outputs import CausalLMOutputWithPast
42
+
43
+ import wandb
44
+ os.environ["WANDB_MODE"]="offline"
45
+
46
+ try:
47
+ from safetensors import safe_open
48
+ SAFETENSORS_AVAILABLE = True
49
+ except ImportError:
50
+ SAFETENSORS_AVAILABLE = False
51
+ print("Warning: safetensors not available, will try torch.load instead")
52
+
53
+ from experiments.robot.openvla_utils import (
54
+ check_model_logic_mismatch,
55
+ model_is_on_hf_hub,
56
+ update_auto_map,
57
+ )
58
+
59
+ from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
60
+ from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
61
+ from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
62
+ from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead
63
+ from prismatic.models.backbones.llm.prompting import PurePromptBuilder
64
+ from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone
65
+ from prismatic.models.ema_model import EMAModel
66
+ from prismatic.models.projectors import (
67
+ NoisyActionProjector,
68
+ ProprioProjector,
69
+ )
70
+ from prismatic.training.train_utils import (
71
+ compute_actions_l1_loss,
72
+ compute_token_accuracy,
73
+ get_current_action_mask,
74
+ get_next_actions_mask,
75
+ )
76
+ from prismatic.util.data_utils import PaddedCollatorForActionPrediction
77
+ from prismatic.vla.action_tokenizer import ActionTokenizer
78
+ from prismatic.vla.constants import (
79
+ ACTION_DIM,
80
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
81
+ NUM_ACTIONS_CHUNK,
82
+ PROPRIO_DIM,
83
+ )
84
+ from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
85
+ from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics
86
+
87
+ # Sane Defaults
88
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
89
+
90
+ #wx: stop gradient in the feature vector direction
91
+
92
+ EPS = 1e-12
93
+
94
+ def register_orthogonal_grad_hook(model, vector_W, debug=False):
95
+ name_to_param = dict(model.named_parameters())
96
+
97
+ hooked_A = 0
98
+ hooked_B = 0
99
+ hooked_direct = 0
100
+
101
+ missed = 0
102
+ missed_name = []
103
+
104
+ direct_missed = 0
105
+ direct_missed_name = []
106
+
107
+ printed = {"A": False, "B": False, "D": False}
108
+
109
+ def proj_out(g2, v2):
110
+ vn2 = (v2 * v2).sum().detach()
111
+ if vn2.item() <= EPS:
112
+ return g2
113
+ gv = (g2 * v2).sum()
114
+ return g2 - (gv / (vn2 + EPS)) * v2
115
+
116
+ for w_name, vW in vector_W.items():
117
+ if "vision_backbone" in w_name:
118
+ continue
119
+
120
+ prefix = "base_model.model."
121
+ A_name = prefix + w_name.replace(".weight", ".lora_A.default.weight")
122
+ B_name = prefix + w_name.replace(".weight", ".lora_B.default.weight")
123
+
124
+ # ===== 1) 先尝试 LoRA hook =====
125
+ if A_name in name_to_param and B_name in name_to_param:
126
+ A = name_to_param[A_name]
127
+ B = name_to_param[B_name]
128
+
129
+ # 两个都不训练就不 hook
130
+ if (not A.requires_grad) and (not B.requires_grad):
131
+ continue
132
+
133
+ # vW 固定到 device/dtype
134
+ vW = vW.to(device=A.device, dtype=A.dtype)
135
+ vW2 = vW.reshape(vW.shape[0], -1) if vW.ndim != 2 else vW # [out, in_flat]
136
+
137
+ # ---- hook A:动态用当前 B 计算 vA = B^T vW ----
138
+ if A.requires_grad:
139
+ def hook_A(g, A_ref=A, B_ref=B, vW2_ref=vW2):
140
+ if g is None:
141
+ return None
142
+ g2 = g.reshape(g.shape[0], -1) if g.ndim != 2 else g
143
+
144
+ B_mat = B_ref.detach()
145
+ B2 = B_mat.reshape(B_mat.shape[0], -1) if B_mat.ndim != 2 else B_mat # [out, r]
146
+
147
+ if B2.shape[0] != vW2_ref.shape[0]:
148
+ return g
149
+
150
+ vA = torch.matmul(B2.transpose(0, 1), vW2_ref) # [r, in_flat]
151
+
152
+ if debug and not printed["A"]:
153
+ print(f"[hook fired] A: ||B||={B2.norm().item():.4e}, ||vA||={vA.norm().item():.4e}, ||g||={g2.norm().item():.4e}")
154
+ printed["A"] = True
155
+
156
+ g2_new = proj_out(g2, vA)
157
+ return g2_new.view_as(g)
158
+
159
+ A.register_hook(hook_A)
160
+ hooked_A += 1
161
+
162
+ # ---- hook B:动态用当前 A 计算 vB = vW A^T ----
163
+ if B.requires_grad:
164
+ def hook_B(g, A_ref=A, B_ref=B, vW2_ref=vW2):
165
+ if g is None:
166
+ return None
167
+ g2 = g.reshape(g.shape[0], -1) if g.ndim != 2 else g
168
+
169
+ A_mat = A_ref.detach()
170
+ A2 = A_mat.reshape(A_mat.shape[0], -1) if A_mat.ndim != 2 else A_mat # [r, in_flat]
171
+
172
+ if A2.shape[1] != vW2_ref.shape[1]:
173
+ return g
174
+
175
+ vB = torch.matmul(vW2_ref, A2.transpose(0, 1)) # [out, r]
176
+
177
+ if debug and not printed["B"]:
178
+ print(f"[hook fired] B: ||A||={A2.norm().item():.4e}, ||vB||={vB.norm().item():.4e}, ||g||={g2.norm().item():.4e}")
179
+ printed["B"] = True
180
+
181
+ g2_new = proj_out(g2, vB)
182
+ return g2_new.view_as(g)
183
+
184
+ B.register_hook(hook_B)
185
+ hooked_B += 1
186
+
187
+ # 这一轮已经成功走 LoRA 分支了
188
+ continue
189
+
190
+ # ===== 2) LoRA 不存在:fallback 到“直接参数”hook(比如 layernorm)=====
191
+ missed += 1
192
+ missed_name.append(w_name)
193
+
194
+ # 尝试对齐到非 LoRA 参数名
195
+ # 绝大多数情况下:base_model.model.<w_name>
196
+ direct_name = prefix + w_name
197
+
198
+ # 有些 vector 的命名可能不带 base_model.model,而你的模型参数名可能是别的前缀
199
+ # 这里给一个“再尝试一次”的备选:如果 direct_name 找不到,就尝试去掉 language_model/等前缀的情况
200
+ # (你也可以按自己工程实际再加规则)
201
+ if direct_name not in name_to_param:
202
+ # 再试一次:如果 w_name 本身已经含 base_model.model 就不加 prefix
203
+ if w_name in name_to_param:
204
+ direct_name = w_name
205
+ else:
206
+ direct_missed += 1
207
+ direct_missed_name.append(w_name)
208
+ continue
209
+
210
+ P = name_to_param[direct_name]
211
+ if not P.requires_grad:
212
+ # 找到了但不训练:不 hook,也不算 direct_missed
213
+ continue
214
+
215
+ vP = vector_W[w_name].to(device=P.device, dtype=P.dtype)
216
+ vP2 = vP.reshape(vP.shape[0], -1) if vP.ndim != 2 else vP
217
+
218
+ def hook_direct(g, v_ref=vP2):
219
+ if g is None:
220
+ return None
221
+ g2 = g.reshape(g.shape[0], -1) if g.ndim != 2 else g
222
+
223
+ # shape 不匹配就不动(避免 hook 改尺寸报错)
224
+ if g2.shape != v_ref.shape:
225
+ return g
226
+
227
+ if debug and not printed["D"]:
228
+ print(f"[hook fired] Direct: param={direct_name}, ||v||={v_ref.norm().item():.4e}, ||g||={g2.norm().item():.4e}")
229
+ printed["D"] = True
230
+
231
+ g2_new = proj_out(g2, v_ref)
232
+ return g2_new.view_as(g)
233
+
234
+ P.register_hook(hook_direct)
235
+ hooked_direct += 1
236
+
237
+ print(
238
+ f"[hook summary] hooked lora_A: {hooked_A}, lora_B: {hooked_B}, direct: {hooked_direct}, "
239
+ f"missed(lora-not-found): {missed}, direct_missed: {direct_missed}"
240
+ )
241
+
242
+ # 如果你想看具体 miss 列表:
243
+ # print("[missed lora-not-found names]")
244
+ # for n in missed_name: print(" -", n)
245
+ # print("[direct_missed names]")
246
+ # for n in direct_missed_name: print(" -", n)
247
+
248
+ # import pdb; pdb.set_trace()
249
+
250
+
251
+ # def register_orthogonal_grad_hook(model, vector_W, debug=False):
252
+ # name_to_param = dict(model.named_parameters())
253
+
254
+ # hooked_A = 0
255
+ # hooked_B = 0
256
+ # missed = 0
257
+
258
+ # printed = {"A": False, "B": False} # 用于只打印一次
259
+
260
+ # for w_name, vW in vector_W.items():
261
+ # if "vision_backbone" in w_name:
262
+ # continue
263
+ # # import pdb; pdb.set_trace()
264
+ # prefix = "base_model.model."
265
+ # A_name = prefix + w_name.replace(".weight", ".lora_A.default.weight")
266
+ # B_name = prefix + w_name.replace(".weight", ".lora_B.default.weight")
267
+
268
+ # if A_name not in name_to_param or B_name not in name_to_param:
269
+ # missed += 1
270
+ # continue
271
+
272
+ # A = name_to_param[A_name]
273
+ # B = name_to_param[B_name]
274
+
275
+ # if (not A.requires_grad) and (not B.requires_grad):
276
+ # continue
277
+
278
+ # vW = vW.to(device=A.device, dtype=A.dtype)
279
+
280
+ # with torch.no_grad():
281
+ # # A_mat = A.detach().view(1, -1) # (1, in)
282
+ # # B_mat = B.detach().view(-1, 1) # (out,1)
283
+
284
+ # # vA = torch.matmul(B_mat.T, vW) # (1,in)
285
+ # # vB = torch.matmul(vW, A_mat.T) # (out,1)
286
+ # B_mat = B.detach()
287
+ # A_mat = A.detach()
288
+ # # import pdb; pdb.set_trace()
289
+
290
+ # # 统一把 vW 变成二维: [out, in_flat]
291
+ # if vW.ndim != 2:
292
+ # vW2 = vW.reshape(vW.shape[0], -1)
293
+ # else:
294
+ # vW2 = vW
295
+
296
+ # # A 也可能不是严格二维(一般是二维,但保险起见)#看了一下AB都是二维
297
+ # if A_mat.ndim != 2:
298
+ # A2 = A_mat.reshape(A_mat.shape[0], -1) # [r, in_flat]
299
+ # else:
300
+ # A2 = A_mat
301
+
302
+ # # B 通常是二维 [out, r]
303
+ # if B_mat.ndim != 2:
304
+ # B2 = B_mat.reshape(B_mat.shape[0], -1) # [out, r]
305
+ # else:
306
+ # B2 = B_mat
307
+
308
+ # # 形状校验:不匹配就跳过这个 w_name(避免再报错)
309
+ # # 需要:B2: [out, r] 与 vW2: [out, in_flat] 的 out 对齐
310
+ # # 需要:A2: [r, in_flat] 与 vW2: [out, in_flat] 的 in_flat 对齐
311
+ # if B2.shape[0] != vW2.shape[0] or A2.shape[1] != vW2.shape[1] or A2.shape[0] != B2.shape[1]:
312
+ # missed += 1
313
+ # continue
314
+
315
+ # vA = torch.matmul(B2.transpose(0, 1), vW2) # [r, in_flat]
316
+ # vB = torch.matmul(vW2, A2.transpose(0, 1)) # [out, r]
317
+
318
+
319
+ # # hook A
320
+ # if A.requires_grad:
321
+ # vA_norm2 = (vA * vA).sum().detach()
322
+ # if vA_norm2.item() > EPS:
323
+ # def make_hook_A(v, vn2):
324
+ # def hook(g):
325
+ # if debug and not printed["A"]:
326
+ # print(f"[hook fired] lora_A grad norm: {g.norm().item():.4e}")
327
+ # printed["A"] = True
328
+ # gv = (g * v).sum()
329
+ # proj = (gv / (vn2 + EPS)) * v
330
+ # return g - proj
331
+ # return hook
332
+
333
+ # A.register_hook(make_hook_A(vA, vA_norm2))
334
+ # hooked_A += 1
335
+
336
+ # # hook B
337
+ # if B.requires_grad:
338
+ # vB_norm2 = (vB * vB).sum().detach()
339
+ # if vB_norm2.item() > EPS:
340
+ # def make_hook_B(v, vn2):
341
+ # def hook(g):
342
+ # if debug and not printed["B"]:
343
+ # print(f"[hook fired] lora_B grad norm: {g.norm().item():.4e}")
344
+ # printed["B"] = True
345
+ # gv = (g * v).sum()
346
+ # proj = (gv / (vn2 + EPS)) * v
347
+ # return g - proj
348
+ # return hook
349
+
350
+ # B.register_hook(make_hook_B(vB, vB_norm2))
351
+ # hooked_B += 1
352
+
353
+ # print(f"[hook summary] hooked lora_A: {hooked_A}, hooked lora_B: {hooked_B}, missed: {missed}")
354
+ # import pdb; pdb.set_trace()
355
+
356
+
357
+
358
+ # 用法:
359
+ # vector_sd = torch.load("your_vector.pth")["state_dict"] or similar
360
+ # register_orthogonal_grad_hook(model, vector_sd)
361
+
362
+
363
+ # import debugpy
364
+ # try:
365
+ # debugpy.listen(("localhost", 9501))
366
+ # print("Waiting for debugger attach")
367
+ # debugpy.wait_for_client()
368
+ # except Exception as e:
369
+ # pass
370
+
371
+
372
+ @dataclass
373
+ class FinetuneConfig:
374
+ # fmt: off
375
+ vla_path: str = "openvla/openvla-7b" # Path to OpenVLA model (on HuggingFace Hub or stored locally)
376
+
377
+ # Dataset
378
+ data_root_dir: Path = Path("datasets/rlds") # Directory containing RLDS datasets
379
+ dataset_name: str = "aloha_scoop_x_into_bowl" # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`)
380
+ run_root_dir: Path = Path("runs") # Path to directory to store logs & checkpoints
381
+ shuffle_buffer_size: int = 100_000 # Dataloader shuffle buffer size (can reduce if OOM errors occur)
382
+
383
+ # Algorithm and architecture
384
+ use_l1_regression: bool = True # If True, trains continuous action head with L1 regression objective
385
+ use_diffusion: bool = False # If True, trains continuous action head with diffusion modeling objective (DDIM)
386
+ num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training
387
+ use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features
388
+ num_images_in_input: int = 1 # Number of images in the VLA input (default: 1)
389
+ use_proprio: bool = False # If True, includes robot proprioceptive state in input
390
+
391
+ # Training configuration
392
+ batch_size: int = 8 # Batch size per device (total batch size = batch_size * num GPUs)
393
+ learning_rate: float = 5e-4 # Learning rate
394
+ lr_warmup_steps: int = 0 # Number of steps to warm up learning rate (from 10% to 100%)
395
+ num_steps_before_decay: int = 100_000 # Number of steps before LR decays by 10x
396
+ grad_accumulation_steps: int = 1 # Number of gradient accumulation steps
397
+ max_steps: int = 200_000 # Max number of training steps
398
+ use_val_set: bool = False # If True, uses validation set and log validation metrics
399
+ val_freq: int = 10_000 # (When `use_val_set==True`) Validation set logging frequency in steps
400
+ val_time_limit: int = 180 # (When `use_val_set==True`) Time limit for computing validation metrics
401
+ save_freq: int = 10_000 # Checkpoint saving frequency in steps
402
+ save_latest_checkpoint_only: bool = False # If True, saves only 1 checkpoint, overwriting latest checkpoint
403
+ # (If False, saves all checkpoints)
404
+ scheduler: str = 'MultiStepLR' # "MultiStepLR" or "CosineAnnealingLR" or "WarmupCosineLR"
405
+ resume: bool = False # If True, resumes from checkpoint
406
+ resume_step: Optional[int] = None # (When `resume==True`) Step number that we are resuming from
407
+ image_aug: bool = True # If True, trains with image augmentations (HIGHLY RECOMMENDED)
408
+ diffusion_sample_freq: int = 50 # (When `use_diffusion==True`) Frequency for sampling in steps
409
+
410
+ # LoRA
411
+ use_lora: bool = True # If True, uses LoRA fine-tuning
412
+ lora_rank: int = 32 # Rank of LoRA weight matrix
413
+ lora_dropout: float = 0.0 # Dropout applied to LoRA weights
414
+ merge_lora_during_training: bool = True # If True, merges LoRA weights and saves result during training
415
+ # Note: Merging can be very slow on some machines. If so, set to
416
+ # False and merge final checkpoint offline!
417
+
418
+ # Regularization
419
+ regularization_lora_vector_path: str = None # Path to regularization vector
420
+ regularization_weight: float = 1e-3 # Weight of regularization loss
421
+
422
+ # Logging
423
+ wandb_entity: str = "your-wandb-entity" # Name of WandB entity
424
+ wandb_project: str = "your-wandb-project" # Name of WandB project
425
+ run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging
426
+ run_id_override: Optional[str] = None # Optional string to override the run ID with
427
+ wandb_log_freq: int = 10 # WandB logging frequency in steps
428
+
429
+ # EMA
430
+ use_ema: bool = False # If True, maintains an EMA copy of the model
431
+ inv_gamma: float = 1 # EMA inverse gamma parameter
432
+
433
+ # fmt: on
434
+
435
+
436
+ def remove_ddp_in_checkpoint(state_dict) -> dict:
437
+ """
438
+ Removes the 'module.' prefix from parameter names in a PyTorch model state dictionary that was saved using
439
+ DistributedDataParallel (DDP).
440
+
441
+ When a model is trained using PyTorch's DistributedDataParallel, the saved state dictionary contains parameters
442
+ prefixed with 'module.'. This function removes these prefixes to make the state dictionary compatible when
443
+ loading into models that are not yet wrapped in DDP.
444
+
445
+ Args:
446
+ state_dict (dict): PyTorch model state dictionary.
447
+
448
+ Returns:
449
+ dict: A new state dictionary with the same contents but with 'module.' prefixes removed from parameter names.
450
+ Parameters without the 'module.' prefix remain unchanged.
451
+ """
452
+ new_state_dict = {}
453
+ for k, v in state_dict.items():
454
+ if k[:7] == "module.":
455
+ new_state_dict[k[7:]] = v
456
+ else:
457
+ new_state_dict[k] = v
458
+ return new_state_dict
459
+
460
+
461
+ def get_run_id(cfg) -> str:
462
+ """
463
+ Generates or retrieves an identifier string for an experiment run.
464
+
465
+ Args:
466
+ cfg (FinetuneConfig): Training configuration.
467
+
468
+ Returns:
469
+ str: Experiment run ID.
470
+ """
471
+ if cfg.run_id_override is not None:
472
+ # Override the run ID with the user-provided ID
473
+ run_id = cfg.run_id_override
474
+ elif cfg.resume:
475
+ # Override run ID with the previous resumed run's ID
476
+ run_id = cfg.vla_path.split("/")[-1]
477
+ # Remove the "--XXX_chkpt" suffix from the run ID if it exists
478
+ if "chkpt" in run_id.split("--")[-1]:
479
+ run_id = "--".join(run_id.split("--")[:-1])
480
+ else:
481
+ run_id = (
482
+ f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}"
483
+ f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
484
+ f"+lr-{cfg.learning_rate}"
485
+ )
486
+ if cfg.use_lora:
487
+ run_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}"
488
+ if cfg.image_aug:
489
+ run_id += "--image_aug"
490
+ if cfg.run_id_note is not None:
491
+ run_id += f"--{cfg.run_id_note}"
492
+ return run_id
493
+
494
+
495
+ def load_checkpoint(module_name: str, path: str, step: int, device: str = "cpu") -> dict:
496
+ """
497
+ Loads a checkpoint for a given module.
498
+
499
+ Args:
500
+ module_name (str): Name of model component to load checkpoint for.
501
+ path (str): Path to checkpoint directory.
502
+ step (int): Gradient step number of saved checkpoint.
503
+ device (str): String specifying how to remap storage locations (default = "cpu").
504
+
505
+ Returns:
506
+ dict: PyTorch model state dictionary.
507
+ """
508
+ checkpoint_path = os.path.join(path, f"{module_name}--{step}_checkpoint.pt")
509
+ print(f"Loading checkpoint: {checkpoint_path}")
510
+ state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
511
+ return remove_ddp_in_checkpoint(state_dict)
512
+
513
+
514
+ def wrap_ddp(module: nn.Module, device_id: int, find_unused: bool = False) -> DDP:
515
+ """
516
+ Wrap a module with DistributedDataParallel.
517
+
518
+ Args:
519
+ module (nn.Module): PyTorch module.
520
+ device_id (str): Device ID.
521
+ find_unused (bool): Whether to detect parameters without gradients in distributed training.
522
+
523
+ Returns:
524
+ DistributedDataParallel: PyTorch module wrapped with DDP.
525
+ """
526
+ return DDP(module, device_ids=[device_id], find_unused_parameters=find_unused, gradient_as_bucket_view=True)
527
+
528
+
529
+ def count_parameters(module: nn.Module, name: str) -> None:
530
+ """
531
+ Counts and prints the number of trainable parameters in a module.
532
+
533
+ Args:
534
+ module (nn.Module): PyTorch module.
535
+ module_name (str): Name of model component.
536
+
537
+ Returns:
538
+ None.
539
+ """
540
+ num_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
541
+ print(f"# trainable params in {name}: {num_params}")
542
+
543
+
544
+ def init_module(
545
+ module_class: Type[nn.Module],
546
+ module_name: str,
547
+ cfg: FinetuneConfig,
548
+ device_id: int,
549
+ module_args: dict,
550
+ to_bf16: bool = False,
551
+ find_unused_params: bool = False,
552
+ ) -> DDP:
553
+ """
554
+ Initializes a module, optionally loads checkpoint, moves to device, and wraps with DDP.
555
+
556
+ Args:
557
+ module_class (Type[nn.Module]): Class of PyTorch module to initialize.
558
+ module_name (str): Name of model component to load checkpoint for.
559
+ cfg (FinetuneConfig): Training configuration.
560
+ device_id (str): Device ID.
561
+ module_args (dict): Args for initializing the module.
562
+ to_bf16 (bool): Whether to convert to torch.bfloat16 data type.
563
+ find_unused_params (bool): Whether to detect parameters without gradients in distributed training.
564
+
565
+ Returns:
566
+ DistributedDataParallel: PyTorch module wrapped with DDP.
567
+ """
568
+ module = module_class(**module_args)
569
+ count_parameters(module, module_name)
570
+
571
+ if cfg.resume:
572
+ state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step)
573
+ module.load_state_dict(state_dict)
574
+
575
+ if to_bf16:
576
+ module = module.to(torch.bfloat16)
577
+ module = module.to(device_id)
578
+
579
+ return wrap_ddp(module, device_id, find_unused_params)
580
+
581
+
582
+ def run_forward_pass(
583
+ vla,
584
+ action_head,
585
+ noisy_action_projector,
586
+ proprio_projector,
587
+ batch,
588
+ action_tokenizer,
589
+ device_id,
590
+ use_l1_regression,
591
+ use_diffusion,
592
+ use_proprio,
593
+ use_film,
594
+ num_patches,
595
+ compute_diffusion_l1=False,
596
+ num_diffusion_steps_train=None,
597
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
598
+ """
599
+ Compute model forward pass and metrics for both training and validation.
600
+
601
+ Args:
602
+ vla (OpenVLAForActionPrediction): Vision-language-action policy.
603
+ action_head (nn.Module): Action head module.
604
+ noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
605
+ proprio_projector (nn.Module): Proprioceptive state projector module.
606
+ batch (dict): Input batch.
607
+ action_tokenizer (ActionTokenizer): Action tokenizer.
608
+ device_id (str): Device ID.
609
+ use_l1_regression (bool): Whether to use L1 regression.
610
+ use_diffusion (bool): Whether to use diffusion.
611
+ use_proprio (bool): Whether to use proprioceptive state as input.
612
+ use_film (bool): Whether to use FiLM for better language following.
613
+ num_patches (int): Number of vision patches.
614
+ compute_diffusion_l1 (bool): Whether to sample actions and compute L1 loss for diffusion (do this once every
615
+ diffusion_sample_freq steps during training; do it every batch for validation)
616
+ num_diffusion_steps_train (int): Number of diffusion steps for training (only used for diffusion).
617
+
618
+ Returns:
619
+ tuple: (loss, metrics_dict)
620
+ loss: The loss tensor with gradient for backpropagation.
621
+ metrics_dict: Dictionary of computed metrics (detached values for logging).
622
+ """
623
+ metrics = {}
624
+
625
+ # Get ground-truth action labels
626
+ ground_truth_actions = batch["actions"].to(device_id).to(torch.bfloat16)
627
+
628
+ # [Only for diffusion] Sample noisy actions used as input for noise predictor network
629
+ if use_diffusion:
630
+ noisy_dict = action_head.module.sample_noisy_actions(ground_truth_actions)
631
+ noise, noisy_actions, diffusion_timestep_embeddings = (
632
+ noisy_dict["noise"],
633
+ noisy_dict["noisy_actions"],
634
+ noisy_dict["diffusion_timestep_embeddings"],
635
+ )
636
+ else:
637
+ noise, noisy_actions, diffusion_timestep_embeddings = None, None, None
638
+
639
+ # VLA forward pass
640
+ with torch.autocast("cuda", dtype=torch.bfloat16):
641
+ output: CausalLMOutputWithPast = vla(
642
+ input_ids=batch["input_ids"].to(device_id),
643
+ attention_mask=batch["attention_mask"].to(device_id),
644
+ pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
645
+ labels=batch["labels"],
646
+ output_hidden_states=True,
647
+ proprio=batch["proprio"] if use_proprio else None,
648
+ proprio_projector=proprio_projector if use_proprio else None,
649
+ noisy_actions=noisy_actions if use_diffusion else None,
650
+ noisy_action_projector=noisy_action_projector if use_diffusion else None,
651
+ diffusion_timestep_embeddings=diffusion_timestep_embeddings if use_diffusion else None,
652
+ use_film=use_film,
653
+ )
654
+
655
+ # Get action masks needed for logging
656
+ ground_truth_token_ids = batch["labels"][:, 1:].to(device_id)
657
+ current_action_mask = get_current_action_mask(ground_truth_token_ids)
658
+ next_actions_mask = get_next_actions_mask(ground_truth_token_ids)
659
+
660
+ # Compute metrics for discrete action representation (next-token prediction)
661
+ if not (use_l1_regression or use_diffusion):
662
+ loss = output.loss
663
+ predicted_token_ids = output.logits[:, num_patches:-1].argmax(dim=2)
664
+ curr_action_accuracy = compute_token_accuracy(
665
+ predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
666
+ )
667
+ curr_action_l1_loss = compute_actions_l1_loss(
668
+ action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask
669
+ )
670
+ next_actions_accuracy = compute_token_accuracy(
671
+ predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
672
+ )
673
+ next_actions_l1_loss = compute_actions_l1_loss(
674
+ action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask
675
+ )
676
+ metrics.update(
677
+ {
678
+ "loss_value": loss.item(), # Detached value for logging
679
+ "curr_action_accuracy": curr_action_accuracy.item(),
680
+ "curr_action_l1_loss": curr_action_l1_loss.item(),
681
+ "next_actions_accuracy": next_actions_accuracy.item(),
682
+ "next_actions_l1_loss": next_actions_l1_loss.item(),
683
+ }
684
+ )
685
+ # Compute metrics for continuous action representations (L1 regression | diffusion)
686
+ else:
687
+ # Get last layer hidden states
688
+ last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
689
+ # Get hidden states for text portion of prompt+response (after the vision patches)
690
+ text_hidden_states = last_hidden_states[:, num_patches:-1]
691
+ # Get hidden states for action portion of response
692
+ batch_size = batch["input_ids"].shape[0]
693
+ actions_hidden_states = (
694
+ text_hidden_states[current_action_mask | next_actions_mask]
695
+ .reshape(batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1)
696
+ .to(torch.bfloat16)
697
+ ) # (B, act_chunk_len, D)
698
+
699
+ if use_l1_regression:
700
+ # Predict action
701
+ predicted_actions = action_head.module.predict_action(actions_hidden_states)
702
+ # Get full L1 loss
703
+ loss = torch.nn.L1Loss()(ground_truth_actions, predicted_actions)
704
+
705
+ if use_diffusion:
706
+ # Predict noise
707
+ noise_pred = action_head.module.predict_noise(actions_hidden_states)
708
+ # Get diffusion noise prediction MSE loss
709
+ noise_pred = noise_pred.reshape(noise.shape)
710
+ loss = nn.functional.mse_loss(noise_pred, noise, reduction="mean")
711
+
712
+ # Only sample actions and compute L1 losses if specified
713
+ if compute_diffusion_l1:
714
+ with torch.no_grad():
715
+ predicted_actions = run_diffusion_sampling(
716
+ vla=vla,
717
+ action_head=action_head,
718
+ noisy_action_projector=noisy_action_projector,
719
+ proprio_projector=proprio_projector,
720
+ batch=batch,
721
+ batch_size=batch_size,
722
+ num_patches=num_patches,
723
+ actions_shape=ground_truth_actions.shape,
724
+ device_id=device_id,
725
+ current_action_mask=current_action_mask,
726
+ next_actions_mask=next_actions_mask,
727
+ use_proprio=use_proprio,
728
+ use_film=use_film,
729
+ )
730
+
731
+ metrics.update(
732
+ {
733
+ "loss_value": loss.item(), # Detached value for logging
734
+ }
735
+ )
736
+
737
+ # Get detailed L1 losses for logging
738
+ should_log_l1_loss = not use_diffusion or (use_diffusion and compute_diffusion_l1)
739
+ if should_log_l1_loss:
740
+ ground_truth_curr_action = ground_truth_actions[:, 0]
741
+ predicted_curr_action = predicted_actions[:, 0]
742
+ ground_truth_next_actions = ground_truth_actions[:, 1:]
743
+ predicted_next_actions = predicted_actions[:, 1:]
744
+ curr_action_l1_loss = torch.nn.L1Loss()(ground_truth_curr_action, predicted_curr_action)
745
+ next_actions_l1_loss = torch.nn.L1Loss()(ground_truth_next_actions, predicted_next_actions)
746
+ metrics.update(
747
+ {
748
+ "curr_action_l1_loss": curr_action_l1_loss.item(),
749
+ "next_actions_l1_loss": next_actions_l1_loss.item(),
750
+ }
751
+ )
752
+
753
+ # Return both the loss tensor (with gradients) and the metrics dictionary (with detached values)
754
+ return loss, metrics
755
+
756
+
757
+ def run_diffusion_sampling(
758
+ vla,
759
+ action_head,
760
+ noisy_action_projector,
761
+ proprio_projector,
762
+ batch,
763
+ batch_size,
764
+ num_patches,
765
+ actions_shape,
766
+ device_id,
767
+ current_action_mask,
768
+ next_actions_mask,
769
+ use_proprio,
770
+ use_film,
771
+ ) -> torch.Tensor:
772
+ """
773
+ Run diffusion sampling (reverse diffusion) to generate actions.
774
+
775
+ Args:
776
+ vla (OpenVLAForActionPrediction): Vision-language-action policy.
777
+ action_head (nn.Module): Action head module.
778
+ noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
779
+ proprio_projector (nn.Module): Proprioceptive state projector module.
780
+ batch (dict): Input batch.
781
+ batch_size (int): Batch size.
782
+ num_patches (int): Number of vision patches.
783
+ actions_shape (tuple): Shape of ground-truth actions.
784
+ device_id (str): Device ID.
785
+ current_action_mask (torch.Tensor): Mask for current action.
786
+ next_actions_mask (torch.Tensor): Mask for next actions.
787
+ use_proprio (bool): Whether to use proprioceptive state as input.
788
+ use_film (bool): Whether to use FiLM for better language following.
789
+
790
+ Returns:
791
+ torch.Tensor: Predicted actions.
792
+ """
793
+ # Sample random noisy action, used as the starting point for reverse diffusion
794
+ noise = torch.randn(
795
+ size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM),
796
+ device=device_id,
797
+ dtype=torch.bfloat16,
798
+ ) # (B, chunk_len, action_dim)
799
+
800
+ # Set diffusion timestep values
801
+ action_head.module.noise_scheduler.set_timesteps(action_head.module.num_diffusion_steps_train)
802
+
803
+ # Reverse diffusion: Iteratively denoise to generate action, conditioned on observation
804
+ curr_noisy_actions = noise
805
+ for t in action_head.module.noise_scheduler.timesteps:
806
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action embedding,
807
+ # and diffusion timestep embedding)
808
+ timesteps = torch.Tensor([t]).repeat(batch_size).to(device_id)
809
+ diffusion_timestep_embeddings = (
810
+ action_head.module.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
811
+ ) # (B, llm_dim)
812
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
813
+
814
+ with torch.autocast("cuda", dtype=torch.bfloat16):
815
+ output = vla(
816
+ input_ids=batch["input_ids"].to(device_id),
817
+ attention_mask=batch["attention_mask"].to(device_id),
818
+ pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
819
+ labels=batch["labels"],
820
+ output_hidden_states=True,
821
+ proprio=batch["proprio"] if use_proprio else None,
822
+ proprio_projector=proprio_projector if use_proprio else None,
823
+ noisy_actions=curr_noisy_actions,
824
+ noisy_action_projector=noisy_action_projector,
825
+ diffusion_timestep_embeddings=diffusion_timestep_embeddings,
826
+ use_film=use_film,
827
+ )
828
+ # Get last layer hidden states
829
+ last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
830
+ # Get hidden states for text portion of prompt+response (after the vision patches)
831
+ text_hidden_states = last_hidden_states[:, num_patches:-1]
832
+ # Get hidden states for action portion of response
833
+ actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask].reshape(
834
+ batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1
835
+ ) # (B, act_chunk_len, D)
836
+ actions_hidden_states = actions_hidden_states.to(torch.bfloat16)
837
+ # Predict noise
838
+ noise_pred = action_head.module.predict_noise(actions_hidden_states)
839
+
840
+ # Compute the action at the previous diffusion timestep: x_t -> x_{t-1}
841
+ curr_noisy_actions = action_head.module.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
842
+
843
+ return curr_noisy_actions.reshape(actions_shape)
844
+
845
+
846
+ def compute_smoothened_metrics(metrics_deques) -> dict:
847
+ """
848
+ Compute smoothened metrics from recent deques.
849
+
850
+ Args:
851
+ metrics_deques (dict): Dictionary of deques containing recent metrics.
852
+
853
+ Returns:
854
+ dict: Dictionary of smoothened metrics.
855
+ """
856
+ smoothened_metrics = {}
857
+ for name, deque in metrics_deques.items():
858
+ if deque and len(deque) > 0:
859
+ smoothened_metrics[name] = sum(deque) / len(deque)
860
+ return smoothened_metrics
861
+
862
+
863
+ def compute_diff_regularization_loss(model, diff_params_dict, regularization_weight=1.0):
864
+ """
865
+ 计算模型参数和diff_path中同名参数之间的正则化loss,用于防止模型参数向diff_path参数的方向更新。
866
+ 参考正交化loss的实现方式,计算参数之间的内积来惩罚相似性。
867
+
868
+ Args:
869
+ model: 模型(可能是DDP包装的)
870
+ diff_params_dict: 从diff_path加载的参数字典
871
+ regularization_weight: 正则化权重
872
+
873
+ Returns:
874
+ regularization_loss: 正则化loss值
875
+ """
876
+ orthogonal_loss = 0.
877
+ matched_count = 0
878
+
879
+ # 获取模型的实际模块(如果是DDP包装的)
880
+ model_module = model.module if hasattr(model, 'module') else model
881
+
882
+ for name, param in model_module.named_parameters():
883
+ if "lora" in name:
884
+ if not param.requires_grad:
885
+ continue
886
+
887
+ # 尝试匹配diff_params_dict中的同名参数
888
+ # 需要处理可能的命名差异:
889
+ # 1. diff_path中可能没有"base_model.model."前缀
890
+ # 2. diff_path中可能在.lora_A或.lora_B后多了一个".default"
891
+ # 例如:model中是 "xxx.lora_A.weight"
892
+ # diff中是 "xxx.lora_A.default.weight"
893
+ matched_diff_param = None
894
+
895
+ # 首先尝试直接匹配
896
+ if name in diff_params_dict:
897
+ import pdb; pdb.set_trace()
898
+ matched_diff_param = diff_params_dict[name]
899
+ else:
900
+ # import pdb; pdb.set_trace()
901
+ # 尝试处理".default"的差异:在.lora_A或.lora_B后添加.default
902
+ # follow o-lora只约束lora_A的参数
903
+ if ".lora_A." in name:
904
+ name_with_default = name.replace(".lora_A.default.", ".lora_A.")
905
+ if name_with_default in diff_params_dict:
906
+ matched_diff_param = diff_params_dict[name_with_default]
907
+ # elif ".lora_B." in name:
908
+ # name_with_default = name.replace(".lora_B.default.", ".lora_B.")
909
+ # if name_with_default in diff_params_dict:
910
+ # matched_diff_param = diff_params_dict[name_with_default]
911
+
912
+ if matched_diff_param is not None:
913
+ # print(f"匹配到参数: {name}")
914
+ # 确保参数在同一个设备上
915
+ diff_param = matched_diff_param.to(device=param.device, dtype=param.dtype)
916
+
917
+ # 检查形状是否匹配
918
+ if param.shape == diff_param.shape:
919
+ # 使用detach().clone().requires_grad_()来避免DDP的重复标记问题
920
+ # 这会创建一个新的tensor,保持梯度连接,但不会触发DDP的重复标记
921
+ param_safe = param.clone()
922
+ diff_param_safe = diff_param.detach().clone()
923
+
924
+ # 对于视觉模型内的多维lora参数
925
+ param_flat = param_safe.reshape(-1) # [N]
926
+ diff_param_flat = diff_param_safe.reshape(-1) # [N]
927
+ inner_product = torch.abs((param_flat * diff_param_flat).sum())
928
+ orthogonal_loss += inner_product
929
+ matched_count += 1
930
+ # print(f"匹配到参数: {name} 的正则化loss: {inner_product}")
931
+
932
+ # print(f"正则化loss: {orthogonal_loss}")
933
+ if matched_count > 0:
934
+ orthogonal_loss = orthogonal_loss * regularization_weight
935
+ else:
936
+ # 如果没有匹配的参数,返回0(需要梯度,这样在backward时不会报错)
937
+ # 但实际梯度为0,所以不会影响训练
938
+ device = next(model_module.parameters()).device
939
+ orthogonal_loss = torch.tensor(0.0, device=device, requires_grad=True)
940
+
941
+ return orthogonal_loss
942
+
943
+
944
+ def load_diff_params(diff_path, device="cpu"):
945
+ """
946
+ 从safetensors或pth文件加载参数。
947
+
948
+ Args:
949
+ diff_path: 参数文件路径
950
+ device: 加载到的设备
951
+
952
+ Returns:
953
+ diff_params_dict: 参数字典
954
+ """
955
+ diff_params_dict = {}
956
+
957
+ if diff_path.endswith('.safetensors'):
958
+ if not SAFETENSORS_AVAILABLE:
959
+ raise ImportError("safetensors library is required to load .safetensors files")
960
+
961
+ with safe_open(diff_path, framework="pt", device=device) as f:
962
+ for key in f.keys():
963
+ diff_params_dict[key] = f.get_tensor(key)
964
+ else:
965
+ # 假设是pth或其他torch格式
966
+ loaded = torch.load(diff_path, map_location=device)
967
+ if isinstance(loaded, dict):
968
+ if "state_dict" in loaded:
969
+ diff_params_dict = loaded["state_dict"]
970
+ else:
971
+ diff_params_dict = loaded
972
+ else:
973
+ diff_params_dict = loaded
974
+
975
+ return diff_params_dict
976
+
977
+
978
+ def log_metrics_to_wandb(metrics, prefix, step, wandb_entity) -> None:
979
+ """
980
+ Log metrics to Weights & Biases.
981
+
982
+ Args:
983
+ metrics (dict): Dictionary of metrics to log
984
+ prefix (str): Prefix for metric names
985
+ step (int): Training step
986
+ wandb_entity (str): W&B entity instance
987
+
988
+ Returns:
989
+ None.
990
+ """
991
+ log_dict = {}
992
+ for name, value in metrics.items():
993
+ # Map loss_value to Loss for better readability in W&B
994
+ if name == "loss_value":
995
+ log_dict[f"{prefix}/Loss"] = value
996
+ # Keep other metrics as is
997
+ else:
998
+ log_dict[f"{prefix}/{name.replace('_', ' ').title()}"] = value
999
+ wandb_entity.log(log_dict, step=step)
1000
+
1001
+
1002
+ def save_training_checkpoint(
1003
+ cfg,
1004
+ run_dir,
1005
+ log_step,
1006
+ vla,
1007
+ processor,
1008
+ proprio_projector,
1009
+ noisy_action_projector,
1010
+ action_head,
1011
+ train_dataset,
1012
+ distributed_state,
1013
+ ) -> None:
1014
+ """
1015
+ Save all training checkpoints including model components, LoRA adapter, and dataset statistics.
1016
+
1017
+ Args:
1018
+ cfg (FinetuneConfig): Training configuration.
1019
+ run_dir (Path): Experiment run directory path.
1020
+ log_step (int): Current logging step.
1021
+ vla (OpenVLAForActionPrediction): Vision-language-action policy.
1022
+ processor (PrismaticProcessor): OpenVLA inputs processor.
1023
+ proprio_projector (nn.Module): Proprioceptive state projector module.
1024
+ noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
1025
+ action_head (nn.Module): Action head module.
1026
+ train_dataset (RLDSDataset): Training dataset.
1027
+ distributed_state (PartialState): Distributed training state.
1028
+
1029
+ Returns:
1030
+ None.
1031
+ """
1032
+ # Determine checkpoint paths and naming
1033
+ if cfg.save_latest_checkpoint_only:
1034
+ checkpoint_dir = run_dir
1035
+ checkpoint_name_suffix = "latest_checkpoint.pt"
1036
+ else:
1037
+ checkpoint_dir = run_dir / f"{log_step}_chkpt"
1038
+ checkpoint_name_suffix = f"{log_step}_checkpoint.pt"
1039
+
1040
+ adapter_dir = checkpoint_dir / "lora_adapter"
1041
+
1042
+ # Create directories and save dataset statistics (main process only)
1043
+ if distributed_state.is_main_process:
1044
+ os.makedirs(checkpoint_dir, exist_ok=True)
1045
+ os.makedirs(adapter_dir, exist_ok=True)
1046
+ save_dataset_statistics(train_dataset.dataset_statistics, checkpoint_dir)
1047
+ print(f"Saving Model Checkpoint for Step {log_step}")
1048
+
1049
+ # Wait for directories to be created
1050
+ dist.barrier()
1051
+
1052
+ # Save model components (main process only)
1053
+ if distributed_state.is_main_process:
1054
+ # Save processor and LoRA adapter
1055
+ processor.save_pretrained(checkpoint_dir)
1056
+ vla.module.save_pretrained(adapter_dir)
1057
+
1058
+ # Save other components
1059
+ if cfg.use_proprio and proprio_projector is not None:
1060
+ torch.save(proprio_projector.state_dict(), checkpoint_dir / f"proprio_projector--{checkpoint_name_suffix}")
1061
+
1062
+ if cfg.use_diffusion and noisy_action_projector is not None:
1063
+ torch.save(
1064
+ noisy_action_projector.state_dict(), checkpoint_dir / f"noisy_action_projector--{checkpoint_name_suffix}"
1065
+ )
1066
+
1067
+ if (cfg.use_l1_regression or cfg.use_diffusion) and action_head is not None:
1068
+ torch.save(action_head.state_dict(), checkpoint_dir / f"action_head--{checkpoint_name_suffix}")
1069
+
1070
+ if cfg.use_film:
1071
+ # To be safe, just save the entire vision backbone (not just FiLM components)
1072
+ torch.save(
1073
+ vla.module.vision_backbone.state_dict(), checkpoint_dir / f"vision_backbone--{checkpoint_name_suffix}"
1074
+ )
1075
+
1076
+ # Wait for model components to be saved
1077
+ dist.barrier()
1078
+
1079
+ # Merge LoRA weights into base model and save resulting model checkpoint
1080
+ # Note: Can be very slow on some devices; if so, we recommend merging offline
1081
+ if cfg.use_lora and cfg.merge_lora_during_training:
1082
+ base_vla = AutoModelForVision2Seq.from_pretrained(
1083
+ cfg.vla_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
1084
+ )
1085
+ merged_vla = PeftModel.from_pretrained(base_vla, adapter_dir)
1086
+ merged_vla = merged_vla.merge_and_unload()
1087
+
1088
+ if distributed_state.is_main_process:
1089
+ merged_vla.save_pretrained(checkpoint_dir)
1090
+ print(f"Saved merged model for Step {log_step} at: {checkpoint_dir}")
1091
+
1092
+ # Wait for merged model to be saved
1093
+ dist.barrier()
1094
+
1095
+
1096
+ def run_validation(
1097
+ vla,
1098
+ action_head,
1099
+ noisy_action_projector,
1100
+ proprio_projector,
1101
+ val_dataloader,
1102
+ action_tokenizer,
1103
+ device_id,
1104
+ cfg,
1105
+ num_patches,
1106
+ log_step,
1107
+ distributed_state,
1108
+ val_time_limit,
1109
+ ) -> None:
1110
+ """
1111
+ Compute validation set metrics for logging.
1112
+
1113
+ Args:
1114
+ vla (OpenVLAForActionPrediction): Vision-language-action policy.
1115
+ action_head (nn.Module): Action head module.
1116
+ noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion).
1117
+ proprio_projector (nn.Module): Proprioceptive state projector module.
1118
+ val_dataloader (DataLoader): Validation data loader.
1119
+ action_tokenizer (ActionTokenizer): Action tokenizer.
1120
+ device_id (str): Device ID.
1121
+ cfg (FinetuneConfig): Training configuration.
1122
+ num_patches (int): Number of vision patches.
1123
+ log_step (int): Current logging step.
1124
+ distributed_state (PartialState): Distributed training state.
1125
+ val_time_limit (int): Time limit for computing validation metrics.
1126
+
1127
+ Returns:
1128
+ None.
1129
+ """
1130
+ val_start_time = time.time()
1131
+ vla.eval()
1132
+ val_batches_count = 0
1133
+
1134
+ # List to store validation metrics
1135
+ all_val_metrics = []
1136
+
1137
+ with torch.no_grad():
1138
+ for batch in val_dataloader:
1139
+ # Always compute L1 loss for validation, even for diffusion
1140
+ _, metrics = run_forward_pass(
1141
+ vla=vla,
1142
+ action_head=action_head,
1143
+ noisy_action_projector=noisy_action_projector,
1144
+ proprio_projector=proprio_projector,
1145
+ batch=batch,
1146
+ action_tokenizer=action_tokenizer,
1147
+ device_id=device_id,
1148
+ use_l1_regression=cfg.use_l1_regression,
1149
+ use_diffusion=cfg.use_diffusion,
1150
+ use_proprio=cfg.use_proprio,
1151
+ use_film=cfg.use_film,
1152
+ num_patches=num_patches,
1153
+ compute_diffusion_l1=True,
1154
+ num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
1155
+ )
1156
+
1157
+ # Add the loss value to the metrics
1158
+ metrics["loss"] = metrics["loss_value"]
1159
+ all_val_metrics.append(metrics)
1160
+ val_batches_count += 1
1161
+
1162
+ # Cut testing on validation set short if it exceeds time limit
1163
+ if time.time() - val_start_time > val_time_limit:
1164
+ break
1165
+
1166
+ # Compute average validation metrics
1167
+ avg_val_metrics = {}
1168
+ for metric_name in all_val_metrics[0].keys():
1169
+ values = [metrics[metric_name] for metrics in all_val_metrics if metric_name in metrics]
1170
+ if values:
1171
+ avg_val_metrics[metric_name] = sum(values) / len(values)
1172
+
1173
+ # Add batch count to metrics
1174
+ avg_val_metrics["val_batches_count"] = val_batches_count
1175
+
1176
+ # Log validation metrics to W&B
1177
+ if distributed_state.is_main_process:
1178
+ log_metrics_to_wandb(avg_val_metrics, "VLA Val", log_step, wandb)
1179
+
1180
+
1181
+ @draccus.wrap()
1182
+ def finetune(cfg: FinetuneConfig) -> None:
1183
+ """
1184
+ Fine-tunes base VLA on demonstration dataset via LoRA.
1185
+
1186
+ Allows toggling different action representations (discrete vs. continuous), different learning objectives
1187
+ (next-token prediction vs. L1 regression vs. diffusion), FiLM. Also allows for additional model inputs,
1188
+ such as additional camera images and robot proprioceptive state. Assumes parallel action generation with
1189
+ action chunking.
1190
+
1191
+ Args:
1192
+ cfg (FinetuneConfig): Training configuration.
1193
+
1194
+ Returns:
1195
+ None.
1196
+ """
1197
+ assert cfg.use_lora, "Only LoRA fine-tuning is supported. Please set --use_lora=True!"
1198
+ assert not (cfg.use_l1_regression and cfg.use_diffusion), (
1199
+ "Cannot do both L1 regression and diffusion. Please pick one of them!"
1200
+ )
1201
+
1202
+ # Trim trailing forward slash ('/') in VLA path if it exists
1203
+ cfg.vla_path = cfg.vla_path.rstrip("/")
1204
+ print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`")
1205
+
1206
+ # Get experiment run ID
1207
+ run_id = get_run_id(cfg)
1208
+
1209
+ # Create experiment run directory
1210
+ run_dir = cfg.run_root_dir / run_id
1211
+ os.makedirs(run_dir, exist_ok=True)
1212
+
1213
+ # GPU setup
1214
+ distributed_state = PartialState()
1215
+ device_id = distributed_state.local_process_index
1216
+ torch.cuda.set_device(device_id)
1217
+ torch.cuda.empty_cache()
1218
+
1219
+ # Initialize wandb logging
1220
+ if distributed_state.is_main_process:
1221
+ wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=run_id, id=run_id)
1222
+
1223
+ # Print detected constants
1224
+ print(
1225
+ "Detected constants:\n"
1226
+ f"\tNUM_ACTIONS_CHUNK: {NUM_ACTIONS_CHUNK}\n"
1227
+ f"\tACTION_DIM: {ACTION_DIM}\n"
1228
+ f"\tPROPRIO_DIM: {PROPRIO_DIM}\n"
1229
+ f"\tACTION_PROPRIO_NORMALIZATION_TYPE: {ACTION_PROPRIO_NORMALIZATION_TYPE}"
1230
+ )
1231
+
1232
+ # Two options:
1233
+ # (1) Base model is on Hugging Face Hub
1234
+ # - Then download it and record the path to the download directory
1235
+ # (2) Base model is stored locally
1236
+ # - Then register model config in HF Auto Classes
1237
+ # In both cases, we want to check whether any changes have been made to
1238
+ # the `modeling_prismatic.py` file in this codebase; if so, we will copy
1239
+ # the file to the downloaded or locally stored checkpoint directory so
1240
+ # that the user's changes to the VLA class logic go into effect
1241
+ if model_is_on_hf_hub(cfg.vla_path):
1242
+ # Download model directly from Hugging Face Hub
1243
+ vla_download_path = snapshot_download(repo_id=cfg.vla_path)
1244
+ # Overwrite VLA path
1245
+ cfg.vla_path = vla_download_path
1246
+ else:
1247
+ # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub)
1248
+ AutoConfig.register("openvla", OpenVLAConfig)
1249
+ AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
1250
+ AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
1251
+ AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
1252
+
1253
+ # Update config.json and sync model files
1254
+ if distributed_state.is_main_process:
1255
+ update_auto_map(cfg.vla_path)
1256
+ check_model_logic_mismatch(cfg.vla_path)
1257
+
1258
+ # Wait for model files to be synced
1259
+ dist.barrier()
1260
+
1261
+ # Load processor and VLA
1262
+ processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
1263
+ vla = AutoModelForVision2Seq.from_pretrained(
1264
+ cfg.vla_path,
1265
+ torch_dtype=torch.bfloat16,
1266
+ low_cpu_mem_usage=True,
1267
+ trust_remote_code=True,
1268
+ ).to(device_id)
1269
+
1270
+ # Set number of images in VLA input
1271
+ vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input)
1272
+
1273
+ # LoRA setup
1274
+ if cfg.use_lora:
1275
+ lora_config = LoraConfig(
1276
+ r=cfg.lora_rank,
1277
+ lora_alpha=min(cfg.lora_rank, 16),
1278
+ lora_dropout=cfg.lora_dropout,
1279
+ target_modules="all-linear",
1280
+ init_lora_weights="gaussian",
1281
+ )
1282
+ vla = get_peft_model(vla, lora_config)
1283
+ vla.print_trainable_parameters()
1284
+
1285
+ # FiLM setup
1286
+ if cfg.use_film:
1287
+ count_parameters(vla.vision_backbone, "vla.vision_backbone (original)")
1288
+ # Wrap vision backbone with FiLM wrapper
1289
+ # Important: For this, must specify `vla.model.vision_backbone` instead of just `vla.vision_backbone`, since the
1290
+ # latter would cause the new wrapped backbone to be saved as a new attribute of `vla` instead of overwriting the
1291
+ # original one (due to the LoRA wrapper)
1292
+ vla.model.vision_backbone = FiLMedPrismaticVisionBackbone(
1293
+ vision_backbone=vla.model.vision_backbone,
1294
+ llm_dim=vla.llm_dim,
1295
+ )
1296
+ count_parameters(vla.vision_backbone, "vla.vision_backbone (post-wrap)")
1297
+ if cfg.resume:
1298
+ state_dict = load_checkpoint("vision_backbone", cfg.vla_path, cfg.resume_step)
1299
+ vla.model.vision_backbone.load_state_dict(state_dict)
1300
+ vla.model.vision_backbone = vla.model.vision_backbone.to(device_id)
1301
+
1302
+ # Wrap VLA with DDP
1303
+ vla = wrap_ddp(vla, device_id, find_unused=False)
1304
+
1305
+ # vla._set_static_graph()
1306
+
1307
+ # If applicable, instantiate proprio projector
1308
+ if cfg.use_proprio:
1309
+ proprio_projector = init_module(
1310
+ ProprioProjector,
1311
+ "proprio_projector",
1312
+ cfg,
1313
+ device_id,
1314
+ {"llm_dim": vla.module.llm_dim, "proprio_dim": PROPRIO_DIM},
1315
+ )
1316
+ else:
1317
+ proprio_projector = None
1318
+
1319
+ # If applicable, instantiate continuous action head for L1 regression
1320
+ if cfg.use_l1_regression:
1321
+ action_head = init_module(
1322
+ L1RegressionActionHead,
1323
+ "action_head",
1324
+ cfg,
1325
+ device_id,
1326
+ {"input_dim": vla.module.llm_dim, "hidden_dim": vla.module.llm_dim, "action_dim": ACTION_DIM},
1327
+ to_bf16=True,
1328
+ )
1329
+ else:
1330
+ action_head = None
1331
+
1332
+ # If applicable, instantiate diffusion action head and noisy action projector
1333
+ if cfg.use_diffusion:
1334
+ action_head = init_module(
1335
+ DiffusionActionHead,
1336
+ "action_head",
1337
+ cfg,
1338
+ device_id,
1339
+ {
1340
+ "input_dim": vla.module.llm_dim,
1341
+ "hidden_dim": vla.module.llm_dim,
1342
+ "action_dim": ACTION_DIM,
1343
+ "num_diffusion_steps_train": cfg.num_diffusion_steps_train,
1344
+ },
1345
+ to_bf16=True,
1346
+ )
1347
+ noisy_action_projector = init_module(
1348
+ NoisyActionProjector, "noisy_action_projector", cfg, device_id, {"llm_dim": vla.module.llm_dim}
1349
+ )
1350
+ else:
1351
+ noisy_action_projector = None
1352
+
1353
+ # EMA
1354
+ if cfg.use_ema:
1355
+ ema_vla = EMAModel(vla,
1356
+ action_head,
1357
+ proprio_projector,
1358
+ noisy_action_projector,
1359
+ inv_gamma=cfg.inv_gamma
1360
+ )
1361
+
1362
+ # Get number of vision patches
1363
+ NUM_PATCHES = vla.module.vision_backbone.get_num_patches() * vla.module.vision_backbone.get_num_images_in_input()
1364
+ # If we have proprio inputs, a single proprio embedding is appended to the end of the vision patch embeddings
1365
+ if cfg.use_proprio:
1366
+ NUM_PATCHES += 1
1367
+ # For diffusion, a single diffusion timestep embedding is appended to the end of the vision patch embeddings
1368
+ if cfg.use_diffusion:
1369
+ NUM_PATCHES += 1
1370
+
1371
+ diff_path = cfg.regularization_lora_vector_path # <- 改成你的
1372
+
1373
+ # Load diff parameters for regularization
1374
+ diff_params_dict = {}
1375
+ if diff_path and os.path.exists(diff_path):
1376
+ print(f"Loading diff parameters from {diff_path}")
1377
+ diff_params_dict = load_diff_params(diff_path, device="cpu")
1378
+ print(f"Loaded {len(diff_params_dict)} parameters from diff_path")
1379
+ else:
1380
+ print(f"Warning: diff_path {diff_path} does not exist, skipping regularization loss")
1381
+
1382
+ # Regularization weight (you can make this configurable via cfg if needed)
1383
+ regularization_weight = cfg.regularization_weight # 可以根据需要调整这个权重
1384
+
1385
+ # Instantiate optimizer
1386
+ trainable_params = [param for param in vla.parameters() if param.requires_grad]
1387
+ if cfg.use_l1_regression or cfg.use_diffusion:
1388
+ trainable_params += [param for param in action_head.parameters() if param.requires_grad]
1389
+ if cfg.use_diffusion:
1390
+ trainable_params += [param for param in noisy_action_projector.parameters() if param.requires_grad]
1391
+ if cfg.use_proprio:
1392
+ trainable_params += [param for param in proprio_projector.parameters() if param.requires_grad]
1393
+ print(f"# total trainable params: {sum(p.numel() for p in trainable_params)}")
1394
+ optimizer = AdamW(trainable_params, lr=cfg.learning_rate)
1395
+
1396
+ # Record original learning rate
1397
+ original_lr = optimizer.param_groups[0]["lr"]
1398
+
1399
+ # Create learning rate scheduler
1400
+ if cfg.scheduler == 'MultiStepLR':
1401
+ scheduler = MultiStepLR(
1402
+ optimizer,
1403
+ milestones=[cfg.num_steps_before_decay], # Number of steps after which LR will change
1404
+ gamma=0.1, # Multiplicative factor of learning rate decay
1405
+ )
1406
+ elif cfg.scheduler == 'CosineAnnealingLR':
1407
+ scheduler = CosineAnnealingLR(
1408
+ optimizer,
1409
+ T_max=cfg.max_steps, # Total number of steps for the cosine annealing
1410
+ eta_min=cfg.learning_rate * 1e-3,
1411
+ )
1412
+ elif cfg.scheduler == 'WarmupCosineLR':
1413
+ scheduler = get_cosine_schedule_with_warmup(
1414
+ optimizer,
1415
+ num_warmup_steps=500,
1416
+ num_training_steps=cfg.max_steps,
1417
+ )
1418
+ else:
1419
+ raise ValueError(f"Unsupported scheduler type: {cfg.scheduler}")
1420
+
1421
+ # Create Action Tokenizer
1422
+ action_tokenizer = ActionTokenizer(processor.tokenizer)
1423
+
1424
+ # Load Fine-tuning Dataset =>> note that we use an RLDS-formatted dataset following Open X-Embodiment by default.
1425
+ # =>> If you want to use a non-RLDS dataset (e.g., a standard PyTorch Dataset) see the following commented block.
1426
+ # =>> Note that our training code does not loop over epochs because the RLDS loader does this implicitly; if using
1427
+ # your own Dataset, make sure to add the appropriate logic to the training loop!
1428
+ #
1429
+ # ---
1430
+ # from prismatic.vla.datasets import DummyDataset
1431
+ #
1432
+ # train_dataset = DummyDataset(
1433
+ # action_tokenizer,
1434
+ # processor.tokenizer,
1435
+ # image_transform=processor.image_processor.apply_transform,
1436
+ # prompt_builder_fn=PurePromptBuilder,
1437
+ # )
1438
+ # ---
1439
+
1440
+ # We assume that the model takes as input one third-person camera image and 1 or 2 optional wrist camera image(s)
1441
+ use_wrist_image = cfg.num_images_in_input > 1
1442
+
1443
+ # Create training and optional validation datasets
1444
+ batch_transform = RLDSBatchTransform(
1445
+ action_tokenizer,
1446
+ processor.tokenizer,
1447
+ image_transform=processor.image_processor.apply_transform,
1448
+ prompt_builder_fn=PurePromptBuilder,
1449
+ use_wrist_image=use_wrist_image,
1450
+ use_proprio=cfg.use_proprio,
1451
+ )
1452
+ train_dataset = RLDSDataset(
1453
+ cfg.data_root_dir,
1454
+ cfg.dataset_name,
1455
+ batch_transform,
1456
+ resize_resolution=tuple(vla.module.config.image_sizes),
1457
+ shuffle_buffer_size=cfg.shuffle_buffer_size,
1458
+ image_aug=cfg.image_aug,
1459
+ )
1460
+ if cfg.use_val_set:
1461
+ val_dataset = RLDSDataset(
1462
+ cfg.data_root_dir,
1463
+ cfg.dataset_name,
1464
+ batch_transform,
1465
+ resize_resolution=tuple(vla.module.config.image_sizes),
1466
+ shuffle_buffer_size=cfg.shuffle_buffer_size // 10,
1467
+ image_aug=cfg.image_aug,
1468
+ train=False,
1469
+ )
1470
+
1471
+ # [Important] Save dataset statistics so that we can unnormalize actions during inference
1472
+ if distributed_state.is_main_process:
1473
+ save_dataset_statistics(train_dataset.dataset_statistics, run_dir)
1474
+
1475
+ # Create collator and dataloader
1476
+ collator = PaddedCollatorForActionPrediction(
1477
+ processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
1478
+ )
1479
+ dataloader = DataLoader(
1480
+ train_dataset,
1481
+ batch_size=cfg.batch_size,
1482
+ sampler=None,
1483
+ collate_fn=collator,
1484
+ num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
1485
+ )
1486
+ if cfg.use_val_set:
1487
+ val_batch_size = cfg.batch_size
1488
+ val_dataloader = DataLoader(
1489
+ val_dataset,
1490
+ batch_size=val_batch_size,
1491
+ sampler=None,
1492
+ collate_fn=collator,
1493
+ num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism
1494
+ )
1495
+
1496
+ # Deque to store recent train metrics (used for computing smoothened metrics for gradient accumulation)
1497
+ recent_metrics = {
1498
+ "loss_value": deque(maxlen=cfg.grad_accumulation_steps),
1499
+ "curr_action_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
1500
+ "curr_action_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
1501
+ "next_actions_accuracy": deque(maxlen=cfg.grad_accumulation_steps),
1502
+ "next_actions_l1_loss": deque(maxlen=cfg.grad_accumulation_steps),
1503
+ "regularization_loss": deque(maxlen=cfg.grad_accumulation_steps),
1504
+ }
1505
+
1506
+ # Start training
1507
+ with tqdm.tqdm(total=cfg.max_steps, leave=False) as progress:
1508
+ vla.train()
1509
+ optimizer.zero_grad()
1510
+ for batch_idx, batch in enumerate(dataloader):
1511
+ # Compute training metrics and loss
1512
+ compute_diffusion_l1 = cfg.use_diffusion and batch_idx % cfg.diffusion_sample_freq == 0
1513
+ loss, metrics = run_forward_pass(
1514
+ vla=vla,
1515
+ action_head=action_head,
1516
+ noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
1517
+ proprio_projector=proprio_projector if cfg.use_proprio else None,
1518
+ batch=batch,
1519
+ action_tokenizer=action_tokenizer,
1520
+ device_id=device_id,
1521
+ use_l1_regression=cfg.use_l1_regression,
1522
+ use_diffusion=cfg.use_diffusion,
1523
+ use_proprio=cfg.use_proprio,
1524
+ use_film=cfg.use_film,
1525
+ num_patches=NUM_PATCHES,
1526
+ compute_diffusion_l1=compute_diffusion_l1,
1527
+ num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
1528
+ )
1529
+
1530
+ # Add regularization loss if diff_params_dict is available
1531
+ if diff_params_dict:
1532
+ ########################### Regularization Loss ##########################
1533
+ regularization_loss = compute_diff_regularization_loss(
1534
+ vla, diff_params_dict, regularization_weight=regularization_weight
1535
+ )
1536
+ # print(f"正则化loss: {regularization_loss}")
1537
+ # print(f"主loss: {loss}")
1538
+ # 这两行是用于梯度检查的
1539
+ # 保存主loss用于梯度检查
1540
+ # main_loss = loss.clone()
1541
+ # reg_loss = regularization_loss.clone()
1542
+ # print('loss:', loss)
1543
+ # print('regularization_loss:', regularization_loss)
1544
+
1545
+ # with vla.no_sync():
1546
+ # regularization_loss.backward()
1547
+
1548
+ # model_module = vla.module if hasattr(vla, 'module') else vla
1549
+ # reg_grads = {}
1550
+ # for name, param in model_module.named_parameters():
1551
+ # if "lora_A" in name and param.requires_grad and param.grad is not None:
1552
+
1553
+ # reg_grads[name] = param.grad.clone()
1554
+
1555
+
1556
+ dummy_loss = 0.0
1557
+ for p in vla.parameters():
1558
+ if p.requires_grad:
1559
+ dummy_loss = dummy_loss + p.sum() * 0.0
1560
+
1561
+ print('action loss:', loss)
1562
+ print('regularization_loss:', regularization_loss)
1563
+ print('dummy_loss:', dummy_loss)
1564
+
1565
+ loss = loss + regularization_loss + dummy_loss
1566
+
1567
+
1568
+
1569
+ loss.backward()
1570
+ # main_grads = {}
1571
+ # for name, param in model_module.named_parameters():
1572
+ # if "lora_A" in name and param.requires_grad and param.grad is not None:
1573
+
1574
+ # main_grads[name] = param.grad.clone()
1575
+
1576
+ # print('################################################')
1577
+ # for name in main_grads.keys():
1578
+ # if name in reg_grads:
1579
+ # main_grad_norm = main_grads[name].norm().item()
1580
+ # reg_grad_norm = reg_grads[name].norm().item()
1581
+ # combined_grad_norm = (main_grads[name] + reg_grads[name]).norm().item()
1582
+ # print(f" {name}:")
1583
+ # print(f" 主loss梯度norm: {main_grad_norm:.6f}")
1584
+ # print(f" 正则化loss梯度norm: {reg_grad_norm:.6f}")
1585
+ # print(f" 合并梯度norm: {combined_grad_norm:.6f}")
1586
+
1587
+
1588
+ # print('################################################')
1589
+ # # Log regularization loss
1590
+ # metrics["regularization_loss"] = regularization_loss.item()
1591
+ # #############################################################################
1592
+
1593
+ # # 这个if下面是用于梯度检查的
1594
+ # # 检查两个loss分别对应的梯度(在backward之前)
1595
+ # if diff_params_dict and batch_idx % cfg.wandb_log_freq == 0:
1596
+ # # 获取模型参数用于检查梯度
1597
+ # model_module = vla.module if hasattr(vla, 'module') else vla
1598
+
1599
+ # # 先清零梯度
1600
+ # optimizer.zero_grad()
1601
+
1602
+ # # 只对主loss进行backward
1603
+ # main_loss_normalized = main_loss / cfg.grad_accumulation_steps
1604
+ # main_loss_normalized.backward(retain_graph=True)
1605
+
1606
+ # # 保存主loss的梯度
1607
+ # main_grads = {}
1608
+ # for name, param in model_module.named_parameters():
1609
+ # if "lora_A" in name and param.requires_grad and param.grad is not None:
1610
+
1611
+ # main_grads[name] = param.grad.clone()
1612
+
1613
+ # # 清零梯度,只对正则化loss进行backward
1614
+ # optimizer.zero_grad()
1615
+ # reg_loss_normalized = reg_loss / cfg.grad_accumulation_steps
1616
+ # reg_loss_normalized.backward(retain_graph=True)
1617
+
1618
+ # # 保存正则化loss的梯度
1619
+ # reg_grads = {}
1620
+ # for name, param in model_module.named_parameters():
1621
+ # if "lora_A" in name and param.requires_grad and param.grad is not None:
1622
+ # reg_grads[name] = param.grad.clone()
1623
+
1624
+ # # 打印梯度信息
1625
+ # print(f"\n[梯度检查] Step {batch_idx // cfg.grad_accumulation_steps}")
1626
+ # sample_count = 0
1627
+ # for name in main_grads.keys():
1628
+ # if name in reg_grads:
1629
+ # main_grad_norm = main_grads[name].norm().item()
1630
+ # reg_grad_norm = reg_grads[name].norm().item()
1631
+ # combined_grad_norm = (main_grads[name] + reg_grads[name]).norm().item()
1632
+ # print(f" {name}:")
1633
+ # print(f" 主loss梯度norm: {main_grad_norm:.6f}")
1634
+ # print(f" 正则化loss梯度norm: {reg_grad_norm:.6f}")
1635
+ # print(f" 合并梯度norm: {combined_grad_norm:.6f}")
1636
+ # sample_count += 1
1637
+ # if sample_count >= 3: # 只检查前3个参数作为示例
1638
+ # break
1639
+ # print()
1640
+
1641
+ # # 清零梯度,准备正常的backward
1642
+ # optimizer.zero_grad()
1643
+
1644
+ # # Normalize loss to account for gradient accumulation
1645
+ # normalized_loss = loss / cfg.grad_accumulation_steps
1646
+
1647
+ # # Backward pass
1648
+ # normalized_loss.backward()
1649
+
1650
+ # Store recent train metrics
1651
+ for metric_name, value in metrics.items():
1652
+ if metric_name in recent_metrics:
1653
+ recent_metrics[metric_name].append(value)
1654
+
1655
+ # Compute gradient step index
1656
+ gradient_step_idx = batch_idx // cfg.grad_accumulation_steps
1657
+
1658
+ # Compute smoothened train metrics
1659
+ smoothened_metrics = compute_smoothened_metrics(recent_metrics)
1660
+
1661
+ # Push Metrics to W&B (every wandb_log_freq gradient steps)
1662
+ log_step = gradient_step_idx if not cfg.resume else cfg.resume_step + gradient_step_idx
1663
+ if distributed_state.is_main_process and log_step % cfg.wandb_log_freq == 0:
1664
+ log_metrics_to_wandb(smoothened_metrics, "VLA Train", log_step, wandb)
1665
+
1666
+ # [If applicable] Linearly warm up learning rate from 10% to 100% of original
1667
+ if cfg.lr_warmup_steps > 0:
1668
+ lr_progress = min((gradient_step_idx + 1) / cfg.lr_warmup_steps, 1.0) # Cap at 1.0
1669
+ current_lr = original_lr * (0.1 + 0.9 * lr_progress)
1670
+ for param_group in optimizer.param_groups:
1671
+ param_group["lr"] = current_lr
1672
+
1673
+ # Optimizer and LR scheduler step
1674
+ if (batch_idx + 1) % cfg.grad_accumulation_steps == 0:
1675
+ optimizer.step()
1676
+ scheduler.step()
1677
+ optimizer.zero_grad()
1678
+ progress.update()
1679
+ if cfg.use_ema:
1680
+ ema_vla.step(vla, action_head, proprio_projector, noisy_action_projector)
1681
+
1682
+ if distributed_state.is_main_process and gradient_step_idx % cfg.wandb_log_freq == 0:
1683
+ # Log the learning rate
1684
+ # Make sure to do this AFTER any learning rate modifications (e.g., warmup/decay)
1685
+ wandb.log(
1686
+ {
1687
+ "VLA Train/Learning Rate": scheduler.get_last_lr()[0],
1688
+ },
1689
+ step=log_step,
1690
+ )
1691
+
1692
+ if cfg.use_ema:
1693
+ # Log the EMA decay value
1694
+ wandb.log(
1695
+ {
1696
+ "VLA Train/EMA Decay": ema_vla.decay,
1697
+ },
1698
+ step=log_step,
1699
+ )
1700
+ # Log the EMA eval loss
1701
+ ema_vla.apply_shadow(vla, action_head, proprio_projector, noisy_action_projector)
1702
+ with torch.no_grad():
1703
+ vla.eval()
1704
+ action_head.eval() if action_head else None
1705
+ _, ema_metrics = run_forward_pass(
1706
+ vla=vla,
1707
+ action_head=action_head,
1708
+ noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
1709
+ proprio_projector=proprio_projector if cfg.use_proprio else None,
1710
+ batch=batch,
1711
+ action_tokenizer=action_tokenizer,
1712
+ device_id=device_id,
1713
+ use_l1_regression=cfg.use_l1_regression,
1714
+ use_diffusion=cfg.use_diffusion,
1715
+ use_proprio=cfg.use_proprio,
1716
+ use_film=cfg.use_film,
1717
+ num_patches=NUM_PATCHES,
1718
+ compute_diffusion_l1=compute_diffusion_l1,
1719
+ num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None,
1720
+ )
1721
+ ema_loss = ema_metrics['loss_value']
1722
+ vla.train()
1723
+ action_head.train() if action_head else None
1724
+ ema_vla.restore(vla, action_head, proprio_projector, noisy_action_projector)
1725
+ wandb.log(
1726
+ {
1727
+ "VLA Train/EMA Loss": ema_loss,
1728
+ },
1729
+ step=log_step,
1730
+ )
1731
+
1732
+ # Save model checkpoint: either keep latest checkpoint only or all checkpoints
1733
+ if gradient_step_idx > 0 and log_step % cfg.save_freq == 0:
1734
+ save_training_checkpoint(
1735
+ cfg=cfg,
1736
+ run_dir=run_dir,
1737
+ log_step=log_step,
1738
+ vla=vla,
1739
+ processor=processor,
1740
+ proprio_projector=proprio_projector if cfg.use_proprio else None,
1741
+ noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
1742
+ action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None,
1743
+ train_dataset=train_dataset,
1744
+ distributed_state=distributed_state,
1745
+ )
1746
+
1747
+ if cfg.use_ema:
1748
+ # Also save EMA model checkpoint
1749
+ ema_vla.apply_shadow(vla, action_head, proprio_projector, noisy_action_projector)
1750
+ save_training_checkpoint(
1751
+ cfg=cfg,
1752
+ run_dir=run_dir / "ema_model",
1753
+ log_step=log_step,
1754
+ vla=vla,
1755
+ processor=processor,
1756
+ proprio_projector=proprio_projector if cfg.use_proprio else None,
1757
+ noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
1758
+ action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None,
1759
+ train_dataset=train_dataset,
1760
+ distributed_state=distributed_state,
1761
+ )
1762
+ ema_vla.restore(vla, action_head, proprio_projector, noisy_action_projector)
1763
+
1764
+ # Test model on validation set
1765
+ if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0:
1766
+ run_validation(
1767
+ vla=vla,
1768
+ action_head=action_head,
1769
+ noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None,
1770
+ proprio_projector=proprio_projector if cfg.use_proprio else None,
1771
+ val_dataloader=val_dataloader,
1772
+ action_tokenizer=action_tokenizer,
1773
+ device_id=device_id,
1774
+ cfg=cfg,
1775
+ num_patches=NUM_PATCHES,
1776
+ log_step=log_step,
1777
+ distributed_state=distributed_state,
1778
+ val_time_limit=cfg.val_time_limit,
1779
+ )
1780
+ # Set model back to training mode after validation
1781
+ vla.train()
1782
+
1783
+ # Stop training when max_steps is reached
1784
+ if log_step == cfg.max_steps:
1785
+ print(f"Max step {cfg.max_steps} reached! Stopping training...")
1786
+ break
1787
+
1788
+
1789
+ if __name__ == "__main__":
1790
+ finetune()
capvector-oft/vla-scripts/merge_lora_weights_and_save.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads a checkpoint that only has a LoRA adapter (no merged model) and merges the adapter
3
+ into the base OpenVLA model. Saves the final checkpoint in the same directory.
4
+
5
+ Make sure to specify the correct base checkpoint when running this script. For example,
6
+ - if you fine-tuned the default OpenVLA-7B model without modifications, then `--base_checkpoint=="openvla/openvla-7b"`
7
+ - if you fine-tuned a different model or resumed fine-tuning from a different checkpoint, then specify that base checkpoint
8
+ - if you fine-tuned the default OpenVLA-7B model with modifications to `modeling_prismatic.py` (OpenVLA class definition),
9
+ then the base checkpoint path should point to the checkpoint containing the modifications
10
+
11
+ Usage:
12
+ python vla-scripts/merge_lora_weights_and_save.py \
13
+ --base_checkpoint openvla/openvla-7b \
14
+ --lora_finetuned_checkpoint_dir /PATH/TO/CHECKPOINT/DIR/
15
+ """
16
+
17
+ import os
18
+ import time
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+ from typing import Union
22
+
23
+ import draccus
24
+ import torch
25
+ from peft import PeftModel
26
+ from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
27
+
28
+ from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
29
+ from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
30
+ from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
31
+
32
+
33
+ @dataclass
34
+ class ConvertConfig:
35
+ # fmt: off
36
+
37
+ base_checkpoint: Union[str, Path] = "" # Base model checkpoint path/dir (either openvla/openvla-7b or whichever model you fine-tuned / resumed training from)
38
+ lora_finetuned_checkpoint_dir: Union[str, Path] = "" # Checkpoint directory containing the LoRA adapter
39
+
40
+ # fmt: on
41
+
42
+
43
+ @draccus.wrap()
44
+ def main(cfg: ConvertConfig) -> None:
45
+ # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub)
46
+ AutoConfig.register("openvla", OpenVLAConfig)
47
+ AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
48
+ AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
49
+ AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
50
+
51
+ # Load Model using HF AutoClasses
52
+ print(f"Loading base model: {cfg.base_checkpoint}")
53
+ vla = AutoModelForVision2Seq.from_pretrained(
54
+ cfg.base_checkpoint,
55
+ torch_dtype=torch.bfloat16,
56
+ low_cpu_mem_usage=True,
57
+ trust_remote_code=True,
58
+ )
59
+
60
+ # Load LoRA weights and merge into base model, then save final checkpoint
61
+ print("Merging LoRA weights into base model...")
62
+ start_time = time.time()
63
+ merged_vla = PeftModel.from_pretrained(vla, os.path.join(cfg.lora_finetuned_checkpoint_dir, "lora_adapter")).to(
64
+ "cuda"
65
+ )
66
+ merged_vla = merged_vla.merge_and_unload()
67
+ merged_vla.save_pretrained(cfg.lora_finetuned_checkpoint_dir)
68
+ print(f"\nMerging complete! Time elapsed (sec): {time.time() - start_time}")
69
+ print(f"\nSaved merged model checkpoint at:\n{cfg.lora_finetuned_checkpoint_dir}")
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main()
capvector-pi05/.dockerignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv
2
+ checkpoints
3
+ data
capvector-pi05/.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data directories.
2
+ assets/
3
+ checkpoints/
4
+ data/
5
+ wandb/
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ .idea/
169
+ .vscode/
capvector-pi05/.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "third_party/aloha"]
2
+ path = third_party/aloha
3
+ url = https://github.com/Physical-Intelligence/aloha.git
4
+ [submodule "third_party/libero"]
5
+ path = third_party/libero
6
+ url = https://github.com/Lifelong-Robot-Learning/LIBERO.git
capvector-pi05/.pre-commit-config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: third_party/
2
+
3
+ repos:
4
+ - repo: https://github.com/astral-sh/uv-pre-commit
5
+ # uv version.
6
+ rev: 0.5.14
7
+ hooks:
8
+ - id: uv-lock
9
+ - repo: https://github.com/astral-sh/ruff-pre-commit
10
+ # Ruff version.
11
+ rev: v0.8.6
12
+ hooks:
13
+ # Run the linter.
14
+ - id: ruff
15
+ args: [--fix]
16
+ - id: ruff-format
capvector-pi05/.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
capvector-pi05/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
capvector-pi05/README.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 1. Environment Setup
2
+ We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment:
3
+
4
+ ```bash
5
+ GIT_LFS_SKIP_SMUDGE=1 uv sync
6
+ GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
7
+ cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
8
+ source .venv/bin/activate
9
+ ```
10
+
11
+ NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
12
+
13
+
14
+ ## 2. Data Preparation
15
+ Here we take the real-world Aloha data as example, more detail simulation data could be refered in the [official openpi repo](https://github.com/Physical-Intelligence/openpi/).
16
+
17
+ First, you need to collect the task-specific raw data with your own robot, and save it in the `.hdf5` format.
18
+
19
+ Then, convert the data to LeRobot dataset format.
20
+ ```bash
21
+ uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
22
+ # By default, The converted data is stored in ~/.cache/huggingface/lerobot/<org>/<dataset-name>/
23
+ ```
24
+
25
+
26
+ ## 3. Obtain the capability vectors and merge it to obtain $\theta_{meta}$
27
+
28
+ First, define your task-specific config in [config.py](src/openpi/training/config.py). And we provide an example of our real-world task [here](src/openpi/training/config.py#L776-L808).
29
+
30
+ Then, convert a JAX model checkpoint to PyTorch format:
31
+ ```bash
32
+ uv run examples/convert_jax_model_to_pytorch.py \
33
+ --checkpoint_dir gs://openpi-assets/checkpoints/pi05_base \
34
+ --config_name <config_name> \
35
+ --output_path checkpoints/pytorch_pi05_base
36
+ # This command will automatically download pi05_base checkpoint to ~/.cache/openpi/openpi-assets/checkpoints/pi05_base/
37
+ # Otherwise you can download it manually and modify the --checkpoint_dir
38
+ ```
39
+
40
+ > ⭐ If you don't use the regularization strategy, you could download the [capability-merged meta model](https://huggingface.co/haofuly/capvector_models_collection/capvector_pi05/merged_model) we provided, place it at `./checkpoints/vector_init/pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial/`, and directly jump to the next [Training step](#4-training).
41
+
42
+ Then, the capability vectors are obtained by simply conducting parameter arithmetic between two models finetuned with different strategies. Therefore, we need to prepare these two trained models, *e.g.*, [Pi0.5 on LIBERO-Spatial)](https://huggingface.co/haofuly/capvector_models_collection/capvector_pi05/pi05_baseline_30000step_spatial) and [Pi0.5-SF on LIBERO-Spatial)](https://huggingface.co/haofuly/capvector_models_collection/capvector_pi05/pi05_spatialforcing_30000step_spatial). The directory structure is as below:
43
+ ```
44
+ capvector-pi05
45
+ ├── checkpoints
46
+ · ├── pi05-LIBEROspatial
47
+ │ ├── model.safetensors
48
+ │ └── ...
49
+ ├── pi05SF-LIBEROspatial
50
+ │ ├── model.safetensors
51
+ │ └── ...
52
+ ├── diff
53
+ ├── vector_init
54
+ ·
55
+ ```
56
+
57
+ Next, conduct parameter arithmetic between these two models:
58
+ ```bash
59
+ CONFIG=pi05_capvector_aloha_place_block && \
60
+ EXT=pi05SF-LIBEROspatial && \
61
+ DOWN=pi05-LIBEROspatial && \
62
+ uv run capvector/compute_param_diff.py \
63
+ --config $CONFIG \
64
+ --a.dir checkpoints/$EXT \
65
+ --b.dir checkpoints/$DOWN \
66
+ --out checkpoints/diff/${EXT}_minus_${DOWN}.pth \
67
+ --strict-keys \
68
+ --dtype fp32
69
+ ```
70
+
71
+ Finally, merge these diff parameters to obtain $\theta_{meta}:
72
+ ```bash
73
+ DIFF=pi05SF-LIBEROspatial_minus_pi05-LIBEROspatial && \
74
+ uv run capvector/apply_param_diff.py \
75
+ --base-safetensors checkpoints/pytorch_pi05_base/model.safetensors \
76
+ --diff-pth checkpoints/diff/${DIFF}.pth \
77
+ --out-safetensors checkpoints/vector_init/${DIFF}/model.safetensors \
78
+ --scale 1.0 \
79
+ --no-strict-keys \
80
+ --dtype fp32 \
81
+ --device cpu
82
+ ```
83
+
84
+
85
+ ## 4. Training
86
+ First, you need to compute the normalization statistics for the training data.
87
+ ```bash
88
+ uv run scripts/compute_norm_stats.py --config-name <config_name>
89
+ ```
90
+
91
+ Finally, launch training using one of these modes:
92
+ ```bash
93
+ # Single GPU training:
94
+ uv run scripts/train_regular_loss_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
95
+ # Example:
96
+ uv run scripts/train_regular_loss_pytorch.py pi05_capvector_aloha_place_block --exp_name pytorch_test
97
+ uv run scripts/train_regular_loss_pytorch.py pi05_capvector_aloha_place_block --exp_name pytorch_test --overwrite # Overwrite existing checkpoints
98
+
99
+ # Multi-GPU training (single node):
100
+ uv run torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_regular_loss_pytorch.py <config_name> --exp_name <run_name>
101
+
102
+ # Multi-Node Training:
103
+ uv run torchrun \
104
+ --nnodes=<num_nodes> \
105
+ --nproc_per_node=<gpus_per_node> \
106
+ --node_rank=<rank_of_node> \
107
+ --master_addr=<master_ip> \
108
+ --master_port=<port> \
109
+ scripts/train_regular_loss_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
110
+ ```
111
+
112
+
113
+ ## 5. Inference
114
+ Real-world inference is executed in the server-client form.
115
+
116
+ First, launch a model server (we use the checkpoint for iteration 20,000 for this example, modify as needed):
117
+ ```bash
118
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=<config_name> --policy.dir=checkpoints/<config_name>/<run_name>/20000
119
+ ```
120
+
121
+ This will spin up a server that listens on port 8000 and waits for observations to be sent to it.
122
+
123
+ Then, We can then run an client robot script that queries the server.
124
+
125
+ You need to write your client script according to your robot. A simple [client exmaple](examples/simple_client/main.py) is as below:
126
+ ```bash
127
+ uv run examples/simple_client/main.py --env ALOHA
128
+ ```
capvector-pi05/capvector/apply_param_diff.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import tyro
7
+ from safetensors.torch import load_file, save_file
8
+
9
+
10
+ @dataclasses.dataclass
11
+ class Args:
12
+ # Base pretrained weights in safetensors
13
+ base_safetensors: str
14
+
15
+ # Diff checkpoint in .pth (either {"state_dict": ...} or raw state_dict)
16
+ diff_pth: str
17
+
18
+ # Output safetensors path
19
+ out_safetensors: str = "model_merged.safetensors"
20
+
21
+ # final = base + scale * diff
22
+ scale: float = 1.0
23
+
24
+ # whether keys must match exactly
25
+ strict_keys: bool = True # use --strict-keys / --no-strict-keys
26
+
27
+ # arithmetic dtype
28
+ dtype: str = "fp32" # fp32/fp16/bf16
29
+
30
+ # compute device
31
+ device: str = "cpu" # cpu/cuda
32
+
33
+
34
+ def cast(t: torch.Tensor, dtype: str) -> torch.Tensor:
35
+ if dtype == "fp32":
36
+ return t.float()
37
+ if dtype == "fp16":
38
+ return t.half()
39
+ if dtype == "bf16":
40
+ return t.bfloat16()
41
+ raise ValueError(f"Unknown dtype: {dtype}")
42
+
43
+
44
+ def load_diff_state_dict(path: str) -> dict[str, torch.Tensor]:
45
+ obj = torch.load(path, map_location="cpu")
46
+ if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict):
47
+ sd = obj["state_dict"]
48
+ elif isinstance(obj, dict):
49
+ sd = obj
50
+ else:
51
+ raise RuntimeError(f"Unexpected diff format: {type(obj)}")
52
+
53
+ for k, v in sd.items():
54
+ if not isinstance(v, torch.Tensor):
55
+ raise RuntimeError(f"Diff contains non-tensor at key={k}: {type(v)}")
56
+ return sd
57
+
58
+
59
+ def main(args: Args) -> None:
60
+ logging.info("Loading base safetensors: %s", args.base_safetensors)
61
+ base_sd = load_file(args.base_safetensors, device="cpu") # dict[str, Tensor]
62
+
63
+ logging.info("Loading diff pth: %s", args.diff_pth)
64
+ diff_sd = load_diff_state_dict(args.diff_pth)
65
+
66
+ keys_base = set(base_sd.keys())
67
+ keys_diff = set(diff_sd.keys())
68
+
69
+ if args.strict_keys:
70
+ if keys_base != keys_diff:
71
+ only_base = sorted(list(keys_base - keys_diff))[:30]
72
+ only_diff = sorted(list(keys_diff - keys_base))[:30]
73
+ raise RuntimeError(
74
+ "Keys mismatch between base safetensors and diff.\n"
75
+ f"Only in base (up to 30): {only_base}\n"
76
+ f"Only in diff (up to 30): {only_diff}\n"
77
+ "Use --no-strict-keys to apply on intersection only."
78
+ )
79
+ keys_apply = keys_base
80
+ else:
81
+ keys_apply = keys_base & keys_diff
82
+ logging.warning("Non-strict mode: applying on intersection keys: %d", len(keys_apply))
83
+
84
+ dev = torch.device(args.device)
85
+
86
+ merged_sd: dict[str, torch.Tensor] = {}
87
+ applied_float = 0
88
+ skipped_nonfloat = 0
89
+ skipped_missing = 0
90
+
91
+ for k, base_t_cpu in base_sd.items():
92
+ base_t = base_t_cpu # already on cpu
93
+
94
+ if k not in keys_apply:
95
+ merged_sd[k] = base_t
96
+ skipped_missing += 1
97
+ continue
98
+
99
+ diff_t_cpu = diff_sd[k]
100
+
101
+ if base_t.shape != diff_t_cpu.shape:
102
+ raise RuntimeError(f"Shape mismatch at key={k}: base {base_t.shape} vs diff {diff_t_cpu.shape}")
103
+
104
+ # only add for floating-point tensors
105
+ if base_t.is_floating_point() and diff_t_cpu.is_floating_point():
106
+ a = cast(base_t.to(dev), args.dtype)
107
+ d = cast(diff_t_cpu.to(dev), args.dtype)
108
+ out = a + args.scale * d
109
+ merged_sd[k] = out.to(base_t.dtype).detach().cpu()
110
+ applied_float += 1
111
+ else:
112
+ merged_sd[k] = base_t
113
+ skipped_nonfloat += 1
114
+
115
+ out_path = Path(args.out_safetensors)
116
+ out_path.parent.mkdir(parents=True, exist_ok=True)
117
+
118
+ # safetensors 需要所有 tensor 在 CPU
119
+ for k, v in merged_sd.items():
120
+ if v.device.type != "cpu":
121
+ merged_sd[k] = v.cpu()
122
+
123
+ logging.info(
124
+ "Done. applied_float=%d, skipped_nonfloat=%d, skipped_missing=%d",
125
+ applied_float,
126
+ skipped_nonfloat,
127
+ skipped_missing,
128
+ )
129
+ logging.info("Saving merged safetensors to: %s", str(out_path))
130
+ save_file(merged_sd, str(out_path))
131
+
132
+
133
+ if __name__ == "__main__":
134
+ logging.basicConfig(level=logging.INFO, force=True)
135
+ main(tyro.cli(Args))
capvector-pi05/capvector/compute_param_diff.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import torch
7
+ import tyro
8
+
9
+ from openpi.training import config as _config
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class CkptSpec:
14
+ dir: str
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class Args:
19
+ config: str
20
+ a: CkptSpec
21
+ b: CkptSpec
22
+ out: str = "checkpoints/diff/a_minus_b.pth"
23
+ only_vlm: bool = False
24
+ strict_keys: bool = False
25
+ dtype: str = "fp32"
26
+ device: str = "cpu"
27
+
28
+
29
+ def _extract_state_dict(obj: Any) -> dict[str, torch.Tensor]:
30
+ """
31
+ Try best to get a torch state_dict from a Policy or Module-like object.
32
+ """
33
+ # Case 1: policy itself has state_dict()
34
+ if hasattr(obj, "state_dict") and callable(obj.state_dict):
35
+ sd = obj.state_dict()
36
+ if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()):
37
+ return sd
38
+
39
+ # Case 2: common attributes that hold torch.nn.Module
40
+ for attr in ["model", "_model", "module", "net", "_net", "policy", "_policy"]:
41
+ if hasattr(obj, attr):
42
+ m = getattr(obj, attr)
43
+ if hasattr(m, "state_dict") and callable(m.state_dict):
44
+ sd = m.state_dict()
45
+ if isinstance(sd, dict) and all(isinstance(v, torch.Tensor) for v in sd.values()):
46
+ return sd
47
+
48
+ raise RuntimeError(
49
+ "Cannot extract state_dict. "
50
+ "Please inspect Policy object and update attribute list in _extract_state_dict()."
51
+ )
52
+
53
+
54
+ def _cast_tensor(t: torch.Tensor, dtype: str) -> torch.Tensor:
55
+ if dtype == "fp32":
56
+ return t.float()
57
+ if dtype == "fp16":
58
+ return t.half()
59
+ if dtype == "bf16":
60
+ return t.bfloat16()
61
+ raise ValueError(f"Unknown dtype: {dtype}")
62
+
63
+
64
+ def load_model(config_name: str, spec: CkptSpec):
65
+ cfg = _config.get_config(config_name)
66
+ weight_path = Path(spec.dir) / "model.safetensors"
67
+ if not weight_path.exists():
68
+ raise FileNotFoundError(f"Missing model.safetensors in checkpoint directory: {spec.dir}")
69
+ return cfg.model.load_pytorch(cfg, str(weight_path))
70
+
71
+
72
+ def main(args: Args) -> None:
73
+ logging.info("Loading A model from %s with config %s", args.a.dir, args.config)
74
+ model_a = load_model(args.config, args.a)
75
+ logging.info("Loading B model from %s with config %s", args.b.dir, args.config)
76
+ model_b = load_model(args.config, args.b)
77
+
78
+ sd_a = _extract_state_dict(model_a)
79
+ sd_b = _extract_state_dict(model_b)
80
+
81
+ keys_a = set(sd_a.keys())
82
+ keys_b = set(sd_b.keys())
83
+
84
+ if args.strict_keys:
85
+ if keys_a != keys_b:
86
+ only_a = sorted(list(keys_a - keys_b))[:20]
87
+ only_b = sorted(list(keys_b - keys_a))[:20]
88
+ raise RuntimeError(
89
+ f"State dict keys mismatch.\n"
90
+ f"Only in A (show up to 20): {only_a}\n"
91
+ f"Only in B (show up to 20): {only_b}\n"
92
+ f"Set --strict-keys False to subtract intersection only."
93
+ )
94
+ keys = sorted(keys_a)
95
+ else:
96
+ keys = sorted(list(keys_a & keys_b))
97
+ logging.warning("Non-strict mode: subtracting only intersection keys: %d", len(keys))
98
+
99
+ device = torch.device(args.device)
100
+ diff: dict[str, torch.Tensor] = {}
101
+
102
+ if args.only_vlm:
103
+ ZERO_PREFIXES = [
104
+ "paligemma_with_expert.gemma_expert.",
105
+ "action_in_proj.",
106
+ "action_out_proj.",
107
+ "action_time_mlp_in",
108
+ "action_time_mlp_oout",
109
+ ]
110
+ else:
111
+ ZERO_PREFIXES = []
112
+
113
+ for k in keys:
114
+ ta = sd_a[k].to(device)
115
+ tb = sd_b[k].to(device)
116
+
117
+ if ta.shape != tb.shape:
118
+ raise RuntimeError(f"Shape mismatch at key={k}: {ta.shape} vs {tb.shape}")
119
+
120
+ zero_this = any(k.startswith(p) for p in ZERO_PREFIXES)
121
+
122
+ if zero_this:
123
+ out = torch.zeros_like(ta)
124
+ else:
125
+ if ta.is_floating_point():
126
+ out = _cast_tensor(ta, args.dtype) - _cast_tensor(tb, args.dtype)
127
+ else:
128
+ out = ta
129
+
130
+ diff[k] = out.detach().cpu()
131
+
132
+
133
+
134
+ out_path = Path(args.out)
135
+ out_path.parent.mkdir(parents=True, exist_ok=True)
136
+ torch.save({"state_dict": diff, "a": dataclasses.asdict(args.a), "b": dataclasses.asdict(args.b)}, out_path)
137
+ logging.info("Saved diff checkpoint to: %s", str(out_path))
138
+
139
+
140
+ if __name__ == "__main__":
141
+ logging.basicConfig(level=logging.INFO, force=True)
142
+ main(tyro.cli(Args))
capvector-pi05/docs/docker.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Docker Setup
2
+
3
+ All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
4
+
5
+ - Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
6
+ - Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
7
+ - To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
8
+ - The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
9
+ - Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
10
+
11
+
12
+ If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
13
+
14
+ Build the Docker image and start the container with the following command:
15
+ ```bash
16
+ docker compose -f scripts/docker/compose.yml up --build
17
+ ```
18
+
19
+ To build and run the Docker image for a specific example, use the following command:
20
+ ```bash
21
+ docker compose -f examples/<example_name>/compose.yml up --build
22
+ ```
23
+ where `<example_name>` is the name of the example you want to run.
24
+
25
+ During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
capvector-pi05/docs/norm_stats.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Normalization statistics
2
+
3
+ Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
4
+
5
+ ## Reloading normalization statistics
6
+
7
+ When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
8
+
9
+ **If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
10
+
11
+ ```python
12
+ TrainConfig(
13
+ ...
14
+ data=LeRobotAlohaDataConfig(
15
+ ...
16
+ assets=AssetsConfig(
17
+ assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
18
+ asset_id="trossen",
19
+ ),
20
+ ),
21
+ )
22
+ ```
23
+
24
+ For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
25
+
26
+ **Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
27
+
28
+ **Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
29
+
30
+
31
+ ## Provided Pre-training Normalization Statistics
32
+
33
+ Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
34
+ | Robot | Description | Asset ID |
35
+ |-------|-------------|----------|
36
+ | ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
37
+ | Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
38
+ | Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
39
+ | Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
40
+ | UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
41
+ | UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
42
+ | ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
43
+ | ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
44
+ | Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
45
+
46
+
47
+ ## Pi0 Model Action Space Definitions
48
+
49
+ Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
50
+ ```
51
+ "dim_0:dim_5": "left arm joint angles",
52
+ "dim_6": "left arm gripper position",
53
+ "dim_7:dim_12": "right arm joint angles (for bi-manual only)",
54
+ "dim_13": "right arm gripper position (for bi-manual only)",
55
+
56
+ # For mobile robots:
57
+ "dim_14:dim_15": "x-y base velocity (for mobile robots only)",
58
+ ```
59
+
60
+ The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
61
+
62
+ For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
63
+
64
+ General info for Pi robots:
65
+ - Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
66
+ - Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
67
+ - Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
68
+
69
+ For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.
capvector-pi05/docs/remote_inference.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Running openpi models remotely
3
+
4
+ We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).
5
+
6
+ ## Starting a remote policy server
7
+
8
+ To start a remote policy server, you can simply run the following command:
9
+
10
+ ```bash
11
+ uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
12
+ ```
13
+
14
+ The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
15
+
16
+ ```bash
17
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
18
+ ```
19
+
20
+ This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
21
+
22
+ ## Querying the remote policy server from your robot code
23
+
24
+ We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.
25
+
26
+ First, install the `openpi-client` package in your robot environment:
27
+
28
+ ```bash
29
+ cd $OPENPI_ROOT/packages/openpi-client
30
+ pip install -e .
31
+ ```
32
+
33
+ Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
34
+
35
+ ```python
36
+ from openpi_client import image_tools
37
+ from openpi_client import websocket_client_policy
38
+
39
+ # Outside of episode loop, initialize the policy client.
40
+ # Point to the host and port of the policy server (localhost and 8000 are the defaults).
41
+ client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
42
+
43
+ for step in range(num_steps):
44
+ # Inside the episode loop, construct the observation.
45
+ # Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
46
+ # We provide utilities for resizing images + uint8 conversion so you match the training routines.
47
+ # The typical resize_size for pre-trained pi0 models is 224.
48
+ # Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
49
+ observation = {
50
+ "observation/image": image_tools.convert_to_uint8(
51
+ image_tools.resize_with_pad(img, 224, 224)
52
+ ),
53
+ "observation/wrist_image": image_tools.convert_to_uint8(
54
+ image_tools.resize_with_pad(wrist_img, 224, 224)
55
+ ),
56
+ "observation/state": state,
57
+ "prompt": task_instruction,
58
+ }
59
+
60
+ # Call the policy server with the current observation.
61
+ # This returns an action chunk of shape (action_horizon, action_dim).
62
+ # Note that you typically only need to call the policy every N steps and execute steps
63
+ # from the predicted action chunk open-loop in the remaining steps.
64
+ action_chunk = client.infer(observation)["actions"]
65
+
66
+ # Execute the actions in the environment.
67
+ ...
68
+
69
+ ```
70
+
71
+ Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](../examples/simple_client/main.py).
capvector-pi05/examples/aloha_real/Dockerfile ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for the Aloha real environment.
2
+
3
+ # Build the container:
4
+ # docker build . -t aloha_real -f examples/aloha_real/Dockerfile
5
+
6
+ # Run the container:
7
+ # docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
8
+
9
+ FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
10
+ SHELL ["/bin/bash", "-c"]
11
+
12
+ ENV DEBIAN_FRONTEND=noninteractive
13
+ RUN apt-get update && \
14
+ apt-get install -y --no-install-recommends \
15
+ cmake \
16
+ curl \
17
+ libffi-dev \
18
+ python3-rosdep \
19
+ python3-rosinstall \
20
+ python3-rosinstall-generator \
21
+ whiptail \
22
+ git \
23
+ wget \
24
+ openssh-client \
25
+ ros-noetic-cv-bridge \
26
+ ros-noetic-usb-cam \
27
+ ros-noetic-realsense2-camera \
28
+ keyboard-configuration
29
+
30
+ WORKDIR /root
31
+ RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
32
+ RUN chmod +x xsarm_amd64_install.sh
33
+ RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
34
+
35
+ COPY ./third_party/aloha /root/interbotix_ws/src/aloha
36
+ RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
37
+
38
+ # Install python 3.10 because this ROS image comes with 3.8
39
+ RUN mkdir /python && \
40
+ cd /python && \
41
+ wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
42
+ tar -zxvf Python-3.10.14.tgz && \
43
+ cd Python-3.10.14 && \
44
+ ls -lhR && \
45
+ ./configure --enable-optimizations && \
46
+ make install && \
47
+ echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
48
+ echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
49
+ cd ~ && rm -rf /python && \
50
+ rm -rf /var/lib/apt/lists/*
51
+
52
+ COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
53
+ ENV UV_HTTP_TIMEOUT=120
54
+ ENV UV_LINK_MODE=copy
55
+ COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
56
+ COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
57
+ RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
58
+
59
+ ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
60
+ WORKDIR /app
61
+
62
+ # Create an entrypoint script to run the setup commands, followed by the command passed in.
63
+ RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
64
+ #!/bin/bash
65
+ source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
66
+ EOF
67
+ RUN chmod +x /usr/local/bin/entrypoint.sh
68
+
69
+ ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
70
+ CMD ["python3", "/app/examples/aloha_real/main.py"]
capvector-pi05/examples/aloha_real/README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run Aloha (Real Robot)
2
+
3
+ This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
4
+
5
+ ## Prerequisites
6
+
7
+ This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
8
+
9
+ 1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
10
+ 1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
11
+
12
+ ## With Docker
13
+
14
+ ```bash
15
+ export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
16
+ docker compose -f examples/aloha_real/compose.yml up --build
17
+ ```
18
+
19
+ ## Without Docker
20
+
21
+ Terminal window 1:
22
+
23
+ ```bash
24
+ # Create virtual environment
25
+ uv venv --python 3.10 examples/aloha_real/.venv
26
+ source examples/aloha_real/.venv/bin/activate
27
+ uv pip sync examples/aloha_real/requirements.txt
28
+ uv pip install -e packages/openpi-client
29
+
30
+ # Run the robot
31
+ python -m examples.aloha_real.main
32
+ ```
33
+
34
+ Terminal window 2:
35
+
36
+ ```bash
37
+ roslaunch aloha ros_nodes.launch
38
+ ```
39
+
40
+ Terminal window 3:
41
+
42
+ ```bash
43
+ uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
44
+ ```
45
+
46
+ ## **ALOHA Checkpoint Guide**
47
+
48
+
49
+ The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
50
+
51
+ While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
52
+
53
+
54
+ ---
55
+
56
+ ### **Toast Task**
57
+
58
+ This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
59
+
60
+ - **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
61
+ - **Prompt**: "take the toast out of the toaster"
62
+ - **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
63
+ - **Object Distribution**:
64
+ - Works on both real toast and rubber fake toast
65
+ - Compatible with standard 2-slice toasters
66
+ - Works with plates of varying colors
67
+
68
+ ### **Scene Setup Guidelines**
69
+ <img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
70
+
71
+ - The toaster should be positioned in the top-left quadrant of the workspace.
72
+ - Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
73
+ - The plate should be placed roughly in the lower-center of the workspace.
74
+ - Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
75
+
76
+
77
+ ### **Towel Task**
78
+
79
+ This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
80
+
81
+ - **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
82
+ - **Prompt**: "fold the towel"
83
+ - **Object Distribution**:
84
+ - Works on towels of varying solid colors
85
+ - Performance is worse on heavily textured or striped towels
86
+
87
+ ### **Scene Setup Guidelines**
88
+ <img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
89
+
90
+ - The towel should be flattened and roughly centered on the table.
91
+ - Choose a towel that does not blend in with the table surface.
92
+
93
+
94
+ ### **Tupperware Task**
95
+
96
+ This task involves opening a tupperware filled with food and pouring the contents onto a plate.
97
+
98
+ - **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
99
+ - **Prompt**: "open the tupperware and put the food on the plate"
100
+ - **Objects needed**: Tupperware, food (or food-like items), and a plate.
101
+ - **Object Distribution**:
102
+ - Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
103
+ - Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
104
+ - The policy has seen plates of varying solid colors.
105
+
106
+ ### **Scene Setup Guidelines**
107
+ <img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
108
+
109
+ - Best performance observed when both the tupperware and plate are roughly centered in the workspace.
110
+ - Positioning:
111
+ - Tupperware should be on the left.
112
+ - Plate should be on the right or bottom.
113
+ - The tupperware flap should point toward the plate.
114
+
115
+ ## Training on your own Aloha dataset
116
+
117
+ 1. Convert the dataset to the LeRobot dataset v2.0 format.
118
+
119
+ We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
120
+
121
+
122
+ 2. Define a training config that uses the custom dataset.
123
+
124
+ We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
125
+
126
+ IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
capvector-pi05/examples/aloha_real/compose.yml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with:
2
+ # docker compose -f examples/aloha_real/compose.yml up --build
3
+ services:
4
+ runtime:
5
+ image: aloha_real
6
+ depends_on:
7
+ - aloha_ros_nodes
8
+ - ros_master
9
+ - openpi_server
10
+ build:
11
+ context: ../..
12
+ dockerfile: examples/aloha_real/Dockerfile
13
+ init: true
14
+ tty: true
15
+ network_mode: host
16
+ privileged: true
17
+ volumes:
18
+ - $PWD:/app
19
+ - ../../data:/data
20
+
21
+ aloha_ros_nodes:
22
+ image: aloha_real
23
+ depends_on:
24
+ - ros_master
25
+ build:
26
+ context: ../..
27
+ dockerfile: examples/aloha_real/Dockerfile
28
+ init: true
29
+ tty: true
30
+ network_mode: host
31
+ privileged: true
32
+ volumes:
33
+ - /dev:/dev
34
+ command: roslaunch --wait aloha ros_nodes.launch
35
+
36
+ ros_master:
37
+ image: ros:noetic-robot
38
+ network_mode: host
39
+ privileged: true
40
+ command:
41
+ - roscore
42
+
43
+ openpi_server:
44
+ image: openpi_server
45
+ build:
46
+ context: ../..
47
+ dockerfile: scripts/docker/serve_policy.Dockerfile
48
+ init: true
49
+ tty: true
50
+ network_mode: host
51
+ volumes:
52
+ - $PWD:/app
53
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
54
+ environment:
55
+ - SERVER_ARGS
56
+ - OPENPI_DATA_HOME=/openpi_assets
57
+ - IS_DOCKER=true
58
+
59
+ # Comment out this block if not running on a machine with GPUs.
60
+ deploy:
61
+ resources:
62
+ reservations:
63
+ devices:
64
+ - driver: nvidia
65
+ count: 1
66
+ capabilities: [gpu]
capvector-pi05/examples/aloha_real/constants.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
2
+ # ruff: noqa
3
+
4
+ ### Task parameters
5
+
6
+ ### ALOHA fixed constants
7
+ DT = 0.001
8
+ JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
9
+ START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
10
+
11
+ # Left finger position limits (qpos[7]), right_finger = -1 * left_finger
12
+ MASTER_GRIPPER_POSITION_OPEN = 0.02417
13
+ MASTER_GRIPPER_POSITION_CLOSE = 0.01244
14
+ PUPPET_GRIPPER_POSITION_OPEN = 0.05800
15
+ PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
16
+
17
+ # Gripper joint limits (qpos[6])
18
+ MASTER_GRIPPER_JOINT_OPEN = 0.3083
19
+ MASTER_GRIPPER_JOINT_CLOSE = -0.6842
20
+ PUPPET_GRIPPER_JOINT_OPEN = 1.4910
21
+ PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
22
+
23
+ ############################ Helper functions ############################
24
+
25
+ MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
26
+ MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
27
+ )
28
+ PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
29
+ PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
30
+ )
31
+ MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
32
+ lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
33
+ )
34
+ PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
35
+ lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
36
+ )
37
+ MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
38
+
39
+ MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
40
+ MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
41
+ )
42
+ PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
43
+ PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
44
+ )
45
+ MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
46
+ lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
47
+ )
48
+ PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
49
+ lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
50
+ )
51
+ MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
52
+
53
+ MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
54
+ PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
55
+
56
+ MASTER_POS2JOINT = (
57
+ lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
58
+ + MASTER_GRIPPER_JOINT_CLOSE
59
+ )
60
+ MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
61
+ (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
62
+ )
63
+ PUPPET_POS2JOINT = (
64
+ lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
65
+ + PUPPET_GRIPPER_JOINT_CLOSE
66
+ )
67
+ PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
68
+ (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
69
+ )
70
+
71
+ MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
capvector-pi05/examples/aloha_real/convert_aloha_data_to_lerobot.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
3
+
4
+ Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
5
+ """
6
+
7
+ import dataclasses
8
+ from pathlib import Path
9
+ import shutil
10
+ from typing import Literal
11
+
12
+ import h5py
13
+ from lerobot.common.constants import HF_LEROBOT_HOME
14
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+ import tyro
19
+
20
+
21
+ @dataclasses.dataclass(frozen=True)
22
+ class DatasetConfig:
23
+ use_videos: bool = True
24
+ tolerance_s: float = 0.0001
25
+ image_writer_processes: int = 10
26
+ image_writer_threads: int = 5
27
+ video_backend: str | None = None
28
+
29
+
30
+ DEFAULT_DATASET_CONFIG = DatasetConfig()
31
+
32
+
33
+ def create_empty_dataset(
34
+ repo_id: str,
35
+ robot_type: str,
36
+ cameras: list[str],
37
+ mode: Literal["video", "image"] = "video",
38
+ *,
39
+ has_velocity: bool = False,
40
+ has_effort: bool = False,
41
+ dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
42
+ ) -> LeRobotDataset:
43
+ motors = [
44
+ "right_waist",
45
+ "right_shoulder",
46
+ "right_elbow",
47
+ "right_forearm_roll",
48
+ "right_wrist_angle",
49
+ "right_wrist_rotate",
50
+ "right_gripper",
51
+ "left_waist",
52
+ "left_shoulder",
53
+ "left_elbow",
54
+ "left_forearm_roll",
55
+ "left_wrist_angle",
56
+ "left_wrist_rotate",
57
+ "left_gripper",
58
+ ]
59
+
60
+ features = {
61
+ "observation.state": {
62
+ "dtype": "float32",
63
+ "shape": (len(motors),),
64
+ "names": [
65
+ motors,
66
+ ],
67
+ },
68
+ "action": {
69
+ "dtype": "float32",
70
+ "shape": (len(motors),),
71
+ "names": [
72
+ motors,
73
+ ],
74
+ },
75
+ }
76
+
77
+ if has_velocity:
78
+ features["observation.velocity"] = {
79
+ "dtype": "float32",
80
+ "shape": (len(motors),),
81
+ "names": [
82
+ motors,
83
+ ],
84
+ }
85
+
86
+ if has_effort:
87
+ features["observation.effort"] = {
88
+ "dtype": "float32",
89
+ "shape": (len(motors),),
90
+ "names": [
91
+ motors,
92
+ ],
93
+ }
94
+
95
+ for cam in cameras:
96
+ features[f"observation.images.{cam}"] = {
97
+ "dtype": mode,
98
+ "shape": (3, 480, 640),
99
+ "names": [
100
+ "channels",
101
+ "height",
102
+ "width",
103
+ ],
104
+ }
105
+
106
+ if Path(HF_LEROBOT_HOME / repo_id).exists():
107
+ shutil.rmtree(HF_LEROBOT_HOME / repo_id)
108
+
109
+ return LeRobotDataset.create(
110
+ repo_id=repo_id,
111
+ fps=50,
112
+ robot_type=robot_type,
113
+ features=features,
114
+ use_videos=dataset_config.use_videos,
115
+ tolerance_s=dataset_config.tolerance_s,
116
+ image_writer_processes=dataset_config.image_writer_processes,
117
+ image_writer_threads=dataset_config.image_writer_threads,
118
+ video_backend=dataset_config.video_backend,
119
+ )
120
+
121
+
122
+ def get_cameras(hdf5_files: list[Path]) -> list[str]:
123
+ with h5py.File(hdf5_files[0], "r") as ep:
124
+ # ignore depth channel, not currently handled
125
+ return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
126
+
127
+
128
+ def has_velocity(hdf5_files: list[Path]) -> bool:
129
+ with h5py.File(hdf5_files[0], "r") as ep:
130
+ return "/observations/qvel" in ep
131
+
132
+
133
+ def has_effort(hdf5_files: list[Path]) -> bool:
134
+ with h5py.File(hdf5_files[0], "r") as ep:
135
+ return "/observations/effort" in ep
136
+
137
+
138
+ def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
139
+ imgs_per_cam = {}
140
+ for camera in cameras:
141
+ uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
142
+
143
+ if uncompressed:
144
+ # load all images in RAM
145
+ imgs_array = ep[f"/observations/images/{camera}"][:]
146
+ else:
147
+ import cv2
148
+
149
+ # load one compressed image after the other in RAM and uncompress
150
+ imgs_array = []
151
+ for data in ep[f"/observations/images/{camera}"]:
152
+ imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
153
+ imgs_array = np.array(imgs_array)
154
+
155
+ imgs_per_cam[camera] = imgs_array
156
+ return imgs_per_cam
157
+
158
+
159
+ def load_raw_episode_data(
160
+ ep_path: Path,
161
+ cameras: list[str],
162
+ ) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
163
+ with h5py.File(ep_path, "r") as ep:
164
+ state = torch.from_numpy(ep["/observations/qpos"][:])
165
+ action = torch.from_numpy(ep["/action"][:])
166
+
167
+ velocity = None
168
+ if "/observations/qvel" in ep:
169
+ velocity = torch.from_numpy(ep["/observations/qvel"][:])
170
+
171
+ effort = None
172
+ if "/observations/effort" in ep:
173
+ effort = torch.from_numpy(ep["/observations/effort"][:])
174
+
175
+ imgs_per_cam = load_raw_images_per_camera(ep, cameras)
176
+
177
+ return imgs_per_cam, state, action, velocity, effort
178
+
179
+
180
+ def populate_dataset(
181
+ dataset: LeRobotDataset,
182
+ hdf5_files: list[Path],
183
+ cameras: list[str],
184
+ task: str,
185
+ episodes: list[int] | None = None,
186
+ ) -> LeRobotDataset:
187
+ if episodes is None:
188
+ episodes = range(len(hdf5_files))
189
+
190
+ for ep_idx in tqdm.tqdm(episodes):
191
+ ep_path = hdf5_files[ep_idx]
192
+
193
+ imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path, cameras)
194
+ num_frames = state.shape[0]
195
+
196
+ for i in range(num_frames):
197
+ frame = {
198
+ "observation.state": state[i],
199
+ "action": action[i],
200
+ "task": task,
201
+ }
202
+
203
+ for camera, img_array in imgs_per_cam.items():
204
+ frame[f"observation.images.{camera}"] = img_array[i]
205
+
206
+ if velocity is not None:
207
+ frame["observation.velocity"] = velocity[i]
208
+ if effort is not None:
209
+ frame["observation.effort"] = effort[i]
210
+
211
+ dataset.add_frame(frame)
212
+
213
+ dataset.save_episode()
214
+
215
+ return dataset
216
+
217
+
218
+ def port_aloha(
219
+ raw_dir: Path,
220
+ repo_id: str,
221
+ task: str = "DEBUG",
222
+ *,
223
+ episodes: list[int] | None = None,
224
+ push_to_hub: bool = False,
225
+ is_mobile: bool = False,
226
+ mode: Literal["video", "image"] = "image",
227
+ dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
228
+ ):
229
+ if (HF_LEROBOT_HOME / repo_id).exists():
230
+ shutil.rmtree(HF_LEROBOT_HOME / repo_id)
231
+
232
+ if not raw_dir.exists():
233
+ raise ValueError(f"Raw directory {raw_dir} does not exist. Please provide a valid path to the raw data.")
234
+
235
+ hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
236
+
237
+ # Get camera names from the first episode
238
+ cameras = get_cameras(hdf5_files)
239
+ print(f"Detected cameras: {cameras}")
240
+
241
+ dataset = create_empty_dataset(
242
+ repo_id,
243
+ robot_type="mobile_aloha" if is_mobile else "aloha",
244
+ cameras=cameras,
245
+ mode=mode,
246
+ has_effort=has_effort(hdf5_files),
247
+ has_velocity=has_velocity(hdf5_files),
248
+ dataset_config=dataset_config,
249
+ )
250
+ dataset = populate_dataset(
251
+ dataset,
252
+ hdf5_files,
253
+ cameras=cameras,
254
+ task=task,
255
+ episodes=episodes,
256
+ )
257
+
258
+ if push_to_hub:
259
+ dataset.push_to_hub()
260
+
261
+
262
+ if __name__ == "__main__":
263
+ tyro.cli(port_aloha)
capvector-pi05/examples/aloha_real/env.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional # noqa: UP035
2
+
3
+ import einops
4
+ from openpi_client import image_tools
5
+ from openpi_client.runtime import environment as _environment
6
+ from typing_extensions import override
7
+
8
+ from examples.aloha_real import real_env as _real_env
9
+
10
+
11
+ class AlohaRealEnvironment(_environment.Environment):
12
+ """An environment for an Aloha robot on real hardware."""
13
+
14
+ def __init__(
15
+ self,
16
+ reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
17
+ render_height: int = 224,
18
+ render_width: int = 224,
19
+ ) -> None:
20
+ self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
21
+ self._render_height = render_height
22
+ self._render_width = render_width
23
+
24
+ self._ts = None
25
+
26
+ @override
27
+ def reset(self) -> None:
28
+ self._ts = self._env.reset()
29
+
30
+ @override
31
+ def is_episode_complete(self) -> bool:
32
+ return False
33
+
34
+ @override
35
+ def get_observation(self) -> dict:
36
+ if self._ts is None:
37
+ raise RuntimeError("Timestep is not set. Call reset() first.")
38
+
39
+ obs = self._ts.observation
40
+ for k in list(obs["images"].keys()):
41
+ if "_depth" in k:
42
+ del obs["images"][k]
43
+
44
+ for cam_name in obs["images"]:
45
+ img = image_tools.convert_to_uint8(
46
+ image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
47
+ )
48
+ obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
49
+
50
+ return {
51
+ "state": obs["qpos"],
52
+ "images": obs["images"],
53
+ }
54
+
55
+ @override
56
+ def apply_action(self, action: dict) -> None:
57
+ self._ts = self._env.step(action["actions"])
capvector-pi05/examples/aloha_real/main.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+
4
+ from openpi_client import action_chunk_broker
5
+ from openpi_client import websocket_client_policy as _websocket_client_policy
6
+ from openpi_client.runtime import runtime as _runtime
7
+ from openpi_client.runtime.agents import policy_agent as _policy_agent
8
+ import tyro
9
+
10
+ from examples.aloha_real import env as _env
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class Args:
15
+ host: str = "0.0.0.0"
16
+ port: int = 8000
17
+
18
+ action_horizon: int = 25
19
+
20
+ num_episodes: int = 1
21
+ max_episode_steps: int = 1000
22
+
23
+
24
+ def main(args: Args) -> None:
25
+ ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
26
+ host=args.host,
27
+ port=args.port,
28
+ )
29
+ logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
30
+
31
+ metadata = ws_client_policy.get_server_metadata()
32
+ runtime = _runtime.Runtime(
33
+ environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
34
+ agent=_policy_agent.PolicyAgent(
35
+ policy=action_chunk_broker.ActionChunkBroker(
36
+ policy=ws_client_policy,
37
+ action_horizon=args.action_horizon,
38
+ )
39
+ ),
40
+ subscribers=[],
41
+ max_hz=50,
42
+ num_episodes=args.num_episodes,
43
+ max_episode_steps=args.max_episode_steps,
44
+ )
45
+
46
+ runtime.run()
47
+
48
+
49
+ if __name__ == "__main__":
50
+ logging.basicConfig(level=logging.INFO, force=True)
51
+ tyro.cli(main)
capvector-pi05/examples/aloha_real/real_env.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
2
+ # ruff: noqa
3
+ import collections
4
+ import time
5
+ from typing import Optional, List
6
+ import dm_env
7
+ from interbotix_xs_modules.arm import InterbotixManipulatorXS
8
+ from interbotix_xs_msgs.msg import JointSingleCommand
9
+ import numpy as np
10
+
11
+ from examples.aloha_real import constants
12
+ from examples.aloha_real import robot_utils
13
+
14
+ # This is the reset position that is used by the standard Aloha runtime.
15
+ DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
16
+
17
+
18
+ class RealEnv:
19
+ """
20
+ Environment for real robot bi-manual manipulation
21
+ Action space: [left_arm_qpos (6), # absolute joint position
22
+ left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
23
+ right_arm_qpos (6), # absolute joint position
24
+ right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
25
+
26
+ Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
27
+ left_gripper_position (1), # normalized gripper position (0: close, 1: open)
28
+ right_arm_qpos (6), # absolute joint position
29
+ right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
30
+ "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
31
+ left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
32
+ right_arm_qvel (6), # absolute joint velocity (rad)
33
+ right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
34
+ "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
35
+ "cam_low": (480x640x3), # h, w, c, dtype='uint8'
36
+ "cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
37
+ "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
38
+ """
39
+
40
+ def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
41
+ # reset_position = START_ARM_POSE[:6]
42
+ self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
43
+
44
+ self.puppet_bot_left = InterbotixManipulatorXS(
45
+ robot_model="vx300s",
46
+ group_name="arm",
47
+ gripper_name="gripper",
48
+ robot_name="puppet_left",
49
+ init_node=init_node,
50
+ )
51
+ self.puppet_bot_right = InterbotixManipulatorXS(
52
+ robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
53
+ )
54
+ if setup_robots:
55
+ self.setup_robots()
56
+
57
+ self.recorder_left = robot_utils.Recorder("left", init_node=False)
58
+ self.recorder_right = robot_utils.Recorder("right", init_node=False)
59
+ self.image_recorder = robot_utils.ImageRecorder(init_node=False)
60
+ self.gripper_command = JointSingleCommand(name="gripper")
61
+
62
+ def setup_robots(self):
63
+ robot_utils.setup_puppet_bot(self.puppet_bot_left)
64
+ robot_utils.setup_puppet_bot(self.puppet_bot_right)
65
+
66
+ def get_qpos(self):
67
+ left_qpos_raw = self.recorder_left.qpos
68
+ right_qpos_raw = self.recorder_right.qpos
69
+ left_arm_qpos = left_qpos_raw[:6]
70
+ right_arm_qpos = right_qpos_raw[:6]
71
+ left_gripper_qpos = [
72
+ constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
73
+ ] # this is position not joint
74
+ right_gripper_qpos = [
75
+ constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
76
+ ] # this is position not joint
77
+ return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
78
+
79
+ def get_qvel(self):
80
+ left_qvel_raw = self.recorder_left.qvel
81
+ right_qvel_raw = self.recorder_right.qvel
82
+ left_arm_qvel = left_qvel_raw[:6]
83
+ right_arm_qvel = right_qvel_raw[:6]
84
+ left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
85
+ right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
86
+ return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
87
+
88
+ def get_effort(self):
89
+ left_effort_raw = self.recorder_left.effort
90
+ right_effort_raw = self.recorder_right.effort
91
+ left_robot_effort = left_effort_raw[:7]
92
+ right_robot_effort = right_effort_raw[:7]
93
+ return np.concatenate([left_robot_effort, right_robot_effort])
94
+
95
+ def get_images(self):
96
+ return self.image_recorder.get_images()
97
+
98
+ def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
99
+ left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
100
+ self.gripper_command.cmd = left_gripper_desired_joint
101
+ self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
102
+
103
+ right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
104
+ right_gripper_desired_pos_normalized
105
+ )
106
+ self.gripper_command.cmd = right_gripper_desired_joint
107
+ self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
108
+
109
+ def _reset_joints(self):
110
+ robot_utils.move_arms(
111
+ [self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
112
+ )
113
+
114
+ def _reset_gripper(self):
115
+ """Set to position mode and do position resets: first close then open. Then change back to PWM mode
116
+
117
+ NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
118
+ was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
119
+ increase the frequency of motor faults.
120
+ """
121
+ robot_utils.move_grippers(
122
+ [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
123
+ )
124
+ robot_utils.move_grippers(
125
+ [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
126
+ )
127
+
128
+ def get_observation(self):
129
+ obs = collections.OrderedDict()
130
+ obs["qpos"] = self.get_qpos()
131
+ obs["qvel"] = self.get_qvel()
132
+ obs["effort"] = self.get_effort()
133
+ obs["images"] = self.get_images()
134
+ return obs
135
+
136
+ def get_reward(self):
137
+ return 0
138
+
139
+ def reset(self, *, fake=False):
140
+ if not fake:
141
+ # Reboot puppet robot gripper motors
142
+ self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
143
+ self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
144
+ self._reset_joints()
145
+ self._reset_gripper()
146
+ return dm_env.TimeStep(
147
+ step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
148
+ )
149
+
150
+ def step(self, action):
151
+ state_len = int(len(action) / 2)
152
+ left_action = action[:state_len]
153
+ right_action = action[state_len:]
154
+ self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
155
+ self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
156
+ self.set_gripper_pose(left_action[-1], right_action[-1])
157
+ time.sleep(constants.DT)
158
+ return dm_env.TimeStep(
159
+ step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
160
+ )
161
+
162
+
163
+ def get_action(master_bot_left, master_bot_right):
164
+ action = np.zeros(14) # 6 joint + 1 gripper, for two arms
165
+ # Arm actions
166
+ action[:6] = master_bot_left.dxl.joint_states.position[:6]
167
+ action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
168
+ # Gripper actions
169
+ action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
170
+ action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
171
+
172
+ return action
173
+
174
+
175
+ def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
176
+ return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
capvector-pi05/examples/aloha_real/requirements.in ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow
2
+ dm_control
3
+ einops
4
+ h5py
5
+ matplotlib
6
+ modern_robotics
7
+ msgpack
8
+ numpy>=1.22.4,<2.0.0
9
+ opencv-python
10
+ packaging
11
+ pexpect
12
+ pyquaternion
13
+ pyrealsense2
14
+ pyyaml
15
+ requests
16
+ rospkg
17
+ tyro
18
+ websockets
capvector-pi05/examples/aloha_real/requirements.txt ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
3
+ absl-py==2.1.0
4
+ # via
5
+ # dm-control
6
+ # dm-env
7
+ # labmaze
8
+ # mujoco
9
+ catkin-pkg==1.0.0
10
+ # via rospkg
11
+ certifi==2024.8.30
12
+ # via requests
13
+ charset-normalizer==3.4.0
14
+ # via requests
15
+ contourpy==1.1.1
16
+ # via matplotlib
17
+ cycler==0.12.1
18
+ # via matplotlib
19
+ distro==1.9.0
20
+ # via rospkg
21
+ dm-control==1.0.23
22
+ # via -r examples/aloha_real/requirements.in
23
+ dm-env==1.6
24
+ # via dm-control
25
+ dm-tree==0.1.8
26
+ # via
27
+ # dm-control
28
+ # dm-env
29
+ docstring-parser==0.16
30
+ # via tyro
31
+ docutils==0.20.1
32
+ # via catkin-pkg
33
+ einops==0.8.0
34
+ # via -r examples/aloha_real/requirements.in
35
+ etils==1.3.0
36
+ # via mujoco
37
+ fonttools==4.55.2
38
+ # via matplotlib
39
+ glfw==2.8.0
40
+ # via
41
+ # dm-control
42
+ # mujoco
43
+ h5py==3.11.0
44
+ # via -r examples/aloha_real/requirements.in
45
+ idna==3.10
46
+ # via requests
47
+ importlib-resources==6.4.5
48
+ # via etils
49
+ kiwisolver==1.4.7
50
+ # via matplotlib
51
+ labmaze==1.0.6
52
+ # via dm-control
53
+ lxml==5.3.0
54
+ # via dm-control
55
+ markdown-it-py==3.0.0
56
+ # via rich
57
+ matplotlib==3.7.5
58
+ # via -r examples/aloha_real/requirements.in
59
+ mdurl==0.1.2
60
+ # via markdown-it-py
61
+ modern-robotics==1.1.1
62
+ # via -r examples/aloha_real/requirements.in
63
+ msgpack==1.1.0
64
+ # via -r examples/aloha_real/requirements.in
65
+ mujoco==3.2.3
66
+ # via dm-control
67
+ numpy==1.24.4
68
+ # via
69
+ # -r examples/aloha_real/requirements.in
70
+ # contourpy
71
+ # dm-control
72
+ # dm-env
73
+ # h5py
74
+ # labmaze
75
+ # matplotlib
76
+ # modern-robotics
77
+ # mujoco
78
+ # opencv-python
79
+ # pyquaternion
80
+ # scipy
81
+ opencv-python==4.10.0.84
82
+ # via -r examples/aloha_real/requirements.in
83
+ packaging==24.2
84
+ # via
85
+ # -r examples/aloha_real/requirements.in
86
+ # matplotlib
87
+ pexpect==4.9.0
88
+ # via -r examples/aloha_real/requirements.in
89
+ pillow==10.4.0
90
+ # via
91
+ # -r examples/aloha_real/requirements.in
92
+ # matplotlib
93
+ protobuf==5.29.1
94
+ # via dm-control
95
+ ptyprocess==0.7.0
96
+ # via pexpect
97
+ pygments==2.18.0
98
+ # via rich
99
+ pyopengl==3.1.7
100
+ # via
101
+ # dm-control
102
+ # mujoco
103
+ pyparsing==3.1.4
104
+ # via
105
+ # catkin-pkg
106
+ # dm-control
107
+ # matplotlib
108
+ pyquaternion==0.9.9
109
+ # via -r examples/aloha_real/requirements.in
110
+ pyrealsense2==2.55.1.6486
111
+ # via -r examples/aloha_real/requirements.in
112
+ python-dateutil==2.9.0.post0
113
+ # via
114
+ # catkin-pkg
115
+ # matplotlib
116
+ pyyaml==6.0.2
117
+ # via
118
+ # -r examples/aloha_real/requirements.in
119
+ # rospkg
120
+ requests==2.32.3
121
+ # via
122
+ # -r examples/aloha_real/requirements.in
123
+ # dm-control
124
+ rich==13.9.4
125
+ # via tyro
126
+ rospkg==1.5.1
127
+ # via -r examples/aloha_real/requirements.in
128
+ scipy==1.10.1
129
+ # via dm-control
130
+ setuptools==75.3.0
131
+ # via
132
+ # catkin-pkg
133
+ # dm-control
134
+ # labmaze
135
+ shtab==1.7.1
136
+ # via tyro
137
+ six==1.17.0
138
+ # via python-dateutil
139
+ tqdm==4.67.1
140
+ # via dm-control
141
+ typeguard==4.4.0
142
+ # via tyro
143
+ typing-extensions==4.12.2
144
+ # via
145
+ # etils
146
+ # rich
147
+ # typeguard
148
+ # tyro
149
+ tyro==0.9.2
150
+ # via -r examples/aloha_real/requirements.in
151
+ urllib3==2.2.3
152
+ # via requests
153
+ websockets==14.1
154
+ # via -r examples/aloha_real/requirements.in
155
+ zipp==3.20.2
156
+ # via etils
capvector-pi05/examples/aloha_real/robot_utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
2
+ # ruff: noqa
3
+ from collections import deque
4
+ import datetime
5
+ import json
6
+ import time
7
+
8
+ from aloha.msg import RGBGrayscaleImage
9
+ from cv_bridge import CvBridge
10
+ from interbotix_xs_msgs.msg import JointGroupCommand
11
+ from interbotix_xs_msgs.msg import JointSingleCommand
12
+ import numpy as np
13
+ import rospy
14
+ from sensor_msgs.msg import JointState
15
+
16
+ from examples.aloha_real import constants
17
+
18
+
19
+ class ImageRecorder:
20
+ def __init__(self, init_node=True, is_debug=False):
21
+ self.is_debug = is_debug
22
+ self.bridge = CvBridge()
23
+ self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
24
+
25
+ if init_node:
26
+ rospy.init_node("image_recorder", anonymous=True)
27
+ for cam_name in self.camera_names:
28
+ setattr(self, f"{cam_name}_rgb_image", None)
29
+ setattr(self, f"{cam_name}_depth_image", None)
30
+ setattr(self, f"{cam_name}_timestamp", 0.0)
31
+ if cam_name == "cam_high":
32
+ callback_func = self.image_cb_cam_high
33
+ elif cam_name == "cam_low":
34
+ callback_func = self.image_cb_cam_low
35
+ elif cam_name == "cam_left_wrist":
36
+ callback_func = self.image_cb_cam_left_wrist
37
+ elif cam_name == "cam_right_wrist":
38
+ callback_func = self.image_cb_cam_right_wrist
39
+ else:
40
+ raise NotImplementedError
41
+ rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
42
+ if self.is_debug:
43
+ setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
44
+
45
+ self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
46
+ time.sleep(0.5)
47
+
48
+ def image_cb(self, cam_name, data):
49
+ setattr(
50
+ self,
51
+ f"{cam_name}_rgb_image",
52
+ self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
53
+ )
54
+ # setattr(
55
+ # self,
56
+ # f"{cam_name}_depth_image",
57
+ # self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
58
+ # )
59
+ setattr(
60
+ self,
61
+ f"{cam_name}_timestamp",
62
+ data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
63
+ )
64
+ # setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
65
+ # setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
66
+ # cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
67
+ if self.is_debug:
68
+ getattr(self, f"{cam_name}_timestamps").append(
69
+ data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
70
+ )
71
+
72
+ def image_cb_cam_high(self, data):
73
+ cam_name = "cam_high"
74
+ return self.image_cb(cam_name, data)
75
+
76
+ def image_cb_cam_low(self, data):
77
+ cam_name = "cam_low"
78
+ return self.image_cb(cam_name, data)
79
+
80
+ def image_cb_cam_left_wrist(self, data):
81
+ cam_name = "cam_left_wrist"
82
+ return self.image_cb(cam_name, data)
83
+
84
+ def image_cb_cam_right_wrist(self, data):
85
+ cam_name = "cam_right_wrist"
86
+ return self.image_cb(cam_name, data)
87
+
88
+ def get_images(self):
89
+ image_dict = {}
90
+ for cam_name in self.camera_names:
91
+ while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
92
+ time.sleep(0.00001)
93
+ rgb_image = getattr(self, f"{cam_name}_rgb_image")
94
+ depth_image = getattr(self, f"{cam_name}_depth_image")
95
+ self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
96
+ image_dict[cam_name] = rgb_image
97
+ image_dict[f"{cam_name}_depth"] = depth_image
98
+ return image_dict
99
+
100
+ def print_diagnostics(self):
101
+ def dt_helper(l):
102
+ l = np.array(l)
103
+ diff = l[1:] - l[:-1]
104
+ return np.mean(diff)
105
+
106
+ for cam_name in self.camera_names:
107
+ image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
108
+ print(f"{cam_name} {image_freq=:.2f}")
109
+ print()
110
+
111
+
112
+ class Recorder:
113
+ def __init__(self, side, init_node=True, is_debug=False):
114
+ self.secs = None
115
+ self.nsecs = None
116
+ self.qpos = None
117
+ self.effort = None
118
+ self.arm_command = None
119
+ self.gripper_command = None
120
+ self.is_debug = is_debug
121
+
122
+ if init_node:
123
+ rospy.init_node("recorder", anonymous=True)
124
+ rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
125
+ rospy.Subscriber(
126
+ f"/puppet_{side}/commands/joint_group",
127
+ JointGroupCommand,
128
+ self.puppet_arm_commands_cb,
129
+ )
130
+ rospy.Subscriber(
131
+ f"/puppet_{side}/commands/joint_single",
132
+ JointSingleCommand,
133
+ self.puppet_gripper_commands_cb,
134
+ )
135
+ if self.is_debug:
136
+ self.joint_timestamps = deque(maxlen=50)
137
+ self.arm_command_timestamps = deque(maxlen=50)
138
+ self.gripper_command_timestamps = deque(maxlen=50)
139
+ time.sleep(0.1)
140
+
141
+ def puppet_state_cb(self, data):
142
+ self.qpos = data.position
143
+ self.qvel = data.velocity
144
+ self.effort = data.effort
145
+ self.data = data
146
+ if self.is_debug:
147
+ self.joint_timestamps.append(time.time())
148
+
149
+ def puppet_arm_commands_cb(self, data):
150
+ self.arm_command = data.cmd
151
+ if self.is_debug:
152
+ self.arm_command_timestamps.append(time.time())
153
+
154
+ def puppet_gripper_commands_cb(self, data):
155
+ self.gripper_command = data.cmd
156
+ if self.is_debug:
157
+ self.gripper_command_timestamps.append(time.time())
158
+
159
+ def print_diagnostics(self):
160
+ def dt_helper(l):
161
+ l = np.array(l)
162
+ diff = l[1:] - l[:-1]
163
+ return np.mean(diff)
164
+
165
+ joint_freq = 1 / dt_helper(self.joint_timestamps)
166
+ arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
167
+ gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
168
+
169
+ print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
170
+
171
+
172
+ def get_arm_joint_positions(bot):
173
+ return bot.arm.core.joint_states.position[:6]
174
+
175
+
176
+ def get_arm_gripper_positions(bot):
177
+ return bot.gripper.core.joint_states.position[6]
178
+
179
+
180
+ def move_arms(bot_list, target_pose_list, move_time=1):
181
+ num_steps = int(move_time / constants.DT)
182
+ curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
183
+ traj_list = [
184
+ np.linspace(curr_pose, target_pose, num_steps)
185
+ for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
186
+ ]
187
+ for t in range(num_steps):
188
+ for bot_id, bot in enumerate(bot_list):
189
+ bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
190
+ time.sleep(constants.DT)
191
+
192
+
193
+ def move_grippers(bot_list, target_pose_list, move_time):
194
+ print(f"Moving grippers to {target_pose_list=}")
195
+ gripper_command = JointSingleCommand(name="gripper")
196
+ num_steps = int(move_time / constants.DT)
197
+ curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
198
+ traj_list = [
199
+ np.linspace(curr_pose, target_pose, num_steps)
200
+ for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
201
+ ]
202
+
203
+ with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
204
+ for t in range(num_steps):
205
+ d = {}
206
+ for bot_id, bot in enumerate(bot_list):
207
+ gripper_command.cmd = traj_list[bot_id][t]
208
+ bot.gripper.core.pub_single.publish(gripper_command)
209
+ d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
210
+ f.write(json.dumps(d) + "\n")
211
+ time.sleep(constants.DT)
212
+
213
+
214
+ def setup_puppet_bot(bot):
215
+ bot.dxl.robot_reboot_motors("single", "gripper", True)
216
+ bot.dxl.robot_set_operating_modes("group", "arm", "position")
217
+ bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
218
+ torque_on(bot)
219
+
220
+
221
+ def setup_master_bot(bot):
222
+ bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
223
+ bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
224
+ torque_off(bot)
225
+
226
+
227
+ def set_standard_pid_gains(bot):
228
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
229
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
230
+
231
+
232
+ def set_low_pid_gains(bot):
233
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
234
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
235
+
236
+
237
+ def torque_off(bot):
238
+ bot.dxl.robot_torque_enable("group", "arm", False)
239
+ bot.dxl.robot_torque_enable("single", "gripper", False)
240
+
241
+
242
+ def torque_on(bot):
243
+ bot.dxl.robot_torque_enable("group", "arm", True)
244
+ bot.dxl.robot_torque_enable("single", "gripper", True)
245
+
246
+
247
+ # for DAgger
248
+ def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
249
+ print("\nSyncing!")
250
+
251
+ # activate master arms
252
+ torque_on(master_bot_left)
253
+ torque_on(master_bot_right)
254
+
255
+ # get puppet arm positions
256
+ puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
257
+ puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
258
+
259
+ # get puppet gripper positions
260
+ puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
261
+ puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
262
+
263
+ # move master arms to puppet positions
264
+ move_arms(
265
+ [master_bot_left, master_bot_right],
266
+ [puppet_left_qpos, puppet_right_qpos],
267
+ move_time=1,
268
+ )
269
+
270
+ # move master grippers to puppet positions
271
+ move_grippers(
272
+ [master_bot_left, master_bot_right],
273
+ [puppet_left_gripper, puppet_right_gripper],
274
+ move_time=1,
275
+ )
capvector-pi05/examples/aloha_real/video_display.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from openpi_client.runtime import subscriber as _subscriber
4
+ from typing_extensions import override
5
+
6
+
7
+ class VideoDisplay(_subscriber.Subscriber):
8
+ """Displays video frames."""
9
+
10
+ def __init__(self) -> None:
11
+ self._ax: plt.Axes | None = None
12
+ self._plt_img: plt.Image | None = None
13
+
14
+ @override
15
+ def on_episode_start(self) -> None:
16
+ plt.ion()
17
+ self._ax = plt.subplot()
18
+ self._plt_img = None
19
+
20
+ @override
21
+ def on_step(self, observation: dict, action: dict) -> None:
22
+ assert self._ax is not None
23
+
24
+ im = observation["image"][0] # [C, H, W]
25
+ im = np.transpose(im, (1, 2, 0)) # [H, W, C]
26
+
27
+ if self._plt_img is None:
28
+ self._plt_img = self._ax.imshow(im)
29
+ else:
30
+ self._plt_img.set_data(im)
31
+ plt.pause(0.001)
32
+
33
+ @override
34
+ def on_episode_end(self) -> None:
35
+ plt.ioff()
36
+ plt.close()
capvector-pi05/examples/aloha_sim/Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for the Aloha simulation environment.
2
+
3
+ # Build the container:
4
+ # docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
5
+
6
+ # Run the container:
7
+ # docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
8
+
9
+ FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
10
+ COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
11
+
12
+ RUN apt-get update && \
13
+ apt-get install -y \
14
+ libosmesa6-dev \
15
+ libgl1-mesa-glx \
16
+ libglew-dev \
17
+ libglfw3-dev \
18
+ libgles2-mesa-dev
19
+ ENV MUJOCO_GL=egl
20
+
21
+ WORKDIR /app
22
+
23
+ # Copy from the cache instead of linking since it's a mounted volume
24
+ ENV UV_LINK_MODE=copy
25
+
26
+ # Write the virtual environment outside of the project directory so it doesn't
27
+ # leak out of the container when we mount the application code.
28
+ ENV UV_PROJECT_ENVIRONMENT=/.venv
29
+
30
+ # Copy the requirements files so we can install dependencies.
31
+ # The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
32
+ # This strategy is best for development-style usage.
33
+ COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
34
+ COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
35
+
36
+ # Install python dependencies.
37
+ RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
38
+ RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
39
+ ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
40
+
41
+ CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
capvector-pi05/examples/aloha_sim/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run Aloha Sim
2
+
3
+ ## With Docker
4
+
5
+ ```bash
6
+ export SERVER_ARGS="--env ALOHA_SIM"
7
+ docker compose -f examples/aloha_sim/compose.yml up --build
8
+ ```
9
+
10
+ ## Without Docker
11
+
12
+ Terminal window 1:
13
+
14
+ ```bash
15
+ # Create virtual environment
16
+ uv venv --python 3.10 examples/aloha_sim/.venv
17
+ source examples/aloha_sim/.venv/bin/activate
18
+ uv pip sync examples/aloha_sim/requirements.txt
19
+ uv pip install -e packages/openpi-client
20
+
21
+ # Run the simulation
22
+ MUJOCO_GL=egl python examples/aloha_sim/main.py
23
+ ```
24
+
25
+ Note: If you are seeing EGL errors, you may need to install the following dependencies:
26
+
27
+ ```bash
28
+ sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
29
+ ```
30
+
31
+ Terminal window 2:
32
+
33
+ ```bash
34
+ # Run the server
35
+ uv run scripts/serve_policy.py --env ALOHA_SIM
36
+ ```
capvector-pi05/examples/aloha_sim/compose.yml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with:
2
+ # docker compose -f examples/aloha_sim/compose.yml up --build
3
+ services:
4
+ runtime:
5
+ image: aloha_sim
6
+ depends_on:
7
+ - openpi_server
8
+ build:
9
+ context: ../..
10
+ dockerfile: examples/aloha_sim/Dockerfile
11
+ init: true
12
+ tty: true
13
+ network_mode: host
14
+ privileged: true
15
+ volumes:
16
+ - $PWD:/app
17
+ - ../../data:/data
18
+
19
+ openpi_server:
20
+ image: openpi_server
21
+ build:
22
+ context: ../..
23
+ dockerfile: scripts/docker/serve_policy.Dockerfile
24
+ init: true
25
+ tty: true
26
+ network_mode: host
27
+ volumes:
28
+ - $PWD:/app
29
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
30
+ environment:
31
+ - SERVER_ARGS
32
+ - OPENPI_DATA_HOME=/openpi_assets
33
+ - IS_DOCKER=true
34
+
35
+ # Comment out this block if not running on a machine with GPUs.
36
+ deploy:
37
+ resources:
38
+ reservations:
39
+ devices:
40
+ - driver: nvidia
41
+ count: 1
42
+ capabilities: [gpu]
capvector-pi05/examples/aloha_sim/env.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym_aloha # noqa: F401
2
+ import gymnasium
3
+ import numpy as np
4
+ from openpi_client import image_tools
5
+ from openpi_client.runtime import environment as _environment
6
+ from typing_extensions import override
7
+
8
+
9
+ class AlohaSimEnvironment(_environment.Environment):
10
+ """An environment for an Aloha robot in simulation."""
11
+
12
+ def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
13
+ np.random.seed(seed)
14
+ self._rng = np.random.default_rng(seed)
15
+
16
+ self._gym = gymnasium.make(task, obs_type=obs_type)
17
+
18
+ self._last_obs = None
19
+ self._done = True
20
+ self._episode_reward = 0.0
21
+
22
+ @override
23
+ def reset(self) -> None:
24
+ gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
25
+ self._last_obs = self._convert_observation(gym_obs) # type: ignore
26
+ self._done = False
27
+ self._episode_reward = 0.0
28
+
29
+ @override
30
+ def is_episode_complete(self) -> bool:
31
+ return self._done
32
+
33
+ @override
34
+ def get_observation(self) -> dict:
35
+ if self._last_obs is None:
36
+ raise RuntimeError("Observation is not set. Call reset() first.")
37
+
38
+ return self._last_obs # type: ignore
39
+
40
+ @override
41
+ def apply_action(self, action: dict) -> None:
42
+ gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
43
+ self._last_obs = self._convert_observation(gym_obs) # type: ignore
44
+ self._done = terminated or truncated
45
+ self._episode_reward = max(self._episode_reward, reward)
46
+
47
+ def _convert_observation(self, gym_obs: dict) -> dict:
48
+ img = gym_obs["pixels"]["top"]
49
+ img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
50
+ # Convert axis order from [H, W, C] --> [C, H, W]
51
+ img = np.transpose(img, (2, 0, 1))
52
+
53
+ return {
54
+ "state": gym_obs["agent_pos"],
55
+ "images": {"cam_high": img},
56
+ }
capvector-pi05/examples/aloha_sim/main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import pathlib
4
+
5
+ import env as _env
6
+ from openpi_client import action_chunk_broker
7
+ from openpi_client import websocket_client_policy as _websocket_client_policy
8
+ from openpi_client.runtime import runtime as _runtime
9
+ from openpi_client.runtime.agents import policy_agent as _policy_agent
10
+ import saver as _saver
11
+ import tyro
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class Args:
16
+ out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
17
+
18
+ task: str = "gym_aloha/AlohaTransferCube-v0"
19
+ seed: int = 0
20
+
21
+ action_horizon: int = 10
22
+
23
+ host: str = "0.0.0.0"
24
+ port: int = 8000
25
+
26
+ display: bool = False
27
+
28
+
29
+ def main(args: Args) -> None:
30
+ runtime = _runtime.Runtime(
31
+ environment=_env.AlohaSimEnvironment(
32
+ task=args.task,
33
+ seed=args.seed,
34
+ ),
35
+ agent=_policy_agent.PolicyAgent(
36
+ policy=action_chunk_broker.ActionChunkBroker(
37
+ policy=_websocket_client_policy.WebsocketClientPolicy(
38
+ host=args.host,
39
+ port=args.port,
40
+ ),
41
+ action_horizon=args.action_horizon,
42
+ )
43
+ ),
44
+ subscribers=[
45
+ _saver.VideoSaver(args.out_dir),
46
+ ],
47
+ max_hz=50,
48
+ )
49
+
50
+ runtime.run()
51
+
52
+
53
+ if __name__ == "__main__":
54
+ logging.basicConfig(level=logging.INFO, force=True)
55
+ tyro.cli(main)
capvector-pi05/examples/aloha_sim/requirements.in ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gym-aloha
2
+ imageio
3
+ matplotlib
4
+ msgpack
5
+ numpy>=1.22.4,<2.0.0
6
+ typing-extensions
7
+ tyro
8
+ websockets
capvector-pi05/examples/aloha_sim/requirements.txt ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
3
+ absl-py==2.1.0
4
+ # via
5
+ # dm-control
6
+ # dm-env
7
+ # labmaze
8
+ # mujoco
9
+ certifi==2024.8.30
10
+ # via requests
11
+ charset-normalizer==3.4.0
12
+ # via requests
13
+ cloudpickle==3.1.0
14
+ # via gymnasium
15
+ contourpy==1.3.1
16
+ # via matplotlib
17
+ cycler==0.12.1
18
+ # via matplotlib
19
+ dm-control==1.0.14
20
+ # via gym-aloha
21
+ dm-env==1.6
22
+ # via dm-control
23
+ dm-tree==0.1.8
24
+ # via
25
+ # dm-control
26
+ # dm-env
27
+ docstring-parser==0.16
28
+ # via tyro
29
+ farama-notifications==0.0.4
30
+ # via gymnasium
31
+ fonttools==4.55.2
32
+ # via matplotlib
33
+ glfw==2.8.0
34
+ # via
35
+ # dm-control
36
+ # mujoco
37
+ gym-aloha==0.1.1
38
+ # via -r examples/aloha_sim/requirements.in
39
+ gymnasium==1.0.0
40
+ # via gym-aloha
41
+ idna==3.10
42
+ # via requests
43
+ imageio==2.36.1
44
+ # via
45
+ # -r examples/aloha_sim/requirements.in
46
+ # gym-aloha
47
+ imageio-ffmpeg==0.5.1
48
+ # via imageio
49
+ kiwisolver==1.4.7
50
+ # via matplotlib
51
+ labmaze==1.0.6
52
+ # via dm-control
53
+ lxml==5.3.0
54
+ # via dm-control
55
+ markdown-it-py==3.0.0
56
+ # via rich
57
+ matplotlib==3.9.3
58
+ # via -r examples/aloha_sim/requirements.in
59
+ mdurl==0.1.2
60
+ # via markdown-it-py
61
+ msgpack==1.1.0
62
+ # via -r examples/aloha_sim/requirements.in
63
+ mujoco==2.3.7
64
+ # via
65
+ # dm-control
66
+ # gym-aloha
67
+ numpy==1.26.4
68
+ # via
69
+ # -r examples/aloha_sim/requirements.in
70
+ # contourpy
71
+ # dm-control
72
+ # dm-env
73
+ # gymnasium
74
+ # imageio
75
+ # labmaze
76
+ # matplotlib
77
+ # mujoco
78
+ # scipy
79
+ packaging==24.2
80
+ # via matplotlib
81
+ pillow==11.0.0
82
+ # via
83
+ # imageio
84
+ # matplotlib
85
+ protobuf==5.29.1
86
+ # via dm-control
87
+ psutil==6.1.0
88
+ # via imageio
89
+ pygments==2.18.0
90
+ # via rich
91
+ pyopengl==3.1.7
92
+ # via
93
+ # dm-control
94
+ # mujoco
95
+ pyparsing==3.2.0
96
+ # via
97
+ # dm-control
98
+ # matplotlib
99
+ python-dateutil==2.9.0.post0
100
+ # via matplotlib
101
+ requests==2.32.3
102
+ # via dm-control
103
+ rich==13.9.4
104
+ # via tyro
105
+ scipy==1.14.1
106
+ # via dm-control
107
+ setuptools==75.6.0
108
+ # via
109
+ # dm-control
110
+ # imageio-ffmpeg
111
+ # labmaze
112
+ shtab==1.7.1
113
+ # via tyro
114
+ six==1.17.0
115
+ # via python-dateutil
116
+ tqdm==4.67.1
117
+ # via dm-control
118
+ typeguard==4.4.1
119
+ # via tyro
120
+ typing-extensions==4.12.2
121
+ # via
122
+ # -r examples/aloha_sim/requirements.in
123
+ # gymnasium
124
+ # rich
125
+ # typeguard
126
+ # tyro
127
+ tyro==0.9.2
128
+ # via -r examples/aloha_sim/requirements.in
129
+ urllib3==2.2.3
130
+ # via requests
131
+ websockets==14.1
132
+ # via -r examples/aloha_sim/requirements.in
capvector-pi05/examples/aloha_sim/saver.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pathlib
3
+
4
+ import imageio
5
+ import numpy as np
6
+ from openpi_client.runtime import subscriber as _subscriber
7
+ from typing_extensions import override
8
+
9
+
10
+ class VideoSaver(_subscriber.Subscriber):
11
+ """Saves episode data."""
12
+
13
+ def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
14
+ out_dir.mkdir(parents=True, exist_ok=True)
15
+ self._out_dir = out_dir
16
+ self._images: list[np.ndarray] = []
17
+ self._subsample = subsample
18
+
19
+ @override
20
+ def on_episode_start(self) -> None:
21
+ self._images = []
22
+
23
+ @override
24
+ def on_step(self, observation: dict, action: dict) -> None:
25
+ im = observation["images"]["cam_high"] # [C, H, W]
26
+ im = np.transpose(im, (1, 2, 0)) # [H, W, C]
27
+ self._images.append(im)
28
+
29
+ @override
30
+ def on_episode_end(self) -> None:
31
+ existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
32
+ next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
33
+ out_path = self._out_dir / f"out_{next_idx}.mp4"
34
+
35
+ logging.info(f"Saving video to {out_path}")
36
+ imageio.mimwrite(
37
+ out_path,
38
+ [np.asarray(x) for x in self._images[:: self._subsample]],
39
+ fps=50 // max(1, self._subsample),
40
+ )
capvector-pi05/examples/convert_jax_model_to_pytorch.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
4
+
5
+ This script loads a JAX model checkpoint using orbax and can either:
6
+ 1. Print out all the parameter keys in a hierarchical structure for inspection
7
+ 2. Convert the JAX model to PyTorch format using our PI0Pytorch model
8
+
9
+ Usage:
10
+ # Just inspect keys:
11
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
12
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
13
+
14
+ # Convert to PyTorch:
15
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
16
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
17
+
18
+ Example:
19
+ # pi0_droid
20
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
21
+
22
+ # pi0_aloha_sim
23
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
24
+
25
+ # pi05_droid
26
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
27
+ """
28
+
29
+ import json
30
+ import os
31
+ import pathlib
32
+ import shutil
33
+ from typing import Literal
34
+
35
+ from flax.nnx import traversals
36
+ import numpy as np
37
+ import orbax.checkpoint as ocp
38
+ import safetensors
39
+ import torch
40
+ import tyro
41
+
42
+ import openpi.models.gemma
43
+ import openpi.models.model
44
+ import openpi.models.pi0_config
45
+ import openpi.models_pytorch.pi0_pytorch
46
+ from openpi.training import utils
47
+ import openpi.training.config as _config
48
+
49
+
50
+ def slice_paligemma_state_dict(state_dict, config):
51
+ """Convert PaliGemma JAX parameters to PyTorch format."""
52
+ suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
53
+
54
+ # patch embeddings
55
+ jax_key = f"img/embedding/kernel{suffix}"
56
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
57
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
58
+
59
+ jax_key = f"img/embedding/bias{suffix}"
60
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
61
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
62
+
63
+ # positional embeddings
64
+ jax_key = f"img/pos_embedding{suffix}"
65
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
66
+ state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
67
+
68
+ # extract vision layers to be sliced at index 0. There are 27 layers in the base model.
69
+ encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
70
+ encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
71
+ encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
72
+ encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
73
+
74
+ encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
75
+ encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
76
+ encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
77
+ encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
78
+
79
+ encoderblock_attention_0_key_kernel = state_dict.pop(
80
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
81
+ )
82
+ encoderblock_attention_0_key_bias = state_dict.pop(
83
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
84
+ )
85
+ encoderblock_attention_0_value_kernel = state_dict.pop(
86
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
87
+ )
88
+ encoderblock_attention_0_value_bias = state_dict.pop(
89
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
90
+ )
91
+ encoderblock_attention_0_query_kernel = state_dict.pop(
92
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
93
+ )
94
+ encoderblock_attention_0_query_bias = state_dict.pop(
95
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
96
+ )
97
+ encoderblock_attention_0_out_kernel = state_dict.pop(
98
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
99
+ )
100
+ encoderblock_attention_0_out_bias = state_dict.pop(
101
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
102
+ )
103
+
104
+ for i in range(config.vision_config.num_hidden_layers):
105
+ state_dict[
106
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
107
+ ] = encoderblock_layernorm0_scale[i].transpose()
108
+ state_dict[
109
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
110
+ ] = encoderblock_layernorm0_bias[i]
111
+ state_dict[
112
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
113
+ ] = encoderblock_layernorm1_scale[i].transpose()
114
+ state_dict[
115
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
116
+ ] = encoderblock_layernorm1_bias[i]
117
+ state_dict[
118
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
119
+ ] = encoderblock_mlp_dense0_kernel[i].transpose()
120
+ state_dict[
121
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
122
+ ] = encoderblock_mlp_dense0_bias[i]
123
+ state_dict[
124
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
125
+ ] = encoderblock_mlp_dense1_kernel[i].transpose()
126
+ state_dict[
127
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
128
+ ] = encoderblock_mlp_dense1_bias[i]
129
+ state_dict[
130
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
131
+ ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
132
+ state_dict[
133
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
134
+ ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
135
+ state_dict[
136
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
137
+ ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
138
+ state_dict[
139
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
140
+ ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
141
+ state_dict[
142
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
143
+ ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
144
+ state_dict[
145
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
146
+ ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
147
+ state_dict[
148
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
149
+ ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
150
+ state_dict[
151
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
152
+ ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
153
+
154
+ jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
155
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
156
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
157
+
158
+ jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
159
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
160
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
161
+
162
+ # multimodal projector
163
+ jax_key = f"img/head/kernel{suffix}"
164
+ pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
165
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
166
+
167
+ jax_key = f"img/head/bias{suffix}"
168
+ pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
169
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
170
+
171
+ # text decoder (gemma)
172
+ jax_key = f"llm/embedder/input_embedding{suffix}"
173
+ pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
174
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
175
+
176
+ # pop the einsum attention + mlp representations
177
+ llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
178
+ llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
179
+ llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
180
+
181
+ llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
182
+ llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
183
+
184
+ llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
185
+ llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
186
+
187
+ for i in range(config.text_config.num_hidden_layers):
188
+ q_proj_weight_reshaped = (
189
+ llm_attention_q_einsum[i]
190
+ .transpose(0, 2, 1)
191
+ .reshape(
192
+ config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
193
+ )
194
+ )
195
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
196
+ q_proj_weight_reshaped
197
+ )
198
+
199
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
200
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
201
+ k_proj_weight_reshaped
202
+ )
203
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
204
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
205
+ v_proj_weight_reshaped
206
+ )
207
+
208
+ o_proj_weight_reshaped = (
209
+ llm_attention_attn_vec_einsum[i]
210
+ .transpose(2, 0, 1)
211
+ .reshape(
212
+ config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
213
+ )
214
+ )
215
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
216
+ o_proj_weight_reshaped
217
+ )
218
+
219
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
220
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
221
+ gate_proj_weight.transpose()
222
+ )
223
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
224
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
225
+ up_proj_weight.transpose()
226
+ )
227
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
228
+ llm_mlp_linear[i].transpose()
229
+ )
230
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
231
+ llm_input_layernorm[i]
232
+ )
233
+ state_dict[
234
+ f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
235
+ ] = llm_post_attention_layernorm[i]
236
+
237
+ jax_key = f"llm/final_norm/scale{suffix}"
238
+ pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
239
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
240
+
241
+ expert_dict = {}
242
+ final_state_dict = {}
243
+
244
+ # Expert-related keys to extract (including pi05 Dense layer parameters)
245
+ expert_keys = [
246
+ f"llm/final_norm_1/scale{suffix}",
247
+ f"llm/final_norm_1/Dense_0/bias{suffix}",
248
+ f"llm/final_norm_1/Dense_0/kernel{suffix}",
249
+ f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
250
+ f"llm/layers/attn/kv_einsum_1/w{suffix}",
251
+ f"llm/layers/attn/q_einsum_1/w{suffix}",
252
+ f"llm/layers/mlp_1/gating_einsum{suffix}",
253
+ f"llm/layers/mlp_1/linear{suffix}",
254
+ f"llm/layers/pre_attention_norm_1/scale{suffix}",
255
+ f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
256
+ f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
257
+ f"llm/layers/pre_ffw_norm_1/scale{suffix}",
258
+ f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
259
+ f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
260
+ ]
261
+
262
+ for key, value in state_dict.items():
263
+ if key not in expert_keys:
264
+ final_state_dict[key] = torch.from_numpy(value)
265
+ else:
266
+ expert_dict[key] = value
267
+
268
+ return final_state_dict, expert_dict
269
+
270
+
271
+ def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
272
+ """Convert Gemma JAX parameters to PyTorch format."""
273
+ # Add missing attributes to config if they don't exist
274
+ if not hasattr(config, "vocab_size"):
275
+ config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
276
+ if not hasattr(config, "hidden_size"):
277
+ config.hidden_size = config.width
278
+ if not hasattr(config, "num_hidden_layers"):
279
+ config.num_hidden_layers = config.depth
280
+ if not hasattr(config, "num_attention_heads"):
281
+ config.num_attention_heads = config.num_heads
282
+
283
+ suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
284
+
285
+ llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
286
+ llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
287
+ llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
288
+
289
+ llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
290
+ llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
291
+
292
+ # Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
293
+ if "pi05" in checkpoint_dir:
294
+ # Pi05 with adaptive normalization
295
+ llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
296
+ llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
297
+ llm_input_layernorm_kernel = state_dict.pop(
298
+ f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
299
+ )
300
+ llm_post_attention_layernorm_kernel = state_dict.pop(
301
+ f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
302
+ )
303
+ else:
304
+ # Regular pi0 with standard RMSNorm
305
+ llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
306
+ llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
307
+
308
+ for i in range(config.num_hidden_layers):
309
+ q_proj_weight_reshaped = (
310
+ llm_attention_q_einsum[i]
311
+ .transpose(0, 2, 1)
312
+ .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
313
+ )
314
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
315
+ q_proj_weight_reshaped
316
+ )
317
+
318
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
319
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
320
+ k_proj_weight_reshaped
321
+ )
322
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
323
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
324
+ v_proj_weight_reshaped
325
+ )
326
+
327
+ o_proj_weight_reshaped = (
328
+ llm_attention_attn_vec_einsum[i]
329
+ .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
330
+ .transpose(1, 0)
331
+ )
332
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
333
+ o_proj_weight_reshaped
334
+ )
335
+
336
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
337
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
338
+ gate_proj_weight.transpose()
339
+ )
340
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
341
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
342
+ up_proj_weight.transpose()
343
+ )
344
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
345
+ i
346
+ ].transpose()
347
+
348
+ if "pi05" in checkpoint_dir:
349
+ # Pi05 with adaptive normalization - use Dense layer parameters directly
350
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
351
+ llm_input_layernorm_bias[i]
352
+ )
353
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
354
+ llm_post_attention_layernorm_bias[i]
355
+ )
356
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
357
+ llm_input_layernorm_kernel[i].transpose()
358
+ )
359
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
360
+ llm_post_attention_layernorm_kernel[i].transpose()
361
+ )
362
+ else:
363
+ # Regular pi0 with standard RMSNorm
364
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
365
+ llm_input_layernorm[i]
366
+ )
367
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
368
+ llm_post_attention_layernorm[i]
369
+ )
370
+
371
+ # Handle final norm layer
372
+ if "pi05" in checkpoint_dir:
373
+ # Pi05 with adaptive normalization - use Dense layer parameters directly
374
+ final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
375
+ final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
376
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
377
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
378
+ else:
379
+ # Regular pi0 with standard RMSNorm
380
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
381
+ f"llm/final_norm_{num_expert}/scale{suffix}"
382
+ )
383
+
384
+ # state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
385
+
386
+ final_state_dict = {}
387
+ for key, value in state_dict.items():
388
+ if not isinstance(value, torch.Tensor):
389
+ final_state_dict[key] = torch.from_numpy(value)
390
+ else:
391
+ final_state_dict[key] = value
392
+
393
+ return final_state_dict
394
+
395
+
396
+ def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
397
+ """Load and process params by restoring via JAX model loader first.
398
+ This respects dtype conversions that occur during model restore.
399
+ """
400
+ # Use repository restore utility to load a pure dict of params (value suffix removed)
401
+ params = openpi.models.model.restore_params(
402
+ f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
403
+ )
404
+
405
+ return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
406
+
407
+
408
+ def load_jax_model_and_print_keys(checkpoint_dir: str):
409
+ """
410
+ Load JAX model from checkpoint and print all parameter keys.
411
+
412
+ Args:
413
+ checkpoint_dir: Path to the checkpoint directory
414
+ """
415
+ checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
416
+ # Initialize checkpointer
417
+ checkpointer = ocp.PyTreeCheckpointer()
418
+ metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
419
+ print(utils.array_tree_to_info(metadata))
420
+
421
+
422
+ def convert_pi0_checkpoint(
423
+ checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
424
+ ):
425
+ """
426
+ Convert PI0 JAX checkpoint to PyTorch format.
427
+
428
+ Args:
429
+ checkpoint_dir: Path to the JAX checkpoint
430
+ precision: Model precision (float32, bfloat16, float16)
431
+ output_path: Path to save the converted PyTorch model
432
+ model_config: Model config
433
+ """
434
+ print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
435
+ print(f"Model config: {model_config}")
436
+
437
+ # Break down orbax ckpts by restoring via JAX to respect dtype
438
+ initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
439
+
440
+ # Process projection params
441
+ if model_config.pi05:
442
+ keys = [
443
+ "action_in_proj",
444
+ "action_out_proj",
445
+ "time_mlp_in",
446
+ "time_mlp_out",
447
+ ]
448
+ else:
449
+ keys = [
450
+ "state_proj",
451
+ "action_in_proj",
452
+ "action_out_proj",
453
+ "action_time_mlp_in",
454
+ "action_time_mlp_out",
455
+ ]
456
+
457
+ projection_params = {}
458
+ for key in keys:
459
+ kernel_params = initial_params["projection_params"][key]["kernel"]
460
+ bias_params = initial_params["projection_params"][key]["bias"]
461
+ if isinstance(kernel_params, dict):
462
+ weight = kernel_params["value"]
463
+ bias = bias_params["value"]
464
+ else:
465
+ weight = kernel_params
466
+ bias = bias_params
467
+
468
+ pytorch_weight_key = f"{key}.weight"
469
+ pytorch_bias_key = f"{key}.bias"
470
+
471
+ projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
472
+ projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
473
+
474
+ # Create configs based on checkpoint path
475
+ # All models use the same PaliGemma config structure
476
+ class PaliGemmaConfig:
477
+ def __init__(self):
478
+ self.vision_config = type(
479
+ "obj",
480
+ (object,),
481
+ {
482
+ "hidden_size": 1152,
483
+ "num_hidden_layers": 27,
484
+ "num_attention_heads": 16,
485
+ "intermediate_size": 4304,
486
+ "patch_size": 14,
487
+ "projection_dim": 2048,
488
+ },
489
+ )()
490
+ self.text_config = type(
491
+ "obj",
492
+ (object,),
493
+ {
494
+ "hidden_size": 2048,
495
+ "num_hidden_layers": 18,
496
+ "num_attention_heads": 8,
497
+ "head_dim": 256,
498
+ "intermediate_size": 16384,
499
+ },
500
+ )()
501
+
502
+ paligemma_config = PaliGemmaConfig()
503
+ action_expert_config = openpi.models.gemma.get_config("gemma_300m")
504
+
505
+ # Process PaliGemma weights
506
+ paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
507
+
508
+ # Process Gemma weights from expert_params
509
+ gemma_params = slice_gemma_state_dict(
510
+ expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
511
+ )
512
+
513
+ # Instantiate model
514
+ pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
515
+
516
+ # Combine all parameters (no prefix needed for our model structure)
517
+ all_params = {**paligemma_params, **gemma_params, **projection_params}
518
+
519
+ # Load state dict
520
+ pi0_model.load_state_dict(all_params, strict=False)
521
+
522
+ if precision == "float32":
523
+ pi0_model = pi0_model.to(torch.float32)
524
+ elif precision == "bfloat16":
525
+ pi0_model = pi0_model.to(torch.bfloat16)
526
+ else:
527
+ raise ValueError(f"Invalid precision: {precision}")
528
+
529
+ # Save the converted model using safetensors
530
+ os.makedirs(output_path, exist_ok=True)
531
+
532
+ # Save model weights as SafeTensors using save_model to handle tied weights
533
+ safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
534
+
535
+ # Copy assets folder if it exists
536
+ assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
537
+ if assets_source.exists():
538
+ assets_dest = pathlib.Path(output_path) / "assets"
539
+ if assets_dest.exists():
540
+ shutil.rmtree(assets_dest)
541
+ shutil.copytree(assets_source, assets_dest)
542
+
543
+ # Save config as JSON for reference
544
+ config_dict = {
545
+ "action_dim": model_config.action_dim,
546
+ "action_horizon": model_config.action_horizon,
547
+ "paligemma_variant": model_config.paligemma_variant,
548
+ "action_expert_variant": model_config.action_expert_variant,
549
+ "precision": precision,
550
+ }
551
+ with open(os.path.join(output_path, "config.json"), "w") as f:
552
+ json.dump(config_dict, f, indent=2)
553
+
554
+ print("Model conversion completed successfully!")
555
+ print(f"Model saved to {output_path}")
556
+
557
+
558
+ def main(
559
+ checkpoint_dir: str,
560
+ config_name: str,
561
+ output_path: str | None = None,
562
+ precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
563
+ *,
564
+ inspect_only: bool = False,
565
+ ):
566
+ """Load JAX model and optionally convert to PyTorch.
567
+
568
+ Args:
569
+ checkpoint_dir: Path to the JAX checkpoint directory
570
+ output_path: Path to save converted PyTorch model (required for conversion)
571
+ precision: Precision for model conversion
572
+ inspect_only: Only inspect parameter keys, don't convert
573
+ """
574
+ model_config = _config.get_config(config_name).model
575
+ if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
576
+ raise ValueError(f"Config {config_name} is not a Pi0Config")
577
+ if inspect_only:
578
+ load_jax_model_and_print_keys(checkpoint_dir)
579
+ else:
580
+ if not output_path:
581
+ print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
582
+ return
583
+ convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
584
+
585
+
586
+ if __name__ == "__main__":
587
+ tyro.cli(main)
capvector-pi05/examples/droid/README.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DROID Policies in openpi
2
+
3
+ We offer instructions for:
4
+ - [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
5
+ - [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
6
+ - [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
7
+ - [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)
8
+
9
+ ## Running DROID Inference
10
+
11
+ This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy.
12
+
13
+
14
+ ### Step 1: Start a policy server
15
+
16
+ Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
17
+
18
+ 1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
19
+ 2. Start the OpenPI server via the following command:
20
+
21
+ ```bash
22
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
23
+ ```
24
+
25
+ You can also run the equivalent command below:
26
+
27
+ ```bash
28
+ uv run scripts/serve_policy.py --env=DROID
29
+ ```
30
+
31
+ ### Step 2: Run the DROID robot
32
+
33
+ 1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
34
+ 2. On the control laptop, activate your DROID conda environment.
35
+ 3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
36
+ 4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
37
+ 5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
38
+ 6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
39
+ 7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
40
+
41
+ ```bash
42
+ python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
43
+ ```
44
+
45
+ The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
46
+
47
+ ## Troubleshooting
48
+
49
+ | Issue | Solution |
50
+ |-------|----------|
51
+ | Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
52
+ | Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
53
+ | Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
54
+ | Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
55
+
56
+
57
+ ## Running Other Policies
58
+
59
+ We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
60
+
61
+ ```
62
+ # Train from pi0-FAST, using FAST tokenizer
63
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
64
+
65
+ # Train from pi0, using flow matching
66
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid
67
+
68
+ # Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
69
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
70
+
71
+ # Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
72
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
73
+
74
+ # Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
75
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
76
+
77
+ # Trained from PaliGemma, using FSQ tokenizer.
78
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
79
+
80
+ # pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
81
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
82
+ ```
83
+
84
+ You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).
capvector-pi05/examples/droid/README_train.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training on DROID
2
+
3
+ Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
4
+ (small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.
5
+
6
+ In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough
7
+ for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.
8
+
9
+ ## Install
10
+
11
+ We need a few additional dependencies for RLDS data loading. Run:
12
+ ```bash
13
+ uv sync --group rlds
14
+ ```
15
+
16
+ ## Download DROID dataset
17
+
18
+ You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
19
+ ```
20
+ gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1
21
+ ```
22
+
23
+ Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).
24
+
25
+ You will need 1.8TB of disk storage to download the DROID RLDS dataset.
26
+
27
+ ## Run
28
+
29
+ First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).
30
+
31
+ Then, compute normalization statistics (this will take ~10 minutes):
32
+ ```bash
33
+ uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
34
+ ```
35
+
36
+ Run training:
37
+ ```bash
38
+ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
39
+ ```
40
+
41
+ **Note**: The original pi0.5-DROID model was trained with joint velocity actions.
42
+ Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
43
+ Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
44
+
45
+
46
+ ## Compute Requirements
47
+
48
+ Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
49
+ If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).
50
+
51
+ We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
52
+
53
+
54
+ ## Data Filtering
55
+
56
+ Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
57
+
58
+ By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
59
+
60
+ **Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
61
+
62
+ ## RoboArena
63
+
64
+ Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)
65
+
66
+ If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).
67
+
68
+
69
+ # Fine-Tuning on Custom DROID Datasets
70
+
71
+ Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.
72
+
73
+ Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).
74
+
75
+
76
+ ## Step 1: Converting your custom DROID dataset to LeRobot
77
+
78
+ We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
79
+ ```
80
+ gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>
81
+ ```
82
+
83
+ We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
84
+ ```
85
+ gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>
86
+ ```
87
+
88
+ For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).
89
+
90
+ Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
91
+ ```
92
+ uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>
93
+ ```
94
+
95
+ ## Step 2: Run fine-tuning with your custom dataset
96
+
97
+ Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created.
98
+ You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).
99
+
100
+ To launch training:
101
+ ```
102
+ uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
103
+ ```
104
+
105
+ Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.
106
+
capvector-pi05/examples/droid/compute_droid_nonidle_ranges.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
3
+ that should be sampled during training (all others are filtered out).
4
+
5
+ Filtering logic:
6
+ We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
7
+ (default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
8
+ this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
9
+ ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
10
+ filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).
11
+
12
+ This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
13
+ yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
14
+ """
15
+
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+ import tensorflow as tf
22
+ import tensorflow_datasets as tfds
23
+ from tqdm import tqdm
24
+
25
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU
26
+
27
+ builder = tfds.builder_from_directory(
28
+ # path to the `droid` directory (not its parent)
29
+ builder_dir="<path_to_droid_dataset_tfds_files>",
30
+ )
31
+ ds = builder.as_dataset(split="train", shuffle_files=False)
32
+ tf.data.experimental.ignore_errors(ds)
33
+
34
+ keep_ranges_path = "<path_to_where_to_save_the_json>"
35
+
36
+ min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out
37
+ min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out
38
+ filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range
39
+
40
+ keep_ranges_map = {}
41
+ if Path(keep_ranges_path).exists():
42
+ with Path(keep_ranges_path).open("r") as f:
43
+ keep_ranges_map = json.load(f)
44
+ print(f"Resuming from {len(keep_ranges_map)} episodes already processed")
45
+
46
+ for ep_idx, ep in enumerate(tqdm(ds)):
47
+ recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
48
+ file_path = ep["episode_metadata"]["file_path"].numpy().decode()
49
+
50
+ key = f"{recording_folderpath}--{file_path}"
51
+ if key in keep_ranges_map:
52
+ continue
53
+
54
+ joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
55
+ joint_velocities = np.array(joint_velocities)
56
+
57
+ is_idle_array = np.hstack(
58
+ [np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
59
+ )
60
+
61
+ # Find what steps go from idle to non-idle and vice-versa
62
+ is_idle_padded = np.concatenate(
63
+ [[False], is_idle_array, [False]]
64
+ ) # Start and end with False, so idle at first step is a start of motion
65
+
66
+ is_idle_diff = np.diff(is_idle_padded.astype(int))
67
+ is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle
68
+ is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle
69
+
70
+ # Find which steps correspond to idle segments of length at least min_idle_len
71
+ true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
72
+ is_idle_true_starts = is_idle_true_starts[true_segment_masks]
73
+ is_idle_true_ends = is_idle_true_ends[true_segment_masks]
74
+
75
+ keep_mask = np.ones(len(joint_velocities), dtype=bool)
76
+ for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
77
+ keep_mask[start:end] = False
78
+
79
+ # Get all non-idle ranges of at least 16
80
+ # Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
81
+ keep_padded = np.concatenate([[False], keep_mask, [False]])
82
+
83
+ keep_diff = np.diff(keep_padded.astype(int))
84
+ keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep
85
+ keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out
86
+
87
+ # Find which steps correspond to non-idle segments of length at least min_non_idle_len
88
+ true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
89
+ keep_true_starts = keep_true_starts[true_segment_masks]
90
+ keep_true_ends = keep_true_ends[true_segment_masks]
91
+
92
+ # Add mapping from episode unique ID key to list of non-idle ranges to keep
93
+ keep_ranges_map[key] = []
94
+ for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
95
+ keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))
96
+
97
+ if ep_idx % 1000 == 0:
98
+ with Path(keep_ranges_path).open("w") as f:
99
+ json.dump(keep_ranges_map, f)
100
+
101
+ print("Done!")
102
+ with Path(keep_ranges_path).open("w") as f:
103
+ json.dump(keep_ranges_map, f)
capvector-pi05/examples/droid/convert_droid_data_to_lerobot.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.
3
+
4
+ Usage:
5
+ uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data
6
+
7
+ If you want to push your dataset to the Hugging Face Hub, you can use the following command:
8
+ uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
9
+
10
+ The resulting dataset will get saved to the $LEROBOT_HOME directory.
11
+ """
12
+
13
+ from collections import defaultdict
14
+ import copy
15
+ import glob
16
+ import json
17
+ from pathlib import Path
18
+ import shutil
19
+
20
+ import cv2
21
+ import h5py
22
+ from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
23
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
24
+ import numpy as np
25
+ from PIL import Image
26
+ from tqdm import tqdm
27
+ import tyro
28
+
29
+ REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub
30
+
31
+
32
+ def resize_image(image, size):
33
+ image = Image.fromarray(image)
34
+ return np.array(image.resize(size, resample=Image.BICUBIC))
35
+
36
+
37
+ def main(data_dir: str, *, push_to_hub: bool = False):
38
+ # Clean up any existing dataset in the output directory
39
+ output_path = HF_LEROBOT_HOME / REPO_NAME
40
+ if output_path.exists():
41
+ shutil.rmtree(output_path)
42
+ data_dir = Path(data_dir)
43
+
44
+ # Create LeRobot dataset, define features to store
45
+ # We will follow the DROID data naming conventions here.
46
+ # LeRobot assumes that dtype of image data is `image`
47
+ dataset = LeRobotDataset.create(
48
+ repo_id=REPO_NAME,
49
+ robot_type="panda",
50
+ fps=15, # DROID data is typically recorded at 15fps
51
+ features={
52
+ # We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
53
+ "exterior_image_1_left": {
54
+ "dtype": "image",
55
+ "shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset
56
+ "names": ["height", "width", "channel"],
57
+ },
58
+ "exterior_image_2_left": {
59
+ "dtype": "image",
60
+ "shape": (180, 320, 3),
61
+ "names": ["height", "width", "channel"],
62
+ },
63
+ "wrist_image_left": {
64
+ "dtype": "image",
65
+ "shape": (180, 320, 3),
66
+ "names": ["height", "width", "channel"],
67
+ },
68
+ "joint_position": {
69
+ "dtype": "float32",
70
+ "shape": (7,),
71
+ "names": ["joint_position"],
72
+ },
73
+ "gripper_position": {
74
+ "dtype": "float32",
75
+ "shape": (1,),
76
+ "names": ["gripper_position"],
77
+ },
78
+ "actions": {
79
+ "dtype": "float32",
80
+ "shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D)
81
+ "names": ["actions"],
82
+ },
83
+ },
84
+ image_writer_threads=10,
85
+ image_writer_processes=5,
86
+ )
87
+
88
+ # Load language annotations
89
+ # Note: we load the DROID language annotations for this example, but you can manually define them for your own data
90
+ with (data_dir / "aggregated-annotations-030724.json").open() as f:
91
+ language_annotations = json.load(f)
92
+
93
+ # Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
94
+ # We assume the following directory structure:
95
+ # RAW_DROID_PATH/
96
+ # - <...>/
97
+ # - recordings/
98
+ # - MP4/
99
+ # - <camera_id>.mp4 # single-view video of left stereo pair camera
100
+ # - trajectory.hdf5
101
+ # - <...>/
102
+ episode_paths = list(data_dir.glob("**/trajectory.h5"))
103
+ print(f"Found {len(episode_paths)} episodes for conversion")
104
+
105
+ # We will loop over each dataset_name and write episodes to the LeRobot dataset
106
+ for episode_path in tqdm(episode_paths, desc="Converting episodes"):
107
+ # Load raw data
108
+ recording_folderpath = episode_path.parent / "recordings" / "MP4"
109
+ trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))
110
+
111
+ # To load the language instruction, we need to parse out the episode_id from the metadata file
112
+ # Again, you can modify this step for your own data, to load your own language instructions
113
+ metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
114
+ episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
115
+ language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
116
+ "language_instruction1"
117
+ ]
118
+ print(f"Converting episode with language instruction: {language_instruction}")
119
+
120
+ # Write to LeRobot dataset
121
+ for step in trajectory:
122
+ camera_type_dict = step["observation"]["camera_type"]
123
+ wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
124
+ exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
125
+ dataset.add_frame(
126
+ {
127
+ # Note: need to flip BGR --> RGB for loaded images
128
+ "exterior_image_1_left": resize_image(
129
+ step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
130
+ ),
131
+ "exterior_image_2_left": resize_image(
132
+ step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
133
+ ),
134
+ "wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
135
+ "joint_position": np.asarray(
136
+ step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
137
+ ),
138
+ "gripper_position": np.asarray(
139
+ step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
140
+ ),
141
+ # Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
142
+ "actions": np.concatenate(
143
+ [step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
144
+ ),
145
+ "task": language_instruction,
146
+ }
147
+ )
148
+ dataset.save_episode()
149
+
150
+ # Optionally push to the Hugging Face Hub
151
+ if push_to_hub:
152
+ dataset.push_to_hub(
153
+ tags=["libero", "panda", "rlds"],
154
+ private=False,
155
+ push_videos=True,
156
+ license="apache-2.0",
157
+ )
158
+
159
+
160
+ ##########################################################################################################
161
+ ################ The rest of this file are functions to parse the raw DROID data #########################
162
+ ################ You don't need to worry about understanding this part #########################
163
+ ################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
164
+ ##########################################################################################################
165
+
166
+
167
+ camera_type_dict = {
168
+ "hand_camera_id": 0,
169
+ "varied_camera_1_id": 1,
170
+ "varied_camera_2_id": 1,
171
+ }
172
+
173
+ camera_type_to_string_dict = {
174
+ 0: "hand_camera",
175
+ 1: "varied_camera",
176
+ 2: "fixed_camera",
177
+ }
178
+
179
+
180
+ def get_camera_type(cam_id):
181
+ if cam_id not in camera_type_dict:
182
+ return None
183
+ type_int = camera_type_dict[cam_id]
184
+ return camera_type_to_string_dict[type_int]
185
+
186
+
187
+ class MP4Reader:
188
+ def __init__(self, filepath, serial_number):
189
+ # Save Parameters #
190
+ self.serial_number = serial_number
191
+ self._index = 0
192
+
193
+ # Open Video Reader #
194
+ self._mp4_reader = cv2.VideoCapture(filepath)
195
+ if not self._mp4_reader.isOpened():
196
+ raise RuntimeError("Corrupted MP4 File")
197
+
198
+ def set_reading_parameters(
199
+ self,
200
+ image=True, # noqa: FBT002
201
+ concatenate_images=False, # noqa: FBT002
202
+ resolution=(0, 0),
203
+ resize_func=None,
204
+ ):
205
+ # Save Parameters #
206
+ self.image = image
207
+ self.concatenate_images = concatenate_images
208
+ self.resolution = resolution
209
+ self.resize_func = cv2.resize
210
+ self.skip_reading = not image
211
+ if self.skip_reading:
212
+ return
213
+
214
+ def get_frame_resolution(self):
215
+ width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
216
+ height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
217
+ return (width, height)
218
+
219
+ def get_frame_count(self):
220
+ if self.skip_reading:
221
+ return 0
222
+ return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
223
+
224
+ def set_frame_index(self, index):
225
+ if self.skip_reading:
226
+ return
227
+
228
+ if index < self._index:
229
+ self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
230
+ self._index = index
231
+
232
+ while self._index < index:
233
+ self.read_camera(ignore_data=True)
234
+
235
+ def _process_frame(self, frame):
236
+ frame = copy.deepcopy(frame)
237
+ if self.resolution == (0, 0):
238
+ return frame
239
+ return self.resize_func(frame, self.resolution)
240
+
241
+ def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002
242
+ # Skip if Read Unnecesary #
243
+ if self.skip_reading:
244
+ return {}
245
+
246
+ # Read Camera #
247
+ success, frame = self._mp4_reader.read()
248
+
249
+ self._index += 1
250
+ if not success:
251
+ return None
252
+ if ignore_data:
253
+ return None
254
+
255
+ # Return Data #
256
+ data_dict = {}
257
+
258
+ if self.concatenate_images or "stereo" not in self.serial_number:
259
+ data_dict["image"] = {self.serial_number: self._process_frame(frame)}
260
+ else:
261
+ single_width = frame.shape[1] // 2
262
+ data_dict["image"] = {
263
+ self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
264
+ self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
265
+ }
266
+
267
+ return data_dict
268
+
269
+ def disable_camera(self):
270
+ if hasattr(self, "_mp4_reader"):
271
+ self._mp4_reader.release()
272
+
273
+
274
+ class RecordedMultiCameraWrapper:
275
+ def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006
276
+ # Save Camera Info #
277
+ self.camera_kwargs = camera_kwargs
278
+
279
+ # Open Camera Readers #
280
+ mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
281
+ all_filepaths = mp4_filepaths
282
+
283
+ self.camera_dict = {}
284
+ for f in all_filepaths:
285
+ serial_number = f.split("/")[-1][:-4]
286
+ cam_type = get_camera_type(serial_number)
287
+ camera_kwargs.get(cam_type, {})
288
+
289
+ if f.endswith(".mp4"):
290
+ Reader = MP4Reader # noqa: N806
291
+ else:
292
+ raise ValueError
293
+
294
+ self.camera_dict[serial_number] = Reader(f, serial_number)
295
+
296
+ def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006
297
+ full_obs_dict = defaultdict(dict)
298
+
299
+ # Read Cameras In Randomized Order #
300
+ all_cam_ids = list(self.camera_dict.keys())
301
+ # random.shuffle(all_cam_ids)
302
+
303
+ for cam_id in all_cam_ids:
304
+ if "stereo" in cam_id:
305
+ continue
306
+ try:
307
+ cam_type = camera_type_dict[cam_id]
308
+ except KeyError:
309
+ print(f"{self.camera_dict} -- {camera_type_dict}")
310
+ raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904
311
+ curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
312
+ self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)
313
+
314
+ timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
315
+ if index is not None:
316
+ self.camera_dict[cam_id].set_frame_index(index)
317
+
318
+ data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)
319
+
320
+ # Process Returned Data #
321
+ if data_dict is None:
322
+ return None
323
+ for key in data_dict:
324
+ full_obs_dict[key].update(data_dict[key])
325
+
326
+ return full_obs_dict
327
+
328
+
329
+ def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006
330
+ length = None
331
+
332
+ for key in hdf5_file:
333
+ if key in keys_to_ignore:
334
+ continue
335
+
336
+ curr_data = hdf5_file[key]
337
+ if isinstance(curr_data, h5py.Group):
338
+ curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
339
+ elif isinstance(curr_data, h5py.Dataset):
340
+ curr_length = len(curr_data)
341
+ else:
342
+ raise ValueError
343
+
344
+ if length is None:
345
+ length = curr_length
346
+ assert curr_length == length
347
+
348
+ return length
349
+
350
+
351
+ def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006
352
+ data_dict = {}
353
+
354
+ for key in hdf5_file:
355
+ if key in keys_to_ignore:
356
+ continue
357
+
358
+ curr_data = hdf5_file[key]
359
+ if isinstance(curr_data, h5py.Group):
360
+ data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
361
+ elif isinstance(curr_data, h5py.Dataset):
362
+ data_dict[key] = curr_data[index]
363
+ else:
364
+ raise ValueError
365
+
366
+ return data_dict
367
+
368
+
369
+ class TrajectoryReader:
370
+ def __init__(self, filepath, read_images=True): # noqa: FBT002
371
+ self._hdf5_file = h5py.File(filepath, "r")
372
+ is_video_folder = "observations/videos" in self._hdf5_file
373
+ self._read_images = read_images and is_video_folder
374
+ self._length = get_hdf5_length(self._hdf5_file)
375
+ self._video_readers = {}
376
+ self._index = 0
377
+
378
+ def length(self):
379
+ return self._length
380
+
381
+ def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006
382
+ # Make Sure We Read Within Range #
383
+ if index is None:
384
+ index = self._index
385
+ else:
386
+ assert not self._read_images
387
+ self._index = index
388
+ assert index < self._length
389
+
390
+ # Load Low Dimensional Data #
391
+ keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
392
+ timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)
393
+
394
+ # Increment Read Index #
395
+ self._index += 1
396
+
397
+ # Return Timestep #
398
+ return timestep
399
+
400
+ def close(self):
401
+ self._hdf5_file.close()
402
+
403
+
404
+ def load_trajectory(
405
+ filepath=None,
406
+ read_cameras=True, # noqa: FBT002
407
+ recording_folderpath=None,
408
+ camera_kwargs={}, # noqa: B006
409
+ remove_skipped_steps=False, # noqa: FBT002
410
+ num_samples_per_traj=None,
411
+ num_samples_per_traj_coeff=1.5,
412
+ ):
413
+ read_recording_folderpath = read_cameras and (recording_folderpath is not None)
414
+
415
+ traj_reader = TrajectoryReader(filepath)
416
+ if read_recording_folderpath:
417
+ camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
418
+
419
+ horizon = traj_reader.length()
420
+ timestep_list = []
421
+
422
+ # Choose Timesteps To Save #
423
+ if num_samples_per_traj:
424
+ num_to_save = num_samples_per_traj
425
+ if remove_skipped_steps:
426
+ num_to_save = int(num_to_save * num_samples_per_traj_coeff)
427
+ max_size = min(num_to_save, horizon)
428
+ indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
429
+ else:
430
+ indices_to_save = np.arange(horizon)
431
+
432
+ # Iterate Over Trajectory #
433
+ for i in indices_to_save:
434
+ # Get HDF5 Data #
435
+ timestep = traj_reader.read_timestep(index=i)
436
+
437
+ # If Applicable, Get Recorded Data #
438
+ if read_recording_folderpath:
439
+ timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
440
+ camera_type_dict = {
441
+ k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
442
+ }
443
+ camera_obs = camera_reader.read_cameras(
444
+ index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
445
+ )
446
+ camera_failed = camera_obs is None
447
+
448
+ # Add Data To Timestep If Successful #
449
+ if camera_failed:
450
+ break
451
+ timestep["observation"].update(camera_obs)
452
+
453
+ # Filter Steps #
454
+ step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
455
+ delete_skipped_step = step_skipped and remove_skipped_steps
456
+
457
+ # Save Filtered Timesteps #
458
+ if delete_skipped_step:
459
+ del timestep
460
+ else:
461
+ timestep_list.append(timestep)
462
+
463
+ # Remove Extra Transitions #
464
+ timestep_list = np.array(timestep_list)
465
+ if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
466
+ ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
467
+ timestep_list = timestep_list[ind_to_keep]
468
+
469
+ # Close Readers #
470
+ traj_reader.close()
471
+
472
+ # Return Data #
473
+ return timestep_list
474
+
475
+
476
+ if __name__ == "__main__":
477
+ tyro.cli(main)
capvector-pi05/examples/droid/main.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa
2
+
3
+ import contextlib
4
+ import dataclasses
5
+ import datetime
6
+ import faulthandler
7
+ import os
8
+ import signal
9
+ import time
10
+ from moviepy.editor import ImageSequenceClip
11
+ import numpy as np
12
+ from openpi_client import image_tools
13
+ from openpi_client import websocket_client_policy
14
+ import pandas as pd
15
+ from PIL import Image
16
+ from droid.robot_env import RobotEnv
17
+ import tqdm
18
+ import tyro
19
+
20
+ faulthandler.enable()
21
+
22
+ # DROID data collection frequency -- we slow down execution to match this frequency
23
+ DROID_CONTROL_FREQUENCY = 15
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class Args:
28
+ # Hardware parameters
29
+ left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
30
+ right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
31
+ wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
32
+
33
+ # Policy parameters
34
+ external_camera: str | None = (
35
+ None # which external camera should be fed to the policy, choose from ["left", "right"]
36
+ )
37
+
38
+ # Rollout parameters
39
+ max_timesteps: int = 600
40
+ # How many actions to execute from a predicted action chunk before querying policy server again
41
+ # 8 is usually a good default (equals 0.5 seconds of action execution).
42
+ open_loop_horizon: int = 8
43
+
44
+ # Remote server parameters
45
+ remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
46
+ remote_port: int = (
47
+ 8000 # point this to the port of the policy server, default server port for openpi servers is 8000
48
+ )
49
+
50
+
51
+ # We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
52
+ # waiting for a new action chunk, it will raise an exception and the server connection dies.
53
+ # This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
54
+ @contextlib.contextmanager
55
+ def prevent_keyboard_interrupt():
56
+ """Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
57
+ interrupted = False
58
+ original_handler = signal.getsignal(signal.SIGINT)
59
+
60
+ def handler(signum, frame):
61
+ nonlocal interrupted
62
+ interrupted = True
63
+
64
+ signal.signal(signal.SIGINT, handler)
65
+ try:
66
+ yield
67
+ finally:
68
+ signal.signal(signal.SIGINT, original_handler)
69
+ if interrupted:
70
+ raise KeyboardInterrupt
71
+
72
+
73
+ def main(args: Args):
74
+ # Make sure external camera is specified by user -- we only use one external camera for the policy
75
+ assert (
76
+ args.external_camera is not None and args.external_camera in ["left", "right"]
77
+ ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
78
+
79
+ # Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
80
+ env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
81
+ print("Created the droid env!")
82
+
83
+ # Connect to the policy server
84
+ policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
85
+
86
+ df = pd.DataFrame(columns=["success", "duration", "video_filename"])
87
+
88
+ while True:
89
+ instruction = input("Enter instruction: ")
90
+
91
+ # Rollout parameters
92
+ actions_from_chunk_completed = 0
93
+ pred_action_chunk = None
94
+
95
+ # Prepare to save video of rollout
96
+ timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
97
+ video = []
98
+ bar = tqdm.tqdm(range(args.max_timesteps))
99
+ print("Running rollout... press Ctrl+C to stop early.")
100
+ for t_step in bar:
101
+ start_time = time.time()
102
+ try:
103
+ # Get the current observation
104
+ curr_obs = _extract_observation(
105
+ args,
106
+ env.get_observation(),
107
+ # Save the first observation to disk
108
+ save_to_disk=t_step == 0,
109
+ )
110
+
111
+ video.append(curr_obs[f"{args.external_camera}_image"])
112
+
113
+ # Send websocket request to policy server if it's time to predict a new chunk
114
+ if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
115
+ actions_from_chunk_completed = 0
116
+
117
+ # We resize images on the robot laptop to minimize the amount of data sent to the policy server
118
+ # and improve latency.
119
+ request_data = {
120
+ "observation/exterior_image_1_left": image_tools.resize_with_pad(
121
+ curr_obs[f"{args.external_camera}_image"], 224, 224
122
+ ),
123
+ "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
124
+ "observation/joint_position": curr_obs["joint_position"],
125
+ "observation/gripper_position": curr_obs["gripper_position"],
126
+ "prompt": instruction,
127
+ }
128
+
129
+ # Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
130
+ # Ctrl+C will be handled after the server call is complete
131
+ with prevent_keyboard_interrupt():
132
+ # this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
133
+ pred_action_chunk = policy_client.infer(request_data)["actions"]
134
+ assert pred_action_chunk.shape == (10, 8)
135
+
136
+ # Select current action to execute from chunk
137
+ action = pred_action_chunk[actions_from_chunk_completed]
138
+ actions_from_chunk_completed += 1
139
+
140
+ # Binarize gripper action
141
+ if action[-1].item() > 0.5:
142
+ # action[-1] = 1.0
143
+ action = np.concatenate([action[:-1], np.ones((1,))])
144
+ else:
145
+ # action[-1] = 0.0
146
+ action = np.concatenate([action[:-1], np.zeros((1,))])
147
+
148
+ # clip all dimensions of action to [-1, 1]
149
+ action = np.clip(action, -1, 1)
150
+
151
+ env.step(action)
152
+
153
+ # Sleep to match DROID data collection frequency
154
+ elapsed_time = time.time() - start_time
155
+ if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
156
+ time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
157
+ except KeyboardInterrupt:
158
+ break
159
+
160
+ video = np.stack(video)
161
+ save_filename = "video_" + timestamp
162
+ ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
163
+
164
+ success: str | float | None = None
165
+ while not isinstance(success, float):
166
+ success = input(
167
+ "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
168
+ )
169
+ if success == "y":
170
+ success = 1.0
171
+ elif success == "n":
172
+ success = 0.0
173
+
174
+ success = float(success) / 100
175
+ if not (0 <= success <= 1):
176
+ print(f"Success must be a number in [0, 100] but got: {success * 100}")
177
+
178
+ df = df.append(
179
+ {
180
+ "success": success,
181
+ "duration": t_step,
182
+ "video_filename": save_filename,
183
+ },
184
+ ignore_index=True,
185
+ )
186
+
187
+ if input("Do one more eval? (enter y or n) ").lower() != "y":
188
+ break
189
+ env.reset()
190
+
191
+ os.makedirs("results", exist_ok=True)
192
+ timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
193
+ csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
194
+ df.to_csv(csv_filename)
195
+ print(f"Results saved to {csv_filename}")
196
+
197
+
198
+ def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
199
+ image_observations = obs_dict["image"]
200
+ left_image, right_image, wrist_image = None, None, None
201
+ for key in image_observations:
202
+ # Note the "left" below refers to the left camera in the stereo pair.
203
+ # The model is only trained on left stereo cams, so we only feed those.
204
+ if args.left_camera_id in key and "left" in key:
205
+ left_image = image_observations[key]
206
+ elif args.right_camera_id in key and "left" in key:
207
+ right_image = image_observations[key]
208
+ elif args.wrist_camera_id in key and "left" in key:
209
+ wrist_image = image_observations[key]
210
+
211
+ # Drop the alpha dimension
212
+ left_image = left_image[..., :3]
213
+ right_image = right_image[..., :3]
214
+ wrist_image = wrist_image[..., :3]
215
+
216
+ # Convert to RGB
217
+ left_image = left_image[..., ::-1]
218
+ right_image = right_image[..., ::-1]
219
+ wrist_image = wrist_image[..., ::-1]
220
+
221
+ # In addition to image observations, also capture the proprioceptive state
222
+ robot_state = obs_dict["robot_state"]
223
+ cartesian_position = np.array(robot_state["cartesian_position"])
224
+ joint_position = np.array(robot_state["joint_positions"])
225
+ gripper_position = np.array([robot_state["gripper_position"]])
226
+
227
+ # Save the images to disk so that they can be viewed live while the robot is running
228
+ # Create one combined image to make live viewing easy
229
+ if save_to_disk:
230
+ combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
231
+ combined_image = Image.fromarray(combined_image)
232
+ combined_image.save("robot_camera_views.png")
233
+
234
+ return {
235
+ "left_image": left_image,
236
+ "right_image": right_image,
237
+ "wrist_image": wrist_image,
238
+ "cartesian_position": cartesian_position,
239
+ "joint_position": joint_position,
240
+ "gripper_position": gripper_position,
241
+ }
242
+
243
+
244
+ if __name__ == "__main__":
245
+ args: Args = tyro.cli(Args)
246
+ main(args)
capvector-pi05/examples/inference.ipynb ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import dataclasses\n",
10
+ "\n",
11
+ "import jax\n",
12
+ "\n",
13
+ "from openpi.models import model as _model\n",
14
+ "from openpi.policies import droid_policy\n",
15
+ "from openpi.policies import policy_config as _policy_config\n",
16
+ "from openpi.shared import download\n",
17
+ "from openpi.training import config as _config\n",
18
+ "from openpi.training import data_loader as _data_loader"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "# Policy inference\n",
26
+ "\n",
27
+ "The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "config = _config.get_config(\"pi0_fast_droid\")\n",
37
+ "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
38
+ "\n",
39
+ "# Create a trained policy.\n",
40
+ "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
41
+ "\n",
42
+ "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
43
+ "example = droid_policy.make_droid_example()\n",
44
+ "result = policy.infer(example)\n",
45
+ "\n",
46
+ "# Delete the policy to free up memory.\n",
47
+ "del policy\n",
48
+ "\n",
49
+ "print(\"Actions shape:\", result[\"actions\"].shape)"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {},
55
+ "source": [
56
+ "# Working with a live model\n",
57
+ "\n",
58
+ "\n",
59
+ "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "config = _config.get_config(\"pi0_aloha_sim\")\n",
69
+ "\n",
70
+ "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
71
+ "key = jax.random.key(0)\n",
72
+ "\n",
73
+ "# Create a model from the checkpoint.\n",
74
+ "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
75
+ "\n",
76
+ "# We can create fake observations and actions to test the model.\n",
77
+ "obs, act = config.model.fake_obs(), config.model.fake_act()\n",
78
+ "\n",
79
+ "# Sample actions from the model.\n",
80
+ "loss = model.compute_loss(key, obs, act)\n",
81
+ "print(\"Loss shape:\", loss.shape)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "metadata": {},
87
+ "source": [
88
+ "Now, we are going to create a data loader and use a real batch of training data to compute the loss."
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# Reduce the batch size to reduce memory usage.\n",
98
+ "config = dataclasses.replace(config, batch_size=2)\n",
99
+ "\n",
100
+ "# Load a single batch of data. This is the same data that will be used during training.\n",
101
+ "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
102
+ "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
103
+ "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
104
+ "obs, act = next(iter(loader))\n",
105
+ "\n",
106
+ "# Sample actions from the model.\n",
107
+ "loss = model.compute_loss(key, obs, act)\n",
108
+ "\n",
109
+ "# Delete the model to free up memory.\n",
110
+ "del model\n",
111
+ "\n",
112
+ "print(\"Loss shape:\", loss.shape)"
113
+ ]
114
+ }
115
+ ],
116
+ "metadata": {
117
+ "kernelspec": {
118
+ "display_name": ".venv",
119
+ "language": "python",
120
+ "name": "python3"
121
+ },
122
+ "language_info": {
123
+ "codemirror_mode": {
124
+ "name": "ipython",
125
+ "version": 3
126
+ },
127
+ "file_extension": ".py",
128
+ "mimetype": "text/x-python",
129
+ "name": "python",
130
+ "nbconvert_exporter": "python",
131
+ "pygments_lexer": "ipython3",
132
+ "version": "3.11.9"
133
+ }
134
+ },
135
+ "nbformat": 4,
136
+ "nbformat_minor": 2
137
+ }
capvector-pi05/examples/libero/compose.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with:
2
+ # docker compose -f examples/libero/compose.yml up --build
3
+ services:
4
+ runtime:
5
+ image: libero
6
+ depends_on:
7
+ - openpi_server
8
+ build:
9
+ context: ../..
10
+ dockerfile: examples/libero/Dockerfile
11
+ init: true
12
+ tty: true
13
+ network_mode: host
14
+ privileged: true
15
+ volumes:
16
+ - $PWD:/app
17
+ - ../../data:/data
18
+ - /tmp/.X11-unix:/tmp/.X11-unix:ro
19
+ environment:
20
+ - CLIENT_ARGS
21
+ - DISPLAY=$DISPLAY
22
+ - MUJOCO_GL=${MUJOCO_GL:-egl}
23
+ deploy:
24
+ resources:
25
+ reservations:
26
+ devices:
27
+ - driver: nvidia
28
+ count: 1
29
+ capabilities: [gpu]
30
+
31
+ openpi_server:
32
+ image: openpi_server
33
+ build:
34
+ context: ../..
35
+ dockerfile: scripts/docker/serve_policy.Dockerfile
36
+ init: true
37
+ tty: true
38
+ network_mode: host
39
+ volumes:
40
+ - $PWD:/app
41
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
42
+ environment:
43
+ - SERVER_ARGS
44
+ - OPENPI_DATA_HOME=/openpi_assets
45
+ - IS_DOCKER=true
46
+
47
+ # Comment out this block if not running on a machine with GPUs.
48
+ deploy:
49
+ resources:
50
+ reservations:
51
+ devices:
52
+ - driver: nvidia
53
+ count: 1
54
+ capabilities: [gpu]
capvector-pi05/examples/libero/convert_libero_data_to_lerobot.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal example script for converting a dataset to LeRobot format.
3
+
4
+ We use the Libero dataset (stored in RLDS) for this example, but it can be easily
5
+ modified for any other data you have saved in a custom format.
6
+
7
+ Usage:
8
+ uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
9
+
10
+ If you want to push your dataset to the Hugging Face Hub, you can use the following command:
11
+ uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
12
+
13
+ Note: to run the script, you need to install tensorflow_datasets:
14
+ `uv pip install tensorflow tensorflow_datasets`
15
+
16
+ You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
17
+ The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
18
+ Running this conversion script will take approximately 30 minutes.
19
+ """
20
+
21
+ import shutil
22
+
23
+ from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
24
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
25
+ import tensorflow_datasets as tfds
26
+ import tyro
27
+
28
+ REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
29
+ RAW_DATASET_NAMES = [
30
+ "libero_10_no_noops",
31
+ "libero_goal_no_noops",
32
+ "libero_object_no_noops",
33
+ "libero_spatial_no_noops",
34
+ ] # For simplicity we will combine multiple Libero datasets into one training dataset
35
+
36
+
37
+ def main(data_dir: str, *, push_to_hub: bool = False):
38
+ # Clean up any existing dataset in the output directory
39
+ output_path = HF_LEROBOT_HOME / REPO_NAME
40
+ if output_path.exists():
41
+ shutil.rmtree(output_path)
42
+
43
+ # Create LeRobot dataset, define features to store
44
+ # OpenPi assumes that proprio is stored in `state` and actions in `action`
45
+ # LeRobot assumes that dtype of image data is `image`
46
+ dataset = LeRobotDataset.create(
47
+ repo_id=REPO_NAME,
48
+ robot_type="panda",
49
+ fps=10,
50
+ features={
51
+ "image": {
52
+ "dtype": "image",
53
+ "shape": (256, 256, 3),
54
+ "names": ["height", "width", "channel"],
55
+ },
56
+ "wrist_image": {
57
+ "dtype": "image",
58
+ "shape": (256, 256, 3),
59
+ "names": ["height", "width", "channel"],
60
+ },
61
+ "state": {
62
+ "dtype": "float32",
63
+ "shape": (8,),
64
+ "names": ["state"],
65
+ },
66
+ "actions": {
67
+ "dtype": "float32",
68
+ "shape": (7,),
69
+ "names": ["actions"],
70
+ },
71
+ },
72
+ image_writer_threads=10,
73
+ image_writer_processes=5,
74
+ )
75
+
76
+ # Loop over raw Libero datasets and write episodes to the LeRobot dataset
77
+ # You can modify this for your own data format
78
+ for raw_dataset_name in RAW_DATASET_NAMES:
79
+ raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
80
+ for episode in raw_dataset:
81
+ for step in episode["steps"].as_numpy_iterator():
82
+ dataset.add_frame(
83
+ {
84
+ "image": step["observation"]["image"],
85
+ "wrist_image": step["observation"]["wrist_image"],
86
+ "state": step["observation"]["state"],
87
+ "actions": step["action"],
88
+ "task": step["language_instruction"].decode(),
89
+ }
90
+ )
91
+ dataset.save_episode()
92
+
93
+ # Optionally push to the Hugging Face Hub
94
+ if push_to_hub:
95
+ dataset.push_to_hub(
96
+ tags=["libero", "panda", "rlds"],
97
+ private=False,
98
+ push_videos=True,
99
+ license="apache-2.0",
100
+ )
101
+
102
+
103
+ if __name__ == "__main__":
104
+ tyro.cli(main)
capvector-pi05/examples/policy_records.ipynb ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import pathlib\n",
10
+ "\n",
11
+ "import numpy as np\n",
12
+ "\n",
13
+ "record_path = pathlib.Path(\"../policy_records\")\n",
14
+ "num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
15
+ "\n",
16
+ "records = []\n",
17
+ "for i in range(num_steps):\n",
18
+ " record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
19
+ " records.append(record)"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "print(\"length of records\", len(records))\n",
29
+ "print(\"keys in records\", records[0].keys())\n",
30
+ "\n",
31
+ "for k in records[0]:\n",
32
+ " print(f\"{k} shape: {records[0][k].shape}\")"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "from PIL import Image\n",
42
+ "\n",
43
+ "\n",
44
+ "def get_image(step: int, idx: int = 0):\n",
45
+ " img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
46
+ " return img[idx].transpose(1, 2, 0)\n",
47
+ "\n",
48
+ "\n",
49
+ "def show_image(step: int, idx_lst: list[int]):\n",
50
+ " imgs = [get_image(step, idx) for idx in idx_lst]\n",
51
+ " return Image.fromarray(np.hstack(imgs))\n",
52
+ "\n",
53
+ "\n",
54
+ "for i in range(2):\n",
55
+ " display(show_image(i, [0]))"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 14,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "import pandas as pd\n",
65
+ "\n",
66
+ "\n",
67
+ "def get_axis(name, axis):\n",
68
+ " return np.array([record[name][axis] for record in records])\n",
69
+ "\n",
70
+ "\n",
71
+ "# qpos is [..., 14] of type float:\n",
72
+ "# 0-5: left arm joint angles\n",
73
+ "# 6: left arm gripper\n",
74
+ "# 7-12: right arm joint angles\n",
75
+ "# 13: right arm gripper\n",
76
+ "names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
77
+ "\n",
78
+ "\n",
79
+ "def make_data():\n",
80
+ " cur_dim = 0\n",
81
+ " in_data = {}\n",
82
+ " out_data = {}\n",
83
+ " for name, dim_size in names:\n",
84
+ " for i in range(dim_size):\n",
85
+ " in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
86
+ " out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
87
+ " cur_dim += 1\n",
88
+ " return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
89
+ "\n",
90
+ "\n",
91
+ "in_data, out_data = make_data()"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "for name in in_data.columns:\n",
101
+ " data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
102
+ " data.plot()"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": []
111
+ }
112
+ ],
113
+ "metadata": {
114
+ "kernelspec": {
115
+ "display_name": ".venv",
116
+ "language": "python",
117
+ "name": "python3"
118
+ },
119
+ "language_info": {
120
+ "codemirror_mode": {
121
+ "name": "ipython",
122
+ "version": 3
123
+ },
124
+ "file_extension": ".py",
125
+ "mimetype": "text/x-python",
126
+ "name": "python",
127
+ "nbconvert_exporter": "python",
128
+ "pygments_lexer": "ipython3",
129
+ "version": "3.11.9"
130
+ }
131
+ },
132
+ "nbformat": 4,
133
+ "nbformat_minor": 2
134
+ }
capvector-pi05/pyproject.toml ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "openpi"
3
+ version = "0.1.0"
4
+ description = "Physical Intelligence open source repo"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ license = { file = "LICENSE" }
8
+ dependencies = [
9
+ "augmax>=0.3.4",
10
+ "dm-tree>=0.1.8",
11
+ "einops>=0.8.0",
12
+ "equinox>=0.11.8",
13
+ "flatbuffers>=24.3.25",
14
+ "flax==0.10.2",
15
+ "fsspec[gcs]>=2024.6.0",
16
+ "gym-aloha>=0.1.1",
17
+ "imageio>=2.36.1",
18
+ "jax[cuda12]==0.5.3",
19
+ "jaxtyping==0.2.36",
20
+ "lerobot",
21
+ "ml_collections==1.0.0",
22
+ "numpy>=1.22.4,<2.0.0",
23
+ "numpydantic>=1.6.6",
24
+ "opencv-python>=4.10.0.84",
25
+ "openpi-client",
26
+ "orbax-checkpoint==0.11.13",
27
+ "pillow>=11.0.0",
28
+ "sentencepiece>=0.2.0",
29
+ "torch==2.7.1",
30
+ "tqdm-loggable>=0.2",
31
+ "typing-extensions>=4.12.2",
32
+ "tyro>=0.9.5",
33
+ "wandb>=0.19.1",
34
+ "filelock>=3.16.1",
35
+ "beartype==0.19.0",
36
+ "treescope>=0.1.7",
37
+ "transformers==4.53.2",
38
+ "rich>=14.0.0",
39
+ "polars>=1.30.0",
40
+ "gradio==5.17.1",
41
+ "viser==0.2.23",
42
+ "hydra-core",
43
+ "onnxruntime",
44
+ "safetensors",
45
+ ]
46
+
47
+
48
+ [project.urls]
49
+ Repository = "https://github.com/Physical-Intelligence/openpi"
50
+
51
+ [dependency-groups]
52
+ dev = [
53
+ "pytest>=8.3.4",
54
+ "ruff>=0.8.6",
55
+ "pre-commit>=4.0.1",
56
+ "ipykernel>=6.29.5",
57
+ "ipywidgets>=8.1.5",
58
+ "matplotlib>=3.10.0",
59
+ "pynvml>=12.0.0",
60
+ ]
61
+ rlds = [
62
+ "dlimp",
63
+ "tensorflow-cpu==2.15.0",
64
+ "tensorflow-datasets==4.9.9",
65
+ ]
66
+
67
+ [tool.uv]
68
+ override-dependencies = ["datasets==3.6.0", "ml-dtypes==0.4.1", "tensorstore==0.1.74"]
69
+
70
+ [tool.uv.sources]
71
+ openpi-client = { workspace = true }
72
+ lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
73
+ dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" }
74
+
75
+ [tool.uv.workspace]
76
+ members = ["packages/*", "src/vggt"]
77
+
78
+ [tool.ruff]
79
+ line-length = 120
80
+ target-version = "py311"
81
+ extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"]
82
+
83
+ [tool.ruff.lint]
84
+ # https://docs.astral.sh/ruff/rules/
85
+ select = [
86
+ "B",
87
+ "C4",
88
+ "DTZ",
89
+ "E4",
90
+ "E7",
91
+ "E9",
92
+ "F",
93
+ "FBT",
94
+ "FURB",
95
+ "I",
96
+ "ICN",
97
+ "ISC",
98
+ "LOG",
99
+ "N",
100
+ "PD",
101
+ "PERF",
102
+ "PIE",
103
+ "PLC",
104
+ "PLE",
105
+ "PLR1",
106
+ "PLR5",
107
+ "PLW",
108
+ "PT",
109
+ "Q",
110
+ "RET",
111
+ "RUF",
112
+ "SIM",
113
+ "SLF",
114
+ "T10",
115
+ "T20",
116
+ "UP",
117
+ "W",
118
+ ]
119
+ ignore = [
120
+ "F722", # Conflicts with array typing.
121
+ "T201", # We use print statements.
122
+ "PD008", # Lots of false positives.
123
+ "ISC001", # Disabling to support ruff format.
124
+ "LOG015", # Use logger.info.
125
+ ]
126
+ unfixable = [
127
+ "B905", # Fix defaults to strict=False, which is not what we want.
128
+ ]
129
+
130
+ [tool.ruff.lint.isort]
131
+ force-single-line = true
132
+ force-sort-within-sections = true
133
+ single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
134
+ known-third-party = ["wandb"]
135
+
136
+ [build-system]
137
+ requires = ["hatchling"]
138
+ build-backend = "hatchling.build"
139
+
140
+ [tool.pytest.ini_options]
141
+ markers = ["manual: should be run manually."]
142
+ testpaths = ["src", "scripts", "packages"]