Add files using upload-large-folder tool
Browse files- prismatic/__init__.py +1 -0
- prismatic/conf/datasets.py +133 -0
- prismatic/conf/vla.py +235 -0
- prismatic/models/backbones/__init__.py +0 -0
- prismatic/models/backbones/vision/dinov2_vit.py +19 -0
- prismatic/models/load.py +226 -0
- prismatic/models/materialize.py +130 -0
- prismatic/models/projectors.py +67 -0
- prismatic/preprocessing/datasets/datasets.py +200 -0
- prismatic/preprocessing/materialize.py +69 -0
- prismatic/py.typed +0 -0
- prismatic/util/nn_utils.py +53 -0
- prismatic/vla/datasets/rlds/__init__.py +1 -0
- prismatic/vla/datasets/rlds/dataset.py +655 -0
- prismatic/vla/datasets/rlds/oxe/transforms.py +951 -0
- prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py +178 -0
- prismatic/vla/datasets/rlds/utils/__init__.py +0 -0
- prismatic/vla/datasets/rlds/utils/goal_relabeling.py +32 -0
- prismatic/vla/materialize.py +56 -0
- run_scripts/ac/ac.sh +87 -0
- run_scripts/ffn/3ffn2.sh +87 -0
- run_scripts/ffn/3postffn2.sh +87 -0
- run_scripts/ffn/3postffn6.sh +87 -0
- run_scripts/ffn/debug_5ffn_withactionprojector.sh +87 -0
- run_scripts/ffn/ffn4.sh +87 -0
- run_scripts/ffn/ffn8.sh +87 -0
- run_scripts/ffn/test.sh +87 -0
- run_scripts/ffn_long_chunks/run.sh +4 -0
- run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_base.sh +88 -0
- run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_50_l2.sh +102 -0
- run_scripts/ffn_q2a/bridge/exffn_relu_connector_linear_relu.sh +95 -0
- run_scripts/ffn_q2a/bridge/run_bridge.sh +2 -0
- run_scripts/ffn_q2a/franka/exffn_gelu_franka.sh +95 -0
- run_scripts/ffn_q2a/libero_moe/debug_moe_lit.sh +101 -0
- run_scripts/ffn_q2a/simhead/simhead_contrastive.sh +100 -0
- run_scripts/pp/pp.sh +87 -0
- run_scripts/run.sh +35 -0
- scripts/extern/verify_prismatic.py +134 -0
- scripts/pretrain.py +238 -0
- test_deepseek_moe.py +246 -0
- vla-scripts/extern/convert_openvla_weights_to_hf.py +272 -0
prismatic/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .models import available_model_names, available_models, get_model_description, load
|
prismatic/conf/datasets.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
datasets.py
|
| 3 |
+
|
| 4 |
+
Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant
|
| 5 |
+
and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes:
|
| 6 |
+
- Dataset Variant (Identifier) --> e.g., "llava-v15"
|
| 7 |
+
- Align Stage Dataset Components (annotations, images)
|
| 8 |
+
- Finetune Stage Dataset Components (annotations, images)
|
| 9 |
+
- Dataset Root Directory (Path)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from enum import Enum, unique
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Tuple
|
| 16 |
+
|
| 17 |
+
from draccus import ChoiceRegistry
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class DatasetConfig(ChoiceRegistry):
|
| 22 |
+
# fmt: off
|
| 23 |
+
dataset_id: str # Unique ID that fully specifies a dataset variant
|
| 24 |
+
|
| 25 |
+
# Dataset Components for each Stage in < align | finetune >
|
| 26 |
+
align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage
|
| 27 |
+
finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage
|
| 28 |
+
|
| 29 |
+
dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root
|
| 30 |
+
# fmt: on
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models)
|
| 34 |
+
@dataclass
|
| 35 |
+
class LLaVa_V15_Config(DatasetConfig):
|
| 36 |
+
dataset_id: str = "llava-v15"
|
| 37 |
+
|
| 38 |
+
align_stage_components: Tuple[Path, Path] = (
|
| 39 |
+
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 40 |
+
Path("download/llava-laion-cc-sbu-558k/"),
|
| 41 |
+
)
|
| 42 |
+
finetune_stage_components: Tuple[Path, Path] = (
|
| 43 |
+
Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"),
|
| 44 |
+
Path("download/llava-v1.5-instruct/"),
|
| 45 |
+
)
|
| 46 |
+
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training)
|
| 50 |
+
@dataclass
|
| 51 |
+
class LLaVa_Multimodal_Only_Config(DatasetConfig):
|
| 52 |
+
dataset_id: str = "llava-multimodal"
|
| 53 |
+
|
| 54 |
+
align_stage_components: Tuple[Path, Path] = (
|
| 55 |
+
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 56 |
+
Path("download/llava-laion-cc-sbu-558k/"),
|
| 57 |
+
)
|
| 58 |
+
finetune_stage_components: Tuple[Path, Path] = (
|
| 59 |
+
Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"),
|
| 60 |
+
Path("download/llava-v1.5-instruct/"),
|
| 61 |
+
)
|
| 62 |
+
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# LLaVa-v15 + LVIS-Instruct-4V
|
| 66 |
+
@dataclass
|
| 67 |
+
class LLaVa_LVIS4V_Config(DatasetConfig):
|
| 68 |
+
dataset_id: str = "llava-lvis4v"
|
| 69 |
+
|
| 70 |
+
align_stage_components: Tuple[Path, Path] = (
|
| 71 |
+
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 72 |
+
Path("download/llava-laion-cc-sbu-558k/"),
|
| 73 |
+
)
|
| 74 |
+
finetune_stage_components: Tuple[Path, Path] = (
|
| 75 |
+
Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"),
|
| 76 |
+
Path("download/llava-v1.5-instruct/"),
|
| 77 |
+
)
|
| 78 |
+
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# LLaVa-v15 + LRV-Instruct
|
| 82 |
+
@dataclass
|
| 83 |
+
class LLaVa_LRV_Config(DatasetConfig):
|
| 84 |
+
dataset_id: str = "llava-lrv"
|
| 85 |
+
|
| 86 |
+
align_stage_components: Tuple[Path, Path] = (
|
| 87 |
+
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 88 |
+
Path("download/llava-laion-cc-sbu-558k/"),
|
| 89 |
+
)
|
| 90 |
+
finetune_stage_components: Tuple[Path, Path] = (
|
| 91 |
+
Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"),
|
| 92 |
+
Path("download/llava-v1.5-instruct/"),
|
| 93 |
+
)
|
| 94 |
+
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct
|
| 98 |
+
@dataclass
|
| 99 |
+
class LLaVa_LVIS4V_LRV_Config(DatasetConfig):
|
| 100 |
+
dataset_id: str = "llava-lvis4v-lrv"
|
| 101 |
+
|
| 102 |
+
align_stage_components: Tuple[Path, Path] = (
|
| 103 |
+
Path("download/llava-laion-cc-sbu-558k/chat.json"),
|
| 104 |
+
Path("download/llava-laion-cc-sbu-558k/"),
|
| 105 |
+
)
|
| 106 |
+
finetune_stage_components: Tuple[Path, Path] = (
|
| 107 |
+
Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"),
|
| 108 |
+
Path("download/llava-v1.5-instruct/"),
|
| 109 |
+
)
|
| 110 |
+
dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! ===
|
| 114 |
+
@unique
|
| 115 |
+
class DatasetRegistry(Enum):
|
| 116 |
+
# === LLaVa v1.5 ===
|
| 117 |
+
LLAVA_V15 = LLaVa_V15_Config
|
| 118 |
+
|
| 119 |
+
LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config
|
| 120 |
+
|
| 121 |
+
LLAVA_LVIS4V = LLaVa_LVIS4V_Config
|
| 122 |
+
LLAVA_LRV = LLaVa_LRV_Config
|
| 123 |
+
|
| 124 |
+
LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def dataset_id(self) -> str:
|
| 128 |
+
return self.value.dataset_id
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Register Datasets in Choice Registry
|
| 132 |
+
for dataset_variant in DatasetRegistry:
|
| 133 |
+
DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value)
|
prismatic/conf/vla.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
vla.py
|
| 3 |
+
|
| 4 |
+
Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and
|
| 5 |
+
model configuration thereof. A given VLA model (`policy`) configures the following attributes:
|
| 6 |
+
- Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.)
|
| 7 |
+
- Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`)
|
| 8 |
+
- VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning)
|
| 9 |
+
- Training / Optimization Hyperparameters
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from enum import Enum, unique
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Optional, Union
|
| 16 |
+
|
| 17 |
+
from draccus import ChoiceRegistry
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class VLAConfig(ChoiceRegistry):
|
| 22 |
+
# fmt: off
|
| 23 |
+
vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant
|
| 24 |
+
base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`)
|
| 25 |
+
freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining)
|
| 26 |
+
freeze_llm_backbone: bool # Freeze LLM Backbone parameters
|
| 27 |
+
unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen)
|
| 28 |
+
|
| 29 |
+
# Data Mixture Parameters
|
| 30 |
+
data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`)
|
| 31 |
+
shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE)
|
| 32 |
+
|
| 33 |
+
# Optimization Parameters
|
| 34 |
+
epochs: int # Epochs to Run (in case `max_steps` is not specified)
|
| 35 |
+
max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`)
|
| 36 |
+
|
| 37 |
+
expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware
|
| 38 |
+
global_batch_size: int # Global Batch Size (divided across processes / world size)
|
| 39 |
+
per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU)
|
| 40 |
+
# =>> # of accumulation steps is auto-computed
|
| 41 |
+
|
| 42 |
+
learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay)
|
| 43 |
+
weight_decay: float # Weight Decay for AdamW Optimizer
|
| 44 |
+
max_grad_norm: float # Max Grad Norm (for global gradient clipping)
|
| 45 |
+
lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay")
|
| 46 |
+
warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers)
|
| 47 |
+
|
| 48 |
+
train_strategy: str # Train Strategy (default "fsdp-full-shard")
|
| 49 |
+
|
| 50 |
+
# Enable Gradient/Activation Checkpointing (for the LLM Backbone)
|
| 51 |
+
enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training
|
| 52 |
+
|
| 53 |
+
# Mixed Precision Training via Torch Native AMP (`autocast`)
|
| 54 |
+
enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision
|
| 55 |
+
reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision
|
| 56 |
+
|
| 57 |
+
# fmt: on
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# === OpenVLA Training Configurations ===
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge =
|
| 64 |
+
@dataclass
|
| 65 |
+
class Exp_SigLIP_224px_Bridge(VLAConfig):
|
| 66 |
+
vla_id: str = "siglip-224px+mx-bridge"
|
| 67 |
+
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 68 |
+
|
| 69 |
+
freeze_vision_backbone: bool = False
|
| 70 |
+
freeze_llm_backbone: bool = False
|
| 71 |
+
unfreeze_last_llm_layer: bool = False
|
| 72 |
+
|
| 73 |
+
# Data Mixture Parameters
|
| 74 |
+
data_mix: str = "bridge"
|
| 75 |
+
shuffle_buffer_size: int = 256_000
|
| 76 |
+
|
| 77 |
+
# Optimization Parameters
|
| 78 |
+
epochs: int = 1000
|
| 79 |
+
max_steps: Optional[int] = None
|
| 80 |
+
|
| 81 |
+
expected_world_size: int = 8
|
| 82 |
+
global_batch_size: int = 256
|
| 83 |
+
per_device_batch_size: int = 32
|
| 84 |
+
|
| 85 |
+
learning_rate: float = 2e-5
|
| 86 |
+
weight_decay: float = 0.0
|
| 87 |
+
max_grad_norm: float = 1.0
|
| 88 |
+
lr_scheduler_type: str = "constant"
|
| 89 |
+
warmup_ratio: float = 0.0
|
| 90 |
+
|
| 91 |
+
train_strategy: str = "fsdp-full-shard"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge =
|
| 95 |
+
@dataclass
|
| 96 |
+
class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
|
| 97 |
+
vla_id: str = "siglip-224px-icy+mx-bridge"
|
| 98 |
+
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 99 |
+
freeze_vision_backbone: bool = True
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge =
|
| 103 |
+
@dataclass
|
| 104 |
+
class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
|
| 105 |
+
vla_id: str = "prism-dinosiglip-224px+mx-bridge"
|
| 106 |
+
base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
|
| 107 |
+
|
| 108 |
+
data_mix: str = "bridge"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# = [64 GPU] SigLIP 224px + OXE Magic Soup =
|
| 112 |
+
@dataclass
|
| 113 |
+
class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge):
|
| 114 |
+
vla_id: str = "siglip-224px+mx-oxe-magic-soup"
|
| 115 |
+
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 116 |
+
|
| 117 |
+
data_mix: str = "oxe_magic_soup"
|
| 118 |
+
|
| 119 |
+
expected_world_size: int = 64
|
| 120 |
+
global_batch_size: int = 2048
|
| 121 |
+
per_device_batch_size: int = 32
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ =
|
| 125 |
+
@dataclass
|
| 126 |
+
class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge):
|
| 127 |
+
vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus"
|
| 128 |
+
base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
|
| 129 |
+
|
| 130 |
+
# Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling!
|
| 131 |
+
# data_mix: str = "oxe_magic_soup_plus"
|
| 132 |
+
data_mix: str = "oxe_magic_soup_plus_minus"
|
| 133 |
+
|
| 134 |
+
expected_world_size: int = 64
|
| 135 |
+
global_batch_size: int = 2048
|
| 136 |
+
per_device_batch_size: int = 32
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# === OpenVLA Fine-tuning Configurations ===
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# = [8 GPU] SigLIP 224px + T-DROID =
|
| 143 |
+
@dataclass
|
| 144 |
+
class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
|
| 145 |
+
vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl"
|
| 146 |
+
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 147 |
+
|
| 148 |
+
data_mix: str = "tdroid_carrot_in_bowl"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@dataclass
|
| 152 |
+
class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge):
|
| 153 |
+
vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot"
|
| 154 |
+
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 155 |
+
|
| 156 |
+
data_mix: str = "tdroid_pour_corn_in_pot"
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning =
|
| 160 |
+
@dataclass
|
| 161 |
+
class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
|
| 162 |
+
vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl"
|
| 163 |
+
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 164 |
+
freeze_vision_backbone: bool = True
|
| 165 |
+
freeze_llm_backbone: bool = False
|
| 166 |
+
|
| 167 |
+
data_mix: str = "tdroid_carrot_in_bowl"
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@dataclass
|
| 171 |
+
class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
|
| 172 |
+
vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl"
|
| 173 |
+
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 174 |
+
freeze_vision_backbone: bool = True
|
| 175 |
+
freeze_llm_backbone: bool = True
|
| 176 |
+
unfreeze_last_llm_layer: bool = True
|
| 177 |
+
|
| 178 |
+
data_mix: str = "tdroid_carrot_in_bowl"
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@dataclass
|
| 182 |
+
class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
|
| 183 |
+
vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl"
|
| 184 |
+
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 185 |
+
freeze_vision_backbone: bool = False
|
| 186 |
+
freeze_llm_backbone: bool = True
|
| 187 |
+
unfreeze_last_llm_layer: bool = True
|
| 188 |
+
|
| 189 |
+
data_mix: str = "tdroid_carrot_in_bowl"
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# === [8 GPU] SigLIP 224px + FrankaWipe ===
|
| 193 |
+
@dataclass
|
| 194 |
+
class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge):
|
| 195 |
+
vla_id: str = "siglip-224px+mx-droid_wipe"
|
| 196 |
+
base_vlm: Union[str, Path] = "siglip-224px+7b"
|
| 197 |
+
|
| 198 |
+
data_mix: str = "droid_wipe"
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# === Define a VLA Registry Enum for Reference & Validation ===
|
| 202 |
+
@unique
|
| 203 |
+
class VLARegistry(Enum):
|
| 204 |
+
# Sanity Check Configurations =>> BridgeV2
|
| 205 |
+
SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge
|
| 206 |
+
DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge
|
| 207 |
+
|
| 208 |
+
# SigLIP Frozen Backbone Experiment
|
| 209 |
+
FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge
|
| 210 |
+
|
| 211 |
+
# [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup
|
| 212 |
+
SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup
|
| 213 |
+
|
| 214 |
+
# [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++
|
| 215 |
+
DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus
|
| 216 |
+
|
| 217 |
+
# === TDROID Fine-tuning Configs ===
|
| 218 |
+
SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl
|
| 219 |
+
SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot
|
| 220 |
+
|
| 221 |
+
SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl
|
| 222 |
+
SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl
|
| 223 |
+
SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl
|
| 224 |
+
|
| 225 |
+
# === DROID Fine-tuning Configs ===
|
| 226 |
+
SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe
|
| 227 |
+
|
| 228 |
+
@property
|
| 229 |
+
def vla_id(self) -> str:
|
| 230 |
+
return self.value.vla_id
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# Register VLAs in Choice Registry
|
| 234 |
+
for vla_variant in VLARegistry:
|
| 235 |
+
VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value)
|
prismatic/models/backbones/__init__.py
ADDED
|
File without changes
|
prismatic/models/backbones/vision/dinov2_vit.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dinov2_vit.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from prismatic.models.backbones.vision.base_vision import TimmViTBackbone
|
| 6 |
+
|
| 7 |
+
# Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers!
|
| 8 |
+
# => Reference: https://arxiv.org/abs/2309.16588
|
| 9 |
+
DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DinoV2ViTBackbone(TimmViTBackbone):
|
| 13 |
+
def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None:
|
| 14 |
+
super().__init__(
|
| 15 |
+
vision_backbone_id,
|
| 16 |
+
DINOv2_VISION_BACKBONES[vision_backbone_id],
|
| 17 |
+
image_resize_strategy,
|
| 18 |
+
default_image_size=default_image_size,
|
| 19 |
+
)
|
prismatic/models/load.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
load.py
|
| 3 |
+
|
| 4 |
+
Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical
|
| 5 |
+
IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Optional, Union
|
| 12 |
+
|
| 13 |
+
from huggingface_hub import HfFileSystem, hf_hub_download
|
| 14 |
+
|
| 15 |
+
from prismatic.conf import ModelConfig
|
| 16 |
+
from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform
|
| 17 |
+
from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY
|
| 18 |
+
from prismatic.models.vlas import OpenVLA
|
| 19 |
+
from prismatic.models.vlms import PrismaticVLM
|
| 20 |
+
from prismatic.overwatch import initialize_overwatch
|
| 21 |
+
from prismatic.vla.action_tokenizer import ActionTokenizer
|
| 22 |
+
|
| 23 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 24 |
+
overwatch = initialize_overwatch(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# === HF Hub Repository ===
|
| 28 |
+
HF_HUB_REPO = "TRI-ML/prismatic-vlms"
|
| 29 |
+
VLA_HF_HUB_REPO = "openvla/openvla-dev"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# === Available Models ===
|
| 33 |
+
def available_models() -> List[str]:
|
| 34 |
+
return list(MODEL_REGISTRY.keys())
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def available_model_names() -> List[str]:
|
| 38 |
+
return list(GLOBAL_REGISTRY.items())
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_model_description(model_id_or_name: str) -> str:
|
| 42 |
+
if model_id_or_name not in GLOBAL_REGISTRY:
|
| 43 |
+
raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`")
|
| 44 |
+
|
| 45 |
+
# Print Description & Return
|
| 46 |
+
print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2))
|
| 47 |
+
|
| 48 |
+
return description
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# === Load Pretrained Model ===
|
| 52 |
+
def load(
|
| 53 |
+
model_id_or_path: Union[str, Path],
|
| 54 |
+
hf_token: Optional[str] = None,
|
| 55 |
+
cache_dir: Optional[Union[str, Path]] = None,
|
| 56 |
+
load_for_training: bool = False,
|
| 57 |
+
) -> PrismaticVLM:
|
| 58 |
+
"""Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub."""
|
| 59 |
+
if os.path.isdir(model_id_or_path):
|
| 60 |
+
overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`")
|
| 61 |
+
|
| 62 |
+
# Get paths for `config.json` and pretrained checkpoint
|
| 63 |
+
config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
|
| 64 |
+
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
|
| 65 |
+
assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
|
| 66 |
+
else:
|
| 67 |
+
if model_id_or_path not in GLOBAL_REGISTRY:
|
| 68 |
+
raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`")
|
| 69 |
+
|
| 70 |
+
overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub")
|
| 71 |
+
with overwatch.local_zero_first():
|
| 72 |
+
config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir)
|
| 73 |
+
checkpoint_pt = hf_hub_download(
|
| 74 |
+
repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Load Model Config from `config.json`
|
| 78 |
+
with open(config_json, "r") as f:
|
| 79 |
+
model_cfg = json.load(f)["model"]
|
| 80 |
+
|
| 81 |
+
# = Load Individual Components necessary for Instantiating a VLM =
|
| 82 |
+
# =>> Print Minimal Config
|
| 83 |
+
overwatch.info(
|
| 84 |
+
f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n"
|
| 85 |
+
f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n"
|
| 86 |
+
f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n"
|
| 87 |
+
f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n"
|
| 88 |
+
f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Load Vision Backbone
|
| 92 |
+
overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]")
|
| 93 |
+
vision_backbone, image_transform = get_vision_backbone_and_transform(
|
| 94 |
+
model_cfg["vision_backbone_id"],
|
| 95 |
+
model_cfg["image_resize_strategy"],
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Load LLM Backbone --> note `inference_mode = True` by default when calling `load()`
|
| 99 |
+
overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers")
|
| 100 |
+
llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
|
| 101 |
+
model_cfg["llm_backbone_id"],
|
| 102 |
+
llm_max_length=model_cfg.get("llm_max_length", 2048),
|
| 103 |
+
hf_token=hf_token,
|
| 104 |
+
inference_mode=not load_for_training,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
|
| 108 |
+
overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint")
|
| 109 |
+
vlm = PrismaticVLM.from_pretrained(
|
| 110 |
+
checkpoint_pt,
|
| 111 |
+
model_cfg["model_id"],
|
| 112 |
+
vision_backbone,
|
| 113 |
+
llm_backbone,
|
| 114 |
+
arch_specifier=model_cfg["arch_specifier"],
|
| 115 |
+
freeze_weights=not load_for_training,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return vlm
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# === Load Pretrained VLA Model ===
|
| 122 |
+
def load_vla(
|
| 123 |
+
model_id_or_path: Union[str, Path],
|
| 124 |
+
hf_token: Optional[str] = None,
|
| 125 |
+
cache_dir: Optional[Union[str, Path]] = None,
|
| 126 |
+
load_for_training: bool = False,
|
| 127 |
+
step_to_load: Optional[int] = None,
|
| 128 |
+
model_type: str = "pretrained",
|
| 129 |
+
) -> OpenVLA:
|
| 130 |
+
"""Loads a pretrained OpenVLA from either local disk or the HuggingFace Hub."""
|
| 131 |
+
|
| 132 |
+
# TODO (siddk, moojink) :: Unify semantics with `load()` above; right now, `load_vla()` assumes path points to
|
| 133 |
+
# checkpoint `.pt` file, rather than the top-level run directory!
|
| 134 |
+
if os.path.isfile(model_id_or_path):
|
| 135 |
+
overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`")
|
| 136 |
+
|
| 137 |
+
# [Validate] Checkpoint Path should look like `.../<RUN_ID>/checkpoints/<CHECKPOINT_PATH>.pt`
|
| 138 |
+
assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!"
|
| 139 |
+
run_dir = checkpoint_pt.parents[1]
|
| 140 |
+
|
| 141 |
+
# Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint
|
| 142 |
+
config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json"
|
| 143 |
+
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
|
| 144 |
+
assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
|
| 145 |
+
|
| 146 |
+
# Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`VLA_HF_HUB_REPO`)
|
| 147 |
+
else:
|
| 148 |
+
# Search HF Hub Repo via fsspec API
|
| 149 |
+
overwatch.info(f"Checking HF for `{(hf_path := str(Path(VLA_HF_HUB_REPO) / model_type / model_id_or_path))}`")
|
| 150 |
+
if not (tmpfs := HfFileSystem()).exists(hf_path):
|
| 151 |
+
raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`")
|
| 152 |
+
|
| 153 |
+
# Identify Checkpoint to Load (via `step_to_load`)
|
| 154 |
+
step_to_load = f"{step_to_load:06d}" if step_to_load is not None else None
|
| 155 |
+
valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/step-{step_to_load if step_to_load is not None else ''}*.pt")
|
| 156 |
+
if (len(valid_ckpts) == 0) or (step_to_load is not None and len(valid_ckpts) != 1):
|
| 157 |
+
raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/")
|
| 158 |
+
|
| 159 |
+
# Call to `glob` will sort steps in ascending order (if `step_to_load` is None); just grab last element
|
| 160 |
+
target_ckpt = Path(valid_ckpts[-1]).name
|
| 161 |
+
|
| 162 |
+
overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`")
|
| 163 |
+
with overwatch.local_zero_first():
|
| 164 |
+
relpath = Path(model_type) / model_id_or_path
|
| 165 |
+
config_json = hf_hub_download(
|
| 166 |
+
repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'config.json')!s}", cache_dir=cache_dir
|
| 167 |
+
)
|
| 168 |
+
dataset_statistics_json = hf_hub_download(
|
| 169 |
+
repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'dataset_statistics.json')!s}", cache_dir=cache_dir
|
| 170 |
+
)
|
| 171 |
+
checkpoint_pt = hf_hub_download(
|
| 172 |
+
repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'checkpoints' / target_ckpt)!s}", cache_dir=cache_dir
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json`
|
| 176 |
+
with open(config_json, "r") as f:
|
| 177 |
+
vla_cfg = json.load(f)["vla"]
|
| 178 |
+
model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])()
|
| 179 |
+
|
| 180 |
+
# Load Dataset Statistics for Action Denormalization
|
| 181 |
+
with open(dataset_statistics_json, "r") as f:
|
| 182 |
+
norm_stats = json.load(f)
|
| 183 |
+
|
| 184 |
+
# = Load Individual Components necessary for Instantiating a VLA (via base VLM components) =
|
| 185 |
+
# =>> Print Minimal Config
|
| 186 |
+
overwatch.info(
|
| 187 |
+
f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n"
|
| 188 |
+
f" Vision Backbone =>> [bold]{model_cfg.vision_backbone_id}[/]\n"
|
| 189 |
+
f" LLM Backbone =>> [bold]{model_cfg.llm_backbone_id}[/]\n"
|
| 190 |
+
f" Arch Specifier =>> [bold]{model_cfg.arch_specifier}[/]\n"
|
| 191 |
+
f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Load Vision Backbone
|
| 195 |
+
overwatch.info(f"Loading Vision Backbone [bold]{model_cfg.vision_backbone_id}[/]")
|
| 196 |
+
vision_backbone, image_transform = get_vision_backbone_and_transform(
|
| 197 |
+
model_cfg.vision_backbone_id,
|
| 198 |
+
model_cfg.image_resize_strategy,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Load LLM Backbone --> note `inference_mode = True` by default when calling `load()`
|
| 202 |
+
overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg.llm_backbone_id}[/] via HF Transformers")
|
| 203 |
+
llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
|
| 204 |
+
model_cfg.llm_backbone_id,
|
| 205 |
+
llm_max_length=model_cfg.llm_max_length,
|
| 206 |
+
hf_token=hf_token,
|
| 207 |
+
inference_mode=not load_for_training,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Create Action Tokenizer
|
| 211 |
+
action_tokenizer = ActionTokenizer(llm_backbone.get_tokenizer())
|
| 212 |
+
|
| 213 |
+
# Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile)
|
| 214 |
+
overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint")
|
| 215 |
+
vla = OpenVLA.from_pretrained(
|
| 216 |
+
checkpoint_pt,
|
| 217 |
+
model_cfg.model_id,
|
| 218 |
+
vision_backbone,
|
| 219 |
+
llm_backbone,
|
| 220 |
+
arch_specifier=model_cfg.arch_specifier,
|
| 221 |
+
freeze_weights=not load_for_training,
|
| 222 |
+
norm_stats=norm_stats,
|
| 223 |
+
action_tokenizer=action_tokenizer,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return vla
|
prismatic/models/materialize.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
materialize.py
|
| 3 |
+
|
| 4 |
+
Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports
|
| 5 |
+
individual functions for clear control flow.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from transformers import PreTrainedTokenizerBase
|
| 11 |
+
|
| 12 |
+
from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone
|
| 13 |
+
from prismatic.models.backbones.vision import (
|
| 14 |
+
CLIPViTBackbone,
|
| 15 |
+
DinoCLIPViTBackbone,
|
| 16 |
+
DinoSigLIPViTBackbone,
|
| 17 |
+
DinoV2ViTBackbone,
|
| 18 |
+
ImageTransform,
|
| 19 |
+
IN1KViTBackbone,
|
| 20 |
+
SigLIPViTBackbone,
|
| 21 |
+
VisionBackbone,
|
| 22 |
+
)
|
| 23 |
+
from prismatic.models.vlms import PrismaticVLM
|
| 24 |
+
|
| 25 |
+
# === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs ===
|
| 26 |
+
# fmt: off
|
| 27 |
+
|
| 28 |
+
# === Vision Backbone Registry ===
|
| 29 |
+
VISION_BACKBONES = {
|
| 30 |
+
# === 224px Backbones ===
|
| 31 |
+
"clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 32 |
+
"siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 33 |
+
"dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 34 |
+
"in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 35 |
+
"dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 36 |
+
|
| 37 |
+
# === Assorted CLIP Backbones ===
|
| 38 |
+
"clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 39 |
+
"clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}},
|
| 40 |
+
|
| 41 |
+
# === Assorted SigLIP Backbones ===
|
| 42 |
+
"siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}},
|
| 43 |
+
"siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}},
|
| 44 |
+
"siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
|
| 45 |
+
"siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
|
| 46 |
+
|
| 47 |
+
# === Fused Backbones ===
|
| 48 |
+
"dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}},
|
| 49 |
+
"dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}},
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# === Language Model Registry ===
|
| 54 |
+
LLM_BACKBONES = {
|
| 55 |
+
# === LLaMa-2 Pure (Non-Chat) Backbones ===
|
| 56 |
+
"llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 57 |
+
"llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 58 |
+
|
| 59 |
+
# === LLaMa-2 Chat Backbones ===
|
| 60 |
+
"llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 61 |
+
"llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 62 |
+
|
| 63 |
+
# === Vicuna-v1.5 Backbones ===
|
| 64 |
+
"vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 65 |
+
"vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
|
| 66 |
+
|
| 67 |
+
# === Mistral v0.1 Backbones ===
|
| 68 |
+
"mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}},
|
| 69 |
+
"mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}},
|
| 70 |
+
|
| 71 |
+
# === Phi-2 Backbone ===
|
| 72 |
+
"phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}},
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# fmt: on
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_vision_backbone_and_transform(
|
| 79 |
+
vision_backbone_id: str, image_resize_strategy: str
|
| 80 |
+
) -> Tuple[VisionBackbone, ImageTransform]:
|
| 81 |
+
"""Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform."""
|
| 82 |
+
if vision_backbone_id in VISION_BACKBONES:
|
| 83 |
+
vision_cfg = VISION_BACKBONES[vision_backbone_id]
|
| 84 |
+
vision_backbone: VisionBackbone = vision_cfg["cls"](
|
| 85 |
+
vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"]
|
| 86 |
+
)
|
| 87 |
+
image_transform = vision_backbone.get_image_transform()
|
| 88 |
+
return vision_backbone, image_transform
|
| 89 |
+
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_llm_backbone_and_tokenizer(
|
| 95 |
+
llm_backbone_id: str,
|
| 96 |
+
llm_max_length: int = 2048,
|
| 97 |
+
hf_token: Optional[str] = None,
|
| 98 |
+
inference_mode: bool = False,
|
| 99 |
+
) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]:
|
| 100 |
+
if llm_backbone_id in LLM_BACKBONES:
|
| 101 |
+
llm_cfg = LLM_BACKBONES[llm_backbone_id]
|
| 102 |
+
llm_backbone: LLMBackbone = llm_cfg["cls"](
|
| 103 |
+
llm_backbone_id,
|
| 104 |
+
llm_max_length=llm_max_length,
|
| 105 |
+
hf_token=hf_token,
|
| 106 |
+
inference_mode=inference_mode,
|
| 107 |
+
**llm_cfg["kwargs"],
|
| 108 |
+
)
|
| 109 |
+
tokenizer = llm_backbone.get_tokenizer()
|
| 110 |
+
return llm_backbone, tokenizer
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_vlm(
|
| 117 |
+
model_id: str,
|
| 118 |
+
arch_specifier: str,
|
| 119 |
+
vision_backbone: VisionBackbone,
|
| 120 |
+
llm_backbone: LLMBackbone,
|
| 121 |
+
enable_mixed_precision_training: bool = True,
|
| 122 |
+
) -> PrismaticVLM:
|
| 123 |
+
"""Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM)."""
|
| 124 |
+
return PrismaticVLM(
|
| 125 |
+
model_id,
|
| 126 |
+
vision_backbone,
|
| 127 |
+
llm_backbone,
|
| 128 |
+
enable_mixed_precision_training=enable_mixed_precision_training,
|
| 129 |
+
arch_specifier=arch_specifier,
|
| 130 |
+
)
|
prismatic/models/projectors.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Implementation of additional projectors for additional inputs to the VLA models."""
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ProprioProjector(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Projects proprio state inputs into the LLM's embedding space.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, llm_dim: int, proprio_dim: int) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.llm_dim = llm_dim
|
| 14 |
+
self.proprio_dim = proprio_dim
|
| 15 |
+
|
| 16 |
+
self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True)
|
| 17 |
+
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
| 18 |
+
self.act_fn1 = nn.GELU()
|
| 19 |
+
|
| 20 |
+
def forward(self, proprio: torch.Tensor = None) -> torch.Tensor:
|
| 21 |
+
# proprio: (bsz, proprio_dim)
|
| 22 |
+
projected_features = self.fc1(proprio)
|
| 23 |
+
projected_features = self.act_fn1(projected_features)
|
| 24 |
+
projected_features = self.fc2(projected_features)
|
| 25 |
+
return projected_features
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NoisyActionProjector(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
[Diffusion] Projects noisy action inputs into the LLM's embedding space.
|
| 31 |
+
|
| 32 |
+
Note that since each action is tokenized into 7 tokens in OpenVLA (rather
|
| 33 |
+
than having 1 token per action), each noisy action token will have dimension 1
|
| 34 |
+
instead of 7.
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, llm_dim: int) -> None:
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.llm_dim = llm_dim
|
| 39 |
+
self.action_token_dim = 1
|
| 40 |
+
|
| 41 |
+
self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True)
|
| 42 |
+
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
| 43 |
+
self.act_fn1 = nn.GELU()
|
| 44 |
+
|
| 45 |
+
def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor:
|
| 46 |
+
# noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1)
|
| 47 |
+
projected_features = self.fc1(noisy_actions)
|
| 48 |
+
projected_features = self.act_fn1(projected_features)
|
| 49 |
+
projected_features = self.fc2(projected_features)
|
| 50 |
+
return projected_features
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class VisualProjector(nn.Module):
|
| 56 |
+
def __init__(self, llm_dim: int, visual_dim: int) -> None:
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.visual_dim, self.llm_dim = visual_dim, llm_dim
|
| 59 |
+
self.fc1 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
| 60 |
+
self.fc2 = nn.Linear(self.llm_dim, self.visual_dim, bias=True)
|
| 61 |
+
self.act_fn1 = nn.GELU()
|
| 62 |
+
|
| 63 |
+
def forward(self, img_hidden_embedding: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
projected_features = self.fc1(img_hidden_embedding)
|
| 65 |
+
projected_features = self.act_fn1(projected_features)
|
| 66 |
+
projected_features = self.fc2(projected_features)
|
| 67 |
+
return projected_features
|
prismatic/preprocessing/datasets/datasets.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
datasets.py
|
| 3 |
+
|
| 4 |
+
PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with
|
| 5 |
+
utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected
|
| 6 |
+
formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models).
|
| 7 |
+
|
| 8 |
+
We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that
|
| 9 |
+
random access image reading is relatively cheap/fast.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import copy
|
| 13 |
+
import json
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Dict, List, Tuple, Type
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from torch.utils.data import Dataset
|
| 20 |
+
from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase
|
| 21 |
+
|
| 22 |
+
from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| 23 |
+
from prismatic.models.backbones.vision import ImageTransform
|
| 24 |
+
|
| 25 |
+
# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels)
|
| 26 |
+
IGNORE_INDEX = -100
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class AlignDataset(Dataset[Dict[str, torch.Tensor]]):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
chat_json: Path,
|
| 33 |
+
image_dir: Path,
|
| 34 |
+
image_transform: ImageTransform,
|
| 35 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 36 |
+
) -> None:
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.chat_json, self.image_dir = chat_json, image_dir
|
| 39 |
+
self.image_transform, self.tokenizer = image_transform, tokenizer
|
| 40 |
+
self.dataset_type = "align"
|
| 41 |
+
|
| 42 |
+
# Create Prompt Template
|
| 43 |
+
self.prompt_template = "{caption}" + self.tokenizer.eos_token
|
| 44 |
+
|
| 45 |
+
# Load Chat JSON
|
| 46 |
+
with open(self.chat_json, "r") as f:
|
| 47 |
+
self.examples = json.load(f)
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 50 |
+
"""
|
| 51 |
+
Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard
|
| 52 |
+
the "prompt" from the human, and instead directly predict the caption from the image.
|
| 53 |
+
|
| 54 |
+
As a concrete example given the "raw data" for the first example:
|
| 55 |
+
example = self.examples[0]["conversations"]` = {
|
| 56 |
+
[
|
| 57 |
+
{"from": "human", "value": "Render a clear and concise summary of the photo.\n<image>"},
|
| 58 |
+
{"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"}
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
Return =>> self.tokenizer("<image> select luxury furniture 3 - inch gel memory foam mattress topper\n")
|
| 63 |
+
|
| 64 |
+
:param idx: Index to retrieve from the dataset.
|
| 65 |
+
|
| 66 |
+
:return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
|
| 67 |
+
"""
|
| 68 |
+
image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"]
|
| 69 |
+
assert (len(conversation) == 2) and ("<image>" not in conversation[-1]["value"]), "Unexpected text!"
|
| 70 |
+
|
| 71 |
+
# Format Caption --> {caption}{eos_token}
|
| 72 |
+
caption = self.prompt_template.format(caption=conversation[-1]["value"].strip())
|
| 73 |
+
|
| 74 |
+
# We treat image patches as "tokens = [p1 p2 p3, ...]"; we need to specify ordering of text/patch tokens.
|
| 75 |
+
# => Critically, we find that inserting *after* the BOS token leads to the strongest performance!
|
| 76 |
+
# - input_ids = "<s> p1 p2 p3 ... <caption_text> \n"
|
| 77 |
+
# - labels = "IGNORE IGNORE ..." (copy `input_ids` replacing <s> and p{1...K} with IGNORE)
|
| 78 |
+
#
|
| 79 |
+
# IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
|
| 80 |
+
input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0]
|
| 81 |
+
labels = copy.deepcopy(input_ids)
|
| 82 |
+
|
| 83 |
+
# Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches right after)
|
| 84 |
+
labels[0] = IGNORE_INDEX
|
| 85 |
+
|
| 86 |
+
# Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
|
| 87 |
+
pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
|
| 88 |
+
|
| 89 |
+
return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
|
| 90 |
+
|
| 91 |
+
def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]:
|
| 92 |
+
"""Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
|
| 93 |
+
modality_lengths = []
|
| 94 |
+
for example in self.examples:
|
| 95 |
+
is_multimodal = "image" in example
|
| 96 |
+
n_words = sum([len(turn["value"].replace("<image>", "").split()) for turn in example["conversations"]])
|
| 97 |
+
modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words))
|
| 98 |
+
return modality_lengths
|
| 99 |
+
|
| 100 |
+
def __len__(self) -> int:
|
| 101 |
+
return len(self.examples)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
instruct_json: Path,
|
| 108 |
+
image_dir: Path,
|
| 109 |
+
image_transform: ImageTransform,
|
| 110 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 111 |
+
prompt_builder_fn: Type[PromptBuilder],
|
| 112 |
+
) -> None:
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.instruct_json, self.image_dir = instruct_json, image_dir
|
| 115 |
+
self.image_transform, self.tokenizer = image_transform, tokenizer
|
| 116 |
+
self.prompt_builder_fn = prompt_builder_fn
|
| 117 |
+
self.dataset_type = "finetune"
|
| 118 |
+
|
| 119 |
+
# Load Instruct JSON
|
| 120 |
+
with open(self.instruct_json, "r") as f:
|
| 121 |
+
self.examples = json.load(f)
|
| 122 |
+
|
| 123 |
+
# === Unimodal + Multimodal Handling ===
|
| 124 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 125 |
+
"""
|
| 126 |
+
Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of
|
| 127 |
+
dialog grounded in a single image.
|
| 128 |
+
|
| 129 |
+
To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the
|
| 130 |
+
methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example.
|
| 131 |
+
|
| 132 |
+
:param idx: Index to retrieve from the dataset.
|
| 133 |
+
|
| 134 |
+
:return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor}
|
| 135 |
+
"""
|
| 136 |
+
conversation = self.examples[idx]["conversations"]
|
| 137 |
+
|
| 138 |
+
# Create Prompt Builder --> add each message sequentially
|
| 139 |
+
prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], []
|
| 140 |
+
for turn_idx, turn in enumerate(conversation):
|
| 141 |
+
# Get "effective" string added to prompt --> handle whitespace for tokenizer type!
|
| 142 |
+
msg = prompt_builder.add_turn(turn["from"], turn["value"])
|
| 143 |
+
|
| 144 |
+
# Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty!
|
| 145 |
+
if isinstance(self.tokenizer, LlamaTokenizerFast):
|
| 146 |
+
msg = msg.rstrip()
|
| 147 |
+
|
| 148 |
+
# Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling!
|
| 149 |
+
elif isinstance(self.tokenizer, CodeGenTokenizerFast):
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!")
|
| 154 |
+
|
| 155 |
+
# Tokenize Input IDs
|
| 156 |
+
turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids
|
| 157 |
+
|
| 158 |
+
# [CRITICAL] We do not want to take the loss for the "USER: <msg>" prompts =>> just the responses!
|
| 159 |
+
turn_labels = (
|
| 160 |
+
[IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Add to Trackers
|
| 164 |
+
input_ids.extend(turn_input_ids)
|
| 165 |
+
labels.extend(turn_labels)
|
| 166 |
+
|
| 167 |
+
# Tensorize =>> Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches after)
|
| 168 |
+
# - IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL!
|
| 169 |
+
input_ids, labels = torch.tensor(input_ids), torch.tensor(labels)
|
| 170 |
+
|
| 171 |
+
# Handle Truncation (if necessary)
|
| 172 |
+
input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length]
|
| 173 |
+
|
| 174 |
+
# === Handle "unimodal" (language-only) vs. "multimodal" ===
|
| 175 |
+
if "image" in self.examples[idx]:
|
| 176 |
+
image_path = Path(self.examples[idx]["image"])
|
| 177 |
+
|
| 178 |
+
# Set the <BOS> token's label to IGNORE_INDEX (since we're inserting the image patches right after)
|
| 179 |
+
labels[0] = IGNORE_INDEX
|
| 180 |
+
|
| 181 |
+
# Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor])
|
| 182 |
+
pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB"))
|
| 183 |
+
|
| 184 |
+
return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels)
|
| 185 |
+
|
| 186 |
+
else:
|
| 187 |
+
# No image --> return `pixel_values` = None; Collator will do the smart batch handling for us!
|
| 188 |
+
return dict(pixel_values=None, input_ids=input_ids, labels=labels)
|
| 189 |
+
|
| 190 |
+
def get_modality_lengths(self) -> List[Tuple[bool, int]]:
|
| 191 |
+
"""Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example."""
|
| 192 |
+
modality_lengths = []
|
| 193 |
+
for example in self.examples:
|
| 194 |
+
is_multimodal = "image" in example
|
| 195 |
+
n_words = sum([len(turn["value"].split()) for turn in example["conversations"]])
|
| 196 |
+
modality_lengths.append((is_multimodal, n_words))
|
| 197 |
+
return modality_lengths
|
| 198 |
+
|
| 199 |
+
def __len__(self) -> int:
|
| 200 |
+
return len(self.examples)
|
prismatic/preprocessing/materialize.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
materialize.py
|
| 3 |
+
|
| 4 |
+
Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for
|
| 5 |
+
clear control flow.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Tuple, Type
|
| 9 |
+
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
from transformers import PreTrainedTokenizerBase
|
| 12 |
+
|
| 13 |
+
from prismatic.conf import DatasetConfig
|
| 14 |
+
from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| 15 |
+
from prismatic.models.backbones.vision import ImageTransform
|
| 16 |
+
from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset
|
| 17 |
+
from prismatic.util.data_utils import PaddedCollatorForLanguageModeling
|
| 18 |
+
|
| 19 |
+
# Dataset Initializers =>> Maps Stage --> cls()
|
| 20 |
+
DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_dataset_and_collator(
|
| 24 |
+
stage: str,
|
| 25 |
+
dataset_cfg: DatasetConfig,
|
| 26 |
+
image_transform: ImageTransform,
|
| 27 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 28 |
+
prompt_builder_fn: Type[PromptBuilder],
|
| 29 |
+
default_image_resolution: Tuple[int, int, int],
|
| 30 |
+
padding_side: str = "right",
|
| 31 |
+
) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]:
|
| 32 |
+
dataset_cls = DATASET_INITIALIZER[stage]
|
| 33 |
+
dataset_root_dir = dataset_cfg.dataset_root_dir
|
| 34 |
+
collator = PaddedCollatorForLanguageModeling(
|
| 35 |
+
tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Switch on `stage`
|
| 39 |
+
if stage == "align":
|
| 40 |
+
annotation_json, image_dir = dataset_cfg.align_stage_components
|
| 41 |
+
dataset = dataset_cls(
|
| 42 |
+
dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer
|
| 43 |
+
)
|
| 44 |
+
return dataset, collator
|
| 45 |
+
|
| 46 |
+
elif stage == "finetune":
|
| 47 |
+
annotation_json, image_dir = dataset_cfg.finetune_stage_components
|
| 48 |
+
dataset = dataset_cls(
|
| 49 |
+
dataset_root_dir / annotation_json,
|
| 50 |
+
dataset_root_dir / image_dir,
|
| 51 |
+
image_transform,
|
| 52 |
+
tokenizer,
|
| 53 |
+
prompt_builder_fn=prompt_builder_fn,
|
| 54 |
+
)
|
| 55 |
+
return dataset, collator
|
| 56 |
+
|
| 57 |
+
elif stage == "full-finetune":
|
| 58 |
+
annotation_json, image_dir = dataset_cfg.finetune_stage_components
|
| 59 |
+
dataset = dataset_cls(
|
| 60 |
+
dataset_root_dir / annotation_json,
|
| 61 |
+
dataset_root_dir / image_dir,
|
| 62 |
+
image_transform,
|
| 63 |
+
tokenizer,
|
| 64 |
+
prompt_builder_fn=prompt_builder_fn,
|
| 65 |
+
)
|
| 66 |
+
return dataset, collator
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError(f"Stage `{stage}` is not supported!")
|
prismatic/py.typed
ADDED
|
File without changes
|
prismatic/util/nn_utils.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nn_utils.py
|
| 3 |
+
|
| 4 |
+
Utility functions and PyTorch submodule definitions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
|
| 12 |
+
class LinearProjector(nn.Module):
|
| 13 |
+
def __init__(self, vision_dim: int, llm_dim: int) -> None:
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
|
| 16 |
+
|
| 17 |
+
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
return self.projector(img_patches)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MLPProjector(nn.Module):
|
| 22 |
+
def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
if mlp_type == "gelu-mlp":
|
| 25 |
+
self.projector = nn.Sequential(
|
| 26 |
+
nn.Linear(vision_dim, llm_dim, bias=True),
|
| 27 |
+
nn.GELU(),
|
| 28 |
+
nn.Linear(llm_dim, llm_dim, bias=True),
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
|
| 32 |
+
|
| 33 |
+
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
return self.projector(img_patches)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FusedMLPProjector(nn.Module):
|
| 38 |
+
def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.initial_projection_dim = fused_vision_dim * 4
|
| 41 |
+
if mlp_type == "fused-gelu-mlp":
|
| 42 |
+
self.projector = nn.Sequential(
|
| 43 |
+
nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True),
|
| 44 |
+
nn.GELU(),
|
| 45 |
+
nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
|
| 46 |
+
nn.GELU(),
|
| 47 |
+
nn.Linear(llm_dim, llm_dim, bias=True),
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!")
|
| 51 |
+
|
| 52 |
+
def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
return self.projector(fused_img_patches)
|
prismatic/vla/datasets/rlds/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .dataset import make_interleaved_dataset, make_single_dataset
|
prismatic/vla/datasets/rlds/dataset.py
ADDED
|
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dataset.py
|
| 3 |
+
|
| 4 |
+
Core interface script for configuring and initializing RLDS datasets.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import copy
|
| 8 |
+
import inspect
|
| 9 |
+
import json
|
| 10 |
+
import random # 导入random模块
|
| 11 |
+
from functools import partial
|
| 12 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import dlimp as dl
|
| 15 |
+
import numpy as np
|
| 16 |
+
import tensorflow as tf
|
| 17 |
+
import tensorflow_datasets as tfds
|
| 18 |
+
|
| 19 |
+
from prismatic.overwatch import initialize_overwatch
|
| 20 |
+
from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX
|
| 21 |
+
from prismatic.vla.datasets.rlds import obs_transforms, traj_transforms
|
| 22 |
+
from prismatic.vla.datasets.rlds.utils import goal_relabeling, task_augmentation
|
| 23 |
+
from prismatic.vla.datasets.rlds.utils.data_utils import (
|
| 24 |
+
allocate_threads,
|
| 25 |
+
get_dataset_statistics,
|
| 26 |
+
normalize_action_and_proprio,
|
| 27 |
+
pprint_data_mixture,
|
| 28 |
+
tree_map,
|
| 29 |
+
shuffle_dataset, # 新增导入shuffle_dataset函数
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 33 |
+
overwatch = initialize_overwatch(__name__)
|
| 34 |
+
|
| 35 |
+
# # Adds a function to set all random seeds
|
| 36 |
+
# def set_all_seeds(seed):
|
| 37 |
+
# """Set the seeds of all random number generators to ensure reproducibility."""
|
| 38 |
+
# random.seed(seed)
|
| 39 |
+
# np.random.seed(seed)
|
| 40 |
+
# tf.random.set_seed(seed)
|
| 41 |
+
# # Enable TensorFlow deterministic operations (if supported by the TensorFlow version)
|
| 42 |
+
# try:
|
| 43 |
+
# tf.config.experimental.enable_op_determinism()
|
| 44 |
+
# except AttributeError:
|
| 45 |
+
# overwatch.warning("The TensorFlow version does not support enable_op_determinism, and the results may not be fully reproducible.")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch)
|
| 49 |
+
tf.config.set_visible_devices([], "GPU")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# # Try to get seeds from environment variables or global Settings and set them
|
| 53 |
+
# try:
|
| 54 |
+
# from prismatic.training.train_utils import get_global_seed
|
| 55 |
+
# seed = get_global_seed()
|
| 56 |
+
# if seed is not None:
|
| 57 |
+
# set_all_seeds(seed)
|
| 58 |
+
# overwatch.info(f"The Dataset module has been set with a random seed: {seed}")
|
| 59 |
+
# except (ImportError, NameError):
|
| 60 |
+
# overwatch.warning("The global seed setting cannot be obtained, so the data processing may not be fully reproducible.")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ruff: noqa: B006
|
| 64 |
+
def make_dataset_from_rlds(
|
| 65 |
+
name: str,
|
| 66 |
+
data_dir: str,
|
| 67 |
+
*,
|
| 68 |
+
train: bool,
|
| 69 |
+
shuffle_seed: int,
|
| 70 |
+
standardize_fn: Optional[Callable[[dict], dict]] = None,
|
| 71 |
+
shuffle: bool = True,
|
| 72 |
+
image_obs_keys: Dict[str, Optional[str]] = {},
|
| 73 |
+
depth_obs_keys: Dict[str, Optional[str]] = {},
|
| 74 |
+
state_obs_keys: List[Optional[str]] = (),
|
| 75 |
+
language_key: Optional[str] = None,
|
| 76 |
+
action_proprio_normalization_type: ACTION_PROPRIO_NORMALIZATION_TYPE,
|
| 77 |
+
dataset_statistics: Optional[Union[dict, str]] = None,
|
| 78 |
+
absolute_action_mask: Optional[List[bool]] = None,
|
| 79 |
+
action_normalization_mask: Optional[List[bool]] = None,
|
| 80 |
+
num_parallel_reads: int = tf.data.AUTOTUNE,
|
| 81 |
+
num_parallel_calls: int = tf.data.AUTOTUNE,
|
| 82 |
+
) -> Tuple[dl.DLataset, dict]:
|
| 83 |
+
"""
|
| 84 |
+
This function is responsible for loading a specific RLDS dataset from storage and getting it into a standardized
|
| 85 |
+
format. Yields a dataset of trajectories. Does not include CPU-intensive operations.
|
| 86 |
+
|
| 87 |
+
If `standardize_fn` is provided, it will be applied to each trajectory. This function should get the trajectory
|
| 88 |
+
into a standard format, which includes the keys "observation" and "action". Entry "observation" should be a
|
| 89 |
+
dictionary containing some number of additional keys, which will be extracted into an even more standardized format
|
| 90 |
+
according to the "*_obs_keys" arguments.
|
| 91 |
+
|
| 92 |
+
The `image_obs_keys` and `depth_obs_keys` arguments are mappings from new names to old names, or None in place of an
|
| 93 |
+
old name to insert padding. For example, if after `standardize_fn`, your "observation" dict has RGB images called
|
| 94 |
+
"workspace" and "wrist", and `image_obs_keys={"primary": "workspace", "secondary": None, "wrist": "wrist"}`, then
|
| 95 |
+
the resulting dataset will have an "observation" dict containing the keys "image_primary", "image_secondary", and
|
| 96 |
+
"image_wrist", where "image_primary" corresponds to "workspace", "image_secondary" is a padding image, and
|
| 97 |
+
"image_wrist" corresponds to "wrist".
|
| 98 |
+
|
| 99 |
+
Entry `state_obs_keys` is a list of 1-dimensional proprioceptive keys to concatenate into a single array, which will
|
| 100 |
+
be placed in the "proprio" key of the "observation" dict. A single padding element (zero) will be inserted for each
|
| 101 |
+
None entry.
|
| 102 |
+
|
| 103 |
+
The dataset will also include a "task" dict. If `language_key` is provided, then the "task" dict will contain the
|
| 104 |
+
key "language_instruction", extracted from `traj[language_key]`.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
name (str): The name of the RLDS dataset (usually "name" or "name:version").
|
| 108 |
+
data_dir (str): The path to the data directory.
|
| 109 |
+
train (bool): Whether to use the training or validation split.
|
| 110 |
+
shuffle (bool, optional): Whether to shuffle the file read order (does NOT fully shuffle the dataset, since one
|
| 111 |
+
file usually contains many trajectories)!
|
| 112 |
+
standardize_fn (Callable[[dict], dict], optional): A function that, if provided, will be the first
|
| 113 |
+
thing applied to each trajectory.
|
| 114 |
+
image_obs_keys (Mapping[str, str|None]): Mapping from {new: old} indicating which RGB images to extract from the
|
| 115 |
+
"observation" dict. `new_obs = {f"image_{new}": old_obs[old] for new, old in image_obs_keys.items()}`.
|
| 116 |
+
If a value of `old` is None, inserts a padding image instead (empty string).
|
| 117 |
+
depth_obs_keys (Mapping[str, str|None]): Same as `image_obs_keys`, but for depth images. Keys will be
|
| 118 |
+
prefixed with "depth_" instead of "image_".
|
| 119 |
+
state_obs_keys (Sequence[str|None]): List of 1-dimensional proprioception keys to be extracted from the
|
| 120 |
+
"observation" dict, concatenated, and mapped to "proprio". Inserts 1 element of padding for each None entry.
|
| 121 |
+
language_key (str, optional): If provided, the "task" dict will contain the key "language_instruction",
|
| 122 |
+
extracted from `traj[language_key]`.
|
| 123 |
+
action_proprio_normalization_type (str, optional): The type of normalization to perform on the action,
|
| 124 |
+
proprio, or both. Can be "normal" (mean 0, std 1) or "bounds" (normalized to [-1, 1]).
|
| 125 |
+
dataset_statistics: (dict|str, optional): dict (or path to JSON file) that contains dataset statistics
|
| 126 |
+
for normalization. If `action_proprio_normalization_type` is "normal", this should contain "mean" and
|
| 127 |
+
"std" keys. If `action_proprio_normalization_type` is "bounds", this should contain "min" and "max"
|
| 128 |
+
keys. May also provide "num_transitions" and "num_trajectories" keys for downstream usage (e.g., for
|
| 129 |
+
`make_interleaved_dataset`). If not provided, the statistics will be computed on the fly.
|
| 130 |
+
absolute_action_mask (Sequence[bool], optional): By default, all action dimensions are assumed to be
|
| 131 |
+
relative. This is important for when `future_action_window_size > 0`: actions that are taken
|
| 132 |
+
from beyond the end of the trajectory (or beyond the goal timestep when goal relabeling is used)
|
| 133 |
+
need to be made "neutral" to indicate that the task has been completed. For relative actions,
|
| 134 |
+
"neutral" means zero, but for absolute actions, "neutral" means repeating the last valid action.
|
| 135 |
+
This mask, if provided, indicates which action dimensions are absolute.
|
| 136 |
+
action_normalization_mask (Sequence[bool], optional): If provided, indicates which action dimensions
|
| 137 |
+
should be normalized. For example, you might not want to normalize the gripper action dimension if
|
| 138 |
+
it's always exactly 0 or 1. By default, all action dimensions are normalized.
|
| 139 |
+
num_parallel_reads (int): number of parallel read workers. Default to AUTOTUNE.
|
| 140 |
+
num_parallel_calls (int): number of parallel calls for traj_map operations. Default to AUTOTUNE.
|
| 141 |
+
Returns:
|
| 142 |
+
Dataset of trajectories where each step has the following fields:
|
| 143 |
+
- observation:
|
| 144 |
+
- image_{name1, name2, ...} # RGB image observations
|
| 145 |
+
- depth_{name1, name2, ...} # depth image observations
|
| 146 |
+
- proprio # 1-dimensional array of proprioceptive observations
|
| 147 |
+
- timestep # timestep of each frame
|
| 148 |
+
- task:
|
| 149 |
+
- language_instruction # language instruction, present if `language_key` is provided
|
| 150 |
+
- action # action vector
|
| 151 |
+
- dataset_name # name of the dataset
|
| 152 |
+
"""
|
| 153 |
+
REQUIRED_KEYS = {"observation", "action"}
|
| 154 |
+
if language_key is not None:
|
| 155 |
+
REQUIRED_KEYS.add(language_key)
|
| 156 |
+
|
| 157 |
+
def restructure(traj):
|
| 158 |
+
# apply a standardization function, if provided
|
| 159 |
+
if standardize_fn is not None:
|
| 160 |
+
traj = standardize_fn(traj)
|
| 161 |
+
|
| 162 |
+
if not all(k in traj for k in REQUIRED_KEYS):
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"Trajectory is missing keys: {REQUIRED_KEYS - set(traj.keys())}. " "Did you write a `standardize_fn`?"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# extracts images, depth images and proprio from the "observation" dict
|
| 168 |
+
traj_len = tf.shape(traj["action"])[0]
|
| 169 |
+
old_obs = traj["observation"]
|
| 170 |
+
new_obs = {}
|
| 171 |
+
for new, old in image_obs_keys.items():
|
| 172 |
+
if old is None:
|
| 173 |
+
new_obs[f"image_{new}"] = tf.repeat("", traj_len) # padding
|
| 174 |
+
else:
|
| 175 |
+
new_obs[f"image_{new}"] = old_obs[old]
|
| 176 |
+
|
| 177 |
+
for new, old in depth_obs_keys.items():
|
| 178 |
+
if old is None:
|
| 179 |
+
new_obs[f"depth_{new}"] = tf.repeat("", traj_len) # padding
|
| 180 |
+
else:
|
| 181 |
+
new_obs[f"depth_{new}"] = old_obs[old]
|
| 182 |
+
|
| 183 |
+
if state_obs_keys:
|
| 184 |
+
new_obs["proprio"] = tf.concat(
|
| 185 |
+
[
|
| 186 |
+
(
|
| 187 |
+
tf.zeros((traj_len, 1), dtype=tf.float32) # padding
|
| 188 |
+
if key is None
|
| 189 |
+
else tf.cast(old_obs[key], tf.float32)
|
| 190 |
+
)
|
| 191 |
+
for key in state_obs_keys
|
| 192 |
+
],
|
| 193 |
+
axis=1,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# add timestep info
|
| 197 |
+
new_obs["timestep"] = tf.range(traj_len)
|
| 198 |
+
|
| 199 |
+
# extracts `language_key` into the "task" dict
|
| 200 |
+
task = {}
|
| 201 |
+
if language_key is not None:
|
| 202 |
+
if traj[language_key].dtype != tf.string:
|
| 203 |
+
raise ValueError(
|
| 204 |
+
f"Language key {language_key} has dtype {traj[language_key].dtype}, " "but it must be tf.string."
|
| 205 |
+
)
|
| 206 |
+
task["language_instruction"] = traj.pop(language_key)
|
| 207 |
+
|
| 208 |
+
traj = {
|
| 209 |
+
"observation": new_obs,
|
| 210 |
+
"task": task,
|
| 211 |
+
"action": tf.cast(traj["action"], tf.float32),
|
| 212 |
+
"dataset_name": tf.repeat(name, traj_len),
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
if absolute_action_mask is not None:
|
| 216 |
+
if len(absolute_action_mask) != traj["action"].shape[-1]:
|
| 217 |
+
raise ValueError(
|
| 218 |
+
f"Length of absolute_action_mask ({len(absolute_action_mask)}) "
|
| 219 |
+
f"does not match action dimension ({traj['action'].shape[-1]})."
|
| 220 |
+
)
|
| 221 |
+
traj["absolute_action_mask"] = tf.tile(
|
| 222 |
+
tf.convert_to_tensor(absolute_action_mask, dtype=tf.bool)[None],
|
| 223 |
+
[traj_len, 1],
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return traj
|
| 227 |
+
|
| 228 |
+
builder = tfds.builder(name, data_dir=data_dir)
|
| 229 |
+
|
| 230 |
+
# load or compute dataset statistics
|
| 231 |
+
if isinstance(dataset_statistics, str):
|
| 232 |
+
with tf.io.gfile.GFile(dataset_statistics, "r") as f:
|
| 233 |
+
dataset_statistics = json.load(f)
|
| 234 |
+
elif dataset_statistics is None:
|
| 235 |
+
full_dataset = dl.DLataset.from_rlds(
|
| 236 |
+
builder, split="all", shuffle=False, num_parallel_reads=num_parallel_reads
|
| 237 |
+
).traj_map(restructure, num_parallel_calls)
|
| 238 |
+
# tries to load from cache, otherwise computes on the fly
|
| 239 |
+
dataset_statistics = get_dataset_statistics(
|
| 240 |
+
full_dataset,
|
| 241 |
+
hash_dependencies=(
|
| 242 |
+
str(builder.info),
|
| 243 |
+
str(state_obs_keys),
|
| 244 |
+
inspect.getsource(standardize_fn) if standardize_fn is not None else "",
|
| 245 |
+
),
|
| 246 |
+
save_dir=builder.data_dir,
|
| 247 |
+
)
|
| 248 |
+
dataset_statistics = tree_map(np.array, dataset_statistics)
|
| 249 |
+
|
| 250 |
+
# skip normalization for certain action dimensions
|
| 251 |
+
if action_normalization_mask is not None:
|
| 252 |
+
if len(action_normalization_mask) != dataset_statistics["action"]["mean"].shape[-1]:
|
| 253 |
+
raise ValueError(
|
| 254 |
+
f"Length of skip_normalization_mask ({len(action_normalization_mask)}) "
|
| 255 |
+
f"does not match action dimension ({dataset_statistics['action']['mean'].shape[-1]})."
|
| 256 |
+
)
|
| 257 |
+
dataset_statistics["action"]["mask"] = np.array(action_normalization_mask)
|
| 258 |
+
|
| 259 |
+
# construct the dataset
|
| 260 |
+
split = "train" if train else "val"
|
| 261 |
+
|
| 262 |
+
dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads, shuffle_seed=shuffle_seed)
|
| 263 |
+
|
| 264 |
+
dataset = dataset.traj_map(restructure, num_parallel_calls)
|
| 265 |
+
dataset = dataset.traj_map(
|
| 266 |
+
partial(
|
| 267 |
+
normalize_action_and_proprio,
|
| 268 |
+
metadata=dataset_statistics,
|
| 269 |
+
normalization_type=action_proprio_normalization_type,
|
| 270 |
+
),
|
| 271 |
+
num_parallel_calls,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
return dataset, dataset_statistics
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def apply_trajectory_transforms(
|
| 278 |
+
dataset: dl.DLataset,
|
| 279 |
+
*,
|
| 280 |
+
train: bool,
|
| 281 |
+
goal_relabeling_strategy: Optional[str] = None,
|
| 282 |
+
goal_relabeling_kwargs: dict = {},
|
| 283 |
+
window_size: int = 1,
|
| 284 |
+
future_action_window_size: int = 0,
|
| 285 |
+
subsample_length: Optional[int] = None,
|
| 286 |
+
skip_unlabeled: bool = False,
|
| 287 |
+
max_action: Optional[float] = None,
|
| 288 |
+
max_proprio: Optional[float] = None,
|
| 289 |
+
task_augment_strategy: Optional[str] = None,
|
| 290 |
+
task_augment_kwargs: dict = {},
|
| 291 |
+
num_parallel_calls: int = tf.data.AUTOTUNE,
|
| 292 |
+
use_predict_future_prop: bool = False,
|
| 293 |
+
) -> dl.DLataset:
|
| 294 |
+
"""
|
| 295 |
+
Applies common transforms that happen at a trajectory level. Such transforms are usually some sort of "relabeling"
|
| 296 |
+
(e.g., filtering, chunking, adding goals, dropping keys).
|
| 297 |
+
|
| 298 |
+
Transforms in this function should have the following properties:
|
| 299 |
+
- They require access to an entire trajectory (i.e., they cannot be applied frame-wise).
|
| 300 |
+
- They are generally not CPU-intensive, mostly involving moving and copying data.
|
| 301 |
+
- They do not require decoded images.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
dataset (dl.DLataset): The dataset to transform.
|
| 305 |
+
train (bool): Whether the dataset is for training (affects subsampling).
|
| 306 |
+
goal_relabeling_strategy (str, optional): The goal relabeling strategy to use, or None for
|
| 307 |
+
no goal relabeling. See `goal_relabeling.py`.
|
| 308 |
+
goal_relabeling_kwargs (dict, optional): Additional keyword arguments to pass to the goal relabeling function.
|
| 309 |
+
window_size (int, optional): The length of the snippets that trajectories are chunked into.
|
| 310 |
+
future_action_window_size (int, optional): The number of future actions beyond window_size to include
|
| 311 |
+
in the chunked actions.
|
| 312 |
+
subsample_length (int, optional): If provided, trajectories longer than this will be subsampled to
|
| 313 |
+
this length (after goal relabeling and chunking).
|
| 314 |
+
skip_unlabeled (bool, optional): Whether to skip trajectories with no language labels.
|
| 315 |
+
max_action: (float, optional): If provided, trajectories in which *any* action dimension
|
| 316 |
+
of *any* transition has an absolute value larger than this will be skipped.
|
| 317 |
+
max_proprio: (float, optional): If provided, trajectories in which *any* proprio dimension
|
| 318 |
+
of *any* transition has an absolute value larger than this will be skipped.
|
| 319 |
+
task_augment_strategy (str, optional): The task augmentation strategy to use, or None for no task
|
| 320 |
+
augmentation. See `task_augmentation.py`.
|
| 321 |
+
task_augment_kwargs (dict, optional): Additional keyword arguments to pass to the task augmentation
|
| 322 |
+
function.
|
| 323 |
+
num_parallel_calls (int, optional): number of parallel calls for map operations. Default to AUTOTUNE.
|
| 324 |
+
"""
|
| 325 |
+
if skip_unlabeled:
|
| 326 |
+
if "language_instruction" not in dataset.element_spec["task"]:
|
| 327 |
+
raise ValueError("skip_unlabeled=True but dataset does not have language labels.")
|
| 328 |
+
|
| 329 |
+
dataset = dataset.filter(lambda x: tf.math.reduce_any(x["task"]["language_instruction"] != ""))
|
| 330 |
+
|
| 331 |
+
if max_action is not None:
|
| 332 |
+
dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["action"]) <= max_action))
|
| 333 |
+
|
| 334 |
+
if max_proprio is not None and "proprio" in dataset.element_spec["observation"]:
|
| 335 |
+
dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["observation"]["proprio"]) <= max_proprio))
|
| 336 |
+
|
| 337 |
+
# Filter out trajectories that are too short for action chunking
|
| 338 |
+
# Required minimum length: window_size + future_action_window_size
|
| 339 |
+
# required_min_length = window_size + future_action_window_size
|
| 340 |
+
# if required_min_length > 1:
|
| 341 |
+
# overwatch.info(f"Filtering trajectories shorter than {required_min_length} steps for action chunking (window_size={window_size}, future_action_window_size={future_action_window_size})")
|
| 342 |
+
|
| 343 |
+
# # Quick statistics: sample a subset of data to estimate filtering ratio
|
| 344 |
+
# try:
|
| 345 |
+
# sample_size = 1000 # Number of samples
|
| 346 |
+
# before_sample = dataset.take(sample_size)
|
| 347 |
+
|
| 348 |
+
# # Count total and valid trajectories in the sample
|
| 349 |
+
# total_sampled = 0
|
| 350 |
+
# valid_sampled = 0
|
| 351 |
+
|
| 352 |
+
# for item in before_sample:
|
| 353 |
+
# total_sampled += 1
|
| 354 |
+
# traj_length = tf.shape(item["action"])[0].numpy()
|
| 355 |
+
# if traj_length >= required_min_length:
|
| 356 |
+
# valid_sampled += 1
|
| 357 |
+
|
| 358 |
+
# if total_sampled > 0:
|
| 359 |
+
# filter_ratio = valid_sampled / total_sampled
|
| 360 |
+
# filtered_ratio = (total_sampled - valid_sampled) / total_sampled
|
| 361 |
+
# overwatch.info(f"Sample statistics ({sample_size} trajectories): keep rate {filter_ratio:.2%}, filter rate {filtered_ratio:.2%}")
|
| 362 |
+
# overwatch.info(f"Estimated ~{filtered_ratio:.1%} of trajectories will be filtered due to insufficient length")
|
| 363 |
+
# else:
|
| 364 |
+
# overwatch.info("Unable to obtain sample data for statistics")
|
| 365 |
+
|
| 366 |
+
# except Exception as e:
|
| 367 |
+
# overwatch.warning(f"Error during quick statistics: {e}, continuing with filtering operation")
|
| 368 |
+
|
| 369 |
+
# Execute the actual filtering operation
|
| 370 |
+
# dataset = dataset.filter(lambda x: tf.shape(x["action"])[0] >= required_min_length)
|
| 371 |
+
# overwatch.info("Trajectory length filtering completed")
|
| 372 |
+
# marks which entires of the observation and task dicts are padding
|
| 373 |
+
dataset = dataset.traj_map(traj_transforms.add_pad_mask_dict, num_parallel_calls)
|
| 374 |
+
|
| 375 |
+
# updates the "task" dict
|
| 376 |
+
if goal_relabeling_strategy is not None:
|
| 377 |
+
dataset = dataset.traj_map(
|
| 378 |
+
partial(getattr(goal_relabeling, goal_relabeling_strategy), **goal_relabeling_kwargs),
|
| 379 |
+
num_parallel_calls,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# must run task augmentation before chunking, in case it changes goal timesteps
|
| 383 |
+
if train and task_augment_strategy is not None:
|
| 384 |
+
# perform task augmentation (e.g., dropping keys)
|
| 385 |
+
dataset = dataset.traj_map(
|
| 386 |
+
partial(
|
| 387 |
+
getattr(task_augmentation, task_augment_strategy),
|
| 388 |
+
**task_augment_kwargs,
|
| 389 |
+
),
|
| 390 |
+
num_parallel_calls,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# chunks observations and actions, giving them a new axis at index 1 of size `window_size` and
|
| 394 |
+
# `window_size + future_action_window_size`, respectively
|
| 395 |
+
if use_predict_future_prop:
|
| 396 |
+
traj_transforms_strategy = traj_transforms.chunk_act_future_obs
|
| 397 |
+
else:
|
| 398 |
+
traj_transforms_strategy = traj_transforms.chunk_act_obs
|
| 399 |
+
|
| 400 |
+
dataset = dataset.traj_map(
|
| 401 |
+
partial(
|
| 402 |
+
traj_transforms_strategy,
|
| 403 |
+
window_size=window_size,
|
| 404 |
+
future_action_window_size=future_action_window_size,
|
| 405 |
+
),
|
| 406 |
+
num_parallel_calls,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
if train and subsample_length is not None:
|
| 410 |
+
dataset = dataset.traj_map(
|
| 411 |
+
partial(traj_transforms.subsample, subsample_length=subsample_length),
|
| 412 |
+
num_parallel_calls,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
return dataset
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def apply_per_dataset_frame_transforms(
|
| 419 |
+
dataset: dl.DLataset,
|
| 420 |
+
chunk_filter_fn: Optional[Callable] = None,
|
| 421 |
+
):
|
| 422 |
+
"""
|
| 423 |
+
Optionally applied *per-dataset* transforms that happen at a frame level.
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
chunk_filter_fn (callable, optional): Filter function for chunks.
|
| 427 |
+
"""
|
| 428 |
+
if chunk_filter_fn:
|
| 429 |
+
dataset = dataset.filter(chunk_filter_fn)
|
| 430 |
+
return dataset
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def apply_frame_transforms(
|
| 434 |
+
dataset: dl.DLataset,
|
| 435 |
+
*,
|
| 436 |
+
train: bool,
|
| 437 |
+
image_augment_kwargs: Union[Dict, Dict[str, Dict]] = {},
|
| 438 |
+
resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {},
|
| 439 |
+
depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {},
|
| 440 |
+
num_parallel_calls: int = tf.data.AUTOTUNE,
|
| 441 |
+
) -> dl.DLataset:
|
| 442 |
+
"""
|
| 443 |
+
Applies common transforms that happen at a frame level. These transforms are usually more CPU-intensive, (e.g.,
|
| 444 |
+
decoding or resizing images).
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
train (bool): Whether the dataset is for training (affects image augmentation).
|
| 448 |
+
dataset (dl.DLataset): The dataset to transform.
|
| 449 |
+
image_augment_kwargs (dict|Mapping[str, dict]): Keyword arguments to pass to the image augmentation
|
| 450 |
+
function. See `dlimp.transforms.augment_image` for documentation of these kwargs. If a dict of
|
| 451 |
+
dicts is provided, then key "k" will be used for "image_{k}" (names determined by `image_obs_keys`
|
| 452 |
+
in `make_dataset_from_rlds`). Augmentation will be skipped for missing keys (so pass an empty dict
|
| 453 |
+
to skip augmentation for all images).
|
| 454 |
+
resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): If provided, images will be resized to
|
| 455 |
+
this size. If a dict of tuples is provided, then key "k" will be used for "image_{k}" (names
|
| 456 |
+
determined by `image_obs_keys` in `make_dataset_from_rlds`). Resizing will be skipped for missing
|
| 457 |
+
keys (so pass an empty dict to skip resizing for all images).
|
| 458 |
+
depth_resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): Same as resize_size, but for depth
|
| 459 |
+
images.
|
| 460 |
+
num_parallel_calls (int): number of parallel calls for frame_map operations. Default to AUTOTUNE.
|
| 461 |
+
"""
|
| 462 |
+
|
| 463 |
+
# Convenience wrapper that takes a function that operates on a non-chunked "observation" dict and applies
|
| 464 |
+
# it to the chunked "observation" dict as well as the non-chunked "task" dict
|
| 465 |
+
def apply_obs_transform(fn: Callable[[Dict], Dict], frame: Dict) -> Dict:
|
| 466 |
+
frame["task"] = fn(frame["task"])
|
| 467 |
+
frame["observation"] = dl.vmap(fn)(frame["observation"])
|
| 468 |
+
return frame
|
| 469 |
+
|
| 470 |
+
# Decode + resize images (and depth images)
|
| 471 |
+
dataset = dataset.frame_map(
|
| 472 |
+
partial(
|
| 473 |
+
apply_obs_transform,
|
| 474 |
+
partial(obs_transforms.decode_and_resize, resize_size=resize_size, depth_resize_size=depth_resize_size),
|
| 475 |
+
),
|
| 476 |
+
num_parallel_calls,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
if train:
|
| 480 |
+
# Augment all images with the same seed, skipping padding images
|
| 481 |
+
def aug(frame: dict):
|
| 482 |
+
seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32)
|
| 483 |
+
aug_fn = partial(obs_transforms.augment, seed=seed, augment_kwargs=image_augment_kwargs)
|
| 484 |
+
return apply_obs_transform(aug_fn, frame)
|
| 485 |
+
|
| 486 |
+
dataset = dataset.frame_map(aug, num_parallel_calls)
|
| 487 |
+
|
| 488 |
+
return dataset
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def make_single_dataset(
|
| 492 |
+
dataset_kwargs: dict,
|
| 493 |
+
*,
|
| 494 |
+
train: bool,
|
| 495 |
+
traj_transform_kwargs: dict = {},
|
| 496 |
+
frame_transform_kwargs: dict = {},
|
| 497 |
+
) -> dl.DLataset:
|
| 498 |
+
"""Creates a single dataset from kwargs. Returns a dataset of trajectories.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
dataset_kwargs: kwargs passed to `make_dataset_from_rlds` that are dataset-specific.
|
| 502 |
+
train: whether this is a training or validation dataset.
|
| 503 |
+
traj_transform_kwargs: kwargs passed to 'apply_trajectory_transforms'.
|
| 504 |
+
frame_transform_kwargs: kwargs passed to 'get_frame_transforms'.
|
| 505 |
+
"""
|
| 506 |
+
dataset, dataset_statistics = make_dataset_from_rlds(
|
| 507 |
+
**dataset_kwargs,
|
| 508 |
+
train=train,
|
| 509 |
+
)
|
| 510 |
+
dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train)
|
| 511 |
+
dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)
|
| 512 |
+
|
| 513 |
+
# this seems to reduce memory usage without affecting speed
|
| 514 |
+
dataset = dataset.with_ram_budget(1)
|
| 515 |
+
|
| 516 |
+
# save for later
|
| 517 |
+
return dataset, dataset_statistics["num_trajectories"], dataset_statistics
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
# === Core Initializer ===
|
| 521 |
+
def make_interleaved_dataset(
|
| 522 |
+
dataset_kwargs_list: List[Dict],
|
| 523 |
+
sample_weights: Optional[List[float]] = None,
|
| 524 |
+
*,
|
| 525 |
+
train: bool,
|
| 526 |
+
shuffle_buffer_size: int,
|
| 527 |
+
shuffle_seed:int,
|
| 528 |
+
traj_transform_kwargs: Optional[Dict] = None,
|
| 529 |
+
frame_transform_kwargs: Optional[Dict] = None,
|
| 530 |
+
batch_size: Optional[int] = None,
|
| 531 |
+
balance_weights: bool = False,
|
| 532 |
+
traj_transform_threads: Optional[int] = None,
|
| 533 |
+
traj_read_threads: Optional[int] = None,
|
| 534 |
+
) -> dl.DLataset:
|
| 535 |
+
"""
|
| 536 |
+
Creates an interleaved dataset from list of dataset configs (kwargs). Returns a dataset of batched frames.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
dataset_kwargs_list: list of kwargs, each element of which is passed to `make_dataset_from_rlds`.
|
| 540 |
+
"num_parallel_calls" and "num_parallel_reads" are overridden using `traj_transform_threads` and
|
| 541 |
+
`traj_read_threads`, respectively.
|
| 542 |
+
sample_weights: sampling weights for each dataset in list. If None, defaults to uniform.
|
| 543 |
+
train: whether this is a training or validation dataset.
|
| 544 |
+
shuffle_buffer_size: size of the dataset shuffle buffer (in number of frames).
|
| 545 |
+
traj_transform_kwargs: kwargs passed to `apply_trajectory_transforms`. "num_parallel_calls" is
|
| 546 |
+
overridden using `traj_transform_threads`.
|
| 547 |
+
frame_transform_kwargs: kwargs passed to `apply_frame_transforms`.
|
| 548 |
+
batch_size: batch size, if not provided output is not batched.
|
| 549 |
+
balance_weights: if True, the sample weights are multiplied by the number of frames in each dataset.
|
| 550 |
+
This makes it so that, if all the sample weights are equal, one full iteration through the interleaved
|
| 551 |
+
dataset will correspond to one full iteration through each individual dataset (only in expectation,
|
| 552 |
+
since in practice the sampling is random).
|
| 553 |
+
traj_transform_threads: total number of parallel calls for trajectory transforms, distributed across
|
| 554 |
+
datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset.
|
| 555 |
+
traj_read_threads: total number of parallel read workers for trajectory transforms, distributed across
|
| 556 |
+
datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset.
|
| 557 |
+
"""
|
| 558 |
+
# Default to uniform sampling (if `sample_weights` is not specified)
|
| 559 |
+
|
| 560 |
+
if not sample_weights:
|
| 561 |
+
sample_weights = [1.0] * len(dataset_kwargs_list)
|
| 562 |
+
|
| 563 |
+
if len(sample_weights) != len(dataset_kwargs_list):
|
| 564 |
+
raise ValueError(f"sample_weights must be None or have length {len(dataset_kwargs_list)}.")
|
| 565 |
+
|
| 566 |
+
# Check valid `traj_transform_kwargs` and `frame_transform_kwargs`
|
| 567 |
+
if (traj_transform_kwargs is None) or (frame_transform_kwargs is None):
|
| 568 |
+
raise ValueError("Missing `traj_transform_kwargs` and `frame_transform_kwargs`!")
|
| 569 |
+
|
| 570 |
+
# Get Dataset Sizes
|
| 571 |
+
dataset_sizes, all_dataset_statistics = [], {}
|
| 572 |
+
for dataset_kwargs in dataset_kwargs_list:
|
| 573 |
+
data_kwargs = copy.deepcopy(dataset_kwargs)
|
| 574 |
+
if "dataset_frame_transform_kwargs" in data_kwargs:
|
| 575 |
+
data_kwargs.pop("dataset_frame_transform_kwargs")
|
| 576 |
+
_, dataset_statistics = make_dataset_from_rlds(**data_kwargs, train=train, shuffle_seed = shuffle_seed)
|
| 577 |
+
dataset_sizes.append(dataset_statistics["num_transitions"])
|
| 578 |
+
all_dataset_statistics[dataset_kwargs["name"]] = dataset_statistics
|
| 579 |
+
|
| 580 |
+
# Get the indices of the "primary" datasets (i.e., datasets with sample_weight == 1.0)
|
| 581 |
+
primary_dataset_indices = np.array([idx for idx in range(len(sample_weights)) if sample_weights[idx] == 1.0])
|
| 582 |
+
|
| 583 |
+
# Balance and Normalize Weights
|
| 584 |
+
if balance_weights:
|
| 585 |
+
sample_weights = np.array(sample_weights) * np.array(dataset_sizes)
|
| 586 |
+
sample_weights = np.array(sample_weights) / np.sum(sample_weights)
|
| 587 |
+
pprint_data_mixture(dataset_kwargs_list, sample_weights)
|
| 588 |
+
|
| 589 |
+
# Effective Dataset Length = Number of samples until each dataset has completed at least one epoch
|
| 590 |
+
# =>> Note :: Only counting the "primary" datasets (i.e., datasets with sample_weight == 1.0)
|
| 591 |
+
dataset_len = int((np.array(dataset_sizes) / sample_weights)[primary_dataset_indices].max())
|
| 592 |
+
|
| 593 |
+
# Allocate Threads based on Weights
|
| 594 |
+
threads_per_dataset = allocate_threads(traj_transform_threads, sample_weights)
|
| 595 |
+
reads_per_dataset = allocate_threads(traj_read_threads, sample_weights)
|
| 596 |
+
|
| 597 |
+
overwatch.info("Threads per Dataset: %s", threads_per_dataset)
|
| 598 |
+
overwatch.info("Reads per Dataset: %s", reads_per_dataset)
|
| 599 |
+
|
| 600 |
+
# Construct Datasets
|
| 601 |
+
overwatch.info("Constructing datasets...")
|
| 602 |
+
datasets = []
|
| 603 |
+
for dataset_kwargs, threads, reads in zip(
|
| 604 |
+
dataset_kwargs_list,
|
| 605 |
+
threads_per_dataset,
|
| 606 |
+
reads_per_dataset,
|
| 607 |
+
):
|
| 608 |
+
dataset_frame_transform_kwargs = (
|
| 609 |
+
dataset_kwargs.pop("dataset_frame_transform_kwargs")
|
| 610 |
+
if "dataset_frame_transform_kwargs" in dataset_kwargs
|
| 611 |
+
else {}
|
| 612 |
+
)
|
| 613 |
+
dataset, _ = make_dataset_from_rlds(
|
| 614 |
+
**dataset_kwargs,
|
| 615 |
+
train=train,
|
| 616 |
+
shuffle_seed=shuffle_seed,
|
| 617 |
+
num_parallel_calls=threads,
|
| 618 |
+
num_parallel_reads=reads,
|
| 619 |
+
dataset_statistics=all_dataset_statistics[dataset_kwargs["name"]],
|
| 620 |
+
)
|
| 621 |
+
dataset = apply_trajectory_transforms(
|
| 622 |
+
dataset.repeat(),
|
| 623 |
+
**traj_transform_kwargs,
|
| 624 |
+
num_parallel_calls=threads,
|
| 625 |
+
train=train,
|
| 626 |
+
).flatten(num_parallel_calls=threads)
|
| 627 |
+
dataset = apply_per_dataset_frame_transforms(dataset, **dataset_frame_transform_kwargs)
|
| 628 |
+
datasets.append(dataset)
|
| 629 |
+
|
| 630 |
+
# Interleave at the Frame Level
|
| 631 |
+
dataset: dl.DLataset = dl.DLataset.sample_from_datasets(datasets, sample_weights, seed=shuffle_seed)
|
| 632 |
+
|
| 633 |
+
# Validation =>> fix a single shuffle buffer of data and cache it in RAM; prevents gradual memory increase!
|
| 634 |
+
if not train:
|
| 635 |
+
dataset = dataset.take(shuffle_buffer_size).cache()
|
| 636 |
+
|
| 637 |
+
# Shuffle the Dataset
|
| 638 |
+
# =>> IMPORTANT :: Shuffle AFTER .cache(), or else memory will still leak!
|
| 639 |
+
dataset = dataset.shuffle(shuffle_buffer_size, seed=shuffle_seed)
|
| 640 |
+
|
| 641 |
+
# Apply Frame Transforms
|
| 642 |
+
overwatch.info("Applying frame transforms on dataset...")
|
| 643 |
+
dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train)
|
| 644 |
+
|
| 645 |
+
# [Contract] When training VLA Policies, we let the Collator handle Batching!
|
| 646 |
+
if batch_size is not None:
|
| 647 |
+
dataset = dataset.batch(batch_size)
|
| 648 |
+
|
| 649 |
+
# Note =>> Seems to reduce memory usage without affecting speed?
|
| 650 |
+
dataset = dataset.with_ram_budget(1)
|
| 651 |
+
|
| 652 |
+
# Save for Later
|
| 653 |
+
dataset.sample_weights = sample_weights
|
| 654 |
+
|
| 655 |
+
return dataset, dataset_len, all_dataset_statistics
|
prismatic/vla/datasets/rlds/oxe/transforms.py
ADDED
|
@@ -0,0 +1,951 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
transforms.py
|
| 3 |
+
|
| 4 |
+
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
|
| 5 |
+
|
| 6 |
+
Transforms adopt the following structure:
|
| 7 |
+
Input: Dictionary of *batched* features (i.e., has leading time dimension)
|
| 8 |
+
Output: Dictionary `step` =>> {
|
| 9 |
+
"observation": {
|
| 10 |
+
<image_keys, depth_image_keys>
|
| 11 |
+
State (in chosen state representation)
|
| 12 |
+
},
|
| 13 |
+
"action": Action (in chosen action representation),
|
| 14 |
+
"language_instruction": str
|
| 15 |
+
}
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from typing import Any, Dict
|
| 19 |
+
|
| 20 |
+
import tensorflow as tf
|
| 21 |
+
|
| 22 |
+
from prismatic.vla.datasets.rlds.oxe.utils.droid_utils import droid_baseact_transform, droid_finetuning_transform
|
| 23 |
+
from prismatic.vla.datasets.rlds.utils.data_utils import (
|
| 24 |
+
binarize_gripper_actions,
|
| 25 |
+
invert_gripper_actions,
|
| 26 |
+
rel2abs_gripper_actions,
|
| 27 |
+
relabel_bridge_actions,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 32 |
+
"""
|
| 33 |
+
Applies to version of Bridge V2 in Open X-Embodiment mixture.
|
| 34 |
+
|
| 35 |
+
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
| 36 |
+
"""
|
| 37 |
+
for key in trajectory.keys():
|
| 38 |
+
if key == "traj_metadata":
|
| 39 |
+
continue
|
| 40 |
+
elif key in ["observation", "action"]:
|
| 41 |
+
for key2 in trajectory[key]:
|
| 42 |
+
trajectory[key][key2] = trajectory[key][key2][1:]
|
| 43 |
+
else:
|
| 44 |
+
trajectory[key] = trajectory[key][1:]
|
| 45 |
+
|
| 46 |
+
trajectory["action"] = tf.concat(
|
| 47 |
+
(
|
| 48 |
+
trajectory["action"]["world_vector"],
|
| 49 |
+
trajectory["action"]["rotation_delta"],
|
| 50 |
+
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
| 51 |
+
),
|
| 52 |
+
axis=-1,
|
| 53 |
+
)
|
| 54 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 55 |
+
trajectory = relabel_bridge_actions(trajectory)
|
| 56 |
+
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
| 57 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 58 |
+
return trajectory
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 62 |
+
"""
|
| 63 |
+
Applies to original version of Bridge V2 from the official project website.
|
| 64 |
+
|
| 65 |
+
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
| 66 |
+
"""
|
| 67 |
+
for key in trajectory.keys():
|
| 68 |
+
if key == "traj_metadata":
|
| 69 |
+
continue
|
| 70 |
+
elif key == "observation":
|
| 71 |
+
for key2 in trajectory[key]:
|
| 72 |
+
trajectory[key][key2] = trajectory[key][key2][1:]
|
| 73 |
+
else:
|
| 74 |
+
trajectory[key] = trajectory[key][1:]
|
| 75 |
+
|
| 76 |
+
trajectory["action"] = tf.concat(
|
| 77 |
+
[
|
| 78 |
+
trajectory["action"][:, :6],
|
| 79 |
+
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
| 80 |
+
],
|
| 81 |
+
axis=1,
|
| 82 |
+
)
|
| 83 |
+
trajectory = relabel_bridge_actions(trajectory)
|
| 84 |
+
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
| 85 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 86 |
+
return trajectory
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 90 |
+
trajectory["action"] = tf.concat(
|
| 91 |
+
[
|
| 92 |
+
trajectory["action"][:, :6],
|
| 93 |
+
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
| 94 |
+
],
|
| 95 |
+
axis=1,
|
| 96 |
+
)
|
| 97 |
+
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
|
| 98 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
|
| 99 |
+
return trajectory
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 103 |
+
# make gripper action absolute action, +1 = open, 0 = close
|
| 104 |
+
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
| 105 |
+
gripper_action = rel2abs_gripper_actions(gripper_action)
|
| 106 |
+
|
| 107 |
+
trajectory["action"] = tf.concat(
|
| 108 |
+
(
|
| 109 |
+
trajectory["action"]["world_vector"],
|
| 110 |
+
trajectory["action"]["rotation_delta"],
|
| 111 |
+
gripper_action[:, None],
|
| 112 |
+
),
|
| 113 |
+
axis=-1,
|
| 114 |
+
)
|
| 115 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 116 |
+
return trajectory
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 120 |
+
# make gripper action absolute action, +1 = open, 0 = close
|
| 121 |
+
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
| 122 |
+
gripper_action = rel2abs_gripper_actions(gripper_action)
|
| 123 |
+
|
| 124 |
+
trajectory["action"] = tf.concat(
|
| 125 |
+
(
|
| 126 |
+
trajectory["action"]["world_vector"],
|
| 127 |
+
trajectory["action"]["rotation_delta"],
|
| 128 |
+
gripper_action[:, None],
|
| 129 |
+
),
|
| 130 |
+
axis=-1,
|
| 131 |
+
)
|
| 132 |
+
# decode compressed state
|
| 133 |
+
eef_value = tf.io.decode_compressed(
|
| 134 |
+
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
|
| 135 |
+
compression_type="ZLIB",
|
| 136 |
+
)
|
| 137 |
+
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
| 138 |
+
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
| 139 |
+
gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB")
|
| 140 |
+
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
| 141 |
+
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
| 142 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 143 |
+
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
| 144 |
+
# ) # delete uninformative language instruction
|
| 145 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 146 |
+
return trajectory
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 150 |
+
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
|
| 151 |
+
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
|
| 152 |
+
trajectory["action"] = trajectory["action"]["rel_actions_world"]
|
| 153 |
+
|
| 154 |
+
# invert gripper action + clip, +1 = open, 0 = close
|
| 155 |
+
trajectory["action"] = tf.concat(
|
| 156 |
+
(
|
| 157 |
+
trajectory["action"][:, :6],
|
| 158 |
+
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
|
| 159 |
+
),
|
| 160 |
+
axis=-1,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 164 |
+
return trajectory
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 168 |
+
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
| 169 |
+
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:]
|
| 170 |
+
|
| 171 |
+
# make gripper action absolute action, +1 = open, 0 = close
|
| 172 |
+
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
| 173 |
+
gripper_action = rel2abs_gripper_actions(gripper_action)
|
| 174 |
+
|
| 175 |
+
trajectory["action"] = tf.concat(
|
| 176 |
+
(
|
| 177 |
+
trajectory["action"]["world_vector"],
|
| 178 |
+
tf.zeros_like(trajectory["action"]["world_vector"]),
|
| 179 |
+
gripper_action[:, None],
|
| 180 |
+
),
|
| 181 |
+
axis=-1,
|
| 182 |
+
)
|
| 183 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 184 |
+
return trajectory
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 188 |
+
trajectory["action"] = tf.concat(
|
| 189 |
+
(
|
| 190 |
+
trajectory["action"]["world_vector"],
|
| 191 |
+
trajectory["action"]["rotation_delta"],
|
| 192 |
+
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
|
| 193 |
+
),
|
| 194 |
+
axis=-1,
|
| 195 |
+
)
|
| 196 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 197 |
+
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
| 198 |
+
# ) # delete uninformative language instruction
|
| 199 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 200 |
+
return trajectory
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 204 |
+
# invert absolute gripper action, +1 = open, 0 = close
|
| 205 |
+
gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1))
|
| 206 |
+
|
| 207 |
+
trajectory["action"] = tf.concat(
|
| 208 |
+
(
|
| 209 |
+
trajectory["action"]["world_vector"],
|
| 210 |
+
trajectory["action"]["rotation_delta"],
|
| 211 |
+
gripper_action,
|
| 212 |
+
),
|
| 213 |
+
axis=-1,
|
| 214 |
+
)
|
| 215 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 216 |
+
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
| 217 |
+
# ) # delete uninformative language instruction
|
| 218 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 219 |
+
return trajectory
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 223 |
+
# make gripper action absolute action, +1 = open, 0 = close
|
| 224 |
+
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
| 225 |
+
gripper_action = rel2abs_gripper_actions(gripper_action)
|
| 226 |
+
|
| 227 |
+
trajectory["action"] = tf.concat(
|
| 228 |
+
(
|
| 229 |
+
trajectory["action"]["world_vector"],
|
| 230 |
+
trajectory["action"]["rotation_delta"],
|
| 231 |
+
gripper_action[:, None],
|
| 232 |
+
),
|
| 233 |
+
axis=-1,
|
| 234 |
+
)
|
| 235 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 236 |
+
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
| 237 |
+
# ) # delete uninformative language instruction
|
| 238 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 239 |
+
return trajectory
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 243 |
+
# make gripper action, +1 = open, 0 = close
|
| 244 |
+
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
|
| 245 |
+
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
|
| 246 |
+
gripper_action = invert_gripper_actions(gripper_action)
|
| 247 |
+
|
| 248 |
+
trajectory["action"] = tf.concat(
|
| 249 |
+
(
|
| 250 |
+
trajectory["action"]["world_vector"],
|
| 251 |
+
trajectory["action"]["rotation_delta"],
|
| 252 |
+
gripper_action,
|
| 253 |
+
),
|
| 254 |
+
axis=-1,
|
| 255 |
+
)
|
| 256 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 257 |
+
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
| 258 |
+
# ) # delete uninformative language instruction
|
| 259 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 260 |
+
return trajectory
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 264 |
+
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
|
| 265 |
+
trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth")
|
| 266 |
+
|
| 267 |
+
# make gripper action absolute action, +1 = open, 0 = close
|
| 268 |
+
gripper_action = trajectory["action"]["gripper_closedness_action"]
|
| 269 |
+
gripper_action = rel2abs_gripper_actions(gripper_action)
|
| 270 |
+
|
| 271 |
+
trajectory["action"] = tf.concat(
|
| 272 |
+
(
|
| 273 |
+
trajectory["action"]["world_vector"],
|
| 274 |
+
trajectory["action"]["rotation_delta"],
|
| 275 |
+
gripper_action[:, None],
|
| 276 |
+
),
|
| 277 |
+
axis=-1,
|
| 278 |
+
)
|
| 279 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 280 |
+
return trajectory
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 284 |
+
trajectory["action"] = tf.concat(
|
| 285 |
+
(
|
| 286 |
+
trajectory["action"]["world_vector"],
|
| 287 |
+
trajectory["action"]["rotation_delta"],
|
| 288 |
+
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
| 289 |
+
),
|
| 290 |
+
axis=-1,
|
| 291 |
+
)
|
| 292 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 293 |
+
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
| 294 |
+
# ) # delete uninformative language instruction
|
| 295 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 296 |
+
return trajectory
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 300 |
+
# default to "open" gripper
|
| 301 |
+
trajectory["action"] = tf.concat(
|
| 302 |
+
(
|
| 303 |
+
trajectory["action"],
|
| 304 |
+
tf.zeros_like(trajectory["action"]),
|
| 305 |
+
tf.zeros_like(trajectory["action"]),
|
| 306 |
+
tf.ones_like(trajectory["action"][:, :1]),
|
| 307 |
+
),
|
| 308 |
+
axis=-1,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# decode language instruction
|
| 312 |
+
instruction_bytes = trajectory["observation"]["instruction"]
|
| 313 |
+
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
| 314 |
+
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
| 315 |
+
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0]
|
| 316 |
+
return trajectory
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 320 |
+
trajectory["action"] = tf.concat(
|
| 321 |
+
(
|
| 322 |
+
trajectory["action"]["world_vector"],
|
| 323 |
+
trajectory["action"]["rotation_delta"],
|
| 324 |
+
trajectory["action"]["gripper_closedness_action"][:, None],
|
| 325 |
+
),
|
| 326 |
+
axis=-1,
|
| 327 |
+
)
|
| 328 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 329 |
+
return trajectory
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 333 |
+
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
|
| 334 |
+
trajectory["action"] = tf.concat(
|
| 335 |
+
(
|
| 336 |
+
trajectory["action"][:, :3],
|
| 337 |
+
tf.zeros_like(trajectory["action"][:, :3]),
|
| 338 |
+
trajectory["action"][:, -1:],
|
| 339 |
+
),
|
| 340 |
+
axis=-1,
|
| 341 |
+
)
|
| 342 |
+
return trajectory
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 346 |
+
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
|
| 347 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
|
| 348 |
+
trajectory["action"] = trajectory["action"][..., :7]
|
| 349 |
+
return trajectory
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 353 |
+
# invert gripper action, +1 = open, 0 = close
|
| 354 |
+
trajectory["action"] = tf.concat(
|
| 355 |
+
(
|
| 356 |
+
trajectory["action"][:, :6],
|
| 357 |
+
invert_gripper_actions(trajectory["action"][:, -1:]),
|
| 358 |
+
),
|
| 359 |
+
axis=-1,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
trajectory["observation"]["eef_state"] = tf.concat(
|
| 363 |
+
(
|
| 364 |
+
trajectory["observation"]["state"][:, :3],
|
| 365 |
+
trajectory["observation"]["state"][:, 7:10],
|
| 366 |
+
),
|
| 367 |
+
axis=-1,
|
| 368 |
+
)
|
| 369 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
|
| 370 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 371 |
+
# tf.shape(trajectory["language_instruction"]), ""
|
| 372 |
+
# ) # delete uninformative language instruction
|
| 373 |
+
return trajectory
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 377 |
+
# invert gripper action + clip, +1 = open, 0 = close
|
| 378 |
+
trajectory["action"] = tf.concat(
|
| 379 |
+
(
|
| 380 |
+
trajectory["action"][:, :6],
|
| 381 |
+
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
| 382 |
+
),
|
| 383 |
+
axis=-1,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
| 387 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 388 |
+
# tf.shape(trajectory["language_instruction"]), ""
|
| 389 |
+
# ) # delete uninformative language instruction
|
| 390 |
+
return trajectory
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 394 |
+
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
|
| 395 |
+
trajectory["observation"]["depth_additional_view"] = tf.cast(
|
| 396 |
+
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
|
| 397 |
+
)
|
| 398 |
+
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
|
| 399 |
+
|
| 400 |
+
# clip gripper action, +1 = open, 0 = close
|
| 401 |
+
trajectory["action"] = tf.concat(
|
| 402 |
+
(
|
| 403 |
+
trajectory["action"][:, -8:-2],
|
| 404 |
+
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
|
| 405 |
+
),
|
| 406 |
+
axis=-1,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 410 |
+
# tf.shape(trajectory["language_instruction"]), ""
|
| 411 |
+
# ) # delete uninformative language instruction
|
| 412 |
+
return trajectory
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 416 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
|
| 417 |
+
return trajectory
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 421 |
+
import tensorflow_graphics.geometry.transformation as tft
|
| 422 |
+
|
| 423 |
+
trajectory["observation"]["state"] = tf.concat(
|
| 424 |
+
(
|
| 425 |
+
trajectory["observation"]["state"][:, :7],
|
| 426 |
+
trajectory["observation"]["state"][:, -1:],
|
| 427 |
+
),
|
| 428 |
+
axis=-1,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# invert gripper action + clip, +1 = open, 0 = close
|
| 432 |
+
trajectory["action"] = tf.concat(
|
| 433 |
+
(
|
| 434 |
+
trajectory["action"][:, :3],
|
| 435 |
+
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
| 436 |
+
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
| 437 |
+
),
|
| 438 |
+
axis=-1,
|
| 439 |
+
)
|
| 440 |
+
return trajectory
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 444 |
+
trajectory["action"] = trajectory["action"][..., :-1]
|
| 445 |
+
return trajectory
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 449 |
+
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
| 450 |
+
trajectory["action"] = trajectory["action"][..., :-1]
|
| 451 |
+
return trajectory
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 455 |
+
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
| 456 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 457 |
+
trajectory["action"] = tf.concat(
|
| 458 |
+
(
|
| 459 |
+
trajectory["action"][:, :3],
|
| 460 |
+
tf.zeros_like(trajectory["action"][:, :3]),
|
| 461 |
+
trajectory["action"][:, -1:],
|
| 462 |
+
),
|
| 463 |
+
axis=-1,
|
| 464 |
+
)
|
| 465 |
+
return trajectory
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 469 |
+
# invert gripper action + clip, +1 = open, 0 = close
|
| 470 |
+
trajectory["action"] = tf.concat(
|
| 471 |
+
(
|
| 472 |
+
trajectory["action"][:, :6],
|
| 473 |
+
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
| 474 |
+
),
|
| 475 |
+
axis=-1,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 479 |
+
# tf.shape(trajectory["language_instruction"]), ""
|
| 480 |
+
# ) # delete uninformative language instruction
|
| 481 |
+
return trajectory
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 485 |
+
# invert gripper action + clip, +1 = open, 0 = close
|
| 486 |
+
trajectory["action"] = tf.concat(
|
| 487 |
+
(
|
| 488 |
+
trajectory["action"][:, :6],
|
| 489 |
+
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
| 490 |
+
),
|
| 491 |
+
axis=-1,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 495 |
+
# tf.shape(trajectory["language_instruction"]), ""
|
| 496 |
+
# ) # delete uninformative language instruction
|
| 497 |
+
return trajectory
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 501 |
+
trajectory["action"] = tf.concat(
|
| 502 |
+
(
|
| 503 |
+
trajectory["action"]["future/xyz_residual"][:, :3],
|
| 504 |
+
trajectory["action"]["future/axis_angle_residual"][:, :3],
|
| 505 |
+
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
|
| 506 |
+
),
|
| 507 |
+
axis=-1,
|
| 508 |
+
)
|
| 509 |
+
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
| 510 |
+
return trajectory
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 514 |
+
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
| 515 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 516 |
+
trajectory["action"] = trajectory["action"][..., :-1]
|
| 517 |
+
return trajectory
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 521 |
+
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
| 522 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 523 |
+
trajectory["action"] = trajectory["action"][..., :-1]
|
| 524 |
+
return trajectory
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def utokyo_xarm_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 528 |
+
return trajectory
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 532 |
+
trajectory["action"] = trajectory["action"][..., -7:]
|
| 533 |
+
return trajectory
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 537 |
+
trajectory["observation"]["eef_state"] = tf.concat(
|
| 538 |
+
(
|
| 539 |
+
trajectory["observation"]["state"][:, :4],
|
| 540 |
+
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
|
| 541 |
+
),
|
| 542 |
+
axis=-1,
|
| 543 |
+
)
|
| 544 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 545 |
+
trajectory["action"] = tf.concat(
|
| 546 |
+
(
|
| 547 |
+
trajectory["action"][:, :4],
|
| 548 |
+
tf.zeros_like(trajectory["action"][:, :2]),
|
| 549 |
+
trajectory["action"][:, -1:],
|
| 550 |
+
),
|
| 551 |
+
axis=-1,
|
| 552 |
+
)
|
| 553 |
+
return trajectory
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 557 |
+
return trajectory
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 561 |
+
return trajectory
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 565 |
+
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
|
| 566 |
+
trajectory["action"] = tf.concat(
|
| 567 |
+
(
|
| 568 |
+
trajectory["action"][:, :6],
|
| 569 |
+
tf.zeros_like(trajectory["action"][:, :1]),
|
| 570 |
+
),
|
| 571 |
+
axis=-1,
|
| 572 |
+
)
|
| 573 |
+
return trajectory
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 577 |
+
trajectory["observation"]["eef_state"] = tf.concat(
|
| 578 |
+
(
|
| 579 |
+
trajectory["observation"]["end_effector_pose"][:, :4],
|
| 580 |
+
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
|
| 581 |
+
),
|
| 582 |
+
axis=-1,
|
| 583 |
+
)
|
| 584 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
|
| 585 |
+
trajectory["action"] = tf.concat(
|
| 586 |
+
(
|
| 587 |
+
trajectory["action"][:, :4],
|
| 588 |
+
tf.zeros_like(trajectory["action"][:, :2]),
|
| 589 |
+
trajectory["action"][:, -1:],
|
| 590 |
+
),
|
| 591 |
+
axis=-1,
|
| 592 |
+
)
|
| 593 |
+
return trajectory
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 597 |
+
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
| 598 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 599 |
+
return trajectory
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def dlr_sara_pour_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 603 |
+
return trajectory
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 607 |
+
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
|
| 608 |
+
return trajectory
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 612 |
+
# invert gripper action, +1 = open, 0 = close
|
| 613 |
+
trajectory["action"] = tf.concat(
|
| 614 |
+
(
|
| 615 |
+
trajectory["action"][:, :6],
|
| 616 |
+
invert_gripper_actions(trajectory["action"][:, -1:]),
|
| 617 |
+
),
|
| 618 |
+
axis=-1,
|
| 619 |
+
)
|
| 620 |
+
return trajectory
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 624 |
+
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
|
| 625 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 626 |
+
return trajectory
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 630 |
+
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
| 631 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 632 |
+
return trajectory
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 636 |
+
trajectory["action"] = trajectory["action"][..., :-1]
|
| 637 |
+
return trajectory
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 641 |
+
import tensorflow_graphics.geometry.transformation as tft
|
| 642 |
+
|
| 643 |
+
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
| 644 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
|
| 645 |
+
trajectory["action"] = tf.concat(
|
| 646 |
+
(
|
| 647 |
+
trajectory["action"][:, :3],
|
| 648 |
+
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
| 649 |
+
trajectory["action"][:, 7:8],
|
| 650 |
+
),
|
| 651 |
+
axis=-1,
|
| 652 |
+
)
|
| 653 |
+
return trajectory
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 657 |
+
trajectory["action"] = tf.concat(
|
| 658 |
+
(
|
| 659 |
+
trajectory["action"],
|
| 660 |
+
tf.zeros_like(trajectory["action"]),
|
| 661 |
+
tf.zeros_like(trajectory["action"][:, :1]),
|
| 662 |
+
),
|
| 663 |
+
axis=-1,
|
| 664 |
+
)
|
| 665 |
+
return trajectory
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 669 |
+
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
| 670 |
+
|
| 671 |
+
# invert gripper action + clip, +1 = open, 0 = close
|
| 672 |
+
trajectory["action"] = tf.concat(
|
| 673 |
+
(
|
| 674 |
+
trajectory["action"][:, :6],
|
| 675 |
+
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
| 676 |
+
),
|
| 677 |
+
axis=-1,
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# trajectory["language_instruction"] = tf.fill(
|
| 681 |
+
# tf.shape(trajectory["language_instruction"]), ""
|
| 682 |
+
# ) # delete uninformative language instruction
|
| 683 |
+
return trajectory
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 687 |
+
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
|
| 688 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
|
| 689 |
+
|
| 690 |
+
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
|
| 691 |
+
trajectory["action"] = tf.concat(
|
| 692 |
+
(
|
| 693 |
+
trajectory["action"],
|
| 694 |
+
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
|
| 695 |
+
),
|
| 696 |
+
axis=-1,
|
| 697 |
+
)
|
| 698 |
+
return trajectory
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 702 |
+
import tensorflow_graphics.geometry.transformation as tft
|
| 703 |
+
|
| 704 |
+
trajectory["action"] = tf.concat(
|
| 705 |
+
(
|
| 706 |
+
trajectory["action"][:, :3],
|
| 707 |
+
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
| 708 |
+
trajectory["action"][:, -1:],
|
| 709 |
+
),
|
| 710 |
+
axis=-1,
|
| 711 |
+
)
|
| 712 |
+
return trajectory
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 716 |
+
trajectory["action"] = tf.concat(
|
| 717 |
+
(
|
| 718 |
+
trajectory["action"][:, :3],
|
| 719 |
+
trajectory["action"][:, -4:],
|
| 720 |
+
),
|
| 721 |
+
axis=-1,
|
| 722 |
+
)
|
| 723 |
+
return trajectory
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 727 |
+
trajectory["observation"]["eef_state"] = tf.concat(
|
| 728 |
+
(
|
| 729 |
+
trajectory["observation"]["state"][:, :3],
|
| 730 |
+
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
| 731 |
+
),
|
| 732 |
+
axis=-1,
|
| 733 |
+
)
|
| 734 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
| 735 |
+
trajectory["action"] = trajectory["action"][..., :-1]
|
| 736 |
+
return trajectory
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 740 |
+
trajectory["observation"]["state"] = tf.concat(
|
| 741 |
+
(
|
| 742 |
+
trajectory["observation"]["position"],
|
| 743 |
+
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
| 744 |
+
trajectory["observation"]["yaw"],
|
| 745 |
+
),
|
| 746 |
+
axis=-1,
|
| 747 |
+
)
|
| 748 |
+
trajectory["action"] = tf.concat(
|
| 749 |
+
(
|
| 750 |
+
trajectory["action"],
|
| 751 |
+
tf.zeros_like(trajectory["action"]),
|
| 752 |
+
tf.zeros_like(trajectory["action"]),
|
| 753 |
+
tf.zeros_like(trajectory["action"][:, :1]),
|
| 754 |
+
),
|
| 755 |
+
axis=-1,
|
| 756 |
+
)
|
| 757 |
+
return trajectory
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 761 |
+
# every input feature is batched, ie has leading batch dimension
|
| 762 |
+
trajectory["observation"]["proprio"] = tf.concat(
|
| 763 |
+
(
|
| 764 |
+
trajectory["observation"]["eef_pose"],
|
| 765 |
+
trajectory["observation"]["state_gripper_pose"][..., None],
|
| 766 |
+
),
|
| 767 |
+
axis=-1,
|
| 768 |
+
)
|
| 769 |
+
return trajectory
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 773 |
+
# every input feature is batched, ie has leading batch dimension
|
| 774 |
+
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
|
| 775 |
+
return trajectory
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def roboset_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 779 |
+
# every input feature is batched, ie has leading batch dimension
|
| 780 |
+
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
|
| 781 |
+
|
| 782 |
+
# gripper action is in -1...1 --> clip to 0...1, flip
|
| 783 |
+
gripper_action = trajectory["action"][:, -1:]
|
| 784 |
+
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
|
| 785 |
+
|
| 786 |
+
trajectory["action"] = tf.concat(
|
| 787 |
+
(
|
| 788 |
+
trajectory["action"][:, :7],
|
| 789 |
+
gripper_action,
|
| 790 |
+
),
|
| 791 |
+
axis=-1,
|
| 792 |
+
)
|
| 793 |
+
return trajectory
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
def rh20t_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 797 |
+
trajectory["action"] = tf.concat(
|
| 798 |
+
(
|
| 799 |
+
trajectory["action"]["tcp_base"],
|
| 800 |
+
tf.cast(trajectory["action"]["gripper"][:, None], tf.float32),
|
| 801 |
+
),
|
| 802 |
+
axis=-1,
|
| 803 |
+
)
|
| 804 |
+
trajectory["observation"]["proprio"] = tf.concat(
|
| 805 |
+
(
|
| 806 |
+
trajectory["observation"]["tcp_base"],
|
| 807 |
+
trajectory["observation"]["gripper_width"][..., None],
|
| 808 |
+
),
|
| 809 |
+
axis=-1,
|
| 810 |
+
)
|
| 811 |
+
return trajectory
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 815 |
+
trajectory["action"] = tf.concat(
|
| 816 |
+
[
|
| 817 |
+
trajectory["action"][:, :6],
|
| 818 |
+
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
| 819 |
+
],
|
| 820 |
+
axis=1,
|
| 821 |
+
)
|
| 822 |
+
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
|
| 823 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
|
| 824 |
+
return trajectory
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 828 |
+
# gripper action is in -1 (open)...1 (close) --> clip to 0...1, flip --> +1 = open, 0 = close
|
| 829 |
+
gripper_action = trajectory["action"][:, -1:]
|
| 830 |
+
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
|
| 831 |
+
|
| 832 |
+
trajectory["action"] = tf.concat(
|
| 833 |
+
[
|
| 834 |
+
trajectory["action"][:, :6],
|
| 835 |
+
gripper_action,
|
| 836 |
+
],
|
| 837 |
+
axis=1,
|
| 838 |
+
)
|
| 839 |
+
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
| 840 |
+
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state
|
| 841 |
+
return trajectory
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
def aloha_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 845 |
+
# Don't need to do anything because dataset is already in the correct format
|
| 846 |
+
return trajectory
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
# === Registry ===
|
| 850 |
+
OXE_STANDARDIZATION_TRANSFORMS = {
|
| 851 |
+
"bridge_oxe": bridge_oxe_dataset_transform,
|
| 852 |
+
"bridge_orig": bridge_orig_dataset_transform,
|
| 853 |
+
"bridge_dataset": bridge_orig_dataset_transform,
|
| 854 |
+
"ppgm": ppgm_dataset_transform,
|
| 855 |
+
"ppgm_static": ppgm_dataset_transform,
|
| 856 |
+
"ppgm_wrist": ppgm_dataset_transform,
|
| 857 |
+
"fractal20220817_data": rt1_dataset_transform,
|
| 858 |
+
"kuka": kuka_dataset_transform,
|
| 859 |
+
"taco_play": taco_play_dataset_transform,
|
| 860 |
+
"jaco_play": jaco_play_dataset_transform,
|
| 861 |
+
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
|
| 862 |
+
"roboturk": roboturk_dataset_transform,
|
| 863 |
+
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
|
| 864 |
+
"viola": viola_dataset_transform,
|
| 865 |
+
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
|
| 866 |
+
"toto": toto_dataset_transform,
|
| 867 |
+
"language_table": language_table_dataset_transform,
|
| 868 |
+
"columbia_cairlab_pusht_real": pusht_dataset_transform,
|
| 869 |
+
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
|
| 870 |
+
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
|
| 871 |
+
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
|
| 872 |
+
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
|
| 873 |
+
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
|
| 874 |
+
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
|
| 875 |
+
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
|
| 876 |
+
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
|
| 877 |
+
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
|
| 878 |
+
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
|
| 879 |
+
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
|
| 880 |
+
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
|
| 881 |
+
"bc_z": bc_z_dataset_transform,
|
| 882 |
+
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
|
| 883 |
+
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
|
| 884 |
+
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform,
|
| 885 |
+
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
|
| 886 |
+
"robo_net": robo_net_dataset_transform,
|
| 887 |
+
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
|
| 888 |
+
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
|
| 889 |
+
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
|
| 890 |
+
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
|
| 891 |
+
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
|
| 892 |
+
"dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform,
|
| 893 |
+
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
|
| 894 |
+
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
|
| 895 |
+
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
|
| 896 |
+
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
|
| 897 |
+
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
|
| 898 |
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
|
| 899 |
+
"uiuc_d3field": uiuc_d3field_dataset_transform,
|
| 900 |
+
"utaustin_mutex": utaustin_mutex_dataset_transform,
|
| 901 |
+
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
|
| 902 |
+
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
|
| 903 |
+
"cmu_play_fusion": playfusion_dataset_transform,
|
| 904 |
+
"cmu_stretch": cmu_stretch_dataset_transform,
|
| 905 |
+
"berkeley_gnm_recon": gnm_dataset_transform,
|
| 906 |
+
"berkeley_gnm_cory_hall": gnm_dataset_transform,
|
| 907 |
+
"berkeley_gnm_sac_son": gnm_dataset_transform,
|
| 908 |
+
"droid": droid_baseact_transform,
|
| 909 |
+
"fmb_dataset": fmb_dataset_transform,
|
| 910 |
+
"dobbe": dobbe_dataset_transform,
|
| 911 |
+
"roboset": roboset_dataset_transform,
|
| 912 |
+
"rh20t": rh20t_dataset_transform,
|
| 913 |
+
### T-DROID datasets
|
| 914 |
+
"tdroid_carrot_in_bowl": tdroid_dataset_transform,
|
| 915 |
+
"tdroid_pour_corn_in_pot": tdroid_dataset_transform,
|
| 916 |
+
"tdroid_flip_pot_upright": tdroid_dataset_transform,
|
| 917 |
+
"tdroid_move_object_onto_plate": tdroid_dataset_transform,
|
| 918 |
+
"tdroid_knock_object_over": tdroid_dataset_transform,
|
| 919 |
+
"tdroid_cover_object_with_towel": tdroid_dataset_transform,
|
| 920 |
+
### DROID Finetuning datasets
|
| 921 |
+
"droid_wipe": droid_finetuning_transform,
|
| 922 |
+
### LIBERO datasets (modified versions)
|
| 923 |
+
"libero_spatial_no_noops": libero_dataset_transform,
|
| 924 |
+
"libero_object_no_noops": libero_dataset_transform,
|
| 925 |
+
"libero_goal_no_noops": libero_dataset_transform,
|
| 926 |
+
"libero_10_no_noops": libero_dataset_transform,
|
| 927 |
+
"libero_4_task_suites_no_noops": libero_dataset_transform,
|
| 928 |
+
### ALOHA fine-tuning datasets
|
| 929 |
+
"aloha1_fold_shorts_20_demos": aloha_dataset_transform,
|
| 930 |
+
"aloha1_fold_shirt_30_demos": aloha_dataset_transform,
|
| 931 |
+
"aloha1_scoop_X_into_bowl_45_demos": aloha_dataset_transform,
|
| 932 |
+
"aloha1_put_X_into_pot_300_demos": aloha_dataset_transform,
|
| 933 |
+
|
| 934 |
+
"aloha_dual_bottles_pick_hard_d435_20": aloha_dataset_transform,
|
| 935 |
+
|
| 936 |
+
# robotwin2
|
| 937 |
+
"grab_roller_aloha_agilex_50": aloha_dataset_transform,
|
| 938 |
+
"handover_mic_aloha_agilex_50": aloha_dataset_transform,
|
| 939 |
+
"lift_pot_aloha_agilex_50": aloha_dataset_transform,
|
| 940 |
+
"move_can_pot_aloha_agilex_50": aloha_dataset_transform,
|
| 941 |
+
"open_laptop_aloha_agilex_50": aloha_dataset_transform,
|
| 942 |
+
"pick_dual_bottles_aloha_agilex_50":aloha_dataset_transform,
|
| 943 |
+
"place_dual_shoes_aloha_agilex_50": aloha_dataset_transform,
|
| 944 |
+
"place_object_basket_aloha_agilex_50": aloha_dataset_transform,
|
| 945 |
+
"place_phone_stand_aloha_agilex_50": aloha_dataset_transform,
|
| 946 |
+
"put_bottles_dustbin_aloha_agilex_50": aloha_dataset_transform,
|
| 947 |
+
"put_object_cabinet_aloha_agilex_50": aloha_dataset_transform,
|
| 948 |
+
"stack_blocks_two_aloha_agilex_50": aloha_dataset_transform,
|
| 949 |
+
"stack_bowls_two_aloha_agilex_50": aloha_dataset_transform,
|
| 950 |
+
|
| 951 |
+
}
|
prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Episode transforms for DROID dataset."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
import tensorflow_graphics.geometry.transformation as tfg
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def rmat_to_euler(rot_mat):
|
| 10 |
+
return tfg.euler.from_rotation_matrix(rot_mat)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def euler_to_rmat(euler):
|
| 14 |
+
return tfg.rotation_matrix_3d.from_euler(euler)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def invert_rmat(rot_mat):
|
| 18 |
+
return tfg.rotation_matrix_3d.inverse(rot_mat)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def rotmat_to_rot6d(mat):
|
| 22 |
+
"""
|
| 23 |
+
Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
|
| 24 |
+
Args:
|
| 25 |
+
mat: rotation matrix
|
| 26 |
+
|
| 27 |
+
Returns: 6d vector (first two rows of rotation matrix)
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
r6 = mat[..., :2, :]
|
| 31 |
+
r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
|
| 32 |
+
r6_flat = tf.concat([r6_0, r6_1], axis=-1)
|
| 33 |
+
return r6_flat
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
|
| 37 |
+
"""
|
| 38 |
+
Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
|
| 39 |
+
Args:
|
| 40 |
+
velocity: 6d velocity action (3 x translation, 3 x rotation)
|
| 41 |
+
wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
|
| 42 |
+
|
| 43 |
+
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
|
| 44 |
+
|
| 45 |
+
"""
|
| 46 |
+
R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
|
| 47 |
+
R_frame_inv = invert_rmat(R_frame)
|
| 48 |
+
|
| 49 |
+
# world to wrist: dT_pi = R^-1 dT_rbt
|
| 50 |
+
vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0]
|
| 51 |
+
|
| 52 |
+
# world to wrist: dR_pi = R^-1 dR_rbt R
|
| 53 |
+
dR = euler_to_rmat(velocity[:, 3:6])
|
| 54 |
+
dR = R_frame_inv @ (dR @ R_frame)
|
| 55 |
+
dR_r6 = rotmat_to_rot6d(dR)
|
| 56 |
+
return tf.concat([vel_t, dR_r6], axis=-1)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def rand_swap_exterior_images(img1, img2):
|
| 60 |
+
"""
|
| 61 |
+
Randomly swaps the two exterior images (for training with single exterior input).
|
| 62 |
+
"""
|
| 63 |
+
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 67 |
+
"""
|
| 68 |
+
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
| 69 |
+
"""
|
| 70 |
+
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
| 71 |
+
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
| 72 |
+
|
| 73 |
+
trajectory["action"] = tf.concat(
|
| 74 |
+
(
|
| 75 |
+
dt,
|
| 76 |
+
dR,
|
| 77 |
+
1 - trajectory["action_dict"]["gripper_position"],
|
| 78 |
+
),
|
| 79 |
+
axis=-1,
|
| 80 |
+
)
|
| 81 |
+
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
| 82 |
+
rand_swap_exterior_images(
|
| 83 |
+
trajectory["observation"]["exterior_image_1_left"],
|
| 84 |
+
trajectory["observation"]["exterior_image_2_left"],
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
trajectory["observation"]["proprio"] = tf.concat(
|
| 88 |
+
(
|
| 89 |
+
trajectory["observation"]["cartesian_position"],
|
| 90 |
+
trajectory["observation"]["gripper_position"],
|
| 91 |
+
),
|
| 92 |
+
axis=-1,
|
| 93 |
+
)
|
| 94 |
+
return trajectory
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 98 |
+
"""
|
| 99 |
+
DROID dataset transformation for actions expressed in *wrist* frame of the robot.
|
| 100 |
+
"""
|
| 101 |
+
wrist_act = velocity_act_to_wrist_frame(
|
| 102 |
+
trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
|
| 103 |
+
)
|
| 104 |
+
trajectory["action"] = tf.concat(
|
| 105 |
+
(
|
| 106 |
+
wrist_act,
|
| 107 |
+
trajectory["action_dict"]["gripper_position"],
|
| 108 |
+
),
|
| 109 |
+
axis=-1,
|
| 110 |
+
)
|
| 111 |
+
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
| 112 |
+
rand_swap_exterior_images(
|
| 113 |
+
trajectory["observation"]["exterior_image_1_left"],
|
| 114 |
+
trajectory["observation"]["exterior_image_2_left"],
|
| 115 |
+
)
|
| 116 |
+
)
|
| 117 |
+
trajectory["observation"]["proprio"] = tf.concat(
|
| 118 |
+
(
|
| 119 |
+
trajectory["observation"]["cartesian_position"],
|
| 120 |
+
trajectory["observation"]["gripper_position"],
|
| 121 |
+
),
|
| 122 |
+
axis=-1,
|
| 123 |
+
)
|
| 124 |
+
return trajectory
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
| 128 |
+
"""
|
| 129 |
+
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
| 130 |
+
"""
|
| 131 |
+
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
| 132 |
+
dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
| 133 |
+
trajectory["action"] = tf.concat(
|
| 134 |
+
(
|
| 135 |
+
dt,
|
| 136 |
+
dR,
|
| 137 |
+
1 - trajectory["action_dict"]["gripper_position"],
|
| 138 |
+
),
|
| 139 |
+
axis=-1,
|
| 140 |
+
)
|
| 141 |
+
trajectory["observation"]["proprio"] = tf.concat(
|
| 142 |
+
(
|
| 143 |
+
trajectory["observation"]["cartesian_position"],
|
| 144 |
+
trajectory["observation"]["gripper_position"],
|
| 145 |
+
),
|
| 146 |
+
axis=-1,
|
| 147 |
+
)
|
| 148 |
+
return trajectory
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def zero_action_filter(traj: Dict) -> bool:
|
| 152 |
+
"""
|
| 153 |
+
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
|
| 154 |
+
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
|
| 155 |
+
"""
|
| 156 |
+
DROID_Q01 = tf.convert_to_tensor(
|
| 157 |
+
[
|
| 158 |
+
-0.7776297926902771,
|
| 159 |
+
-0.5803514122962952,
|
| 160 |
+
-0.5795090794563293,
|
| 161 |
+
-0.6464047729969025,
|
| 162 |
+
-0.7041108310222626,
|
| 163 |
+
-0.8895104378461838,
|
| 164 |
+
]
|
| 165 |
+
)
|
| 166 |
+
DROID_Q99 = tf.convert_to_tensor(
|
| 167 |
+
[
|
| 168 |
+
0.7597932070493698,
|
| 169 |
+
0.5726242214441299,
|
| 170 |
+
0.7351000607013702,
|
| 171 |
+
0.6705610305070877,
|
| 172 |
+
0.6464948207139969,
|
| 173 |
+
0.8897542208433151,
|
| 174 |
+
]
|
| 175 |
+
)
|
| 176 |
+
DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1
|
| 177 |
+
|
| 178 |
+
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5)
|
prismatic/vla/datasets/rlds/utils/__init__.py
ADDED
|
File without changes
|
prismatic/vla/datasets/rlds/utils/goal_relabeling.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
goal_relabeling.py
|
| 3 |
+
|
| 4 |
+
Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required.
|
| 5 |
+
Each function should add entries to the "task" dict.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict
|
| 9 |
+
|
| 10 |
+
import tensorflow as tf
|
| 11 |
+
|
| 12 |
+
from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def uniform(traj: Dict) -> Dict:
|
| 16 |
+
"""Relabels with a true uniform distribution over future states."""
|
| 17 |
+
traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0]
|
| 18 |
+
|
| 19 |
+
# Select a random future index for each transition i in the range [i + 1, traj_len)
|
| 20 |
+
rand = tf.random.uniform([traj_len])
|
| 21 |
+
low = tf.cast(tf.range(traj_len) + 1, tf.float32)
|
| 22 |
+
high = tf.cast(traj_len, tf.float32)
|
| 23 |
+
goal_idxs = tf.cast(rand * (high - low) + low, tf.int32)
|
| 24 |
+
|
| 25 |
+
# Sometimes there are floating-point errors that cause an out-of-bounds
|
| 26 |
+
goal_idxs = tf.minimum(goal_idxs, traj_len - 1)
|
| 27 |
+
|
| 28 |
+
# Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly)
|
| 29 |
+
goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"])
|
| 30 |
+
traj["task"] = tree_merge(traj["task"], goal)
|
| 31 |
+
|
| 32 |
+
return traj
|
prismatic/vla/materialize.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
materialize.py
|
| 3 |
+
|
| 4 |
+
Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and
|
| 5 |
+
exports individual functions for clear control flow.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Tuple, Type
|
| 10 |
+
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
from transformers import PreTrainedTokenizerBase
|
| 13 |
+
|
| 14 |
+
from prismatic.models.backbones.llm.prompting import PromptBuilder
|
| 15 |
+
from prismatic.models.backbones.vision import ImageTransform
|
| 16 |
+
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
|
| 17 |
+
from prismatic.vla.action_tokenizer import ActionTokenizer
|
| 18 |
+
from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_vla_dataset_and_collator(
|
| 22 |
+
data_root_dir: Path,
|
| 23 |
+
data_mix: str,
|
| 24 |
+
image_transform: ImageTransform,
|
| 25 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 26 |
+
prompt_builder_fn: Type[PromptBuilder],
|
| 27 |
+
default_image_resolution: Tuple[int, int, int],
|
| 28 |
+
padding_side: str = "right",
|
| 29 |
+
predict_stop_token: bool = True,
|
| 30 |
+
shuffle_buffer_size: int = 100_000,
|
| 31 |
+
train: bool = True,
|
| 32 |
+
episodic: bool = False,
|
| 33 |
+
image_aug: bool = False,
|
| 34 |
+
) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]:
|
| 35 |
+
"""Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions."""
|
| 36 |
+
action_tokenizer = ActionTokenizer(tokenizer)
|
| 37 |
+
batch_transform = RLDSBatchTransform(
|
| 38 |
+
action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token
|
| 39 |
+
)
|
| 40 |
+
collator = PaddedCollatorForActionPrediction(
|
| 41 |
+
tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Build RLDS Iterable Dataset
|
| 45 |
+
cls = RLDSDataset if not episodic else EpisodicRLDSDataset
|
| 46 |
+
dataset = cls(
|
| 47 |
+
data_root_dir,
|
| 48 |
+
data_mix,
|
| 49 |
+
batch_transform,
|
| 50 |
+
resize_resolution=default_image_resolution[1:],
|
| 51 |
+
shuffle_buffer_size=shuffle_buffer_size,
|
| 52 |
+
train=train,
|
| 53 |
+
image_aug=image_aug,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
return dataset, action_tokenizer, collator
|
run_scripts/ac/ac.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_query
|
| 3 |
+
#========== !NOTE! ==========#
|
| 4 |
+
RUN_MODE=simvla_ffn_AC
|
| 5 |
+
use_predict_future_prop=False
|
| 6 |
+
batch_size=16
|
| 7 |
+
use_action_ts_head=True
|
| 8 |
+
use_one_embed=True
|
| 9 |
+
use_multi_scaling=False
|
| 10 |
+
mlp_type=ffn
|
| 11 |
+
decoder_num_blocks=2
|
| 12 |
+
robot_platform=libero
|
| 13 |
+
MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 14 |
+
#========== !NOTE! ==========#
|
| 15 |
+
use_l1_regression=True
|
| 16 |
+
num_images_in_input=1
|
| 17 |
+
wandb_entity=chenghaha
|
| 18 |
+
wandb_project=fastvla
|
| 19 |
+
wandb_log_freq=1
|
| 20 |
+
use_proprio=False
|
| 21 |
+
use_diffusion=False
|
| 22 |
+
use_film=False
|
| 23 |
+
num_steps_before_decay=20000
|
| 24 |
+
save_freq=5000
|
| 25 |
+
max_steps=40000
|
| 26 |
+
vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
|
| 27 |
+
data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
|
| 28 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 29 |
+
run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 30 |
+
#========== get run_id ==========#
|
| 31 |
+
note_parts=("${MODE}")
|
| 32 |
+
|
| 33 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 34 |
+
# note_parts+=("L1_regression")
|
| 35 |
+
# fi
|
| 36 |
+
|
| 37 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 38 |
+
# note_parts+=("3rd_person_img")
|
| 39 |
+
# else
|
| 40 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 41 |
+
# fi
|
| 42 |
+
|
| 43 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 44 |
+
# note_parts+=("proprio_state")
|
| 45 |
+
# fi
|
| 46 |
+
|
| 47 |
+
# if [ "$use_film" = "True" ]; then
|
| 48 |
+
# note_parts+=("Film")
|
| 49 |
+
# fi
|
| 50 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 51 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 52 |
+
|
| 53 |
+
#========== enter environment ==========#
|
| 54 |
+
conda activate openvla-oft
|
| 55 |
+
cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 56 |
+
export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 57 |
+
|
| 58 |
+
#========== run ==========#
|
| 59 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 60 |
+
--vla_path "$vla_path" \
|
| 61 |
+
--data_root_dir "$data_root_dir" \
|
| 62 |
+
--dataset_name "$dataset_name" \
|
| 63 |
+
--run_root_dir "$run_root_dir" \
|
| 64 |
+
--use_l1_regression "$use_l1_regression" \
|
| 65 |
+
--use_diffusion "$use_diffusion" \
|
| 66 |
+
--use_film "$use_film" \
|
| 67 |
+
--num_images_in_input "$num_images_in_input" \
|
| 68 |
+
--use_proprio "$use_proprio" \
|
| 69 |
+
--batch_size "$batch_size" \
|
| 70 |
+
--learning_rate 5e-4 \
|
| 71 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 72 |
+
--max_steps "$max_steps" \
|
| 73 |
+
--save_freq "$save_freq" \
|
| 74 |
+
--save_latest_checkpoint_only False \
|
| 75 |
+
--image_aug True \
|
| 76 |
+
--lora_rank 32 \
|
| 77 |
+
--wandb_entity "$wandb_entity" \
|
| 78 |
+
--wandb_project "$wandb_project" \
|
| 79 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 80 |
+
--run_id_note "$run_id_note_value" \
|
| 81 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 82 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 83 |
+
--use_one_embed "$use_one_embed" \
|
| 84 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 85 |
+
--mlp_type "$mlp_type" \
|
| 86 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 87 |
+
--robot_platform "$robot_platform"
|
run_scripts/ffn/3ffn2.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_query
|
| 3 |
+
#========== !NOTE! ==========#
|
| 4 |
+
RUN_MODE=simvla3_ffn
|
| 5 |
+
use_predict_future_prop=False
|
| 6 |
+
batch_size=16
|
| 7 |
+
use_action_ts_head=True
|
| 8 |
+
use_one_embed=True
|
| 9 |
+
use_multi_scaling=False
|
| 10 |
+
mlp_type=ffn
|
| 11 |
+
decoder_num_blocks=2
|
| 12 |
+
robot_platform=libero
|
| 13 |
+
MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 14 |
+
#========== !NOTE! ==========#
|
| 15 |
+
use_l1_regression=True
|
| 16 |
+
num_images_in_input=1
|
| 17 |
+
wandb_entity=chenghaha
|
| 18 |
+
wandb_project=fastvla
|
| 19 |
+
wandb_log_freq=1
|
| 20 |
+
use_proprio=False
|
| 21 |
+
use_diffusion=False
|
| 22 |
+
use_film=False
|
| 23 |
+
num_steps_before_decay=20000
|
| 24 |
+
save_freq=10000
|
| 25 |
+
max_steps=40000
|
| 26 |
+
vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
|
| 27 |
+
data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
|
| 28 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 29 |
+
run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 30 |
+
#========== get run_id ==========#
|
| 31 |
+
note_parts=("${MODE}")
|
| 32 |
+
|
| 33 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 34 |
+
# note_parts+=("L1_regression")
|
| 35 |
+
# fi
|
| 36 |
+
|
| 37 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 38 |
+
# note_parts+=("3rd_person_img")
|
| 39 |
+
# else
|
| 40 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 41 |
+
# fi
|
| 42 |
+
|
| 43 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 44 |
+
# note_parts+=("proprio_state")
|
| 45 |
+
# fi
|
| 46 |
+
|
| 47 |
+
# if [ "$use_film" = "True" ]; then
|
| 48 |
+
# note_parts+=("Film")
|
| 49 |
+
# fi
|
| 50 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 51 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 52 |
+
|
| 53 |
+
#========== enter environment ==========#
|
| 54 |
+
conda activate openvla-oft
|
| 55 |
+
cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 56 |
+
export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 57 |
+
|
| 58 |
+
#========== run ==========#
|
| 59 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 60 |
+
--vla_path "$vla_path" \
|
| 61 |
+
--data_root_dir "$data_root_dir" \
|
| 62 |
+
--dataset_name "$dataset_name" \
|
| 63 |
+
--run_root_dir "$run_root_dir" \
|
| 64 |
+
--use_l1_regression "$use_l1_regression" \
|
| 65 |
+
--use_diffusion "$use_diffusion" \
|
| 66 |
+
--use_film "$use_film" \
|
| 67 |
+
--num_images_in_input "$num_images_in_input" \
|
| 68 |
+
--use_proprio "$use_proprio" \
|
| 69 |
+
--batch_size "$batch_size" \
|
| 70 |
+
--learning_rate 5e-4 \
|
| 71 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 72 |
+
--max_steps "$max_steps" \
|
| 73 |
+
--save_freq "$save_freq" \
|
| 74 |
+
--save_latest_checkpoint_only False \
|
| 75 |
+
--image_aug True \
|
| 76 |
+
--lora_rank 32 \
|
| 77 |
+
--wandb_entity "$wandb_entity" \
|
| 78 |
+
--wandb_project "$wandb_project" \
|
| 79 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 80 |
+
--run_id_note "$run_id_note_value" \
|
| 81 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 82 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 83 |
+
--use_one_embed "$use_one_embed" \
|
| 84 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 85 |
+
--mlp_type "$mlp_type" \
|
| 86 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 87 |
+
--robot_platform "$robot_platform"
|
run_scripts/ffn/3postffn2.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_query
|
| 3 |
+
#========== !NOTE! ==========#
|
| 4 |
+
RUN_MODE=simvla3_postffn
|
| 5 |
+
use_predict_future_prop=False
|
| 6 |
+
batch_size=16
|
| 7 |
+
use_action_ts_head=True
|
| 8 |
+
use_one_embed=True
|
| 9 |
+
use_multi_scaling=False
|
| 10 |
+
mlp_type=postffn
|
| 11 |
+
decoder_num_blocks=2
|
| 12 |
+
robot_platform=libero
|
| 13 |
+
MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 14 |
+
#========== !NOTE! ==========#
|
| 15 |
+
use_l1_regression=True
|
| 16 |
+
num_images_in_input=1
|
| 17 |
+
wandb_entity=chenghaha
|
| 18 |
+
wandb_project=fastvla
|
| 19 |
+
wandb_log_freq=1
|
| 20 |
+
use_proprio=False
|
| 21 |
+
use_diffusion=False
|
| 22 |
+
use_film=False
|
| 23 |
+
num_steps_before_decay=20000
|
| 24 |
+
save_freq=10000
|
| 25 |
+
max_steps=40000
|
| 26 |
+
vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
|
| 27 |
+
data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
|
| 28 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 29 |
+
run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 30 |
+
#========== get run_id ==========#
|
| 31 |
+
note_parts=("${MODE}")
|
| 32 |
+
|
| 33 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 34 |
+
# note_parts+=("L1_regression")
|
| 35 |
+
# fi
|
| 36 |
+
|
| 37 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 38 |
+
# note_parts+=("3rd_person_img")
|
| 39 |
+
# else
|
| 40 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 41 |
+
# fi
|
| 42 |
+
|
| 43 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 44 |
+
# note_parts+=("proprio_state")
|
| 45 |
+
# fi
|
| 46 |
+
|
| 47 |
+
# if [ "$use_film" = "True" ]; then
|
| 48 |
+
# note_parts+=("Film")
|
| 49 |
+
# fi
|
| 50 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 51 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 52 |
+
|
| 53 |
+
#========== enter environment ==========#
|
| 54 |
+
conda activate openvla-oft
|
| 55 |
+
cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 56 |
+
export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 57 |
+
|
| 58 |
+
#========== run ==========#
|
| 59 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 60 |
+
--vla_path "$vla_path" \
|
| 61 |
+
--data_root_dir "$data_root_dir" \
|
| 62 |
+
--dataset_name "$dataset_name" \
|
| 63 |
+
--run_root_dir "$run_root_dir" \
|
| 64 |
+
--use_l1_regression "$use_l1_regression" \
|
| 65 |
+
--use_diffusion "$use_diffusion" \
|
| 66 |
+
--use_film "$use_film" \
|
| 67 |
+
--num_images_in_input "$num_images_in_input" \
|
| 68 |
+
--use_proprio "$use_proprio" \
|
| 69 |
+
--batch_size "$batch_size" \
|
| 70 |
+
--learning_rate 5e-4 \
|
| 71 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 72 |
+
--max_steps "$max_steps" \
|
| 73 |
+
--save_freq "$save_freq" \
|
| 74 |
+
--save_latest_checkpoint_only False \
|
| 75 |
+
--image_aug True \
|
| 76 |
+
--lora_rank 32 \
|
| 77 |
+
--wandb_entity "$wandb_entity" \
|
| 78 |
+
--wandb_project "$wandb_project" \
|
| 79 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 80 |
+
--run_id_note "$run_id_note_value" \
|
| 81 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 82 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 83 |
+
--use_one_embed "$use_one_embed" \
|
| 84 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 85 |
+
--mlp_type "$mlp_type" \
|
| 86 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 87 |
+
--robot_platform "$robot_platform"
|
run_scripts/ffn/3postffn6.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_query
|
| 3 |
+
#========== !NOTE! ==========#
|
| 4 |
+
RUN_MODE=simvla3_postffn
|
| 5 |
+
use_predict_future_prop=False
|
| 6 |
+
batch_size=16
|
| 7 |
+
use_action_ts_head=True
|
| 8 |
+
use_one_embed=True
|
| 9 |
+
use_multi_scaling=False
|
| 10 |
+
mlp_type=postffn
|
| 11 |
+
decoder_num_blocks=6
|
| 12 |
+
robot_platform=libero
|
| 13 |
+
MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 14 |
+
#========== !NOTE! ==========#
|
| 15 |
+
use_l1_regression=True
|
| 16 |
+
num_images_in_input=1
|
| 17 |
+
wandb_entity=chenghaha
|
| 18 |
+
wandb_project=fastvla
|
| 19 |
+
wandb_log_freq=1
|
| 20 |
+
use_proprio=False
|
| 21 |
+
use_diffusion=False
|
| 22 |
+
use_film=False
|
| 23 |
+
num_steps_before_decay=20000
|
| 24 |
+
save_freq=10000
|
| 25 |
+
max_steps=40000
|
| 26 |
+
vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
|
| 27 |
+
data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
|
| 28 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 29 |
+
run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 30 |
+
#========== get run_id ==========#
|
| 31 |
+
note_parts=("${MODE}")
|
| 32 |
+
|
| 33 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 34 |
+
# note_parts+=("L1_regression")
|
| 35 |
+
# fi
|
| 36 |
+
|
| 37 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 38 |
+
# note_parts+=("3rd_person_img")
|
| 39 |
+
# else
|
| 40 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 41 |
+
# fi
|
| 42 |
+
|
| 43 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 44 |
+
# note_parts+=("proprio_state")
|
| 45 |
+
# fi
|
| 46 |
+
|
| 47 |
+
# if [ "$use_film" = "True" ]; then
|
| 48 |
+
# note_parts+=("Film")
|
| 49 |
+
# fi
|
| 50 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 51 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 52 |
+
|
| 53 |
+
#========== enter environment ==========#
|
| 54 |
+
conda activate openvla-oft
|
| 55 |
+
cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 56 |
+
export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 57 |
+
|
| 58 |
+
#========== run ==========#
|
| 59 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 60 |
+
--vla_path "$vla_path" \
|
| 61 |
+
--data_root_dir "$data_root_dir" \
|
| 62 |
+
--dataset_name "$dataset_name" \
|
| 63 |
+
--run_root_dir "$run_root_dir" \
|
| 64 |
+
--use_l1_regression "$use_l1_regression" \
|
| 65 |
+
--use_diffusion "$use_diffusion" \
|
| 66 |
+
--use_film "$use_film" \
|
| 67 |
+
--num_images_in_input "$num_images_in_input" \
|
| 68 |
+
--use_proprio "$use_proprio" \
|
| 69 |
+
--batch_size "$batch_size" \
|
| 70 |
+
--learning_rate 5e-4 \
|
| 71 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 72 |
+
--max_steps "$max_steps" \
|
| 73 |
+
--save_freq "$save_freq" \
|
| 74 |
+
--save_latest_checkpoint_only False \
|
| 75 |
+
--image_aug True \
|
| 76 |
+
--lora_rank 32 \
|
| 77 |
+
--wandb_entity "$wandb_entity" \
|
| 78 |
+
--wandb_project "$wandb_project" \
|
| 79 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 80 |
+
--run_id_note "$run_id_note_value" \
|
| 81 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 82 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 83 |
+
--use_one_embed "$use_one_embed" \
|
| 84 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 85 |
+
--mlp_type "$mlp_type" \
|
| 86 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 87 |
+
--robot_platform "$robot_platform"
|
run_scripts/ffn/debug_5ffn_withactionprojector.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_query
|
| 3 |
+
#========== !NOTE! ==========#
|
| 4 |
+
RUN_MODE=simvla4_ffn_withprojector
|
| 5 |
+
use_predict_future_prop=False
|
| 6 |
+
batch_size=16
|
| 7 |
+
use_action_ts_head=True
|
| 8 |
+
use_one_embed=True
|
| 9 |
+
use_multi_scaling=False
|
| 10 |
+
mlp_type=ffn
|
| 11 |
+
decoder_num_blocks=2
|
| 12 |
+
robot_platform=libero
|
| 13 |
+
MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 14 |
+
#========== !NOTE! ==========#
|
| 15 |
+
use_l1_regression=True
|
| 16 |
+
num_images_in_input=1
|
| 17 |
+
wandb_entity=chenghaha
|
| 18 |
+
wandb_project=fastvla
|
| 19 |
+
wandb_log_freq=1
|
| 20 |
+
use_proprio=False
|
| 21 |
+
use_diffusion=False
|
| 22 |
+
use_film=False
|
| 23 |
+
num_steps_before_decay=20000
|
| 24 |
+
save_freq=10000
|
| 25 |
+
max_steps=40000
|
| 26 |
+
vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
|
| 27 |
+
data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
|
| 28 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 29 |
+
run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 30 |
+
#========== get run_id ==========#
|
| 31 |
+
note_parts=("${MODE}")
|
| 32 |
+
|
| 33 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 34 |
+
# note_parts+=("L1_regression")
|
| 35 |
+
# fi
|
| 36 |
+
|
| 37 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 38 |
+
# note_parts+=("3rd_person_img")
|
| 39 |
+
# else
|
| 40 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 41 |
+
# fi
|
| 42 |
+
|
| 43 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 44 |
+
# note_parts+=("proprio_state")
|
| 45 |
+
# fi
|
| 46 |
+
|
| 47 |
+
# if [ "$use_film" = "True" ]; then
|
| 48 |
+
# note_parts+=("Film")
|
| 49 |
+
# fi
|
| 50 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 51 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 52 |
+
|
| 53 |
+
#========== enter environment ==========#
|
| 54 |
+
conda activate openvla-oft
|
| 55 |
+
cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 56 |
+
export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 57 |
+
|
| 58 |
+
#========== run ==========#
|
| 59 |
+
WANDB_CONSOLE=off WANDB_MODE=offline python -m debugpy --listen 1234 --wait-for-client '/opt/conda/envs/spatialvla/bin/torchrun' --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune.py \
|
| 60 |
+
--vla_path "$vla_path" \
|
| 61 |
+
--data_root_dir "$data_root_dir" \
|
| 62 |
+
--dataset_name "$dataset_name" \
|
| 63 |
+
--run_root_dir "$run_root_dir" \
|
| 64 |
+
--use_l1_regression "$use_l1_regression" \
|
| 65 |
+
--use_diffusion "$use_diffusion" \
|
| 66 |
+
--use_film "$use_film" \
|
| 67 |
+
--num_images_in_input "$num_images_in_input" \
|
| 68 |
+
--use_proprio "$use_proprio" \
|
| 69 |
+
--batch_size "$batch_size" \
|
| 70 |
+
--learning_rate 5e-5 \
|
| 71 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 72 |
+
--max_steps "$max_steps" \
|
| 73 |
+
--save_freq "$save_freq" \
|
| 74 |
+
--save_latest_checkpoint_only False \
|
| 75 |
+
--image_aug True \
|
| 76 |
+
--lora_rank 32 \
|
| 77 |
+
--wandb_entity "$wandb_entity" \
|
| 78 |
+
--wandb_project "$wandb_project" \
|
| 79 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 80 |
+
--run_id_note "$run_id_note_value" \
|
| 81 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 82 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 83 |
+
--use_one_embed "$use_one_embed" \
|
| 84 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 85 |
+
--mlp_type "$mlp_type" \
|
| 86 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 87 |
+
--robot_platform "$robot_platform"
|
run_scripts/ffn/ffn4.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_query
|
| 3 |
+
#========== !NOTE! ==========#
|
| 4 |
+
RUN_MODE=simvla_ffn
|
| 5 |
+
use_predict_future_prop=False
|
| 6 |
+
batch_size=16
|
| 7 |
+
use_action_ts_head=True
|
| 8 |
+
use_one_embed=True
|
| 9 |
+
use_multi_scaling=False
|
| 10 |
+
mlp_type=ffn
|
| 11 |
+
decoder_num_blocks=4
|
| 12 |
+
robot_platform=libero
|
| 13 |
+
MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 14 |
+
#========== !NOTE! ==========#
|
| 15 |
+
use_l1_regression=True
|
| 16 |
+
num_images_in_input=1
|
| 17 |
+
wandb_entity=chenghaha
|
| 18 |
+
wandb_project=fastvla
|
| 19 |
+
wandb_log_freq=1
|
| 20 |
+
use_proprio=False
|
| 21 |
+
use_diffusion=False
|
| 22 |
+
use_film=False
|
| 23 |
+
num_steps_before_decay=20000
|
| 24 |
+
save_freq=5000
|
| 25 |
+
max_steps=40000
|
| 26 |
+
vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
|
| 27 |
+
data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
|
| 28 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 29 |
+
run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 30 |
+
#========== get run_id ==========#
|
| 31 |
+
note_parts=("${MODE}")
|
| 32 |
+
|
| 33 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 34 |
+
# note_parts+=("L1_regression")
|
| 35 |
+
# fi
|
| 36 |
+
|
| 37 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 38 |
+
# note_parts+=("3rd_person_img")
|
| 39 |
+
# else
|
| 40 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 41 |
+
# fi
|
| 42 |
+
|
| 43 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 44 |
+
# note_parts+=("proprio_state")
|
| 45 |
+
# fi
|
| 46 |
+
|
| 47 |
+
# if [ "$use_film" = "True" ]; then
|
| 48 |
+
# note_parts+=("Film")
|
| 49 |
+
# fi
|
| 50 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 51 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 52 |
+
|
| 53 |
+
#========== enter environment ==========#
|
| 54 |
+
conda activate openvla-oft
|
| 55 |
+
cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 56 |
+
export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 57 |
+
|
| 58 |
+
#========== run ==========#
|
| 59 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 60 |
+
--vla_path "$vla_path" \
|
| 61 |
+
--data_root_dir "$data_root_dir" \
|
| 62 |
+
--dataset_name "$dataset_name" \
|
| 63 |
+
--run_root_dir "$run_root_dir" \
|
| 64 |
+
--use_l1_regression "$use_l1_regression" \
|
| 65 |
+
--use_diffusion "$use_diffusion" \
|
| 66 |
+
--use_film "$use_film" \
|
| 67 |
+
--num_images_in_input "$num_images_in_input" \
|
| 68 |
+
--use_proprio "$use_proprio" \
|
| 69 |
+
--batch_size "$batch_size" \
|
| 70 |
+
--learning_rate 5e-4 \
|
| 71 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 72 |
+
--max_steps "$max_steps" \
|
| 73 |
+
--save_freq "$save_freq" \
|
| 74 |
+
--save_latest_checkpoint_only False \
|
| 75 |
+
--image_aug True \
|
| 76 |
+
--lora_rank 32 \
|
| 77 |
+
--wandb_entity "$wandb_entity" \
|
| 78 |
+
--wandb_project "$wandb_project" \
|
| 79 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 80 |
+
--run_id_note "$run_id_note_value" \
|
| 81 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 82 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 83 |
+
--use_one_embed "$use_one_embed" \
|
| 84 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 85 |
+
--mlp_type "$mlp_type" \
|
| 86 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 87 |
+
--robot_platform "$robot_platform"
|
run_scripts/ffn/ffn8.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_query
|
| 3 |
+
#========== !NOTE! ==========#
|
| 4 |
+
RUN_MODE=simvla_ffn
|
| 5 |
+
use_predict_future_prop=False
|
| 6 |
+
batch_size=16
|
| 7 |
+
use_action_ts_head=True
|
| 8 |
+
use_one_embed=True
|
| 9 |
+
use_multi_scaling=False
|
| 10 |
+
mlp_type=ffn
|
| 11 |
+
decoder_num_blocks=8
|
| 12 |
+
robot_platform=libero
|
| 13 |
+
MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 14 |
+
#========== !NOTE! ==========#
|
| 15 |
+
use_l1_regression=True
|
| 16 |
+
num_images_in_input=1
|
| 17 |
+
wandb_entity=chenghaha
|
| 18 |
+
wandb_project=fastvla
|
| 19 |
+
wandb_log_freq=1
|
| 20 |
+
use_proprio=False
|
| 21 |
+
use_diffusion=False
|
| 22 |
+
use_film=False
|
| 23 |
+
num_steps_before_decay=20000
|
| 24 |
+
save_freq=5000
|
| 25 |
+
max_steps=40000
|
| 26 |
+
vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
|
| 27 |
+
data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
|
| 28 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 29 |
+
run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 30 |
+
#========== get run_id ==========#
|
| 31 |
+
note_parts=("${MODE}")
|
| 32 |
+
|
| 33 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 34 |
+
# note_parts+=("L1_regression")
|
| 35 |
+
# fi
|
| 36 |
+
|
| 37 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 38 |
+
# note_parts+=("3rd_person_img")
|
| 39 |
+
# else
|
| 40 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 41 |
+
# fi
|
| 42 |
+
|
| 43 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 44 |
+
# note_parts+=("proprio_state")
|
| 45 |
+
# fi
|
| 46 |
+
|
| 47 |
+
# if [ "$use_film" = "True" ]; then
|
| 48 |
+
# note_parts+=("Film")
|
| 49 |
+
# fi
|
| 50 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 51 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 52 |
+
|
| 53 |
+
#========== enter environment ==========#
|
| 54 |
+
conda activate openvla-oft
|
| 55 |
+
cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 56 |
+
export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 57 |
+
|
| 58 |
+
#========== run ==========#
|
| 59 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 60 |
+
--vla_path "$vla_path" \
|
| 61 |
+
--data_root_dir "$data_root_dir" \
|
| 62 |
+
--dataset_name "$dataset_name" \
|
| 63 |
+
--run_root_dir "$run_root_dir" \
|
| 64 |
+
--use_l1_regression "$use_l1_regression" \
|
| 65 |
+
--use_diffusion "$use_diffusion" \
|
| 66 |
+
--use_film "$use_film" \
|
| 67 |
+
--num_images_in_input "$num_images_in_input" \
|
| 68 |
+
--use_proprio "$use_proprio" \
|
| 69 |
+
--batch_size "$batch_size" \
|
| 70 |
+
--learning_rate 5e-4 \
|
| 71 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 72 |
+
--max_steps "$max_steps" \
|
| 73 |
+
--save_freq "$save_freq" \
|
| 74 |
+
--save_latest_checkpoint_only False \
|
| 75 |
+
--image_aug True \
|
| 76 |
+
--lora_rank 32 \
|
| 77 |
+
--wandb_entity "$wandb_entity" \
|
| 78 |
+
--wandb_project "$wandb_project" \
|
| 79 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 80 |
+
--run_id_note "$run_id_note_value" \
|
| 81 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 82 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 83 |
+
--use_one_embed "$use_one_embed" \
|
| 84 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 85 |
+
--mlp_type "$mlp_type" \
|
| 86 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 87 |
+
--robot_platform "$robot_platform"
|
run_scripts/ffn/test.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_query
|
| 3 |
+
#========== !NOTE! ==========#
|
| 4 |
+
RUN_MODE=test
|
| 5 |
+
use_predict_future_prop=False
|
| 6 |
+
batch_size=16
|
| 7 |
+
use_action_ts_head=True
|
| 8 |
+
use_one_embed=True
|
| 9 |
+
use_multi_scaling=False
|
| 10 |
+
mlp_type=ffn
|
| 11 |
+
decoder_num_blocks=2
|
| 12 |
+
robot_platform=libero
|
| 13 |
+
MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 14 |
+
#========== !NOTE! ==========#
|
| 15 |
+
use_l1_regression=True
|
| 16 |
+
num_images_in_input=1
|
| 17 |
+
wandb_entity=chenghaha
|
| 18 |
+
wandb_project=fastvla
|
| 19 |
+
wandb_log_freq=1
|
| 20 |
+
use_proprio=False
|
| 21 |
+
use_diffusion=False
|
| 22 |
+
use_film=False
|
| 23 |
+
num_steps_before_decay=20000
|
| 24 |
+
save_freq=5000
|
| 25 |
+
max_steps=40000
|
| 26 |
+
vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
|
| 27 |
+
data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
|
| 28 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 29 |
+
run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 30 |
+
#========== get run_id ==========#
|
| 31 |
+
note_parts=("${MODE}")
|
| 32 |
+
|
| 33 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 34 |
+
# note_parts+=("L1_regression")
|
| 35 |
+
# fi
|
| 36 |
+
|
| 37 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 38 |
+
# note_parts+=("3rd_person_img")
|
| 39 |
+
# else
|
| 40 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 41 |
+
# fi
|
| 42 |
+
|
| 43 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 44 |
+
# note_parts+=("proprio_state")
|
| 45 |
+
# fi
|
| 46 |
+
|
| 47 |
+
# if [ "$use_film" = "True" ]; then
|
| 48 |
+
# note_parts+=("Film")
|
| 49 |
+
# fi
|
| 50 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 51 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 52 |
+
|
| 53 |
+
#========== enter environment ==========#
|
| 54 |
+
conda activate openvla-oft
|
| 55 |
+
cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 56 |
+
export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 57 |
+
|
| 58 |
+
#========== run ==========#
|
| 59 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune.py \
|
| 60 |
+
--vla_path "$vla_path" \
|
| 61 |
+
--data_root_dir "$data_root_dir" \
|
| 62 |
+
--dataset_name "$dataset_name" \
|
| 63 |
+
--run_root_dir "$run_root_dir" \
|
| 64 |
+
--use_l1_regression "$use_l1_regression" \
|
| 65 |
+
--use_diffusion "$use_diffusion" \
|
| 66 |
+
--use_film "$use_film" \
|
| 67 |
+
--num_images_in_input "$num_images_in_input" \
|
| 68 |
+
--use_proprio "$use_proprio" \
|
| 69 |
+
--batch_size "$batch_size" \
|
| 70 |
+
--learning_rate 5e-4 \
|
| 71 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 72 |
+
--max_steps "$max_steps" \
|
| 73 |
+
--save_freq "$save_freq" \
|
| 74 |
+
--save_latest_checkpoint_only False \
|
| 75 |
+
--image_aug True \
|
| 76 |
+
--lora_rank 32 \
|
| 77 |
+
--wandb_entity "$wandb_entity" \
|
| 78 |
+
--wandb_project "$wandb_project" \
|
| 79 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 80 |
+
--run_id_note "$run_id_note_value" \
|
| 81 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 82 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 83 |
+
--use_one_embed "$use_one_embed" \
|
| 84 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 85 |
+
--mlp_type "$mlp_type" \
|
| 86 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 87 |
+
--robot_platform "$robot_platform"
|
run_scripts/ffn_long_chunks/run.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
bash run_scripts/ffn_long_chunks/li4.sh
|
| 2 |
+
bash run_scripts/ffn_long_chunks/li16.sh
|
| 3 |
+
bash run_scripts/ffn_long_chunks/li24.sh
|
| 4 |
+
bash run_scripts/ffn_long_chunks/li32.sh
|
run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_base.sh
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=simvla_twin2
|
| 3 |
+
ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
|
| 4 |
+
#========== !NOTE! ==========#
|
| 5 |
+
RUN_MODE=base
|
| 6 |
+
use_predict_future_prop=False
|
| 7 |
+
batch_size=4
|
| 8 |
+
use_action_ts_head=False
|
| 9 |
+
use_one_embed=False
|
| 10 |
+
use_multi_scaling=False
|
| 11 |
+
mlp_type=ffn
|
| 12 |
+
decoder_num_blocks=2
|
| 13 |
+
robot_platform=aloha
|
| 14 |
+
MODE=${RUN_MODE}_robot_platform_${robot_platform}
|
| 15 |
+
#========== !NOTE! ==========#
|
| 16 |
+
use_l1_regression=True
|
| 17 |
+
num_images_in_input=3
|
| 18 |
+
wandb_entity=chenghaha
|
| 19 |
+
wandb_project=robotwin
|
| 20 |
+
wandb_log_freq=1
|
| 21 |
+
use_proprio=True
|
| 22 |
+
use_diffusion=False
|
| 23 |
+
use_film=True
|
| 24 |
+
num_steps_before_decay=1000
|
| 25 |
+
save_freq=2000
|
| 26 |
+
max_steps=2000
|
| 27 |
+
vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
|
| 28 |
+
data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
|
| 29 |
+
dataset_name=grab_roller_aloha_agilex_50
|
| 30 |
+
run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 31 |
+
#========== get run_id ==========#
|
| 32 |
+
note_parts=("${MODE}")
|
| 33 |
+
|
| 34 |
+
if [ "$use_l1_regression" = "True" ]; then
|
| 35 |
+
note_parts+=("L1_regression")
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
if [ "$num_images_in_input" == 1 ]; then
|
| 39 |
+
note_parts+=("3rd_person_img")
|
| 40 |
+
else
|
| 41 |
+
note_parts+=("3rd_person_img_and_wrist")
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
if [ "$use_l1_regression" = "True" ]; then
|
| 45 |
+
note_parts+=("proprio_state")
|
| 46 |
+
fi
|
| 47 |
+
|
| 48 |
+
if [ "$use_film" = "True" ]; then
|
| 49 |
+
note_parts+=("Film")
|
| 50 |
+
fi
|
| 51 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 52 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 53 |
+
|
| 54 |
+
#========== enter environment ==========#
|
| 55 |
+
source activate openvla-oft
|
| 56 |
+
cd $ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 57 |
+
export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 58 |
+
|
| 59 |
+
#========== run ==========#
|
| 60 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 61 |
+
--vla_path "$vla_path" \
|
| 62 |
+
--data_root_dir "$data_root_dir" \
|
| 63 |
+
--dataset_name "$dataset_name" \
|
| 64 |
+
--run_root_dir "$run_root_dir" \
|
| 65 |
+
--use_l1_regression "$use_l1_regression" \
|
| 66 |
+
--use_diffusion "$use_diffusion" \
|
| 67 |
+
--use_film "$use_film" \
|
| 68 |
+
--num_images_in_input "$num_images_in_input" \
|
| 69 |
+
--use_proprio "$use_proprio" \
|
| 70 |
+
--batch_size "$batch_size" \
|
| 71 |
+
--learning_rate 5e-5 \
|
| 72 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 73 |
+
--max_steps "$max_steps" \
|
| 74 |
+
--save_freq "$save_freq" \
|
| 75 |
+
--save_latest_checkpoint_only False \
|
| 76 |
+
--image_aug True \
|
| 77 |
+
--lora_rank 32 \
|
| 78 |
+
--wandb_entity "$wandb_entity" \
|
| 79 |
+
--wandb_project "$wandb_project" \
|
| 80 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 81 |
+
--run_id_note "$run_id_note_value" \
|
| 82 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 83 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 84 |
+
--use_one_embed "$use_one_embed" \
|
| 85 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 86 |
+
--mlp_type "$mlp_type" \
|
| 87 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 88 |
+
--robot_platform "$robot_platform"
|
run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_50_l2.sh
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=simvla_twin2
|
| 3 |
+
ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
|
| 4 |
+
#========== !NOTE! ==========#
|
| 5 |
+
RUN_MODE=simvla_50
|
| 6 |
+
use_predict_future_prop=False
|
| 7 |
+
batch_size=4
|
| 8 |
+
use_action_ts_head=True
|
| 9 |
+
use_one_embed=True
|
| 10 |
+
use_multi_scaling=False
|
| 11 |
+
mlp_type=ffn
|
| 12 |
+
decoder_num_blocks=4
|
| 13 |
+
robot_platform=50_al
|
| 14 |
+
proj_type=gelu_linear
|
| 15 |
+
ffn_type=gelu
|
| 16 |
+
expand_inner_ratio=1
|
| 17 |
+
linear_drop_ratio=0.1
|
| 18 |
+
multi_queries_num=50
|
| 19 |
+
multi_query_norm_type=layernorm
|
| 20 |
+
action_norm=l2
|
| 21 |
+
MODE=${RUN_MODE}_inner${expand_inner_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 22 |
+
#========== !NOTE! ==========#
|
| 23 |
+
use_l1_regression=True
|
| 24 |
+
num_images_in_input=3
|
| 25 |
+
wandb_entity=chenghaha
|
| 26 |
+
wandb_project=robotwin
|
| 27 |
+
wandb_log_freq=1
|
| 28 |
+
use_proprio=True
|
| 29 |
+
use_diffusion=False
|
| 30 |
+
use_film=True
|
| 31 |
+
num_steps_before_decay=2000
|
| 32 |
+
save_freq=3000
|
| 33 |
+
max_steps=3000
|
| 34 |
+
vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
|
| 35 |
+
data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
|
| 36 |
+
dataset_name=grab_roller_aloha_agilex_50
|
| 37 |
+
run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 38 |
+
#========== get run_id ==========#
|
| 39 |
+
note_parts=("${MODE}")
|
| 40 |
+
|
| 41 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 42 |
+
# note_parts+=("L1_regression")
|
| 43 |
+
# fi
|
| 44 |
+
|
| 45 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 46 |
+
# note_parts+=("3rd_person_img")
|
| 47 |
+
# else
|
| 48 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 49 |
+
# fi
|
| 50 |
+
|
| 51 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 52 |
+
# note_parts+=("proprio_state")
|
| 53 |
+
# fi
|
| 54 |
+
|
| 55 |
+
# if [ "$use_film" = "True" ]; then
|
| 56 |
+
# note_parts+=("Film")
|
| 57 |
+
# fi
|
| 58 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 59 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 60 |
+
|
| 61 |
+
#========== enter environment ==========#
|
| 62 |
+
conda activate openvla-oft
|
| 63 |
+
cd $ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 64 |
+
export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 65 |
+
|
| 66 |
+
#========== run ==========#
|
| 67 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 68 |
+
--vla_path "$vla_path" \
|
| 69 |
+
--data_root_dir "$data_root_dir" \
|
| 70 |
+
--dataset_name "$dataset_name" \
|
| 71 |
+
--run_root_dir "$run_root_dir" \
|
| 72 |
+
--use_l1_regression "$use_l1_regression" \
|
| 73 |
+
--use_diffusion "$use_diffusion" \
|
| 74 |
+
--use_film "$use_film" \
|
| 75 |
+
--num_images_in_input "$num_images_in_input" \
|
| 76 |
+
--use_proprio "$use_proprio" \
|
| 77 |
+
--batch_size "$batch_size" \
|
| 78 |
+
--learning_rate 5e-5 \
|
| 79 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 80 |
+
--max_steps "$max_steps" \
|
| 81 |
+
--save_freq "$save_freq" \
|
| 82 |
+
--save_latest_checkpoint_only False \
|
| 83 |
+
--image_aug True \
|
| 84 |
+
--lora_rank 32 \
|
| 85 |
+
--wandb_entity "$wandb_entity" \
|
| 86 |
+
--wandb_project "$wandb_project" \
|
| 87 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 88 |
+
--run_id_note "$run_id_note_value" \
|
| 89 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 90 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 91 |
+
--use_one_embed "$use_one_embed" \
|
| 92 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 93 |
+
--mlp_type "$mlp_type" \
|
| 94 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 95 |
+
--robot_platform "$robot_platform" \
|
| 96 |
+
--proj_type "$proj_type" \
|
| 97 |
+
--ffn_type "$ffn_type" \
|
| 98 |
+
--expand_inner_ratio "$expand_inner_ratio" \
|
| 99 |
+
--linear_drop_ratio "$linear_drop_ratio" \
|
| 100 |
+
--multi_query_norm_type "$multi_query_norm_type" \
|
| 101 |
+
--multi_queries_num "$multi_queries_num" \
|
| 102 |
+
--action_norm "$action_norm"
|
run_scripts/ffn_q2a/bridge/exffn_relu_connector_linear_relu.sh
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_q2a
|
| 3 |
+
ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
|
| 4 |
+
#========== !NOTE! ==========#
|
| 5 |
+
RUN_MODE=simvla_q2a
|
| 6 |
+
use_predict_future_prop=False
|
| 7 |
+
batch_size=16
|
| 8 |
+
use_action_ts_head=True
|
| 9 |
+
use_one_embed=True
|
| 10 |
+
use_multi_scaling=False
|
| 11 |
+
mlp_type=ffn
|
| 12 |
+
decoder_num_blocks=6
|
| 13 |
+
robot_platform=bridge
|
| 14 |
+
without_head_drop_out=True
|
| 15 |
+
proj_type=linear_relu
|
| 16 |
+
ffn_type=relu
|
| 17 |
+
expand_actiondim_ratio=2.0
|
| 18 |
+
MODE=${RUN_MODE}_exffn${expand_actiondim_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 19 |
+
#========== !NOTE! ==========#
|
| 20 |
+
use_l1_regression=True
|
| 21 |
+
num_images_in_input=1
|
| 22 |
+
wandb_entity=chenghaha
|
| 23 |
+
wandb_project=fastvla
|
| 24 |
+
wandb_log_freq=1
|
| 25 |
+
use_proprio=False
|
| 26 |
+
use_diffusion=False
|
| 27 |
+
use_film=False
|
| 28 |
+
num_steps_before_decay=20000
|
| 29 |
+
save_freq=10000
|
| 30 |
+
max_steps=50000
|
| 31 |
+
vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
|
| 32 |
+
data_root_dir=$ROOT_PATH/datasets/openx/data/origin
|
| 33 |
+
dataset_name=bridge
|
| 34 |
+
run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 35 |
+
#========== get run_id ==========#
|
| 36 |
+
note_parts=("${MODE}")
|
| 37 |
+
|
| 38 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 39 |
+
# note_parts+=("L1_regression")
|
| 40 |
+
# fi
|
| 41 |
+
|
| 42 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 43 |
+
# note_parts+=("3rd_person_img")
|
| 44 |
+
# else
|
| 45 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 46 |
+
# fi
|
| 47 |
+
|
| 48 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 49 |
+
# note_parts+=("proprio_state")
|
| 50 |
+
# fi
|
| 51 |
+
|
| 52 |
+
# if [ "$use_film" = "True" ]; then
|
| 53 |
+
# note_parts+=("Film")
|
| 54 |
+
# fi
|
| 55 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 56 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 57 |
+
|
| 58 |
+
#========== enter environment ==========#
|
| 59 |
+
conda activate openvla-oft
|
| 60 |
+
cd $ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 61 |
+
export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 62 |
+
|
| 63 |
+
#========== run ==========#
|
| 64 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 65 |
+
--vla_path "$vla_path" \
|
| 66 |
+
--data_root_dir "$data_root_dir" \
|
| 67 |
+
--dataset_name "$dataset_name" \
|
| 68 |
+
--run_root_dir "$run_root_dir" \
|
| 69 |
+
--use_l1_regression "$use_l1_regression" \
|
| 70 |
+
--use_diffusion "$use_diffusion" \
|
| 71 |
+
--use_film "$use_film" \
|
| 72 |
+
--num_images_in_input "$num_images_in_input" \
|
| 73 |
+
--use_proprio "$use_proprio" \
|
| 74 |
+
--batch_size "$batch_size" \
|
| 75 |
+
--learning_rate 5e-4 \
|
| 76 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 77 |
+
--max_steps "$max_steps" \
|
| 78 |
+
--save_freq "$save_freq" \
|
| 79 |
+
--save_latest_checkpoint_only False \
|
| 80 |
+
--image_aug True \
|
| 81 |
+
--lora_rank 32 \
|
| 82 |
+
--wandb_entity "$wandb_entity" \
|
| 83 |
+
--wandb_project "$wandb_project" \
|
| 84 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 85 |
+
--run_id_note "$run_id_note_value" \
|
| 86 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 87 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 88 |
+
--use_one_embed "$use_one_embed" \
|
| 89 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 90 |
+
--mlp_type "$mlp_type" \
|
| 91 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 92 |
+
--robot_platform "$robot_platform" \
|
| 93 |
+
--proj_type "$proj_type" \
|
| 94 |
+
--ffn_type "$ffn_type" \
|
| 95 |
+
--expand_actiondim_ratio "$expand_actiondim_ratio"
|
run_scripts/ffn_q2a/bridge/run_bridge.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
bash run_scripts/ffn_q2a/bridge/exffn_gelu_bridge_drop0_5.sh
|
| 2 |
+
bash run_scripts/ffn_q2a/bridge/exffn_gelu_bridge.sh
|
run_scripts/ffn_q2a/franka/exffn_gelu_franka.sh
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=SimVLA_Condition
|
| 3 |
+
ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
|
| 4 |
+
#========== !NOTE! ==========#
|
| 5 |
+
RUN_MODE=simvla_q2a
|
| 6 |
+
use_predict_future_prop=False
|
| 7 |
+
batch_size=16
|
| 8 |
+
use_action_ts_head=True
|
| 9 |
+
use_one_embed=True
|
| 10 |
+
use_multi_scaling=False
|
| 11 |
+
mlp_type=ffn
|
| 12 |
+
decoder_num_blocks=4
|
| 13 |
+
robot_platform=rt1
|
| 14 |
+
without_head_drop_out=True
|
| 15 |
+
proj_type=gelu_linear
|
| 16 |
+
ffn_type=gelu
|
| 17 |
+
expand_actiondim_ratio=1.0
|
| 18 |
+
MODE=${RUN_MODE}_exffn${expand_actiondim_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 19 |
+
#========== !NOTE! ==========#
|
| 20 |
+
use_l1_regression=True
|
| 21 |
+
num_images_in_input=1
|
| 22 |
+
wandb_entity=chenghaha
|
| 23 |
+
wandb_project=fastvla
|
| 24 |
+
wandb_log_freq=1
|
| 25 |
+
use_proprio=False
|
| 26 |
+
use_diffusion=False
|
| 27 |
+
use_film=False
|
| 28 |
+
num_steps_before_decay=30000
|
| 29 |
+
save_freq=10000
|
| 30 |
+
max_steps=60000
|
| 31 |
+
vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
|
| 32 |
+
data_root_dir=$ROOT_PATH/datasets/openx/data/origin
|
| 33 |
+
dataset_name=rt1
|
| 34 |
+
run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 35 |
+
#========== get run_id ==========#
|
| 36 |
+
note_parts=("${MODE}")
|
| 37 |
+
|
| 38 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 39 |
+
# note_parts+=("L1_regression")
|
| 40 |
+
# fi
|
| 41 |
+
|
| 42 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 43 |
+
# note_parts+=("3rd_person_img")
|
| 44 |
+
# else
|
| 45 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 46 |
+
# fi
|
| 47 |
+
|
| 48 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 49 |
+
# note_parts+=("proprio_state")
|
| 50 |
+
# fi
|
| 51 |
+
|
| 52 |
+
# if [ "$use_film" = "True" ]; then
|
| 53 |
+
# note_parts+=("Film")
|
| 54 |
+
# fi
|
| 55 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 56 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 57 |
+
|
| 58 |
+
#========== enter environment ==========#
|
| 59 |
+
conda activate openvla-oft
|
| 60 |
+
cd $ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 61 |
+
export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 62 |
+
|
| 63 |
+
#========== run ==========#
|
| 64 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 65 |
+
--vla_path "$vla_path" \
|
| 66 |
+
--data_root_dir "$data_root_dir" \
|
| 67 |
+
--dataset_name "$dataset_name" \
|
| 68 |
+
--run_root_dir "$run_root_dir" \
|
| 69 |
+
--use_l1_regression "$use_l1_regression" \
|
| 70 |
+
--use_diffusion "$use_diffusion" \
|
| 71 |
+
--use_film "$use_film" \
|
| 72 |
+
--num_images_in_input "$num_images_in_input" \
|
| 73 |
+
--use_proprio "$use_proprio" \
|
| 74 |
+
--batch_size "$batch_size" \
|
| 75 |
+
--learning_rate 5e-4 \
|
| 76 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 77 |
+
--max_steps "$max_steps" \
|
| 78 |
+
--save_freq "$save_freq" \
|
| 79 |
+
--save_latest_checkpoint_only False \
|
| 80 |
+
--image_aug True \
|
| 81 |
+
--lora_rank 32 \
|
| 82 |
+
--wandb_entity "$wandb_entity" \
|
| 83 |
+
--wandb_project "$wandb_project" \
|
| 84 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 85 |
+
--run_id_note "$run_id_note_value" \
|
| 86 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 87 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 88 |
+
--use_one_embed "$use_one_embed" \
|
| 89 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 90 |
+
--mlp_type "$mlp_type" \
|
| 91 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 92 |
+
--robot_platform "$robot_platform" \
|
| 93 |
+
--proj_type "$proj_type" \
|
| 94 |
+
--ffn_type "$ffn_type" \
|
| 95 |
+
--expand_actiondim_ratio "$expand_actiondim_ratio"
|
run_scripts/ffn_q2a/libero_moe/debug_moe_lit.sh
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=SimVLA
|
| 3 |
+
ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/jiajiuyang-240108580167/chengdongzhou
|
| 4 |
+
#========== !NOTE! ==========#
|
| 5 |
+
RUN_MODE=simvla_q2a_lit
|
| 6 |
+
use_predict_future_prop=False
|
| 7 |
+
batch_size=2
|
| 8 |
+
use_action_ts_head=True
|
| 9 |
+
use_one_embed=True
|
| 10 |
+
use_multi_scaling=False
|
| 11 |
+
mlp_type=moe
|
| 12 |
+
decoder_num_blocks=2
|
| 13 |
+
robot_platform=16_li
|
| 14 |
+
without_head_drop_out=True
|
| 15 |
+
proj_type=gelu_linear
|
| 16 |
+
ffn_type=gelu
|
| 17 |
+
num_experts=8
|
| 18 |
+
expand_inner_ratio=2
|
| 19 |
+
top_k=2
|
| 20 |
+
expand_actiondim_ratio=0.5
|
| 21 |
+
MODE=${RUN_MODE}_ex${expand_actiondim_ratio}_inner${expand_inner_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}_num_experts${num_experts}_top_k{$top_k}
|
| 22 |
+
#========== !NOTE! ==========#
|
| 23 |
+
use_l1_regression=True
|
| 24 |
+
num_images_in_input=1
|
| 25 |
+
wandb_entity=chenghaha
|
| 26 |
+
wandb_project=fastvla
|
| 27 |
+
wandb_log_freq=1
|
| 28 |
+
use_proprio=False
|
| 29 |
+
use_diffusion=False
|
| 30 |
+
use_film=False
|
| 31 |
+
num_steps_before_decay=20000
|
| 32 |
+
save_freq=10000
|
| 33 |
+
max_steps=50000
|
| 34 |
+
vla_path=$ROOT_PATH/ai_models/openvla
|
| 35 |
+
data_root_dir=$ROOT_PATH/datasets/openvla/modified_libero_rlds
|
| 36 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 37 |
+
run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 38 |
+
#========== get run_id ==========#
|
| 39 |
+
note_parts=("${MODE}")
|
| 40 |
+
|
| 41 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 42 |
+
# note_parts+=("L1_regression")
|
| 43 |
+
# fi
|
| 44 |
+
|
| 45 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 46 |
+
# note_parts+=("3rd_person_img")
|
| 47 |
+
# else
|
| 48 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 49 |
+
# fi
|
| 50 |
+
|
| 51 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 52 |
+
# note_parts+=("proprio_state")
|
| 53 |
+
# fi
|
| 54 |
+
|
| 55 |
+
# if [ "$use_film" = "True" ]; then
|
| 56 |
+
# note_parts+=("Film")
|
| 57 |
+
# fi
|
| 58 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 59 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 60 |
+
|
| 61 |
+
#========== enter environment ==========#
|
| 62 |
+
conda activate openvla-oft
|
| 63 |
+
cd $ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 64 |
+
export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 65 |
+
|
| 66 |
+
#========== run ==========#
|
| 67 |
+
WANDB_CONSOLE=off WANDB_MODE=offline python -m debugpy --listen 1234 --wait-for-client '/opt/conda/envs/openvla-oft/bin/torchrun' --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune.py \
|
| 68 |
+
--vla_path "$vla_path" \
|
| 69 |
+
--data_root_dir "$data_root_dir" \
|
| 70 |
+
--dataset_name "$dataset_name" \
|
| 71 |
+
--run_root_dir "$run_root_dir" \
|
| 72 |
+
--use_l1_regression "$use_l1_regression" \
|
| 73 |
+
--use_diffusion "$use_diffusion" \
|
| 74 |
+
--use_film "$use_film" \
|
| 75 |
+
--num_images_in_input "$num_images_in_input" \
|
| 76 |
+
--use_proprio "$use_proprio" \
|
| 77 |
+
--batch_size "$batch_size" \
|
| 78 |
+
--learning_rate 5e-4 \
|
| 79 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 80 |
+
--max_steps "$max_steps" \
|
| 81 |
+
--save_freq "$save_freq" \
|
| 82 |
+
--save_latest_checkpoint_only False \
|
| 83 |
+
--image_aug True \
|
| 84 |
+
--lora_rank 32 \
|
| 85 |
+
--wandb_entity "$wandb_entity" \
|
| 86 |
+
--wandb_project "$wandb_project" \
|
| 87 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 88 |
+
--run_id_note "$run_id_note_value" \
|
| 89 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 90 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 91 |
+
--use_one_embed "$use_one_embed" \
|
| 92 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 93 |
+
--mlp_type "$mlp_type" \
|
| 94 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 95 |
+
--robot_platform "$robot_platform" \
|
| 96 |
+
--proj_type "$proj_type" \
|
| 97 |
+
--ffn_type "$ffn_type" \
|
| 98 |
+
--expand_inner_ratio "$expand_inner_ratio" \
|
| 99 |
+
--expand_actiondim_ratio "$expand_actiondim_ratio" \
|
| 100 |
+
--num_experts "$num_experts" \
|
| 101 |
+
--top_k "$top_k"
|
run_scripts/ffn_q2a/simhead/simhead_contrastive.sh
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=SimVLA_Condition
|
| 3 |
+
ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
|
| 4 |
+
#========== !NOTE! ==========#
|
| 5 |
+
RUN_MODE=simvla_q2a
|
| 6 |
+
use_predict_future_prop=False
|
| 7 |
+
batch_size=16
|
| 8 |
+
use_action_ts_head=True
|
| 9 |
+
use_one_embed=True
|
| 10 |
+
use_multi_scaling=False
|
| 11 |
+
mlp_type=ffn
|
| 12 |
+
decoder_num_blocks=2
|
| 13 |
+
robot_platform=16_li
|
| 14 |
+
without_head_drop_out=True
|
| 15 |
+
without_action_projector=True
|
| 16 |
+
ffn_type=gelu
|
| 17 |
+
use_l2norm=False
|
| 18 |
+
expand_inner_ratio=2.0
|
| 19 |
+
use_contrastive_loss=True
|
| 20 |
+
MODE=${RUN_MODE}_usecons${use_contrastive_loss}_newexinner_${expand_inner_ratio}_without_ap_ffn_type_${ffn_type}_use_l2norm${use_l2norm}_mlp_${mlp_type}_num_${decoder_num_blocks}
|
| 21 |
+
#========== !NOTE! ==========#
|
| 22 |
+
use_l1_regression=True
|
| 23 |
+
num_images_in_input=1
|
| 24 |
+
wandb_entity=chenghaha
|
| 25 |
+
wandb_project=fastvla
|
| 26 |
+
wandb_log_freq=1
|
| 27 |
+
use_proprio=False
|
| 28 |
+
use_diffusion=False
|
| 29 |
+
use_film=False
|
| 30 |
+
num_steps_before_decay=30000
|
| 31 |
+
save_freq=10000
|
| 32 |
+
max_steps=50000
|
| 33 |
+
vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
|
| 34 |
+
data_root_dir=$ROOT_PATH/datasets/openvla/modified_libero_rlds
|
| 35 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 36 |
+
run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 37 |
+
#========== get run_id ==========#
|
| 38 |
+
note_parts=("${MODE}")
|
| 39 |
+
|
| 40 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 41 |
+
# note_parts+=("L1_regression")
|
| 42 |
+
# fi
|
| 43 |
+
|
| 44 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 45 |
+
# note_parts+=("3rd_person_img")
|
| 46 |
+
# else
|
| 47 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 48 |
+
# fi
|
| 49 |
+
|
| 50 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 51 |
+
# note_parts+=("proprio_state")
|
| 52 |
+
# fi
|
| 53 |
+
|
| 54 |
+
# if [ "$use_film" = "True" ]; then
|
| 55 |
+
# note_parts+=("Film")
|
| 56 |
+
# fi
|
| 57 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 58 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 59 |
+
|
| 60 |
+
#========== enter environment ==========#
|
| 61 |
+
conda activate openvla-oft
|
| 62 |
+
cd $ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 63 |
+
export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
|
| 64 |
+
|
| 65 |
+
#========== run ==========#
|
| 66 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 67 |
+
--vla_path "$vla_path" \
|
| 68 |
+
--data_root_dir "$data_root_dir" \
|
| 69 |
+
--dataset_name "$dataset_name" \
|
| 70 |
+
--run_root_dir "$run_root_dir" \
|
| 71 |
+
--use_l1_regression "$use_l1_regression" \
|
| 72 |
+
--use_diffusion "$use_diffusion" \
|
| 73 |
+
--use_film "$use_film" \
|
| 74 |
+
--num_images_in_input "$num_images_in_input" \
|
| 75 |
+
--use_proprio "$use_proprio" \
|
| 76 |
+
--batch_size "$batch_size" \
|
| 77 |
+
--learning_rate 5e-4 \
|
| 78 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 79 |
+
--max_steps "$max_steps" \
|
| 80 |
+
--save_freq "$save_freq" \
|
| 81 |
+
--save_latest_checkpoint_only False \
|
| 82 |
+
--image_aug True \
|
| 83 |
+
--lora_rank 32 \
|
| 84 |
+
--wandb_entity "$wandb_entity" \
|
| 85 |
+
--wandb_project "$wandb_project" \
|
| 86 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 87 |
+
--run_id_note "$run_id_note_value" \
|
| 88 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 89 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 90 |
+
--use_one_embed "$use_one_embed" \
|
| 91 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 92 |
+
--mlp_type "$mlp_type" \
|
| 93 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 94 |
+
--robot_platform "$robot_platform" \
|
| 95 |
+
--proj_type "$proj_type" \
|
| 96 |
+
--ffn_type "$ffn_type" \
|
| 97 |
+
--use_l2norm "$use_l2norm" \
|
| 98 |
+
--expand_inner_ratio "$expand_inner_ratio" \
|
| 99 |
+
--without_action_projector "$without_action_projector" \
|
| 100 |
+
--use_contrastive_loss "$use_contrastive_loss"
|
run_scripts/pp/pp.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#========== settings ==========#
|
| 2 |
+
PROJECT_PATH=fastvla_multi_scale_query
|
| 3 |
+
#========== !NOTE! ==========#
|
| 4 |
+
RUN_MODE=simvla_PP
|
| 5 |
+
use_predict_future_prop=True
|
| 6 |
+
batch_size=16
|
| 7 |
+
use_action_ts_head=True
|
| 8 |
+
use_one_embed=True
|
| 9 |
+
use_multi_scaling=False
|
| 10 |
+
mlp_type=ffn
|
| 11 |
+
decoder_num_blocks=4
|
| 12 |
+
robot_platform=libero
|
| 13 |
+
MODE=${RUN_MODE}_use_pp_${use_predict_future_prop}_use_ts_${use_action_ts_head}_use_one_${use_one_embed}_use_ms_${use_multi_scaling}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
|
| 14 |
+
#========== !NOTE! ==========#
|
| 15 |
+
use_l1_regression=True
|
| 16 |
+
num_images_in_input=1
|
| 17 |
+
wandb_entity=chenghaha
|
| 18 |
+
wandb_project=fastvla
|
| 19 |
+
wandb_log_freq=1
|
| 20 |
+
use_proprio=True
|
| 21 |
+
use_diffusion=False
|
| 22 |
+
use_film=False
|
| 23 |
+
num_steps_before_decay=20000
|
| 24 |
+
save_freq=5000
|
| 25 |
+
max_steps=50000
|
| 26 |
+
vla_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b
|
| 27 |
+
data_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/openvla/modified_libero_rlds
|
| 28 |
+
dataset_name=libero_4_task_suites_no_noops
|
| 29 |
+
run_root_dir=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH/results/$RUN_MODE
|
| 30 |
+
#========== get run_id ==========#
|
| 31 |
+
note_parts=("${MODE}")
|
| 32 |
+
|
| 33 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 34 |
+
# note_parts+=("L1_regression")
|
| 35 |
+
# fi
|
| 36 |
+
|
| 37 |
+
# if [ "$num_images_in_input" == 1 ]; then
|
| 38 |
+
# note_parts+=("3rd_person_img")
|
| 39 |
+
# else
|
| 40 |
+
# note_parts+=("3rd_person_img_and_wrist")
|
| 41 |
+
# fi
|
| 42 |
+
|
| 43 |
+
# if [ "$use_l1_regression" = "True" ]; then
|
| 44 |
+
# note_parts+=("proprio_state")
|
| 45 |
+
# fi
|
| 46 |
+
|
| 47 |
+
# if [ "$use_film" = "True" ]; then
|
| 48 |
+
# note_parts+=("Film")
|
| 49 |
+
# fi
|
| 50 |
+
note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
|
| 51 |
+
run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
|
| 52 |
+
|
| 53 |
+
#========== enter environment ==========#
|
| 54 |
+
# conda activate openvla-oft
|
| 55 |
+
cd /inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 56 |
+
export PYTHONPATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/$PROJECT_PATH
|
| 57 |
+
|
| 58 |
+
#========== run ==========#
|
| 59 |
+
WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
|
| 60 |
+
--vla_path "$vla_path" \
|
| 61 |
+
--data_root_dir "$data_root_dir" \
|
| 62 |
+
--dataset_name "$dataset_name" \
|
| 63 |
+
--run_root_dir "$run_root_dir" \
|
| 64 |
+
--use_l1_regression "$use_l1_regression" \
|
| 65 |
+
--use_diffusion "$use_diffusion" \
|
| 66 |
+
--use_film "$use_film" \
|
| 67 |
+
--num_images_in_input "$num_images_in_input" \
|
| 68 |
+
--use_proprio "$use_proprio" \
|
| 69 |
+
--batch_size "$batch_size" \
|
| 70 |
+
--learning_rate 5e-4 \
|
| 71 |
+
--num_steps_before_decay "$num_steps_before_decay" \
|
| 72 |
+
--max_steps "$max_steps" \
|
| 73 |
+
--save_freq "$save_freq" \
|
| 74 |
+
--save_latest_checkpoint_only False \
|
| 75 |
+
--image_aug True \
|
| 76 |
+
--lora_rank 32 \
|
| 77 |
+
--wandb_entity "$wandb_entity" \
|
| 78 |
+
--wandb_project "$wandb_project" \
|
| 79 |
+
--wandb_log_freq "$wandb_log_freq" \
|
| 80 |
+
--run_id_note "$run_id_note_value" \
|
| 81 |
+
--use_predict_future_prop "$use_predict_future_prop" \
|
| 82 |
+
--use_action_ts_head "$use_action_ts_head" \
|
| 83 |
+
--use_one_embed "$use_one_embed" \
|
| 84 |
+
--use_multi_scaling "$use_multi_scaling" \
|
| 85 |
+
--mlp_type "$mlp_type" \
|
| 86 |
+
--decoder_num_blocks "$decoder_num_blocks" \
|
| 87 |
+
--robot_platform "$robot_platform"
|
run_scripts/run.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# bash run_scripts/multiscaling/exp_multiscaling.sh
|
| 2 |
+
|
| 3 |
+
# bash run_scripts/ffn_or_gating/gating.sh
|
| 4 |
+
|
| 5 |
+
# bash run_scripts/ffn/ffn0.sh
|
| 6 |
+
# bash run_scripts/ffn/ffn2.sh
|
| 7 |
+
# bash run_scripts/ffn/ffn4.sh
|
| 8 |
+
# bash run_scripts/ffn/ffn6.sh
|
| 9 |
+
# bash run_scripts/ffn/ffn8.sh
|
| 10 |
+
# bash run_scripts/multiscaling/2latentmsahead.sh
|
| 11 |
+
# bash run_scripts/multiscaling/2msahead.sh
|
| 12 |
+
# bash run_scripts/ffn/2ffn6.sh
|
| 13 |
+
# bash run_scripts/all_input/2all_inputs.sh
|
| 14 |
+
|
| 15 |
+
# bash run_scripts/ffn_or_gating/gating.sh
|
| 16 |
+
|
| 17 |
+
# bash run_scripts/ffn/ffn0.sh
|
| 18 |
+
# bash run_scripts/ffn/ffn2.sh
|
| 19 |
+
# bash run_scripts/ffn/ffn4.sh
|
| 20 |
+
# bash run_scripts/ffn/ffn6.sh
|
| 21 |
+
# bash run_scripts/ffn/ffn8.sh
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# bash run_scripts/ffn/3ffn2.sh
|
| 25 |
+
# bash run_scripts/ffn/3ffn6.sh
|
| 26 |
+
# bash run_scripts/ffn/3postffn2.sh
|
| 27 |
+
# bash run_scripts/ffn/3postffn6.sh
|
| 28 |
+
|
| 29 |
+
# bash run_scripts/ffn/4ffn_withactionprojector.sh
|
| 30 |
+
# bash run_scripts/ffn/4ffn6_withactionprojector.sh
|
| 31 |
+
|
| 32 |
+
# bash run_scripts/ffn/4ffn_withactionprojector.sh
|
| 33 |
+
|
| 34 |
+
bash run_scripts/ffn/5ffn_withactionprojector.sh
|
| 35 |
+
bash run_scripts/ffn/5ffn6_withactionprojector.sh
|
scripts/extern/verify_prismatic.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
verify_prismatic.py
|
| 3 |
+
|
| 4 |
+
Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate().
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import requests
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from transformers import AutoModelForVision2Seq, AutoProcessor
|
| 13 |
+
|
| 14 |
+
# === Verification Arguments ===
|
| 15 |
+
MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b"
|
| 16 |
+
DEFAULT_IMAGE_URL = (
|
| 17 |
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
if "-prism-" in MODEL_PATH:
|
| 21 |
+
SAMPLE_PROMPTS_FOR_GENERATION = [
|
| 22 |
+
"In: What is sitting in the coffee?\nOut:",
|
| 23 |
+
"In: What's the name of the food on the plate?\nOut:",
|
| 24 |
+
"In: caption.\nOut:",
|
| 25 |
+
"In: how many beinets..?\nOut:",
|
| 26 |
+
"In: Can you give me a lyrical description of the scene\nOut:",
|
| 27 |
+
]
|
| 28 |
+
else:
|
| 29 |
+
SYSTEM_PROMPT = (
|
| 30 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
| 31 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
| 32 |
+
)
|
| 33 |
+
SAMPLE_PROMPTS_FOR_GENERATION = [
|
| 34 |
+
f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:",
|
| 35 |
+
f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:",
|
| 36 |
+
f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:",
|
| 37 |
+
f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:",
|
| 38 |
+
f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@torch.inference_mode()
|
| 43 |
+
def verify_prismatic() -> None:
|
| 44 |
+
print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`")
|
| 45 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 46 |
+
|
| 47 |
+
# Load Processor & VLM
|
| 48 |
+
print("[*] Instantiating Processor and Pretrained VLM")
|
| 49 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
| 50 |
+
|
| 51 |
+
# === AUTOCAST MODE ===
|
| 52 |
+
# print("[*] Loading in BF16 Autocast Mode")
|
| 53 |
+
# vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to(
|
| 54 |
+
# device, dtype=torch.bfloat16
|
| 55 |
+
# )
|
| 56 |
+
|
| 57 |
+
# === NATIVE BFLOAT16 MODE ===
|
| 58 |
+
# print("[*] Loading in BF16")
|
| 59 |
+
# vlm = AutoModelForVision2Seq.from_pretrained(
|
| 60 |
+
# MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
|
| 61 |
+
# ).to(device)
|
| 62 |
+
|
| 63 |
+
# === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] ===
|
| 64 |
+
print("[*] Loading in BF16 with Flash-Attention Enabled")
|
| 65 |
+
vlm = AutoModelForVision2Seq.from_pretrained(
|
| 66 |
+
MODEL_PATH,
|
| 67 |
+
attn_implementation="flash_attention_2",
|
| 68 |
+
torch_dtype=torch.bfloat16,
|
| 69 |
+
low_cpu_mem_usage=True,
|
| 70 |
+
trust_remote_code=True,
|
| 71 |
+
).to(device)
|
| 72 |
+
|
| 73 |
+
# === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] ===
|
| 74 |
+
# print("[*] Loading in 8-Bit Quantization Mode")
|
| 75 |
+
# vlm = AutoModelForVision2Seq.from_pretrained(
|
| 76 |
+
# MODEL_PATH,
|
| 77 |
+
# attn_implementation="flash_attention_2",
|
| 78 |
+
# torch_dtype=torch.float16,
|
| 79 |
+
# quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
| 80 |
+
# low_cpu_mem_usage=True,
|
| 81 |
+
# trust_remote_code=True,
|
| 82 |
+
# )
|
| 83 |
+
|
| 84 |
+
# === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] ===
|
| 85 |
+
# print("[*] Loading in 4-Bit Quantization Mode")
|
| 86 |
+
# vlm = AutoModelForVision2Seq.from_pretrained(
|
| 87 |
+
# MODEL_PATH,
|
| 88 |
+
# attn_implementation="flash_attention_2",
|
| 89 |
+
# torch_dtype=torch.float16,
|
| 90 |
+
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
| 91 |
+
# low_cpu_mem_usage=True,
|
| 92 |
+
# trust_remote_code=True,
|
| 93 |
+
# )
|
| 94 |
+
|
| 95 |
+
# Iterate over Sample Prompts =>> Generate
|
| 96 |
+
image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB")
|
| 97 |
+
num_tokens, total_time = 0, 0.0
|
| 98 |
+
|
| 99 |
+
print("[*] Iterating over Sample Prompts\n===\n")
|
| 100 |
+
for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION):
|
| 101 |
+
# === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) ===
|
| 102 |
+
# inputs = processor(prompt, image).to(device)
|
| 103 |
+
#
|
| 104 |
+
# # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py`
|
| 105 |
+
# # =>> Running in native BF16 is also fine (but leads to slightly different generations)
|
| 106 |
+
# with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
|
| 107 |
+
# gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)
|
| 108 |
+
|
| 109 |
+
# === BFLOAT16 MODE ===
|
| 110 |
+
inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
|
| 111 |
+
|
| 112 |
+
# === 8-BIT/4-BIT QUANTIZATION MODE ===
|
| 113 |
+
# inputs = processor(prompt, image).to(device, dtype=torch.float16)
|
| 114 |
+
|
| 115 |
+
# Run Inference
|
| 116 |
+
gen_ids = None
|
| 117 |
+
for _ in range(5):
|
| 118 |
+
start_time = time.time()
|
| 119 |
+
gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)
|
| 120 |
+
total_time += time.time() - start_time
|
| 121 |
+
|
| 122 |
+
gen_ids = gen_ids[0, inputs.input_ids.shape[1] :]
|
| 123 |
+
num_tokens += len(gen_ids)
|
| 124 |
+
|
| 125 |
+
# ===
|
| 126 |
+
gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip()
|
| 127 |
+
print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n")
|
| 128 |
+
|
| 129 |
+
# Compute Tokens / Second
|
| 130 |
+
print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
verify_prismatic()
|
scripts/pretrain.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
pretrain.py
|
| 3 |
+
|
| 4 |
+
Pretraining script for Prismatic VLM pretraining in native PyTorch, using Fully-Sharded Data Parallel (FSDP) to run
|
| 5 |
+
distributed training across GPUs. By default, assumes that CUDA toolkit is >= 11.0 (to support BF16 mixed precision).
|
| 6 |
+
|
| 7 |
+
Notes & Prerequisites:
|
| 8 |
+
- We're loading LLaMa-2 (and possibly other) gated models from HuggingFace (HF Hub); these require an auth_token.
|
| 9 |
+
For LLaMa-2, make sure to first get Meta approval, then fill out the form at the top of the HF LLaMa-2 page:
|
| 10 |
+
=> Link: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
| 11 |
+
=> Generate Token (from `huggingface.co`): Settings / Access Tokens / New "Read" Token
|
| 12 |
+
=> Set `cfg.hf_token` to file path with token (as single line text file) or environment variable name
|
| 13 |
+
|
| 14 |
+
- If you want to set a custom location for all HF / TIMM artifacts --> `export HF_HOME="<PATH>"` *before* running!
|
| 15 |
+
=> For example (add to end of .bashrc): `export HF_HOME="/mnt/fsx/skaramcheti/cache"`
|
| 16 |
+
|
| 17 |
+
Run with:
|
| 18 |
+
- [Single Node One-GPU (Debug)] : torchrun --standalone --nnodes 1 --nproc-per-node 1 scripts/pretrain.py
|
| 19 |
+
- [Single Node Multi-GPU (= $K)]: torchrun --standalone --nnodes 1 --nproc-per-node $K scripts/pretrain.py
|
| 20 |
+
- [Multi-Node/AWS Sagemaker] Depends on your individual setup; file an issue if you have trouble!
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Optional, Tuple, Union
|
| 28 |
+
|
| 29 |
+
import draccus
|
| 30 |
+
import torch
|
| 31 |
+
import torch.distributed as dist
|
| 32 |
+
import yaml
|
| 33 |
+
|
| 34 |
+
from prismatic.conf import DatasetConfig, DatasetRegistry, ModelConfig, ModelRegistry
|
| 35 |
+
from prismatic.models import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm
|
| 36 |
+
from prismatic.overwatch import initialize_overwatch
|
| 37 |
+
from prismatic.preprocessing import get_dataset_and_collator
|
| 38 |
+
from prismatic.training import Metrics, get_train_strategy
|
| 39 |
+
from prismatic.util import set_global_seed
|
| 40 |
+
|
| 41 |
+
# Disable Tokenizers Parallelism to Play Nice w/ PyTorch Multiprocessing DataLoaders
|
| 42 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 43 |
+
|
| 44 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
| 45 |
+
overwatch = initialize_overwatch(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class PretrainConfig:
|
| 50 |
+
# fmt: off
|
| 51 |
+
|
| 52 |
+
# ModelConfig (`prismatic/conf/models.py`); override with --model.type `ModelRegistry.<MODEL>.model_id`
|
| 53 |
+
model: ModelConfig = field(
|
| 54 |
+
default_factory=ModelConfig.get_choice_class(ModelRegistry.PRISM_DINOSIGLIP_CONTROLLED_7B.model_id)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# DatasetConfig (`prismatic/conf/datasets.py`); override with --dataset.type `DatasetRegistry.<DATASET>.dataset_id`
|
| 58 |
+
dataset: DatasetConfig = field(
|
| 59 |
+
default_factory=DatasetConfig.get_choice_class(DatasetRegistry.LLAVA_V15.dataset_id)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Pretraining Stage in < align (projector-only) | finetune (projector + LLM) | full-finetune (all) >
|
| 63 |
+
# ---
|
| 64 |
+
stage: str = "finetune" # Pretraining Stage in < align | finetune >
|
| 65 |
+
pretrained_checkpoint: Optional[Path] = None # Pretrained Checkpoint to Load (for `finetune`)
|
| 66 |
+
# if None =>> will match on (run_dir / `align`)
|
| 67 |
+
|
| 68 |
+
# Run Arguments
|
| 69 |
+
run_id: Optional[str] = None # Run ID for logging, Weights & Biases
|
| 70 |
+
run_root_dir: Path = Path("/mnt/fsx/x-prismatic-vlms/runs") # Path to directory to store logs & checkpoints
|
| 71 |
+
seed: int = 7 # Random seed (for reproducibility)
|
| 72 |
+
|
| 73 |
+
# HF Hub Credentials (for any gated models)
|
| 74 |
+
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
| 75 |
+
|
| 76 |
+
# Tracking Parameters
|
| 77 |
+
trackers: Tuple[str, ...] = ("jsonl", "wandb") # Trackers to initialize (if W&B, add config!)
|
| 78 |
+
wandb_project: str = "onyx-vlms" # Name of W&B project (default: `prismatic`)
|
| 79 |
+
wandb_entity: Optional[str] = "stanford-voltron" # Name of W&B entity (default: None)
|
| 80 |
+
|
| 81 |
+
def __post_init__(self) -> None:
|
| 82 |
+
"""Set optimization parameters based on `stage` in {"align", "finetune"}."""
|
| 83 |
+
if self.stage == "align":
|
| 84 |
+
self.epochs = self.model.align_epochs
|
| 85 |
+
self.max_steps = self.model.align_max_steps
|
| 86 |
+
self.global_batch_size = self.model.align_global_batch_size
|
| 87 |
+
self.per_device_batch_size = self.model.align_per_device_batch_size
|
| 88 |
+
|
| 89 |
+
self.learning_rate = self.model.align_learning_rate
|
| 90 |
+
self.weight_decay = self.model.align_weight_decay
|
| 91 |
+
self.max_grad_norm = self.model.align_max_grad_norm
|
| 92 |
+
self.lr_scheduler_type = self.model.align_lr_scheduler_type
|
| 93 |
+
self.warmup_ratio = self.model.align_warmup_ratio
|
| 94 |
+
|
| 95 |
+
self.train_strategy = self.model.align_train_strategy
|
| 96 |
+
|
| 97 |
+
elif self.stage.endswith("finetune"):
|
| 98 |
+
self.epochs = self.model.finetune_epochs
|
| 99 |
+
self.max_steps = self.model.finetune_max_steps
|
| 100 |
+
self.global_batch_size = self.model.finetune_global_batch_size
|
| 101 |
+
self.per_device_batch_size = self.model.finetune_per_device_batch_size
|
| 102 |
+
|
| 103 |
+
self.learning_rate = self.model.finetune_learning_rate
|
| 104 |
+
self.weight_decay = self.model.finetune_weight_decay
|
| 105 |
+
self.max_grad_norm = self.model.finetune_max_grad_norm
|
| 106 |
+
self.lr_scheduler_type = self.model.finetune_lr_scheduler_type
|
| 107 |
+
self.warmup_ratio = self.model.finetune_warmup_ratio
|
| 108 |
+
|
| 109 |
+
self.train_strategy = self.model.finetune_train_strategy
|
| 110 |
+
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Stage `{self.stage}` is not supported!")
|
| 113 |
+
|
| 114 |
+
# fmt: on
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@draccus.wrap()
|
| 118 |
+
def pretrain(cfg: PretrainConfig) -> None:
|
| 119 |
+
overwatch.info("Prismatic VLM Training :: Gathering Light")
|
| 120 |
+
|
| 121 |
+
# Note => Under `torchrun` initializing `overwatch` will automatically set up `torch.distributed`
|
| 122 |
+
torch.cuda.set_device(device_id := overwatch.local_rank())
|
| 123 |
+
torch.cuda.empty_cache()
|
| 124 |
+
|
| 125 |
+
# Create Unique Run Name & Save Directory
|
| 126 |
+
model_id = cfg.model.model_id
|
| 127 |
+
if (dataset_id := cfg.dataset.dataset_id) == "llava-v15":
|
| 128 |
+
cfg.run_id = f"{model_id}+stage-{cfg.stage}+x{cfg.seed}" if cfg.run_id is None else cfg.run_id
|
| 129 |
+
else:
|
| 130 |
+
cfg.run_id = f"{dataset_id}+{model_id}+stage-{cfg.stage}+x{cfg.seed}" if cfg.run_id is None else cfg.run_id
|
| 131 |
+
|
| 132 |
+
# Start =>> Build Directories and Set Randomness
|
| 133 |
+
overwatch.info('"Life is like a prism; what you see depends on how you turn the glass."', ctx_level=1)
|
| 134 |
+
hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token]
|
| 135 |
+
worker_init_fn = set_global_seed(cfg.seed, get_worker_init_fn=True)
|
| 136 |
+
os.makedirs(run_dir := (cfg.run_root_dir / cfg.run_id), exist_ok=True)
|
| 137 |
+
os.makedirs(cfg.run_root_dir / cfg.run_id / "checkpoints", exist_ok=True)
|
| 138 |
+
if overwatch.is_rank_zero():
|
| 139 |
+
# Additionally save a JSON version of the config
|
| 140 |
+
draccus.dump(cfg, open(run_dir / "config.yaml", "w"))
|
| 141 |
+
with open(run_dir / "config.yaml", "r") as f_yaml, open(run_dir / "config.json", "w") as f_json:
|
| 142 |
+
yaml_cfg = yaml.safe_load(f_yaml)
|
| 143 |
+
json.dump(yaml_cfg, f_json, indent=2)
|
| 144 |
+
|
| 145 |
+
# Load Vision Backbone --> on CPU, in Full Precision (initializing model, image_transform via TIMM)
|
| 146 |
+
overwatch.info(f"Loading Vision Backbone [bold]{cfg.model.vision_backbone_id}[/] via TIMM ")
|
| 147 |
+
vision_backbone, image_transform = get_vision_backbone_and_transform(
|
| 148 |
+
cfg.model.vision_backbone_id, image_resize_strategy=cfg.model.image_resize_strategy
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Load LLM Backbone --> on CPU, in Full Precision (initializing Tokenizer + handling special tokens if necessary)
|
| 152 |
+
overwatch.info(f"Loading Pretrained LLM [bold]{cfg.model.llm_backbone_id}[/] via HF Transformers")
|
| 153 |
+
llm_backbone, tokenizer = get_llm_backbone_and_tokenizer(
|
| 154 |
+
cfg.model.llm_backbone_id, llm_max_length=cfg.model.llm_max_length, hf_token=hf_token
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Create VLM => wraps `vision_backbone` and `llm`
|
| 158 |
+
overwatch.info(f"Instantiating PrismaticVLM `{model_id}` for Training Stage = `{cfg.stage}`")
|
| 159 |
+
vlm = get_vlm(
|
| 160 |
+
model_id,
|
| 161 |
+
cfg.model.arch_specifier,
|
| 162 |
+
vision_backbone,
|
| 163 |
+
llm_backbone,
|
| 164 |
+
enable_mixed_precision_training=cfg.model.enable_mixed_precision_training,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# [Explicit] Call to `freeze_backbones` here for clarity => will log exactly what is frozen / what's not!
|
| 168 |
+
overwatch.info(f"Invoking `VLM.freeze_backbones()` for `{model_id}` => Training Stage: `{cfg.stage}`")
|
| 169 |
+
vlm.freeze_backbones(cfg.stage)
|
| 170 |
+
|
| 171 |
+
# Load Weights from Checkpoint (depends on stage, config)
|
| 172 |
+
overwatch.info(f"Invoking `VLM.load_checkpoint()` for `{model_id}` => Training Stage: `{cfg.stage}`")
|
| 173 |
+
vlm.load_from_checkpoint(cfg.stage, run_dir, pretrained_checkpoint=cfg.pretrained_checkpoint)
|
| 174 |
+
|
| 175 |
+
# Get Dataset for Specified Stage
|
| 176 |
+
overwatch.info(f"Creating Dataset `{cfg.dataset.dataset_id}` => Stage: `{cfg.stage}`")
|
| 177 |
+
train_dataset, collator = get_dataset_and_collator(
|
| 178 |
+
cfg.stage,
|
| 179 |
+
cfg.dataset,
|
| 180 |
+
image_transform,
|
| 181 |
+
tokenizer,
|
| 182 |
+
prompt_builder_fn=llm_backbone.prompt_builder_fn,
|
| 183 |
+
default_image_resolution=vision_backbone.default_image_resolution,
|
| 184 |
+
padding_side=tokenizer.padding_side,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Create Train Strategy
|
| 188 |
+
overwatch.info(f"Initializing Train Strategy `{cfg.train_strategy}`")
|
| 189 |
+
train_strategy = get_train_strategy(
|
| 190 |
+
train_strategy=cfg.train_strategy,
|
| 191 |
+
vlm=vlm,
|
| 192 |
+
device_id=device_id,
|
| 193 |
+
stage=cfg.stage,
|
| 194 |
+
epochs=cfg.epochs,
|
| 195 |
+
max_steps=cfg.max_steps,
|
| 196 |
+
global_batch_size=cfg.global_batch_size,
|
| 197 |
+
per_device_batch_size=cfg.per_device_batch_size,
|
| 198 |
+
learning_rate=cfg.learning_rate,
|
| 199 |
+
weight_decay=cfg.weight_decay,
|
| 200 |
+
max_grad_norm=cfg.max_grad_norm,
|
| 201 |
+
lr_scheduler_type=cfg.lr_scheduler_type,
|
| 202 |
+
warmup_ratio=cfg.warmup_ratio,
|
| 203 |
+
enable_gradient_checkpointing=cfg.model.enable_gradient_checkpointing,
|
| 204 |
+
enable_mixed_precision_training=cfg.model.enable_mixed_precision_training,
|
| 205 |
+
reduce_in_full_precision=cfg.model.reduce_in_full_precision,
|
| 206 |
+
worker_init_fn=worker_init_fn,
|
| 207 |
+
)
|
| 208 |
+
train_strategy.run_setup(run_dir=run_dir, n_train_examples=len(train_dataset))
|
| 209 |
+
|
| 210 |
+
# Create Metrics =>> Handles on the fly tracking, logging to specified trackers (e.g., JSONL, Weights & Biases)
|
| 211 |
+
overwatch.info(f"Creating Metrics with Active Trackers => `{cfg.trackers}`")
|
| 212 |
+
metrics = Metrics(
|
| 213 |
+
cfg.trackers,
|
| 214 |
+
cfg.run_id,
|
| 215 |
+
run_dir,
|
| 216 |
+
draccus.encode(cfg),
|
| 217 |
+
cfg.stage,
|
| 218 |
+
wandb_project=cfg.wandb_project,
|
| 219 |
+
wandb_entity=cfg.wandb_entity,
|
| 220 |
+
grad_accumulation_steps=train_strategy.grad_accumulation_steps,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Run Training
|
| 224 |
+
overwatch.info("Starting Training Loop")
|
| 225 |
+
train_strategy.run_training(train_dataset, collator, metrics, stage=cfg.stage, seed=cfg.seed)
|
| 226 |
+
|
| 227 |
+
# Finalize
|
| 228 |
+
overwatch.info("Done with Training =>> Finalizing Metrics")
|
| 229 |
+
metrics.finalize()
|
| 230 |
+
|
| 231 |
+
# And... we're done!
|
| 232 |
+
overwatch.info("... and that's all, folks!")
|
| 233 |
+
dist.barrier()
|
| 234 |
+
dist.destroy_process_group()
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
pretrain()
|
test_deepseek_moe.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
|
| 7 |
+
# 模拟常量定义
|
| 8 |
+
ACTION_DIM = 7
|
| 9 |
+
NUM_ACTIONS_CHUNK = 8
|
| 10 |
+
SHORT_NUM_ACTIONS_CHUNK = 4
|
| 11 |
+
MID_NUM_ACTIONS_CHUNK = 6
|
| 12 |
+
|
| 13 |
+
# 导入相关模块
|
| 14 |
+
from prismatic.models.action_heads import (
|
| 15 |
+
Expert,
|
| 16 |
+
DeepSeekV3AdaptiveBiasRouter,
|
| 17 |
+
MoELayer,
|
| 18 |
+
DeepSeekV3MoEActionHead,
|
| 19 |
+
TSActionHead
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def test_deepseek_moe_components():
|
| 23 |
+
"""测试DeepSeek V3 MoE组件"""
|
| 24 |
+
print("测试 DeepSeek V3 MoE 组件...")
|
| 25 |
+
|
| 26 |
+
# 测试参数
|
| 27 |
+
batch_size = 4
|
| 28 |
+
seq_len = 8
|
| 29 |
+
hidden_dim = 256
|
| 30 |
+
num_experts = 8
|
| 31 |
+
top_k = 2
|
| 32 |
+
|
| 33 |
+
# 创建测试数据
|
| 34 |
+
x = torch.randn(batch_size, seq_len, hidden_dim)
|
| 35 |
+
|
| 36 |
+
print("\n1. 测试 GELU Expert 网络:")
|
| 37 |
+
try:
|
| 38 |
+
expert = Expert(hidden_dim)
|
| 39 |
+
output = expert(x.view(-1, hidden_dim))
|
| 40 |
+
print(f" 输入形状: {x.view(-1, hidden_dim).shape}")
|
| 41 |
+
print(f" 输出形状: {output.shape}")
|
| 42 |
+
assert output.shape == (batch_size * seq_len, hidden_dim)
|
| 43 |
+
|
| 44 |
+
# 验证使用了GELU激活
|
| 45 |
+
print(f" 激活函数类型: {type(expert.activation).__name__}")
|
| 46 |
+
assert isinstance(expert.activation, nn.GELU)
|
| 47 |
+
print(" ✓ GELU Expert 网络测试通过")
|
| 48 |
+
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f" ✗ GELU Expert 网络测试失败: {e}")
|
| 51 |
+
|
| 52 |
+
print("\n2. 测试 DeepSeek V3 自适应偏置路由器:")
|
| 53 |
+
try:
|
| 54 |
+
router = DeepSeekV3AdaptiveBiasRouter(hidden_dim, num_experts, top_k)
|
| 55 |
+
weights, indices = router(x)
|
| 56 |
+
print(f" 输入形状: {x.shape}")
|
| 57 |
+
print(f" 权重形状: {weights.shape}")
|
| 58 |
+
print(f" 索引形状: {indices.shape}")
|
| 59 |
+
|
| 60 |
+
assert weights.shape == (batch_size, seq_len, top_k)
|
| 61 |
+
assert indices.shape == (batch_size, seq_len, top_k)
|
| 62 |
+
|
| 63 |
+
# 验证路由器有自适应偏置
|
| 64 |
+
if router.enable_bias_correction:
|
| 65 |
+
print(f" 自适应偏置形状: {router.adaptive_bias.shape}")
|
| 66 |
+
assert router.adaptive_bias.shape == (num_experts,)
|
| 67 |
+
|
| 68 |
+
# 验证负载均衡损失
|
| 69 |
+
loss = router.get_load_balancing_loss()
|
| 70 |
+
print(f" 负载均衡损失: {loss.item():.6f}")
|
| 71 |
+
|
| 72 |
+
print(" ✓ DeepSeek V3 路由器测试通过")
|
| 73 |
+
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f" ✗ DeepSeek V3 路由器测试失败: {e}")
|
| 76 |
+
|
| 77 |
+
print("\n3. 测试 DeepSeek V3 MoE层:")
|
| 78 |
+
try:
|
| 79 |
+
# 测试不带共享专家的版本
|
| 80 |
+
moe_layer = MoELayer(
|
| 81 |
+
hidden_dim,
|
| 82 |
+
num_experts,
|
| 83 |
+
top_k,
|
| 84 |
+
enable_shared_expert=False
|
| 85 |
+
)
|
| 86 |
+
output = moe_layer(x)
|
| 87 |
+
print(f" 输入形状: {x.shape}")
|
| 88 |
+
print(f" 输出形状: {output.shape}")
|
| 89 |
+
assert output.shape == x.shape
|
| 90 |
+
|
| 91 |
+
# 测试带共享专家的版本
|
| 92 |
+
moe_layer_shared = MoELayer(
|
| 93 |
+
hidden_dim,
|
| 94 |
+
num_experts,
|
| 95 |
+
top_k,
|
| 96 |
+
enable_shared_expert=True,
|
| 97 |
+
num_shared_experts=2
|
| 98 |
+
)
|
| 99 |
+
output_shared = moe_layer_shared(x)
|
| 100 |
+
print(f" 带共享专家输出形状: {output_shared.shape}")
|
| 101 |
+
assert output_shared.shape == x.shape
|
| 102 |
+
|
| 103 |
+
# 验证负载均衡
|
| 104 |
+
load_loss = moe_layer.get_load_balancing_loss()
|
| 105 |
+
print(f" 负载均衡损失: {load_loss.item():.6f}")
|
| 106 |
+
|
| 107 |
+
print(" ✓ DeepSeek V3 MoE层测试通过")
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f" ✗ DeepSeek V3 MoE层测试失败: {e}")
|
| 111 |
+
|
| 112 |
+
def test_deepseek_moe_action_head():
|
| 113 |
+
"""测试DeepSeek V3 MoE动作头"""
|
| 114 |
+
print("\n4. 测试 DeepSeek V3 MoE 动作头:")
|
| 115 |
+
|
| 116 |
+
# 测试参数
|
| 117 |
+
batch_size = 2
|
| 118 |
+
input_dim = 512
|
| 119 |
+
hidden_dim = 256
|
| 120 |
+
action_dim = 7
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
# 创建模型
|
| 124 |
+
model = DeepSeekV3MoEActionHead(
|
| 125 |
+
input_dim=input_dim,
|
| 126 |
+
hidden_dim=hidden_dim,
|
| 127 |
+
action_dim=action_dim,
|
| 128 |
+
num_routed_experts=8,
|
| 129 |
+
top_k=2,
|
| 130 |
+
num_moe_layers=2,
|
| 131 |
+
enable_shared_expert=True
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# 测试单token输入
|
| 135 |
+
actions_hidden_states_single = torch.randn(batch_size, 1, input_dim)
|
| 136 |
+
output_single = model.predict_action(actions_hidden_states_single)
|
| 137 |
+
print(f" 单token输入形状: {actions_hidden_states_single.shape}")
|
| 138 |
+
print(f" 单token输出形状: {output_single.shape}")
|
| 139 |
+
assert output_single.shape == (batch_size, NUM_ACTIONS_CHUNK, action_dim)
|
| 140 |
+
|
| 141 |
+
# 测试多token输入
|
| 142 |
+
actions_hidden_states_multi = torch.randn(batch_size, ACTION_DIM, input_dim)
|
| 143 |
+
output_multi = model.predict_action(actions_hidden_states_multi)
|
| 144 |
+
print(f" 多token输入形状: {actions_hidden_states_multi.shape}")
|
| 145 |
+
print(f" 多token输出形状: {output_multi.shape}")
|
| 146 |
+
assert output_multi.shape == (batch_size, NUM_ACTIONS_CHUNK, action_dim)
|
| 147 |
+
|
| 148 |
+
# 测试负载均衡损失
|
| 149 |
+
load_loss = model.get_load_balancing_loss()
|
| 150 |
+
print(f" 模型负载均衡损失: {load_loss.item():.6f}")
|
| 151 |
+
|
| 152 |
+
# 测试专家使用统计
|
| 153 |
+
model.train()
|
| 154 |
+
_ = model.predict_action(actions_hidden_states_single) # 触发统计更新
|
| 155 |
+
stats = model.get_expert_usage_stats()
|
| 156 |
+
print(f" 专家使用统计层数: {len(stats)}")
|
| 157 |
+
|
| 158 |
+
print(" ✓ DeepSeek V3 MoE 动作头测试通过")
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f" ✗ DeepSeek V3 MoE 动作头测试失败: {e}")
|
| 162 |
+
|
| 163 |
+
def test_comparison_with_traditional_methods():
|
| 164 |
+
"""比较DeepSeek V3 MoE与传统方法"""
|
| 165 |
+
print("\n5. 性能比较测试:")
|
| 166 |
+
|
| 167 |
+
# 测试参数
|
| 168 |
+
batch_size = 2
|
| 169 |
+
input_dim = 512
|
| 170 |
+
hidden_dim = 256
|
| 171 |
+
action_dim = 7
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
# 传统FFN方法
|
| 175 |
+
model_ffn = TSActionHead(
|
| 176 |
+
input_dim=input_dim,
|
| 177 |
+
hidden_dim=hidden_dim,
|
| 178 |
+
action_dim=action_dim,
|
| 179 |
+
mlp_type='ffn',
|
| 180 |
+
decoder_num_blocks=2
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# 旧版MoE方法
|
| 184 |
+
model_old_moe = TSActionHead(
|
| 185 |
+
input_dim=input_dim,
|
| 186 |
+
hidden_dim=hidden_dim,
|
| 187 |
+
action_dim=action_dim,
|
| 188 |
+
mlp_type='moe',
|
| 189 |
+
num_experts=8,
|
| 190 |
+
top_k=2,
|
| 191 |
+
decoder_num_blocks=2
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# DeepSeek V3 MoE方法
|
| 195 |
+
model_deepseek_moe = DeepSeekV3MoEActionHead(
|
| 196 |
+
input_dim=input_dim,
|
| 197 |
+
hidden_dim=hidden_dim,
|
| 198 |
+
action_dim=action_dim,
|
| 199 |
+
num_routed_experts=8,
|
| 200 |
+
top_k=2,
|
| 201 |
+
num_moe_layers=2,
|
| 202 |
+
enable_shared_expert=True
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# 计算参数量
|
| 206 |
+
params_ffn = sum(p.numel() for p in model_ffn.parameters())
|
| 207 |
+
params_old_moe = sum(p.numel() for p in model_old_moe.parameters())
|
| 208 |
+
params_deepseek_moe = sum(p.numel() for p in model_deepseek_moe.parameters())
|
| 209 |
+
|
| 210 |
+
print(f" FFN 模型参数量: {params_ffn:,}")
|
| 211 |
+
print(f" 旧版 MoE 参数量: {params_old_moe:,}")
|
| 212 |
+
print(f" DeepSeek V3 MoE 参数量: {params_deepseek_moe:,}")
|
| 213 |
+
print(f" DeepSeek V3 vs FFN 参数比例: {params_deepseek_moe / params_ffn:.2f}x")
|
| 214 |
+
print(f" DeepSeek V3 vs 旧版MoE 参数比例: {params_deepseek_moe / params_old_moe:.2f}x")
|
| 215 |
+
|
| 216 |
+
# 测试推理时间(简单测试)
|
| 217 |
+
import time
|
| 218 |
+
|
| 219 |
+
test_input = torch.randn(batch_size, 1, input_dim)
|
| 220 |
+
|
| 221 |
+
# FFN推理时间
|
| 222 |
+
start_time = time.time()
|
| 223 |
+
for _ in range(100):
|
| 224 |
+
_ = model_ffn.predict_action(test_input)
|
| 225 |
+
ffn_time = time.time() - start_time
|
| 226 |
+
|
| 227 |
+
# DeepSeek V3 MoE推理时间
|
| 228 |
+
start_time = time.time()
|
| 229 |
+
for _ in range(100):
|
| 230 |
+
_ = model_deepseek_moe.predict_action(test_input)
|
| 231 |
+
deepseek_time = time.time() - start_time
|
| 232 |
+
|
| 233 |
+
print(f" FFN 推理时间 (100次): {ffn_time:.4f}s")
|
| 234 |
+
print(f" DeepSeek V3 MoE 推理时间 (100次): {deepseek_time:.4f}s")
|
| 235 |
+
print(f" 推理时间比例: {deepseek_time / ffn_time:.2f}x")
|
| 236 |
+
|
| 237 |
+
print(" ✓ 性能比较测试完成")
|
| 238 |
+
|
| 239 |
+
except Exception as e:
|
| 240 |
+
print(f" ✗ 性能比较测试失败: {e}")
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
test_deepseek_moe_components()
|
| 244 |
+
test_deepseek_moe_action_head()
|
| 245 |
+
test_comparison_with_traditional_methods()
|
| 246 |
+
print("\n所有 DeepSeek V3 MoE 测试完成!")
|
vla-scripts/extern/convert_openvla_weights_to_hf.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
convert_openvla_weights_to_hf.py
|
| 3 |
+
|
| 4 |
+
Utility script for converting full OpenVLA VLA weights (from this repository, in the default "Prismatic" format) to
|
| 5 |
+
the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers``
|
| 6 |
+
via `trust_remote_code = True`.
|
| 7 |
+
|
| 8 |
+
Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the
|
| 9 |
+
line, with first-class support.
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python vla-scripts/extern/convert_openvla_weights_to_hf.py \
|
| 13 |
+
--openvla_model_path_or_id <PATH TO PRISMATIC TRAINING RUN DIR> \
|
| 14 |
+
--output_hf_model_local_path <OUTPUT DIR FOR CONVERTED CHECKPOINT>
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import shutil
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Dict, Union
|
| 23 |
+
|
| 24 |
+
import draccus
|
| 25 |
+
import timm
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from huggingface_hub import hf_hub_download
|
| 29 |
+
from timm.models.vision_transformer import LayerScale
|
| 30 |
+
from transformers import AutoTokenizer
|
| 31 |
+
|
| 32 |
+
from prismatic.conf import ModelConfig
|
| 33 |
+
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
|
| 34 |
+
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
|
| 35 |
+
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class HFConvertConfig:
|
| 40 |
+
# fmt: off
|
| 41 |
+
openvla_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLA (on disk or HF Hub)
|
| 42 |
+
"runs/prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7"
|
| 43 |
+
)
|
| 44 |
+
output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model
|
| 45 |
+
"hf-convert/openvla-7b"
|
| 46 |
+
)
|
| 47 |
+
output_hf_model_hub_path: str = "openvla/openvla-7b" # (Optional) Path to HF Hub Path to push
|
| 48 |
+
# model to
|
| 49 |
+
|
| 50 |
+
# HF Hub Credentials (required for Gated Models like LLaMa-2)
|
| 51 |
+
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
| 52 |
+
|
| 53 |
+
def __post_init__(self) -> None:
|
| 54 |
+
self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token
|
| 55 |
+
|
| 56 |
+
# fmt: on
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
|
| 60 |
+
# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
|
| 61 |
+
# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
|
| 62 |
+
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def ls_apply_patch(ls_module: LayerScale):
|
| 67 |
+
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
|
| 68 |
+
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
|
| 69 |
+
del ls_module.gamma
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# === Conversion Constants ===
|
| 73 |
+
PROJECTOR_KEY_MAPPING = {
|
| 74 |
+
"projector.0.weight": "projector.fc1.weight",
|
| 75 |
+
"projector.0.bias": "projector.fc1.bias",
|
| 76 |
+
"projector.2.weight": "projector.fc2.weight",
|
| 77 |
+
"projector.2.bias": "projector.fc2.bias",
|
| 78 |
+
"projector.4.weight": "projector.fc3.weight",
|
| 79 |
+
"projector.4.bias": "projector.fc3.bias",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def remap_state_dicts_for_hf(
|
| 84 |
+
prismatic_vision_backbone_state_dict: Dict[str, torch.Tensor],
|
| 85 |
+
projector_state_dict: Dict[str, torch.Tensor],
|
| 86 |
+
llm_backbone_state_dict: Dict[str, torch.Tensor],
|
| 87 |
+
use_fused_vision_backbone: bool = False,
|
| 88 |
+
) -> Dict[str, torch.Tensor]:
|
| 89 |
+
"""Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion."""
|
| 90 |
+
hf_state_dict = {}
|
| 91 |
+
|
| 92 |
+
# Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING`
|
| 93 |
+
for key, value in projector_state_dict.items():
|
| 94 |
+
hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value
|
| 95 |
+
|
| 96 |
+
# Iterate through LLM Backbone =>> replace `llm.` with `language_model.`
|
| 97 |
+
for key, value in llm_backbone_state_dict.items():
|
| 98 |
+
hf_state_dict[key.replace("llm.", "language_model.")] = value
|
| 99 |
+
|
| 100 |
+
# Iterate through Vision Backbone =>> add "vision_backbone." prefix
|
| 101 |
+
if not use_fused_vision_backbone:
|
| 102 |
+
for key, value in prismatic_vision_backbone_state_dict.items():
|
| 103 |
+
hf_state_dict[key.replace("featurizer.", "vision_backbone.featurizer.")] = value
|
| 104 |
+
else:
|
| 105 |
+
# Note =>> Assumes that backbones are always DINO + SigLIP...
|
| 106 |
+
for key, value in prismatic_vision_backbone_state_dict.items():
|
| 107 |
+
if key.startswith("dino_featurizer"):
|
| 108 |
+
if key.endswith(".gamma"):
|
| 109 |
+
# Handle `LayerScale gamma` =>> DINOv2 only!
|
| 110 |
+
key = key.replace(".gamma", ".scale_factor")
|
| 111 |
+
hf_state_dict[key.replace("dino_featurizer.", "vision_backbone.featurizer.")] = value
|
| 112 |
+
elif key.startswith("siglip_featurizer"):
|
| 113 |
+
hf_state_dict[key.replace("siglip_featurizer.", "vision_backbone.fused_featurizer.")] = value
|
| 114 |
+
|
| 115 |
+
return hf_state_dict
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@draccus.wrap()
|
| 119 |
+
def convert_openvla_weights_to_hf(cfg: HFConvertConfig) -> None:
|
| 120 |
+
print(f"[*] Converting OpenVLA Model `{cfg.openvla_model_path_or_id}` to HF Transformers Format")
|
| 121 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 122 |
+
|
| 123 |
+
# Get `config.json`, 'dataset_statistics.json' and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py`
|
| 124 |
+
if os.path.isdir(cfg.openvla_model_path_or_id):
|
| 125 |
+
print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.openvla_model_path_or_id))}`")
|
| 126 |
+
config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt"
|
| 127 |
+
dataset_statistics_json = run_dir / "dataset_statistics.json"
|
| 128 |
+
|
| 129 |
+
assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`"
|
| 130 |
+
assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`"
|
| 131 |
+
assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`"
|
| 132 |
+
else:
|
| 133 |
+
print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.openvla_model_path_or_id}`")
|
| 134 |
+
config_json = hf_hub_download("openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/config.json")
|
| 135 |
+
checkpoint_pt = hf_hub_download(
|
| 136 |
+
"openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/checkpoints/latest-checkpoint.pt"
|
| 137 |
+
)
|
| 138 |
+
dataset_statistics_json = hf_hub_download(
|
| 139 |
+
"openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/dataset_statistics.json"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer
|
| 143 |
+
with open(config_json, "r") as f:
|
| 144 |
+
vla_cfg = json.load(f)["vla"]
|
| 145 |
+
prismatic_config = ModelConfig.get_choice_class(vla_cfg["base_vlm"])().__dict__
|
| 146 |
+
|
| 147 |
+
# Load Normalization Statistics
|
| 148 |
+
with open(dataset_statistics_json, "r") as f:
|
| 149 |
+
norm_stats = json.load(f)
|
| 150 |
+
|
| 151 |
+
# Create HF OpenVLAConfig (`transformers.PretrainedConfig`)
|
| 152 |
+
hf_config = OpenVLAConfig(
|
| 153 |
+
vision_backbone_id=prismatic_config["vision_backbone_id"],
|
| 154 |
+
llm_backbone_id=prismatic_config["llm_backbone_id"],
|
| 155 |
+
arch_specifier=prismatic_config["arch_specifier"],
|
| 156 |
+
image_resize_strategy=prismatic_config["image_resize_strategy"],
|
| 157 |
+
llm_max_length=prismatic_config["llm_max_length"],
|
| 158 |
+
torch_dtype=torch.bfloat16,
|
| 159 |
+
norm_stats=norm_stats,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer`
|
| 163 |
+
# TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`!
|
| 164 |
+
print("[*] Instantiating and Patching Tokenizer, LLM Config")
|
| 165 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 166 |
+
hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right"
|
| 167 |
+
)
|
| 168 |
+
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 169 |
+
tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload...
|
| 170 |
+
assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!"
|
| 171 |
+
assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!"
|
| 172 |
+
|
| 173 |
+
# Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate
|
| 174 |
+
hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of
|
| 175 |
+
hf_config.text_config.pad_token_id = hf_config.pad_token_id
|
| 176 |
+
hf_config.text_config.torch_dtype = torch.bfloat16
|
| 177 |
+
assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!"
|
| 178 |
+
|
| 179 |
+
# Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform`
|
| 180 |
+
# =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py`
|
| 181 |
+
print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor")
|
| 182 |
+
input_sizes, interpolations, means, stds = [], [], [], []
|
| 183 |
+
for idx, timm_model_id in enumerate(hf_config.timm_model_ids):
|
| 184 |
+
timm_vision_backbone = timm.create_model(
|
| 185 |
+
timm_model_id,
|
| 186 |
+
pretrained=True,
|
| 187 |
+
num_classes=0,
|
| 188 |
+
img_size=hf_config.image_sizes[idx],
|
| 189 |
+
act_layer=hf_config.timm_override_act_layers[idx],
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Get Per-Backbone Image Processing
|
| 193 |
+
data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone)
|
| 194 |
+
input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx]))
|
| 195 |
+
interpolations.append(data_cfg["interpolation"])
|
| 196 |
+
means.append(data_cfg["mean"])
|
| 197 |
+
stds.append(data_cfg["std"])
|
| 198 |
+
|
| 199 |
+
# Patch `LayerScale` because of HF annoying `fix_key` overwrite...
|
| 200 |
+
for module in timm_vision_backbone.modules():
|
| 201 |
+
if isinstance(module, LayerScale):
|
| 202 |
+
ls_apply_patch(module)
|
| 203 |
+
|
| 204 |
+
# Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`)
|
| 205 |
+
hf_image_processor = PrismaticImageProcessor(
|
| 206 |
+
use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
|
| 207 |
+
image_resize_strategy=hf_config.image_resize_strategy,
|
| 208 |
+
input_sizes=input_sizes,
|
| 209 |
+
interpolations=interpolations,
|
| 210 |
+
means=means,
|
| 211 |
+
stds=stds,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor)
|
| 215 |
+
print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor")
|
| 216 |
+
hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer)
|
| 217 |
+
|
| 218 |
+
# Load Prismatic Model State Dictionary (in preparation for conversion)
|
| 219 |
+
print("[*] Loading Prismatic VLM State Dictionary from Checkpoint")
|
| 220 |
+
model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"]
|
| 221 |
+
assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?"
|
| 222 |
+
assert all([k in model_state_dict for k in ["vision_backbone", "projector", "llm_backbone"]]), "Missing keys!"
|
| 223 |
+
|
| 224 |
+
# Convert
|
| 225 |
+
print("[*] Running Conversion")
|
| 226 |
+
converted_state_dict = remap_state_dicts_for_hf(
|
| 227 |
+
model_state_dict["vision_backbone"],
|
| 228 |
+
model_state_dict["projector"],
|
| 229 |
+
model_state_dict["llm_backbone"],
|
| 230 |
+
use_fused_vision_backbone=hf_config.use_fused_vision_backbone,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM
|
| 234 |
+
print("[*] Building (Randomly Initialized) Model =>> OpenVLAForActionPrediction")
|
| 235 |
+
hf_model = OpenVLAForActionPrediction(hf_config)
|
| 236 |
+
hf_model.load_state_dict(converted_state_dict, strict=True, assign=True)
|
| 237 |
+
|
| 238 |
+
# Cast Model to BF16 before Saving
|
| 239 |
+
hf_model.to(torch.bfloat16)
|
| 240 |
+
|
| 241 |
+
# Save Pretrained Versions to Local Path
|
| 242 |
+
print("[*] Saving Model & Processor to Local Path")
|
| 243 |
+
hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB")
|
| 244 |
+
hf_image_processor.save_pretrained(cfg.output_hf_model_local_path)
|
| 245 |
+
hf_processor.save_pretrained(cfg.output_hf_model_local_path)
|
| 246 |
+
|
| 247 |
+
# Copy `dataset_statistics.json` File to Converted Checkpoint Directory
|
| 248 |
+
output_dataset_statistics_json = cfg.output_hf_model_local_path / "dataset_statistics.json"
|
| 249 |
+
shutil.copyfile(dataset_statistics_json, output_dataset_statistics_json)
|
| 250 |
+
|
| 251 |
+
print(f"[*] Saving Complete! Saved converted checkpoint to: {cfg.output_hf_model_local_path}")
|
| 252 |
+
|
| 253 |
+
#####################################################################################
|
| 254 |
+
# Optional: Push Model to Hugging Face Hub
|
| 255 |
+
#####################################################################################
|
| 256 |
+
|
| 257 |
+
# # Register AutoClasses
|
| 258 |
+
# OpenVLAConfig.register_for_auto_class()
|
| 259 |
+
# PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor")
|
| 260 |
+
# PrismaticProcessor.register_for_auto_class("AutoProcessor")
|
| 261 |
+
# OpenVLAForActionPrediction.register_for_auto_class("AutoModelForVision2Seq")
|
| 262 |
+
|
| 263 |
+
# # Push to HF Hub
|
| 264 |
+
# print("[*] Pushing Model & Processor to HF Hub")
|
| 265 |
+
# hf_config.push_to_hub(cfg.output_hf_model_hub_path)
|
| 266 |
+
# hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB")
|
| 267 |
+
# hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path)
|
| 268 |
+
# hf_processor.push_to_hub(cfg.output_hf_model_hub_path)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
convert_openvla_weights_to_hf()
|