iMihayo commited on
Commit
e188f3d
·
verified ·
1 Parent(s): dfa767d

Add files using upload-large-folder tool

Browse files
Files changed (41) hide show
  1. prismatic/__init__.py +1 -0
  2. prismatic/conf/datasets.py +133 -0
  3. prismatic/conf/vla.py +235 -0
  4. prismatic/models/backbones/__init__.py +0 -0
  5. prismatic/models/backbones/vision/dinov2_vit.py +19 -0
  6. prismatic/models/load.py +226 -0
  7. prismatic/models/materialize.py +130 -0
  8. prismatic/models/projectors.py +67 -0
  9. prismatic/preprocessing/datasets/datasets.py +200 -0
  10. prismatic/preprocessing/materialize.py +69 -0
  11. prismatic/py.typed +0 -0
  12. prismatic/util/nn_utils.py +53 -0
  13. prismatic/vla/datasets/rlds/__init__.py +1 -0
  14. prismatic/vla/datasets/rlds/dataset.py +655 -0
  15. prismatic/vla/datasets/rlds/oxe/transforms.py +951 -0
  16. prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py +178 -0
  17. prismatic/vla/datasets/rlds/utils/__init__.py +0 -0
  18. prismatic/vla/datasets/rlds/utils/goal_relabeling.py +32 -0
  19. prismatic/vla/materialize.py +56 -0
  20. run_scripts/ac/ac.sh +87 -0
  21. run_scripts/ffn/3ffn2.sh +87 -0
  22. run_scripts/ffn/3postffn2.sh +87 -0
  23. run_scripts/ffn/3postffn6.sh +87 -0
  24. run_scripts/ffn/debug_5ffn_withactionprojector.sh +87 -0
  25. run_scripts/ffn/ffn4.sh +87 -0
  26. run_scripts/ffn/ffn8.sh +87 -0
  27. run_scripts/ffn/test.sh +87 -0
  28. run_scripts/ffn_long_chunks/run.sh +4 -0
  29. run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_base.sh +88 -0
  30. run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_50_l2.sh +102 -0
  31. run_scripts/ffn_q2a/bridge/exffn_relu_connector_linear_relu.sh +95 -0
  32. run_scripts/ffn_q2a/bridge/run_bridge.sh +2 -0
  33. run_scripts/ffn_q2a/franka/exffn_gelu_franka.sh +95 -0
  34. run_scripts/ffn_q2a/libero_moe/debug_moe_lit.sh +101 -0
  35. run_scripts/ffn_q2a/simhead/simhead_contrastive.sh +100 -0
  36. run_scripts/pp/pp.sh +87 -0
  37. run_scripts/run.sh +35 -0
  38. scripts/extern/verify_prismatic.py +134 -0
  39. scripts/pretrain.py +238 -0
  40. test_deepseek_moe.py +246 -0
  41. vla-scripts/extern/convert_openvla_weights_to_hf.py +272 -0
prismatic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .models import available_model_names, available_models, get_model_description, load
prismatic/conf/datasets.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ datasets.py
3
+
4
+ Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant
5
+ and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes:
6
+ - Dataset Variant (Identifier) --> e.g., "llava-v15"
7
+ - Align Stage Dataset Components (annotations, images)
8
+ - Finetune Stage Dataset Components (annotations, images)
9
+ - Dataset Root Directory (Path)
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from enum import Enum, unique
14
+ from pathlib import Path
15
+ from typing import Tuple
16
+
17
+ from draccus import ChoiceRegistry
18
+
19
+
20
+ @dataclass
21
+ class DatasetConfig(ChoiceRegistry):
22
+ # fmt: off
23
+ dataset_id: str # Unique ID that fully specifies a dataset variant
24
+
25
+ # Dataset Components for each Stage in < align | finetune >
26
+ align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage
27
+ finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage
28
+
29
+ dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root
30
+ # fmt: on
31
+
32
+
33
+ # [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models)
34
+ @dataclass
35
+ class LLaVa_V15_Config(DatasetConfig):
36
+ dataset_id: str = "llava-v15"
37
+
38
+ align_stage_components: Tuple[Path, Path] = (
39
+ Path("download/llava-laion-cc-sbu-558k/chat.json"),
40
+ Path("download/llava-laion-cc-sbu-558k/"),
41
+ )
42
+ finetune_stage_components: Tuple[Path, Path] = (
43
+ Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"),
44
+ Path("download/llava-v1.5-instruct/"),
45
+ )
46
+ dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
47
+
48
+
49
+ # [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training)
50
+ @dataclass
51
+ class LLaVa_Multimodal_Only_Config(DatasetConfig):
52
+ dataset_id: str = "llava-multimodal"
53
+
54
+ align_stage_components: Tuple[Path, Path] = (
55
+ Path("download/llava-laion-cc-sbu-558k/chat.json"),
56
+ Path("download/llava-laion-cc-sbu-558k/"),
57
+ )
58
+ finetune_stage_components: Tuple[Path, Path] = (
59
+ Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"),
60
+ Path("download/llava-v1.5-instruct/"),
61
+ )
62
+ dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
63
+
64
+
65
+ # LLaVa-v15 + LVIS-Instruct-4V
66
+ @dataclass
67
+ class LLaVa_LVIS4V_Config(DatasetConfig):
68
+ dataset_id: str = "llava-lvis4v"
69
+
70
+ align_stage_components: Tuple[Path, Path] = (
71
+ Path("download/llava-laion-cc-sbu-558k/chat.json"),
72
+ Path("download/llava-laion-cc-sbu-558k/"),
73
+ )
74
+ finetune_stage_components: Tuple[Path, Path] = (
75
+ Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"),
76
+ Path("download/llava-v1.5-instruct/"),
77
+ )
78
+ dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
79
+
80
+
81
+ # LLaVa-v15 + LRV-Instruct
82
+ @dataclass
83
+ class LLaVa_LRV_Config(DatasetConfig):
84
+ dataset_id: str = "llava-lrv"
85
+
86
+ align_stage_components: Tuple[Path, Path] = (
87
+ Path("download/llava-laion-cc-sbu-558k/chat.json"),
88
+ Path("download/llava-laion-cc-sbu-558k/"),
89
+ )
90
+ finetune_stage_components: Tuple[Path, Path] = (
91
+ Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"),
92
+ Path("download/llava-v1.5-instruct/"),
93
+ )
94
+ dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
95
+
96
+
97
+ # LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct
98
+ @dataclass
99
+ class LLaVa_LVIS4V_LRV_Config(DatasetConfig):
100
+ dataset_id: str = "llava-lvis4v-lrv"
101
+
102
+ align_stage_components: Tuple[Path, Path] = (
103
+ Path("download/llava-laion-cc-sbu-558k/chat.json"),
104
+ Path("download/llava-laion-cc-sbu-558k/"),
105
+ )
106
+ finetune_stage_components: Tuple[Path, Path] = (
107
+ Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"),
108
+ Path("download/llava-v1.5-instruct/"),
109
+ )
110
+ dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
111
+
112
+
113
+ # === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! ===
114
+ @unique
115
+ class DatasetRegistry(Enum):
116
+ # === LLaVa v1.5 ===
117
+ LLAVA_V15 = LLaVa_V15_Config
118
+
119
+ LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config
120
+
121
+ LLAVA_LVIS4V = LLaVa_LVIS4V_Config
122
+ LLAVA_LRV = LLaVa_LRV_Config
123
+
124
+ LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config
125
+
126
+ @property
127
+ def dataset_id(self) -> str:
128
+ return self.value.dataset_id
129
+
130
+
131
+ # Register Datasets in Choice Registry
132
+ for dataset_variant in DatasetRegistry:
133
+ DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value)
prismatic/conf/vla.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ vla.py
3
+
4
+ Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and
5
+ model configuration thereof. A given VLA model (`policy`) configures the following attributes:
6
+ - Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.)
7
+ - Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`)
8
+ - VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning)
9
+ - Training / Optimization Hyperparameters
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from enum import Enum, unique
14
+ from pathlib import Path
15
+ from typing import Optional, Union
16
+
17
+ from draccus import ChoiceRegistry
18
+
19
+
20
+ @dataclass
21
+ class VLAConfig(ChoiceRegistry):
22
+ # fmt: off
23
+ vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant
24
+ base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`)
25
+ freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining)
26
+ freeze_llm_backbone: bool # Freeze LLM Backbone parameters
27
+ unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen)
28
+
29
+ # Data Mixture Parameters
30
+ data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`)
31
+ shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE)
32
+
33
+ # Optimization Parameters
34
+ epochs: int # Epochs to Run (in case `max_steps` is not specified)
35
+ max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`)
36
+
37
+ expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware
38
+ global_batch_size: int # Global Batch Size (divided across processes / world size)
39
+ per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU)
40
+ # =>> # of accumulation steps is auto-computed
41
+
42
+ learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay)
43
+ weight_decay: float # Weight Decay for AdamW Optimizer
44
+ max_grad_norm: float # Max Grad Norm (for global gradient clipping)
45
+ lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay")
46
+ warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers)
47
+
48
+ train_strategy: str # Train Strategy (default "fsdp-full-shard")
49
+
50
+ # Enable Gradient/Activation Checkpointing (for the LLM Backbone)
51
+ enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training
52
+
53
+ # Mixed Precision Training via Torch Native AMP (`autocast`)
54
+ enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision
55
+ reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision
56
+
57
+ # fmt: on
58
+
59
+
60
+ # === OpenVLA Training Configurations ===
61
+
62
+
63
+ # = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge =
64
+ @dataclass
65
+ class Exp_SigLIP_224px_Bridge(VLAConfig):
66
+ vla_id: str = "siglip-224px+mx-bridge"
67
+ base_vlm: Union[str, Path] = "siglip-224px+7b"
68
+
69
+ freeze_vision_backbone: bool = False
70
+ freeze_llm_backbone: bool = False
71
+ unfreeze_last_llm_layer: bool = False
72
+
73
+ # Data Mixture Parameters
74
+ data_mix: str = "bridge"
75
+ shuffle_buffer_size: int = 256_000
76
+
77
+ # Optimization Parameters
78
+ epochs: int = 1000
79
+ max_steps: Optional[int] = None
80
+
81
+ expected_world_size: int = 8
82
+ global_batch_size: int = 256
83
+ per_device_batch_size: int = 32
84
+
85
+ learning_rate: float = 2e-5
86
+ weight_decay: float = 0.0
87
+ max_grad_norm: float = 1.0
88
+ lr_scheduler_type: str = "constant"
89
+ warmup_ratio: float = 0.0
90
+
91
+ train_strategy: str = "fsdp-full-shard"
92
+
93
+
94
+ # = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge =
95
+ @dataclass
96
+ class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
97
+ vla_id: str = "siglip-224px-icy+mx-bridge"
98
+ base_vlm: Union[str, Path] = "siglip-224px+7b"
99
+ freeze_vision_backbone: bool = True
100
+
101
+
102
+ # = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge =
103
+ @dataclass
104
+ class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
105
+ vla_id: str = "prism-dinosiglip-224px+mx-bridge"
106
+ base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
107
+
108
+ data_mix: str = "bridge"
109
+
110
+
111
+ # = [64 GPU] SigLIP 224px + OXE Magic Soup =
112
+ @dataclass
113
+ class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge):
114
+ vla_id: str = "siglip-224px+mx-oxe-magic-soup"
115
+ base_vlm: Union[str, Path] = "siglip-224px+7b"
116
+
117
+ data_mix: str = "oxe_magic_soup"
118
+
119
+ expected_world_size: int = 64
120
+ global_batch_size: int = 2048
121
+ per_device_batch_size: int = 32
122
+
123
+
124
+ # = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ =
125
+ @dataclass
126
+ class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge):
127
+ vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus"
128
+ base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
129
+
130
+ # Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling!
131
+ # data_mix: str = "oxe_magic_soup_plus"
132
+ data_mix: str = "oxe_magic_soup_plus_minus"
133
+
134
+ expected_world_size: int = 64
135
+ global_batch_size: int = 2048
136
+ per_device_batch_size: int = 32
137
+
138
+
139
+ # === OpenVLA Fine-tuning Configurations ===
140
+
141
+
142
+ # = [8 GPU] SigLIP 224px + T-DROID =
143
+ @dataclass
144
+ class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
145
+ vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl"
146
+ base_vlm: Union[str, Path] = "siglip-224px+7b"
147
+
148
+ data_mix: str = "tdroid_carrot_in_bowl"
149
+
150
+
151
+ @dataclass
152
+ class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge):
153
+ vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot"
154
+ base_vlm: Union[str, Path] = "siglip-224px+7b"
155
+
156
+ data_mix: str = "tdroid_pour_corn_in_pot"
157
+
158
+
159
+ # = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning =
160
+ @dataclass
161
+ class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
162
+ vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl"
163
+ base_vlm: Union[str, Path] = "siglip-224px+7b"
164
+ freeze_vision_backbone: bool = True
165
+ freeze_llm_backbone: bool = False
166
+
167
+ data_mix: str = "tdroid_carrot_in_bowl"
168
+
169
+
170
+ @dataclass
171
+ class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
172
+ vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl"
173
+ base_vlm: Union[str, Path] = "siglip-224px+7b"
174
+ freeze_vision_backbone: bool = True
175
+ freeze_llm_backbone: bool = True
176
+ unfreeze_last_llm_layer: bool = True
177
+
178
+ data_mix: str = "tdroid_carrot_in_bowl"
179
+
180
+
181
+ @dataclass
182
+ class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
183
+ vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl"
184
+ base_vlm: Union[str, Path] = "siglip-224px+7b"
185
+ freeze_vision_backbone: bool = False
186
+ freeze_llm_backbone: bool = True
187
+ unfreeze_last_llm_layer: bool = True
188
+
189
+ data_mix: str = "tdroid_carrot_in_bowl"
190
+
191
+
192
+ # === [8 GPU] SigLIP 224px + FrankaWipe ===
193
+ @dataclass
194
+ class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge):
195
+ vla_id: str = "siglip-224px+mx-droid_wipe"
196
+ base_vlm: Union[str, Path] = "siglip-224px+7b"
197
+
198
+ data_mix: str = "droid_wipe"
199
+
200
+
201
+ # === Define a VLA Registry Enum for Reference & Validation ===
202
+ @unique
203
+ class VLARegistry(Enum):
204
+ # Sanity Check Configurations =>> BridgeV2
205
+ SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge
206
+ DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge
207
+
208
+ # SigLIP Frozen Backbone Experiment
209
+ FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge
210
+
211
+ # [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup
212
+ SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup
213
+
214
+ # [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++
215
+ DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus
216
+
217
+ # === TDROID Fine-tuning Configs ===
218
+ SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl
219
+ SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot
220
+
221
+ SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl
222
+ SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl
223
+ SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl
224
+
225
+ # === DROID Fine-tuning Configs ===
226
+ SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe
227
+
228
+ @property
229
+ def vla_id(self) -> str:
230
+ return self.value.vla_id
231
+
232
+
233
+ # Register VLAs in Choice Registry
234
+ for vla_variant in VLARegistry:
235
+ VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value)
prismatic/models/backbones/__init__.py ADDED
File without changes
prismatic/models/backbones/vision/dinov2_vit.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dinov2_vit.py
3
+ """
4
+
5
+ from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
6
+
7
+ # Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers!
8
+ # => Reference: https://arxiv.org/abs/2309.16588
9
+ DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"}
10
+
11
+
12
+ class DinoV2ViTBackbone(TimmViTBackbone):
13
+ def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
14
+ super().__init__(
15
+ vision_backbone_id,
16
+ DINOv2_VISION_BACKBONES[vision_backbone_id],
17
+ image_resize_strategy,
18
+ default_image_size=default_image_size,
19
+ )
prismatic/models/load.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ load.py
3
+
4
+ Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical
5
+ IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub).
6
+ """
7
+
8
+ import json
9
+ import os
10
+ from pathlib import Path
11
+ from typing import List, Optional, Union
12
+
13
+ from huggingface_hub import HfFileSystem, hf_hub_download
14
+
15
+ from prismatic.conf import ModelConfig
16
+ from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform
17
+ from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY
18
+ from prismatic.models.vlas import OpenVLA
19
+ from prismatic.models.vlms import PrismaticVLM
20
+ from prismatic.overwatch import initialize_overwatch
21
+ from prismatic.vla.action_tokenizer import ActionTokenizer
22
+
23
+ # Initialize Overwatch =>> Wraps `logging.Logger`
24
+ overwatch = initialize_overwatch(__name__)
25
+
26
+
27
+ # === HF Hub Repository ===
28
+ HF_HUB_REPO = "TRI-ML/prismatic-vlms"
29
+ VLA_HF_HUB_REPO = "openvla/openvla-dev"
30
+
31
+
32
+ # === Available Models ===
33
+ def available_models() -> List[str]:
34
+ return list(MODEL_REGISTRY.keys())
35
+
36
+
37
+ def available_model_names() -> List[str]:
38
+ return list(GLOBAL_REGISTRY.items())
39
+
40
+
41
+ def get_model_description(model_id_or_name: str) -> str:
42
+ if model_id_or_name not in GLOBAL_REGISTRY:
43
+ raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`")
44
+
45
+ # Print Description & Return
46
+ print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2))
47
+
48
+ return description
49
+
50
+
51
+ # === Load Pretrained Model ===
52
+ def load(
53
+ model_id_or_path: Union[str, Path],
54
+ hf_token: Optional[str] = None,
55
+ cache_dir: Optional[Union[str, Path]] = None,
56
+ load_for_training: bool = False,
57
+ ) -> PrismaticVLM:
58
+ """Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub."""
59
+ if os.path.isdir(model_id_or_path):
60
+ overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`")
61
+
62
+ # Get paths for `config.json` and pretrained checkpoint
63
+ config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
64
+ assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
65
+ assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
66
+ else:
67
+ if model_id_or_path not in GLOBAL_REGISTRY:
68
+ raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`")
69
+
70
+ overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub")
71
+ with overwatch.local_zero_first():
72
+ config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir)
73
+ checkpoint_pt = hf_hub_download(
74
+ repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir
75
+ )
76
+
77
+ # Load Model Config from `config.json`
78
+ with open(config_json, "r") as f:
79
+ model_cfg = json.load(f)["model"]
80
+
81
+ # = Load Individual Components necessary for Instantiating a VLM =
82
+ # =>> Print Minimal Config
83
+ overwatch.info(
84
+ f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n"
85
+ f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n"
86
+ f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n"
87
+ f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n"
88
+ f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
89
+ )
90
+
91
+ # Load Vision Backbone
92
+ overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]")
93
+ vision_backbone, image_transform = get_vision_backbone_and_transform(
94
+ model_cfg["vision_backbone_id"],
95
+ model_cfg["image_resize_strategy"],
96
+ )
97
+
98
+ # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()`
99
+ overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers")
100
+ llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
101
+ model_cfg["llm_backbone_id"],
102
+ llm_max_length=model_cfg.get("llm_max_length", 2048),
103
+ hf_token=hf_token,
104
+ inference_mode=not load_for_training,
105
+ )
106
+
107
+ # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
108
+ overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint")
109
+ vlm = PrismaticVLM.from_pretrained(
110
+ checkpoint_pt,
111
+ model_cfg["model_id"],
112
+ vision_backbone,
113
+ llm_backbone,
114
+ arch_specifier=model_cfg["arch_specifier"],
115
+ freeze_weights=not load_for_training,
116
+ )
117
+
118
+ return vlm
119
+
120
+
121
+ # === Load Pretrained VLA Model ===
122
+ def load_vla(
123
+ model_id_or_path: Union[str, Path],
124
+ hf_token: Optional[str] = None,
125
+ cache_dir: Optional[Union[str, Path]] = None,
126
+ load_for_training: bool = False,
127
+ step_to_load: Optional[int] = None,
128
+ model_type: str = "pretrained",
129
+ ) -> OpenVLA:
130
+ """Loads a pretrained OpenVLA from either local disk or the HuggingFace Hub."""
131
+
132
+ # TODO (siddk, moojink) :: Unify semantics with `load()` above; right now, `load_vla()` assumes path points to
133
+ # checkpoint `.pt` file, rather than the top-level run directory!
134
+ if os.path.isfile(model_id_or_path):
135
+ overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`")
136
+
137
+ # [Validate] Checkpoint Path should look like `.../<RUN_ID>/checkpoints/<CHECKPOINT_PATH>.pt`
138
+ assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!"
139
+ run_dir = checkpoint_pt.parents[1]
140
+
141
+ # Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint
142
+ config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json"
143
+ assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
144
+ assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
145
+
146
+ # Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`VLA_HF_HUB_REPO`)
147
+ else:
148
+ # Search HF Hub Repo via fsspec API
149
+ overwatch.info(f"Checking HF for `{(hf_path := str(Path(VLA_HF_HUB_REPO) / model_type / model_id_or_path))}`")
150
+ if not (tmpfs := HfFileSystem()).exists(hf_path):
151
+ raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`")
152
+
153
+ # Identify Checkpoint to Load (via `step_to_load`)
154
+ step_to_load = f"{step_to_load:06d}" if step_to_load is not None else None
155
+ valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/step-{step_to_load if step_to_load is not None else ''}*.pt")
156
+ if (len(valid_ckpts) == 0) or (step_to_load is not None and len(valid_ckpts) != 1):
157
+ raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/")
158
+
159
+ # Call to `glob` will sort steps in ascending order (if `step_to_load` is None); just grab last element
160
+ target_ckpt = Path(valid_ckpts[-1]).name
161
+
162
+ overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`")
163
+ with overwatch.local_zero_first():
164
+ relpath = Path(model_type) / model_id_or_path
165
+ config_json = hf_hub_download(
166
+ repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'config.json')!s}", cache_dir=cache_dir
167
+ )
168
+ dataset_statistics_json = hf_hub_download(
169
+ repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'dataset_statistics.json')!s}", cache_dir=cache_dir
170
+ )
171
+ checkpoint_pt = hf_hub_download(
172
+ repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'checkpoints' / target_ckpt)!s}", cache_dir=cache_dir
173
+ )
174
+
175
+ # Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json`
176
+ with open(config_json, "r") as f:
177
+ vla_cfg = json.load(f)["vla"]
178
+ model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])()
179
+
180
+ # Load Dataset Statistics for Action Denormalization
181
+ with open(dataset_statistics_json, "r") as f:
182
+ norm_stats = json.load(f)
183
+
184
+ # = Load Individual Components necessary for Instantiating a VLA (via base VLM components) =
185
+ # =>> Print Minimal Config
186
+ overwatch.info(
187
+ f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n"
188
+ f" Vision Backbone =>> [bold]{model_cfg.vision_backbone_id}[/]\n"
189
+ f" LLM Backbone =>> [bold]{model_cfg.llm_backbone_id}[/]\n"
190
+ f" Arch Specifier =>> [bold]{model_cfg.arch_specifier}[/]\n"
191
+ f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
192
+ )
193
+
194
+ # Load Vision Backbone
195
+ overwatch.info(f"Loading Vision Backbone [bold]{model_cfg.vision_backbone_id}[/]")
196
+ vision_backbone, image_transform = get_vision_backbone_and_transform(
197
+ model_cfg.vision_backbone_id,
198
+ model_cfg.image_resize_strategy,
199
+ )
200
+
201
+ # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()`
202
+ overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg.llm_backbone_id}[/] via HF Transformers")
203
+ llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
204
+ model_cfg.llm_backbone_id,
205
+ llm_max_length=model_cfg.llm_max_length,
206
+ hf_token=hf_token,
207
+ inference_mode=not load_for_training,
208
+ )
209
+
210
+ # Create Action Tokenizer
211
+ action_tokenizer = ActionTokenizer(llm_backbone.get_tokenizer())
212
+
213
+ # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
214
+ overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint")
215
+ vla = OpenVLA.from_pretrained(
216
+ checkpoint_pt,
217
+ model_cfg.model_id,
218
+ vision_backbone,
219
+ llm_backbone,
220
+ arch_specifier=model_cfg.arch_specifier,
221
+ freeze_weights=not load_for_training,
222
+ norm_stats=norm_stats,
223
+ action_tokenizer=action_tokenizer,
224
+ )
225
+
226
+ return vla
prismatic/models/materialize.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ materialize.py
3
+
4
+ Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports
5
+ individual functions for clear control flow.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ from transformers import PreTrainedTokenizerBase
11
+
12
+ from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone
13
+ from prismatic.models.backbones.vision import (
14
+ CLIPViTBackbone,
15
+ DinoCLIPViTBackbone,
16
+ DinoSigLIPViTBackbone,
17
+ DinoV2ViTBackbone,
18
+ ImageTransform,
19
+ IN1KViTBackbone,
20
+ SigLIPViTBackbone,
21
+ VisionBackbone,
22
+ )
23
+ from prismatic.models.vlms import PrismaticVLM
24
+
25
+ # === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs ===
26
+ # fmt: off
27
+
28
+ # === Vision Backbone Registry ===
29
+ VISION_BACKBONES = {
30
+ # === 224px Backbones ===
31
+ "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
32
+ "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
33
+ "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}},
34
+ "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}},
35
+ "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
36
+
37
+ # === Assorted CLIP Backbones ===
38
+ "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
39
+ "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}},
40
+
41
+ # === Assorted SigLIP Backbones ===
42
+ "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
43
+ "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}},
44
+ "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
45
+ "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
46
+
47
+ # === Fused Backbones ===
48
+ "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}},
49
+ "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
50
+ }
51
+
52
+
53
+ # === Language Model Registry ===
54
+ LLM_BACKBONES = {
55
+ # === LLaMa-2 Pure (Non-Chat) Backbones ===
56
+ "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
57
+ "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
58
+
59
+ # === LLaMa-2 Chat Backbones ===
60
+ "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
61
+ "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
62
+
63
+ # === Vicuna-v1.5 Backbones ===
64
+ "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
65
+ "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
66
+
67
+ # === Mistral v0.1 Backbones ===
68
+ "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}},
69
+ "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}},
70
+
71
+ # === Phi-2 Backbone ===
72
+ "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}},
73
+ }
74
+
75
+ # fmt: on
76
+
77
+
78
+ def get_vision_backbone_and_transform(
79
+ vision_backbone_id: str, image_resize_strategy: str
80
+ ) -> Tuple[VisionBackbone, ImageTransform]:
81
+ """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform."""
82
+ if vision_backbone_id in VISION_BACKBONES:
83
+ vision_cfg = VISION_BACKBONES[vision_backbone_id]
84
+ vision_backbone: VisionBackbone = vision_cfg["cls"](
85
+ vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"]
86
+ )
87
+ image_transform = vision_backbone.get_image_transform()
88
+ return vision_backbone, image_transform
89
+
90
+ else:
91
+ raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!")
92
+
93
+
94
+ def get_llm_backbone_and_tokenizer(
95
+ llm_backbone_id: str,
96
+ llm_max_length: int = 2048,
97
+ hf_token: Optional[str] = None,
98
+ inference_mode: bool = False,
99
+ ) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]:
100
+ if llm_backbone_id in LLM_BACKBONES:
101
+ llm_cfg = LLM_BACKBONES[llm_backbone_id]
102
+ llm_backbone: LLMBackbone = llm_cfg["cls"](
103
+ llm_backbone_id,
104
+ llm_max_length=llm_max_length,
105
+ hf_token=hf_token,
106
+ inference_mode=inference_mode,
107
+ **llm_cfg["kwargs"],
108
+ )
109
+ tokenizer = llm_backbone.get_tokenizer()
110
+ return llm_backbone, tokenizer
111
+
112
+ else:
113
+ raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!")
114
+
115
+
116
+ def get_vlm(
117
+ model_id: str,
118
+ arch_specifier: str,
119
+ vision_backbone: VisionBackbone,
120
+ llm_backbone: LLMBackbone,
121
+ enable_mixed_precision_training: bool = True,
122
+ ) -> PrismaticVLM:
123
+ """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM)."""
124
+ return PrismaticVLM(
125
+ model_id,
126
+ vision_backbone,
127
+ llm_backbone,
128
+ enable_mixed_precision_training=enable_mixed_precision_training,
129
+ arch_specifier=arch_specifier,
130
+ )
prismatic/models/projectors.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of additional projectors for additional inputs to the VLA models."""
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+
6
+
7
+ class ProprioProjector(nn.Module):
8
+ """
9
+ Projects proprio state inputs into the LLM's embedding space.
10
+ """
11
+ def __init__(self, llm_dim: int, proprio_dim: int) -> None:
12
+ super().__init__()
13
+ self.llm_dim = llm_dim
14
+ self.proprio_dim = proprio_dim
15
+
16
+ self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True)
17
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
18
+ self.act_fn1 = nn.GELU()
19
+
20
+ def forward(self, proprio: torch.Tensor = None) -> torch.Tensor:
21
+ # proprio: (bsz, proprio_dim)
22
+ projected_features = self.fc1(proprio)
23
+ projected_features = self.act_fn1(projected_features)
24
+ projected_features = self.fc2(projected_features)
25
+ return projected_features
26
+
27
+
28
+ class NoisyActionProjector(nn.Module):
29
+ """
30
+ [Diffusion] Projects noisy action inputs into the LLM's embedding space.
31
+
32
+ Note that since each action is tokenized into 7 tokens in OpenVLA (rather
33
+ than having 1 token per action), each noisy action token will have dimension 1
34
+ instead of 7.
35
+ """
36
+ def __init__(self, llm_dim: int) -> None:
37
+ super().__init__()
38
+ self.llm_dim = llm_dim
39
+ self.action_token_dim = 1
40
+
41
+ self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True)
42
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
43
+ self.act_fn1 = nn.GELU()
44
+
45
+ def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor:
46
+ # noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1)
47
+ projected_features = self.fc1(noisy_actions)
48
+ projected_features = self.act_fn1(projected_features)
49
+ projected_features = self.fc2(projected_features)
50
+ return projected_features
51
+
52
+
53
+
54
+
55
+ class VisualProjector(nn.Module):
56
+ def __init__(self, llm_dim: int, visual_dim: int) -> None:
57
+ super().__init__()
58
+ self.visual_dim, self.llm_dim = visual_dim, llm_dim
59
+ self.fc1 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
60
+ self.fc2 = nn.Linear(self.llm_dim, self.visual_dim, bias=True)
61
+ self.act_fn1 = nn.GELU()
62
+
63
+ def forward(self, img_hidden_embedding: torch.Tensor) -> torch.Tensor:
64
+ projected_features = self.fc1(img_hidden_embedding)
65
+ projected_features = self.act_fn1(projected_features)
66
+ projected_features = self.fc2(projected_features)
67
+ return projected_features
prismatic/preprocessing/datasets/datasets.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ datasets.py
3
+
4
+ PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with
5
+ utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected
6
+ formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models).
7
+
8
+ We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that
9
+ random access image reading is relatively cheap/fast.
10
+ """
11
+
12
+ import copy
13
+ import json
14
+ from pathlib import Path
15
+ from typing import Dict, List, Tuple, Type
16
+
17
+ import torch
18
+ from PIL import Image
19
+ from torch.utils.data import Dataset
20
+ from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase
21
+
22
+ from prismatic.models.backbones.llm.prompting import PromptBuilder
23
+ from prismatic.models.backbones.vision import ImageTransform
24
+
25
+ # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
26
+ IGNORE_INDEX = -100
27
+
28
+
29
+ class AlignDataset(Dataset[Dict[str, torch.Tensor]]):
30
+ def __init__(
31
+ self,
32
+ chat_json: Path,
33
+ image_dir: Path,
34
+ image_transform: ImageTransform,
35
+ tokenizer: PreTrainedTokenizerBase,
36
+ ) -> None:
37
+ super().__init__()
38
+ self.chat_json, self.image_dir = chat_json, image_dir
39
+ self.image_transform, self.tokenizer = image_transform, tokenizer
40
+ self.dataset_type = "align"
41
+
42
+ # Create Prompt Template
43
+ self.prompt_template = "{caption}" + self.tokenizer.eos_token
44
+
45
+ # Load Chat JSON
46
+ with open(self.chat_json, "r") as f:
47
+ self.examples = json.load(f)
48
+
49
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
50
+ """
51
+ Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard
52
+ the "prompt" from the human, and instead directly predict the caption from the image.
53
+
54
+ As a concrete example given the "raw data" for the first example:
55
+ example = self.examples[0]["conversations"]` = {
56
+ [
57
+ {"from": "human", "value": "Render a clear and concise summary of the photo.\n<image>"},
58
+ {"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"}
59
+ ]
60
+ }
61
+
62
+ Return =>> self.tokenizer("<image> select luxury furniture 3 - inch gel memory foam mattress topper\n")
63
+
64
+ :param idx: Index to retrieve from the dataset.
65
+
66
+ :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
67
+ """
68
+ image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"]
69
+ assert (len(conversation) == 2) and ("<image>" not in conversation[-1]["value"]), "Unexpected text!"
70
+
71
+ # Format Caption --> {caption}{eos_token}
72
+ caption = self.prompt_template.format(caption=conversation[-1]["value"].strip())
73
+
74
+ # We treat image patches as "tokens = [p1 p2 p3, ...]"; we need to specify ordering of text/patch tokens.
75
+ # => Critically, we find that inserting *after* the BOS token leads to the strongest performance!
76
+ # - input_ids = "<s> p1 p2 p3 ... <caption_text> \n"
77
+ # - labels = "IGNORE IGNORE ..." (copy `input_ids` replacing <s> and p{1...K} with IGNORE)
78
+ #
79
+ # IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
80
+ input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0]
81
+ labels = copy.deepcopy(input_ids)
82
+
83
+ # Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches right after)
84
+ labels[0] = IGNORE_INDEX
85
+
86
+ # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
87
+ pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
88
+
89
+ return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
90
+
91
+ def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]:
92
+ """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
93
+ modality_lengths = []
94
+ for example in self.examples:
95
+ is_multimodal = "image" in example
96
+ n_words = sum([len(turn["value"].replace("<image>", "").split()) for turn in example["conversations"]])
97
+ modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words))
98
+ return modality_lengths
99
+
100
+ def __len__(self) -> int:
101
+ return len(self.examples)
102
+
103
+
104
+ class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]):
105
+ def __init__(
106
+ self,
107
+ instruct_json: Path,
108
+ image_dir: Path,
109
+ image_transform: ImageTransform,
110
+ tokenizer: PreTrainedTokenizerBase,
111
+ prompt_builder_fn: Type[PromptBuilder],
112
+ ) -> None:
113
+ super().__init__()
114
+ self.instruct_json, self.image_dir = instruct_json, image_dir
115
+ self.image_transform, self.tokenizer = image_transform, tokenizer
116
+ self.prompt_builder_fn = prompt_builder_fn
117
+ self.dataset_type = "finetune"
118
+
119
+ # Load Instruct JSON
120
+ with open(self.instruct_json, "r") as f:
121
+ self.examples = json.load(f)
122
+
123
+ # === Unimodal + Multimodal Handling ===
124
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
125
+ """
126
+ Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of
127
+ dialog grounded in a single image.
128
+
129
+ To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the
130
+ methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example.
131
+
132
+ :param idx: Index to retrieve from the dataset.
133
+
134
+ :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
135
+ """
136
+ conversation = self.examples[idx]["conversations"]
137
+
138
+ # Create Prompt Builder --> add each message sequentially
139
+ prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], []
140
+ for turn_idx, turn in enumerate(conversation):
141
+ # Get "effective" string added to prompt --> handle whitespace for tokenizer type!
142
+ msg = prompt_builder.add_turn(turn["from"], turn["value"])
143
+
144
+ # Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty!
145
+ if isinstance(self.tokenizer, LlamaTokenizerFast):
146
+ msg = msg.rstrip()
147
+
148
+ # Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling!
149
+ elif isinstance(self.tokenizer, CodeGenTokenizerFast):
150
+ pass
151
+
152
+ else:
153
+ raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!")
154
+
155
+ # Tokenize Input IDs
156
+ turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids
157
+
158
+ # [CRITICAL] We do not want to take the loss for the "USER: <msg>" prompts =>> just the responses!
159
+ turn_labels = (
160
+ [IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids)
161
+ )
162
+
163
+ # Add to Trackers
164
+ input_ids.extend(turn_input_ids)
165
+ labels.extend(turn_labels)
166
+
167
+ # Tensorize =>> Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches after)
168
+ # - IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
169
+ input_ids, labels = torch.tensor(input_ids), torch.tensor(labels)
170
+
171
+ # Handle Truncation (if necessary)
172
+ input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length]
173
+
174
+ # === Handle "unimodal" (language-only) vs. "multimodal" ===
175
+ if "image" in self.examples[idx]:
176
+ image_path = Path(self.examples[idx]["image"])
177
+
178
+ # Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches right after)
179
+ labels[0] = IGNORE_INDEX
180
+
181
+ # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
182
+ pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
183
+
184
+ return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
185
+
186
+ else:
187
+ # No image --> return `pixel_values` = None; Collator will do the smart batch handling for us!
188
+ return dict(pixel_values=None, input_ids=input_ids, labels=labels)
189
+
190
+ def get_modality_lengths(self) -> List[Tuple[bool, int]]:
191
+ """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
192
+ modality_lengths = []
193
+ for example in self.examples:
194
+ is_multimodal = "image" in example
195
+ n_words = sum([len(turn["value"].split()) for turn in example["conversations"]])
196
+ modality_lengths.append((is_multimodal, n_words))
197
+ return modality_lengths
198
+
199
+ def __len__(self) -> int:
200
+ return len(self.examples)
prismatic/preprocessing/materialize.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ materialize.py
3
+
4
+ Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for
5
+ clear control flow.
6
+ """
7
+
8
+ from typing import Tuple, Type
9
+
10
+ from torch.utils.data import Dataset
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from prismatic.conf import DatasetConfig
14
+ from prismatic.models.backbones.llm.prompting import PromptBuilder
15
+ from prismatic.models.backbones.vision import ImageTransform
16
+ from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset
17
+ from prismatic.util.data_utils import PaddedCollatorForLanguageModeling
18
+
19
+ # Dataset Initializers =>> Maps Stage --> cls()
20
+ DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset}
21
+
22
+
23
+ def get_dataset_and_collator(
24
+ stage: str,
25
+ dataset_cfg: DatasetConfig,
26
+ image_transform: ImageTransform,
27
+ tokenizer: PreTrainedTokenizerBase,
28
+ prompt_builder_fn: Type[PromptBuilder],
29
+ default_image_resolution: Tuple[int, int, int],
30
+ padding_side: str = "right",
31
+ ) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]:
32
+ dataset_cls = DATASET_INITIALIZER[stage]
33
+ dataset_root_dir = dataset_cfg.dataset_root_dir
34
+ collator = PaddedCollatorForLanguageModeling(
35
+ tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side
36
+ )
37
+
38
+ # Switch on `stage`
39
+ if stage == "align":
40
+ annotation_json, image_dir = dataset_cfg.align_stage_components
41
+ dataset = dataset_cls(
42
+ dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer
43
+ )
44
+ return dataset, collator
45
+
46
+ elif stage == "finetune":
47
+ annotation_json, image_dir = dataset_cfg.finetune_stage_components
48
+ dataset = dataset_cls(
49
+ dataset_root_dir / annotation_json,
50
+ dataset_root_dir / image_dir,
51
+ image_transform,
52
+ tokenizer,
53
+ prompt_builder_fn=prompt_builder_fn,
54
+ )
55
+ return dataset, collator
56
+
57
+ elif stage == "full-finetune":
58
+ annotation_json, image_dir = dataset_cfg.finetune_stage_components
59
+ dataset = dataset_cls(
60
+ dataset_root_dir / annotation_json,
61
+ dataset_root_dir / image_dir,
62
+ image_transform,
63
+ tokenizer,
64
+ prompt_builder_fn=prompt_builder_fn,
65
+ )
66
+ return dataset, collator
67
+
68
+ else:
69
+ raise ValueError(f"Stage `{stage}` is not supported!")
prismatic/py.typed ADDED
File without changes
prismatic/util/nn_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nn_utils.py
3
+
4
+ Utility functions and PyTorch submodule definitions.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ # === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
12
+ class LinearProjector(nn.Module):
13
+ def __init__(self, vision_dim: int, llm_dim: int) -> None:
14
+ super().__init__()
15
+ self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
16
+
17
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
18
+ return self.projector(img_patches)
19
+
20
+
21
+ class MLPProjector(nn.Module):
22
+ def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None:
23
+ super().__init__()
24
+ if mlp_type == "gelu-mlp":
25
+ self.projector = nn.Sequential(
26
+ nn.Linear(vision_dim, llm_dim, bias=True),
27
+ nn.GELU(),
28
+ nn.Linear(llm_dim, llm_dim, bias=True),
29
+ )
30
+ else:
31
+ raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
32
+
33
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
34
+ return self.projector(img_patches)
35
+
36
+
37
+ class FusedMLPProjector(nn.Module):
38
+ def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None:
39
+ super().__init__()
40
+ self.initial_projection_dim = fused_vision_dim * 4
41
+ if mlp_type == "fused-gelu-mlp":
42
+ self.projector = nn.Sequential(
43
+ nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True),
44
+ nn.GELU(),
45
+ nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
46
+ nn.GELU(),
47
+ nn.Linear(llm_dim, llm_dim, bias=True),
48
+ )
49
+ else:
50
+ raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!")
51
+
52
+ def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
53
+ return self.projector(fused_img_patches)
prismatic/vla/datasets/rlds/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dataset import make_interleaved_dataset, make_single_dataset
prismatic/vla/datasets/rlds/dataset.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dataset.py
3
+
4
+ Core interface script for configuring and initializing RLDS datasets.
5
+ """
6
+
7
+ import copy
8
+ import inspect
9
+ import json
10
+ import random # 导入random模块
11
+ from functools import partial
12
+ from typing import Callable, Dict, List, Optional, Tuple, Union
13
+
14
+ import dlimp as dl
15
+ import numpy as np
16
+ import tensorflow as tf
17
+ import tensorflow_datasets as tfds
18
+
19
+ from prismatic.overwatch import initialize_overwatch
20
+ from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX
21
+ from prismatic.vla.datasets.rlds import obs_transforms, traj_transforms
22
+ from prismatic.vla.datasets.rlds.utils import goal_relabeling, task_augmentation
23
+ from prismatic.vla.datasets.rlds.utils.data_utils import (
24
+ allocate_threads,
25
+ get_dataset_statistics,
26
+ normalize_action_and_proprio,
27
+ pprint_data_mixture,
28
+ tree_map,
29
+ shuffle_dataset, # 新增导入shuffle_dataset函数
30
+ )
31
+
32
+ # Initialize Overwatch =>> Wraps `logging.Logger`
33
+ overwatch = initialize_overwatch(__name__)
34
+
35
+ # # Adds a function to set all random seeds
36
+ # def set_all_seeds(seed):
37
+ # """Set the seeds of all random number generators to ensure reproducibility."""
38
+ # random.seed(seed)
39
+ # np.random.seed(seed)
40
+ # tf.random.set_seed(seed)
41
+ # # Enable TensorFlow deterministic operations (if supported by the TensorFlow version)
42
+ # try:
43
+ # tf.config.experimental.enable_op_determinism()
44
+ # except AttributeError:
45
+ # overwatch.warning("The TensorFlow version does not support enable_op_determinism, and the results may not be fully reproducible.")
46
+
47
+
48
+ # Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch)
49
+ tf.config.set_visible_devices([], "GPU")
50
+
51
+
52
+ # # Try to get seeds from environment variables or global Settings and set them
53
+ # try:
54
+ # from prismatic.training.train_utils import get_global_seed
55
+ # seed = get_global_seed()
56
+ # if seed is not None:
57
+ # set_all_seeds(seed)
58
+ # overwatch.info(f"The Dataset module has been set with a random seed: {seed}")
59
+ # except (ImportError, NameError):
60
+ # overwatch.warning("The global seed setting cannot be obtained, so the data processing may not be fully reproducible.")
61
+
62
+
63
+ # ruff: noqa: B006
64
+ def make_dataset_from_rlds(
65
+ name: str,
66
+ data_dir: str,
67
+ *,
68
+ train: bool,
69
+ shuffle_seed: int,
70
+ standardize_fn: Optional[Callable[[dict], dict]] = None,
71
+ shuffle: bool = True,
72
+ image_obs_keys: Dict[str, Optional[str]] = {},
73
+ depth_obs_keys: Dict[str, Optional[str]] = {},
74
+ state_obs_keys: List[Optional[str]] = (),
75
+ language_key: Optional[str] = None,
76
+ action_proprio_normalization_type: ACTION_PROPRIO_NORMALIZATION_TYPE,
77
+ dataset_statistics: Optional[Union[dict, str]] = None,
78
+ absolute_action_mask: Optional[List[bool]] = None,
79
+ action_normalization_mask: Optional[List[bool]] = None,
80
+ num_parallel_reads: int = tf.data.AUTOTUNE,
81
+ num_parallel_calls: int = tf.data.AUTOTUNE,
82
+ ) -> Tuple[dl.DLataset, dict]:
83
+ """
84
+ This function is responsible for loading a specific RLDS dataset from storage and getting it into a standardized
85
+ format. Yields a dataset of trajectories. Does not include CPU-intensive operations.
86
+
87
+ If `standardize_fn` is provided, it will be applied to each trajectory. This function should get the trajectory
88
+ into a standard format, which includes the keys "observation" and "action". Entry "observation" should be a
89
+ dictionary containing some number of additional keys, which will be extracted into an even more standardized format
90
+ according to the "*_obs_keys" arguments.
91
+
92
+ The `image_obs_keys` and `depth_obs_keys` arguments are mappings from new names to old names, or None in place of an
93
+ old name to insert padding. For example, if after `standardize_fn`, your "observation" dict has RGB images called
94
+ "workspace" and "wrist", and `image_obs_keys={"primary": "workspace", "secondary": None, "wrist": "wrist"}`, then
95
+ the resulting dataset will have an "observation" dict containing the keys "image_primary", "image_secondary", and
96
+ "image_wrist", where "image_primary" corresponds to "workspace", "image_secondary" is a padding image, and
97
+ "image_wrist" corresponds to "wrist".
98
+
99
+ Entry `state_obs_keys` is a list of 1-dimensional proprioceptive keys to concatenate into a single array, which will
100
+ be placed in the "proprio" key of the "observation" dict. A single padding element (zero) will be inserted for each
101
+ None entry.
102
+
103
+ The dataset will also include a "task" dict. If `language_key` is provided, then the "task" dict will contain the
104
+ key "language_instruction", extracted from `traj[language_key]`.
105
+
106
+ Args:
107
+ name (str): The name of the RLDS dataset (usually "name" or "name:version").
108
+ data_dir (str): The path to the data directory.
109
+ train (bool): Whether to use the training or validation split.
110
+ shuffle (bool, optional): Whether to shuffle the file read order (does NOT fully shuffle the dataset, since one
111
+ file usually contains many trajectories)!
112
+ standardize_fn (Callable[[dict], dict], optional): A function that, if provided, will be the first
113
+ thing applied to each trajectory.
114
+ image_obs_keys (Mapping[str, str|None]): Mapping from {new: old} indicating which RGB images to extract from the
115
+ "observation" dict. `new_obs = {f"image_{new}": old_obs[old] for new, old in image_obs_keys.items()}`.
116
+ If a value of `old` is None, inserts a padding image instead (empty string).
117
+ depth_obs_keys (Mapping[str, str|None]): Same as `image_obs_keys`, but for depth images. Keys will be
118
+ prefixed with "depth_" instead of "image_".
119
+ state_obs_keys (Sequence[str|None]): List of 1-dimensional proprioception keys to be extracted from the
120
+ "observation" dict, concatenated, and mapped to "proprio". Inserts 1 element of padding for each None entry.
121
+ language_key (str, optional): If provided, the "task" dict will contain the key "language_instruction",
122
+ extracted from `traj[language_key]`.
123
+ action_proprio_normalization_type (str, optional): The type of normalization to perform on the action,
124
+ proprio, or both. Can be "normal" (mean 0, std 1) or "bounds" (normalized to [-1, 1]).
125
+ dataset_statistics: (dict|str, optional): dict (or path to JSON file) that contains dataset statistics
126
+ for normalization. If `action_proprio_normalization_type` is "normal", this should contain "mean" and
127
+ "std" keys. If `action_proprio_normalization_type` is "bounds", this should contain "min" and "max"
128
+ keys. May also provide "num_transitions" and "num_trajectories" keys for downstream usage (e.g., for
129
+ `make_interleaved_dataset`). If not provided, the statistics will be computed on the fly.
130
+ absolute_action_mask (Sequence[bool], optional): By default, all action dimensions are assumed to be
131
+ relative. This is important for when `future_action_window_size > 0`: actions that are taken
132
+ from beyond the end of the trajectory (or beyond the goal timestep when goal relabeling is used)
133
+ need to be made "neutral" to indicate that the task has been completed. For relative actions,
134
+ "neutral" means zero, but for absolute actions, "neutral" means repeating the last valid action.
135
+ This mask, if provided, indicates which action dimensions are absolute.
136
+ action_normalization_mask (Sequence[bool], optional): If provided, indicates which action dimensions
137
+ should be normalized. For example, you might not want to normalize the gripper action dimension if
138
+ it's always exactly 0 or 1. By default, all action dimensions are normalized.
139
+ num_parallel_reads (int): number of parallel read workers. Default to AUTOTUNE.
140
+ num_parallel_calls (int): number of parallel calls for traj_map operations. Default to AUTOTUNE.
141
+ Returns:
142
+ Dataset of trajectories where each step has the following fields:
143
+ - observation:
144
+ - image_{name1, name2, ...} # RGB image observations
145
+ - depth_{name1, name2, ...} # depth image observations
146
+ - proprio # 1-dimensional array of proprioceptive observations
147
+ - timestep # timestep of each frame
148
+ - task:
149
+ - language_instruction # language instruction, present if `language_key` is provided
150
+ - action # action vector
151
+ - dataset_name # name of the dataset
152
+ """
153
+ REQUIRED_KEYS = {"observation", "action"}
154
+ if language_key is not None:
155
+ REQUIRED_KEYS.add(language_key)
156
+
157
+ def restructure(traj):
158
+ # apply a standardization function, if provided
159
+ if standardize_fn is not None:
160
+ traj = standardize_fn(traj)
161
+
162
+ if not all(k in traj for k in REQUIRED_KEYS):
163
+ raise ValueError(
164
+ f"Trajectory is missing keys: {REQUIRED_KEYS - set(traj.keys())}. " "Did you write a `standardize_fn`?"
165
+ )
166
+
167
+ # extracts images, depth images and proprio from the "observation" dict
168
+ traj_len = tf.shape(traj["action"])[0]
169
+ old_obs = traj["observation"]
170
+ new_obs = {}
171
+ for new, old in image_obs_keys.items():
172
+ if old is None:
173
+ new_obs[f"image_{new}"] = tf.repeat("", traj_len) # padding
174
+ else:
175
+ new_obs[f"image_{new}"] = old_obs[old]
176
+
177
+ for new, old in depth_obs_keys.items():
178
+ if old is None:
179
+ new_obs[f"depth_{new}"] = tf.repeat("", traj_len) # padding
180
+ else:
181
+ new_obs[f"depth_{new}"] = old_obs[old]
182
+
183
+ if state_obs_keys:
184
+ new_obs["proprio"] = tf.concat(
185
+ [
186
+ (
187
+ tf.zeros((traj_len, 1), dtype=tf.float32) # padding
188
+ if key is None
189
+ else tf.cast(old_obs[key], tf.float32)
190
+ )
191
+ for key in state_obs_keys
192
+ ],
193
+ axis=1,
194
+ )
195
+
196
+ # add timestep info
197
+ new_obs["timestep"] = tf.range(traj_len)
198
+
199
+ # extracts `language_key` into the "task" dict
200
+ task = {}
201
+ if language_key is not None:
202
+ if traj[language_key].dtype != tf.string:
203
+ raise ValueError(
204
+ f"Language key {language_key} has dtype {traj[language_key].dtype}, " "but it must be tf.string."
205
+ )
206
+ task["language_instruction"] = traj.pop(language_key)
207
+
208
+ traj = {
209
+ "observation": new_obs,
210
+ "task": task,
211
+ "action": tf.cast(traj["action"], tf.float32),
212
+ "dataset_name": tf.repeat(name, traj_len),
213
+ }
214
+
215
+ if absolute_action_mask is not None:
216
+ if len(absolute_action_mask) != traj["action"].shape[-1]:
217
+ raise ValueError(
218
+ f"Length of absolute_action_mask ({len(absolute_action_mask)}) "
219
+ f"does not match action dimension ({traj['action'].shape[-1]})."
220
+ )
221
+ traj["absolute_action_mask"] = tf.tile(
222
+ tf.convert_to_tensor(absolute_action_mask, dtype=tf.bool)[None],
223
+ [traj_len, 1],
224
+ )
225
+
226
+ return traj
227
+
228
+ builder = tfds.builder(name, data_dir=data_dir)
229
+
230
+ # load or compute dataset statistics
231
+ if isinstance(dataset_statistics, str):
232
+ with tf.io.gfile.GFile(dataset_statistics, "r") as f:
233
+ dataset_statistics = json.load(f)
234
+ elif dataset_statistics is None:
235
+ full_dataset = dl.DLataset.from_rlds(
236
+ builder, split="all", shuffle=False, num_parallel_reads=num_parallel_reads
237
+ ).traj_map(restructure, num_parallel_calls)
238
+ # tries to load from cache, otherwise computes on the fly
239
+ dataset_statistics = get_dataset_statistics(
240
+ full_dataset,
241
+ hash_dependencies=(
242
+ str(builder.info),
243
+ str(state_obs_keys),
244
+ inspect.getsource(standardize_fn) if standardize_fn is not None else "",
245
+ ),
246
+ save_dir=builder.data_dir,
247
+ )
248
+ dataset_statistics = tree_map(np.array, dataset_statistics)
249
+
250
+ # skip normalization for certain action dimensions
251
+ if action_normalization_mask is not None:
252
+ if len(action_normalization_mask) != dataset_statistics["action"]["mean"].shape[-1]:
253
+ raise ValueError(
254
+ f"Length of skip_normalization_mask ({len(action_normalization_mask)}) "
255
+ f"does not match action dimension ({dataset_statistics['action']['mean'].shape[-1]})."
256
+ )
257
+ dataset_statistics["action"]["mask"] = np.array(action_normalization_mask)
258
+
259
+ # construct the dataset
260
+ split = "train" if train else "val"
261
+
262
+ dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads, shuffle_seed=shuffle_seed)
263
+
264
+ dataset = dataset.traj_map(restructure, num_parallel_calls)
265
+ dataset = dataset.traj_map(
266
+ partial(
267
+ normalize_action_and_proprio,
268
+ metadata=dataset_statistics,
269
+ normalization_type=action_proprio_normalization_type,
270
+ ),
271
+ num_parallel_calls,
272
+ )
273
+
274
+ return dataset, dataset_statistics
275
+
276
+
277
+ def apply_trajectory_transforms(
278
+ dataset: dl.DLataset,
279
+ *,
280
+ train: bool,
281
+ goal_relabeling_strategy: Optional[str] = None,
282
+ goal_relabeling_kwargs: dict = {},
283
+ window_size: int = 1,
284
+ future_action_window_size: int = 0,
285
+ subsample_length: Optional[int] = None,
286
+ skip_unlabeled: bool = False,
287
+ max_action: Optional[float] = None,
288
+ max_proprio: Optional[float] = None,
289
+ task_augment_strategy: Optional[str] = None,
290
+ task_augment_kwargs: dict = {},
291
+ num_parallel_calls: int = tf.data.AUTOTUNE,
292
+ use_predict_future_prop: bool = False,
293
+ ) -> dl.DLataset:
294
+ """
295
+ Applies common transforms that happen at a trajectory level. Such transforms are usually some sort of "relabeling"
296
+ (e.g., filtering, chunking, adding goals, dropping keys).
297
+
298
+ Transforms in this function should have the following properties:
299
+ - They require access to an entire trajectory (i.e., they cannot be applied frame-wise).
300
+ - They are generally not CPU-intensive, mostly involving moving and copying data.
301
+ - They do not require decoded images.
302
+
303
+ Args:
304
+ dataset (dl.DLataset): The dataset to transform.
305
+ train (bool): Whether the dataset is for training (affects subsampling).
306
+ goal_relabeling_strategy (str, optional): The goal relabeling strategy to use, or None for
307
+ no goal relabeling. See `goal_relabeling.py`.
308
+ goal_relabeling_kwargs (dict, optional): Additional keyword arguments to pass to the goal relabeling function.
309
+ window_size (int, optional): The length of the snippets that trajectories are chunked into.
310
+ future_action_window_size (int, optional): The number of future actions beyond window_size to include
311
+ in the chunked actions.
312
+ subsample_length (int, optional): If provided, trajectories longer than this will be subsampled to
313
+ this length (after goal relabeling and chunking).
314
+ skip_unlabeled (bool, optional): Whether to skip trajectories with no language labels.
315
+ max_action: (float, optional): If provided, trajectories in which *any* action dimension
316
+ of *any* transition has an absolute value larger than this will be skipped.
317
+ max_proprio: (float, optional): If provided, trajectories in which *any* proprio dimension
318
+ of *any* transition has an absolute value larger than this will be skipped.
319
+ task_augment_strategy (str, optional): The task augmentation strategy to use, or None for no task
320
+ augmentation. See `task_augmentation.py`.
321
+ task_augment_kwargs (dict, optional): Additional keyword arguments to pass to the task augmentation
322
+ function.
323
+ num_parallel_calls (int, optional): number of parallel calls for map operations. Default to AUTOTUNE.
324
+ """
325
+ if skip_unlabeled:
326
+ if "language_instruction" not in dataset.element_spec["task"]:
327
+ raise ValueError("skip_unlabeled=True but dataset does not have language labels.")
328
+
329
+ dataset = dataset.filter(lambda x: tf.math.reduce_any(x["task"]["language_instruction"] != ""))
330
+
331
+ if max_action is not None:
332
+ dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["action"]) <= max_action))
333
+
334
+ if max_proprio is not None and "proprio" in dataset.element_spec["observation"]:
335
+ dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["observation"]["proprio"]) <= max_proprio))
336
+
337
+ # Filter out trajectories that are too short for action chunking
338
+ # Required minimum length: window_size + future_action_window_size
339
+ # required_min_length = window_size + future_action_window_size
340
+ # if required_min_length > 1:
341
+ # overwatch.info(f"Filtering trajectories shorter than {required_min_length} steps for action chunking (window_size={window_size}, future_action_window_size={future_action_window_size})")
342
+
343
+ # # Quick statistics: sample a subset of data to estimate filtering ratio
344
+ # try:
345
+ # sample_size = 1000 # Number of samples
346
+ # before_sample = dataset.take(sample_size)
347
+
348
+ # # Count total and valid trajectories in the sample
349
+ # total_sampled = 0
350
+ # valid_sampled = 0
351
+
352
+ # for item in before_sample:
353
+ # total_sampled += 1
354
+ # traj_length = tf.shape(item["action"])[0].numpy()
355
+ # if traj_length >= required_min_length:
356
+ # valid_sampled += 1
357
+
358
+ # if total_sampled > 0:
359
+ # filter_ratio = valid_sampled / total_sampled
360
+ # filtered_ratio = (total_sampled - valid_sampled) / total_sampled
361
+ # overwatch.info(f"Sample statistics ({sample_size} trajectories): keep rate {filter_ratio:.2%}, filter rate {filtered_ratio:.2%}")
362
+ # overwatch.info(f"Estimated ~{filtered_ratio:.1%} of trajectories will be filtered due to insufficient length")
363
+ # else:
364
+ # overwatch.info("Unable to obtain sample data for statistics")
365
+
366
+ # except Exception as e:
367
+ # overwatch.warning(f"Error during quick statistics: {e}, continuing with filtering operation")
368
+
369
+ # Execute the actual filtering operation
370
+ # dataset = dataset.filter(lambda x: tf.shape(x["action"])[0] >= required_min_length)
371
+ # overwatch.info("Trajectory length filtering completed")
372
+ # marks which entires of the observation and task dicts are padding
373
+ dataset = dataset.traj_map(traj_transforms.add_pad_mask_dict, num_parallel_calls)
374
+
375
+ # updates the "task" dict
376
+ if goal_relabeling_strategy is not None:
377
+ dataset = dataset.traj_map(
378
+ partial(getattr(goal_relabeling, goal_relabeling_strategy), **goal_relabeling_kwargs),
379
+ num_parallel_calls,
380
+ )
381
+
382
+ # must run task augmentation before chunking, in case it changes goal timesteps
383
+ if train and task_augment_strategy is not None:
384
+ # perform task augmentation (e.g., dropping keys)
385
+ dataset = dataset.traj_map(
386
+ partial(
387
+ getattr(task_augmentation, task_augment_strategy),
388
+ **task_augment_kwargs,
389
+ ),
390
+ num_parallel_calls,
391
+ )
392
+
393
+ # chunks observations and actions, giving them a new axis at index 1 of size `window_size` and
394
+ # `window_size + future_action_window_size`, respectively
395
+ if use_predict_future_prop:
396
+ traj_transforms_strategy = traj_transforms.chunk_act_future_obs
397
+ else:
398
+ traj_transforms_strategy = traj_transforms.chunk_act_obs
399
+
400
+ dataset = dataset.traj_map(
401
+ partial(
402
+ traj_transforms_strategy,
403
+ window_size=window_size,
404
+ future_action_window_size=future_action_window_size,
405
+ ),
406
+ num_parallel_calls,
407
+ )
408
+
409
+ if train and subsample_length is not None:
410
+ dataset = dataset.traj_map(
411
+ partial(traj_transforms.subsample, subsample_length=subsample_length),
412
+ num_parallel_calls,
413
+ )
414
+
415
+ return dataset
416
+
417
+
418
+ def apply_per_dataset_frame_transforms(
419
+ dataset: dl.DLataset,
420
+ chunk_filter_fn: Optional[Callable] = None,
421
+ ):
422
+ """
423
+ Optionally applied *per-dataset* transforms that happen at a frame level.
424
+
425
+ Args:
426
+ chunk_filter_fn (callable, optional): Filter function for chunks.
427
+ """
428
+ if chunk_filter_fn:
429
+ dataset = dataset.filter(chunk_filter_fn)
430
+ return dataset
431
+
432
+
433
+ def apply_frame_transforms(
434
+ dataset: dl.DLataset,
435
+ *,
436
+ train: bool,
437
+ image_augment_kwargs: Union[Dict, Dict[str, Dict]] = {},
438
+ resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {},
439
+ depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {},
440
+ num_parallel_calls: int = tf.data.AUTOTUNE,
441
+ ) -> dl.DLataset:
442
+ """
443
+ Applies common transforms that happen at a frame level. These transforms are usually more CPU-intensive, (e.g.,
444
+ decoding or resizing images).
445
+
446
+ Args:
447
+ train (bool): Whether the dataset is for training (affects image augmentation).
448
+ dataset (dl.DLataset): The dataset to transform.
449
+ image_augment_kwargs (dict|Mapping[str, dict]): Keyword arguments to pass to the image augmentation
450
+ function. See `dlimp.transforms.augment_image` for documentation of these kwargs. If a dict of
451
+ dicts is provided, then key "k" will be used for "image_{k}" (names determined by `image_obs_keys`
452
+ in `make_dataset_from_rlds`). Augmentation will be skipped for missing keys (so pass an empty dict
453
+ to skip augmentation for all images).
454
+ resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): If provided, images will be resized to
455
+ this size. If a dict of tuples is provided, then key "k" will be used for "image_{k}" (names
456
+ determined by `image_obs_keys` in `make_dataset_from_rlds`). Resizing will be skipped for missing
457
+ keys (so pass an empty dict to skip resizing for all images).
458
+ depth_resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): Same as resize_size, but for depth
459
+ images.
460
+ num_parallel_calls (int): number of parallel calls for frame_map operations. Default to AUTOTUNE.
461
+ """
462
+
463
+ # Convenience wrapper that takes a function that operates on a non-chunked "observation" dict and applies
464
+ # it to the chunked "observation" dict as well as the non-chunked "task" dict
465
+ def apply_obs_transform(fn: Callable[[Dict], Dict], frame: Dict) -> Dict:
466
+ frame["task"] = fn(frame["task"])
467
+ frame["observation"] = dl.vmap(fn)(frame["observation"])
468
+ return frame
469
+
470
+ # Decode + resize images (and depth images)
471
+ dataset = dataset.frame_map(
472
+ partial(
473
+ apply_obs_transform,
474
+ partial(obs_transforms.decode_and_resize, resize_size=resize_size, depth_resize_size=depth_resize_size),
475
+ ),
476
+ num_parallel_calls,
477
+ )
478
+
479
+ if train:
480
+ # Augment all images with the same seed, skipping padding images
481
+ def aug(frame: dict):
482
+ seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32)
483
+ aug_fn = partial(obs_transforms.augment, seed=seed, augment_kwargs=image_augment_kwargs)
484
+ return apply_obs_transform(aug_fn, frame)
485
+
486
+ dataset = dataset.frame_map(aug, num_parallel_calls)
487
+
488
+ return dataset
489
+
490
+
491
+ def make_single_dataset(
492
+ dataset_kwargs: dict,
493
+ *,
494
+ train: bool,
495
+ traj_transform_kwargs: dict = {},
496
+ frame_transform_kwargs: dict = {},
497
+ ) -> dl.DLataset:
498
+ """Creates a single dataset from kwargs. Returns a dataset of trajectories.
499
+
500
+ Args:
501
+ dataset_kwargs: kwargs passed to `make_dataset_from_rlds` that are dataset-specific.
502
+ train: whether this is a training or validation dataset.
503
+ traj_transform_kwargs: kwargs passed to 'apply_trajectory_transforms'.
504
+ frame_transform_kwargs: kwargs passed to 'get_frame_transforms'.
505
+ """
506
+ dataset, dataset_statistics = make_dataset_from_rlds(
507
+ **dataset_kwargs,
508
+ train=train,
509
+ )
510
+ dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train)
511
+ dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)
512
+
513
+ # this seems to reduce memory usage without affecting speed
514
+ dataset = dataset.with_ram_budget(1)
515
+
516
+ # save for later
517
+ return dataset, dataset_statistics["num_trajectories"], dataset_statistics
518
+
519
+
520
+ # === Core Initializer ===
521
+ def make_interleaved_dataset(
522
+ dataset_kwargs_list: List[Dict],
523
+ sample_weights: Optional[List[float]] = None,
524
+ *,
525
+ train: bool,
526
+ shuffle_buffer_size: int,
527
+ shuffle_seed:int,
528
+ traj_transform_kwargs: Optional[Dict] = None,
529
+ frame_transform_kwargs: Optional[Dict] = None,
530
+ batch_size: Optional[int] = None,
531
+ balance_weights: bool = False,
532
+ traj_transform_threads: Optional[int] = None,
533
+ traj_read_threads: Optional[int] = None,
534
+ ) -> dl.DLataset:
535
+ """
536
+ Creates an interleaved dataset from list of dataset configs (kwargs). Returns a dataset of batched frames.
537
+
538
+ Args:
539
+ dataset_kwargs_list: list of kwargs, each element of which is passed to `make_dataset_from_rlds`.
540
+ "num_parallel_calls" and "num_parallel_reads" are overridden using `traj_transform_threads` and
541
+ `traj_read_threads`, respectively.
542
+ sample_weights: sampling weights for each dataset in list. If None, defaults to uniform.
543
+ train: whether this is a training or validation dataset.
544
+ shuffle_buffer_size: size of the dataset shuffle buffer (in number of frames).
545
+ traj_transform_kwargs: kwargs passed to `apply_trajectory_transforms`. "num_parallel_calls" is
546
+ overridden using `traj_transform_threads`.
547
+ frame_transform_kwargs: kwargs passed to `apply_frame_transforms`.
548
+ batch_size: batch size, if not provided output is not batched.
549
+ balance_weights: if True, the sample weights are multiplied by the number of frames in each dataset.
550
+ This makes it so that, if all the sample weights are equal, one full iteration through the interleaved
551
+ dataset will correspond to one full iteration through each individual dataset (only in expectation,
552
+ since in practice the sampling is random).
553
+ traj_transform_threads: total number of parallel calls for trajectory transforms, distributed across
554
+ datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset.
555
+ traj_read_threads: total number of parallel read workers for trajectory transforms, distributed across
556
+ datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset.
557
+ """
558
+ # Default to uniform sampling (if `sample_weights` is not specified)
559
+
560
+ if not sample_weights:
561
+ sample_weights = [1.0] * len(dataset_kwargs_list)
562
+
563
+ if len(sample_weights) != len(dataset_kwargs_list):
564
+ raise ValueError(f"sample_weights must be None or have length {len(dataset_kwargs_list)}.")
565
+
566
+ # Check valid `traj_transform_kwargs` and `frame_transform_kwargs`
567
+ if (traj_transform_kwargs is None) or (frame_transform_kwargs is None):
568
+ raise ValueError("Missing `traj_transform_kwargs` and `frame_transform_kwargs`!")
569
+
570
+ # Get Dataset Sizes
571
+ dataset_sizes, all_dataset_statistics = [], {}
572
+ for dataset_kwargs in dataset_kwargs_list:
573
+ data_kwargs = copy.deepcopy(dataset_kwargs)
574
+ if "dataset_frame_transform_kwargs" in data_kwargs:
575
+ data_kwargs.pop("dataset_frame_transform_kwargs")
576
+ _, dataset_statistics = make_dataset_from_rlds(**data_kwargs, train=train, shuffle_seed = shuffle_seed)
577
+ dataset_sizes.append(dataset_statistics["num_transitions"])
578
+ all_dataset_statistics[dataset_kwargs["name"]] = dataset_statistics
579
+
580
+ # Get the indices of the "primary" datasets (i.e., datasets with sample_weight == 1.0)
581
+ primary_dataset_indices = np.array([idx for idx in range(len(sample_weights)) if sample_weights[idx] == 1.0])
582
+
583
+ # Balance and Normalize Weights
584
+ if balance_weights:
585
+ sample_weights = np.array(sample_weights) * np.array(dataset_sizes)
586
+ sample_weights = np.array(sample_weights) / np.sum(sample_weights)
587
+ pprint_data_mixture(dataset_kwargs_list, sample_weights)
588
+
589
+ # Effective Dataset Length = Number of samples until each dataset has completed at least one epoch
590
+ # =>> Note :: Only counting the "primary" datasets (i.e., datasets with sample_weight == 1.0)
591
+ dataset_len = int((np.array(dataset_sizes) / sample_weights)[primary_dataset_indices].max())
592
+
593
+ # Allocate Threads based on Weights
594
+ threads_per_dataset = allocate_threads(traj_transform_threads, sample_weights)
595
+ reads_per_dataset = allocate_threads(traj_read_threads, sample_weights)
596
+
597
+ overwatch.info("Threads per Dataset: %s", threads_per_dataset)
598
+ overwatch.info("Reads per Dataset: %s", reads_per_dataset)
599
+
600
+ # Construct Datasets
601
+ overwatch.info("Constructing datasets...")
602
+ datasets = []
603
+ for dataset_kwargs, threads, reads in zip(
604
+ dataset_kwargs_list,
605
+ threads_per_dataset,
606
+ reads_per_dataset,
607
+ ):
608
+ dataset_frame_transform_kwargs = (
609
+ dataset_kwargs.pop("dataset_frame_transform_kwargs")
610
+ if "dataset_frame_transform_kwargs" in dataset_kwargs
611
+ else {}
612
+ )
613
+ dataset, _ = make_dataset_from_rlds(
614
+ **dataset_kwargs,
615
+ train=train,
616
+ shuffle_seed=shuffle_seed,
617
+ num_parallel_calls=threads,
618
+ num_parallel_reads=reads,
619
+ dataset_statistics=all_dataset_statistics[dataset_kwargs["name"]],
620
+ )
621
+ dataset = apply_trajectory_transforms(
622
+ dataset.repeat(),
623
+ **traj_transform_kwargs,
624
+ num_parallel_calls=threads,
625
+ train=train,
626
+ ).flatten(num_parallel_calls=threads)
627
+ dataset = apply_per_dataset_frame_transforms(dataset, **dataset_frame_transform_kwargs)
628
+ datasets.append(dataset)
629
+
630
+ # Interleave at the Frame Level
631
+ dataset: dl.DLataset = dl.DLataset.sample_from_datasets(datasets, sample_weights, seed=shuffle_seed)
632
+
633
+ # Validation =>> fix a single shuffle buffer of data and cache it in RAM; prevents gradual memory increase!
634
+ if not train:
635
+ dataset = dataset.take(shuffle_buffer_size).cache()
636
+
637
+ # Shuffle the Dataset
638
+ # =>> IMPORTANT :: Shuffle AFTER .cache(), or else memory will still leak!
639
+ dataset = dataset.shuffle(shuffle_buffer_size, seed=shuffle_seed)
640
+
641
+ # Apply Frame Transforms
642
+ overwatch.info("Applying frame transforms on dataset...")
643
+ dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)
644
+
645
+ # [Contract] When training VLA Policies, we let the Collator handle Batching!
646
+ if batch_size is not None:
647
+ dataset = dataset.batch(batch_size)
648
+
649
+ # Note =>> Seems to reduce memory usage without affecting speed?
650
+ dataset = dataset.with_ram_budget(1)
651
+
652
+ # Save for Later
653
+ dataset.sample_weights = sample_weights
654
+
655
+ return dataset, dataset_len, all_dataset_statistics
prismatic/vla/datasets/rlds/oxe/transforms.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ transforms.py
3
+
4
+ Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
5
+
6
+ Transforms adopt the following structure:
7
+ Input: Dictionary of *batched* features (i.e., has leading time dimension)
8
+ Output: Dictionary `step` =>> {
9
+ "observation": {
10
+ <image_keys, depth_image_keys>
11
+ State (in chosen state representation)
12
+ },
13
+ "action": Action (in chosen action representation),
14
+ "language_instruction": str
15
+ }
16
+ """
17
+
18
+ from typing import Any, Dict
19
+
20
+ import tensorflow as tf
21
+
22
+ from prismatic.vla.datasets.rlds.oxe.utils.droid_utils import droid_baseact_transform, droid_finetuning_transform
23
+ from prismatic.vla.datasets.rlds.utils.data_utils import (
24
+ binarize_gripper_actions,
25
+ invert_gripper_actions,
26
+ rel2abs_gripper_actions,
27
+ relabel_bridge_actions,
28
+ )
29
+
30
+
31
+ def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
32
+ """
33
+ Applies to version of Bridge V2 in Open X-Embodiment mixture.
34
+
35
+ Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
36
+ """
37
+ for key in trajectory.keys():
38
+ if key == "traj_metadata":
39
+ continue
40
+ elif key in ["observation", "action"]:
41
+ for key2 in trajectory[key]:
42
+ trajectory[key][key2] = trajectory[key][key2][1:]
43
+ else:
44
+ trajectory[key] = trajectory[key][1:]
45
+
46
+ trajectory["action"] = tf.concat(
47
+ (
48
+ trajectory["action"]["world_vector"],
49
+ trajectory["action"]["rotation_delta"],
50
+ tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
51
+ ),
52
+ axis=-1,
53
+ )
54
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
55
+ trajectory = relabel_bridge_actions(trajectory)
56
+ trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
57
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
58
+ return trajectory
59
+
60
+
61
+ def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
62
+ """
63
+ Applies to original version of Bridge V2 from the official project website.
64
+
65
+ Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
66
+ """
67
+ for key in trajectory.keys():
68
+ if key == "traj_metadata":
69
+ continue
70
+ elif key == "observation":
71
+ for key2 in trajectory[key]:
72
+ trajectory[key][key2] = trajectory[key][key2][1:]
73
+ else:
74
+ trajectory[key] = trajectory[key][1:]
75
+
76
+ trajectory["action"] = tf.concat(
77
+ [
78
+ trajectory["action"][:, :6],
79
+ binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
80
+ ],
81
+ axis=1,
82
+ )
83
+ trajectory = relabel_bridge_actions(trajectory)
84
+ trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
85
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
86
+ return trajectory
87
+
88
+
89
+ def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
90
+ trajectory["action"] = tf.concat(
91
+ [
92
+ trajectory["action"][:, :6],
93
+ binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
94
+ ],
95
+ axis=1,
96
+ )
97
+ trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
98
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
99
+ return trajectory
100
+
101
+
102
+ def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
103
+ # make gripper action absolute action, +1 = open, 0 = close
104
+ gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
105
+ gripper_action = rel2abs_gripper_actions(gripper_action)
106
+
107
+ trajectory["action"] = tf.concat(
108
+ (
109
+ trajectory["action"]["world_vector"],
110
+ trajectory["action"]["rotation_delta"],
111
+ gripper_action[:, None],
112
+ ),
113
+ axis=-1,
114
+ )
115
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
116
+ return trajectory
117
+
118
+
119
+ def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
120
+ # make gripper action absolute action, +1 = open, 0 = close
121
+ gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
122
+ gripper_action = rel2abs_gripper_actions(gripper_action)
123
+
124
+ trajectory["action"] = tf.concat(
125
+ (
126
+ trajectory["action"]["world_vector"],
127
+ trajectory["action"]["rotation_delta"],
128
+ gripper_action[:, None],
129
+ ),
130
+ axis=-1,
131
+ )
132
+ # decode compressed state
133
+ eef_value = tf.io.decode_compressed(
134
+ trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
135
+ compression_type="ZLIB",
136
+ )
137
+ eef_value = tf.io.decode_raw(eef_value, tf.float32)
138
+ trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
139
+ gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB")
140
+ gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
141
+ trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
142
+ # trajectory["language_instruction"] = tf.fill(
143
+ # tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
144
+ # ) # delete uninformative language instruction
145
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
146
+ return trajectory
147
+
148
+
149
+ def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
150
+ trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
151
+ trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
152
+ trajectory["action"] = trajectory["action"]["rel_actions_world"]
153
+
154
+ # invert gripper action + clip, +1 = open, 0 = close
155
+ trajectory["action"] = tf.concat(
156
+ (
157
+ trajectory["action"][:, :6],
158
+ tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
159
+ ),
160
+ axis=-1,
161
+ )
162
+
163
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
164
+ return trajectory
165
+
166
+
167
+ def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
168
+ trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
169
+ trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:]
170
+
171
+ # make gripper action absolute action, +1 = open, 0 = close
172
+ gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
173
+ gripper_action = rel2abs_gripper_actions(gripper_action)
174
+
175
+ trajectory["action"] = tf.concat(
176
+ (
177
+ trajectory["action"]["world_vector"],
178
+ tf.zeros_like(trajectory["action"]["world_vector"]),
179
+ gripper_action[:, None],
180
+ ),
181
+ axis=-1,
182
+ )
183
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
184
+ return trajectory
185
+
186
+
187
+ def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
188
+ trajectory["action"] = tf.concat(
189
+ (
190
+ trajectory["action"]["world_vector"],
191
+ trajectory["action"]["rotation_delta"],
192
+ tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
193
+ ),
194
+ axis=-1,
195
+ )
196
+ # trajectory["language_instruction"] = tf.fill(
197
+ # tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
198
+ # ) # delete uninformative language instruction
199
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
200
+ return trajectory
201
+
202
+
203
+ def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
204
+ # invert absolute gripper action, +1 = open, 0 = close
205
+ gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1))
206
+
207
+ trajectory["action"] = tf.concat(
208
+ (
209
+ trajectory["action"]["world_vector"],
210
+ trajectory["action"]["rotation_delta"],
211
+ gripper_action,
212
+ ),
213
+ axis=-1,
214
+ )
215
+ # trajectory["language_instruction"] = tf.fill(
216
+ # tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
217
+ # ) # delete uninformative language instruction
218
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
219
+ return trajectory
220
+
221
+
222
+ def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
223
+ # make gripper action absolute action, +1 = open, 0 = close
224
+ gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
225
+ gripper_action = rel2abs_gripper_actions(gripper_action)
226
+
227
+ trajectory["action"] = tf.concat(
228
+ (
229
+ trajectory["action"]["world_vector"],
230
+ trajectory["action"]["rotation_delta"],
231
+ gripper_action[:, None],
232
+ ),
233
+ axis=-1,
234
+ )
235
+ # trajectory["language_instruction"] = tf.fill(
236
+ # tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
237
+ # ) # delete uninformative language instruction
238
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
239
+ return trajectory
240
+
241
+
242
+ def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
243
+ # make gripper action, +1 = open, 0 = close
244
+ gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
245
+ gripper_action = tf.clip_by_value(gripper_action, 0, 1)
246
+ gripper_action = invert_gripper_actions(gripper_action)
247
+
248
+ trajectory["action"] = tf.concat(
249
+ (
250
+ trajectory["action"]["world_vector"],
251
+ trajectory["action"]["rotation_delta"],
252
+ gripper_action,
253
+ ),
254
+ axis=-1,
255
+ )
256
+ # trajectory["language_instruction"] = tf.fill(
257
+ # tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
258
+ # ) # delete uninformative language instruction
259
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
260
+ return trajectory
261
+
262
+
263
+ def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
264
+ trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
265
+ trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth")
266
+
267
+ # make gripper action absolute action, +1 = open, 0 = close
268
+ gripper_action = trajectory["action"]["gripper_closedness_action"]
269
+ gripper_action = rel2abs_gripper_actions(gripper_action)
270
+
271
+ trajectory["action"] = tf.concat(
272
+ (
273
+ trajectory["action"]["world_vector"],
274
+ trajectory["action"]["rotation_delta"],
275
+ gripper_action[:, None],
276
+ ),
277
+ axis=-1,
278
+ )
279
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
280
+ return trajectory
281
+
282
+
283
+ def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
284
+ trajectory["action"] = tf.concat(
285
+ (
286
+ trajectory["action"]["world_vector"],
287
+ trajectory["action"]["rotation_delta"],
288
+ tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
289
+ ),
290
+ axis=-1,
291
+ )
292
+ # trajectory["language_instruction"] = tf.fill(
293
+ # tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
294
+ # ) # delete uninformative language instruction
295
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
296
+ return trajectory
297
+
298
+
299
+ def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
300
+ # default to "open" gripper
301
+ trajectory["action"] = tf.concat(
302
+ (
303
+ trajectory["action"],
304
+ tf.zeros_like(trajectory["action"]),
305
+ tf.zeros_like(trajectory["action"]),
306
+ tf.ones_like(trajectory["action"][:, :1]),
307
+ ),
308
+ axis=-1,
309
+ )
310
+
311
+ # decode language instruction
312
+ instruction_bytes = trajectory["observation"]["instruction"]
313
+ instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
314
+ # Remove trailing padding --> convert RaggedTensor to regular Tensor.
315
+ trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0]
316
+ return trajectory
317
+
318
+
319
+ def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
320
+ trajectory["action"] = tf.concat(
321
+ (
322
+ trajectory["action"]["world_vector"],
323
+ trajectory["action"]["rotation_delta"],
324
+ trajectory["action"]["gripper_closedness_action"][:, None],
325
+ ),
326
+ axis=-1,
327
+ )
328
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
329
+ return trajectory
330
+
331
+
332
+ def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
333
+ trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
334
+ trajectory["action"] = tf.concat(
335
+ (
336
+ trajectory["action"][:, :3],
337
+ tf.zeros_like(trajectory["action"][:, :3]),
338
+ trajectory["action"][:, -1:],
339
+ ),
340
+ axis=-1,
341
+ )
342
+ return trajectory
343
+
344
+
345
+ def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
346
+ trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
347
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
348
+ trajectory["action"] = trajectory["action"][..., :7]
349
+ return trajectory
350
+
351
+
352
+ def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
353
+ # invert gripper action, +1 = open, 0 = close
354
+ trajectory["action"] = tf.concat(
355
+ (
356
+ trajectory["action"][:, :6],
357
+ invert_gripper_actions(trajectory["action"][:, -1:]),
358
+ ),
359
+ axis=-1,
360
+ )
361
+
362
+ trajectory["observation"]["eef_state"] = tf.concat(
363
+ (
364
+ trajectory["observation"]["state"][:, :3],
365
+ trajectory["observation"]["state"][:, 7:10],
366
+ ),
367
+ axis=-1,
368
+ )
369
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
370
+ # trajectory["language_instruction"] = tf.fill(
371
+ # tf.shape(trajectory["language_instruction"]), ""
372
+ # ) # delete uninformative language instruction
373
+ return trajectory
374
+
375
+
376
+ def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
377
+ # invert gripper action + clip, +1 = open, 0 = close
378
+ trajectory["action"] = tf.concat(
379
+ (
380
+ trajectory["action"][:, :6],
381
+ invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
382
+ ),
383
+ axis=-1,
384
+ )
385
+
386
+ trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
387
+ # trajectory["language_instruction"] = tf.fill(
388
+ # tf.shape(trajectory["language_instruction"]), ""
389
+ # ) # delete uninformative language instruction
390
+ return trajectory
391
+
392
+
393
+ def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
394
+ trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
395
+ trajectory["observation"]["depth_additional_view"] = tf.cast(
396
+ trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
397
+ )
398
+ trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
399
+
400
+ # clip gripper action, +1 = open, 0 = close
401
+ trajectory["action"] = tf.concat(
402
+ (
403
+ trajectory["action"][:, -8:-2],
404
+ tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
405
+ ),
406
+ axis=-1,
407
+ )
408
+
409
+ # trajectory["language_instruction"] = tf.fill(
410
+ # tf.shape(trajectory["language_instruction"]), ""
411
+ # ) # delete uninformative language instruction
412
+ return trajectory
413
+
414
+
415
+ def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
416
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
417
+ return trajectory
418
+
419
+
420
+ def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
421
+ import tensorflow_graphics.geometry.transformation as tft
422
+
423
+ trajectory["observation"]["state"] = tf.concat(
424
+ (
425
+ trajectory["observation"]["state"][:, :7],
426
+ trajectory["observation"]["state"][:, -1:],
427
+ ),
428
+ axis=-1,
429
+ )
430
+
431
+ # invert gripper action + clip, +1 = open, 0 = close
432
+ trajectory["action"] = tf.concat(
433
+ (
434
+ trajectory["action"][:, :3],
435
+ tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
436
+ invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
437
+ ),
438
+ axis=-1,
439
+ )
440
+ return trajectory
441
+
442
+
443
+ def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
444
+ trajectory["action"] = trajectory["action"][..., :-1]
445
+ return trajectory
446
+
447
+
448
+ def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
449
+ trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
450
+ trajectory["action"] = trajectory["action"][..., :-1]
451
+ return trajectory
452
+
453
+
454
+ def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
455
+ trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
456
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
457
+ trajectory["action"] = tf.concat(
458
+ (
459
+ trajectory["action"][:, :3],
460
+ tf.zeros_like(trajectory["action"][:, :3]),
461
+ trajectory["action"][:, -1:],
462
+ ),
463
+ axis=-1,
464
+ )
465
+ return trajectory
466
+
467
+
468
+ def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
469
+ # invert gripper action + clip, +1 = open, 0 = close
470
+ trajectory["action"] = tf.concat(
471
+ (
472
+ trajectory["action"][:, :6],
473
+ invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
474
+ ),
475
+ axis=-1,
476
+ )
477
+
478
+ # trajectory["language_instruction"] = tf.fill(
479
+ # tf.shape(trajectory["language_instruction"]), ""
480
+ # ) # delete uninformative language instruction
481
+ return trajectory
482
+
483
+
484
+ def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
485
+ # invert gripper action + clip, +1 = open, 0 = close
486
+ trajectory["action"] = tf.concat(
487
+ (
488
+ trajectory["action"][:, :6],
489
+ invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
490
+ ),
491
+ axis=-1,
492
+ )
493
+
494
+ # trajectory["language_instruction"] = tf.fill(
495
+ # tf.shape(trajectory["language_instruction"]), ""
496
+ # ) # delete uninformative language instruction
497
+ return trajectory
498
+
499
+
500
+ def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
501
+ trajectory["action"] = tf.concat(
502
+ (
503
+ trajectory["action"]["future/xyz_residual"][:, :3],
504
+ trajectory["action"]["future/axis_angle_residual"][:, :3],
505
+ invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
506
+ ),
507
+ axis=-1,
508
+ )
509
+ trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
510
+ return trajectory
511
+
512
+
513
+ def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
514
+ trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
515
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
516
+ trajectory["action"] = trajectory["action"][..., :-1]
517
+ return trajectory
518
+
519
+
520
+ def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
521
+ trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
522
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
523
+ trajectory["action"] = trajectory["action"][..., :-1]
524
+ return trajectory
525
+
526
+
527
+ def utokyo_xarm_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
528
+ return trajectory
529
+
530
+
531
+ def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
532
+ trajectory["action"] = trajectory["action"][..., -7:]
533
+ return trajectory
534
+
535
+
536
+ def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
537
+ trajectory["observation"]["eef_state"] = tf.concat(
538
+ (
539
+ trajectory["observation"]["state"][:, :4],
540
+ tf.zeros_like(trajectory["observation"]["state"][:, :2]),
541
+ ),
542
+ axis=-1,
543
+ )
544
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
545
+ trajectory["action"] = tf.concat(
546
+ (
547
+ trajectory["action"][:, :4],
548
+ tf.zeros_like(trajectory["action"][:, :2]),
549
+ trajectory["action"][:, -1:],
550
+ ),
551
+ axis=-1,
552
+ )
553
+ return trajectory
554
+
555
+
556
+ def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
557
+ return trajectory
558
+
559
+
560
+ def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
561
+ return trajectory
562
+
563
+
564
+ def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
565
+ trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
566
+ trajectory["action"] = tf.concat(
567
+ (
568
+ trajectory["action"][:, :6],
569
+ tf.zeros_like(trajectory["action"][:, :1]),
570
+ ),
571
+ axis=-1,
572
+ )
573
+ return trajectory
574
+
575
+
576
+ def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
577
+ trajectory["observation"]["eef_state"] = tf.concat(
578
+ (
579
+ trajectory["observation"]["end_effector_pose"][:, :4],
580
+ tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
581
+ ),
582
+ axis=-1,
583
+ )
584
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
585
+ trajectory["action"] = tf.concat(
586
+ (
587
+ trajectory["action"][:, :4],
588
+ tf.zeros_like(trajectory["action"][:, :2]),
589
+ trajectory["action"][:, -1:],
590
+ ),
591
+ axis=-1,
592
+ )
593
+ return trajectory
594
+
595
+
596
+ def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
597
+ trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
598
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
599
+ return trajectory
600
+
601
+
602
+ def dlr_sara_pour_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
603
+ return trajectory
604
+
605
+
606
+ def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
607
+ trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
608
+ return trajectory
609
+
610
+
611
+ def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
612
+ # invert gripper action, +1 = open, 0 = close
613
+ trajectory["action"] = tf.concat(
614
+ (
615
+ trajectory["action"][:, :6],
616
+ invert_gripper_actions(trajectory["action"][:, -1:]),
617
+ ),
618
+ axis=-1,
619
+ )
620
+ return trajectory
621
+
622
+
623
+ def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
624
+ trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
625
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
626
+ return trajectory
627
+
628
+
629
+ def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
630
+ trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
631
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
632
+ return trajectory
633
+
634
+
635
+ def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
636
+ trajectory["action"] = trajectory["action"][..., :-1]
637
+ return trajectory
638
+
639
+
640
+ def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
641
+ import tensorflow_graphics.geometry.transformation as tft
642
+
643
+ trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
644
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
645
+ trajectory["action"] = tf.concat(
646
+ (
647
+ trajectory["action"][:, :3],
648
+ tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
649
+ trajectory["action"][:, 7:8],
650
+ ),
651
+ axis=-1,
652
+ )
653
+ return trajectory
654
+
655
+
656
+ def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
657
+ trajectory["action"] = tf.concat(
658
+ (
659
+ trajectory["action"],
660
+ tf.zeros_like(trajectory["action"]),
661
+ tf.zeros_like(trajectory["action"][:, :1]),
662
+ ),
663
+ axis=-1,
664
+ )
665
+ return trajectory
666
+
667
+
668
+ def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
669
+ trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
670
+
671
+ # invert gripper action + clip, +1 = open, 0 = close
672
+ trajectory["action"] = tf.concat(
673
+ (
674
+ trajectory["action"][:, :6],
675
+ invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
676
+ ),
677
+ axis=-1,
678
+ )
679
+
680
+ # trajectory["language_instruction"] = tf.fill(
681
+ # tf.shape(trajectory["language_instruction"]), ""
682
+ # ) # delete uninformative language instruction
683
+ return trajectory
684
+
685
+
686
+ def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
687
+ trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
688
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
689
+
690
+ # dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
691
+ trajectory["action"] = tf.concat(
692
+ (
693
+ trajectory["action"],
694
+ invert_gripper_actions(trajectory["observation"]["gripper_state"]),
695
+ ),
696
+ axis=-1,
697
+ )
698
+ return trajectory
699
+
700
+
701
+ def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
702
+ import tensorflow_graphics.geometry.transformation as tft
703
+
704
+ trajectory["action"] = tf.concat(
705
+ (
706
+ trajectory["action"][:, :3],
707
+ tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
708
+ trajectory["action"][:, -1:],
709
+ ),
710
+ axis=-1,
711
+ )
712
+ return trajectory
713
+
714
+
715
+ def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
716
+ trajectory["action"] = tf.concat(
717
+ (
718
+ trajectory["action"][:, :3],
719
+ trajectory["action"][:, -4:],
720
+ ),
721
+ axis=-1,
722
+ )
723
+ return trajectory
724
+
725
+
726
+ def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
727
+ trajectory["observation"]["eef_state"] = tf.concat(
728
+ (
729
+ trajectory["observation"]["state"][:, :3],
730
+ tf.zeros_like(trajectory["observation"]["state"][:, :3]),
731
+ ),
732
+ axis=-1,
733
+ )
734
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
735
+ trajectory["action"] = trajectory["action"][..., :-1]
736
+ return trajectory
737
+
738
+
739
+ def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
740
+ trajectory["observation"]["state"] = tf.concat(
741
+ (
742
+ trajectory["observation"]["position"],
743
+ tf.zeros_like(trajectory["observation"]["state"][:, :3]),
744
+ trajectory["observation"]["yaw"],
745
+ ),
746
+ axis=-1,
747
+ )
748
+ trajectory["action"] = tf.concat(
749
+ (
750
+ trajectory["action"],
751
+ tf.zeros_like(trajectory["action"]),
752
+ tf.zeros_like(trajectory["action"]),
753
+ tf.zeros_like(trajectory["action"][:, :1]),
754
+ ),
755
+ axis=-1,
756
+ )
757
+ return trajectory
758
+
759
+
760
+ def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
761
+ # every input feature is batched, ie has leading batch dimension
762
+ trajectory["observation"]["proprio"] = tf.concat(
763
+ (
764
+ trajectory["observation"]["eef_pose"],
765
+ trajectory["observation"]["state_gripper_pose"][..., None],
766
+ ),
767
+ axis=-1,
768
+ )
769
+ return trajectory
770
+
771
+
772
+ def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
773
+ # every input feature is batched, ie has leading batch dimension
774
+ trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
775
+ return trajectory
776
+
777
+
778
+ def roboset_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
779
+ # every input feature is batched, ie has leading batch dimension
780
+ trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
781
+
782
+ # gripper action is in -1...1 --> clip to 0...1, flip
783
+ gripper_action = trajectory["action"][:, -1:]
784
+ gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
785
+
786
+ trajectory["action"] = tf.concat(
787
+ (
788
+ trajectory["action"][:, :7],
789
+ gripper_action,
790
+ ),
791
+ axis=-1,
792
+ )
793
+ return trajectory
794
+
795
+
796
+ def rh20t_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
797
+ trajectory["action"] = tf.concat(
798
+ (
799
+ trajectory["action"]["tcp_base"],
800
+ tf.cast(trajectory["action"]["gripper"][:, None], tf.float32),
801
+ ),
802
+ axis=-1,
803
+ )
804
+ trajectory["observation"]["proprio"] = tf.concat(
805
+ (
806
+ trajectory["observation"]["tcp_base"],
807
+ trajectory["observation"]["gripper_width"][..., None],
808
+ ),
809
+ axis=-1,
810
+ )
811
+ return trajectory
812
+
813
+
814
+ def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
815
+ trajectory["action"] = tf.concat(
816
+ [
817
+ trajectory["action"][:, :6],
818
+ binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
819
+ ],
820
+ axis=1,
821
+ )
822
+ trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
823
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
824
+ return trajectory
825
+
826
+
827
+ def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
828
+ # gripper action is in -1 (open)...1 (close) --> clip to 0...1, flip --> +1 = open, 0 = close
829
+ gripper_action = trajectory["action"][:, -1:]
830
+ gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
831
+
832
+ trajectory["action"] = tf.concat(
833
+ [
834
+ trajectory["action"][:, :6],
835
+ gripper_action,
836
+ ],
837
+ axis=1,
838
+ )
839
+ trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
840
+ trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state
841
+ return trajectory
842
+
843
+
844
+ def aloha_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
845
+ # Don't need to do anything because dataset is already in the correct format
846
+ return trajectory
847
+
848
+
849
+ # === Registry ===
850
+ OXE_STANDARDIZATION_TRANSFORMS = {
851
+ "bridge_oxe": bridge_oxe_dataset_transform,
852
+ "bridge_orig": bridge_orig_dataset_transform,
853
+ "bridge_dataset": bridge_orig_dataset_transform,
854
+ "ppgm": ppgm_dataset_transform,
855
+ "ppgm_static": ppgm_dataset_transform,
856
+ "ppgm_wrist": ppgm_dataset_transform,
857
+ "fractal20220817_data": rt1_dataset_transform,
858
+ "kuka": kuka_dataset_transform,
859
+ "taco_play": taco_play_dataset_transform,
860
+ "jaco_play": jaco_play_dataset_transform,
861
+ "berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
862
+ "roboturk": roboturk_dataset_transform,
863
+ "nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
864
+ "viola": viola_dataset_transform,
865
+ "berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
866
+ "toto": toto_dataset_transform,
867
+ "language_table": language_table_dataset_transform,
868
+ "columbia_cairlab_pusht_real": pusht_dataset_transform,
869
+ "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
870
+ "nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
871
+ "stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
872
+ "austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
873
+ "nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
874
+ "maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
875
+ "furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
876
+ "cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
877
+ "ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
878
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
879
+ "austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
880
+ "austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
881
+ "bc_z": bc_z_dataset_transform,
882
+ "utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
883
+ "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
884
+ "utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform,
885
+ "utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
886
+ "robo_net": robo_net_dataset_transform,
887
+ "berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
888
+ "berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
889
+ "kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
890
+ "stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
891
+ "tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
892
+ "dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform,
893
+ "dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
894
+ "dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
895
+ "asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
896
+ "stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
897
+ "imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
898
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
899
+ "uiuc_d3field": uiuc_d3field_dataset_transform,
900
+ "utaustin_mutex": utaustin_mutex_dataset_transform,
901
+ "berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
902
+ "cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
903
+ "cmu_play_fusion": playfusion_dataset_transform,
904
+ "cmu_stretch": cmu_stretch_dataset_transform,
905
+ "berkeley_gnm_recon": gnm_dataset_transform,
906
+ "berkeley_gnm_cory_hall": gnm_dataset_transform,
907
+ "berkeley_gnm_sac_son": gnm_dataset_transform,
908
+ "droid": droid_baseact_transform,
909
+ "fmb_dataset": fmb_dataset_transform,
910
+ "dobbe": dobbe_dataset_transform,
911
+ "roboset": roboset_dataset_transform,
912
+ "rh20t": rh20t_dataset_transform,
913
+ ### T-DROID datasets
914
+ "tdroid_carrot_in_bowl": tdroid_dataset_transform,
915
+ "tdroid_pour_corn_in_pot": tdroid_dataset_transform,
916
+ "tdroid_flip_pot_upright": tdroid_dataset_transform,
917
+ "tdroid_move_object_onto_plate": tdroid_dataset_transform,
918
+ "tdroid_knock_object_over": tdroid_dataset_transform,
919
+ "tdroid_cover_object_with_towel": tdroid_dataset_transform,
920
+ ### DROID Finetuning datasets
921
+ "droid_wipe": droid_finetuning_transform,
922
+ ### LIBERO datasets (modified versions)
923
+ "libero_spatial_no_noops": libero_dataset_transform,
924
+ "libero_object_no_noops": libero_dataset_transform,
925
+ "libero_goal_no_noops": libero_dataset_transform,
926
+ "libero_10_no_noops": libero_dataset_transform,
927
+ "libero_4_task_suites_no_noops": libero_dataset_transform,
928
+ ### ALOHA fine-tuning datasets
929
+ "aloha1_fold_shorts_20_demos": aloha_dataset_transform,
930
+ "aloha1_fold_shirt_30_demos": aloha_dataset_transform,
931
+ "aloha1_scoop_X_into_bowl_45_demos": aloha_dataset_transform,
932
+ "aloha1_put_X_into_pot_300_demos": aloha_dataset_transform,
933
+
934
+ "aloha_dual_bottles_pick_hard_d435_20": aloha_dataset_transform,
935
+
936
+ # robotwin2
937
+ "grab_roller_aloha_agilex_50": aloha_dataset_transform,
938
+ "handover_mic_aloha_agilex_50": aloha_dataset_transform,
939
+ "lift_pot_aloha_agilex_50": aloha_dataset_transform,
940
+ "move_can_pot_aloha_agilex_50": aloha_dataset_transform,
941
+ "open_laptop_aloha_agilex_50": aloha_dataset_transform,
942
+ "pick_dual_bottles_aloha_agilex_50":aloha_dataset_transform,
943
+ "place_dual_shoes_aloha_agilex_50": aloha_dataset_transform,
944
+ "place_object_basket_aloha_agilex_50": aloha_dataset_transform,
945
+ "place_phone_stand_aloha_agilex_50": aloha_dataset_transform,
946
+ "put_bottles_dustbin_aloha_agilex_50": aloha_dataset_transform,
947
+ "put_object_cabinet_aloha_agilex_50": aloha_dataset_transform,
948
+ "stack_blocks_two_aloha_agilex_50": aloha_dataset_transform,
949
+ "stack_bowls_two_aloha_agilex_50": aloha_dataset_transform,
950
+
951
+ }
prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Episode transforms for DROID dataset."""
2
+
3
+ from typing import Any, Dict
4
+
5
+ import tensorflow as tf
6
+ import tensorflow_graphics.geometry.transformation as tfg
7
+
8
+
9
+ def rmat_to_euler(rot_mat):
10
+ return tfg.euler.from_rotation_matrix(rot_mat)
11
+
12
+
13
+ def euler_to_rmat(euler):
14
+ return tfg.rotation_matrix_3d.from_euler(euler)
15
+
16
+
17
+ def invert_rmat(rot_mat):
18
+ return tfg.rotation_matrix_3d.inverse(rot_mat)
19
+
20
+
21
+ def rotmat_to_rot6d(mat):
22
+ """
23
+ Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
24
+ Args:
25
+ mat: rotation matrix
26
+
27
+ Returns: 6d vector (first two rows of rotation matrix)
28
+
29
+ """
30
+ r6 = mat[..., :2, :]
31
+ r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
32
+ r6_flat = tf.concat([r6_0, r6_1], axis=-1)
33
+ return r6_flat
34
+
35
+
36
+ def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
37
+ """
38
+ Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
39
+ Args:
40
+ velocity: 6d velocity action (3 x translation, 3 x rotation)
41
+ wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
42
+
43
+ Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
44
+
45
+ """
46
+ R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
47
+ R_frame_inv = invert_rmat(R_frame)
48
+
49
+ # world to wrist: dT_pi = R^-1 dT_rbt
50
+ vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0]
51
+
52
+ # world to wrist: dR_pi = R^-1 dR_rbt R
53
+ dR = euler_to_rmat(velocity[:, 3:6])
54
+ dR = R_frame_inv @ (dR @ R_frame)
55
+ dR_r6 = rotmat_to_rot6d(dR)
56
+ return tf.concat([vel_t, dR_r6], axis=-1)
57
+
58
+
59
+ def rand_swap_exterior_images(img1, img2):
60
+ """
61
+ Randomly swaps the two exterior images (for training with single exterior input).
62
+ """
63
+ return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
64
+
65
+
66
+ def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
67
+ """
68
+ DROID dataset transformation for actions expressed in *base* frame of the robot.
69
+ """
70
+ dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
71
+ dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
72
+
73
+ trajectory["action"] = tf.concat(
74
+ (
75
+ dt,
76
+ dR,
77
+ 1 - trajectory["action_dict"]["gripper_position"],
78
+ ),
79
+ axis=-1,
80
+ )
81
+ trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
82
+ rand_swap_exterior_images(
83
+ trajectory["observation"]["exterior_image_1_left"],
84
+ trajectory["observation"]["exterior_image_2_left"],
85
+ )
86
+ )
87
+ trajectory["observation"]["proprio"] = tf.concat(
88
+ (
89
+ trajectory["observation"]["cartesian_position"],
90
+ trajectory["observation"]["gripper_position"],
91
+ ),
92
+ axis=-1,
93
+ )
94
+ return trajectory
95
+
96
+
97
+ def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
98
+ """
99
+ DROID dataset transformation for actions expressed in *wrist* frame of the robot.
100
+ """
101
+ wrist_act = velocity_act_to_wrist_frame(
102
+ trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
103
+ )
104
+ trajectory["action"] = tf.concat(
105
+ (
106
+ wrist_act,
107
+ trajectory["action_dict"]["gripper_position"],
108
+ ),
109
+ axis=-1,
110
+ )
111
+ trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
112
+ rand_swap_exterior_images(
113
+ trajectory["observation"]["exterior_image_1_left"],
114
+ trajectory["observation"]["exterior_image_2_left"],
115
+ )
116
+ )
117
+ trajectory["observation"]["proprio"] = tf.concat(
118
+ (
119
+ trajectory["observation"]["cartesian_position"],
120
+ trajectory["observation"]["gripper_position"],
121
+ ),
122
+ axis=-1,
123
+ )
124
+ return trajectory
125
+
126
+
127
+ def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
128
+ """
129
+ DROID dataset transformation for actions expressed in *base* frame of the robot.
130
+ """
131
+ dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
132
+ dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
133
+ trajectory["action"] = tf.concat(
134
+ (
135
+ dt,
136
+ dR,
137
+ 1 - trajectory["action_dict"]["gripper_position"],
138
+ ),
139
+ axis=-1,
140
+ )
141
+ trajectory["observation"]["proprio"] = tf.concat(
142
+ (
143
+ trajectory["observation"]["cartesian_position"],
144
+ trajectory["observation"]["gripper_position"],
145
+ ),
146
+ axis=-1,
147
+ )
148
+ return trajectory
149
+
150
+
151
+ def zero_action_filter(traj: Dict) -> bool:
152
+ """
153
+ Filters transitions whose actions are all-0 (only relative actions, no gripper action).
154
+ Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
155
+ """
156
+ DROID_Q01 = tf.convert_to_tensor(
157
+ [
158
+ -0.7776297926902771,
159
+ -0.5803514122962952,
160
+ -0.5795090794563293,
161
+ -0.6464047729969025,
162
+ -0.7041108310222626,
163
+ -0.8895104378461838,
164
+ ]
165
+ )
166
+ DROID_Q99 = tf.convert_to_tensor(
167
+ [
168
+ 0.7597932070493698,
169
+ 0.5726242214441299,
170
+ 0.7351000607013702,
171
+ 0.6705610305070877,
172
+ 0.6464948207139969,
173
+ 0.8897542208433151,
174
+ ]
175
+ )
176
+ DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
177
+
178
+ return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
prismatic/vla/datasets/rlds/utils/__init__.py ADDED
File without changes
prismatic/vla/datasets/rlds/utils/goal_relabeling.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ goal_relabeling.py
3
+
4
+ Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required.
5
+ Each function should add entries to the "task" dict.
6
+ """
7
+
8
+ from typing import Dict
9
+
10
+ import tensorflow as tf
11
+
12
+ from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge
13
+
14
+
15
+ def uniform(traj: Dict) -> Dict:
16
+ """Relabels with a true uniform distribution over future states."""
17
+ traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0]
18
+
19
+ # Select a random future index for each transition i in the range [i + 1, traj_len)
20
+ rand = tf.random.uniform([traj_len])
21
+ low = tf.cast(tf.range(traj_len) + 1, tf.float32)
22
+ high = tf.cast(traj_len, tf.float32)
23
+ goal_idxs = tf.cast(rand * (high - low) + low, tf.int32)
24
+
25
+ # Sometimes there are floating-point errors that cause an out-of-bounds
26
+ goal_idxs = tf.minimum(goal_idxs, traj_len - 1)
27
+
28
+ # Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly)
29
+ goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"])
30
+ traj["task"] = tree_merge(traj["task"], goal)
31
+
32
+ return traj
prismatic/vla/materialize.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ materialize.py
3
+
4
+ Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and
5
+ exports individual functions for clear control flow.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import Tuple, Type
10
+
11
+ from torch.utils.data import Dataset
12
+ from transformers import PreTrainedTokenizerBase
13
+
14
+ from prismatic.models.backbones.llm.prompting import PromptBuilder
15
+ from prismatic.models.backbones.vision import ImageTransform
16
+ from prismatic.util.data_utils import PaddedCollatorForActionPrediction
17
+ from prismatic.vla.action_tokenizer import ActionTokenizer
18
+ from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset
19
+
20
+
21
+ def get_vla_dataset_and_collator(
22
+ data_root_dir: Path,
23
+ data_mix: str,
24
+ image_transform: ImageTransform,
25
+ tokenizer: PreTrainedTokenizerBase,
26
+ prompt_builder_fn: Type[PromptBuilder],
27
+ default_image_resolution: Tuple[int, int, int],
28
+ padding_side: str = "right",
29
+ predict_stop_token: bool = True,
30
+ shuffle_buffer_size: int = 100_000,
31
+ train: bool = True,
32
+ episodic: bool = False,
33
+ image_aug: bool = False,
34
+ ) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]:
35
+ """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions."""
36
+ action_tokenizer = ActionTokenizer(tokenizer)
37
+ batch_transform = RLDSBatchTransform(
38
+ action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token
39
+ )
40
+ collator = PaddedCollatorForActionPrediction(
41
+ tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side
42
+ )
43
+
44
+ # Build RLDS Iterable Dataset
45
+ cls = RLDSDataset if not episodic else EpisodicRLDSDataset
46
+ dataset = cls(
47
+ data_root_dir,
48
+ data_mix,
49
+ batch_transform,
50
+ resize_resolution=default_image_resolution[1:],
51
+ shuffle_buffer_size=shuffle_buffer_size,
52
+ train=train,
53
+ image_aug=image_aug,
54
+ )
55
+
56
+ return dataset, action_tokenizer, collator
run_scripts/ac/ac.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_query
3
+ #========== !NOTE! ==========#
4
+ RUN_MODE=simvla_ffn_AC
5
+ use_predict_future_prop=False
6
+ batch_size=16
7
+ use_action_ts_head=True
8
+ use_one_embed=True
9
+ use_multi_scaling=False
10
+ mlp_type=ffn
11
+ decoder_num_blocks=2
12
+ robot_platform=libero
13
+ MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
14
+ #========== !NOTE! ==========#
15
+ use_l1_regression=True
16
+ num_images_in_input=1
17
+ wandb_entity=chenghaha
18
+ wandb_project=fastvla
19
+ wandb_log_freq=1
20
+ use_proprio=False
21
+ use_diffusion=False
22
+ use_film=False
23
+ num_steps_before_decay=20000
24
+ save_freq=5000
25
+ max_steps=40000
26
+ vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
27
+ data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
28
+ dataset_name=libero_4_task_suites_no_noops
29
+ run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
30
+ #========== get run_id ==========#
31
+ note_parts=("${MODE}")
32
+
33
+ # if [ "$use_l1_regression" = "True" ]; then
34
+ # note_parts+=("L1_regression")
35
+ # fi
36
+
37
+ # if [ "$num_images_in_input" == 1 ]; then
38
+ # note_parts+=("3rd_person_img")
39
+ # else
40
+ # note_parts+=("3rd_person_img_and_wrist")
41
+ # fi
42
+
43
+ # if [ "$use_l1_regression" = "True" ]; then
44
+ # note_parts+=("proprio_state")
45
+ # fi
46
+
47
+ # if [ "$use_film" = "True" ]; then
48
+ # note_parts+=("Film")
49
+ # fi
50
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
51
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
52
+
53
+ #========== enter environment ==========#
54
+ conda activate openvla-oft
55
+ cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
56
+ export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
57
+
58
+ #========== run ==========#
59
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
60
+ --vla_path "$vla_path" \
61
+ --data_root_dir "$data_root_dir" \
62
+ --dataset_name "$dataset_name" \
63
+ --run_root_dir "$run_root_dir" \
64
+ --use_l1_regression "$use_l1_regression" \
65
+ --use_diffusion "$use_diffusion" \
66
+ --use_film "$use_film" \
67
+ --num_images_in_input "$num_images_in_input" \
68
+ --use_proprio "$use_proprio" \
69
+ --batch_size "$batch_size" \
70
+ --learning_rate 5e-4 \
71
+ --num_steps_before_decay "$num_steps_before_decay" \
72
+ --max_steps "$max_steps" \
73
+ --save_freq "$save_freq" \
74
+ --save_latest_checkpoint_only False \
75
+ --image_aug True \
76
+ --lora_rank 32 \
77
+ --wandb_entity "$wandb_entity" \
78
+ --wandb_project "$wandb_project" \
79
+ --wandb_log_freq "$wandb_log_freq" \
80
+ --run_id_note "$run_id_note_value" \
81
+ --use_predict_future_prop "$use_predict_future_prop" \
82
+ --use_action_ts_head "$use_action_ts_head" \
83
+ --use_one_embed "$use_one_embed" \
84
+ --use_multi_scaling "$use_multi_scaling" \
85
+ --mlp_type "$mlp_type" \
86
+ --decoder_num_blocks "$decoder_num_blocks" \
87
+ --robot_platform "$robot_platform"
run_scripts/ffn/3ffn2.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_query
3
+ #========== !NOTE! ==========#
4
+ RUN_MODE=simvla3_ffn
5
+ use_predict_future_prop=False
6
+ batch_size=16
7
+ use_action_ts_head=True
8
+ use_one_embed=True
9
+ use_multi_scaling=False
10
+ mlp_type=ffn
11
+ decoder_num_blocks=2
12
+ robot_platform=libero
13
+ MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
14
+ #========== !NOTE! ==========#
15
+ use_l1_regression=True
16
+ num_images_in_input=1
17
+ wandb_entity=chenghaha
18
+ wandb_project=fastvla
19
+ wandb_log_freq=1
20
+ use_proprio=False
21
+ use_diffusion=False
22
+ use_film=False
23
+ num_steps_before_decay=20000
24
+ save_freq=10000
25
+ max_steps=40000
26
+ vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
27
+ data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
28
+ dataset_name=libero_4_task_suites_no_noops
29
+ run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
30
+ #========== get run_id ==========#
31
+ note_parts=("${MODE}")
32
+
33
+ # if [ "$use_l1_regression" = "True" ]; then
34
+ # note_parts+=("L1_regression")
35
+ # fi
36
+
37
+ # if [ "$num_images_in_input" == 1 ]; then
38
+ # note_parts+=("3rd_person_img")
39
+ # else
40
+ # note_parts+=("3rd_person_img_and_wrist")
41
+ # fi
42
+
43
+ # if [ "$use_l1_regression" = "True" ]; then
44
+ # note_parts+=("proprio_state")
45
+ # fi
46
+
47
+ # if [ "$use_film" = "True" ]; then
48
+ # note_parts+=("Film")
49
+ # fi
50
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
51
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
52
+
53
+ #========== enter environment ==========#
54
+ conda activate openvla-oft
55
+ cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
56
+ export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
57
+
58
+ #========== run ==========#
59
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
60
+ --vla_path "$vla_path" \
61
+ --data_root_dir "$data_root_dir" \
62
+ --dataset_name "$dataset_name" \
63
+ --run_root_dir "$run_root_dir" \
64
+ --use_l1_regression "$use_l1_regression" \
65
+ --use_diffusion "$use_diffusion" \
66
+ --use_film "$use_film" \
67
+ --num_images_in_input "$num_images_in_input" \
68
+ --use_proprio "$use_proprio" \
69
+ --batch_size "$batch_size" \
70
+ --learning_rate 5e-4 \
71
+ --num_steps_before_decay "$num_steps_before_decay" \
72
+ --max_steps "$max_steps" \
73
+ --save_freq "$save_freq" \
74
+ --save_latest_checkpoint_only False \
75
+ --image_aug True \
76
+ --lora_rank 32 \
77
+ --wandb_entity "$wandb_entity" \
78
+ --wandb_project "$wandb_project" \
79
+ --wandb_log_freq "$wandb_log_freq" \
80
+ --run_id_note "$run_id_note_value" \
81
+ --use_predict_future_prop "$use_predict_future_prop" \
82
+ --use_action_ts_head "$use_action_ts_head" \
83
+ --use_one_embed "$use_one_embed" \
84
+ --use_multi_scaling "$use_multi_scaling" \
85
+ --mlp_type "$mlp_type" \
86
+ --decoder_num_blocks "$decoder_num_blocks" \
87
+ --robot_platform "$robot_platform"
run_scripts/ffn/3postffn2.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_query
3
+ #========== !NOTE! ==========#
4
+ RUN_MODE=simvla3_postffn
5
+ use_predict_future_prop=False
6
+ batch_size=16
7
+ use_action_ts_head=True
8
+ use_one_embed=True
9
+ use_multi_scaling=False
10
+ mlp_type=postffn
11
+ decoder_num_blocks=2
12
+ robot_platform=libero
13
+ MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
14
+ #========== !NOTE! ==========#
15
+ use_l1_regression=True
16
+ num_images_in_input=1
17
+ wandb_entity=chenghaha
18
+ wandb_project=fastvla
19
+ wandb_log_freq=1
20
+ use_proprio=False
21
+ use_diffusion=False
22
+ use_film=False
23
+ num_steps_before_decay=20000
24
+ save_freq=10000
25
+ max_steps=40000
26
+ vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
27
+ data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
28
+ dataset_name=libero_4_task_suites_no_noops
29
+ run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
30
+ #========== get run_id ==========#
31
+ note_parts=("${MODE}")
32
+
33
+ # if [ "$use_l1_regression" = "True" ]; then
34
+ # note_parts+=("L1_regression")
35
+ # fi
36
+
37
+ # if [ "$num_images_in_input" == 1 ]; then
38
+ # note_parts+=("3rd_person_img")
39
+ # else
40
+ # note_parts+=("3rd_person_img_and_wrist")
41
+ # fi
42
+
43
+ # if [ "$use_l1_regression" = "True" ]; then
44
+ # note_parts+=("proprio_state")
45
+ # fi
46
+
47
+ # if [ "$use_film" = "True" ]; then
48
+ # note_parts+=("Film")
49
+ # fi
50
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
51
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
52
+
53
+ #========== enter environment ==========#
54
+ conda activate openvla-oft
55
+ cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
56
+ export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
57
+
58
+ #========== run ==========#
59
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
60
+ --vla_path "$vla_path" \
61
+ --data_root_dir "$data_root_dir" \
62
+ --dataset_name "$dataset_name" \
63
+ --run_root_dir "$run_root_dir" \
64
+ --use_l1_regression "$use_l1_regression" \
65
+ --use_diffusion "$use_diffusion" \
66
+ --use_film "$use_film" \
67
+ --num_images_in_input "$num_images_in_input" \
68
+ --use_proprio "$use_proprio" \
69
+ --batch_size "$batch_size" \
70
+ --learning_rate 5e-4 \
71
+ --num_steps_before_decay "$num_steps_before_decay" \
72
+ --max_steps "$max_steps" \
73
+ --save_freq "$save_freq" \
74
+ --save_latest_checkpoint_only False \
75
+ --image_aug True \
76
+ --lora_rank 32 \
77
+ --wandb_entity "$wandb_entity" \
78
+ --wandb_project "$wandb_project" \
79
+ --wandb_log_freq "$wandb_log_freq" \
80
+ --run_id_note "$run_id_note_value" \
81
+ --use_predict_future_prop "$use_predict_future_prop" \
82
+ --use_action_ts_head "$use_action_ts_head" \
83
+ --use_one_embed "$use_one_embed" \
84
+ --use_multi_scaling "$use_multi_scaling" \
85
+ --mlp_type "$mlp_type" \
86
+ --decoder_num_blocks "$decoder_num_blocks" \
87
+ --robot_platform "$robot_platform"
run_scripts/ffn/3postffn6.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_query
3
+ #========== !NOTE! ==========#
4
+ RUN_MODE=simvla3_postffn
5
+ use_predict_future_prop=False
6
+ batch_size=16
7
+ use_action_ts_head=True
8
+ use_one_embed=True
9
+ use_multi_scaling=False
10
+ mlp_type=postffn
11
+ decoder_num_blocks=6
12
+ robot_platform=libero
13
+ MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
14
+ #========== !NOTE! ==========#
15
+ use_l1_regression=True
16
+ num_images_in_input=1
17
+ wandb_entity=chenghaha
18
+ wandb_project=fastvla
19
+ wandb_log_freq=1
20
+ use_proprio=False
21
+ use_diffusion=False
22
+ use_film=False
23
+ num_steps_before_decay=20000
24
+ save_freq=10000
25
+ max_steps=40000
26
+ vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
27
+ data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
28
+ dataset_name=libero_4_task_suites_no_noops
29
+ run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
30
+ #========== get run_id ==========#
31
+ note_parts=("${MODE}")
32
+
33
+ # if [ "$use_l1_regression" = "True" ]; then
34
+ # note_parts+=("L1_regression")
35
+ # fi
36
+
37
+ # if [ "$num_images_in_input" == 1 ]; then
38
+ # note_parts+=("3rd_person_img")
39
+ # else
40
+ # note_parts+=("3rd_person_img_and_wrist")
41
+ # fi
42
+
43
+ # if [ "$use_l1_regression" = "True" ]; then
44
+ # note_parts+=("proprio_state")
45
+ # fi
46
+
47
+ # if [ "$use_film" = "True" ]; then
48
+ # note_parts+=("Film")
49
+ # fi
50
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
51
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
52
+
53
+ #========== enter environment ==========#
54
+ conda activate openvla-oft
55
+ cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
56
+ export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
57
+
58
+ #========== run ==========#
59
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
60
+ --vla_path "$vla_path" \
61
+ --data_root_dir "$data_root_dir" \
62
+ --dataset_name "$dataset_name" \
63
+ --run_root_dir "$run_root_dir" \
64
+ --use_l1_regression "$use_l1_regression" \
65
+ --use_diffusion "$use_diffusion" \
66
+ --use_film "$use_film" \
67
+ --num_images_in_input "$num_images_in_input" \
68
+ --use_proprio "$use_proprio" \
69
+ --batch_size "$batch_size" \
70
+ --learning_rate 5e-4 \
71
+ --num_steps_before_decay "$num_steps_before_decay" \
72
+ --max_steps "$max_steps" \
73
+ --save_freq "$save_freq" \
74
+ --save_latest_checkpoint_only False \
75
+ --image_aug True \
76
+ --lora_rank 32 \
77
+ --wandb_entity "$wandb_entity" \
78
+ --wandb_project "$wandb_project" \
79
+ --wandb_log_freq "$wandb_log_freq" \
80
+ --run_id_note "$run_id_note_value" \
81
+ --use_predict_future_prop "$use_predict_future_prop" \
82
+ --use_action_ts_head "$use_action_ts_head" \
83
+ --use_one_embed "$use_one_embed" \
84
+ --use_multi_scaling "$use_multi_scaling" \
85
+ --mlp_type "$mlp_type" \
86
+ --decoder_num_blocks "$decoder_num_blocks" \
87
+ --robot_platform "$robot_platform"
run_scripts/ffn/debug_5ffn_withactionprojector.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_query
3
+ #========== !NOTE! ==========#
4
+ RUN_MODE=simvla4_ffn_withprojector
5
+ use_predict_future_prop=False
6
+ batch_size=16
7
+ use_action_ts_head=True
8
+ use_one_embed=True
9
+ use_multi_scaling=False
10
+ mlp_type=ffn
11
+ decoder_num_blocks=2
12
+ robot_platform=libero
13
+ MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
14
+ #========== !NOTE! ==========#
15
+ use_l1_regression=True
16
+ num_images_in_input=1
17
+ wandb_entity=chenghaha
18
+ wandb_project=fastvla
19
+ wandb_log_freq=1
20
+ use_proprio=False
21
+ use_diffusion=False
22
+ use_film=False
23
+ num_steps_before_decay=20000
24
+ save_freq=10000
25
+ max_steps=40000
26
+ vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
27
+ data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
28
+ dataset_name=libero_4_task_suites_no_noops
29
+ run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
30
+ #========== get run_id ==========#
31
+ note_parts=("${MODE}")
32
+
33
+ # if [ "$use_l1_regression" = "True" ]; then
34
+ # note_parts+=("L1_regression")
35
+ # fi
36
+
37
+ # if [ "$num_images_in_input" == 1 ]; then
38
+ # note_parts+=("3rd_person_img")
39
+ # else
40
+ # note_parts+=("3rd_person_img_and_wrist")
41
+ # fi
42
+
43
+ # if [ "$use_l1_regression" = "True" ]; then
44
+ # note_parts+=("proprio_state")
45
+ # fi
46
+
47
+ # if [ "$use_film" = "True" ]; then
48
+ # note_parts+=("Film")
49
+ # fi
50
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
51
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
52
+
53
+ #========== enter environment ==========#
54
+ conda activate openvla-oft
55
+ cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
56
+ export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
57
+
58
+ #========== run ==========#
59
+ WANDB_CONSOLE=off WANDB_MODE=offline python -m debugpy --listen 1234 --wait-for-client '/opt/conda/envs/spatialvla/bin/torchrun' --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune.py \
60
+ --vla_path "$vla_path" \
61
+ --data_root_dir "$data_root_dir" \
62
+ --dataset_name "$dataset_name" \
63
+ --run_root_dir "$run_root_dir" \
64
+ --use_l1_regression "$use_l1_regression" \
65
+ --use_diffusion "$use_diffusion" \
66
+ --use_film "$use_film" \
67
+ --num_images_in_input "$num_images_in_input" \
68
+ --use_proprio "$use_proprio" \
69
+ --batch_size "$batch_size" \
70
+ --learning_rate 5e-5 \
71
+ --num_steps_before_decay "$num_steps_before_decay" \
72
+ --max_steps "$max_steps" \
73
+ --save_freq "$save_freq" \
74
+ --save_latest_checkpoint_only False \
75
+ --image_aug True \
76
+ --lora_rank 32 \
77
+ --wandb_entity "$wandb_entity" \
78
+ --wandb_project "$wandb_project" \
79
+ --wandb_log_freq "$wandb_log_freq" \
80
+ --run_id_note "$run_id_note_value" \
81
+ --use_predict_future_prop "$use_predict_future_prop" \
82
+ --use_action_ts_head "$use_action_ts_head" \
83
+ --use_one_embed "$use_one_embed" \
84
+ --use_multi_scaling "$use_multi_scaling" \
85
+ --mlp_type "$mlp_type" \
86
+ --decoder_num_blocks "$decoder_num_blocks" \
87
+ --robot_platform "$robot_platform"
run_scripts/ffn/ffn4.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_query
3
+ #========== !NOTE! ==========#
4
+ RUN_MODE=simvla_ffn
5
+ use_predict_future_prop=False
6
+ batch_size=16
7
+ use_action_ts_head=True
8
+ use_one_embed=True
9
+ use_multi_scaling=False
10
+ mlp_type=ffn
11
+ decoder_num_blocks=4
12
+ robot_platform=libero
13
+ MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
14
+ #========== !NOTE! ==========#
15
+ use_l1_regression=True
16
+ num_images_in_input=1
17
+ wandb_entity=chenghaha
18
+ wandb_project=fastvla
19
+ wandb_log_freq=1
20
+ use_proprio=False
21
+ use_diffusion=False
22
+ use_film=False
23
+ num_steps_before_decay=20000
24
+ save_freq=5000
25
+ max_steps=40000
26
+ vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
27
+ data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
28
+ dataset_name=libero_4_task_suites_no_noops
29
+ run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
30
+ #========== get run_id ==========#
31
+ note_parts=("${MODE}")
32
+
33
+ # if [ "$use_l1_regression" = "True" ]; then
34
+ # note_parts+=("L1_regression")
35
+ # fi
36
+
37
+ # if [ "$num_images_in_input" == 1 ]; then
38
+ # note_parts+=("3rd_person_img")
39
+ # else
40
+ # note_parts+=("3rd_person_img_and_wrist")
41
+ # fi
42
+
43
+ # if [ "$use_l1_regression" = "True" ]; then
44
+ # note_parts+=("proprio_state")
45
+ # fi
46
+
47
+ # if [ "$use_film" = "True" ]; then
48
+ # note_parts+=("Film")
49
+ # fi
50
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
51
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
52
+
53
+ #========== enter environment ==========#
54
+ conda activate openvla-oft
55
+ cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
56
+ export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
57
+
58
+ #========== run ==========#
59
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
60
+ --vla_path "$vla_path" \
61
+ --data_root_dir "$data_root_dir" \
62
+ --dataset_name "$dataset_name" \
63
+ --run_root_dir "$run_root_dir" \
64
+ --use_l1_regression "$use_l1_regression" \
65
+ --use_diffusion "$use_diffusion" \
66
+ --use_film "$use_film" \
67
+ --num_images_in_input "$num_images_in_input" \
68
+ --use_proprio "$use_proprio" \
69
+ --batch_size "$batch_size" \
70
+ --learning_rate 5e-4 \
71
+ --num_steps_before_decay "$num_steps_before_decay" \
72
+ --max_steps "$max_steps" \
73
+ --save_freq "$save_freq" \
74
+ --save_latest_checkpoint_only False \
75
+ --image_aug True \
76
+ --lora_rank 32 \
77
+ --wandb_entity "$wandb_entity" \
78
+ --wandb_project "$wandb_project" \
79
+ --wandb_log_freq "$wandb_log_freq" \
80
+ --run_id_note "$run_id_note_value" \
81
+ --use_predict_future_prop "$use_predict_future_prop" \
82
+ --use_action_ts_head "$use_action_ts_head" \
83
+ --use_one_embed "$use_one_embed" \
84
+ --use_multi_scaling "$use_multi_scaling" \
85
+ --mlp_type "$mlp_type" \
86
+ --decoder_num_blocks "$decoder_num_blocks" \
87
+ --robot_platform "$robot_platform"
run_scripts/ffn/ffn8.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_query
3
+ #========== !NOTE! ==========#
4
+ RUN_MODE=simvla_ffn
5
+ use_predict_future_prop=False
6
+ batch_size=16
7
+ use_action_ts_head=True
8
+ use_one_embed=True
9
+ use_multi_scaling=False
10
+ mlp_type=ffn
11
+ decoder_num_blocks=8
12
+ robot_platform=libero
13
+ MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
14
+ #========== !NOTE! ==========#
15
+ use_l1_regression=True
16
+ num_images_in_input=1
17
+ wandb_entity=chenghaha
18
+ wandb_project=fastvla
19
+ wandb_log_freq=1
20
+ use_proprio=False
21
+ use_diffusion=False
22
+ use_film=False
23
+ num_steps_before_decay=20000
24
+ save_freq=5000
25
+ max_steps=40000
26
+ vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
27
+ data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
28
+ dataset_name=libero_4_task_suites_no_noops
29
+ run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
30
+ #========== get run_id ==========#
31
+ note_parts=("${MODE}")
32
+
33
+ # if [ "$use_l1_regression" = "True" ]; then
34
+ # note_parts+=("L1_regression")
35
+ # fi
36
+
37
+ # if [ "$num_images_in_input" == 1 ]; then
38
+ # note_parts+=("3rd_person_img")
39
+ # else
40
+ # note_parts+=("3rd_person_img_and_wrist")
41
+ # fi
42
+
43
+ # if [ "$use_l1_regression" = "True" ]; then
44
+ # note_parts+=("proprio_state")
45
+ # fi
46
+
47
+ # if [ "$use_film" = "True" ]; then
48
+ # note_parts+=("Film")
49
+ # fi
50
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
51
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
52
+
53
+ #========== enter environment ==========#
54
+ conda activate openvla-oft
55
+ cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
56
+ export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
57
+
58
+ #========== run ==========#
59
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
60
+ --vla_path "$vla_path" \
61
+ --data_root_dir "$data_root_dir" \
62
+ --dataset_name "$dataset_name" \
63
+ --run_root_dir "$run_root_dir" \
64
+ --use_l1_regression "$use_l1_regression" \
65
+ --use_diffusion "$use_diffusion" \
66
+ --use_film "$use_film" \
67
+ --num_images_in_input "$num_images_in_input" \
68
+ --use_proprio "$use_proprio" \
69
+ --batch_size "$batch_size" \
70
+ --learning_rate 5e-4 \
71
+ --num_steps_before_decay "$num_steps_before_decay" \
72
+ --max_steps "$max_steps" \
73
+ --save_freq "$save_freq" \
74
+ --save_latest_checkpoint_only False \
75
+ --image_aug True \
76
+ --lora_rank 32 \
77
+ --wandb_entity "$wandb_entity" \
78
+ --wandb_project "$wandb_project" \
79
+ --wandb_log_freq "$wandb_log_freq" \
80
+ --run_id_note "$run_id_note_value" \
81
+ --use_predict_future_prop "$use_predict_future_prop" \
82
+ --use_action_ts_head "$use_action_ts_head" \
83
+ --use_one_embed "$use_one_embed" \
84
+ --use_multi_scaling "$use_multi_scaling" \
85
+ --mlp_type "$mlp_type" \
86
+ --decoder_num_blocks "$decoder_num_blocks" \
87
+ --robot_platform "$robot_platform"
run_scripts/ffn/test.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_query
3
+ #========== !NOTE! ==========#
4
+ RUN_MODE=test
5
+ use_predict_future_prop=False
6
+ batch_size=16
7
+ use_action_ts_head=True
8
+ use_one_embed=True
9
+ use_multi_scaling=False
10
+ mlp_type=ffn
11
+ decoder_num_blocks=2
12
+ robot_platform=libero
13
+ MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
14
+ #========== !NOTE! ==========#
15
+ use_l1_regression=True
16
+ num_images_in_input=1
17
+ wandb_entity=chenghaha
18
+ wandb_project=fastvla
19
+ wandb_log_freq=1
20
+ use_proprio=False
21
+ use_diffusion=False
22
+ use_film=False
23
+ num_steps_before_decay=20000
24
+ save_freq=5000
25
+ max_steps=40000
26
+ vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
27
+ data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
28
+ dataset_name=libero_4_task_suites_no_noops
29
+ run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
30
+ #========== get run_id ==========#
31
+ note_parts=("${MODE}")
32
+
33
+ # if [ "$use_l1_regression" = "True" ]; then
34
+ # note_parts+=("L1_regression")
35
+ # fi
36
+
37
+ # if [ "$num_images_in_input" == 1 ]; then
38
+ # note_parts+=("3rd_person_img")
39
+ # else
40
+ # note_parts+=("3rd_person_img_and_wrist")
41
+ # fi
42
+
43
+ # if [ "$use_l1_regression" = "True" ]; then
44
+ # note_parts+=("proprio_state")
45
+ # fi
46
+
47
+ # if [ "$use_film" = "True" ]; then
48
+ # note_parts+=("Film")
49
+ # fi
50
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
51
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
52
+
53
+ #========== enter environment ==========#
54
+ conda activate openvla-oft
55
+ cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
56
+ export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
57
+
58
+ #========== run ==========#
59
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune.py \
60
+ --vla_path "$vla_path" \
61
+ --data_root_dir "$data_root_dir" \
62
+ --dataset_name "$dataset_name" \
63
+ --run_root_dir "$run_root_dir" \
64
+ --use_l1_regression "$use_l1_regression" \
65
+ --use_diffusion "$use_diffusion" \
66
+ --use_film "$use_film" \
67
+ --num_images_in_input "$num_images_in_input" \
68
+ --use_proprio "$use_proprio" \
69
+ --batch_size "$batch_size" \
70
+ --learning_rate 5e-4 \
71
+ --num_steps_before_decay "$num_steps_before_decay" \
72
+ --max_steps "$max_steps" \
73
+ --save_freq "$save_freq" \
74
+ --save_latest_checkpoint_only False \
75
+ --image_aug True \
76
+ --lora_rank 32 \
77
+ --wandb_entity "$wandb_entity" \
78
+ --wandb_project "$wandb_project" \
79
+ --wandb_log_freq "$wandb_log_freq" \
80
+ --run_id_note "$run_id_note_value" \
81
+ --use_predict_future_prop "$use_predict_future_prop" \
82
+ --use_action_ts_head "$use_action_ts_head" \
83
+ --use_one_embed "$use_one_embed" \
84
+ --use_multi_scaling "$use_multi_scaling" \
85
+ --mlp_type "$mlp_type" \
86
+ --decoder_num_blocks "$decoder_num_blocks" \
87
+ --robot_platform "$robot_platform"
run_scripts/ffn_long_chunks/run.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ bash run_scripts/ffn_long_chunks/li4.sh
2
+ bash run_scripts/ffn_long_chunks/li16.sh
3
+ bash run_scripts/ffn_long_chunks/li24.sh
4
+ bash run_scripts/ffn_long_chunks/li32.sh
run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_base.sh ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=simvla_twin2
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=base
6
+ use_predict_future_prop=False
7
+ batch_size=4
8
+ use_action_ts_head=False
9
+ use_one_embed=False
10
+ use_multi_scaling=False
11
+ mlp_type=ffn
12
+ decoder_num_blocks=2
13
+ robot_platform=aloha
14
+ MODE=${RUN_MODE}_robot_platform_${robot_platform}
15
+ #========== !NOTE! ==========#
16
+ use_l1_regression=True
17
+ num_images_in_input=3
18
+ wandb_entity=chenghaha
19
+ wandb_project=robotwin
20
+ wandb_log_freq=1
21
+ use_proprio=True
22
+ use_diffusion=False
23
+ use_film=True
24
+ num_steps_before_decay=1000
25
+ save_freq=2000
26
+ max_steps=2000
27
+ vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
28
+ data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
29
+ dataset_name=grab_roller_aloha_agilex_50
30
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
31
+ #========== get run_id ==========#
32
+ note_parts=("${MODE}")
33
+
34
+ if [ "$use_l1_regression" = "True" ]; then
35
+ note_parts+=("L1_regression")
36
+ fi
37
+
38
+ if [ "$num_images_in_input" == 1 ]; then
39
+ note_parts+=("3rd_person_img")
40
+ else
41
+ note_parts+=("3rd_person_img_and_wrist")
42
+ fi
43
+
44
+ if [ "$use_l1_regression" = "True" ]; then
45
+ note_parts+=("proprio_state")
46
+ fi
47
+
48
+ if [ "$use_film" = "True" ]; then
49
+ note_parts+=("Film")
50
+ fi
51
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
52
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
53
+
54
+ #========== enter environment ==========#
55
+ source activate openvla-oft
56
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
57
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
58
+
59
+ #========== run ==========#
60
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
61
+ --vla_path "$vla_path" \
62
+ --data_root_dir "$data_root_dir" \
63
+ --dataset_name "$dataset_name" \
64
+ --run_root_dir "$run_root_dir" \
65
+ --use_l1_regression "$use_l1_regression" \
66
+ --use_diffusion "$use_diffusion" \
67
+ --use_film "$use_film" \
68
+ --num_images_in_input "$num_images_in_input" \
69
+ --use_proprio "$use_proprio" \
70
+ --batch_size "$batch_size" \
71
+ --learning_rate 5e-5 \
72
+ --num_steps_before_decay "$num_steps_before_decay" \
73
+ --max_steps "$max_steps" \
74
+ --save_freq "$save_freq" \
75
+ --save_latest_checkpoint_only False \
76
+ --image_aug True \
77
+ --lora_rank 32 \
78
+ --wandb_entity "$wandb_entity" \
79
+ --wandb_project "$wandb_project" \
80
+ --wandb_log_freq "$wandb_log_freq" \
81
+ --run_id_note "$run_id_note_value" \
82
+ --use_predict_future_prop "$use_predict_future_prop" \
83
+ --use_action_ts_head "$use_action_ts_head" \
84
+ --use_one_embed "$use_one_embed" \
85
+ --use_multi_scaling "$use_multi_scaling" \
86
+ --mlp_type "$mlp_type" \
87
+ --decoder_num_blocks "$decoder_num_blocks" \
88
+ --robot_platform "$robot_platform"
run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_50_l2.sh ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=simvla_twin2
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=simvla_50
6
+ use_predict_future_prop=False
7
+ batch_size=4
8
+ use_action_ts_head=True
9
+ use_one_embed=True
10
+ use_multi_scaling=False
11
+ mlp_type=ffn
12
+ decoder_num_blocks=4
13
+ robot_platform=50_al
14
+ proj_type=gelu_linear
15
+ ffn_type=gelu
16
+ expand_inner_ratio=1
17
+ linear_drop_ratio=0.1
18
+ multi_queries_num=50
19
+ multi_query_norm_type=layernorm
20
+ action_norm=l2
21
+ MODE=${RUN_MODE}_inner${expand_inner_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
22
+ #========== !NOTE! ==========#
23
+ use_l1_regression=True
24
+ num_images_in_input=3
25
+ wandb_entity=chenghaha
26
+ wandb_project=robotwin
27
+ wandb_log_freq=1
28
+ use_proprio=True
29
+ use_diffusion=False
30
+ use_film=True
31
+ num_steps_before_decay=2000
32
+ save_freq=3000
33
+ max_steps=3000
34
+ vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
35
+ data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
36
+ dataset_name=grab_roller_aloha_agilex_50
37
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
38
+ #========== get run_id ==========#
39
+ note_parts=("${MODE}")
40
+
41
+ # if [ "$use_l1_regression" = "True" ]; then
42
+ # note_parts+=("L1_regression")
43
+ # fi
44
+
45
+ # if [ "$num_images_in_input" == 1 ]; then
46
+ # note_parts+=("3rd_person_img")
47
+ # else
48
+ # note_parts+=("3rd_person_img_and_wrist")
49
+ # fi
50
+
51
+ # if [ "$use_l1_regression" = "True" ]; then
52
+ # note_parts+=("proprio_state")
53
+ # fi
54
+
55
+ # if [ "$use_film" = "True" ]; then
56
+ # note_parts+=("Film")
57
+ # fi
58
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
59
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
60
+
61
+ #========== enter environment ==========#
62
+ conda activate openvla-oft
63
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
64
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
65
+
66
+ #========== run ==========#
67
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
68
+ --vla_path "$vla_path" \
69
+ --data_root_dir "$data_root_dir" \
70
+ --dataset_name "$dataset_name" \
71
+ --run_root_dir "$run_root_dir" \
72
+ --use_l1_regression "$use_l1_regression" \
73
+ --use_diffusion "$use_diffusion" \
74
+ --use_film "$use_film" \
75
+ --num_images_in_input "$num_images_in_input" \
76
+ --use_proprio "$use_proprio" \
77
+ --batch_size "$batch_size" \
78
+ --learning_rate 5e-5 \
79
+ --num_steps_before_decay "$num_steps_before_decay" \
80
+ --max_steps "$max_steps" \
81
+ --save_freq "$save_freq" \
82
+ --save_latest_checkpoint_only False \
83
+ --image_aug True \
84
+ --lora_rank 32 \
85
+ --wandb_entity "$wandb_entity" \
86
+ --wandb_project "$wandb_project" \
87
+ --wandb_log_freq "$wandb_log_freq" \
88
+ --run_id_note "$run_id_note_value" \
89
+ --use_predict_future_prop "$use_predict_future_prop" \
90
+ --use_action_ts_head "$use_action_ts_head" \
91
+ --use_one_embed "$use_one_embed" \
92
+ --use_multi_scaling "$use_multi_scaling" \
93
+ --mlp_type "$mlp_type" \
94
+ --decoder_num_blocks "$decoder_num_blocks" \
95
+ --robot_platform "$robot_platform" \
96
+ --proj_type "$proj_type" \
97
+ --ffn_type "$ffn_type" \
98
+ --expand_inner_ratio "$expand_inner_ratio" \
99
+ --linear_drop_ratio "$linear_drop_ratio" \
100
+ --multi_query_norm_type "$multi_query_norm_type" \
101
+ --multi_queries_num "$multi_queries_num" \
102
+ --action_norm "$action_norm"
run_scripts/ffn_q2a/bridge/exffn_relu_connector_linear_relu.sh ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_q2a
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=simvla_q2a
6
+ use_predict_future_prop=False
7
+ batch_size=16
8
+ use_action_ts_head=True
9
+ use_one_embed=True
10
+ use_multi_scaling=False
11
+ mlp_type=ffn
12
+ decoder_num_blocks=6
13
+ robot_platform=bridge
14
+ without_head_drop_out=True
15
+ proj_type=linear_relu
16
+ ffn_type=relu
17
+ expand_actiondim_ratio=2.0
18
+ MODE=${RUN_MODE}_exffn${expand_actiondim_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
19
+ #========== !NOTE! ==========#
20
+ use_l1_regression=True
21
+ num_images_in_input=1
22
+ wandb_entity=chenghaha
23
+ wandb_project=fastvla
24
+ wandb_log_freq=1
25
+ use_proprio=False
26
+ use_diffusion=False
27
+ use_film=False
28
+ num_steps_before_decay=20000
29
+ save_freq=10000
30
+ max_steps=50000
31
+ vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
32
+ data_root_dir=$ROOT_PATH/datasets/openx/data/origin
33
+ dataset_name=bridge
34
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
35
+ #========== get run_id ==========#
36
+ note_parts=("${MODE}")
37
+
38
+ # if [ "$use_l1_regression" = "True" ]; then
39
+ # note_parts+=("L1_regression")
40
+ # fi
41
+
42
+ # if [ "$num_images_in_input" == 1 ]; then
43
+ # note_parts+=("3rd_person_img")
44
+ # else
45
+ # note_parts+=("3rd_person_img_and_wrist")
46
+ # fi
47
+
48
+ # if [ "$use_l1_regression" = "True" ]; then
49
+ # note_parts+=("proprio_state")
50
+ # fi
51
+
52
+ # if [ "$use_film" = "True" ]; then
53
+ # note_parts+=("Film")
54
+ # fi
55
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
56
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
57
+
58
+ #========== enter environment ==========#
59
+ conda activate openvla-oft
60
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
61
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
62
+
63
+ #========== run ==========#
64
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
65
+ --vla_path "$vla_path" \
66
+ --data_root_dir "$data_root_dir" \
67
+ --dataset_name "$dataset_name" \
68
+ --run_root_dir "$run_root_dir" \
69
+ --use_l1_regression "$use_l1_regression" \
70
+ --use_diffusion "$use_diffusion" \
71
+ --use_film "$use_film" \
72
+ --num_images_in_input "$num_images_in_input" \
73
+ --use_proprio "$use_proprio" \
74
+ --batch_size "$batch_size" \
75
+ --learning_rate 5e-4 \
76
+ --num_steps_before_decay "$num_steps_before_decay" \
77
+ --max_steps "$max_steps" \
78
+ --save_freq "$save_freq" \
79
+ --save_latest_checkpoint_only False \
80
+ --image_aug True \
81
+ --lora_rank 32 \
82
+ --wandb_entity "$wandb_entity" \
83
+ --wandb_project "$wandb_project" \
84
+ --wandb_log_freq "$wandb_log_freq" \
85
+ --run_id_note "$run_id_note_value" \
86
+ --use_predict_future_prop "$use_predict_future_prop" \
87
+ --use_action_ts_head "$use_action_ts_head" \
88
+ --use_one_embed "$use_one_embed" \
89
+ --use_multi_scaling "$use_multi_scaling" \
90
+ --mlp_type "$mlp_type" \
91
+ --decoder_num_blocks "$decoder_num_blocks" \
92
+ --robot_platform "$robot_platform" \
93
+ --proj_type "$proj_type" \
94
+ --ffn_type "$ffn_type" \
95
+ --expand_actiondim_ratio "$expand_actiondim_ratio"
run_scripts/ffn_q2a/bridge/run_bridge.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ bash run_scripts/ffn_q2a/bridge/exffn_gelu_bridge_drop0_5.sh
2
+ bash run_scripts/ffn_q2a/bridge/exffn_gelu_bridge.sh
run_scripts/ffn_q2a/franka/exffn_gelu_franka.sh ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=SimVLA_Condition
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=simvla_q2a
6
+ use_predict_future_prop=False
7
+ batch_size=16
8
+ use_action_ts_head=True
9
+ use_one_embed=True
10
+ use_multi_scaling=False
11
+ mlp_type=ffn
12
+ decoder_num_blocks=4
13
+ robot_platform=rt1
14
+ without_head_drop_out=True
15
+ proj_type=gelu_linear
16
+ ffn_type=gelu
17
+ expand_actiondim_ratio=1.0
18
+ MODE=${RUN_MODE}_exffn${expand_actiondim_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
19
+ #========== !NOTE! ==========#
20
+ use_l1_regression=True
21
+ num_images_in_input=1
22
+ wandb_entity=chenghaha
23
+ wandb_project=fastvla
24
+ wandb_log_freq=1
25
+ use_proprio=False
26
+ use_diffusion=False
27
+ use_film=False
28
+ num_steps_before_decay=30000
29
+ save_freq=10000
30
+ max_steps=60000
31
+ vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
32
+ data_root_dir=$ROOT_PATH/datasets/openx/data/origin
33
+ dataset_name=rt1
34
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
35
+ #========== get run_id ==========#
36
+ note_parts=("${MODE}")
37
+
38
+ # if [ "$use_l1_regression" = "True" ]; then
39
+ # note_parts+=("L1_regression")
40
+ # fi
41
+
42
+ # if [ "$num_images_in_input" == 1 ]; then
43
+ # note_parts+=("3rd_person_img")
44
+ # else
45
+ # note_parts+=("3rd_person_img_and_wrist")
46
+ # fi
47
+
48
+ # if [ "$use_l1_regression" = "True" ]; then
49
+ # note_parts+=("proprio_state")
50
+ # fi
51
+
52
+ # if [ "$use_film" = "True" ]; then
53
+ # note_parts+=("Film")
54
+ # fi
55
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
56
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
57
+
58
+ #========== enter environment ==========#
59
+ conda activate openvla-oft
60
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
61
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
62
+
63
+ #========== run ==========#
64
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
65
+ --vla_path "$vla_path" \
66
+ --data_root_dir "$data_root_dir" \
67
+ --dataset_name "$dataset_name" \
68
+ --run_root_dir "$run_root_dir" \
69
+ --use_l1_regression "$use_l1_regression" \
70
+ --use_diffusion "$use_diffusion" \
71
+ --use_film "$use_film" \
72
+ --num_images_in_input "$num_images_in_input" \
73
+ --use_proprio "$use_proprio" \
74
+ --batch_size "$batch_size" \
75
+ --learning_rate 5e-4 \
76
+ --num_steps_before_decay "$num_steps_before_decay" \
77
+ --max_steps "$max_steps" \
78
+ --save_freq "$save_freq" \
79
+ --save_latest_checkpoint_only False \
80
+ --image_aug True \
81
+ --lora_rank 32 \
82
+ --wandb_entity "$wandb_entity" \
83
+ --wandb_project "$wandb_project" \
84
+ --wandb_log_freq "$wandb_log_freq" \
85
+ --run_id_note "$run_id_note_value" \
86
+ --use_predict_future_prop "$use_predict_future_prop" \
87
+ --use_action_ts_head "$use_action_ts_head" \
88
+ --use_one_embed "$use_one_embed" \
89
+ --use_multi_scaling "$use_multi_scaling" \
90
+ --mlp_type "$mlp_type" \
91
+ --decoder_num_blocks "$decoder_num_blocks" \
92
+ --robot_platform "$robot_platform" \
93
+ --proj_type "$proj_type" \
94
+ --ffn_type "$ffn_type" \
95
+ --expand_actiondim_ratio "$expand_actiondim_ratio"
run_scripts/ffn_q2a/libero_moe/debug_moe_lit.sh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=SimVLA
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/jiajiuyang-240108580167/chengdongzhou
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=simvla_q2a_lit
6
+ use_predict_future_prop=False
7
+ batch_size=2
8
+ use_action_ts_head=True
9
+ use_one_embed=True
10
+ use_multi_scaling=False
11
+ mlp_type=moe
12
+ decoder_num_blocks=2
13
+ robot_platform=16_li
14
+ without_head_drop_out=True
15
+ proj_type=gelu_linear
16
+ ffn_type=gelu
17
+ num_experts=8
18
+ expand_inner_ratio=2
19
+ top_k=2
20
+ expand_actiondim_ratio=0.5
21
+ MODE=${RUN_MODE}_ex${expand_actiondim_ratio}_inner${expand_inner_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}_num_experts${num_experts}_top_k{$top_k}
22
+ #========== !NOTE! ==========#
23
+ use_l1_regression=True
24
+ num_images_in_input=1
25
+ wandb_entity=chenghaha
26
+ wandb_project=fastvla
27
+ wandb_log_freq=1
28
+ use_proprio=False
29
+ use_diffusion=False
30
+ use_film=False
31
+ num_steps_before_decay=20000
32
+ save_freq=10000
33
+ max_steps=50000
34
+ vla_path=$ROOT_PATH/ai_models/openvla
35
+ data_root_dir=$ROOT_PATH/datasets/openvla/modified_libero_rlds
36
+ dataset_name=libero_4_task_suites_no_noops
37
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
38
+ #========== get run_id ==========#
39
+ note_parts=("${MODE}")
40
+
41
+ # if [ "$use_l1_regression" = "True" ]; then
42
+ # note_parts+=("L1_regression")
43
+ # fi
44
+
45
+ # if [ "$num_images_in_input" == 1 ]; then
46
+ # note_parts+=("3rd_person_img")
47
+ # else
48
+ # note_parts+=("3rd_person_img_and_wrist")
49
+ # fi
50
+
51
+ # if [ "$use_l1_regression" = "True" ]; then
52
+ # note_parts+=("proprio_state")
53
+ # fi
54
+
55
+ # if [ "$use_film" = "True" ]; then
56
+ # note_parts+=("Film")
57
+ # fi
58
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
59
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
60
+
61
+ #========== enter environment ==========#
62
+ conda activate openvla-oft
63
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
64
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
65
+
66
+ #========== run ==========#
67
+ WANDB_CONSOLE=off WANDB_MODE=offline python -m debugpy --listen 1234 --wait-for-client '/opt/conda/envs/openvla-oft/bin/torchrun' --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune.py \
68
+ --vla_path "$vla_path" \
69
+ --data_root_dir "$data_root_dir" \
70
+ --dataset_name "$dataset_name" \
71
+ --run_root_dir "$run_root_dir" \
72
+ --use_l1_regression "$use_l1_regression" \
73
+ --use_diffusion "$use_diffusion" \
74
+ --use_film "$use_film" \
75
+ --num_images_in_input "$num_images_in_input" \
76
+ --use_proprio "$use_proprio" \
77
+ --batch_size "$batch_size" \
78
+ --learning_rate 5e-4 \
79
+ --num_steps_before_decay "$num_steps_before_decay" \
80
+ --max_steps "$max_steps" \
81
+ --save_freq "$save_freq" \
82
+ --save_latest_checkpoint_only False \
83
+ --image_aug True \
84
+ --lora_rank 32 \
85
+ --wandb_entity "$wandb_entity" \
86
+ --wandb_project "$wandb_project" \
87
+ --wandb_log_freq "$wandb_log_freq" \
88
+ --run_id_note "$run_id_note_value" \
89
+ --use_predict_future_prop "$use_predict_future_prop" \
90
+ --use_action_ts_head "$use_action_ts_head" \
91
+ --use_one_embed "$use_one_embed" \
92
+ --use_multi_scaling "$use_multi_scaling" \
93
+ --mlp_type "$mlp_type" \
94
+ --decoder_num_blocks "$decoder_num_blocks" \
95
+ --robot_platform "$robot_platform" \
96
+ --proj_type "$proj_type" \
97
+ --ffn_type "$ffn_type" \
98
+ --expand_inner_ratio "$expand_inner_ratio" \
99
+ --expand_actiondim_ratio "$expand_actiondim_ratio" \
100
+ --num_experts "$num_experts" \
101
+ --top_k "$top_k"
run_scripts/ffn_q2a/simhead/simhead_contrastive.sh ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=SimVLA_Condition
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=simvla_q2a
6
+ use_predict_future_prop=False
7
+ batch_size=16
8
+ use_action_ts_head=True
9
+ use_one_embed=True
10
+ use_multi_scaling=False
11
+ mlp_type=ffn
12
+ decoder_num_blocks=2
13
+ robot_platform=16_li
14
+ without_head_drop_out=True
15
+ without_action_projector=True
16
+ ffn_type=gelu
17
+ use_l2norm=False
18
+ expand_inner_ratio=2.0
19
+ use_contrastive_loss=True
20
+ MODE=${RUN_MODE}_usecons${use_contrastive_loss}_newexinner_${expand_inner_ratio}_without_ap_ffn_type_${ffn_type}_use_l2norm${use_l2norm}_mlp_${mlp_type}_num_${decoder_num_blocks}
21
+ #========== !NOTE! ==========#
22
+ use_l1_regression=True
23
+ num_images_in_input=1
24
+ wandb_entity=chenghaha
25
+ wandb_project=fastvla
26
+ wandb_log_freq=1
27
+ use_proprio=False
28
+ use_diffusion=False
29
+ use_film=False
30
+ num_steps_before_decay=30000
31
+ save_freq=10000
32
+ max_steps=50000
33
+ vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
34
+ data_root_dir=$ROOT_PATH/datasets/openvla/modified_libero_rlds
35
+ dataset_name=libero_4_task_suites_no_noops
36
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
37
+ #========== get run_id ==========#
38
+ note_parts=("${MODE}")
39
+
40
+ # if [ "$use_l1_regression" = "True" ]; then
41
+ # note_parts+=("L1_regression")
42
+ # fi
43
+
44
+ # if [ "$num_images_in_input" == 1 ]; then
45
+ # note_parts+=("3rd_person_img")
46
+ # else
47
+ # note_parts+=("3rd_person_img_and_wrist")
48
+ # fi
49
+
50
+ # if [ "$use_l1_regression" = "True" ]; then
51
+ # note_parts+=("proprio_state")
52
+ # fi
53
+
54
+ # if [ "$use_film" = "True" ]; then
55
+ # note_parts+=("Film")
56
+ # fi
57
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
58
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
59
+
60
+ #========== enter environment ==========#
61
+ conda activate openvla-oft
62
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
63
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
64
+
65
+ #========== run ==========#
66
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
67
+ --vla_path "$vla_path" \
68
+ --data_root_dir "$data_root_dir" \
69
+ --dataset_name "$dataset_name" \
70
+ --run_root_dir "$run_root_dir" \
71
+ --use_l1_regression "$use_l1_regression" \
72
+ --use_diffusion "$use_diffusion" \
73
+ --use_film "$use_film" \
74
+ --num_images_in_input "$num_images_in_input" \
75
+ --use_proprio "$use_proprio" \
76
+ --batch_size "$batch_size" \
77
+ --learning_rate 5e-4 \
78
+ --num_steps_before_decay "$num_steps_before_decay" \
79
+ --max_steps "$max_steps" \
80
+ --save_freq "$save_freq" \
81
+ --save_latest_checkpoint_only False \
82
+ --image_aug True \
83
+ --lora_rank 32 \
84
+ --wandb_entity "$wandb_entity" \
85
+ --wandb_project "$wandb_project" \
86
+ --wandb_log_freq "$wandb_log_freq" \
87
+ --run_id_note "$run_id_note_value" \
88
+ --use_predict_future_prop "$use_predict_future_prop" \
89
+ --use_action_ts_head "$use_action_ts_head" \
90
+ --use_one_embed "$use_one_embed" \
91
+ --use_multi_scaling "$use_multi_scaling" \
92
+ --mlp_type "$mlp_type" \
93
+ --decoder_num_blocks "$decoder_num_blocks" \
94
+ --robot_platform "$robot_platform" \
95
+ --proj_type "$proj_type" \
96
+ --ffn_type "$ffn_type" \
97
+ --use_l2norm "$use_l2norm" \
98
+ --expand_inner_ratio "$expand_inner_ratio" \
99
+ --without_action_projector "$without_action_projector" \
100
+ --use_contrastive_loss "$use_contrastive_loss"
run_scripts/pp/pp.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=fastvla_multi_scale_query
3
+ #========== !NOTE! ==========#
4
+ RUN_MODE=simvla_PP
5
+ use_predict_future_prop=True
6
+ batch_size=16
7
+ use_action_ts_head=True
8
+ use_one_embed=True
9
+ use_multi_scaling=False
10
+ mlp_type=ffn
11
+ decoder_num_blocks=4
12
+ robot_platform=libero
13
+ MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
14
+ #========== !NOTE! ==========#
15
+ use_l1_regression=True
16
+ num_images_in_input=1
17
+ wandb_entity=chenghaha
18
+ wandb_project=fastvla
19
+ wandb_log_freq=1
20
+ use_proprio=True
21
+ use_diffusion=False
22
+ use_film=False
23
+ num_steps_before_decay=20000
24
+ save_freq=5000
25
+ max_steps=50000
26
+ vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
27
+ data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
28
+ dataset_name=libero_4_task_suites_no_noops
29
+ run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
30
+ #========== get run_id ==========#
31
+ note_parts=("${MODE}")
32
+
33
+ # if [ "$use_l1_regression" = "True" ]; then
34
+ # note_parts+=("L1_regression")
35
+ # fi
36
+
37
+ # if [ "$num_images_in_input" == 1 ]; then
38
+ # note_parts+=("3rd_person_img")
39
+ # else
40
+ # note_parts+=("3rd_person_img_and_wrist")
41
+ # fi
42
+
43
+ # if [ "$use_l1_regression" = "True" ]; then
44
+ # note_parts+=("proprio_state")
45
+ # fi
46
+
47
+ # if [ "$use_film" = "True" ]; then
48
+ # note_parts+=("Film")
49
+ # fi
50
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
51
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
52
+
53
+ #========== enter environment ==========#
54
+ # conda activate openvla-oft
55
+ cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
56
+ export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
57
+
58
+ #========== run ==========#
59
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
60
+ --vla_path "$vla_path" \
61
+ --data_root_dir "$data_root_dir" \
62
+ --dataset_name "$dataset_name" \
63
+ --run_root_dir "$run_root_dir" \
64
+ --use_l1_regression "$use_l1_regression" \
65
+ --use_diffusion "$use_diffusion" \
66
+ --use_film "$use_film" \
67
+ --num_images_in_input "$num_images_in_input" \
68
+ --use_proprio "$use_proprio" \
69
+ --batch_size "$batch_size" \
70
+ --learning_rate 5e-4 \
71
+ --num_steps_before_decay "$num_steps_before_decay" \
72
+ --max_steps "$max_steps" \
73
+ --save_freq "$save_freq" \
74
+ --save_latest_checkpoint_only False \
75
+ --image_aug True \
76
+ --lora_rank 32 \
77
+ --wandb_entity "$wandb_entity" \
78
+ --wandb_project "$wandb_project" \
79
+ --wandb_log_freq "$wandb_log_freq" \
80
+ --run_id_note "$run_id_note_value" \
81
+ --use_predict_future_prop "$use_predict_future_prop" \
82
+ --use_action_ts_head "$use_action_ts_head" \
83
+ --use_one_embed "$use_one_embed" \
84
+ --use_multi_scaling "$use_multi_scaling" \
85
+ --mlp_type "$mlp_type" \
86
+ --decoder_num_blocks "$decoder_num_blocks" \
87
+ --robot_platform "$robot_platform"
run_scripts/run.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bash run_scripts/multiscaling/exp_multiscaling.sh
2
+
3
+ # bash run_scripts/ffn_or_gating/gating.sh
4
+
5
+ # bash run_scripts/ffn/ffn0.sh
6
+ # bash run_scripts/ffn/ffn2.sh
7
+ # bash run_scripts/ffn/ffn4.sh
8
+ # bash run_scripts/ffn/ffn6.sh
9
+ # bash run_scripts/ffn/ffn8.sh
10
+ # bash run_scripts/multiscaling/2latentmsahead.sh
11
+ # bash run_scripts/multiscaling/2msahead.sh
12
+ # bash run_scripts/ffn/2ffn6.sh
13
+ # bash run_scripts/all_input/2all_inputs.sh
14
+
15
+ # bash run_scripts/ffn_or_gating/gating.sh
16
+
17
+ # bash run_scripts/ffn/ffn0.sh
18
+ # bash run_scripts/ffn/ffn2.sh
19
+ # bash run_scripts/ffn/ffn4.sh
20
+ # bash run_scripts/ffn/ffn6.sh
21
+ # bash run_scripts/ffn/ffn8.sh
22
+
23
+
24
+ # bash run_scripts/ffn/3ffn2.sh
25
+ # bash run_scripts/ffn/3ffn6.sh
26
+ # bash run_scripts/ffn/3postffn2.sh
27
+ # bash run_scripts/ffn/3postffn6.sh
28
+
29
+ # bash run_scripts/ffn/4ffn_withactionprojector.sh
30
+ # bash run_scripts/ffn/4ffn6_withactionprojector.sh
31
+
32
+ # bash run_scripts/ffn/4ffn_withactionprojector.sh
33
+
34
+ bash run_scripts/ffn/5ffn_withactionprojector.sh
35
+ bash run_scripts/ffn/5ffn6_withactionprojector.sh
scripts/extern/verify_prismatic.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ verify_prismatic.py
3
+
4
+ Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate().
5
+ """
6
+
7
+ import time
8
+
9
+ import requests
10
+ import torch
11
+ from PIL import Image
12
+ from transformers import AutoModelForVision2Seq, AutoProcessor
13
+
14
+ # === Verification Arguments ===
15
+ MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b"
16
+ DEFAULT_IMAGE_URL = (
17
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
18
+ )
19
+
20
+ if "-prism-" in MODEL_PATH:
21
+ SAMPLE_PROMPTS_FOR_GENERATION = [
22
+ "In: What is sitting in the coffee?\nOut:",
23
+ "In: What's the name of the food on the plate?\nOut:",
24
+ "In: caption.\nOut:",
25
+ "In: how many beinets..?\nOut:",
26
+ "In: Can you give me a lyrical description of the scene\nOut:",
27
+ ]
28
+ else:
29
+ SYSTEM_PROMPT = (
30
+ "A chat between a curious user and an artificial intelligence assistant. "
31
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
32
+ )
33
+ SAMPLE_PROMPTS_FOR_GENERATION = [
34
+ f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:",
35
+ f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:",
36
+ f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:",
37
+ f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:",
38
+ f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:",
39
+ ]
40
+
41
+
42
+ @torch.inference_mode()
43
+ def verify_prismatic() -> None:
44
+ print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`")
45
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
46
+
47
+ # Load Processor & VLM
48
+ print("[*] Instantiating Processor and Pretrained VLM")
49
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
50
+
51
+ # === AUTOCAST MODE ===
52
+ # print("[*] Loading in BF16 Autocast Mode")
53
+ # vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to(
54
+ # device, dtype=torch.bfloat16
55
+ # )
56
+
57
+ # === NATIVE BFLOAT16 MODE ===
58
+ # print("[*] Loading in BF16")
59
+ # vlm = AutoModelForVision2Seq.from_pretrained(
60
+ # MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
61
+ # ).to(device)
62
+
63
+ # === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] ===
64
+ print("[*] Loading in BF16 with Flash-Attention Enabled")
65
+ vlm = AutoModelForVision2Seq.from_pretrained(
66
+ MODEL_PATH,
67
+ attn_implementation="flash_attention_2",
68
+ torch_dtype=torch.bfloat16,
69
+ low_cpu_mem_usage=True,
70
+ trust_remote_code=True,
71
+ ).to(device)
72
+
73
+ # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] ===
74
+ # print("[*] Loading in 8-Bit Quantization Mode")
75
+ # vlm = AutoModelForVision2Seq.from_pretrained(
76
+ # MODEL_PATH,
77
+ # attn_implementation="flash_attention_2",
78
+ # torch_dtype=torch.float16,
79
+ # quantization_config=BitsAndBytesConfig(load_in_8bit=True),
80
+ # low_cpu_mem_usage=True,
81
+ # trust_remote_code=True,
82
+ # )
83
+
84
+ # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] ===
85
+ # print("[*] Loading in 4-Bit Quantization Mode")
86
+ # vlm = AutoModelForVision2Seq.from_pretrained(
87
+ # MODEL_PATH,
88
+ # attn_implementation="flash_attention_2",
89
+ # torch_dtype=torch.float16,
90
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True),
91
+ # low_cpu_mem_usage=True,
92
+ # trust_remote_code=True,
93
+ # )
94
+
95
+ # Iterate over Sample Prompts =>> Generate
96
+ image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB")
97
+ num_tokens, total_time = 0, 0.0
98
+
99
+ print("[*] Iterating over Sample Prompts\n===\n")
100
+ for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION):
101
+ # === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) ===
102
+ # inputs = processor(prompt, image).to(device)
103
+ #
104
+ # # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py`
105
+ # # =>> Running in native BF16 is also fine (but leads to slightly different generations)
106
+ # with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
107
+ # gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)
108
+
109
+ # === BFLOAT16 MODE ===
110
+ inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
111
+
112
+ # === 8-BIT/4-BIT QUANTIZATION MODE ===
113
+ # inputs = processor(prompt, image).to(device, dtype=torch.float16)
114
+
115
+ # Run Inference
116
+ gen_ids = None
117
+ for _ in range(5):
118
+ start_time = time.time()
119
+ gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)
120
+ total_time += time.time() - start_time
121
+
122
+ gen_ids = gen_ids[0, inputs.input_ids.shape[1] :]
123
+ num_tokens += len(gen_ids)
124
+
125
+ # ===
126
+ gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip()
127
+ print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n")
128
+
129
+ # Compute Tokens / Second
130
+ print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }")
131
+
132
+
133
+ if __name__ == "__main__":
134
+ verify_prismatic()
scripts/pretrain.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ pretrain.py
3
+
4
+ Pretraining script for Prismatic VLM pretraining in native PyTorch, using Fully-Sharded Data Parallel (FSDP) to run
5
+ distributed training across GPUs. By default, assumes that CUDA toolkit is >= 11.0 (to support BF16 mixed precision).
6
+
7
+ Notes & Prerequisites:
8
+ - We're loading LLaMa-2 (and possibly other) gated models from HuggingFace (HF Hub); these require an auth_token.
9
+ For LLaMa-2, make sure to first get Meta approval, then fill out the form at the top of the HF LLaMa-2 page:
10
+ => Link: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
11
+ => Generate Token (from `huggingface.co`): Settings / Access Tokens / New "Read" Token
12
+ => Set `cfg.hf_token` to file path with token (as single line text file) or environment variable name
13
+
14
+ - If you want to set a custom location for all HF / TIMM artifacts --> `export HF_HOME="<PATH>"` *before* running!
15
+ => For example (add to end of .bashrc): `export HF_HOME="/mnt/fsx/skaramcheti/cache"`
16
+
17
+ Run with:
18
+ - [Single Node One-GPU (Debug)] : torchrun --standalone --nnodes 1 --nproc-per-node 1 scripts/pretrain.py
19
+ - [Single Node Multi-GPU (= $K)]: torchrun --standalone --nnodes 1 --nproc-per-node $K scripts/pretrain.py
20
+ - [Multi-Node/AWS Sagemaker] Depends on your individual setup; file an issue if you have trouble!
21
+ """
22
+
23
+ import json
24
+ import os
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Optional, Tuple, Union
28
+
29
+ import draccus
30
+ import torch
31
+ import torch.distributed as dist
32
+ import yaml
33
+
34
+ from prismatic.conf import DatasetConfig, DatasetRegistry, ModelConfig, ModelRegistry
35
+ from prismatic.models import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm
36
+ from prismatic.overwatch import initialize_overwatch
37
+ from prismatic.preprocessing import get_dataset_and_collator
38
+ from prismatic.training import Metrics, get_train_strategy
39
+ from prismatic.util import set_global_seed
40
+
41
+ # Disable Tokenizers Parallelism to Play Nice w/ PyTorch Multiprocessing DataLoaders
42
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
43
+
44
+ # Initialize Overwatch =>> Wraps `logging.Logger`
45
+ overwatch = initialize_overwatch(__name__)
46
+
47
+
48
+ @dataclass
49
+ class PretrainConfig:
50
+ # fmt: off
51
+
52
+ # ModelConfig (`prismatic/conf/models.py`); override with --model.type `ModelRegistry.<MODEL>.model_id`
53
+ model: ModelConfig = field(
54
+ default_factory=ModelConfig.get_choice_class(ModelRegistry.PRISM_DINOSIGLIP_CONTROLLED_7B.model_id)
55
+ )
56
+
57
+ # DatasetConfig (`prismatic/conf/datasets.py`); override with --dataset.type `DatasetRegistry.<DATASET>.dataset_id`
58
+ dataset: DatasetConfig = field(
59
+ default_factory=DatasetConfig.get_choice_class(DatasetRegistry.LLAVA_V15.dataset_id)
60
+ )
61
+
62
+ # Pretraining Stage in < align (projector-only) | finetune (projector + LLM) | full-finetune (all) >
63
+ # ---
64
+ stage: str = "finetune" # Pretraining Stage in < align | finetune >
65
+ pretrained_checkpoint: Optional[Path] = None # Pretrained Checkpoint to Load (for `finetune`)
66
+ # if None =>> will match on (run_dir / `align`)
67
+
68
+ # Run Arguments
69
+ run_id: Optional[str] = None # Run ID for logging, Weights & Biases
70
+ run_root_dir: Path = Path("/mnt/fsx/x-prismatic-vlms/runs") # Path to directory to store logs & checkpoints
71
+ seed: int = 7 # Random seed (for reproducibility)
72
+
73
+ # HF Hub Credentials (for any gated models)
74
+ hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
75
+
76
+ # Tracking Parameters
77
+ trackers: Tuple[str, ...] = ("jsonl", "wandb") # Trackers to initialize (if W&B, add config!)
78
+ wandb_project: str = "onyx-vlms" # Name of W&B project (default: `prismatic`)
79
+ wandb_entity: Optional[str] = "stanford-voltron" # Name of W&B entity (default: None)
80
+
81
+ def __post_init__(self) -> None:
82
+ """Set optimization parameters based on `stage` in {"align", "finetune"}."""
83
+ if self.stage == "align":
84
+ self.epochs = self.model.align_epochs
85
+ self.max_steps = self.model.align_max_steps
86
+ self.global_batch_size = self.model.align_global_batch_size
87
+ self.per_device_batch_size = self.model.align_per_device_batch_size
88
+
89
+ self.learning_rate = self.model.align_learning_rate
90
+ self.weight_decay = self.model.align_weight_decay
91
+ self.max_grad_norm = self.model.align_max_grad_norm
92
+ self.lr_scheduler_type = self.model.align_lr_scheduler_type
93
+ self.warmup_ratio = self.model.align_warmup_ratio
94
+
95
+ self.train_strategy = self.model.align_train_strategy
96
+
97
+ elif self.stage.endswith("finetune"):
98
+ self.epochs = self.model.finetune_epochs
99
+ self.max_steps = self.model.finetune_max_steps
100
+ self.global_batch_size = self.model.finetune_global_batch_size
101
+ self.per_device_batch_size = self.model.finetune_per_device_batch_size
102
+
103
+ self.learning_rate = self.model.finetune_learning_rate
104
+ self.weight_decay = self.model.finetune_weight_decay
105
+ self.max_grad_norm = self.model.finetune_max_grad_norm
106
+ self.lr_scheduler_type = self.model.finetune_lr_scheduler_type
107
+ self.warmup_ratio = self.model.finetune_warmup_ratio
108
+
109
+ self.train_strategy = self.model.finetune_train_strategy
110
+
111
+ else:
112
+ raise ValueError(f"Stage `{self.stage}` is not supported!")
113
+
114
+ # fmt: on
115
+
116
+
117
+ @draccus.wrap()
118
+ def pretrain(cfg: PretrainConfig) -> None:
119
+ overwatch.info("Prismatic VLM Training :: Gathering Light")
120
+
121
+ # Note => Under `torchrun` initializing `overwatch` will automatically set up `torch.distributed`
122
+ torch.cuda.set_device(device_id := overwatch.local_rank())
123
+ torch.cuda.empty_cache()
124
+
125
+ # Create Unique Run Name & Save Directory
126
+ model_id = cfg.model.model_id
127
+ if (dataset_id := cfg.dataset.dataset_id) == "llava-v15":
128
+ cfg.run_id = f"{model_id}+stage-{cfg.stage}+x{cfg.seed}" if cfg.run_id is None else cfg.run_id
129
+ else:
130
+ cfg.run_id = f"{dataset_id}+{model_id}+stage-{cfg.stage}+x{cfg.seed}" if cfg.run_id is None else cfg.run_id
131
+
132
+ # Start =>> Build Directories and Set Randomness
133
+ overwatch.info('"Life is like a prism; what you see depends on how you turn the glass."', ctx_level=1)
134
+ hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token]
135
+ worker_init_fn = set_global_seed(cfg.seed, get_worker_init_fn=True)
136
+ os.makedirs(run_dir := (cfg.run_root_dir / cfg.run_id), exist_ok=True)
137
+ os.makedirs(cfg.run_root_dir / cfg.run_id / "checkpoints", exist_ok=True)
138
+ if overwatch.is_rank_zero():
139
+ # Additionally save a JSON version of the config
140
+ draccus.dump(cfg, open(run_dir / "config.yaml", "w"))
141
+ with open(run_dir / "config.yaml", "r") as f_yaml, open(run_dir / "config.json", "w") as f_json:
142
+ yaml_cfg = yaml.safe_load(f_yaml)
143
+ json.dump(yaml_cfg, f_json, indent=2)
144
+
145
+ # Load Vision Backbone --> on CPU, in Full Precision (initializing model, image_transform via TIMM)
146
+ overwatch.info(f"Loading Vision Backbone [bold]{cfg.model.vision_backbone_id}[/] via TIMM ")
147
+ vision_backbone, image_transform = get_vision_backbone_and_transform(
148
+ cfg.model.vision_backbone_id, image_resize_strategy=cfg.model.image_resize_strategy
149
+ )
150
+
151
+ # Load LLM Backbone --> on CPU, in Full Precision (initializing Tokenizer + handling special tokens if necessary)
152
+ overwatch.info(f"Loading Pretrained LLM [bold]{cfg.model.llm_backbone_id}[/] via HF Transformers")
153
+ llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
154
+ cfg.model.llm_backbone_id, llm_max_length=cfg.model.llm_max_length, hf_token=hf_token
155
+ )
156
+
157
+ # Create VLM => wraps `vision_backbone` and `llm`
158
+ overwatch.info(f"Instantiating PrismaticVLM `{model_id}` for Training Stage = `{cfg.stage}`")
159
+ vlm = get_vlm(
160
+ model_id,
161
+ cfg.model.arch_specifier,
162
+ vision_backbone,
163
+ llm_backbone,
164
+ enable_mixed_precision_training=cfg.model.enable_mixed_precision_training,
165
+ )
166
+
167
+ # [Explicit] Call to `freeze_backbones` here for clarity => will log exactly what is frozen / what's not!
168
+ overwatch.info(f"Invoking `VLM.freeze_backbones()` for `{model_id}` => Training Stage: `{cfg.stage}`")
169
+ vlm.freeze_backbones(cfg.stage)
170
+
171
+ # Load Weights from Checkpoint (depends on stage, config)
172
+ overwatch.info(f"Invoking `VLM.load_checkpoint()` for `{model_id}` => Training Stage: `{cfg.stage}`")
173
+ vlm.load_from_checkpoint(cfg.stage, run_dir, pretrained_checkpoint=cfg.pretrained_checkpoint)
174
+
175
+ # Get Dataset for Specified Stage
176
+ overwatch.info(f"Creating Dataset `{cfg.dataset.dataset_id}` => Stage: `{cfg.stage}`")
177
+ train_dataset, collator = get_dataset_and_collator(
178
+ cfg.stage,
179
+ cfg.dataset,
180
+ image_transform,
181
+ tokenizer,
182
+ prompt_builder_fn=llm_backbone.prompt_builder_fn,
183
+ default_image_resolution=vision_backbone.default_image_resolution,
184
+ padding_side=tokenizer.padding_side,
185
+ )
186
+
187
+ # Create Train Strategy
188
+ overwatch.info(f"Initializing Train Strategy `{cfg.train_strategy}`")
189
+ train_strategy = get_train_strategy(
190
+ train_strategy=cfg.train_strategy,
191
+ vlm=vlm,
192
+ device_id=device_id,
193
+ stage=cfg.stage,
194
+ epochs=cfg.epochs,
195
+ max_steps=cfg.max_steps,
196
+ global_batch_size=cfg.global_batch_size,
197
+ per_device_batch_size=cfg.per_device_batch_size,
198
+ learning_rate=cfg.learning_rate,
199
+ weight_decay=cfg.weight_decay,
200
+ max_grad_norm=cfg.max_grad_norm,
201
+ lr_scheduler_type=cfg.lr_scheduler_type,
202
+ warmup_ratio=cfg.warmup_ratio,
203
+ enable_gradient_checkpointing=cfg.model.enable_gradient_checkpointing,
204
+ enable_mixed_precision_training=cfg.model.enable_mixed_precision_training,
205
+ reduce_in_full_precision=cfg.model.reduce_in_full_precision,
206
+ worker_init_fn=worker_init_fn,
207
+ )
208
+ train_strategy.run_setup(run_dir=run_dir, n_train_examples=len(train_dataset))
209
+
210
+ # Create Metrics =>> Handles on the fly tracking, logging to specified trackers (e.g., JSONL, Weights & Biases)
211
+ overwatch.info(f"Creating Metrics with Active Trackers => `{cfg.trackers}`")
212
+ metrics = Metrics(
213
+ cfg.trackers,
214
+ cfg.run_id,
215
+ run_dir,
216
+ draccus.encode(cfg),
217
+ cfg.stage,
218
+ wandb_project=cfg.wandb_project,
219
+ wandb_entity=cfg.wandb_entity,
220
+ grad_accumulation_steps=train_strategy.grad_accumulation_steps,
221
+ )
222
+
223
+ # Run Training
224
+ overwatch.info("Starting Training Loop")
225
+ train_strategy.run_training(train_dataset, collator, metrics, stage=cfg.stage, seed=cfg.seed)
226
+
227
+ # Finalize
228
+ overwatch.info("Done with Training =>> Finalizing Metrics")
229
+ metrics.finalize()
230
+
231
+ # And... we're done!
232
+ overwatch.info("... and that's all, folks!")
233
+ dist.barrier()
234
+ dist.destroy_process_group()
235
+
236
+
237
+ if __name__ == "__main__":
238
+ pretrain()
test_deepseek_moe.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import sys
4
+ import os
5
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
+
7
+ # 模拟常量定义
8
+ ACTION_DIM = 7
9
+ NUM_ACTIONS_CHUNK = 8
10
+ SHORT_NUM_ACTIONS_CHUNK = 4
11
+ MID_NUM_ACTIONS_CHUNK = 6
12
+
13
+ # 导入相关模块
14
+ from prismatic.models.action_heads import (
15
+ Expert,
16
+ DeepSeekV3AdaptiveBiasRouter,
17
+ MoELayer,
18
+ DeepSeekV3MoEActionHead,
19
+ TSActionHead
20
+ )
21
+
22
+ def test_deepseek_moe_components():
23
+ """测试DeepSeek V3 MoE组件"""
24
+ print("测试 DeepSeek V3 MoE 组件...")
25
+
26
+ # 测试参数
27
+ batch_size = 4
28
+ seq_len = 8
29
+ hidden_dim = 256
30
+ num_experts = 8
31
+ top_k = 2
32
+
33
+ # 创建测试数据
34
+ x = torch.randn(batch_size, seq_len, hidden_dim)
35
+
36
+ print("\n1. 测试 GELU Expert 网络:")
37
+ try:
38
+ expert = Expert(hidden_dim)
39
+ output = expert(x.view(-1, hidden_dim))
40
+ print(f" 输入形状: {x.view(-1, hidden_dim).shape}")
41
+ print(f" 输出形状: {output.shape}")
42
+ assert output.shape == (batch_size * seq_len, hidden_dim)
43
+
44
+ # 验证使用了GELU激活
45
+ print(f" 激活函数类型: {type(expert.activation).__name__}")
46
+ assert isinstance(expert.activation, nn.GELU)
47
+ print(" ✓ GELU Expert 网络测试通过")
48
+
49
+ except Exception as e:
50
+ print(f" ✗ GELU Expert 网络测试失败: {e}")
51
+
52
+ print("\n2. 测试 DeepSeek V3 自适应偏置路由器:")
53
+ try:
54
+ router = DeepSeekV3AdaptiveBiasRouter(hidden_dim, num_experts, top_k)
55
+ weights, indices = router(x)
56
+ print(f" 输入形状: {x.shape}")
57
+ print(f" 权重形状: {weights.shape}")
58
+ print(f" 索引形状: {indices.shape}")
59
+
60
+ assert weights.shape == (batch_size, seq_len, top_k)
61
+ assert indices.shape == (batch_size, seq_len, top_k)
62
+
63
+ # 验证路由器有自适应偏置
64
+ if router.enable_bias_correction:
65
+ print(f" 自适应偏置形状: {router.adaptive_bias.shape}")
66
+ assert router.adaptive_bias.shape == (num_experts,)
67
+
68
+ # 验证负载均衡损失
69
+ loss = router.get_load_balancing_loss()
70
+ print(f" 负载均衡损失: {loss.item():.6f}")
71
+
72
+ print(" ✓ DeepSeek V3 路由器测试通过")
73
+
74
+ except Exception as e:
75
+ print(f" ✗ DeepSeek V3 路由器测试失败: {e}")
76
+
77
+ print("\n3. 测试 DeepSeek V3 MoE层:")
78
+ try:
79
+ # 测试不带共享专家的版本
80
+ moe_layer = MoELayer(
81
+ hidden_dim,
82
+ num_experts,
83
+ top_k,
84
+ enable_shared_expert=False
85
+ )
86
+ output = moe_layer(x)
87
+ print(f" 输入形状: {x.shape}")
88
+ print(f" 输出形状: {output.shape}")
89
+ assert output.shape == x.shape
90
+
91
+ # 测试带共享专家的版本
92
+ moe_layer_shared = MoELayer(
93
+ hidden_dim,
94
+ num_experts,
95
+ top_k,
96
+ enable_shared_expert=True,
97
+ num_shared_experts=2
98
+ )
99
+ output_shared = moe_layer_shared(x)
100
+ print(f" 带共享专家输出形状: {output_shared.shape}")
101
+ assert output_shared.shape == x.shape
102
+
103
+ # 验证负载均衡
104
+ load_loss = moe_layer.get_load_balancing_loss()
105
+ print(f" 负载均衡损失: {load_loss.item():.6f}")
106
+
107
+ print(" ✓ DeepSeek V3 MoE层测试通过")
108
+
109
+ except Exception as e:
110
+ print(f" ✗ DeepSeek V3 MoE层测试失败: {e}")
111
+
112
+ def test_deepseek_moe_action_head():
113
+ """测试DeepSeek V3 MoE动作头"""
114
+ print("\n4. 测试 DeepSeek V3 MoE 动作头:")
115
+
116
+ # 测试参数
117
+ batch_size = 2
118
+ input_dim = 512
119
+ hidden_dim = 256
120
+ action_dim = 7
121
+
122
+ try:
123
+ # 创建模型
124
+ model = DeepSeekV3MoEActionHead(
125
+ input_dim=input_dim,
126
+ hidden_dim=hidden_dim,
127
+ action_dim=action_dim,
128
+ num_routed_experts=8,
129
+ top_k=2,
130
+ num_moe_layers=2,
131
+ enable_shared_expert=True
132
+ )
133
+
134
+ # 测试单token输入
135
+ actions_hidden_states_single = torch.randn(batch_size, 1, input_dim)
136
+ output_single = model.predict_action(actions_hidden_states_single)
137
+ print(f" 单token输入形状: {actions_hidden_states_single.shape}")
138
+ print(f" 单token输出形状: {output_single.shape}")
139
+ assert output_single.shape == (batch_size, NUM_ACTIONS_CHUNK, action_dim)
140
+
141
+ # 测试多token输入
142
+ actions_hidden_states_multi = torch.randn(batch_size, ACTION_DIM, input_dim)
143
+ output_multi = model.predict_action(actions_hidden_states_multi)
144
+ print(f" 多token输入形状: {actions_hidden_states_multi.shape}")
145
+ print(f" 多token输出形状: {output_multi.shape}")
146
+ assert output_multi.shape == (batch_size, NUM_ACTIONS_CHUNK, action_dim)
147
+
148
+ # 测试负载均衡损失
149
+ load_loss = model.get_load_balancing_loss()
150
+ print(f" 模型负载均衡损失: {load_loss.item():.6f}")
151
+
152
+ # 测试专家使用统计
153
+ model.train()
154
+ _ = model.predict_action(actions_hidden_states_single) # 触发统计更新
155
+ stats = model.get_expert_usage_stats()
156
+ print(f" 专家使用统计层数: {len(stats)}")
157
+
158
+ print(" ✓ DeepSeek V3 MoE 动作头测试通过")
159
+
160
+ except Exception as e:
161
+ print(f" ✗ DeepSeek V3 MoE 动作头测试失败: {e}")
162
+
163
+ def test_comparison_with_traditional_methods():
164
+ """比较DeepSeek V3 MoE与传统方法"""
165
+ print("\n5. 性能比较测试:")
166
+
167
+ # 测试参数
168
+ batch_size = 2
169
+ input_dim = 512
170
+ hidden_dim = 256
171
+ action_dim = 7
172
+
173
+ try:
174
+ # 传统FFN方法
175
+ model_ffn = TSActionHead(
176
+ input_dim=input_dim,
177
+ hidden_dim=hidden_dim,
178
+ action_dim=action_dim,
179
+ mlp_type='ffn',
180
+ decoder_num_blocks=2
181
+ )
182
+
183
+ # 旧版MoE方法
184
+ model_old_moe = TSActionHead(
185
+ input_dim=input_dim,
186
+ hidden_dim=hidden_dim,
187
+ action_dim=action_dim,
188
+ mlp_type='moe',
189
+ num_experts=8,
190
+ top_k=2,
191
+ decoder_num_blocks=2
192
+ )
193
+
194
+ # DeepSeek V3 MoE方法
195
+ model_deepseek_moe = DeepSeekV3MoEActionHead(
196
+ input_dim=input_dim,
197
+ hidden_dim=hidden_dim,
198
+ action_dim=action_dim,
199
+ num_routed_experts=8,
200
+ top_k=2,
201
+ num_moe_layers=2,
202
+ enable_shared_expert=True
203
+ )
204
+
205
+ # 计算参数量
206
+ params_ffn = sum(p.numel() for p in model_ffn.parameters())
207
+ params_old_moe = sum(p.numel() for p in model_old_moe.parameters())
208
+ params_deepseek_moe = sum(p.numel() for p in model_deepseek_moe.parameters())
209
+
210
+ print(f" FFN 模型参数量: {params_ffn:,}")
211
+ print(f" 旧版 MoE 参数量: {params_old_moe:,}")
212
+ print(f" DeepSeek V3 MoE 参数量: {params_deepseek_moe:,}")
213
+ print(f" DeepSeek V3 vs FFN 参数比例: {params_deepseek_moe / params_ffn:.2f}x")
214
+ print(f" DeepSeek V3 vs 旧版MoE 参数比例: {params_deepseek_moe / params_old_moe:.2f}x")
215
+
216
+ # 测试推理时间(简单测试)
217
+ import time
218
+
219
+ test_input = torch.randn(batch_size, 1, input_dim)
220
+
221
+ # FFN推理时间
222
+ start_time = time.time()
223
+ for _ in range(100):
224
+ _ = model_ffn.predict_action(test_input)
225
+ ffn_time = time.time() - start_time
226
+
227
+ # DeepSeek V3 MoE推理时间
228
+ start_time = time.time()
229
+ for _ in range(100):
230
+ _ = model_deepseek_moe.predict_action(test_input)
231
+ deepseek_time = time.time() - start_time
232
+
233
+ print(f" FFN 推理时间 (100次): {ffn_time:.4f}s")
234
+ print(f" DeepSeek V3 MoE 推理时间 (100次): {deepseek_time:.4f}s")
235
+ print(f" 推理时间比例: {deepseek_time / ffn_time:.2f}x")
236
+
237
+ print(" ✓ 性能比较测试完成")
238
+
239
+ except Exception as e:
240
+ print(f" ✗ 性能比较测试失败: {e}")
241
+
242
+ if __name__ == "__main__":
243
+ test_deepseek_moe_components()
244
+ test_deepseek_moe_action_head()
245
+ test_comparison_with_traditional_methods()
246
+ print("\n所有 DeepSeek V3 MoE 测试完成!")
vla-scripts/extern/convert_openvla_weights_to_hf.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ convert_openvla_weights_to_hf.py
3
+
4
+ Utility script for converting full OpenVLA VLA weights (from this repository, in the default "Prismatic" format) to
5
+ the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers``
6
+ via `trust_remote_code = True`.
7
+
8
+ Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the
9
+ line, with first-class support.
10
+
11
+ Usage:
12
+ python vla-scripts/extern/convert_openvla_weights_to_hf.py \
13
+ --openvla_model_path_or_id <PATH TO PRISMATIC TRAINING RUN DIR> \
14
+ --output_hf_model_local_path <OUTPUT DIR FOR CONVERTED CHECKPOINT>
15
+ """
16
+
17
+ import json
18
+ import os
19
+ import shutil
20
+ from dataclasses import dataclass
21
+ from pathlib import Path
22
+ from typing import Dict, Union
23
+
24
+ import draccus
25
+ import timm
26
+ import torch
27
+ import torch.nn as nn
28
+ from huggingface_hub import hf_hub_download
29
+ from timm.models.vision_transformer import LayerScale
30
+ from transformers import AutoTokenizer
31
+
32
+ from prismatic.conf import ModelConfig
33
+ from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
34
+ from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
35
+ from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
36
+
37
+
38
+ @dataclass
39
+ class HFConvertConfig:
40
+ # fmt: off
41
+ openvla_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLA (on disk or HF Hub)
42
+ "runs/prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7"
43
+ )
44
+ output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model
45
+ "hf-convert/openvla-7b"
46
+ )
47
+ output_hf_model_hub_path: str = "openvla/openvla-7b" # (Optional) Path to HF Hub Path to push
48
+ # model to
49
+
50
+ # HF Hub Credentials (required for Gated Models like LLaMa-2)
51
+ hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
52
+
53
+ def __post_init__(self) -> None:
54
+ self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token
55
+
56
+ # fmt: on
57
+
58
+
59
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
60
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
61
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
62
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
64
+
65
+
66
+ def ls_apply_patch(ls_module: LayerScale):
67
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
68
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
69
+ del ls_module.gamma
70
+
71
+
72
+ # === Conversion Constants ===
73
+ PROJECTOR_KEY_MAPPING = {
74
+ "projector.0.weight": "projector.fc1.weight",
75
+ "projector.0.bias": "projector.fc1.bias",
76
+ "projector.2.weight": "projector.fc2.weight",
77
+ "projector.2.bias": "projector.fc2.bias",
78
+ "projector.4.weight": "projector.fc3.weight",
79
+ "projector.4.bias": "projector.fc3.bias",
80
+ }
81
+
82
+
83
+ def remap_state_dicts_for_hf(
84
+ prismatic_vision_backbone_state_dict: Dict[str, torch.Tensor],
85
+ projector_state_dict: Dict[str, torch.Tensor],
86
+ llm_backbone_state_dict: Dict[str, torch.Tensor],
87
+ use_fused_vision_backbone: bool = False,
88
+ ) -> Dict[str, torch.Tensor]:
89
+ """Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion."""
90
+ hf_state_dict = {}
91
+
92
+ # Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING`
93
+ for key, value in projector_state_dict.items():
94
+ hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value
95
+
96
+ # Iterate through LLM Backbone =>> replace `llm.` with `language_model.`
97
+ for key, value in llm_backbone_state_dict.items():
98
+ hf_state_dict[key.replace("llm.", "language_model.")] = value
99
+
100
+ # Iterate through Vision Backbone =>> add "vision_backbone." prefix
101
+ if not use_fused_vision_backbone:
102
+ for key, value in prismatic_vision_backbone_state_dict.items():
103
+ hf_state_dict[key.replace("featurizer.", "vision_backbone.featurizer.")] = value
104
+ else:
105
+ # Note =>> Assumes that backbones are always DINO + SigLIP...
106
+ for key, value in prismatic_vision_backbone_state_dict.items():
107
+ if key.startswith("dino_featurizer"):
108
+ if key.endswith(".gamma"):
109
+ # Handle `LayerScale gamma` =>> DINOv2 only!
110
+ key = key.replace(".gamma", ".scale_factor")
111
+ hf_state_dict[key.replace("dino_featurizer.", "vision_backbone.featurizer.")] = value
112
+ elif key.startswith("siglip_featurizer"):
113
+ hf_state_dict[key.replace("siglip_featurizer.", "vision_backbone.fused_featurizer.")] = value
114
+
115
+ return hf_state_dict
116
+
117
+
118
+ @draccus.wrap()
119
+ def convert_openvla_weights_to_hf(cfg: HFConvertConfig) -> None:
120
+ print(f"[*] Converting OpenVLA Model `{cfg.openvla_model_path_or_id}` to HF Transformers Format")
121
+ torch.set_default_dtype(torch.bfloat16)
122
+
123
+ # Get `config.json`, 'dataset_statistics.json' and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py`
124
+ if os.path.isdir(cfg.openvla_model_path_or_id):
125
+ print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.openvla_model_path_or_id))}`")
126
+ config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
127
+ dataset_statistics_json = run_dir / "dataset_statistics.json"
128
+
129
+ assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
130
+ assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
131
+ assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
132
+ else:
133
+ print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.openvla_model_path_or_id}`")
134
+ config_json = hf_hub_download("openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/config.json")
135
+ checkpoint_pt = hf_hub_download(
136
+ "openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/checkpoints/latest-checkpoint.pt"
137
+ )
138
+ dataset_statistics_json = hf_hub_download(
139
+ "openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/dataset_statistics.json"
140
+ )
141
+
142
+ # Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer
143
+ with open(config_json, "r") as f:
144
+ vla_cfg = json.load(f)["vla"]
145
+ prismatic_config = ModelConfig.get_choice_class(vla_cfg["base_vlm"])().__dict__
146
+
147
+ # Load Normalization Statistics
148
+ with open(dataset_statistics_json, "r") as f:
149
+ norm_stats = json.load(f)
150
+
151
+ # Create HF OpenVLAConfig (`transformers.PretrainedConfig`)
152
+ hf_config = OpenVLAConfig(
153
+ vision_backbone_id=prismatic_config["vision_backbone_id"],
154
+ llm_backbone_id=prismatic_config["llm_backbone_id"],
155
+ arch_specifier=prismatic_config["arch_specifier"],
156
+ image_resize_strategy=prismatic_config["image_resize_strategy"],
157
+ llm_max_length=prismatic_config["llm_max_length"],
158
+ torch_dtype=torch.bfloat16,
159
+ norm_stats=norm_stats,
160
+ )
161
+
162
+ # Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer`
163
+ # TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`!
164
+ print("[*] Instantiating and Patching Tokenizer, LLM Config")
165
+ tokenizer = AutoTokenizer.from_pretrained(
166
+ hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right"
167
+ )
168
+ tokenizer.add_special_tokens({"pad_token": "<PAD>"})
169
+ tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload...
170
+ assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!"
171
+ assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!"
172
+
173
+ # Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate
174
+ hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of
175
+ hf_config.text_config.pad_token_id = hf_config.pad_token_id
176
+ hf_config.text_config.torch_dtype = torch.bfloat16
177
+ assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!"
178
+
179
+ # Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform`
180
+ # =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py`
181
+ print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor")
182
+ input_sizes, interpolations, means, stds = [], [], [], []
183
+ for idx, timm_model_id in enumerate(hf_config.timm_model_ids):
184
+ timm_vision_backbone = timm.create_model(
185
+ timm_model_id,
186
+ pretrained=True,
187
+ num_classes=0,
188
+ img_size=hf_config.image_sizes[idx],
189
+ act_layer=hf_config.timm_override_act_layers[idx],
190
+ )
191
+
192
+ # Get Per-Backbone Image Processing
193
+ data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone)
194
+ input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx]))
195
+ interpolations.append(data_cfg["interpolation"])
196
+ means.append(data_cfg["mean"])
197
+ stds.append(data_cfg["std"])
198
+
199
+ # Patch `LayerScale` because of HF annoying `fix_key` overwrite...
200
+ for module in timm_vision_backbone.modules():
201
+ if isinstance(module, LayerScale):
202
+ ls_apply_patch(module)
203
+
204
+ # Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`)
205
+ hf_image_processor = PrismaticImageProcessor(
206
+ use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
207
+ image_resize_strategy=hf_config.image_resize_strategy,
208
+ input_sizes=input_sizes,
209
+ interpolations=interpolations,
210
+ means=means,
211
+ stds=stds,
212
+ )
213
+
214
+ # Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor)
215
+ print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor")
216
+ hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer)
217
+
218
+ # Load Prismatic Model State Dictionary (in preparation for conversion)
219
+ print("[*] Loading Prismatic VLM State Dictionary from Checkpoint")
220
+ model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"]
221
+ assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?"
222
+ assert all([k in model_state_dict for k in ["vision_backbone", "projector", "llm_backbone"]]), "Missing keys!"
223
+
224
+ # Convert
225
+ print("[*] Running Conversion")
226
+ converted_state_dict = remap_state_dicts_for_hf(
227
+ model_state_dict["vision_backbone"],
228
+ model_state_dict["projector"],
229
+ model_state_dict["llm_backbone"],
230
+ use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
231
+ )
232
+
233
+ # Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM
234
+ print("[*] Building (Randomly Initialized) Model =>> OpenVLAForActionPrediction")
235
+ hf_model = OpenVLAForActionPrediction(hf_config)
236
+ hf_model.load_state_dict(converted_state_dict, strict=True, assign=True)
237
+
238
+ # Cast Model to BF16 before Saving
239
+ hf_model.to(torch.bfloat16)
240
+
241
+ # Save Pretrained Versions to Local Path
242
+ print("[*] Saving Model & Processor to Local Path")
243
+ hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB")
244
+ hf_image_processor.save_pretrained(cfg.output_hf_model_local_path)
245
+ hf_processor.save_pretrained(cfg.output_hf_model_local_path)
246
+
247
+ # Copy `dataset_statistics.json` File to Converted Checkpoint Directory
248
+ output_dataset_statistics_json = cfg.output_hf_model_local_path / "dataset_statistics.json"
249
+ shutil.copyfile(dataset_statistics_json, output_dataset_statistics_json)
250
+
251
+ print(f"[*] Saving Complete! Saved converted checkpoint to: {cfg.output_hf_model_local_path}")
252
+
253
+ #####################################################################################
254
+ # Optional: Push Model to Hugging Face Hub
255
+ #####################################################################################
256
+
257
+ # # Register AutoClasses
258
+ # OpenVLAConfig.register_for_auto_class()
259
+ # PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor")
260
+ # PrismaticProcessor.register_for_auto_class("AutoProcessor")
261
+ # OpenVLAForActionPrediction.register_for_auto_class("AutoModelForVision2Seq")
262
+
263
+ # # Push to HF Hub
264
+ # print("[*] Pushing Model & Processor to HF Hub")
265
+ # hf_config.push_to_hub(cfg.output_hf_model_hub_path)
266
+ # hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB")
267
+ # hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path)
268
+ # hf_processor.push_to_hub(cfg.output_hf_model_hub_path)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ convert_openvla_weights_to_hf()