diff --git a/.gitattributes b/.gitattributes index 82e33954dca281a89464c0fc6cc634f58a761794..597c33bb0e439485785b31c3a29614404b5156ad 100644 --- a/.gitattributes +++ b/.gitattributes @@ -2407,3 +2407,4 @@ wandb/wandb/run-20260419_111433-oh7yfg1j/run-oh7yfg1j.wandb filter=lfs diff=lfs 0422_QwenLatent_13tasks_stateactionprior_50k/videos/pytorch_model/n_action_steps_10_max_episode_steps_720_n_envs_1_gr1_unified/PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_Env/eb75ba54-de01-4932-90a2-36f3cdeaefaa_success0.mp4 filter=lfs diff=lfs merge=lfs -text 0422_QwenLatent_13tasks_stateactionprior_50k/videos/pytorch_model/n_action_steps_10_max_episode_steps_720_n_envs_1_gr1_unified/PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_Env/fb67ee40-aa3d-4616-b022-a36af1f3d7d0_success1.mp4 filter=lfs diff=lfs merge=lfs -text 0422_QwenLatent_13tasks_stateactionprior_50k/videos/pytorch_model/n_action_steps_10_max_episode_steps_720_n_envs_1_gr1_unified/PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_Env/fcdb3d2a-4bc1-4a9b-ae8e-7d9900cea828_success1.mp4 filter=lfs diff=lfs merge=lfs -text +code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/code/__init__.py b/code/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/code/__pycache__/__init__.cpython-310.pyc b/code/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..666ad4051bce5456714e8bce666b79f1ab77b6e8 Binary files /dev/null and b/code/__pycache__/__init__.cpython-310.pyc differ diff --git a/code/__pycache__/__init__.cpython-311.pyc b/code/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d9c8b0ed79534d2dba76a2641881c74481780d5 Binary files /dev/null and b/code/__pycache__/__init__.cpython-311.pyc differ diff --git a/code/config/deepseeds/deepspeed_zero2.yaml b/code/config/deepseeds/deepspeed_zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ccb145ffa7a8995655e497b356ecfec8e416027 --- /dev/null +++ b/code/config/deepseeds/deepspeed_zero2.yaml @@ -0,0 +1,9 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_config_file: "./starVLA/config/deepseeds/ds_config.yaml" + deepspeed_multinode_launcher: standard + zero3_init_flag: false +distributed_type: DEEPSPEED +num_machines: 1 +num_processes: 8 \ No newline at end of file diff --git a/code/config/deepseeds/deepspeed_zero3.yaml b/code/config/deepseeds/deepspeed_zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1b7cca733e95616507bca81fe71a4c96530c19d --- /dev/null +++ b/code/config/deepseeds/deepspeed_zero3.yaml @@ -0,0 +1,7 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_config_file: "./starVLA/config/deepseeds/zero3.yaml" + deepspeed_multinode_launcher: standard + zero3_init_flag: false +distributed_type: DEEPSPEED \ No newline at end of file diff --git a/code/config/deepseeds/ds_config.yaml b/code/config/deepseeds/ds_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d461f95290fe6574255262ead9e7dcbb81a245e --- /dev/null +++ b/code/config/deepseeds/ds_config.yaml @@ -0,0 +1,23 @@ +{ + "fp16": { + "enabled": false + }, + "bf16": { + "enabled": true + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": 1, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "overlap_comm": true, + "contiguous_gradients": true, + "cpu_offload": false + }, + "gradient_clipping": 1.0, + "steps_per_print": 10 +} \ No newline at end of file diff --git a/code/config/deepseeds/zero2.yaml b/code/config/deepseeds/zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4df9d409dfc682961b2e9e970721a75cdd3667f --- /dev/null +++ b/code/config/deepseeds/zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/code/config/deepseeds/zero3.yaml b/code/config/deepseeds/zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..90a219f8070df25ee448e7d0f6a850c1963644a6 --- /dev/null +++ b/code/config/deepseeds/zero3.yaml @@ -0,0 +1,28 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": 5e8, + "stage3_prefetch_bucket_size": 5e8, + "stage3_param_persistence_threshold": 1e6, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + } +} \ No newline at end of file diff --git a/code/config/training/starvla_train_actionmodel_oxe.yaml b/code/config/training/starvla_train_actionmodel_oxe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca4db3ec78a74b91823e0129bc6fe95626a4089e --- /dev/null +++ b/code/config/training/starvla_train_actionmodel_oxe.yaml @@ -0,0 +1,85 @@ +run_id: vla_jepa_temp +run_root_dir: ./runs +seed: 21 +trackers: [jsonl, wandb] +wandb_entity: timsty +wandb_project: vla_jepa +is_debug: false + +framework: + name: ActionModelFM + action_model: + action_size: 37 + state_size: 74 + use_state: ${datasets.vla_data.state_use_action_chunk} + hidden_size: 1024 + intermediate_size: 3072 + dataset_vocab_size: 256 + num_data_tokens: 32 + mask_ratio_mode: "uniform_per_traj" + mask_ratio_min: 0.25 + mask_ratio_max: 0.75 + min_action_len: 5 + num_encoder_layers: 28 + num_decoder_layers: 28 + num_attention_heads: 16 + num_key_value_heads: 8 + head_dim: 128 + max_position_embeddings: 4096 + max_action_chunk_size: 50 + rms_norm_eps: 1.0e-6 + attention_dropout: 0.0 + # --- Action model loss mode (choose one combination) --- + use_masked_action_recon: false # true = add reconstruction loss for masked-action view (two-view training) + qwen3_pretrained_name_or_path: /mnt/data/fangyu/model/Qwen/Qwen3-0.6B + +datasets: + vla_data: + dataset_py: lerobot_datasets + data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY + data_mix: cross_embodiedment_13tasks + require_language: false + # action_type: delta_ee + default_image_resolution: [3, 224, 224] + per_device_batch_size: 256 + load_all_data_for_training: true + load_video: false + obs: ["image_0"] + image_size: [224,224] + video_backend: torchcodec + chunk_size: 15 + # state chunk aligned with action: state shape (L, state_dim) like action (L, action_dim) + state_use_action_chunk: true + +trainer: + epochs: 1000 + max_train_steps: 5000 + num_warmup_steps: 1000 + save_interval: 5000 + eval_interval: 50 + learning_rate: + base: 1e-04 + lr_scheduler_type: cosine_with_min_lr + scheduler_specific_kwargs: + min_lr: 5.0e-07 + freeze_modules: '' + loss_scale: + vla: 1.0 + warmup_ratio: 0.1 + weight_decay: 0.0 + logging_frequency: 10 + gradient_clipping: 5 + gradient_accumulation_steps: 1 + + optimizer: + name: AdamW + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + + # parameters to be determined + is_resume: false + resume_epoch: null + resume_step: null + enable_gradient_checkpointing: true + enable_mixed_precision_training: true diff --git a/code/config/training/starvla_train_pi0.yaml b/code/config/training/starvla_train_pi0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce88c0ea70c89294c6fb85b402a2004210ddca6f --- /dev/null +++ b/code/config/training/starvla_train_pi0.yaml @@ -0,0 +1,104 @@ +# PI0 训练配置 - 使用 unified 37D action 表示 +# action/state 投影层(原 openpi 硬编码 32D)会在 PI0Framework 初始化时自动替换为 37D, +# checkpoint 中对应的 32D 参数加载时自动跳过,其余 backbone 参数正常复用。 + +run_id: pi0_unified_37d +run_root_dir: ./runs +seed: 42 +trackers: [jsonl, wandb] +wandb_entity: timsty +wandb_project: vla_jepa +is_debug: false + +framework: + name: PI0 + # PI0 模型配置 + # action_dim 以本项目为准(统一 37D unified action 表示)。 + # PI0Pytorch 源码中 action_in_proj / action_out_proj / state_proj 硬编码为 32D, + # PI0Framework.__init__ 会调用 _replace_pi0_projection_layers 将其替换为 37D, + # 加载 checkpoint 时这些层因 shape 不匹配会自动跳过(保持随机初始化)。 + # 其余 VLM backbone 层(PaliGemma、action expert transformer 等)仍正常从 checkpoint 加载。 + pi0: + paligemma_variant: "gemma_2b" + action_expert_variant: "gemma_300m" + pi05: false + action_dim: 37 # 项目统一维度;投影层会被自动替换,checkpoint 同维度参数跳过加载 + state_dim: 74 # unified state 维度;state_proj 替换为 Linear(74, width),与 action_dim 独立 + action_horizon: 15 # 与 chunk_size 对齐 + dtype: "bfloat16" + + # 预训练权重路径(pi05_libero 等,action_dim 不匹配时会 strict=False 部分加载) + pi0_checkpoint: /mnt/data/fangyu/model/openpi/openpi-assets/checkpoints/pi0_base_torch/model.pt + + # PaliGemma tokenizer + tokenizer_path: /root/.cache/openpi/big_vision/paligemma_tokenizer.model + + # 图像键名,与 openpi 三视角格式对应;gr1 单视角时配合 replicate_single_view + image_keys: + - "base_0_rgb" + - "left_wrist_0_rgb" + - "right_wrist_0_rgb" + + # 当 dataset 仅提供 1 张图时复制到 3 视角(如 fourier_gr1 video.ego_view) + replicate_single_view: true + + use_state: true + + # 若 true,根据实际图像数量动态使用 image_keys 的前 N 个;否则固定全部 keys,不足补零 + dynamic_image_keys: false + + num_inference_steps: 10 + + # 输出截断维度,null 表示输出完整 action_dim + effective_action_dim: null + +datasets: + vla_data: + dataset_py: lerobot_datasets + data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY + data_mix: cross_embodiedment_simulator + default_image_resolution: [3, 224, 224] + per_device_batch_size: 32 + load_all_data_for_training: true + obs: ["image_0"] + image_size: [224, 224] + video_backend: torchcodec + load_video: true + chunk_size: 15 + state_use_action_chunk: false + num_history_steps: 0 + include_state: false # 训练 PI0 时不使用 state + +trainer: + epochs: 100 + max_train_steps: 20000 + num_warmup_steps: 5000 + num_stable_steps: 0 + save_interval: 5000 + max_checkpoints_to_keep: 20 + + learning_rate: + base: 2.5e-5 + pi0_model: 2.5e-5 + + lr_scheduler_type: warmup_stable_cosine + scheduler_specific_kwargs: + min_lr_ratio: 0.001 + + freeze_modules: "" + warmup_ratio: 0.1 + weight_decay: 0.0 + logging_frequency: 10 + gradient_clipping: 5.0 + gradient_accumulation_steps: 1 + + optimizer: + name: AdamW + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + + is_resume: false + pretrained_checkpoint: null + enable_gradient_checkpointing: false + enable_mixed_precision_training: true diff --git a/code/config/training/starvla_train_qwengr00t.yaml b/code/config/training/starvla_train_qwengr00t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..567a84de5ef4bd65ff9cc8783d690316aad9b1a1 --- /dev/null +++ b/code/config/training/starvla_train_qwengr00t.yaml @@ -0,0 +1,99 @@ +run_id: qwengr00t_oxe +run_root_dir: ./runs +seed: 42 +trackers: [jsonl, wandb] +wandb_entity: timsty +wandb_project: vla_jepa +is_debug: false + +framework: + name: QwenGR00T + qwenvl: + base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct + attn_implementation: flash_attention_2 + vl_hidden_dim: 2048 + num_data_tokens: 32 # dataset soft prompt tokens prepended to VLM input (0 = disabled) + + # QwenGR00T required action head config + action_model: + dataset_vocab_size: 256 # number of distinct dataset IDs for soft prompt embedding + action_model_type: DiT-B + hidden_size: 1024 + add_pos_embed: true + max_seq_len: 1024 + action_dim: 37 + state_dim: 74 + future_action_window_size: 14 + action_horizon: 15 + past_action_window_size: 0 + noise_beta_alpha: 1.5 + noise_beta_beta: 1.0 + noise_s: 0.999 + num_timestep_buckets: 1000 + num_inference_timesteps: 10 + num_target_vision_tokens: 32 + diffusion_model_cfg: + cross_attention_dim: 2048 + dropout: 0.2 + final_dropout: true + interleave_self_attention: true + norm_type: "ada_norm" + num_layers: 16 + output_dim: 1024 + positional_embeddings: null + +datasets: + vla_data: + dataset_py: lerobot_datasets + data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY + data_mix: cross_embodiedment_13tasks + CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?" + default_image_resolution: [3, 224, 224] + per_device_batch_size: 32 + load_all_data_for_training: true + obs: ["image_0"] + image_size: [224, 224] + video_backend: torchcodec + load_video: true + chunk_size: 15 + state_use_action_chunk: false + num_history_steps: 0 + include_state: true + +trainer: + epochs: 100 + max_train_steps: 50000 + num_warmup_steps: 5000 + num_stable_steps: 0 + save_interval: 5000 + eval_interval: 50 + max_checkpoints_to_keep: 20 + + # Used in QwenGR00T.forward() to repeat diffusion training pairs + repeated_diffusion_steps: 1 + + learning_rate: + base: 5e-05 + qwen_vl_interface: 5e-05 + action_model: 5e-05 + lr_scheduler_type: warmup_stable_cosine + scheduler_specific_kwargs: + min_lr_ratio: 0.001` + + freeze_modules: '' + warmup_ratio: 0.1 + logging_frequency: 10 + gradient_clipping: 5.0 + gradient_accumulation_steps: 4 + + optimizer: + name: AdamW + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + + is_resume: false + resume_epoch: null + resume_step: null + enable_gradient_checkpointing: true + enable_mixed_precision_training: true diff --git a/code/config/training/starvla_train_qwenlatent_history_naive_oxe.yaml b/code/config/training/starvla_train_qwenlatent_history_naive_oxe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30e55e36d0f16a649cabe7e23be6d61acc0ab36c --- /dev/null +++ b/code/config/training/starvla_train_qwenlatent_history_naive_oxe.yaml @@ -0,0 +1,106 @@ +run_id: vla_jepa_temp +run_root_dir: ./runs +seed: 42 +trackers: [jsonl, wandb] +wandb_entity: timsty +wandb_project: vla_jepa +is_debug: false + +framework: + # Naive baseline: history tokens are projected directly via two-layer MLPs + # (history_action_projector + history_state_projector) without any action + # encoder. Directly comparable to QwenLatent_history which uses the full + # action-encoder path for history encoding. + name: QwenLatent_history_naive + qwenvl: + base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct + attn_implementation: flash_attention_2 + vl_hidden_dim: 2048 + num_data_tokens: 32 + action_model: + ckpt_path: /mnt/data/fangyu/code/reward_new/runs/0417_Action_9tasks_actionstate_fixchunk15/final_model/pytorch_model.pt + # ckpt_path: null + action_size: 37 + state_size: 74 # 与 action model 一致;0 表示不使用 state + use_state: ${datasets.vla_data.state_use_action_chunk} + hidden_size: 1024 + intermediate_size: 3072 + dataset_vocab_size: 256 + num_data_tokens: 32 + min_action_len: 5 + num_encoder_layers: 28 + num_decoder_layers: 28 + num_attention_heads: 16 + num_key_value_heads: 8 + head_dim: 128 + max_position_embeddings: 2048 + max_action_chunk_size: 50 + rms_norm_eps: 1.0e-6 + attention_dropout: 0.0 + use_vae_reparameterization: false + use_ema: false # 是否使用 EMA;若为 false,则冻结 encoder,只训练 VLM 和 decoder + chunk_size: ${datasets.vla_data.chunk_size} + loss_mode: full # full, predict_only + qwen3_pretrained_name_or_path: /mnt/data/fangyu/model/Qwen/Qwen3-0.6B +datasets: + vla_data: + dataset_py: lerobot_datasets + data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY + data_mix: cross_embodiedment_simulator # bridge_rt_1 + # action_type: delta_ee + CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?" + default_image_resolution: [3, 224, 224] + per_device_batch_size: 32 + load_all_data_for_training: true + obs: ["image_0"] + image_size: [224,224] + video_backend: torchcodec + load_video: true + chunk_size: 30 + # state chunk 与 action chunk 对齐(与 action model 训练一致) + state_use_action_chunk: true + # 历史 state/action 步数;>0 时每个 sample 会多返回 state_history、action_history + num_history_steps: 15 + include_state: ${datasets.vla_data.state_use_action_chunk} + +trainer: + epochs: 100 + max_train_steps: 30000 + num_warmup_steps: 3000 + num_stable_steps: 0 # 保持 max_lr 的步数(在 warmup 之后) + mode: freeze_action_encoder_decay_aux_loss # freeze_action_encoder_decay_aux_loss + loss_weights_decay_steps: 5000 + + save_interval: 5000 + eval_interval: 50 + max_checkpoints_to_keep: 10 # 最多保留的checkpoint数量,超过则删除最旧的 + learning_rate: + base: 2.5e-05 + qwen_vl_interface: 2.5e-05 + action_model: 2.5e-05 + lr_scheduler_type: warmup_stable_cosine # options: warmup_stable_cosine (default), onecycle + scheduler_specific_kwargs: + min_lr_ratio: 0.001 # 最终 lr = base_lr * min_lr_ratio + freeze_modules: '' + loss_scale: + align_loss: 1.0 + recon_loss: 1.0 + predict_loss: 1.0 + warmup_ratio: 0.1 + weight_decay: 0.0 + logging_frequency: 10 + gradient_clipping: 5.0 + gradient_accumulation_steps: 1 + + optimizer: + name: AdamW + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + + # parameters to be determined + is_resume: false + resume_epoch: null + resume_step: null + enable_gradient_checkpointing: true + enable_mixed_precision_training: true diff --git a/code/config/training/starvla_train_qwenlatent_history_oxe.yaml b/code/config/training/starvla_train_qwenlatent_history_oxe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13b338d6af1248941394de94eedb1b5de0ccfa86 --- /dev/null +++ b/code/config/training/starvla_train_qwenlatent_history_oxe.yaml @@ -0,0 +1,102 @@ +run_id: vla_jepa_temp +run_root_dir: ./runs +seed: 42 +trackers: [jsonl, wandb] +wandb_entity: timsty +wandb_project: vla_jepa +is_debug: false + +framework: + name: QwenLatent_history + qwenvl: + base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct + attn_implementation: flash_attention_2 + vl_hidden_dim: 2048 + num_data_tokens: 32 + action_model: + ckpt_path: /mnt/data/fangyu/code/reward_new/runs/0418_Action_13tasks_actionstate_fixchunk15/final_model/pytorch_model.pt + # ckpt_path: null + action_size: 37 + state_size: 74 # 与 action model 一致;0 表示不使用 state + use_state: ${datasets.vla_data.state_use_action_chunk} + hidden_size: 1024 + intermediate_size: 3072 + dataset_vocab_size: 256 + num_data_tokens: 32 + min_action_len: 5 + num_encoder_layers: 28 + num_decoder_layers: 28 + num_attention_heads: 16 + num_key_value_heads: 8 + head_dim: 128 + max_position_embeddings: 2048 + max_action_chunk_size: 50 + rms_norm_eps: 1.0e-6 + attention_dropout: 0.0 + use_vae_reparameterization: false + use_ema: false # 是否使用 EMA;若为 false,则冻结 encoder,只训练 VLM 和 decoder + chunk_size: ${datasets.vla_data.chunk_size} + loss_mode: full # full, predict_only + qwen3_pretrained_name_or_path: /mnt/data/fangyu/model/Qwen/Qwen3-0.6B +datasets: + vla_data: + dataset_py: lerobot_datasets + data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY + data_mix: cross_embodiedment_simulator # bridge_rt_1 + # action_type: delta_ee + CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?" + default_image_resolution: [3, 224, 224] + per_device_batch_size: 32 + load_all_data_for_training: true + obs: ["image_0"] + image_size: [224,224] + video_backend: torchcodec + load_video: true + chunk_size: 30 + # state chunk 与 action chunk 对齐(与 action model 训练一致) + state_use_action_chunk: true + # 历史 state/action 步数;>0 时每个 sample 会多返回 state_history、action_history + num_history_steps: 15 + include_state: ${datasets.vla_data.state_use_action_chunk} + +trainer: + epochs: 100 + max_train_steps: 50000 + num_warmup_steps: 5000 + num_stable_steps: 0 # 保持 max_lr 的步数(在 warmup 之后) + mode: freeze_action_encoder_decay_aux_loss # freeze_action_encoder_decay_aux_loss + loss_weights_decay_steps: 5000 + + save_interval: 5000 + eval_interval: 50 + max_checkpoints_to_keep: 10 # 最多保留的checkpoint数量,超过则删除最旧的 + learning_rate: + base: 3e-05 + qwen_vl_interface: 3e-05 + action_model: 3e-05 + lr_scheduler_type: warmup_stable_cosine # options: warmup_stable_cosine (default), onecycle + scheduler_specific_kwargs: + min_lr_ratio: 0.001 # 最终 lr = base_lr * min_lr_ratio + freeze_modules: '' + loss_scale: + align_loss: 1.0 + recon_loss: 1.0 + predict_loss: 1.0 + warmup_ratio: 0.1 + weight_decay: 0.0 + logging_frequency: 10 + gradient_clipping: 5.0 + gradient_accumulation_steps: 1 + + optimizer: + name: AdamW + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + + # parameters to be determined + is_resume: false + resume_epoch: null + resume_step: null + enable_gradient_checkpointing: true + enable_mixed_precision_training: true diff --git a/code/config/training/starvla_train_qwenlatent_oxe.yaml b/code/config/training/starvla_train_qwenlatent_oxe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e01e7a9f3a36bdcf93232d08cbd9f0b459d2d38b --- /dev/null +++ b/code/config/training/starvla_train_qwenlatent_oxe.yaml @@ -0,0 +1,103 @@ +run_id: vla_jepa_temp +run_root_dir: ./runs +seed: 21 +trackers: [jsonl, wandb] +wandb_entity: timsty +wandb_project: vla_jepa +is_debug: false + +framework: + name: QwenLatent + qwenvl: + base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct + attn_implementation: flash_attention_2 + vl_hidden_dim: 2048 + num_data_tokens: 32 + action_model: + ckpt_path: /mnt/data/fangyu/code/reward_new/runs/0418_Action_13tasks_actionstate_fixchunk15/final_model/pytorch_model.pt + # ckpt_path: null + action_size: 37 + state_size: 74 # 与 action model 一致;0 表示不使用 state + use_state: ${datasets.vla_data.state_use_action_chunk} + hidden_size: 1024 + intermediate_size: 3072 + dataset_vocab_size: 256 + num_data_tokens: 32 + num_t_samples: 4 + min_action_len: 5 + num_encoder_layers: 28 + num_decoder_layers: 28 + num_attention_heads: 16 + num_key_value_heads: 8 + head_dim: 128 + max_position_embeddings: 2048 + max_action_chunk_size: 50 + rms_norm_eps: 1.0e-6 + attention_dropout: 0.0 + use_vae_reparameterization: false + use_ema: false # 是否使用 EMA;若为 false,则冻结 encoder,只训练 VLM 和 decoder + chunk_size: ${datasets.vla_data.chunk_size} + loss_mode: full # full, predict_only + qwen3_pretrained_name_or_path: /mnt/data/fangyu/model/Qwen/Qwen3-0.6B +datasets: + vla_data: + dataset_py: lerobot_datasets + data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY + data_mix: cross_embodiedment_13tasks # bridge_rt_1 + # action_type: delta_ee + CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?" + default_image_resolution: [3, 224, 224] + per_device_batch_size: 32 + load_all_data_for_training: true + obs: ["image_0"] + image_size: [224,224] + video_backend: torchcodec + load_video: true + chunk_size: 15 + # state chunk 与 action chunk 对齐(与 action model 训练一致) + state_use_action_chunk: true + # 历史 state/action 步数;>0 时每个 sample 会多返回 state_history、action_history + num_history_steps: 0 + include_state: ${datasets.vla_data.state_use_action_chunk} + +trainer: + epochs: 100 + max_train_steps: 50000 + num_warmup_steps: 5000 + num_stable_steps: 0 # 保持 max_lr 的步数(在 warmup 之后) + mode: decay_aux_loss # freeze_action_encoder_decay_aux_loss + loss_weights_decay_steps: 5000 + + save_interval: 5000 + eval_interval: 50 + max_checkpoints_to_keep: 20 # 最多保留的checkpoint数量,超过则删除最旧的 + learning_rate: + base: 5e-05 + qwen_vl_interface: 5e-05 + action_model: 5e-05 + lr_scheduler_type: warmup_stable_cosine # options: warmup_stable_cosine (default), onecycle + scheduler_specific_kwargs: + min_lr_ratio: 0.001 # 最终 lr = base_lr * min_lr_ratio + freeze_modules: '' + loss_scale: + align_loss: 1.0 + recon_loss: 1.0 + predict_loss: 1.0 + warmup_ratio: 0.1 + weight_decay: 0.0 + logging_frequency: 10 + gradient_clipping: 5.0 + gradient_accumulation_steps: 1 + + optimizer: + name: AdamW + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + + # parameters to be determined + is_resume: false + resume_epoch: null + resume_step: null + enable_gradient_checkpointing: true + enable_mixed_precision_training: true diff --git a/code/config/training/starvla_train_qwenpi.yaml b/code/config/training/starvla_train_qwenpi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..973e067ab78e241cfb6be57f1c43bce663985e8d --- /dev/null +++ b/code/config/training/starvla_train_qwenpi.yaml @@ -0,0 +1,97 @@ +run_id: qwenpi_oxe +run_root_dir: ./runs +seed: 42 +trackers: [jsonl, wandb] +wandb_entity: timsty +wandb_project: vla_jepa +is_debug: false + +framework: + name: QwenPI + qwenvl: + base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct + attn_implementation: flash_attention_2 + vl_hidden_dim: 2048 + num_data_tokens: 32 # dataset soft prompt tokens prepended to VLM input (0 = disabled) + + # QwenPI required action head config (LayerwiseFlowmatchingActionHead) + action_model: + dataset_vocab_size: 256 # number of distinct dataset IDs for soft prompt embedding + hidden_size: 1024 + add_pos_embed: true + max_seq_len: 1024 + action_dim: 37 + state_dim: 74 + future_action_window_size: 14 + action_horizon: 15 + past_action_window_size: 0 + noise_beta_alpha: 1.5 + noise_beta_beta: 1.0 + noise_s: 0.999 + num_timestep_buckets: 1000 + num_inference_timesteps: 10 + num_target_vision_tokens: 32 + diffusion_model_cfg: + dropout: 0.2 + final_dropout: true + interleave_self_attention: true + norm_type: "ada_norm" + output_dim: 1024 + positional_embeddings: null + +datasets: + vla_data: + dataset_py: lerobot_datasets + data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY + data_mix: cross_embodiedment_13tasks + CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?" + default_image_resolution: [3, 224, 224] + per_device_batch_size: 64 + load_all_data_for_training: true + obs: ["image_0"] + image_size: [224, 224] + video_backend: torchcodec + load_video: true + chunk_size: 15 + state_use_action_chunk: false + num_history_steps: 0 + include_state: true + +trainer: + epochs: 100 + max_train_steps: 50000 + num_warmup_steps: 5000 + num_stable_steps: 0 + save_interval: 5000 + eval_interval: 50 + max_checkpoints_to_keep: 20 + + # Used in QwenPI.forward() to repeat diffusion training pairs + repeated_diffusion_steps: 1 + + learning_rate: + base: 5e-05 + qwen_vl_interface: 5e-05 + action_model: 5e-05 + lr_scheduler_type: warmup_stable_cosine + scheduler_specific_kwargs: + min_lr_ratio: 0.001 + + freeze_modules: '' + warmup_ratio: 0.1 + weight_decay: 0.0 + logging_frequency: 10 + gradient_clipping: 5.0 + gradient_accumulation_steps: 1 + + optimizer: + name: AdamW + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + + is_resume: false + resume_epoch: null + resume_step: null + enable_gradient_checkpointing: true + enable_mixed_precision_training: true diff --git a/code/dataloader/__init__.py b/code/dataloader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1870aebb219b7d28dafd434b4cdd767ebd90f166 --- /dev/null +++ b/code/dataloader/__init__.py @@ -0,0 +1,70 @@ +import json +from accelerate.logging import get_logger +import numpy as np +from torch.utils.data import DataLoader +import torch.distributed as dist +from pathlib import Path +from starVLA.dataloader.vlm_datasets import make_vlm_dataloader + +logger = get_logger(__name__) + + +def _is_main_process() -> bool: + return (not dist.is_initialized()) or dist.get_rank() == 0 + +def save_dataset_statistics(dataset_statistics, run_dir): + """Saves a `dataset_statistics.json` file.""" + out_path = run_dir / "dataset_statistics.json" + with open(out_path, "w") as f_json: + for _, stats in dataset_statistics.items(): + for k in stats["action"].keys(): + if isinstance(stats["action"][k], np.ndarray): + stats["action"][k] = stats["action"][k].tolist() + if "proprio" in stats: + for k in stats["proprio"].keys(): + if isinstance(stats["proprio"][k], np.ndarray): + stats["proprio"][k] = stats["proprio"][k].tolist() + if "num_trajectories" in stats: + if isinstance(stats["num_trajectories"], np.ndarray): + stats["num_trajectories"] = stats["num_trajectories"].item() + if "num_transitions" in stats: + if isinstance(stats["num_transitions"], np.ndarray): + stats["num_transitions"] = stats["num_transitions"].item() + json.dump(dataset_statistics, f_json, indent=2) + logger.info(f"Saved dataset statistics file at path {out_path}") + + + +def build_dataloader(cfg, dataset_py="lerobot_datasets_oxe"): # TODO now here only is get dataset, we need mv dataloader to here + + if dataset_py == "lerobot_datasets": + from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn + vla_dataset_cfg = cfg.datasets.vla_data + + vla_dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) + + vla_train_dataloader = DataLoader( + vla_dataset, + batch_size=cfg.datasets.vla_data.per_device_batch_size, + collate_fn=collate_fn, + num_workers=16, + prefetch_factor=20, + shuffle=True, + persistent_workers=True, # 保持 worker 存活,避免重启开销 + pin_memory=True, # 加速 GPU 传输 + drop_last=True, # 丢弃最后不完整的 batch,避免等待 + timeout=30, # 设置超时,避免 worker 阻塞导致长时间等待 + ) + if _is_main_process(): + output_dir = Path(cfg.output_dir) + vla_dataset.save_dataset_statistics(output_dir / "dataset_statistics.json") + return vla_train_dataloader + if dataset_py == "vlm_datasets": + vlm_data_module = make_vlm_dataloader(cfg) + vlm_train_dataloader = vlm_data_module["train_dataloader"] + return vlm_train_dataloader + + raise ValueError( + f"Unsupported dataset builder `{dataset_py}`. " + "Expected one of: `lerobot_datasets`, `vlm_datasets`." + ) diff --git a/code/dataloader/__pycache__/__init__.cpython-310.pyc b/code/dataloader/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fe6363e9f643179befd76ca55687423d12b6755 Binary files /dev/null and b/code/dataloader/__pycache__/__init__.cpython-310.pyc differ diff --git a/code/dataloader/__pycache__/__init__.cpython-311.pyc b/code/dataloader/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba328397a494874bfabaef753ce14fe361dee9b5 Binary files /dev/null and b/code/dataloader/__pycache__/__init__.cpython-311.pyc differ diff --git a/code/dataloader/__pycache__/lerobot_datasets.cpython-310.pyc b/code/dataloader/__pycache__/lerobot_datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5eabd173911b64da3cdc32278aa10705fd624e Binary files /dev/null and b/code/dataloader/__pycache__/lerobot_datasets.cpython-310.pyc differ diff --git a/code/dataloader/__pycache__/lerobot_datasets.cpython-311.pyc b/code/dataloader/__pycache__/lerobot_datasets.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61261bc902e737e6e6f26f633f880e8ee4d6c6b2 Binary files /dev/null and b/code/dataloader/__pycache__/lerobot_datasets.cpython-311.pyc differ diff --git a/code/dataloader/__pycache__/vlm_datasets.cpython-310.pyc b/code/dataloader/__pycache__/vlm_datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9af99a899d1385e6d109471908748e7bdb294e43 Binary files /dev/null and b/code/dataloader/__pycache__/vlm_datasets.cpython-310.pyc differ diff --git a/code/dataloader/__pycache__/vlm_datasets.cpython-311.pyc b/code/dataloader/__pycache__/vlm_datasets.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1748c26dc9b43618d2c9ac68a9baa8fc7d88562 Binary files /dev/null and b/code/dataloader/__pycache__/vlm_datasets.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/README.md b/code/dataloader/gr00t_lerobot/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/code/dataloader/gr00t_lerobot/__init__.py b/code/dataloader/gr00t_lerobot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-310.pyc b/code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b90b023ae561868dd0b9a758ea561ecf1139210 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-311.pyc b/code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..068ca0a3e6497367f78a9186bf6a70eb7f52268c Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-310.pyc b/code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c518480a91f6ee1aee9eb012b1cc81c622a9e7a Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-311.pyc b/code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c499b555fed97e3a41956e16391be4d03dbc745 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-310.pyc b/code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..017f240703418ce614e4f2d85380ae2ea60aa3ce Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-311.pyc b/code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91668b00ee19cb121649c3d3232bce1829d76b75 --- /dev/null +++ b/code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:953210fc54145a3fac1026888bb1106cdd6351cb8d4eb9e94161669f67db91d9 +size 105575 diff --git a/code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-310.pyc b/code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7876723fc75719e26a2ae5add22fef079092f77f Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-311.pyc b/code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5af0a0294f04149816a7c5b9135463e7bbf06b5d Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-310.pyc b/code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b138cd14fa236b016d6ae5837e78fb225127a955 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-311.pyc b/code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..399d2341ed2c0facf3fe61d05498a86569724ee8 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-310.pyc b/code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22e876ba93a5085a3e175340d0c5761c7995f380 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-311.pyc b/code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0fa0082cd0e111e638e0d9a6ca2505c4efa3151 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/video.cpython-310.pyc b/code/dataloader/gr00t_lerobot/__pycache__/video.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bd8b945c6cf9f84e8f432eeba3ed83407921831 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/video.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/__pycache__/video.cpython-311.pyc b/code/dataloader/gr00t_lerobot/__pycache__/video.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c97f017a1ae13acc807df2e6d981ee781a32b1 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/__pycache__/video.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/data_config.py b/code/dataloader/gr00t_lerobot/data_config.py new file mode 100644 index 0000000000000000000000000000000000000000..99433c60385b51fe10a47e2ea1069a39d27335e8 --- /dev/null +++ b/code/dataloader/gr00t_lerobot/data_config.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from starVLA.dataloader.gr00t_lerobot.datasets import ModalityConfig +from starVLA.dataloader.gr00t_lerobot.transform.base import ComposedModalityTransform, ModalityTransform +from starVLA.dataloader.gr00t_lerobot.transform.state_action import ( + StateActionSinCosTransform, + StateActionToTensor, + StateActionTransform, +) + + +class BaseDataConfig(ABC): + @abstractmethod + def modality_config(self) -> dict[str, ModalityConfig]: + pass + + @abstractmethod + def transform(self) -> ModalityTransform: + pass + + +########################################################################################### + +class Libero4in1DataConfig: + video_keys = [ + "video.primary_image", + "video.wrist_image", + ] + + state_keys = [ + "state.x", + "state.y", + "state.z", + "state.roll", + "state.pitch", + "state.yaw", + "state.pad", + "state.gripper", + ] + action_keys = [ + "action.x", + "action.y", + "action.z", + "action.roll", + "action.pitch", + "action.yaw", + "action.gripper", + ] + + language_keys = ["annotation.human.action.task_description"] + + observation_indices = [0] + action_indices = list(range(16)) + + def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0): + self.chunk_size = chunk_size + self.action_indices = list(range(chunk_size)) + self.state_use_action_chunk = state_use_action_chunk + self.num_history_steps = int(num_history_steps or 0) + self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1] + + def modality_config(self): + video_modality = ModalityConfig( + delta_indices=self.video_observation_indices, + modality_keys=self.video_keys, + ) + state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices + state_modality = ModalityConfig( + delta_indices=state_delta, + modality_keys=self.state_keys, + ) + action_modality = ModalityConfig( + delta_indices=self.action_indices, + modality_keys=self.action_keys, + ) + language_modality = ModalityConfig( + delta_indices=self.observation_indices, + modality_keys=self.language_keys, + ) + modality_configs = { + "video": video_modality, + "state": state_modality, + "action": action_modality, + "language": language_modality, + } + return modality_configs + + def transform(self): + transforms = [ + # state transforms + StateActionToTensor(apply_to=self.state_keys), + StateActionTransform( + apply_to=self.state_keys, + normalization_modes={ + "state.x": "min_max", + "state.y": "min_max", + "state.z": "min_max", + "state.roll": "min_max", + "state.pitch": "min_max", + "state.yaw": "min_max", + "state.pad": "min_max", + # "state.gripper": "binary", + }, + ), + # action transforms + StateActionToTensor(apply_to=self.action_keys), + StateActionTransform( + apply_to=self.action_keys, + normalization_modes={ + "action.x": "min_max", + "action.y": "min_max", + "action.z": "min_max", + "action.roll": "min_max", + "action.pitch": "min_max", + "action.yaw": "min_max", + # "action.gripper": "binary", + }, + ), + ] + + return ComposedModalityTransform(transforms=transforms) + +########################################################################################### + +class RealWorldFrankaDataConfig: + """Real-world Panda robot: 7 joints + 1 gripper (8D), single-arm -> right slot [7:15].""" + video_keys = [ + "video.exterior_image_1_left", + "video.wrist_image_left", + ] + state_keys = [ + "state.joints", + "state.gripper", + ] + action_keys = [ + "action.joints", + "action.gripper", + ] + language_keys = ["annotation.human.action.task_description"] + observation_indices = [0] + action_indices = list(range(16)) + + def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0): + self.chunk_size = chunk_size + self.action_indices = list(range(chunk_size)) + self.state_use_action_chunk = state_use_action_chunk + self.num_history_steps = int(num_history_steps or 0) + self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1] + + def modality_config(self): + video_modality = ModalityConfig( + delta_indices=self.video_observation_indices, + modality_keys=self.video_keys, + ) + state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices + state_modality = ModalityConfig( + delta_indices=state_delta, + modality_keys=self.state_keys, + ) + action_modality = ModalityConfig( + delta_indices=self.action_indices, + modality_keys=self.action_keys, + ) + language_modality = ModalityConfig( + delta_indices=self.observation_indices, + modality_keys=self.language_keys, + ) + modality_configs = { + "video": video_modality, + "state": state_modality, + "action": action_modality, + "language": language_modality, + } + return modality_configs + + def transform(self): + transforms = [ + StateActionToTensor(apply_to=self.state_keys), + StateActionTransform( + apply_to=self.state_keys, + normalization_modes={ + "state.joints": "min_max", + # "state.gripper": "binary", + }, + ), + StateActionToTensor(apply_to=self.action_keys), + StateActionTransform( + apply_to=self.action_keys, + normalization_modes={ + "action.joints": "min_max", + # "action.gripper": "binary", + }, + ), + ] + return ComposedModalityTransform(transforms=transforms) + + +class AgilexDataConfig: + video_keys = [ + "video.cam_high", + "video.cam_left_wrist", + "video.cam_right_wrist", + ] + state_keys = [ + "state.left_joints", + "state.left_gripper", + "state.right_joints", + "state.right_gripper", + ] + action_keys = [ + "action.left_joints", + "action.left_gripper", + "action.right_joints", + "action.right_gripper", + ] + + language_keys = ["annotation.human.action.task_description"] + observation_indices = [0] + + def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0): + self.chunk_size = chunk_size + self.action_indices = list(range(chunk_size)) + self.state_use_action_chunk = state_use_action_chunk + self.num_history_steps = int(num_history_steps or 0) + self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1] + + def modality_config(self): + video_modality = ModalityConfig( + delta_indices=self.video_observation_indices, + modality_keys=self.video_keys, + ) + state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices + state_modality = ModalityConfig( + delta_indices=state_delta, + modality_keys=self.state_keys, + ) + action_modality = ModalityConfig( + delta_indices=self.action_indices, + modality_keys=self.action_keys, + ) + language_modality = ModalityConfig( + delta_indices=self.observation_indices, + modality_keys=self.language_keys, + ) + modality_configs = { + "video": video_modality, + "state": state_modality, + "action": action_modality, + "language": language_modality, + } + return modality_configs + + def transform(self): + transforms = [ + # state transforms + StateActionToTensor(apply_to=self.state_keys), + StateActionTransform( + apply_to=self.state_keys, + normalization_modes={ + "state.left_joints": "min_max", + "state.left_gripper": "binary", + "state.right_joints": "min_max", + "state.right_gripper": "binary", + }, + ), + # action transforms + StateActionToTensor(apply_to=self.action_keys), + StateActionTransform( + apply_to=self.action_keys, + normalization_modes={ + "action.left_joints": "min_max", + "action.left_gripper": "binary", + "action.right_joints": "min_max", + "action.right_gripper": "binary", + }, + ), + ] + return ComposedModalityTransform(transforms=transforms) + + +class FourierGr1ArmsWaistDataConfig: + video_keys = ["video.ego_view"] + state_keys = [ + "state.left_arm", + "state.right_arm", + "state.left_hand", + "state.right_hand", + "state.waist", + ] + action_keys = [ + "action.left_arm", + "action.right_arm", + "action.left_hand", + "action.right_hand", + "action.waist", + ] + language_keys = ["annotation.human.coarse_action"] + observation_indices = [0] + + def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0): + self.chunk_size = chunk_size + self.action_indices = list(range(chunk_size)) + self.state_use_action_chunk = state_use_action_chunk + self.num_history_steps = int(num_history_steps or 0) + self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1] + + def modality_config(self): + video_modality = ModalityConfig( + delta_indices=self.video_observation_indices, + modality_keys=self.video_keys, + ) + state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices + state_modality = ModalityConfig( + delta_indices=state_delta, + modality_keys=self.state_keys, + ) + action_modality = ModalityConfig( + delta_indices=self.action_indices, + modality_keys=self.action_keys, + ) + language_modality = ModalityConfig( + delta_indices=self.observation_indices, + modality_keys=self.language_keys, + ) + modality_configs = { + "video": video_modality, + "state": state_modality, + "action": action_modality, + "language": language_modality, + } + return modality_configs + + def transform(self) -> ModalityTransform: + transforms = [ + # state transforms + StateActionToTensor(apply_to=self.state_keys), + StateActionSinCosTransform(apply_to=self.state_keys), + # action transforms + StateActionToTensor(apply_to=self.action_keys), + StateActionTransform( + apply_to=self.action_keys, + normalization_modes={key: "min_max" for key in self.action_keys}, + ), + ] + return ComposedModalityTransform(transforms=transforms) + +########################################################################################### + + +def get_robot_type_config_map( + chunk_size: int = 15, + state_use_action_chunk: bool = True, + num_history_steps: int = 0, +) -> dict[str, BaseDataConfig]: + """state_use_action_chunk: when True, state uses action_indices so state has shape (L, state_dim) aligned with action chunk.""" + return { + "libero_franka": Libero4in1DataConfig( + chunk_size=chunk_size, + state_use_action_chunk=state_use_action_chunk, + num_history_steps=num_history_steps, + ), + "robotwin": AgilexDataConfig( + chunk_size=chunk_size, + state_use_action_chunk=state_use_action_chunk, + num_history_steps=num_history_steps, + ), + "fourier_gr1_arms_waist": FourierGr1ArmsWaistDataConfig( + chunk_size=chunk_size, + state_use_action_chunk=state_use_action_chunk, + num_history_steps=num_history_steps, + ), + "real_world_franka": RealWorldFrankaDataConfig( + chunk_size=chunk_size, + state_use_action_chunk=state_use_action_chunk, + num_history_steps=num_history_steps, + ), + } diff --git a/code/dataloader/gr00t_lerobot/datasets.py b/code/dataloader/gr00t_lerobot/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..de7275a30abf6d260defb55273b93a2bf8d9dade --- /dev/null +++ b/code/dataloader/gr00t_lerobot/datasets.py @@ -0,0 +1,2165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +In this file, we define 3 types of datasets: +1. LeRobotSingleDataset: a single dataset for a given embodiment tag +2. LeRobotMixtureDataset: a mixture of datasets for a given list of embodiment tags +3. CachedLeRobotSingleDataset: a single dataset for a given embodiment tag, + with caching for the video frames + +See `scripts/load_dataset.py` for examples on how to use these datasets. +""" +import os +import hashlib +import json, torch +from collections import defaultdict +from pathlib import Path +from typing import Sequence +import os, random +import numpy as np +import pandas as pd +from pydantic import BaseModel, Field, ValidationError +from torch.utils.data import Dataset +from tqdm import tqdm +from PIL import Image + +from starVLA.dataloader.gr00t_lerobot.video import get_all_frames, get_frames_by_timestamps + +from starVLA.dataloader.gr00t_lerobot.embodiment_tags import EmbodimentTag, DATASET_NAME_TO_ID +from starVLA.dataloader.gr00t_lerobot.schema import ( + DatasetMetadata, + DatasetStatisticalValues, + LeRobotModalityMetadata, + LeRobotStateActionMetadata, +) +from starVLA.dataloader.gr00t_lerobot.transform import ComposedModalityTransform + +from functools import partial +from typing import Tuple, List +import pickle + +# LeRobot v2.0 dataset file names +LE_ROBOT_MODALITY_FILENAME = "meta/modality.json" +LE_ROBOT_EPISODE_FILENAME = "meta/episodes.jsonl" +LE_ROBOT_TASKS_FILENAME = "meta/tasks.jsonl" +LE_ROBOT_INFO_FILENAME = "meta/info.json" +LE_ROBOT_STATS_FILENAME = "meta/stats_gr00t.json" +LE_ROBOT_DATA_FILENAME = "data/*/*.parquet" +LE_ROBOT_STEPS_FILENAME = "meta/steps.pkl" +EPSILON = 5e-4 + +# LeRobot v3.0 dataset file names +LE_ROBOT3_TASKS_FILENAME = "meta/tasks.parquet" +LE_ROBOT3_EPISODE_FILENAME = "meta/episodes/*/*.parquet" + + +# ============================================================================= +# Unified Representation Layout & Helpers +# ============================================================================= + +STANDARD_ACTION_DIM = 37 +# +# Unified action representation layout (0-based indices, Python slice is [start, stop)): +# Keep only: libero_franka, gr1, real_world_franka. +# +# - 0:7 -> left_arm (7D): xyz, rpy/euler, gripper +# Used by: gr1 left_arm +# - 7:14 -> right_arm (7D): same structure +# Used by: libero_franka; gr1 right_arm +# - 14:20 -> left_hand (6D): gr1 only +# - 20:26 -> right_hand (6D): gr1 only +# - 26:29 -> waist (3D): gr1 only +# - 29:37 -> joints + gripper (8D): real_world_franka only +# +# Mapping: +# libero_franka (7D) -> [7:14] (right_arm slot) +# gr1 (29D) -> [0:29] +# real_world_franka (8D) -> [29:37] (joints + gripper) + +ACTION_REPRESENTATION_SLICES = { + # Single-arm (7D) -> right_arm slot [7:14] + "franka": slice(7, 14), + + # Humanoid (29D) -> full [0:29] + "gr1": slice(0, 29), + + # Real-world (8D) -> [29:37] (joints + gripper) + "real_world_franka": slice(29, 37), +} + +STANDARD_STATE_DIM = 74 +# Mapping: +# libero_franka (8D) -> [0:8] +# real_world_franka (8D) -> [8:16] +# gr1 (58D after sin/cos) -> [16:74] + +STATE_REPRESENTATION_SLICES = { + # Single-arm (8D) + "franka": slice(0, 8), + # Real-world (8D) + "real_world_franka": slice(8, 16), + # GR1 isolated (58D, has StateActionSinCosTransform - different pipeline) + "gr1": slice(16, 74), +} + + +def standardize_action_representation( + action: np.ndarray, embodiment_tag: str +) -> np.ndarray: + """Map per-robot action to a fixed-size standard action vector.""" + target_slice = ACTION_REPRESENTATION_SLICES.get(embodiment_tag) + + # Only allow explicitly configured embodiment tags. + if target_slice is None: + raise ValueError( + f"Unknown embodiment tag '{embodiment_tag}' for action mapping. " + f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES)}" + ) + + expected_dim = target_slice.stop - target_slice.start + if action.shape[-1] != expected_dim: + raise ValueError( + f"Action dim mismatch for tag '{embodiment_tag}': " + f"{action.shape[-1]=} vs expected {expected_dim}." + ) + + standard = np.zeros( + (*action.shape[:-1], STANDARD_ACTION_DIM), dtype=action.dtype + ) + standard[..., target_slice] = action + return standard + + +def standardize_state_representation( + state: np.ndarray, embodiment_tag: str +) -> np.ndarray: + """Map per-robot state to a fixed-size standard state vector.""" + + target_slice = STATE_REPRESENTATION_SLICES.get(embodiment_tag) + + # Only allow explicitly configured embodiment tags. + if target_slice is None: + raise ValueError( + f"Unknown embodiment tag '{embodiment_tag}' for state mapping. " + f"Known tags: {sorted(STATE_REPRESENTATION_SLICES)}" + ) + + expected_dim = target_slice.stop - target_slice.start + if state.shape[-1] != expected_dim: + raise ValueError( + f"State dim mismatch for tag '{embodiment_tag}': " + f"{state.shape[-1]=} vs expected {expected_dim}." + ) + + standard = np.zeros( + (*state.shape[:-1], STANDARD_STATE_DIM), dtype=state.dtype + ) + standard[..., target_slice] = state + return standard + + +def calculate_dataset_statistics(parquet_paths: list[Path]) -> dict: + """Calculate the dataset statistics of all columns for a list of parquet files.""" + # Dataset statistics + all_low_dim_data_list = [] + # Collect all the data + # parquet_paths = parquet_paths[:3] + for parquet_path in tqdm( + sorted(list(parquet_paths)), + desc="Collecting all parquet files...", + ): + # Load the parquet file + parquet_data = pd.read_parquet(parquet_path) + parquet_data = parquet_data + all_low_dim_data_list.append(parquet_data) + + all_low_dim_data = pd.concat(all_low_dim_data_list, axis=0) + # Compute dataset statistics + dataset_statistics = {} + for le_modality in all_low_dim_data.columns: + if le_modality.startswith("annotation."): + continue + print(f"Computing statistics for {le_modality}...") + np_data = np.vstack( + [np.asarray(x, dtype=np.float32) for x in all_low_dim_data[le_modality]] + ) + dataset_statistics[le_modality] = { + "mean": np.mean(np_data, axis=0).tolist(), + "std": np.std(np_data, axis=0).tolist(), + "min": np.min(np_data, axis=0).tolist(), + "max": np.max(np_data, axis=0).tolist(), + "q01": np.quantile(np_data, 0.01, axis=0).tolist(), + "q99": np.quantile(np_data, 0.99, axis=0).tolist(), + } + return dataset_statistics + + +class ModalityConfig(BaseModel): + """Configuration for a modality.""" + + delta_indices: list[int] + """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices.""" + modality_keys: list[str] + """The keys to load for the modality in the dataset.""" + + +class LeRobotSingleDataset(Dataset): + """ + Base dataset class for LeRobot that supports sharding. + """ + def __init__( + self, + dataset_path: Path | str, + modality_configs: dict[str, ModalityConfig], + embodiment_tag: str | EmbodimentTag, + video_backend: str = "decord", + video_backend_kwargs: dict | None = None, + transforms: ComposedModalityTransform | None = None, + delete_pause_frame: bool = False, + **kwargs, + ): + """ + Initialize the dataset. + + Args: + dataset_path (Path | str): The path to the dataset. + modality_configs (dict[str, ModalityConfig]): The configuration for each modality. The keys are the modality names, and the values are the modality configurations. + See `ModalityConfig` for more details. + video_backend (str): Backend for video reading. + video_backend_kwargs (dict): Keyword arguments for the video backend when initializing the video reader. + transforms (ComposedModalityTransform): The transforms to apply to the dataset. + embodiment_tag (EmbodimentTag): Overload the embodiment tag for the dataset. e.g. define it as "new_embodiment" + """ + # first check if the path directory exists + if not Path(dataset_path).exists(): + raise FileNotFoundError(f"Dataset path {dataset_path} does not exist") + data_cfg = kwargs.get("data_cfg", {}) or {} + # indict letobot version + self._lerobot_version = data_cfg.get("lerobot_version", "v2.0") #self._indict_lerobot_version(**kwargs) + self.load_video = data_cfg.get("load_video", True) + self.num_history_steps = int(data_cfg.get("num_history_steps", 0) or 0) + + self.delete_pause_frame = delete_pause_frame + + # If video loading is disabled, skip video modality end-to-end. + if self.load_video: + self.modality_configs = modality_configs + else: + self.modality_configs = { + modality: config + for modality, config in modality_configs.items() + if modality != "video" + } + self.video_backend = video_backend + self.video_backend_kwargs = video_backend_kwargs if video_backend_kwargs is not None else {} + self.transforms = ( + transforms if transforms is not None else ComposedModalityTransform(transforms=[]) + ) + + self._dataset_path = Path(dataset_path) + self._dataset_name = self._dataset_path.name + self._dataset_id = DATASET_NAME_TO_ID.get(self._dataset_name) + if isinstance(embodiment_tag, EmbodimentTag): + self.tag = embodiment_tag.value + else: + self.tag = embodiment_tag + + self._metadata = self._get_metadata(EmbodimentTag(self.tag)) + + # LeRobot-specific config + self._lerobot_modality_meta = self._get_lerobot_modality_meta() + self._lerobot_info_meta = self._get_lerobot_info_meta() + self._data_path_pattern = self._get_data_path_pattern() + self._video_path_pattern = self._get_video_path_pattern() + self._chunk_size = self._get_chunk_size() + self._tasks = self._get_tasks() + self.curr_traj_data = None + self.curr_traj_id = None + + self._trajectory_ids, self._trajectory_lengths = self._get_trajectories() + self._modality_keys = self._get_modality_keys() + self._delta_indices = self._get_delta_indices() + self._all_steps = self._get_all_steps() + self.set_transforms_metadata(self.metadata) + self.set_epoch(0) + + print(f"Initialized dataset {self.dataset_name} with {embodiment_tag}") + + + # Check if the dataset is valid + self._check_integrity() + + @property + def dataset_path(self) -> Path: + """The path to the dataset that contains the METADATA_FILENAME file.""" + return self._dataset_path + + @property + def metadata(self) -> DatasetMetadata: + """The metadata for the dataset, loaded from metadata.json in the dataset directory""" + return self._metadata + + @property + def trajectory_ids(self) -> np.ndarray: + """The trajectory IDs in the dataset, stored as a 1D numpy array of strings.""" + return self._trajectory_ids + + @property + def trajectory_lengths(self) -> np.ndarray: + """The trajectory lengths in the dataset, stored as a 1D numpy array of integers. + The order of the lengths is the same as the order of the trajectory IDs. + """ + return self._trajectory_lengths + + @property + def all_steps(self) -> list[tuple[int, int]]: + """The trajectory IDs and base indices for all steps in the dataset. + Example: + self.trajectory_ids: [0, 1, 2] + self.trajectory_lengths: [3, 2, 4] + return: [ + ("traj_0", 0), ("traj_0", 1), ("traj_0", 2), + ("traj_1", 0), ("traj_1", 1), + ("traj_2", 0), ("traj_2", 1), ("traj_2", 2), ("traj_2", 3) + ] + """ + return self._all_steps + + @property + def modality_keys(self) -> dict: + """The modality keys for the dataset. The keys are the modality names, and the values are the keys for each modality. + + Example: { + "video": ["video.image_side_0", "video.image_side_1"], + "state": ["state.eef_position", "state.eef_rotation"], + "action": ["action.eef_position", "action.eef_rotation"], + "language": ["language.human.task"], + "timestamp": ["timestamp"], + "reward": ["reward"], + } + """ + return self._modality_keys + + @property + def delta_indices(self) -> dict[str, np.ndarray]: + """The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key.""" + return self._delta_indices + + @property + def dataset_name(self) -> str: + """The name of the dataset.""" + return self._dataset_name + + @property + def lerobot_modality_meta(self) -> LeRobotModalityMetadata: + """The metadata for the LeRobot dataset.""" + return self._lerobot_modality_meta + + @property + def lerobot_info_meta(self) -> dict: + """The metadata for the LeRobot dataset.""" + return self._lerobot_info_meta + + @property + def data_path_pattern(self) -> str: + """The path pattern for the LeRobot dataset.""" + return self._data_path_pattern + + @property + def video_path_pattern(self) -> str: + """The path pattern for the LeRobot dataset.""" + return self._video_path_pattern + + @property + def chunk_size(self) -> int: + """The chunk size for the LeRobot dataset.""" + return self._chunk_size + + @property + def tasks(self) -> pd.DataFrame: + """The tasks for the dataset.""" + return self._tasks + + def _get_metadata(self, embodiment_tag: EmbodimentTag) -> DatasetMetadata: + """Get the metadata for the dataset. + + Returns: + dict: The metadata for the dataset. + """ + + # 1. Modality metadata + modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME + assert ( + modality_meta_path.exists() + ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}" + # 1.1. State and action modalities + simplified_modality_meta: dict[str, dict] = {} + with open(modality_meta_path, "r") as f: + le_modality_meta = LeRobotModalityMetadata.model_validate(json.load(f)) + for modality in ["state", "action"]: + simplified_modality_meta[modality] = {} + le_state_action_meta: dict[str, LeRobotStateActionMetadata] = getattr( + le_modality_meta, modality + ) + for subkey in le_state_action_meta: + state_action_dtype = np.dtype(le_state_action_meta[subkey].dtype) + if np.issubdtype(state_action_dtype, np.floating): + continuous = True + else: + continuous = False + simplified_modality_meta[modality][subkey] = { + "absolute": le_state_action_meta[subkey].absolute, + "rotation_type": le_state_action_meta[subkey].rotation_type, + "shape": [ + le_state_action_meta[subkey].end - le_state_action_meta[subkey].start + ], + "continuous": continuous, + } + + # 1.2. Video modalities + le_info_path = self.dataset_path / LE_ROBOT_INFO_FILENAME + assert ( + le_info_path.exists() + ), f"Please provide a {LE_ROBOT_INFO_FILENAME} file in {self.dataset_path}" + with open(le_info_path, "r") as f: + le_info = json.load(f) + simplified_modality_meta["video"] = {} + for new_key in le_modality_meta.video: + original_key = le_modality_meta.video[new_key].original_key + if original_key is None: + original_key = new_key + le_video_meta = le_info["features"][original_key] + height = le_video_meta["shape"][le_video_meta["names"].index("height")] + width = le_video_meta["shape"][le_video_meta["names"].index("width")] + # NOTE(FH): different lerobot dataset versions have different keys for the number of channels and fps + try: + channels = le_video_meta["shape"][le_video_meta["names"].index("channel")] + fps = le_video_meta["video_info"]["video.fps"] + except (ValueError, KeyError): + # channels = le_video_meta["shape"][le_video_meta["names"].index("channels")] + channels = le_video_meta["info"]["video.channels"] + fps = le_video_meta["info"]["video.fps"] + simplified_modality_meta["video"][new_key] = { + "resolution": [width, height], + "channels": channels, + "fps": fps, + } + + # 2. Dataset statistics + stats_path = self.dataset_path / LE_ROBOT_STATS_FILENAME + try: + with open(stats_path, "r") as f: + le_statistics = json.load(f) + for stat in le_statistics.values(): + DatasetStatisticalValues.model_validate(stat) + except (FileNotFoundError, ValidationError) as e: + print(f"Failed to load dataset statistics: {e}") + print(f"Calculating dataset statistics for {self.dataset_name}") + # Get all parquet files in the dataset paths + parquet_files = list((self.dataset_path).glob(LE_ROBOT_DATA_FILENAME)) + parquet_files_filtered = [] + # parquet_files[0].name = "episode_033675.parquet" is broken file + for pf in parquet_files: + if "episode_033675.parquet" in pf.name: + continue + parquet_files_filtered.append(pf) + + le_statistics = calculate_dataset_statistics(parquet_files_filtered) + with open(stats_path, "w") as f: + json.dump(le_statistics, f, indent=4) + dataset_statistics = {} + for our_modality in ["state", "action"]: + dataset_statistics[our_modality] = {} + for subkey in simplified_modality_meta[our_modality]: + dataset_statistics[our_modality][subkey] = {} + state_action_meta = le_modality_meta.get_key_meta(f"{our_modality}.{subkey}") + assert isinstance(state_action_meta, LeRobotStateActionMetadata) + le_modality = state_action_meta.original_key + for stat_name in le_statistics[le_modality]: + indices = np.arange( + state_action_meta.start, + state_action_meta.end, + ) + stat = np.array(le_statistics[le_modality][stat_name]) + dataset_statistics[our_modality][subkey][stat_name] = stat[indices].tolist() + + # 3. Full dataset metadata + metadata = DatasetMetadata( + statistics=dataset_statistics, # type: ignore + modalities=simplified_modality_meta, # type: ignore + embodiment_tag=embodiment_tag, + ) + + return metadata + + def _get_trajectories(self) -> tuple[np.ndarray, np.ndarray]: + """Get the trajectories in the dataset.""" + # Get trajectory lengths, IDs, and whitelist from dataset metadata + # v2.0 + if self._lerobot_version == "v2.0": + file_path = self.dataset_path / LE_ROBOT_EPISODE_FILENAME + with open(file_path, "r") as f: + episode_metadata = [json.loads(line) for line in f] + trajectory_ids = [] + trajectory_lengths = [] + for episode in episode_metadata: + trajectory_ids.append(episode["episode_index"]) + trajectory_lengths.append(episode["length"]) + return np.array(trajectory_ids), np.array(trajectory_lengths) + # v3.0 + elif self._lerobot_version == "v3.0": + file_paths = list((self.dataset_path).glob(LE_ROBOT3_EPISODE_FILENAME)) + trajectory_ids = [] + trajectory_lengths = [] + # data_chunck_index = [] + # data_file_index = [] + # vido_from_index = [] + self.trajectory_ids_to_metadata = {} + for file_path in file_paths: + episodes_data = pd.read_parquet(file_path) + for index, episode in episodes_data.iterrows(): + trajectory_ids.append(episode["episode_index"]) + trajectory_lengths.append(episode["length"]) + + # TODO auto map key? just map to file_path and file_from_index + episode_meta = { + "data/chunk_index": episode["data/chunk_index"], + "data/file_index": episode["data/file_index"], + "data/file_from_index": index, + } + if self.load_video: + episode_meta["videos/observation.images.wrist/from_timestamp"] = episode[ + "videos/observation.images.wrist/from_timestamp" + ] + self.trajectory_ids_to_metadata[trajectory_ids[-1]] = episode_meta + + # 这里应该可以直接读取到 save index 信息 + return np.array(trajectory_ids), np.array(trajectory_lengths) + + def _get_all_steps(self) -> list[tuple[int, int]]: + """Get the trajectory IDs and base indices for all steps in the dataset. + + Returns: + list[tuple[str, int]]: A list of (trajectory_id, base_index) tuples. + """ + # Create a hash key based on configuration to ensure cache validity + config_key = self._get_steps_config_key() + + # Create a unique filename based on config_key + # steps_filename = f"steps_{config_key}.pkl" + # @BUG + # fast get static steps @fangjing --> don't use hash to dynamic sample + steps_filename = "steps_data_index.pkl" + + + steps_path = self.dataset_path / "meta" / steps_filename + + # Try to load cached steps first + try: + if steps_path.exists(): + with open(steps_path, "rb") as f: + cached_data = pickle.load(f) + return cached_data["steps"] + + except (FileNotFoundError, pickle.PickleError, KeyError) as e: + print(f"Failed to load cached steps: {e}") + print("Computing steps from scratch...") + + # Compute steps using single process + all_steps = self._get_all_steps_single_process() + + # Cache the computed steps with unique filename + try: + cache_data = { + "config_key": config_key, + "steps": all_steps, + "num_trajectories": len(self.trajectory_ids), + "total_steps": len(all_steps), + "computed_timestamp": pd.Timestamp.now().isoformat(), + "delete_pause_frame": self.delete_pause_frame, + } + + # Ensure the meta directory exists + steps_path.parent.mkdir(parents=True, exist_ok=True) + + with open(steps_path, "wb") as f: + pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL) + print(f"Cached steps saved to {steps_path}") + except Exception as e: + print(f"Failed to cache steps: {e}") + + return all_steps + + def _get_steps_config_key(self) -> str: + """Generate a configuration key for steps caching.""" + config_dict = { + "delete_pause_frame": self.delete_pause_frame, + "dataset_name": self.dataset_name, + } + # Create a hash of the configuration + config_str = str(sorted(config_dict.items())) + return hashlib.md5(config_str.encode()).hexdigest()[:12] # + + + def _get_all_steps_single_process(self) -> list[tuple[int, int]]: + """Original single-process implementation as fallback.""" + all_steps: list[tuple[int, int]] = [] + skipped_trajectories = 0 + processed_trajectories = 0 + + # Check if language modality is configured + has_language_modality = 'language' in self.modality_keys and len(self.modality_keys['language']) > 0 + # TODO why trajectory_length here, why not use data length? + for trajectory_id, trajectory_length in tqdm(zip(self.trajectory_ids, self.trajectory_lengths), total=len(self.trajectory_ids), desc="Getting All Step"): + try: + if self._lerobot_version == "v2.0": + data = self.get_trajectory_data(trajectory_id) + elif self._lerobot_version == "v3.0": + data = self.get_trajectory_data_lerobot_v3(trajectory_id) + + trajectory_skipped = False + + # Check if trajectory has valid language instruction (if language modality is configured) + if has_language_modality: + self.curr_traj_data = data # Set current trajectory data for get_language to work + + language_instruction = self.get_language(trajectory_id, self.modality_keys['language'][0], 0) + if not language_instruction or language_instruction[0] == "": + print(f"Skipping trajectory {trajectory_id} due to empty language instruction") + skipped_trajectories += 1 + trajectory_skipped = True + continue + + except Exception as e: + print(f"Skipping trajectory {trajectory_id} due to read error: {e}") + skipped_trajectories += 1 + trajectory_skipped = True + continue + + if not trajectory_skipped: + processed_trajectories += 1 + + for base_index in range(trajectory_length): + all_steps.append((trajectory_id, base_index)) + + # Print summary statistics + print(f"Single-process summary: Processed {processed_trajectories} trajectories, skipped {skipped_trajectories} empty trajectories") + print(f"Total steps: {len(all_steps)} from {len(self.trajectory_ids)} trajectories") + + return all_steps + + def _get_position_and_gripper_values(self, data: pd.DataFrame) -> tuple[list, list]: + """Get position and gripper values based on available columns in the dataset.""" + # Get action keys from modality_keys + action_keys = self.modality_keys.get('action', []) + + # Extract position data + delta_position_values = None + position_candidates = ['delta_eef_position'] + coordinate_candidates = ['x', 'y', 'z'] + + # First try combined position fields + for pos_key in position_candidates: + full_key = f"action.{pos_key}" + if full_key in action_keys: + try: + # Get the lerobot key for this modality + le_action_cfg = self.lerobot_modality_meta.action + subkey = pos_key + if subkey in le_action_cfg: + le_key = le_action_cfg[subkey].original_key or subkey + if le_key in data.columns: + data_array = np.stack(data[le_key]) + le_indices = np.arange(le_action_cfg[subkey].start, le_action_cfg[subkey].end) + filtered_data = data_array[:, le_indices] + delta_position_values = filtered_data.tolist() + break + except Exception: + continue + + # If combined fields not found, try individual x,y,z coordinates + if delta_position_values is None: + x_data, y_data, z_data = None, None, None + for coord in coordinate_candidates: + full_key = f"action.{coord}" + if full_key in action_keys: + try: + le_action_cfg = self.lerobot_modality_meta.action + if coord in le_action_cfg: + le_key = le_action_cfg[coord].original_key or coord + if le_key in data.columns: + data_array = np.stack(data[le_key]) + le_indices = np.arange(le_action_cfg[coord].start, le_action_cfg[coord].end) + coord_data = data_array[:, le_indices].flatten() + if coord == 'x': + x_data = coord_data + elif coord == 'y': + y_data = coord_data + elif coord == 'z': + z_data = coord_data + except Exception: + continue + + if x_data is not None and y_data is not None and z_data is not None: + delta_position_values = np.column_stack((x_data, y_data, z_data)).tolist() + + if delta_position_values is None: + # Fallback to the old hardcoded approach if metadata approach fails + if 'action.delta_eef_position' in data.columns: + delta_position_values = data['action.delta_eef_position'].to_numpy().tolist() + elif all(col in data.columns for col in ['action.x', 'action.y', 'action.z']): + x_vals = data['action.x'].to_numpy() + y_vals = data['action.y'].to_numpy() + z_vals = data['action.z'].to_numpy() + delta_position_values = np.column_stack((x_vals, y_vals, z_vals)).tolist() + else: + raise ValueError(f"No suitable position columns found. Available columns: {data.columns.tolist()}") + + # Extract gripper data + gripper_values = None + gripper_candidates = ['gripper_close', 'gripper'] + + for grip_key in gripper_candidates: + full_key = f"action.{grip_key}" + if full_key in action_keys: + try: + le_action_cfg = self.lerobot_modality_meta.action + if grip_key in le_action_cfg: + le_key = le_action_cfg[grip_key].original_key or grip_key + if le_key in data.columns: + data_array = np.stack(data[le_key]) + le_indices = np.arange(le_action_cfg[grip_key].start, le_action_cfg[grip_key].end) + gripper_data = data_array[:, le_indices].flatten() + gripper_values = gripper_data.tolist() + break + except Exception: + continue + + if gripper_values is None: + # Fallback to the old hardcoded approach if metadata approach fails + if 'action.gripper_close' in data.columns: + gripper_values = data['action.gripper_close'].to_numpy().tolist() + elif 'action.gripper' in data.columns: + gripper_values = data['action.gripper'].to_numpy().tolist() + else: + raise ValueError(f"No suitable gripper columns found. Available columns: {data.columns.tolist()}") + + return delta_position_values, gripper_values + + def _get_modality_keys(self) -> dict: + """Get the modality keys for the dataset. + The keys are the modality names, and the values are the keys for each modality. + See property `modality_keys` for the expected format. + """ + modality_keys = defaultdict(list) + for modality, config in self.modality_configs.items(): + modality_keys[modality] = config.modality_keys + return modality_keys + + def _get_delta_indices(self) -> dict[str, np.ndarray]: + """Restructure the delta indices to use modality.key as keys instead of just the modalities.""" + delta_indices: dict[str, np.ndarray] = {} + for config in self.modality_configs.values(): + for key in config.modality_keys: + delta_indices[key] = np.array(config.delta_indices) + return delta_indices + + def _get_lerobot_modality_meta(self) -> LeRobotModalityMetadata: + """Get the metadata for the LeRobot dataset.""" + modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME + assert ( + modality_meta_path.exists() + ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}" + with open(modality_meta_path, "r") as f: + modality_meta = LeRobotModalityMetadata.model_validate(json.load(f)) + return modality_meta + + def _get_lerobot_info_meta(self) -> dict: + """Get the metadata for the LeRobot dataset.""" + info_meta_path = self.dataset_path / LE_ROBOT_INFO_FILENAME + with open(info_meta_path, "r") as f: + info_meta = json.load(f) + return info_meta + + def _get_data_path_pattern(self) -> str: + """Get the data path pattern for the LeRobot dataset.""" + return self.lerobot_info_meta["data_path"] + + def _get_video_path_pattern(self) -> str: + """Get the video path pattern for the LeRobot dataset.""" + return self.lerobot_info_meta["video_path"] + + def _get_chunk_size(self) -> int: + """Get the chunk size for the LeRobot dataset.""" + return self.lerobot_info_meta["chunks_size"] + + def _get_tasks(self) -> pd.DataFrame: + """Get the tasks for the dataset.""" + if self._lerobot_version == "v2.0": + tasks_path = self.dataset_path / LE_ROBOT_TASKS_FILENAME + with open(tasks_path, "r") as f: + tasks = [json.loads(line) for line in f] + df = pd.DataFrame(tasks) + return df.set_index("task_index") + + elif self._lerobot_version == "v3.0": + tasks_path = self.dataset_path / LE_ROBOT3_TASKS_FILENAME + df = pd.read_parquet(tasks_path) + df = df.reset_index() # 把索引变成一列,列名通常为 'index' + df = df.rename(columns={'index': 'task'}) # 把 'index' 列重命名为 'task' + df = df[['task_index', 'task']] # 调整列顺序 + return df + def _check_integrity(self): + """Use the config to check if the keys are valid and detect silent data corruption.""" + ERROR_MSG_HEADER = f"Error occurred in initializing dataset {self.dataset_name}:\n" + + for modality_config in self.modality_configs.values(): + for key in modality_config.modality_keys: + if key == "lapa_action" or key == "dream_actions": + continue # no need for any metadata for lapa actions because it comes normalized + # Check if the key is valid + try: + self.lerobot_modality_meta.get_key_meta(key) + except Exception as e: + raise ValueError( + ERROR_MSG_HEADER + f"Unable to find key {key} in modality metadata:\n{e}" + ) + + def set_transforms_metadata(self, metadata: DatasetMetadata): + """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values.""" + self.transforms.set_metadata(metadata) + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset. + + Args: + epoch (int): The epoch to set. + """ + self.epoch = epoch + + def __len__(self) -> int: + """Get the total number of data points in the dataset. + + Returns: + int: the total number of data points in the dataset. + """ + return len(self.all_steps) + + def __str__(self) -> str: + """Get the description of the dataset.""" + return f"{self.dataset_name} ({len(self)} steps)" + + + def __getitem__(self, index: int) -> dict: + """Get the data for a single step in a trajectory. + + Args: + index (int): The index of the step to get. + + Returns: + dict: The data for the step. + """ + trajectory_id, base_index = self.all_steps[index] + data = self.get_step_data(trajectory_id, base_index) + + # Process all video keys dynamically + images = [] + mid_images = [] + for video_key in self.modality_keys.get("video", []): + video_frames = data[video_key] + image = video_frames[0] + image = Image.fromarray(image).resize((224, 224)) + images.append(image) + if self.num_history_steps != 0: + history_index = min(self.num_history_steps - 1, len(video_frames) - 1) + mid_image = video_frames[history_index] + mid_image = Image.fromarray(mid_image).resize((224, 224)) + mid_images.append(mid_image) + + # Get language and action data + language = data[self.modality_keys["language"][0]][0] + action = [] + for action_key in self.modality_keys["action"]: + action.append(data[action_key]) + action = np.concatenate(action, axis=1) + action = standardize_action_representation(action, self.tag) + + state = [] + for state_key in self.modality_keys["state"]: + state.append(data[state_key]) + state = np.concatenate(state, axis=1) + state = standardize_state_representation(state, self.tag) + + sample = dict(action=action, state=state, image=images, language=language, dataset_id=self._dataset_id) + if self.num_history_steps != 0: + sample["mid_image"] = mid_images + return sample + + def get_step_data(self, trajectory_id: int, base_index: int) -> dict: + """Get the RAW data for a single step in a trajectory. No transforms are applied. + + Args: + trajectory_id (int): The name of the trajectory. + base_index (int): The base step index in the trajectory. + + Returns: + dict: The RAW data for the step. + + Example return: + { + "video": { + "video.image_side_0": [B, T, H, W, C], + "video.image_side_1": [B, T, H, W, C], + }, + "state": { + "state.eef_position": [B, T, state_dim], + "state.eef_rotation": [B, T, state_dim], + }, + "action": { + "action.eef_position": [B, T, action_dim], + "action.eef_rotation": [B, T, action_dim], + }, + } + """ + data = {} + # Get the data for all modalities # just for action base data + self.curr_traj_data = self.get_trajectory_data(trajectory_id) + # TODO @JinhuiYE The logic below is poorly implemented. Data reading should be directly based on curr_traj_data. + for modality in self.modality_keys: + # Get the data corresponding to each key in the modality + for key in self.modality_keys[modality]: + data[key] = self.get_data_by_modality(trajectory_id, modality, key, base_index) + return data + + def get_trajectory_data(self, trajectory_id: int) -> pd.DataFrame: + """Get the data for a trajectory.""" + if self._lerobot_version == "v2.0": + + if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None: + return self.curr_traj_data + else: + chunk_index = self.get_episode_chunk(trajectory_id) + parquet_path = self.dataset_path / self.data_path_pattern.format( + episode_chunk=chunk_index, episode_index=trajectory_id + ) + assert parquet_path.exists(), f"Parquet file not found at {parquet_path}" + return pd.read_parquet(parquet_path) + elif self._lerobot_version == "v3.0": + return self.get_trajectory_data_lerobot_v3(trajectory_id) + + def get_trajectory_data_lerobot_v3(self, trajectory_id: int) -> pd.DataFrame: + """Get the data for a trajectory from lerobot v3.""" + if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None: + return self.curr_traj_data + else: #TODO check detail later + chunk_index = self.get_episode_chunk(trajectory_id) + + file_index = self.get_episode_file_index(trajectory_id) + # file_from_index = self.get_episode_file_from_index(trajectory_id) + + + parquet_path = self.dataset_path / self.data_path_pattern.format( + chunk_index=chunk_index, file_index=file_index + ) + assert parquet_path.exists(), f"Parquet file not found at {parquet_path}" + file_data = pd.read_parquet(parquet_path) + + # filter by trajectory_id + episode_data = file_data.loc[file_data["episode_index"] == trajectory_id].copy() + + # fix timestamp from epis index to file index for video alignment + if self.load_video: + from_timestamp = self.trajectory_ids_to_metadata[trajectory_id].get( + "videos/observation.images.wrist/from_timestamp", 0 + ) + episode_data["timestamp"] = episode_data["timestamp"] + from_timestamp + + return episode_data + + + def get_trajectory_index(self, trajectory_id: int) -> int: + """Get the index of the trajectory in the dataset by the trajectory ID. + This is useful when you need to get the trajectory length or sampling weight corresponding to the trajectory ID. + + Args: + trajectory_id (str): The ID of the trajectory. + + Returns: + int: The index of the trajectory in the dataset. + """ + trajectory_indices = np.where(self.trajectory_ids == trajectory_id)[0] + if len(trajectory_indices) != 1: + raise ValueError( + f"Error finding trajectory index for {trajectory_id}, found {trajectory_indices=}" + ) + return trajectory_indices[0] + + def get_episode_chunk(self, ep_index: int) -> int: + """Get the chunk index for an episode index.""" + return ep_index // self.chunk_size + def get_episode_file_index(self, ep_index: int) -> int: + """Get the file index for an episode index.""" + episode_meta = self.trajectory_ids_to_metadata[ep_index] + return episode_meta["data/file_index"] + + def get_episode_file_from_index(self, ep_index: int) -> int: + """Get the file from index for an episode index.""" + episode_meta = self.trajectory_ids_to_metadata[ep_index] + return episode_meta["data/file_from_index"] + + + def retrieve_data_and_pad( + self, + array: np.ndarray, + step_indices: np.ndarray, + max_length: int, + padding_strategy: str = "first_last", + ) -> np.ndarray: + """Retrieve the data from the dataset and pad it if necessary. + Args: + array (np.ndarray): The array to retrieve the data from. + step_indices (np.ndarray): The step indices to retrieve the data for. + max_length (int): The maximum length of the data. + padding_strategy (str): The padding strategy, either "first" or "last". + """ + # Get the padding indices + front_padding_indices = step_indices < 0 + end_padding_indices = step_indices >= max_length + padding_positions = np.logical_or(front_padding_indices, end_padding_indices) + # Retrieve the data with the non-padding indices + # If there exists some padding, Given T step_indices, the shape of the retrieved data will be (T', ...) where T' < T + raw_data = array[step_indices[~padding_positions]] + assert isinstance(raw_data, np.ndarray), f"{type(raw_data)=}" + # This is the shape of the output, (T, ...) + if raw_data.ndim == 1: + expected_shape = (len(step_indices),) + else: + expected_shape = (len(step_indices), *array.shape[1:]) + + # Pad the data + output = np.zeros(expected_shape) + # Assign the non-padded data + output[~padding_positions] = raw_data + # If there exists some padding, pad the data + if padding_positions.any(): + if padding_strategy == "first_last": + # Use first / last step data to pad + front_padding_data = array[0] + end_padding_data = array[-1] + output[front_padding_indices] = front_padding_data + output[end_padding_indices] = end_padding_data + elif padding_strategy == "zero": + # Use zero padding + output[padding_positions] = 0 + else: + raise ValueError(f"Invalid padding strategy: {padding_strategy}") + return output + + def get_video_path(self, trajectory_id: int, key: str) -> Path: + chunk_index = self.get_episode_chunk(trajectory_id) + original_key = self.lerobot_modality_meta.video[key].original_key + if original_key is None: + original_key = key + if self._lerobot_version == "v2.0": + video_filename = self.video_path_pattern.format( + episode_chunk=chunk_index, episode_index=trajectory_id, video_key=original_key + ) + elif self._lerobot_version == "v3.0": + episode_meta = self.trajectory_ids_to_metadata[trajectory_id] + video_filename = self.video_path_pattern.format( + video_key=original_key, + chunk_index=episode_meta["data/chunk_index"], + file_index=episode_meta["data/file_index"], + ) + return self.dataset_path / video_filename + + def get_video( + self, + trajectory_id: int, + key: str, + base_index: int, + ) -> np.ndarray: + """Get the video frames for a trajectory by a base index. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (str): The ID of the trajectory. + key (str): The key of the video. + base_index (int): The base index of the trajectory. + + Returns: + np.ndarray: The video frames for the trajectory and frame indices. Shape: (T, H, W, C) + """ + # Get the step indices + step_indices = self.delta_indices[key] + base_index + # print(f"{step_indices=}") + # Get the trajectory index + trajectory_index = self.get_trajectory_index(trajectory_id) + # Ensure the indices are within the valid range + # This is equivalent to padding the video with extra frames at the beginning and end + step_indices = np.maximum(step_indices, 0) + step_indices = np.minimum(step_indices, self.trajectory_lengths[trajectory_index] - 1) + assert key.startswith("video."), f"Video key must start with 'video.', got {key}" + # Get the sub-key + key = key.replace("video.", "") + video_path = self.get_video_path(trajectory_id, key) + # Get the action/state timestamps for each frame in the video + assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" + assert "timestamp" in self.curr_traj_data.columns, f"No timestamp found in {trajectory_id=}" + timestamp: np.ndarray = self.curr_traj_data["timestamp"].to_numpy() + # Get the corresponding video timestamps from the step indices + video_timestamp = timestamp[step_indices] + + return get_frames_by_timestamps( + video_path.as_posix(), + video_timestamp, + video_backend=self.video_backend, # TODO + video_backend_kwargs=self.video_backend_kwargs, + ) + + def get_state_or_action( + self, + trajectory_id: int, + modality: str, + key: str, + base_index: int, + ) -> np.ndarray: + """Get the state or action data for a trajectory by a base index. + If the step indices are out of range, pad with the data: + if the data is stored in absolute format, pad with the first or last step data; + otherwise, pad with zero. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (int): The ID of the trajectory. + modality (str): The modality of the data. + key (str): The key of the data. + base_index (int): The base index of the trajectory. + + Returns: + np.ndarray: The data for the trajectory and step indices. + """ + # Get the step indices + step_indices = self.delta_indices[key] + base_index + # Get the trajectory index + trajectory_index = self.get_trajectory_index(trajectory_id) + # Get the maximum length of the trajectory + max_length = self.trajectory_lengths[trajectory_index] + assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}" + # Get the sub-key, e.g. state.joint_angles -> joint_angles + key = key.replace(modality + ".", "") + # Get the lerobot key + le_state_or_action_cfg = getattr(self.lerobot_modality_meta, modality) + le_key = le_state_or_action_cfg[key].original_key + if le_key is None: + le_key = key + # Get the data array, shape: (T, D) + assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" + assert le_key in self.curr_traj_data.columns, f"No {le_key} found in {trajectory_id=}" + data_array: np.ndarray = np.stack(self.curr_traj_data[le_key]) # type: ignore + assert data_array.ndim == 2, f"Expected 2D array, got key {le_key} is{data_array.shape} array" + le_indices = np.arange( + le_state_or_action_cfg[key].start, + le_state_or_action_cfg[key].end, + ) + data_array = data_array[:, le_indices] + # Get the state or action configuration + state_or_action_cfg = getattr(self.metadata.modalities, modality)[key] + + # Pad the data + return self.retrieve_data_and_pad( + array=data_array, + step_indices=step_indices, + max_length=max_length, + padding_strategy="first_last" if state_or_action_cfg.absolute else "zero", + # padding_strategy="zero", # HACK for realdata + ) + + def get_language( + self, + trajectory_id: int, + key: str, + base_index: int, + ) -> list[str]: + """Get the language annotation data for a trajectory by step indices. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (int): The ID of the trajectory. + key (str): The key of the annotation. + base_index (int): The base index of the trajectory. + + Returns: + list[str]: The annotation data for the trajectory and step indices. If no matching data is found, return empty strings. + """ + assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" + # Get the step indices + step_indices = self.delta_indices[key] + base_index + # Get the trajectory index + trajectory_index = self.get_trajectory_index(trajectory_id) + # Get the maximum length of the trajectory + max_length = self.trajectory_lengths[trajectory_index] + # Get the end times corresponding to the closest indices + step_indices = np.maximum(step_indices, 0) + step_indices = np.minimum(step_indices, max_length - 1) + # Get the annotations + task_indices: list[int] = [] + assert key.startswith( + "annotation." + ), f"Language key must start with 'annotation.', got {key}" + subkey = key.replace("annotation.", "") + annotation_meta = self.lerobot_modality_meta.annotation + assert annotation_meta is not None, f"Annotation metadata is None for {subkey}" + assert ( + subkey in annotation_meta + ), f"Annotation key {subkey} not found in metadata, available annotation keys: {annotation_meta.keys()}" + subkey_meta = annotation_meta[subkey] + original_key = subkey_meta.original_key + if original_key is None: + original_key = key + for i in range(len(step_indices)): # + # task_indices.append(self.curr_traj_data[original_key][step_indices[i]].item()) + value = self.curr_traj_data[original_key].iloc[step_indices[i]] # TODO check v2.0 + task_indices.append(value if isinstance(value, (int, float)) else value.item()) + + return self.tasks.loc[task_indices]["task"].tolist() + + def get_data_by_modality( + self, + trajectory_id: int, + modality: str, + key: str, + base_index: int, + ): + """Get the data corresponding to the modality for a trajectory by a base index. + This method will call the corresponding helper method based on the modality. + See the helper methods for more details. + NOTE: For the language modality, the data is padded with empty strings if no matching data is found. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (int): The ID of the trajectory. + modality (str): The modality of the data. + key (str): The key of the data. + base_index (int): The base index of the trajectory. + """ + if modality == "video": + return self.get_video(trajectory_id, key, base_index) + elif modality == "state" or modality == "action": + return self.get_state_or_action(trajectory_id, modality, key, base_index) + elif modality == "language": + return self.get_language(trajectory_id, key, base_index) + else: + raise ValueError(f"Invalid modality: {modality}") + + def _save_dataset_statistics_(self, save_path: Path | str, format: str = "json") -> None: + """ + Save dataset statistics to specified path in the required format. + Only includes statistics for keys that are actually used in the dataset. + Key order follows modality config order. + + Args: + save_path (Path | str): Path to save the statistics file + format (str): Save format, currently only supports "json" + """ + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + # Build the data structure to save + statistics_data = {} + + # Get used modality keys + used_action_keys, used_state_keys = get_used_modality_keys(self.modality_keys) + + # Organize statistics by tag + tag = self.tag + tag_stats = {} + + # Process action statistics (only for used keys, config order) + if hasattr(self.metadata.statistics, 'action') and self.metadata.statistics.action: + action_stats = self.metadata.statistics.action + filtered_action_stats = { + key: action_stats[key] + for key in used_action_keys + if key in action_stats + } + + if filtered_action_stats: + # Combine statistics from filtered action sub-keys + combined_action_stats = combine_modality_stats(filtered_action_stats) + + # Add mask field based on whether it's gripper or not + mask = generate_action_mask_for_used_keys( + self.metadata.modalities.action, filtered_action_stats.keys() + ) + combined_action_stats["mask"] = mask + + tag_stats["action"] = combined_action_stats + + # Process state statistics (only for used keys, config order) + if hasattr(self.metadata.statistics, 'state') and self.metadata.statistics.state: + state_stats = self.metadata.statistics.state + filtered_state_stats = { + key: state_stats[key] + for key in used_state_keys + if key in state_stats + } + + if filtered_state_stats: + combined_state_stats = combine_modality_stats(filtered_state_stats) + tag_stats["state"] = combined_state_stats + + # Add dataset counts + tag_stats["num_transitions"] = len(self) + tag_stats["num_trajectories"] = len(self.trajectory_ids) + + statistics_data[tag] = tag_stats + + # Save as JSON file + if format.lower() == "json": + if not str(save_path).endswith('.json'): + save_path = save_path.with_suffix('.json') + with open(save_path, 'w', encoding='utf-8') as f: + json.dump(statistics_data, f, indent=2, ensure_ascii=False) + else: + raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.") + + print(f"Single dataset statistics saved to: {save_path}") + print(f"Used action keys (reordered): {list(used_action_keys)}") + print(f"Used state keys (reordered): {list(used_state_keys)}") + + + +class MixtureSpecElement(BaseModel): + dataset_path: list[Path] | Path = Field(..., description="The path to the dataset.") + dataset_weight: float = Field(..., description="The weight of the dataset in the mixture.") + distribute_weights: bool = Field( + default=False, + description="Whether to distribute the weights of the dataset across all the paths. If True, the weights will be evenly distributed across all the paths.", + ) + + +# Helper functions for dataset statistics + +def combine_modality_stats(modality_stats: dict) -> dict: + """ + Combine statistics from all sub-keys under a modality. + + Args: + modality_stats (dict): Statistics for a modality, containing multiple sub-keys. + Each sub-key contains DatasetStatisticalValues object. + + Returns: + dict: Combined statistics + """ + combined_stats = { + "mean": [], + "std": [], + "max": [], + "min": [], + "q01": [], + "q99": [] + } + + # Combine statistics in sub-key order + for subkey in modality_stats.keys(): + subkey_stats = modality_stats[subkey] # This is a DatasetStatisticalValues object + + # Convert DatasetStatisticalValues to dict-like access + for stat_name in ["mean", "std", "max", "min", "q01", "q99"]: + stat_value = getattr(subkey_stats, stat_name) + if isinstance(stat_value, (list, tuple)): + combined_stats[stat_name].extend(stat_value) + else: + # Handle NDArray case - convert to list + if hasattr(stat_value, 'tolist'): + combined_stats[stat_name].extend(stat_value.tolist()) + else: + combined_stats[stat_name].append(float(stat_value)) + + return combined_stats + +def generate_action_mask_for_used_keys(action_modalities: dict, used_action_keys_ordered) -> list[bool]: + """ + Generate mask based on action modalities, but only for used keys. + All dimensions are set to True so every channel is de/normalized. + + Args: + action_modalities (dict): Configuration information for action modalities. + used_action_keys_ordered: Iterable of actually used action keys in the correct order. + + Returns: + list[bool]: List of mask values + """ + mask = [] + + # Generate mask in the same order as the statistics were combined + for subkey in used_action_keys_ordered: + if subkey in action_modalities: + subkey_config = action_modalities[subkey] + + # Get dimension count from shape + if hasattr(subkey_config, 'shape') and len(subkey_config.shape) > 0: + dim_count = subkey_config.shape[0] + else: + dim_count = 1 + + # Check if it's gripper-related + is_gripper = "gripper" in subkey.lower() + + # Generate mask value for each dimension + for _ in range(dim_count): + mask.append(not is_gripper) # gripper is False, others are True + + return mask + +def get_used_modality_keys(modality_keys: dict) -> tuple[set, set]: + """Extract used action and state keys from modality configuration.""" + used_action_keys = [] + used_state_keys = [] + + # Extract action keys (remove "action." prefix) + for action_key in modality_keys.get("action", []): + if action_key.startswith("action."): + clean_key = action_key.replace("action.", "") + used_action_keys.append(clean_key) + + # Extract state keys (remove "state." prefix) + for state_key in modality_keys.get("state", []): + if state_key.startswith("state."): + clean_key = state_key.replace("state.", "") + used_state_keys.append(clean_key) + + return used_action_keys, used_state_keys + + +def safe_hash(input_tuple): + # keep 128 bits of the hash + tuple_string = repr(input_tuple).encode("utf-8") + sha256 = hashlib.sha256() + sha256.update(tuple_string) + + seed = int(sha256.hexdigest(), 16) + + return seed & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF + + +class LeRobotMixtureDataset(Dataset): + """ + A mixture of multiple datasets. This class samples a single dataset based on the dataset weights and then calls the `__getitem__` method of the sampled dataset. + It is recommended to modify the single dataset class instead of this class. + """ + + def __init__( + self, + data_mixture: Sequence[tuple[LeRobotSingleDataset, float]], + mode: str, + balance_dataset_weights: bool = True, + balance_trajectory_weights: bool = True, + seed: int = 42, + metadata_config: dict = { + "percentile_mixing_method": "min_max", + }, + **kwargs, + ): + """ + Initialize the mixture dataset. + + Args: + data_mixture (list[tuple[LeRobotSingleDataset, float]]): Datasets and their corresponding weights. + mode (str): If "train", __getitem__ will return different samples every epoch; if "val" or "test", __getitem__ will return the same sample every epoch. + balance_dataset_weights (bool): If True, the weight of dataset will be multiplied by the total trajectory length of each dataset. + balance_trajectory_weights (bool): If True, sample trajectories within a dataset weighted by their length; otherwise, use equal weighting. + seed (int): Random seed for sampling. + """ + datasets: list[LeRobotSingleDataset] = [] + dataset_sampling_weights: list[float] = [] + for dataset, weight in data_mixture: + # Check if dataset is valid and has data + if len(dataset) == 0: + print(f"Warning: Skipping empty dataset {dataset.dataset_name}") + continue + datasets.append(dataset) + dataset_sampling_weights.append(weight) + + if len(datasets) == 0: + raise ValueError("No valid datasets found in the mixture. All datasets are empty.") + + self.datasets = datasets + self.balance_dataset_weights = balance_dataset_weights + self.balance_trajectory_weights = balance_trajectory_weights + self.seed = seed + self.mode = mode + + # Set properties for sampling + + # 1. Dataset lengths + self._dataset_lengths = np.array([len(dataset) for dataset in self.datasets]) + print(f"Dataset lengths: {self._dataset_lengths}") + + # 2. Dataset sampling weights + self._dataset_sampling_weights = np.array(dataset_sampling_weights) + + if self.balance_dataset_weights: + self._dataset_sampling_weights *= self._dataset_lengths + + # Check for zero or negative weights before normalization + if np.any(self._dataset_sampling_weights <= 0): + print(f"Warning: Found zero or negative sampling weights: {self._dataset_sampling_weights}") + # Set minimum weight to prevent division issues + self._dataset_sampling_weights = np.maximum(self._dataset_sampling_weights, 1e-8) + + # Normalize weights + weights_sum = self._dataset_sampling_weights.sum() + if weights_sum == 0 or np.isnan(weights_sum): + print(f"Error: Invalid weights sum: {weights_sum}") + # Fallback to equal weights + self._dataset_sampling_weights = np.ones(len(self.datasets)) / len(self.datasets) + print(f"Fallback to equal weights") + else: + self._dataset_sampling_weights /= weights_sum + + # 3. Trajectory sampling weights + self._trajectory_sampling_weights: list[np.ndarray] = [] + for i, dataset in enumerate(self.datasets): + trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) + if self.balance_trajectory_weights: + trajectory_sampling_weights *= dataset.trajectory_lengths + + # Check for zero or negative weights before normalization + if np.any(trajectory_sampling_weights <= 0): + print(f"Warning: Dataset {i} has zero or negative trajectory weights") + trajectory_sampling_weights = np.maximum(trajectory_sampling_weights, 1e-8) + + # Normalize weights + weights_sum = trajectory_sampling_weights.sum() + if weights_sum == 0 or np.isnan(weights_sum): + print(f"Error: Dataset {i} has invalid trajectory weights sum: {weights_sum}") + # Fallback to equal weights + trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) / len(dataset.trajectory_lengths) + else: + trajectory_sampling_weights /= weights_sum + + self._trajectory_sampling_weights.append(trajectory_sampling_weights) + + # 4. Primary dataset indices + self._primary_dataset_indices = np.array(dataset_sampling_weights) == 1.0 + if not np.any(self._primary_dataset_indices): + print(f"Warning: No dataset with weight 1.0 found. Original weights: {dataset_sampling_weights}") + # Fallback: use the dataset(s) with maximum weight as primary + max_weight = max(dataset_sampling_weights) + self._primary_dataset_indices = np.array(dataset_sampling_weights) == max_weight + print(f"Using datasets with maximum weight {max_weight} as primary: {self._primary_dataset_indices}") + + if not np.any(self._primary_dataset_indices): + # This should never happen, but just in case + print("Error: Still no primary dataset found. Using first dataset as primary.") + self._primary_dataset_indices = np.zeros(len(self.datasets), dtype=bool) + self._primary_dataset_indices[0] = True + + # Set the epoch and sample the first epoch + self.set_epoch(0) + + self.update_metadata(metadata_config) + + @property + def dataset_lengths(self) -> np.ndarray: + """The lengths of each dataset.""" + return self._dataset_lengths + + @property + def dataset_sampling_weights(self) -> np.ndarray: + """The sampling weights for each dataset.""" + return self._dataset_sampling_weights + + @property + def trajectory_sampling_weights(self) -> list[np.ndarray]: + """The sampling weights for each trajectory in each dataset.""" + return self._trajectory_sampling_weights + + @property + def primary_dataset_indices(self) -> np.ndarray: + """The indices of the primary datasets.""" + return self._primary_dataset_indices + + def __str__(self) -> str: + dataset_descriptions = [] + for dataset, weight in zip(self.datasets, self.dataset_sampling_weights): + dataset_description = { + "Dataset": str(dataset), + "Sampling weight": float(weight), + } + dataset_descriptions.append(dataset_description) + return json.dumps({"Mixture dataset": dataset_descriptions}, indent=2) + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset. + + Args: + epoch (int): The epoch to set. + """ + self.epoch = epoch + # self.sampled_steps = self.sample_epoch() + + def sample_step(self, index: int) -> tuple[LeRobotSingleDataset, int, int]: + """Sample a single step from the dataset.""" + # return self.sampled_steps[index] + + # Set seed + seed = index if self.mode != "train" else safe_hash((self.epoch, index, self.seed)) + rng = np.random.default_rng(seed) + + # Sample dataset + dataset_index = rng.choice(len(self.datasets), p=self.dataset_sampling_weights) + dataset = self.datasets[dataset_index] + + # Sample trajectory + # trajectory_index = rng.choice( + # len(dataset.trajectory_ids), p=self.trajectory_sampling_weights[dataset_index] + # ) + # trajectory_id = dataset.trajectory_ids[trajectory_index] + + # # Sample step + # base_index = rng.choice(dataset.trajectory_lengths[trajectory_index]) + # return dataset, trajectory_id, base_index + single_step_index = rng.choice(len(dataset.all_steps)) + trajectory_id, base_index = dataset.all_steps[single_step_index] + return dataset, trajectory_id, base_index + + def __getitem__(self, index: int) -> dict: + """Get the data for a single trajectory and start index. + + Args: + index (int): The index of the trajectory to get. + + Returns: + dict: The data for the trajectory and start index. + """ + max_retries = 10 + last_exception = None + + for attempt in range(max_retries): + try: + dataset, trajectory_name, step = self.sample_step(index) + data_raw = dataset.get_step_data(trajectory_name, step) + data = dataset.transforms(data_raw) + + # Process all video keys dynamically + images = [] + mid_images = [] + num_history_steps = int(getattr(dataset, "num_history_steps", 0) or 0) + for video_key in dataset.modality_keys.get("video", []): + video_frames = data[video_key] + image = video_frames[0] + image = Image.fromarray(image).resize((224, 224)) #TODO check if this is ok + images.append(image) + if num_history_steps != 0: + history_index = min(num_history_steps - 1, len(video_frames) - 1) + mid_image = video_frames[history_index] + mid_image = Image.fromarray(mid_image).resize((224, 224)) + mid_images.append(mid_image) + + # Get language and action data + language = data[dataset.modality_keys["language"][0]][0] + action = [] + for action_key in dataset.modality_keys["action"]: + action.append(data[action_key]) + action = np.concatenate(action, axis=1).astype(np.float16) + action = standardize_action_representation(action, dataset.tag) + + state = [] + for state_key in dataset.modality_keys["state"]: + state.append(data[state_key]) + state = np.concatenate(state, axis=1).astype(np.float16) + state = standardize_state_representation(state, dataset.tag) + + sample = dict(action=action, state=state, image=images, lang=language, dataset_id=dataset._dataset_id) + if num_history_steps != 0: + sample["mid_image"] = mid_images + return sample + + except Exception as e: + last_exception = e + if attempt < max_retries - 1: + # Log the error but continue trying + print(f"Attempt {attempt + 1}/{max_retries} failed for index {index}: {e}") + print(f"Retrying with new sample...") + # For retry, we can use a slightly different index to get a new sample + # This helps avoid getting stuck on the same problematic sample + index = random.randint(0, len(self) - 1) + else: + # All retries exhausted + print(f"All {max_retries} attempts failed for index {index}") + print(f"Last error: {last_exception}") + # Return a dummy sample or re-raise the exception + raise last_exception + + def __len__(self) -> int: + """Get the length of a single epoch in the mixture. + + Returns: + int: The length of a single epoch in the mixture. + """ + # Check for potential issues + if len(self.datasets) == 0: + return 0 + + # Check if any dataset lengths are 0 or NaN + if np.any(self.dataset_lengths == 0) or np.any(np.isnan(self.dataset_lengths)): + print(f"Warning: Found zero or NaN dataset lengths: {self.dataset_lengths}") + # Filter out zero/NaN length datasets + valid_indices = (self.dataset_lengths > 0) & (~np.isnan(self.dataset_lengths)) + if not np.any(valid_indices): + print("Error: All datasets have zero or NaN length") + return 0 + else: + valid_indices = np.ones(len(self.datasets), dtype=bool) + + # Check if any sampling weights are 0 or NaN + if np.any(self.dataset_sampling_weights == 0) or np.any(np.isnan(self.dataset_sampling_weights)): + print(f"Warning: Found zero or NaN sampling weights: {self.dataset_sampling_weights}") + # Use only valid weights + valid_weights = (self.dataset_sampling_weights > 0) & (~np.isnan(self.dataset_sampling_weights)) + valid_indices = valid_indices & valid_weights + if not np.any(valid_indices): + print("Error: All sampling weights are zero or NaN") + return 0 + + # Check primary dataset indices + primary_and_valid = self.primary_dataset_indices & valid_indices + if not np.any(primary_and_valid): + print(f"Warning: No valid primary datasets found. Primary indices: {self.primary_dataset_indices}, Valid indices: {valid_indices}") + # Fallback: use the largest valid dataset + if np.any(valid_indices): + max_length = self.dataset_lengths[valid_indices].max() + print(f"Fallback: Using maximum dataset length: {max_length}") + return int(max_length) + else: + return 0 + + # Calculate the ratio and get max + ratios = (self.dataset_lengths / self.dataset_sampling_weights)[primary_and_valid] + + # Check for NaN or inf in ratios + if np.any(np.isnan(ratios)) or np.any(np.isinf(ratios)): + print(f"Warning: Found NaN or inf in ratios: {ratios}") + print(f"Dataset lengths: {self.dataset_lengths[primary_and_valid]}") + print(f"Sampling weights: {self.dataset_sampling_weights[primary_and_valid]}") + # Filter out invalid ratios + valid_ratios = ratios[~np.isnan(ratios) & ~np.isinf(ratios)] + if len(valid_ratios) == 0: + print("Error: All ratios are NaN or inf") + return 0 + max_ratio = valid_ratios.max() + else: + max_ratio = ratios.max() + + result = int(max_ratio) + if result == 0: + print(f"Warning: Dataset mixture length is 0") + return result + + @staticmethod + def compute_overall_statistics( + per_task_stats: list[dict[str, dict[str, list[float] | np.ndarray]]], + dataset_sampling_weights: list[float] | np.ndarray, + percentile_mixing_method: str = "weighted_average", + ) -> dict[str, dict[str, list[float]]]: + """ + Computes overall statistics from per-task statistics using dataset sample weights. + + Args: + per_task_stats: List of per-task statistics. + Example format of one element in the per-task statistics list: + { + "state.gripper": { + "min": [...], + "max": [...], + "mean": [...], + "std": [...], + "q01": [...], + "q99": [...], + }, + ... + } + dataset_sampling_weights: List of sample weights for each task. + percentile_mixing_method: The method to mix the percentiles, either "weighted_average" or "weighted_std". + + Returns: + A dict of overall statistics per modality. + """ + # Normalize the sample weights to sum to 1 + dataset_sampling_weights = np.array(dataset_sampling_weights) + normalized_weights = dataset_sampling_weights / dataset_sampling_weights.sum() + + # Initialize overall statistics dict + overall_stats: dict[str, dict[str, list[float]]] = {} + + # Get the list of modality keys + modality_keys = per_task_stats[0].keys() + + for modality in modality_keys: + # Number of dimensions (assuming consistent across tasks) + num_dims = len(per_task_stats[0][modality]["mean"]) + + # Initialize accumulators for means and variances + weighted_means = np.zeros(num_dims) + weighted_squares = np.zeros(num_dims) + + # Collect min, max, q01, q99 from all tasks + min_list = [] + max_list = [] + q01_list = [] + q99_list = [] + + for task_idx, task_stats in enumerate(per_task_stats): + w_i = normalized_weights[task_idx] + stats = task_stats[modality] + means = np.array(stats["mean"]) + stds = np.array(stats["std"]) + + # Update weighted sums for mean and variance + weighted_means += w_i * means + weighted_squares += w_i * (stds**2 + means**2) + + # Collect min, max, q01, q99 + min_list.append(stats["min"]) + max_list.append(stats["max"]) + q01_list.append(stats["q01"]) + q99_list.append(stats["q99"]) + + # Compute overall mean + overall_mean = weighted_means.tolist() + + # Compute overall variance and std deviation + overall_variance = weighted_squares - weighted_means**2 + overall_std = np.sqrt(overall_variance).tolist() + + # Compute overall min and max per dimension + overall_min = np.min(np.array(min_list), axis=0).tolist() + overall_max = np.max(np.array(max_list), axis=0).tolist() + + # Compute overall q01 and q99 per dimension + # Use weighted average of per-task quantiles + q01_array = np.array(q01_list) + q99_array = np.array(q99_list) + if percentile_mixing_method == "weighted_average": + weighted_q01 = np.average(q01_array, axis=0, weights=normalized_weights).tolist() + weighted_q99 = np.average(q99_array, axis=0, weights=normalized_weights).tolist() + # std_q01 = np.std(q01_array, axis=0).tolist() + # std_q99 = np.std(q99_array, axis=0).tolist() + # print(modality) + # print(f"{std_q01=}, {std_q99=}") + # print(f"{weighted_q01=}, {weighted_q99=}") + elif percentile_mixing_method == "min_max": + weighted_q01 = np.min(q01_array, axis=0).tolist() + weighted_q99 = np.max(q99_array, axis=0).tolist() + else: + raise ValueError(f"Invalid percentile mixing method: {percentile_mixing_method}") + + # Store the overall statistics for the modality + overall_stats[modality] = { + "min": overall_min, + "max": overall_max, + "mean": overall_mean, + "std": overall_std, + "q01": weighted_q01, + "q99": weighted_q99, + } + + return overall_stats + + @staticmethod + def merge_metadata( + metadatas: list[DatasetMetadata], + dataset_sampling_weights: list[float], + percentile_mixing_method: str, + ) -> DatasetMetadata: + """Merge multiple metadata into one.""" + # Convert to dicts + metadata_dicts = [metadata.model_dump(mode="json") for metadata in metadatas] + # Create a new metadata dict + merged_metadata = {} + + # Check all metadata have the same embodiment tag + assert all( + metadata.embodiment_tag == metadatas[0].embodiment_tag for metadata in metadatas + ), "All metadata must have the same embodiment tag" + merged_metadata["embodiment_tag"] = metadatas[0].embodiment_tag + + # Merge the dataset statistics + dataset_statistics = {} + dataset_statistics["state"] = LeRobotMixtureDataset.compute_overall_statistics( + per_task_stats=[m["statistics"]["state"] for m in metadata_dicts], + dataset_sampling_weights=dataset_sampling_weights, + percentile_mixing_method=percentile_mixing_method, + ) + dataset_statistics["action"] = LeRobotMixtureDataset.compute_overall_statistics( + per_task_stats=[m["statistics"]["action"] for m in metadata_dicts], + dataset_sampling_weights=dataset_sampling_weights, + percentile_mixing_method=percentile_mixing_method, + ) + merged_metadata["statistics"] = dataset_statistics + + # Merge the modality configs + modality_configs = defaultdict(set) + for metadata in metadata_dicts: + for modality, configs in metadata["modalities"].items(): + modality_configs[modality].add(json.dumps(configs)) + merged_metadata["modalities"] = {} + for modality, configs in modality_configs.items(): + # Check that all modality configs correspond to the same tag matches + assert ( + len(configs) == 1 + ), f"Multiple modality configs for modality {modality}: {list(configs)}" + merged_metadata["modalities"][modality] = json.loads(configs.pop()) + + return DatasetMetadata.model_validate(merged_metadata) + + def update_metadata(self, metadata_config: dict, cached_statistics_path: Path | str | None = None) -> None: + """ + Merge multiple metadatas into one and set the transforms with the merged metadata. + + Args: + metadata_config (dict): Configuration for the metadata. + "percentile_mixing_method": The method to mix the percentiles, either "weighted_average" or "min_max". + weighted_average: Use the weighted average of the percentiles using the weight used in sampling the datasets. + min_max: Use the min of the 1st percentile and max of the 99th percentile. + """ + # If cached path is provided, try to load and apply + if cached_statistics_path is not None: + try: + cached_stats = self.load_merged_statistics(cached_statistics_path) + self.apply_cached_statistics(cached_stats) + return + except (FileNotFoundError, KeyError, ValidationError) as e: + print(f"Failed to load cached statistics: {e}") + print("Falling back to computing statistics from scratch...") + + self.tag = EmbodimentTag.NEW_EMBODIMENT.value + self.merged_metadata: dict[str, DatasetMetadata] = {} + # Group metadata by tag + all_metadatas: dict[str, list[DatasetMetadata]] = {} + for dataset in self.datasets: + if dataset.tag not in all_metadatas: + all_metadatas[dataset.tag] = [] + all_metadatas[dataset.tag].append(dataset.metadata) + for tag, metadatas in all_metadatas.items(): + self.merged_metadata[tag] = self.merge_metadata( + metadatas=metadatas, + dataset_sampling_weights=self.dataset_sampling_weights.tolist(), + percentile_mixing_method=metadata_config["percentile_mixing_method"], + ) + for dataset in self.datasets: + dataset.set_transforms_metadata(self.merged_metadata[dataset.tag]) + + def save_dataset_statistics(self, save_path: Path | str, format: str = "json") -> None: + """ + Save merged dataset statistics to specified path in the required format. + Only includes statistics for keys that are actually used in the datasets. + Key order follows each tag's modality config order. + + Args: + save_path (Path | str): Path to save the statistics file + format (str): Save format, currently only supports "json" + """ + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + # Build the data structure to save + statistics_data = {} + + # Keep key orders per embodiment tag (from modality config order) + tag_to_used_action_keys = {} + tag_to_used_state_keys = {} + for dataset in self.datasets: + if dataset.tag in tag_to_used_action_keys: + continue + used_action_keys, used_state_keys = get_used_modality_keys(dataset.modality_keys) + tag_to_used_action_keys[dataset.tag] = used_action_keys + tag_to_used_state_keys[dataset.tag] = used_state_keys + + # Organize statistics by tag + for tag, merged_metadata in self.merged_metadata.items(): + tag_stats = {} + + # Process action statistics + if hasattr(merged_metadata.statistics, 'action') and merged_metadata.statistics.action: + action_stats = merged_metadata.statistics.action + + used_action_keys = tag_to_used_action_keys.get(tag, []) + filtered_action_stats = { + key: action_stats[key] + for key in used_action_keys + if key in action_stats + } + + if filtered_action_stats: + combined_action_stats = combine_modality_stats(filtered_action_stats) + + mask = generate_action_mask_for_used_keys( + merged_metadata.modalities.action, filtered_action_stats.keys() + ) + combined_action_stats["mask"] = mask + + tag_stats["action"] = combined_action_stats + + # Process state statistics + if hasattr(merged_metadata.statistics, 'state') and merged_metadata.statistics.state: + state_stats = merged_metadata.statistics.state + + used_state_keys = tag_to_used_state_keys.get(tag, []) + filtered_state_stats = { + key: state_stats[key] + for key in used_state_keys + if key in state_stats + } + + if filtered_state_stats: + combined_state_stats = combine_modality_stats(filtered_state_stats) + tag_stats["state"] = combined_state_stats + + # Add dataset counts + tag_stats.update(self._get_dataset_counts(tag)) + + statistics_data[tag] = tag_stats + + # Save file + if format.lower() == "json": + if not str(save_path).endswith('.json'): + save_path = save_path.with_suffix('.json') + with open(save_path, 'w', encoding='utf-8') as f: + json.dump(statistics_data, f, indent=2, ensure_ascii=False) + else: + raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.") + + print(f"Merged dataset statistics saved to: {save_path}") + print(f"Used action keys by tag: {tag_to_used_action_keys}") + print(f"Used state keys by tag: {tag_to_used_state_keys}") + + + def _combine_modality_stats(self, modality_stats: dict) -> dict: + """Backward compatibility wrapper.""" + return combine_modality_stats(modality_stats) + + def _generate_action_mask_for_used_keys(self, action_modalities: dict, used_action_keys_ordered) -> list[bool]: + """Backward compatibility wrapper.""" + return generate_action_mask_for_used_keys(action_modalities, used_action_keys_ordered) + + def _get_dataset_counts(self, tag: str) -> dict: + """ + Get dataset count information for specified tag. + + Args: + tag (str): embodiment tag + + Returns: + dict: Dictionary containing num_transitions and num_trajectories + """ + num_transitions = 0 + num_trajectories = 0 + + # Count dataset information belonging to this tag + for dataset in self.datasets: + if dataset.tag == tag: + num_transitions += len(dataset) + num_trajectories += len(dataset.trajectory_ids) + + return { + "num_transitions": num_transitions, + "num_trajectories": num_trajectories + } + + @classmethod + def load_merged_statistics(cls, load_path: Path | str) -> dict: + """ + Load merged dataset statistics from file. + + Args: + load_path (Path | str): Path to the statistics file + + Returns: + dict: Dictionary containing merged statistics + """ + load_path = Path(load_path) + if not load_path.exists(): + raise FileNotFoundError(f"Statistics file not found: {load_path}") + + if load_path.suffix.lower() == '.json': + with open(load_path, 'r', encoding='utf-8') as f: + return json.load(f) + elif load_path.suffix.lower() == '.pkl': + import pickle + with open(load_path, 'rb') as f: + return pickle.load(f) + else: + raise ValueError(f"Unsupported file format: {load_path.suffix}") + + def apply_cached_statistics(self, cached_statistics: dict) -> None: + """ + Apply cached statistics to avoid recomputation. + + Args: + cached_statistics (dict): Statistics loaded from file + """ + # Validate that cached statistics match current datasets + if "metadata" in cached_statistics: + cached_dataset_names = set(cached_statistics["metadata"]["dataset_names"]) + current_dataset_names = set(dataset.dataset_name for dataset in self.datasets) + + if cached_dataset_names != current_dataset_names: + print("Warning: Cached statistics dataset names don't match current datasets.") + print(f"Cached: {cached_dataset_names}") + print(f"Current: {current_dataset_names}") + return + + # Apply cached statistics + self.merged_metadata = {} + for tag, stats_data in cached_statistics.items(): + if tag == "metadata": # Skip metadata field + continue + + # Convert back to DatasetMetadata format + metadata_dict = { + "embodiment_tag": tag, + "statistics": { + "action": {}, + "state": {} + }, + "modalities": {} + } + + # Convert action statistics back + if "action" in stats_data: + action_data = stats_data["action"] + # This is simplified - you may need to split back to sub-keys + metadata_dict["statistics"]["action"] = action_data + + # Convert state statistics back + if "state" in stats_data: + state_data = stats_data["state"] + metadata_dict["statistics"]["state"] = state_data + + self.merged_metadata[tag] = DatasetMetadata.model_validate(metadata_dict) + + # Update transforms metadata for each dataset + for dataset in self.datasets: + if dataset.tag in self.merged_metadata: + dataset.set_transforms_metadata(self.merged_metadata[dataset.tag]) + + print(f"Applied cached statistics for {len(self.merged_metadata)} embodiment tags.") + diff --git a/code/dataloader/gr00t_lerobot/datasets_bak.py b/code/dataloader/gr00t_lerobot/datasets_bak.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a603bd28570a0e0adae01329486ddd63aa3996 --- /dev/null +++ b/code/dataloader/gr00t_lerobot/datasets_bak.py @@ -0,0 +1,2175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +In this file, we define 3 types of datasets: +1. LeRobotSingleDataset: a single dataset for a given embodiment tag +2. LeRobotMixtureDataset: a mixture of datasets for a given list of embodiment tags +3. CachedLeRobotSingleDataset: a single dataset for a given embodiment tag, + with caching for the video frames + +See `scripts/load_dataset.py` for examples on how to use these datasets. +""" +import os +import hashlib +import json, torch +from collections import defaultdict +from pathlib import Path +from typing import Sequence +import os, random +import numpy as np +import pandas as pd +from pydantic import BaseModel, Field, ValidationError +from torch.utils.data import Dataset +from tqdm import tqdm +from PIL import Image + +from starVLA.dataloader.gr00t_lerobot.video import get_all_frames, get_frames_by_timestamps + +from starVLA.dataloader.gr00t_lerobot.embodiment_tags import EmbodimentTag, DATASET_NAME_TO_ID +from starVLA.dataloader.gr00t_lerobot.schema import ( + DatasetMetadata, + DatasetStatisticalValues, + LeRobotModalityMetadata, + LeRobotStateActionMetadata, +) +from starVLA.dataloader.gr00t_lerobot.transform import ComposedModalityTransform + +from functools import partial +from typing import Tuple, List +import pickle + +# LeRobot v2.0 dataset file names +LE_ROBOT_MODALITY_FILENAME = "meta/modality.json" +LE_ROBOT_EPISODE_FILENAME = "meta/episodes.jsonl" +LE_ROBOT_TASKS_FILENAME = "meta/tasks.jsonl" +LE_ROBOT_INFO_FILENAME = "meta/info.json" +LE_ROBOT_STATS_FILENAME = "meta/stats_gr00t.json" +LE_ROBOT_DATA_FILENAME = "data/*/*.parquet" +LE_ROBOT_STEPS_FILENAME = "meta/steps.pkl" +EPSILON = 5e-4 + +# LeRobot v3.0 dataset file names +LE_ROBOT3_TASKS_FILENAME = "meta/tasks.parquet" +LE_ROBOT3_EPISODE_FILENAME = "meta/episodes/*/*.parquet" + + +# ============================================================================= +# Unified Representation Layout & Helpers +# ============================================================================= + +STANDARD_ACTION_DIM = 37 +# +# Unified action representation layout (0-based indices, Python slice is [start, stop)): +# TIGHT layout: all datasets share the same 29D space for better cross-embodiment transfer. +# +# - 0:7 -> left_arm (7D): xyz, rpy/euler, gripper +# Used by: robotwin left arm; gr1 left_arm +# - 7:14 -> right_arm (7D): same structure +# Used by: libero, bridge, fractal(rt1), oxe_droid (single-arm -> right slot); +# robotwin right arm; gr1 right_arm +# - 14:20 -> left_hand (6D): gr1 only +# - 20:26 -> right_hand (6D): gr1 only +# - 26:29 -> waist (3D): gr1 only +# - 29:37 -> joints + gripper (8D): real_world_franka only +# +# Mapping: +# libero/bridge/fractal/oxe_droid (7D) -> [7:14] (right_arm slot, single-arm default) +# robotwin (14D, left+right) -> [0:14] +# gr1/robocasa (29D) -> [0:29] +# real-world (8D) -> [29:37] (joints + gripper) + +ACTION_REPRESENTATION_SLICES = { + # Single-arm (7D) -> right_arm slot [7:14] (single-arm default to right hand) + "franka": slice(7, 14), + "libero_franka": slice(7, 14), + "oxe_droid": slice(7, 14), + "oxe_rt1": slice(7, 14), + "oxe_bridge": slice(7, 14), + + # Dual-arm (14D) -> left [0:7] + right [7:14] + "dual_arm_franka": slice(0, 14), + "robotwin": slice(0, 14), + + # Humanoid (29D) -> full [0:29], standard vector 30D (index 29 pad 0) + "gr1": slice(0, 29), + "fourier_gr1_arms_waist": slice(0, 29), + + # Real-world (8D) -> [29:37] (joints + gripper) + "real_world_franka": slice(29, 37), + + # Fallback (single-arm -> right slot) + "new_embodiment": slice(7, 14), +} + +STANDARD_STATE_DIM = 88 +# Mapping: +# robotwin (14D) -> [0:14] (left [0:7] + right [7:14]) +# libero/bridge/fractal (8D) -> [14:22] (right slot) +# real-world (8D) -> [22:30] (joints + gripper) +# gr1 (58D after sin/cos) -> [30:88] (isolated, different transform) + +STATE_REPRESENTATION_SLICES = { + # Dual-arm (14D) -> left [0:7] + right [7:14] + "dual_arm_franka": slice(0, 14), + "robotwin": slice(0, 14), + # Single-arm (8D) -> right slot [7:15] (aligned with action right [7:14]) + "franka": slice(14, 22), + "libero_franka": slice(14, 22), + "oxe_droid": slice(14, 22), + "oxe_rt1": slice(14, 22), + "oxe_bridge": slice(14, 22), + # Real-world (8D) -> [22:30] (joints + gripper) + "real_world_franka": slice(22, 30), + # GR1 isolated [30:88] (58D, has StateActionSinCosTransform - different pipeline) + "gr1": slice(30, 88), + # Fallback (single-arm -> right slot) + "new_embodiment": slice(14, 22), +} + + +def standardize_action_representation( + action: np.ndarray, embodiment_tag: str +) -> np.ndarray: + """Map per-robot action to a fixed-size standard action vector.""" + target_slice = ACTION_REPRESENTATION_SLICES.get(embodiment_tag) + + # Fallback to 'new_embodiment' if tag not found, or raise error + if target_slice is None: + if "new_embodiment" in ACTION_REPRESENTATION_SLICES: + target_slice = ACTION_REPRESENTATION_SLICES["new_embodiment"] + else: + raise ValueError( + f"Unknown embodiment tag '{embodiment_tag}' for action mapping. " + f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES)}" + ) + + expected_dim = target_slice.stop - target_slice.start + if action.shape[-1] != expected_dim: + raise ValueError( + f"Action dim mismatch for tag '{embodiment_tag}': " + f"{action.shape[-1]=} vs expected {expected_dim}." + ) + + standard = np.zeros( + (*action.shape[:-1], STANDARD_ACTION_DIM), dtype=action.dtype + ) + standard[..., target_slice] = action + return standard + + +def standardize_state_representation( + state: np.ndarray, embodiment_tag: str +) -> np.ndarray: + """Map per-robot state to a fixed-size standard state vector.""" + + target_slice = STATE_REPRESENTATION_SLICES.get(embodiment_tag) + + # Fallback to 'new_embodiment' if tag not found, or raise error + if target_slice is None: + if "new_embodiment" in STATE_REPRESENTATION_SLICES: + target_slice = STATE_REPRESENTATION_SLICES["new_embodiment"] + else: + raise ValueError( + f"Unknown embodiment tag '{embodiment_tag}' for state mapping. " + f"Known tags: {sorted(STATE_REPRESENTATION_SLICES)}" + ) + + expected_dim = target_slice.stop - target_slice.start + if state.shape[-1] != expected_dim: + raise ValueError( + f"State dim mismatch for tag '{embodiment_tag}': " + f"{state.shape[-1]=} vs expected {expected_dim}." + ) + + standard = np.zeros( + (*state.shape[:-1], STANDARD_STATE_DIM), dtype=state.dtype + ) + standard[..., target_slice] = state + return standard + + +def calculate_dataset_statistics(parquet_paths: list[Path]) -> dict: + """Calculate the dataset statistics of all columns for a list of parquet files.""" + # Dataset statistics + all_low_dim_data_list = [] + # Collect all the data + # parquet_paths = parquet_paths[:3] + for parquet_path in tqdm( + sorted(list(parquet_paths)), + desc="Collecting all parquet files...", + ): + # Load the parquet file + parquet_data = pd.read_parquet(parquet_path) + parquet_data = parquet_data + all_low_dim_data_list.append(parquet_data) + + all_low_dim_data = pd.concat(all_low_dim_data_list, axis=0) + # Compute dataset statistics + dataset_statistics = {} + for le_modality in all_low_dim_data.columns: + if le_modality.startswith("annotation."): + continue + print(f"Computing statistics for {le_modality}...") + np_data = np.vstack( + [np.asarray(x, dtype=np.float32) for x in all_low_dim_data[le_modality]] + ) + dataset_statistics[le_modality] = { + "mean": np.mean(np_data, axis=0).tolist(), + "std": np.std(np_data, axis=0).tolist(), + "min": np.min(np_data, axis=0).tolist(), + "max": np.max(np_data, axis=0).tolist(), + "q01": np.quantile(np_data, 0.01, axis=0).tolist(), + "q99": np.quantile(np_data, 0.99, axis=0).tolist(), + } + return dataset_statistics + + +class ModalityConfig(BaseModel): + """Configuration for a modality.""" + + delta_indices: list[int] + """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices.""" + modality_keys: list[str] + """The keys to load for the modality in the dataset.""" + + +class LeRobotSingleDataset(Dataset): + """ + Base dataset class for LeRobot that supports sharding. + """ + def __init__( + self, + dataset_path: Path | str, + modality_configs: dict[str, ModalityConfig], + embodiment_tag: str | EmbodimentTag, + video_backend: str = "decord", + video_backend_kwargs: dict | None = None, + transforms: ComposedModalityTransform | None = None, + delete_pause_frame: bool = False, + **kwargs, + ): + """ + Initialize the dataset. + + Args: + dataset_path (Path | str): The path to the dataset. + modality_configs (dict[str, ModalityConfig]): The configuration for each modality. The keys are the modality names, and the values are the modality configurations. + See `ModalityConfig` for more details. + video_backend (str): Backend for video reading. + video_backend_kwargs (dict): Keyword arguments for the video backend when initializing the video reader. + transforms (ComposedModalityTransform): The transforms to apply to the dataset. + embodiment_tag (EmbodimentTag): Overload the embodiment tag for the dataset. e.g. define it as "new_embodiment" + """ + # first check if the path directory exists + if not Path(dataset_path).exists(): + raise FileNotFoundError(f"Dataset path {dataset_path} does not exist") + data_cfg = kwargs.get("data_cfg", {}) or {} + # indict letobot version + self._lerobot_version = data_cfg.get("lerobot_version", "v2.0") #self._indict_lerobot_version(**kwargs) + self.load_video = data_cfg.get("load_video", True) + + self.delete_pause_frame = delete_pause_frame + + # If video loading is disabled, skip video modality end-to-end. + if self.load_video: + self.modality_configs = modality_configs + else: + self.modality_configs = { + modality: config + for modality, config in modality_configs.items() + if modality != "video" + } + self.video_backend = video_backend + self.video_backend_kwargs = video_backend_kwargs if video_backend_kwargs is not None else {} + self.transforms = ( + transforms if transforms is not None else ComposedModalityTransform(transforms=[]) + ) + + self._dataset_path = Path(dataset_path) + self._dataset_name = self._dataset_path.name + self._dataset_id = DATASET_NAME_TO_ID.get(self._dataset_name) + if isinstance(embodiment_tag, EmbodimentTag): + self.tag = embodiment_tag.value + else: + self.tag = embodiment_tag + + self._metadata = self._get_metadata(EmbodimentTag(self.tag)) + + # LeRobot-specific config + self._lerobot_modality_meta = self._get_lerobot_modality_meta() + self._lerobot_info_meta = self._get_lerobot_info_meta() + self._data_path_pattern = self._get_data_path_pattern() + self._video_path_pattern = self._get_video_path_pattern() + self._chunk_size = self._get_chunk_size() + self._tasks = self._get_tasks() + self.curr_traj_data = None + self.curr_traj_id = None + + self._trajectory_ids, self._trajectory_lengths = self._get_trajectories() + self._modality_keys = self._get_modality_keys() + self._delta_indices = self._get_delta_indices() + self._all_steps = self._get_all_steps() + self.set_transforms_metadata(self.metadata) + self.set_epoch(0) + + print(f"Initialized dataset {self.dataset_name} with {embodiment_tag}") + + + # Check if the dataset is valid + self._check_integrity() + + @property + def dataset_path(self) -> Path: + """The path to the dataset that contains the METADATA_FILENAME file.""" + return self._dataset_path + + @property + def metadata(self) -> DatasetMetadata: + """The metadata for the dataset, loaded from metadata.json in the dataset directory""" + return self._metadata + + @property + def trajectory_ids(self) -> np.ndarray: + """The trajectory IDs in the dataset, stored as a 1D numpy array of strings.""" + return self._trajectory_ids + + @property + def trajectory_lengths(self) -> np.ndarray: + """The trajectory lengths in the dataset, stored as a 1D numpy array of integers. + The order of the lengths is the same as the order of the trajectory IDs. + """ + return self._trajectory_lengths + + @property + def all_steps(self) -> list[tuple[int, int]]: + """The trajectory IDs and base indices for all steps in the dataset. + Example: + self.trajectory_ids: [0, 1, 2] + self.trajectory_lengths: [3, 2, 4] + return: [ + ("traj_0", 0), ("traj_0", 1), ("traj_0", 2), + ("traj_1", 0), ("traj_1", 1), + ("traj_2", 0), ("traj_2", 1), ("traj_2", 2), ("traj_2", 3) + ] + """ + return self._all_steps + + @property + def modality_keys(self) -> dict: + """The modality keys for the dataset. The keys are the modality names, and the values are the keys for each modality. + + Example: { + "video": ["video.image_side_0", "video.image_side_1"], + "state": ["state.eef_position", "state.eef_rotation"], + "action": ["action.eef_position", "action.eef_rotation"], + "language": ["language.human.task"], + "timestamp": ["timestamp"], + "reward": ["reward"], + } + """ + return self._modality_keys + + @property + def delta_indices(self) -> dict[str, np.ndarray]: + """The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key.""" + return self._delta_indices + + @property + def dataset_name(self) -> str: + """The name of the dataset.""" + return self._dataset_name + + @property + def lerobot_modality_meta(self) -> LeRobotModalityMetadata: + """The metadata for the LeRobot dataset.""" + return self._lerobot_modality_meta + + @property + def lerobot_info_meta(self) -> dict: + """The metadata for the LeRobot dataset.""" + return self._lerobot_info_meta + + @property + def data_path_pattern(self) -> str: + """The path pattern for the LeRobot dataset.""" + return self._data_path_pattern + + @property + def video_path_pattern(self) -> str: + """The path pattern for the LeRobot dataset.""" + return self._video_path_pattern + + @property + def chunk_size(self) -> int: + """The chunk size for the LeRobot dataset.""" + return self._chunk_size + + @property + def tasks(self) -> pd.DataFrame: + """The tasks for the dataset.""" + return self._tasks + + def _get_metadata(self, embodiment_tag: EmbodimentTag) -> DatasetMetadata: + """Get the metadata for the dataset. + + Returns: + dict: The metadata for the dataset. + """ + + # 1. Modality metadata + modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME + assert ( + modality_meta_path.exists() + ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}" + # 1.1. State and action modalities + simplified_modality_meta: dict[str, dict] = {} + with open(modality_meta_path, "r") as f: + le_modality_meta = LeRobotModalityMetadata.model_validate(json.load(f)) + for modality in ["state", "action"]: + simplified_modality_meta[modality] = {} + le_state_action_meta: dict[str, LeRobotStateActionMetadata] = getattr( + le_modality_meta, modality + ) + for subkey in le_state_action_meta: + state_action_dtype = np.dtype(le_state_action_meta[subkey].dtype) + if np.issubdtype(state_action_dtype, np.floating): + continuous = True + else: + continuous = False + simplified_modality_meta[modality][subkey] = { + "absolute": le_state_action_meta[subkey].absolute, + "rotation_type": le_state_action_meta[subkey].rotation_type, + "shape": [ + le_state_action_meta[subkey].end - le_state_action_meta[subkey].start + ], + "continuous": continuous, + } + + # 1.2. Video modalities + le_info_path = self.dataset_path / LE_ROBOT_INFO_FILENAME + assert ( + le_info_path.exists() + ), f"Please provide a {LE_ROBOT_INFO_FILENAME} file in {self.dataset_path}" + with open(le_info_path, "r") as f: + le_info = json.load(f) + simplified_modality_meta["video"] = {} + for new_key in le_modality_meta.video: + original_key = le_modality_meta.video[new_key].original_key + if original_key is None: + original_key = new_key + le_video_meta = le_info["features"][original_key] + height = le_video_meta["shape"][le_video_meta["names"].index("height")] + width = le_video_meta["shape"][le_video_meta["names"].index("width")] + # NOTE(FH): different lerobot dataset versions have different keys for the number of channels and fps + try: + channels = le_video_meta["shape"][le_video_meta["names"].index("channel")] + fps = le_video_meta["video_info"]["video.fps"] + except (ValueError, KeyError): + # channels = le_video_meta["shape"][le_video_meta["names"].index("channels")] + channels = le_video_meta["info"]["video.channels"] + fps = le_video_meta["info"]["video.fps"] + simplified_modality_meta["video"][new_key] = { + "resolution": [width, height], + "channels": channels, + "fps": fps, + } + + # 2. Dataset statistics + stats_path = self.dataset_path / LE_ROBOT_STATS_FILENAME + try: + with open(stats_path, "r") as f: + le_statistics = json.load(f) + for stat in le_statistics.values(): + DatasetStatisticalValues.model_validate(stat) + except (FileNotFoundError, ValidationError) as e: + print(f"Failed to load dataset statistics: {e}") + print(f"Calculating dataset statistics for {self.dataset_name}") + # Get all parquet files in the dataset paths + parquet_files = list((self.dataset_path).glob(LE_ROBOT_DATA_FILENAME)) + parquet_files_filtered = [] + # parquet_files[0].name = "episode_033675.parquet" is broken file + for pf in parquet_files: + if "episode_033675.parquet" in pf.name: + continue + parquet_files_filtered.append(pf) + + le_statistics = calculate_dataset_statistics(parquet_files_filtered) + with open(stats_path, "w") as f: + json.dump(le_statistics, f, indent=4) + dataset_statistics = {} + for our_modality in ["state", "action"]: + dataset_statistics[our_modality] = {} + for subkey in simplified_modality_meta[our_modality]: + dataset_statistics[our_modality][subkey] = {} + state_action_meta = le_modality_meta.get_key_meta(f"{our_modality}.{subkey}") + assert isinstance(state_action_meta, LeRobotStateActionMetadata) + le_modality = state_action_meta.original_key + for stat_name in le_statistics[le_modality]: + indices = np.arange( + state_action_meta.start, + state_action_meta.end, + ) + stat = np.array(le_statistics[le_modality][stat_name]) + dataset_statistics[our_modality][subkey][stat_name] = stat[indices].tolist() + + # 3. Full dataset metadata + metadata = DatasetMetadata( + statistics=dataset_statistics, # type: ignore + modalities=simplified_modality_meta, # type: ignore + embodiment_tag=embodiment_tag, + ) + + return metadata + + def _get_trajectories(self) -> tuple[np.ndarray, np.ndarray]: + """Get the trajectories in the dataset.""" + # Get trajectory lengths, IDs, and whitelist from dataset metadata + # v2.0 + if self._lerobot_version == "v2.0": + file_path = self.dataset_path / LE_ROBOT_EPISODE_FILENAME + with open(file_path, "r") as f: + episode_metadata = [json.loads(line) for line in f] + trajectory_ids = [] + trajectory_lengths = [] + for episode in episode_metadata: + trajectory_ids.append(episode["episode_index"]) + trajectory_lengths.append(episode["length"]) + return np.array(trajectory_ids), np.array(trajectory_lengths) + # v3.0 + elif self._lerobot_version == "v3.0": + file_paths = list((self.dataset_path).glob(LE_ROBOT3_EPISODE_FILENAME)) + trajectory_ids = [] + trajectory_lengths = [] + # data_chunck_index = [] + # data_file_index = [] + # vido_from_index = [] + self.trajectory_ids_to_metadata = {} + for file_path in file_paths: + episodes_data = pd.read_parquet(file_path) + for index, episode in episodes_data.iterrows(): + trajectory_ids.append(episode["episode_index"]) + trajectory_lengths.append(episode["length"]) + + # TODO auto map key? just map to file_path and file_from_index + episode_meta = { + "data/chunk_index": episode["data/chunk_index"], + "data/file_index": episode["data/file_index"], + "data/file_from_index": index, + } + if self.load_video: + episode_meta["videos/observation.images.wrist/from_timestamp"] = episode[ + "videos/observation.images.wrist/from_timestamp" + ] + self.trajectory_ids_to_metadata[trajectory_ids[-1]] = episode_meta + + # 这里应该可以直接读取到 save index 信息 + return np.array(trajectory_ids), np.array(trajectory_lengths) + + def _get_all_steps(self) -> list[tuple[int, int]]: + """Get the trajectory IDs and base indices for all steps in the dataset. + + Returns: + list[tuple[str, int]]: A list of (trajectory_id, base_index) tuples. + """ + # Create a hash key based on configuration to ensure cache validity + config_key = self._get_steps_config_key() + + # Create a unique filename based on config_key + # steps_filename = f"steps_{config_key}.pkl" + # @BUG + # fast get static steps @fangjing --> don't use hash to dynamic sample + steps_filename = "steps_data_index.pkl" + + + steps_path = self.dataset_path / "meta" / steps_filename + + # Try to load cached steps first + try: + if steps_path.exists(): + with open(steps_path, "rb") as f: + cached_data = pickle.load(f) + return cached_data["steps"] + + except (FileNotFoundError, pickle.PickleError, KeyError) as e: + print(f"Failed to load cached steps: {e}") + print("Computing steps from scratch...") + + # Compute steps using single process + all_steps = self._get_all_steps_single_process() + + # Cache the computed steps with unique filename + try: + cache_data = { + "config_key": config_key, + "steps": all_steps, + "num_trajectories": len(self.trajectory_ids), + "total_steps": len(all_steps), + "computed_timestamp": pd.Timestamp.now().isoformat(), + "delete_pause_frame": self.delete_pause_frame, + } + + # Ensure the meta directory exists + steps_path.parent.mkdir(parents=True, exist_ok=True) + + with open(steps_path, "wb") as f: + pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL) + print(f"Cached steps saved to {steps_path}") + except Exception as e: + print(f"Failed to cache steps: {e}") + + return all_steps + + def _get_steps_config_key(self) -> str: + """Generate a configuration key for steps caching.""" + config_dict = { + "delete_pause_frame": self.delete_pause_frame, + "dataset_name": self.dataset_name, + } + # Create a hash of the configuration + config_str = str(sorted(config_dict.items())) + return hashlib.md5(config_str.encode()).hexdigest()[:12] # + + + def _get_all_steps_single_process(self) -> list[tuple[int, int]]: + """Original single-process implementation as fallback.""" + all_steps: list[tuple[int, int]] = [] + skipped_trajectories = 0 + processed_trajectories = 0 + + # Check if language modality is configured + has_language_modality = 'language' in self.modality_keys and len(self.modality_keys['language']) > 0 + # TODO why trajectory_length here, why not use data length? + for trajectory_id, trajectory_length in tqdm(zip(self.trajectory_ids, self.trajectory_lengths), total=len(self.trajectory_ids), desc="Getting All Step"): + try: + if self._lerobot_version == "v2.0": + data = self.get_trajectory_data(trajectory_id) + elif self._lerobot_version == "v3.0": + data = self.get_trajectory_data_lerobot_v3(trajectory_id) + + trajectory_skipped = False + + # Check if trajectory has valid language instruction (if language modality is configured) + if has_language_modality: + self.curr_traj_data = data # Set current trajectory data for get_language to work + + language_instruction = self.get_language(trajectory_id, self.modality_keys['language'][0], 0) + if not language_instruction or language_instruction[0] == "": + print(f"Skipping trajectory {trajectory_id} due to empty language instruction") + skipped_trajectories += 1 + trajectory_skipped = True + continue + + except Exception as e: + print(f"Skipping trajectory {trajectory_id} due to read error: {e}") + skipped_trajectories += 1 + trajectory_skipped = True + continue + + if not trajectory_skipped: + processed_trajectories += 1 + + for base_index in range(trajectory_length): + all_steps.append((trajectory_id, base_index)) + + # Print summary statistics + print(f"Single-process summary: Processed {processed_trajectories} trajectories, skipped {skipped_trajectories} empty trajectories") + print(f"Total steps: {len(all_steps)} from {len(self.trajectory_ids)} trajectories") + + return all_steps + + def _get_position_and_gripper_values(self, data: pd.DataFrame) -> tuple[list, list]: + """Get position and gripper values based on available columns in the dataset.""" + # Get action keys from modality_keys + action_keys = self.modality_keys.get('action', []) + + # Extract position data + delta_position_values = None + position_candidates = ['delta_eef_position'] + coordinate_candidates = ['x', 'y', 'z'] + + # First try combined position fields + for pos_key in position_candidates: + full_key = f"action.{pos_key}" + if full_key in action_keys: + try: + # Get the lerobot key for this modality + le_action_cfg = self.lerobot_modality_meta.action + subkey = pos_key + if subkey in le_action_cfg: + le_key = le_action_cfg[subkey].original_key or subkey + if le_key in data.columns: + data_array = np.stack(data[le_key]) + le_indices = np.arange(le_action_cfg[subkey].start, le_action_cfg[subkey].end) + filtered_data = data_array[:, le_indices] + delta_position_values = filtered_data.tolist() + break + except Exception: + continue + + # If combined fields not found, try individual x,y,z coordinates + if delta_position_values is None: + x_data, y_data, z_data = None, None, None + for coord in coordinate_candidates: + full_key = f"action.{coord}" + if full_key in action_keys: + try: + le_action_cfg = self.lerobot_modality_meta.action + if coord in le_action_cfg: + le_key = le_action_cfg[coord].original_key or coord + if le_key in data.columns: + data_array = np.stack(data[le_key]) + le_indices = np.arange(le_action_cfg[coord].start, le_action_cfg[coord].end) + coord_data = data_array[:, le_indices].flatten() + if coord == 'x': + x_data = coord_data + elif coord == 'y': + y_data = coord_data + elif coord == 'z': + z_data = coord_data + except Exception: + continue + + if x_data is not None and y_data is not None and z_data is not None: + delta_position_values = np.column_stack((x_data, y_data, z_data)).tolist() + + if delta_position_values is None: + # Fallback to the old hardcoded approach if metadata approach fails + if 'action.delta_eef_position' in data.columns: + delta_position_values = data['action.delta_eef_position'].to_numpy().tolist() + elif all(col in data.columns for col in ['action.x', 'action.y', 'action.z']): + x_vals = data['action.x'].to_numpy() + y_vals = data['action.y'].to_numpy() + z_vals = data['action.z'].to_numpy() + delta_position_values = np.column_stack((x_vals, y_vals, z_vals)).tolist() + else: + raise ValueError(f"No suitable position columns found. Available columns: {data.columns.tolist()}") + + # Extract gripper data + gripper_values = None + gripper_candidates = ['gripper_close', 'gripper'] + + for grip_key in gripper_candidates: + full_key = f"action.{grip_key}" + if full_key in action_keys: + try: + le_action_cfg = self.lerobot_modality_meta.action + if grip_key in le_action_cfg: + le_key = le_action_cfg[grip_key].original_key or grip_key + if le_key in data.columns: + data_array = np.stack(data[le_key]) + le_indices = np.arange(le_action_cfg[grip_key].start, le_action_cfg[grip_key].end) + gripper_data = data_array[:, le_indices].flatten() + gripper_values = gripper_data.tolist() + break + except Exception: + continue + + if gripper_values is None: + # Fallback to the old hardcoded approach if metadata approach fails + if 'action.gripper_close' in data.columns: + gripper_values = data['action.gripper_close'].to_numpy().tolist() + elif 'action.gripper' in data.columns: + gripper_values = data['action.gripper'].to_numpy().tolist() + else: + raise ValueError(f"No suitable gripper columns found. Available columns: {data.columns.tolist()}") + + return delta_position_values, gripper_values + + def _get_modality_keys(self) -> dict: + """Get the modality keys for the dataset. + The keys are the modality names, and the values are the keys for each modality. + See property `modality_keys` for the expected format. + """ + modality_keys = defaultdict(list) + for modality, config in self.modality_configs.items(): + modality_keys[modality] = config.modality_keys + return modality_keys + + def _get_delta_indices(self) -> dict[str, np.ndarray]: + """Restructure the delta indices to use modality.key as keys instead of just the modalities.""" + delta_indices: dict[str, np.ndarray] = {} + for config in self.modality_configs.values(): + for key in config.modality_keys: + delta_indices[key] = np.array(config.delta_indices) + return delta_indices + + def _get_lerobot_modality_meta(self) -> LeRobotModalityMetadata: + """Get the metadata for the LeRobot dataset.""" + modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME + assert ( + modality_meta_path.exists() + ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}" + with open(modality_meta_path, "r") as f: + modality_meta = LeRobotModalityMetadata.model_validate(json.load(f)) + return modality_meta + + def _get_lerobot_info_meta(self) -> dict: + """Get the metadata for the LeRobot dataset.""" + info_meta_path = self.dataset_path / LE_ROBOT_INFO_FILENAME + with open(info_meta_path, "r") as f: + info_meta = json.load(f) + return info_meta + + def _get_data_path_pattern(self) -> str: + """Get the data path pattern for the LeRobot dataset.""" + return self.lerobot_info_meta["data_path"] + + def _get_video_path_pattern(self) -> str: + """Get the video path pattern for the LeRobot dataset.""" + return self.lerobot_info_meta["video_path"] + + def _get_chunk_size(self) -> int: + """Get the chunk size for the LeRobot dataset.""" + return self.lerobot_info_meta["chunks_size"] + + def _get_tasks(self) -> pd.DataFrame: + """Get the tasks for the dataset.""" + if self._lerobot_version == "v2.0": + tasks_path = self.dataset_path / LE_ROBOT_TASKS_FILENAME + with open(tasks_path, "r") as f: + tasks = [json.loads(line) for line in f] + df = pd.DataFrame(tasks) + return df.set_index("task_index") + + elif self._lerobot_version == "v3.0": + tasks_path = self.dataset_path / LE_ROBOT3_TASKS_FILENAME + df = pd.read_parquet(tasks_path) + df = df.reset_index() # 把索引变成一列,列名通常为 'index' + df = df.rename(columns={'index': 'task'}) # 把 'index' 列重命名为 'task' + df = df[['task_index', 'task']] # 调整列顺序 + return df + def _check_integrity(self): + """Use the config to check if the keys are valid and detect silent data corruption.""" + ERROR_MSG_HEADER = f"Error occurred in initializing dataset {self.dataset_name}:\n" + + for modality_config in self.modality_configs.values(): + for key in modality_config.modality_keys: + if key == "lapa_action" or key == "dream_actions": + continue # no need for any metadata for lapa actions because it comes normalized + # Check if the key is valid + try: + self.lerobot_modality_meta.get_key_meta(key) + except Exception as e: + raise ValueError( + ERROR_MSG_HEADER + f"Unable to find key {key} in modality metadata:\n{e}" + ) + + def set_transforms_metadata(self, metadata: DatasetMetadata): + """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values.""" + self.transforms.set_metadata(metadata) + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset. + + Args: + epoch (int): The epoch to set. + """ + self.epoch = epoch + + def __len__(self) -> int: + """Get the total number of data points in the dataset. + + Returns: + int: the total number of data points in the dataset. + """ + return len(self.all_steps) + + def __str__(self) -> str: + """Get the description of the dataset.""" + return f"{self.dataset_name} ({len(self)} steps)" + + + def __getitem__(self, index: int) -> dict: + """Get the data for a single step in a trajectory. + + Args: + index (int): The index of the step to get. + + Returns: + dict: The data for the step. + """ + trajectory_id, base_index = self.all_steps[index] + data = self.get_step_data(trajectory_id, base_index) + + # Process all video keys dynamically + images = [] + for video_key in self.modality_keys.get("video", []): + image = data[video_key][0] + + image = Image.fromarray(image).resize((224, 224)) + images.append(image) + + # Get language and action data + language = data[self.modality_keys["language"][0]][0] + action = [] + for action_key in self.modality_keys["action"]: + action.append(data[action_key]) + action = np.concatenate(action, axis=1) + action = standardize_action_representation(action, self.tag) + + state = [] + for state_key in self.modality_keys["state"]: + state.append(data[state_key]) + state = np.concatenate(state, axis=1) + state = standardize_state_representation(state, self.tag) + + return dict(action=action, state=state, image=images, language=language, dataset_id=self._dataset_id) + + def get_step_data(self, trajectory_id: int, base_index: int) -> dict: + """Get the RAW data for a single step in a trajectory. No transforms are applied. + + Args: + trajectory_id (int): The name of the trajectory. + base_index (int): The base step index in the trajectory. + + Returns: + dict: The RAW data for the step. + + Example return: + { + "video": { + "video.image_side_0": [B, T, H, W, C], + "video.image_side_1": [B, T, H, W, C], + }, + "state": { + "state.eef_position": [B, T, state_dim], + "state.eef_rotation": [B, T, state_dim], + }, + "action": { + "action.eef_position": [B, T, action_dim], + "action.eef_rotation": [B, T, action_dim], + }, + } + """ + data = {} + # Get the data for all modalities # just for action base data + self.curr_traj_data = self.get_trajectory_data(trajectory_id) + # TODO @JinhuiYE The logic below is poorly implemented. Data reading should be directly based on curr_traj_data. + for modality in self.modality_keys: + # Get the data corresponding to each key in the modality + for key in self.modality_keys[modality]: + data[key] = self.get_data_by_modality(trajectory_id, modality, key, base_index) + return data + + def get_trajectory_data(self, trajectory_id: int) -> pd.DataFrame: + """Get the data for a trajectory.""" + if self._lerobot_version == "v2.0": + + if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None: + return self.curr_traj_data + else: + chunk_index = self.get_episode_chunk(trajectory_id) + parquet_path = self.dataset_path / self.data_path_pattern.format( + episode_chunk=chunk_index, episode_index=trajectory_id + ) + assert parquet_path.exists(), f"Parquet file not found at {parquet_path}" + return pd.read_parquet(parquet_path) + elif self._lerobot_version == "v3.0": + return self.get_trajectory_data_lerobot_v3(trajectory_id) + + def get_trajectory_data_lerobot_v3(self, trajectory_id: int) -> pd.DataFrame: + """Get the data for a trajectory from lerobot v3.""" + if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None: + return self.curr_traj_data + else: #TODO check detail later + chunk_index = self.get_episode_chunk(trajectory_id) + + file_index = self.get_episode_file_index(trajectory_id) + # file_from_index = self.get_episode_file_from_index(trajectory_id) + + + parquet_path = self.dataset_path / self.data_path_pattern.format( + chunk_index=chunk_index, file_index=file_index + ) + assert parquet_path.exists(), f"Parquet file not found at {parquet_path}" + file_data = pd.read_parquet(parquet_path) + + # filter by trajectory_id + episode_data = file_data.loc[file_data["episode_index"] == trajectory_id].copy() + + # fix timestamp from epis index to file index for video alignment + if self.load_video: + from_timestamp = self.trajectory_ids_to_metadata[trajectory_id].get( + "videos/observation.images.wrist/from_timestamp", 0 + ) + episode_data["timestamp"] = episode_data["timestamp"] + from_timestamp + + return episode_data + + + def get_trajectory_index(self, trajectory_id: int) -> int: + """Get the index of the trajectory in the dataset by the trajectory ID. + This is useful when you need to get the trajectory length or sampling weight corresponding to the trajectory ID. + + Args: + trajectory_id (str): The ID of the trajectory. + + Returns: + int: The index of the trajectory in the dataset. + """ + trajectory_indices = np.where(self.trajectory_ids == trajectory_id)[0] + if len(trajectory_indices) != 1: + raise ValueError( + f"Error finding trajectory index for {trajectory_id}, found {trajectory_indices=}" + ) + return trajectory_indices[0] + + def get_episode_chunk(self, ep_index: int) -> int: + """Get the chunk index for an episode index.""" + return ep_index // self.chunk_size + def get_episode_file_index(self, ep_index: int) -> int: + """Get the file index for an episode index.""" + episode_meta = self.trajectory_ids_to_metadata[ep_index] + return episode_meta["data/file_index"] + + def get_episode_file_from_index(self, ep_index: int) -> int: + """Get the file from index for an episode index.""" + episode_meta = self.trajectory_ids_to_metadata[ep_index] + return episode_meta["data/file_from_index"] + + + def retrieve_data_and_pad( + self, + array: np.ndarray, + step_indices: np.ndarray, + max_length: int, + padding_strategy: str = "first_last", + ) -> np.ndarray: + """Retrieve the data from the dataset and pad it if necessary. + Args: + array (np.ndarray): The array to retrieve the data from. + step_indices (np.ndarray): The step indices to retrieve the data for. + max_length (int): The maximum length of the data. + padding_strategy (str): The padding strategy, either "first" or "last". + """ + # Get the padding indices + front_padding_indices = step_indices < 0 + end_padding_indices = step_indices >= max_length + padding_positions = np.logical_or(front_padding_indices, end_padding_indices) + # Retrieve the data with the non-padding indices + # If there exists some padding, Given T step_indices, the shape of the retrieved data will be (T', ...) where T' < T + raw_data = array[step_indices[~padding_positions]] + assert isinstance(raw_data, np.ndarray), f"{type(raw_data)=}" + # This is the shape of the output, (T, ...) + if raw_data.ndim == 1: + expected_shape = (len(step_indices),) + else: + expected_shape = (len(step_indices), *array.shape[1:]) + + # Pad the data + output = np.zeros(expected_shape) + # Assign the non-padded data + output[~padding_positions] = raw_data + # If there exists some padding, pad the data + if padding_positions.any(): + if padding_strategy == "first_last": + # Use first / last step data to pad + front_padding_data = array[0] + end_padding_data = array[-1] + output[front_padding_indices] = front_padding_data + output[end_padding_indices] = end_padding_data + elif padding_strategy == "zero": + # Use zero padding + output[padding_positions] = 0 + else: + raise ValueError(f"Invalid padding strategy: {padding_strategy}") + return output + + def get_video_path(self, trajectory_id: int, key: str) -> Path: + chunk_index = self.get_episode_chunk(trajectory_id) + original_key = self.lerobot_modality_meta.video[key].original_key + if original_key is None: + original_key = key + if self._lerobot_version == "v2.0": + video_filename = self.video_path_pattern.format( + episode_chunk=chunk_index, episode_index=trajectory_id, video_key=original_key + ) + elif self._lerobot_version == "v3.0": + episode_meta = self.trajectory_ids_to_metadata[trajectory_id] + video_filename = self.video_path_pattern.format( + video_key=original_key, + chunk_index=episode_meta["data/chunk_index"], + file_index=episode_meta["data/file_index"], + ) + return self.dataset_path / video_filename + + def get_video( + self, + trajectory_id: int, + key: str, + base_index: int, + ) -> np.ndarray: + """Get the video frames for a trajectory by a base index. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (str): The ID of the trajectory. + key (str): The key of the video. + base_index (int): The base index of the trajectory. + + Returns: + np.ndarray: The video frames for the trajectory and frame indices. Shape: (T, H, W, C) + """ + # Get the step indices + step_indices = self.delta_indices[key] + base_index + # print(f"{step_indices=}") + # Get the trajectory index + trajectory_index = self.get_trajectory_index(trajectory_id) + # Ensure the indices are within the valid range + # This is equivalent to padding the video with extra frames at the beginning and end + step_indices = np.maximum(step_indices, 0) + step_indices = np.minimum(step_indices, self.trajectory_lengths[trajectory_index] - 1) + assert key.startswith("video."), f"Video key must start with 'video.', got {key}" + # Get the sub-key + key = key.replace("video.", "") + video_path = self.get_video_path(trajectory_id, key) + # Get the action/state timestamps for each frame in the video + assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" + assert "timestamp" in self.curr_traj_data.columns, f"No timestamp found in {trajectory_id=}" + timestamp: np.ndarray = self.curr_traj_data["timestamp"].to_numpy() + # Get the corresponding video timestamps from the step indices + video_timestamp = timestamp[step_indices] + + return get_frames_by_timestamps( + video_path.as_posix(), + video_timestamp, + video_backend=self.video_backend, # TODO + video_backend_kwargs=self.video_backend_kwargs, + ) + + def get_state_or_action( + self, + trajectory_id: int, + modality: str, + key: str, + base_index: int, + ) -> np.ndarray: + """Get the state or action data for a trajectory by a base index. + If the step indices are out of range, pad with the data: + if the data is stored in absolute format, pad with the first or last step data; + otherwise, pad with zero. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (int): The ID of the trajectory. + modality (str): The modality of the data. + key (str): The key of the data. + base_index (int): The base index of the trajectory. + + Returns: + np.ndarray: The data for the trajectory and step indices. + """ + # Get the step indices + step_indices = self.delta_indices[key] + base_index + # Get the trajectory index + trajectory_index = self.get_trajectory_index(trajectory_id) + # Get the maximum length of the trajectory + max_length = self.trajectory_lengths[trajectory_index] + assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}" + # Get the sub-key, e.g. state.joint_angles -> joint_angles + key = key.replace(modality + ".", "") + # Get the lerobot key + le_state_or_action_cfg = getattr(self.lerobot_modality_meta, modality) + le_key = le_state_or_action_cfg[key].original_key + if le_key is None: + le_key = key + # Get the data array, shape: (T, D) + assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" + assert le_key in self.curr_traj_data.columns, f"No {le_key} found in {trajectory_id=}" + data_array: np.ndarray = np.stack(self.curr_traj_data[le_key]) # type: ignore + assert data_array.ndim == 2, f"Expected 2D array, got key {le_key} is{data_array.shape} array" + le_indices = np.arange( + le_state_or_action_cfg[key].start, + le_state_or_action_cfg[key].end, + ) + data_array = data_array[:, le_indices] + # Get the state or action configuration + state_or_action_cfg = getattr(self.metadata.modalities, modality)[key] + + # Pad the data + return self.retrieve_data_and_pad( + array=data_array, + step_indices=step_indices, + max_length=max_length, + padding_strategy="first_last" if state_or_action_cfg.absolute else "zero", + # padding_strategy="zero", # HACK for realdata + ) + + def get_language( + self, + trajectory_id: int, + key: str, + base_index: int, + ) -> list[str]: + """Get the language annotation data for a trajectory by step indices. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (int): The ID of the trajectory. + key (str): The key of the annotation. + base_index (int): The base index of the trajectory. + + Returns: + list[str]: The annotation data for the trajectory and step indices. If no matching data is found, return empty strings. + """ + assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" + # Get the step indices + step_indices = self.delta_indices[key] + base_index + # Get the trajectory index + trajectory_index = self.get_trajectory_index(trajectory_id) + # Get the maximum length of the trajectory + max_length = self.trajectory_lengths[trajectory_index] + # Get the end times corresponding to the closest indices + step_indices = np.maximum(step_indices, 0) + step_indices = np.minimum(step_indices, max_length - 1) + # Get the annotations + task_indices: list[int] = [] + assert key.startswith( + "annotation." + ), f"Language key must start with 'annotation.', got {key}" + subkey = key.replace("annotation.", "") + annotation_meta = self.lerobot_modality_meta.annotation + assert annotation_meta is not None, f"Annotation metadata is None for {subkey}" + assert ( + subkey in annotation_meta + ), f"Annotation key {subkey} not found in metadata, available annotation keys: {annotation_meta.keys()}" + subkey_meta = annotation_meta[subkey] + original_key = subkey_meta.original_key + if original_key is None: + original_key = key + for i in range(len(step_indices)): # + # task_indices.append(self.curr_traj_data[original_key][step_indices[i]].item()) + value = self.curr_traj_data[original_key].iloc[step_indices[i]] # TODO check v2.0 + task_indices.append(value if isinstance(value, (int, float)) else value.item()) + + return self.tasks.loc[task_indices]["task"].tolist() + + def get_data_by_modality( + self, + trajectory_id: int, + modality: str, + key: str, + base_index: int, + ): + """Get the data corresponding to the modality for a trajectory by a base index. + This method will call the corresponding helper method based on the modality. + See the helper methods for more details. + NOTE: For the language modality, the data is padded with empty strings if no matching data is found. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (int): The ID of the trajectory. + modality (str): The modality of the data. + key (str): The key of the data. + base_index (int): The base index of the trajectory. + """ + if modality == "video": + return self.get_video(trajectory_id, key, base_index) + elif modality == "state" or modality == "action": + return self.get_state_or_action(trajectory_id, modality, key, base_index) + elif modality == "language": + return self.get_language(trajectory_id, key, base_index) + else: + raise ValueError(f"Invalid modality: {modality}") + + def _save_dataset_statistics_(self, save_path: Path | str, format: str = "json") -> None: + """ + Save dataset statistics to specified path in the required format. + Only includes statistics for keys that are actually used in the dataset. + Key order follows modality config order. + + Args: + save_path (Path | str): Path to save the statistics file + format (str): Save format, currently only supports "json" + """ + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + # Build the data structure to save + statistics_data = {} + + # Get used modality keys + used_action_keys, used_state_keys = get_used_modality_keys(self.modality_keys) + + # Organize statistics by tag + tag = self.tag + tag_stats = {} + + # Process action statistics (only for used keys, config order) + if hasattr(self.metadata.statistics, 'action') and self.metadata.statistics.action: + action_stats = self.metadata.statistics.action + filtered_action_stats = { + key: action_stats[key] + for key in used_action_keys + if key in action_stats + } + + if filtered_action_stats: + # Combine statistics from filtered action sub-keys + combined_action_stats = combine_modality_stats(filtered_action_stats) + + # Add mask field based on whether it's gripper or not + mask = generate_action_mask_for_used_keys( + self.metadata.modalities.action, filtered_action_stats.keys() + ) + combined_action_stats["mask"] = mask + + tag_stats["action"] = combined_action_stats + + # Process state statistics (only for used keys, config order) + if hasattr(self.metadata.statistics, 'state') and self.metadata.statistics.state: + state_stats = self.metadata.statistics.state + filtered_state_stats = { + key: state_stats[key] + for key in used_state_keys + if key in state_stats + } + + if filtered_state_stats: + combined_state_stats = combine_modality_stats(filtered_state_stats) + tag_stats["state"] = combined_state_stats + + # Add dataset counts + tag_stats["num_transitions"] = len(self) + tag_stats["num_trajectories"] = len(self.trajectory_ids) + + statistics_data[tag] = tag_stats + + # Save as JSON file + if format.lower() == "json": + if not str(save_path).endswith('.json'): + save_path = save_path.with_suffix('.json') + with open(save_path, 'w', encoding='utf-8') as f: + json.dump(statistics_data, f, indent=2, ensure_ascii=False) + else: + raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.") + + print(f"Single dataset statistics saved to: {save_path}") + print(f"Used action keys (reordered): {list(used_action_keys)}") + print(f"Used state keys (reordered): {list(used_state_keys)}") + + + +class MixtureSpecElement(BaseModel): + dataset_path: list[Path] | Path = Field(..., description="The path to the dataset.") + dataset_weight: float = Field(..., description="The weight of the dataset in the mixture.") + distribute_weights: bool = Field( + default=False, + description="Whether to distribute the weights of the dataset across all the paths. If True, the weights will be evenly distributed across all the paths.", + ) + + +# Helper functions for dataset statistics + +def combine_modality_stats(modality_stats: dict) -> dict: + """ + Combine statistics from all sub-keys under a modality. + + Args: + modality_stats (dict): Statistics for a modality, containing multiple sub-keys. + Each sub-key contains DatasetStatisticalValues object. + + Returns: + dict: Combined statistics + """ + combined_stats = { + "mean": [], + "std": [], + "max": [], + "min": [], + "q01": [], + "q99": [] + } + + # Combine statistics in sub-key order + for subkey in modality_stats.keys(): + subkey_stats = modality_stats[subkey] # This is a DatasetStatisticalValues object + + # Convert DatasetStatisticalValues to dict-like access + for stat_name in ["mean", "std", "max", "min", "q01", "q99"]: + stat_value = getattr(subkey_stats, stat_name) + if isinstance(stat_value, (list, tuple)): + combined_stats[stat_name].extend(stat_value) + else: + # Handle NDArray case - convert to list + if hasattr(stat_value, 'tolist'): + combined_stats[stat_name].extend(stat_value.tolist()) + else: + combined_stats[stat_name].append(float(stat_value)) + + return combined_stats + +def generate_action_mask_for_used_keys(action_modalities: dict, used_action_keys_ordered) -> list[bool]: + """ + Generate mask based on action modalities, but only for used keys. + All dimensions are set to True so every channel is de/normalized. + + Args: + action_modalities (dict): Configuration information for action modalities. + used_action_keys_ordered: Iterable of actually used action keys in the correct order. + + Returns: + list[bool]: List of mask values + """ + mask = [] + + # Generate mask in the same order as the statistics were combined + for subkey in used_action_keys_ordered: + if subkey in action_modalities: + subkey_config = action_modalities[subkey] + + # Get dimension count from shape + if hasattr(subkey_config, 'shape') and len(subkey_config.shape) > 0: + dim_count = subkey_config.shape[0] + else: + dim_count = 1 + + # Check if it's gripper-related + is_gripper = "gripper" in subkey.lower() + + # Generate mask value for each dimension + for _ in range(dim_count): + mask.append(not is_gripper) # gripper is False, others are True + + return mask + +def get_used_modality_keys(modality_keys: dict) -> tuple[set, set]: + """Extract used action and state keys from modality configuration.""" + used_action_keys = [] + used_state_keys = [] + + # Extract action keys (remove "action." prefix) + for action_key in modality_keys.get("action", []): + if action_key.startswith("action."): + clean_key = action_key.replace("action.", "") + used_action_keys.append(clean_key) + + # Extract state keys (remove "state." prefix) + for state_key in modality_keys.get("state", []): + if state_key.startswith("state."): + clean_key = state_key.replace("state.", "") + used_state_keys.append(clean_key) + + return used_action_keys, used_state_keys + + +def safe_hash(input_tuple): + # keep 128 bits of the hash + tuple_string = repr(input_tuple).encode("utf-8") + sha256 = hashlib.sha256() + sha256.update(tuple_string) + + seed = int(sha256.hexdigest(), 16) + + return seed & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF + + +class LeRobotMixtureDataset(Dataset): + """ + A mixture of multiple datasets. This class samples a single dataset based on the dataset weights and then calls the `__getitem__` method of the sampled dataset. + It is recommended to modify the single dataset class instead of this class. + """ + + def __init__( + self, + data_mixture: Sequence[tuple[LeRobotSingleDataset, float]], + mode: str, + balance_dataset_weights: bool = True, + balance_trajectory_weights: bool = True, + seed: int = 42, + metadata_config: dict = { + "percentile_mixing_method": "min_max", + }, + **kwargs, + ): + """ + Initialize the mixture dataset. + + Args: + data_mixture (list[tuple[LeRobotSingleDataset, float]]): Datasets and their corresponding weights. + mode (str): If "train", __getitem__ will return different samples every epoch; if "val" or "test", __getitem__ will return the same sample every epoch. + balance_dataset_weights (bool): If True, the weight of dataset will be multiplied by the total trajectory length of each dataset. + balance_trajectory_weights (bool): If True, sample trajectories within a dataset weighted by their length; otherwise, use equal weighting. + seed (int): Random seed for sampling. + """ + datasets: list[LeRobotSingleDataset] = [] + dataset_sampling_weights: list[float] = [] + for dataset, weight in data_mixture: + # Check if dataset is valid and has data + if len(dataset) == 0: + print(f"Warning: Skipping empty dataset {dataset.dataset_name}") + continue + datasets.append(dataset) + dataset_sampling_weights.append(weight) + + if len(datasets) == 0: + raise ValueError("No valid datasets found in the mixture. All datasets are empty.") + + self.datasets = datasets + self.balance_dataset_weights = balance_dataset_weights + self.balance_trajectory_weights = balance_trajectory_weights + self.seed = seed + self.mode = mode + + # Set properties for sampling + + # 1. Dataset lengths + self._dataset_lengths = np.array([len(dataset) for dataset in self.datasets]) + print(f"Dataset lengths: {self._dataset_lengths}") + + # 2. Dataset sampling weights + self._dataset_sampling_weights = np.array(dataset_sampling_weights) + + if self.balance_dataset_weights: + self._dataset_sampling_weights *= self._dataset_lengths + + # Check for zero or negative weights before normalization + if np.any(self._dataset_sampling_weights <= 0): + print(f"Warning: Found zero or negative sampling weights: {self._dataset_sampling_weights}") + # Set minimum weight to prevent division issues + self._dataset_sampling_weights = np.maximum(self._dataset_sampling_weights, 1e-8) + + # Normalize weights + weights_sum = self._dataset_sampling_weights.sum() + if weights_sum == 0 or np.isnan(weights_sum): + print(f"Error: Invalid weights sum: {weights_sum}") + # Fallback to equal weights + self._dataset_sampling_weights = np.ones(len(self.datasets)) / len(self.datasets) + print(f"Fallback to equal weights") + else: + self._dataset_sampling_weights /= weights_sum + + # 3. Trajectory sampling weights + self._trajectory_sampling_weights: list[np.ndarray] = [] + for i, dataset in enumerate(self.datasets): + trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) + if self.balance_trajectory_weights: + trajectory_sampling_weights *= dataset.trajectory_lengths + + # Check for zero or negative weights before normalization + if np.any(trajectory_sampling_weights <= 0): + print(f"Warning: Dataset {i} has zero or negative trajectory weights") + trajectory_sampling_weights = np.maximum(trajectory_sampling_weights, 1e-8) + + # Normalize weights + weights_sum = trajectory_sampling_weights.sum() + if weights_sum == 0 or np.isnan(weights_sum): + print(f"Error: Dataset {i} has invalid trajectory weights sum: {weights_sum}") + # Fallback to equal weights + trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) / len(dataset.trajectory_lengths) + else: + trajectory_sampling_weights /= weights_sum + + self._trajectory_sampling_weights.append(trajectory_sampling_weights) + + # 4. Primary dataset indices + self._primary_dataset_indices = np.array(dataset_sampling_weights) == 1.0 + if not np.any(self._primary_dataset_indices): + print(f"Warning: No dataset with weight 1.0 found. Original weights: {dataset_sampling_weights}") + # Fallback: use the dataset(s) with maximum weight as primary + max_weight = max(dataset_sampling_weights) + self._primary_dataset_indices = np.array(dataset_sampling_weights) == max_weight + print(f"Using datasets with maximum weight {max_weight} as primary: {self._primary_dataset_indices}") + + if not np.any(self._primary_dataset_indices): + # This should never happen, but just in case + print("Error: Still no primary dataset found. Using first dataset as primary.") + self._primary_dataset_indices = np.zeros(len(self.datasets), dtype=bool) + self._primary_dataset_indices[0] = True + + # Set the epoch and sample the first epoch + self.set_epoch(0) + + self.update_metadata(metadata_config) + + @property + def dataset_lengths(self) -> np.ndarray: + """The lengths of each dataset.""" + return self._dataset_lengths + + @property + def dataset_sampling_weights(self) -> np.ndarray: + """The sampling weights for each dataset.""" + return self._dataset_sampling_weights + + @property + def trajectory_sampling_weights(self) -> list[np.ndarray]: + """The sampling weights for each trajectory in each dataset.""" + return self._trajectory_sampling_weights + + @property + def primary_dataset_indices(self) -> np.ndarray: + """The indices of the primary datasets.""" + return self._primary_dataset_indices + + def __str__(self) -> str: + dataset_descriptions = [] + for dataset, weight in zip(self.datasets, self.dataset_sampling_weights): + dataset_description = { + "Dataset": str(dataset), + "Sampling weight": float(weight), + } + dataset_descriptions.append(dataset_description) + return json.dumps({"Mixture dataset": dataset_descriptions}, indent=2) + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset. + + Args: + epoch (int): The epoch to set. + """ + self.epoch = epoch + # self.sampled_steps = self.sample_epoch() + + def sample_step(self, index: int) -> tuple[LeRobotSingleDataset, int, int]: + """Sample a single step from the dataset.""" + # return self.sampled_steps[index] + + # Set seed + seed = index if self.mode != "train" else safe_hash((self.epoch, index, self.seed)) + rng = np.random.default_rng(seed) + + # Sample dataset + dataset_index = rng.choice(len(self.datasets), p=self.dataset_sampling_weights) + dataset = self.datasets[dataset_index] + + # Sample trajectory + # trajectory_index = rng.choice( + # len(dataset.trajectory_ids), p=self.trajectory_sampling_weights[dataset_index] + # ) + # trajectory_id = dataset.trajectory_ids[trajectory_index] + + # # Sample step + # base_index = rng.choice(dataset.trajectory_lengths[trajectory_index]) + # return dataset, trajectory_id, base_index + single_step_index = rng.choice(len(dataset.all_steps)) + trajectory_id, base_index = dataset.all_steps[single_step_index] + return dataset, trajectory_id, base_index + + def __getitem__(self, index: int) -> dict: + """Get the data for a single trajectory and start index. + + Args: + index (int): The index of the trajectory to get. + + Returns: + dict: The data for the trajectory and start index. + """ + max_retries = 10 + last_exception = None + + for attempt in range(max_retries): + try: + dataset, trajectory_name, step = self.sample_step(index) + data_raw = dataset.get_step_data(trajectory_name, step) + data = dataset.transforms(data_raw) + + # Process all video keys dynamically + images = [] + for video_key in dataset.modality_keys.get("video", []): + image = data[video_key][0] + + image = Image.fromarray(image).resize((224, 224)) #TODO check if this is ok + images.append(image) + + # Get language and action data + language = data[dataset.modality_keys["language"][0]][0] + action = [] + for action_key in dataset.modality_keys["action"]: + action.append(data[action_key]) + action = np.concatenate(action, axis=1).astype(np.float16) + action = standardize_action_representation(action, dataset.tag) + + state = [] + for state_key in dataset.modality_keys["state"]: + state.append(data[state_key]) + state = np.concatenate(state, axis=1).astype(np.float16) + state = standardize_state_representation(state, dataset.tag) + + return dict(action=action, state=state, image=images, lang=language, dataset_id=dataset._dataset_id) + + except Exception as e: + last_exception = e + if attempt < max_retries - 1: + # Log the error but continue trying + print(f"Attempt {attempt + 1}/{max_retries} failed for index {index}: {e}") + print(f"Retrying with new sample...") + # For retry, we can use a slightly different index to get a new sample + # This helps avoid getting stuck on the same problematic sample + index = random.randint(0, len(self) - 1) + else: + # All retries exhausted + print(f"All {max_retries} attempts failed for index {index}") + print(f"Last error: {last_exception}") + # Return a dummy sample or re-raise the exception + raise last_exception + + def __len__(self) -> int: + """Get the length of a single epoch in the mixture. + + Returns: + int: The length of a single epoch in the mixture. + """ + # Check for potential issues + if len(self.datasets) == 0: + return 0 + + # Check if any dataset lengths are 0 or NaN + if np.any(self.dataset_lengths == 0) or np.any(np.isnan(self.dataset_lengths)): + print(f"Warning: Found zero or NaN dataset lengths: {self.dataset_lengths}") + # Filter out zero/NaN length datasets + valid_indices = (self.dataset_lengths > 0) & (~np.isnan(self.dataset_lengths)) + if not np.any(valid_indices): + print("Error: All datasets have zero or NaN length") + return 0 + else: + valid_indices = np.ones(len(self.datasets), dtype=bool) + + # Check if any sampling weights are 0 or NaN + if np.any(self.dataset_sampling_weights == 0) or np.any(np.isnan(self.dataset_sampling_weights)): + print(f"Warning: Found zero or NaN sampling weights: {self.dataset_sampling_weights}") + # Use only valid weights + valid_weights = (self.dataset_sampling_weights > 0) & (~np.isnan(self.dataset_sampling_weights)) + valid_indices = valid_indices & valid_weights + if not np.any(valid_indices): + print("Error: All sampling weights are zero or NaN") + return 0 + + # Check primary dataset indices + primary_and_valid = self.primary_dataset_indices & valid_indices + if not np.any(primary_and_valid): + print(f"Warning: No valid primary datasets found. Primary indices: {self.primary_dataset_indices}, Valid indices: {valid_indices}") + # Fallback: use the largest valid dataset + if np.any(valid_indices): + max_length = self.dataset_lengths[valid_indices].max() + print(f"Fallback: Using maximum dataset length: {max_length}") + return int(max_length) + else: + return 0 + + # Calculate the ratio and get max + ratios = (self.dataset_lengths / self.dataset_sampling_weights)[primary_and_valid] + + # Check for NaN or inf in ratios + if np.any(np.isnan(ratios)) or np.any(np.isinf(ratios)): + print(f"Warning: Found NaN or inf in ratios: {ratios}") + print(f"Dataset lengths: {self.dataset_lengths[primary_and_valid]}") + print(f"Sampling weights: {self.dataset_sampling_weights[primary_and_valid]}") + # Filter out invalid ratios + valid_ratios = ratios[~np.isnan(ratios) & ~np.isinf(ratios)] + if len(valid_ratios) == 0: + print("Error: All ratios are NaN or inf") + return 0 + max_ratio = valid_ratios.max() + else: + max_ratio = ratios.max() + + result = int(max_ratio) + if result == 0: + print(f"Warning: Dataset mixture length is 0") + return result + + @staticmethod + def compute_overall_statistics( + per_task_stats: list[dict[str, dict[str, list[float] | np.ndarray]]], + dataset_sampling_weights: list[float] | np.ndarray, + percentile_mixing_method: str = "weighted_average", + ) -> dict[str, dict[str, list[float]]]: + """ + Computes overall statistics from per-task statistics using dataset sample weights. + + Args: + per_task_stats: List of per-task statistics. + Example format of one element in the per-task statistics list: + { + "state.gripper": { + "min": [...], + "max": [...], + "mean": [...], + "std": [...], + "q01": [...], + "q99": [...], + }, + ... + } + dataset_sampling_weights: List of sample weights for each task. + percentile_mixing_method: The method to mix the percentiles, either "weighted_average" or "weighted_std". + + Returns: + A dict of overall statistics per modality. + """ + # Normalize the sample weights to sum to 1 + dataset_sampling_weights = np.array(dataset_sampling_weights) + normalized_weights = dataset_sampling_weights / dataset_sampling_weights.sum() + + # Initialize overall statistics dict + overall_stats: dict[str, dict[str, list[float]]] = {} + + # Get the list of modality keys + modality_keys = per_task_stats[0].keys() + + for modality in modality_keys: + # Number of dimensions (assuming consistent across tasks) + num_dims = len(per_task_stats[0][modality]["mean"]) + + # Initialize accumulators for means and variances + weighted_means = np.zeros(num_dims) + weighted_squares = np.zeros(num_dims) + + # Collect min, max, q01, q99 from all tasks + min_list = [] + max_list = [] + q01_list = [] + q99_list = [] + + for task_idx, task_stats in enumerate(per_task_stats): + w_i = normalized_weights[task_idx] + stats = task_stats[modality] + means = np.array(stats["mean"]) + stds = np.array(stats["std"]) + + # Update weighted sums for mean and variance + weighted_means += w_i * means + weighted_squares += w_i * (stds**2 + means**2) + + # Collect min, max, q01, q99 + min_list.append(stats["min"]) + max_list.append(stats["max"]) + q01_list.append(stats["q01"]) + q99_list.append(stats["q99"]) + + # Compute overall mean + overall_mean = weighted_means.tolist() + + # Compute overall variance and std deviation + overall_variance = weighted_squares - weighted_means**2 + overall_std = np.sqrt(overall_variance).tolist() + + # Compute overall min and max per dimension + overall_min = np.min(np.array(min_list), axis=0).tolist() + overall_max = np.max(np.array(max_list), axis=0).tolist() + + # Compute overall q01 and q99 per dimension + # Use weighted average of per-task quantiles + q01_array = np.array(q01_list) + q99_array = np.array(q99_list) + if percentile_mixing_method == "weighted_average": + weighted_q01 = np.average(q01_array, axis=0, weights=normalized_weights).tolist() + weighted_q99 = np.average(q99_array, axis=0, weights=normalized_weights).tolist() + # std_q01 = np.std(q01_array, axis=0).tolist() + # std_q99 = np.std(q99_array, axis=0).tolist() + # print(modality) + # print(f"{std_q01=}, {std_q99=}") + # print(f"{weighted_q01=}, {weighted_q99=}") + elif percentile_mixing_method == "min_max": + weighted_q01 = np.min(q01_array, axis=0).tolist() + weighted_q99 = np.max(q99_array, axis=0).tolist() + else: + raise ValueError(f"Invalid percentile mixing method: {percentile_mixing_method}") + + # Store the overall statistics for the modality + overall_stats[modality] = { + "min": overall_min, + "max": overall_max, + "mean": overall_mean, + "std": overall_std, + "q01": weighted_q01, + "q99": weighted_q99, + } + + return overall_stats + + @staticmethod + def merge_metadata( + metadatas: list[DatasetMetadata], + dataset_sampling_weights: list[float], + percentile_mixing_method: str, + ) -> DatasetMetadata: + """Merge multiple metadata into one.""" + # Convert to dicts + metadata_dicts = [metadata.model_dump(mode="json") for metadata in metadatas] + # Create a new metadata dict + merged_metadata = {} + + # Check all metadata have the same embodiment tag + assert all( + metadata.embodiment_tag == metadatas[0].embodiment_tag for metadata in metadatas + ), "All metadata must have the same embodiment tag" + merged_metadata["embodiment_tag"] = metadatas[0].embodiment_tag + + # Merge the dataset statistics + dataset_statistics = {} + dataset_statistics["state"] = LeRobotMixtureDataset.compute_overall_statistics( + per_task_stats=[m["statistics"]["state"] for m in metadata_dicts], + dataset_sampling_weights=dataset_sampling_weights, + percentile_mixing_method=percentile_mixing_method, + ) + dataset_statistics["action"] = LeRobotMixtureDataset.compute_overall_statistics( + per_task_stats=[m["statistics"]["action"] for m in metadata_dicts], + dataset_sampling_weights=dataset_sampling_weights, + percentile_mixing_method=percentile_mixing_method, + ) + merged_metadata["statistics"] = dataset_statistics + + # Merge the modality configs + modality_configs = defaultdict(set) + for metadata in metadata_dicts: + for modality, configs in metadata["modalities"].items(): + modality_configs[modality].add(json.dumps(configs)) + merged_metadata["modalities"] = {} + for modality, configs in modality_configs.items(): + # Check that all modality configs correspond to the same tag matches + assert ( + len(configs) == 1 + ), f"Multiple modality configs for modality {modality}: {list(configs)}" + merged_metadata["modalities"][modality] = json.loads(configs.pop()) + + return DatasetMetadata.model_validate(merged_metadata) + + def update_metadata(self, metadata_config: dict, cached_statistics_path: Path | str | None = None) -> None: + """ + Merge multiple metadatas into one and set the transforms with the merged metadata. + + Args: + metadata_config (dict): Configuration for the metadata. + "percentile_mixing_method": The method to mix the percentiles, either "weighted_average" or "min_max". + weighted_average: Use the weighted average of the percentiles using the weight used in sampling the datasets. + min_max: Use the min of the 1st percentile and max of the 99th percentile. + """ + # If cached path is provided, try to load and apply + if cached_statistics_path is not None: + try: + cached_stats = self.load_merged_statistics(cached_statistics_path) + self.apply_cached_statistics(cached_stats) + return + except (FileNotFoundError, KeyError, ValidationError) as e: + print(f"Failed to load cached statistics: {e}") + print("Falling back to computing statistics from scratch...") + + self.tag = EmbodimentTag.NEW_EMBODIMENT.value + self.merged_metadata: dict[str, DatasetMetadata] = {} + # Group metadata by tag + all_metadatas: dict[str, list[DatasetMetadata]] = {} + for dataset in self.datasets: + if dataset.tag not in all_metadatas: + all_metadatas[dataset.tag] = [] + all_metadatas[dataset.tag].append(dataset.metadata) + for tag, metadatas in all_metadatas.items(): + self.merged_metadata[tag] = self.merge_metadata( + metadatas=metadatas, + dataset_sampling_weights=self.dataset_sampling_weights.tolist(), + percentile_mixing_method=metadata_config["percentile_mixing_method"], + ) + for dataset in self.datasets: + dataset.set_transforms_metadata(self.merged_metadata[dataset.tag]) + + def save_dataset_statistics(self, save_path: Path | str, format: str = "json") -> None: + """ + Save merged dataset statistics to specified path in the required format. + Only includes statistics for keys that are actually used in the datasets. + Key order follows each tag's modality config order. + + Args: + save_path (Path | str): Path to save the statistics file + format (str): Save format, currently only supports "json" + """ + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + # Build the data structure to save + statistics_data = {} + + # Keep key orders per embodiment tag (from modality config order) + tag_to_used_action_keys = {} + tag_to_used_state_keys = {} + for dataset in self.datasets: + if dataset.tag in tag_to_used_action_keys: + continue + used_action_keys, used_state_keys = get_used_modality_keys(dataset.modality_keys) + tag_to_used_action_keys[dataset.tag] = used_action_keys + tag_to_used_state_keys[dataset.tag] = used_state_keys + + # Organize statistics by tag + for tag, merged_metadata in self.merged_metadata.items(): + tag_stats = {} + + # Process action statistics + if hasattr(merged_metadata.statistics, 'action') and merged_metadata.statistics.action: + action_stats = merged_metadata.statistics.action + + used_action_keys = tag_to_used_action_keys.get(tag, []) + filtered_action_stats = { + key: action_stats[key] + for key in used_action_keys + if key in action_stats + } + + if filtered_action_stats: + combined_action_stats = combine_modality_stats(filtered_action_stats) + + mask = generate_action_mask_for_used_keys( + merged_metadata.modalities.action, filtered_action_stats.keys() + ) + combined_action_stats["mask"] = mask + + tag_stats["action"] = combined_action_stats + + # Process state statistics + if hasattr(merged_metadata.statistics, 'state') and merged_metadata.statistics.state: + state_stats = merged_metadata.statistics.state + + used_state_keys = tag_to_used_state_keys.get(tag, []) + filtered_state_stats = { + key: state_stats[key] + for key in used_state_keys + if key in state_stats + } + + if filtered_state_stats: + combined_state_stats = combine_modality_stats(filtered_state_stats) + tag_stats["state"] = combined_state_stats + + # Add dataset counts + tag_stats.update(self._get_dataset_counts(tag)) + + statistics_data[tag] = tag_stats + + # Save file + if format.lower() == "json": + if not str(save_path).endswith('.json'): + save_path = save_path.with_suffix('.json') + with open(save_path, 'w', encoding='utf-8') as f: + json.dump(statistics_data, f, indent=2, ensure_ascii=False) + else: + raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.") + + print(f"Merged dataset statistics saved to: {save_path}") + print(f"Used action keys by tag: {tag_to_used_action_keys}") + print(f"Used state keys by tag: {tag_to_used_state_keys}") + + + def _combine_modality_stats(self, modality_stats: dict) -> dict: + """Backward compatibility wrapper.""" + return combine_modality_stats(modality_stats) + + def _generate_action_mask_for_used_keys(self, action_modalities: dict, used_action_keys_ordered) -> list[bool]: + """Backward compatibility wrapper.""" + return generate_action_mask_for_used_keys(action_modalities, used_action_keys_ordered) + + def _get_dataset_counts(self, tag: str) -> dict: + """ + Get dataset count information for specified tag. + + Args: + tag (str): embodiment tag + + Returns: + dict: Dictionary containing num_transitions and num_trajectories + """ + num_transitions = 0 + num_trajectories = 0 + + # Count dataset information belonging to this tag + for dataset in self.datasets: + if dataset.tag == tag: + num_transitions += len(dataset) + num_trajectories += len(dataset.trajectory_ids) + + return { + "num_transitions": num_transitions, + "num_trajectories": num_trajectories + } + + @classmethod + def load_merged_statistics(cls, load_path: Path | str) -> dict: + """ + Load merged dataset statistics from file. + + Args: + load_path (Path | str): Path to the statistics file + + Returns: + dict: Dictionary containing merged statistics + """ + load_path = Path(load_path) + if not load_path.exists(): + raise FileNotFoundError(f"Statistics file not found: {load_path}") + + if load_path.suffix.lower() == '.json': + with open(load_path, 'r', encoding='utf-8') as f: + return json.load(f) + elif load_path.suffix.lower() == '.pkl': + import pickle + with open(load_path, 'rb') as f: + return pickle.load(f) + else: + raise ValueError(f"Unsupported file format: {load_path.suffix}") + + def apply_cached_statistics(self, cached_statistics: dict) -> None: + """ + Apply cached statistics to avoid recomputation. + + Args: + cached_statistics (dict): Statistics loaded from file + """ + # Validate that cached statistics match current datasets + if "metadata" in cached_statistics: + cached_dataset_names = set(cached_statistics["metadata"]["dataset_names"]) + current_dataset_names = set(dataset.dataset_name for dataset in self.datasets) + + if cached_dataset_names != current_dataset_names: + print("Warning: Cached statistics dataset names don't match current datasets.") + print(f"Cached: {cached_dataset_names}") + print(f"Current: {current_dataset_names}") + return + + # Apply cached statistics + self.merged_metadata = {} + for tag, stats_data in cached_statistics.items(): + if tag == "metadata": # Skip metadata field + continue + + # Convert back to DatasetMetadata format + metadata_dict = { + "embodiment_tag": tag, + "statistics": { + "action": {}, + "state": {} + }, + "modalities": {} + } + + # Convert action statistics back + if "action" in stats_data: + action_data = stats_data["action"] + # This is simplified - you may need to split back to sub-keys + metadata_dict["statistics"]["action"] = action_data + + # Convert state statistics back + if "state" in stats_data: + state_data = stats_data["state"] + metadata_dict["statistics"]["state"] = state_data + + self.merged_metadata[tag] = DatasetMetadata.model_validate(metadata_dict) + + # Update transforms metadata for each dataset + for dataset in self.datasets: + if dataset.tag in self.merged_metadata: + dataset.set_transforms_metadata(self.merged_metadata[dataset.tag]) + + print(f"Applied cached statistics for {len(self.merged_metadata)} embodiment tags.") + diff --git a/code/dataloader/gr00t_lerobot/datasets_bak2.py b/code/dataloader/gr00t_lerobot/datasets_bak2.py new file mode 100644 index 0000000000000000000000000000000000000000..43da9dc9614fcc36b3794695dc6a0b0d36cf7162 --- /dev/null +++ b/code/dataloader/gr00t_lerobot/datasets_bak2.py @@ -0,0 +1,2145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +In this file, we define 3 types of datasets: +1. LeRobotSingleDataset: a single dataset for a given embodiment tag +2. LeRobotMixtureDataset: a mixture of datasets for a given list of embodiment tags +3. CachedLeRobotSingleDataset: a single dataset for a given embodiment tag, + with caching for the video frames + +See `scripts/load_dataset.py` for examples on how to use these datasets. +""" +import os +import hashlib +import json, torch +from collections import defaultdict +from pathlib import Path +from typing import Sequence +import os, random +import numpy as np +import pandas as pd +from pydantic import BaseModel, Field, ValidationError +from torch.utils.data import Dataset +from tqdm import tqdm +from PIL import Image + +from starVLA.dataloader.gr00t_lerobot.video import get_all_frames, get_frames_by_timestamps + +from starVLA.dataloader.gr00t_lerobot.embodiment_tags import EmbodimentTag, DATASET_NAME_TO_ID +from starVLA.dataloader.gr00t_lerobot.schema import ( + DatasetMetadata, + DatasetStatisticalValues, + LeRobotModalityMetadata, + LeRobotStateActionMetadata, +) +from starVLA.dataloader.gr00t_lerobot.transform import ComposedModalityTransform + +from functools import partial +from typing import Tuple, List +import pickle + +# LeRobot v2.0 dataset file names +LE_ROBOT_MODALITY_FILENAME = "meta/modality.json" +LE_ROBOT_EPISODE_FILENAME = "meta/episodes.jsonl" +LE_ROBOT_TASKS_FILENAME = "meta/tasks.jsonl" +LE_ROBOT_INFO_FILENAME = "meta/info.json" +LE_ROBOT_STATS_FILENAME = "meta/stats_gr00t.json" +LE_ROBOT_DATA_FILENAME = "data/*/*.parquet" +LE_ROBOT_STEPS_FILENAME = "meta/steps.pkl" +EPSILON = 5e-4 + +# LeRobot v3.0 dataset file names +LE_ROBOT3_TASKS_FILENAME = "meta/tasks.parquet" +LE_ROBOT3_EPISODE_FILENAME = "meta/episodes/*/*.parquet" + + +# ============================================================================= +# Unified Representation Layout & Helpers +# ============================================================================= + +STANDARD_ACTION_DIM = 37 +# +# Unified action representation layout (0-based indices, Python slice is [start, stop)): +# Keep only: libero_franka, gr1, real_world_franka. +# +# - 0:7 -> left_arm (7D): xyz, rpy/euler, gripper +# Used by: gr1 left_arm +# - 7:14 -> right_arm (7D): same structure +# Used by: libero_franka; gr1 right_arm +# - 14:20 -> left_hand (6D): gr1 only +# - 20:26 -> right_hand (6D): gr1 only +# - 26:29 -> waist (3D): gr1 only +# - 29:37 -> joints + gripper (8D): real_world_franka only +# +# Mapping: +# libero_franka (7D) -> [7:14] (right_arm slot) +# gr1 (29D) -> [0:29] +# real_world_franka (8D) -> [29:37] (joints + gripper) + +ACTION_REPRESENTATION_SLICES = { + # Single-arm (7D) -> right_arm slot [7:14] + "franka": slice(7, 14), + + # Humanoid (29D) -> full [0:29] + "gr1": slice(0, 29), + + # Real-world (8D) -> [29:37] (joints + gripper) + "real_world_franka": slice(29, 37), +} + +STANDARD_STATE_DIM = 74 +# Mapping: +# libero_franka (8D) -> [0:8] +# real_world_franka (8D) -> [8:16] +# gr1 (58D after sin/cos) -> [16:74] + +STATE_REPRESENTATION_SLICES = { + # Single-arm (8D) + "franka": slice(0, 8), + # Real-world (8D) + "real_world_franka": slice(8, 16), + # GR1 isolated (58D, has StateActionSinCosTransform - different pipeline) + "gr1": slice(16, 74), +} + + +def standardize_action_representation( + action: np.ndarray, embodiment_tag: str +) -> np.ndarray: + """Map per-robot action to a fixed-size standard action vector.""" + target_slice = ACTION_REPRESENTATION_SLICES.get(embodiment_tag) + + # Only allow explicitly configured embodiment tags. + if target_slice is None: + raise ValueError( + f"Unknown embodiment tag '{embodiment_tag}' for action mapping. " + f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES)}" + ) + + expected_dim = target_slice.stop - target_slice.start + if action.shape[-1] != expected_dim: + raise ValueError( + f"Action dim mismatch for tag '{embodiment_tag}': " + f"{action.shape[-1]=} vs expected {expected_dim}." + ) + + standard = np.zeros( + (*action.shape[:-1], STANDARD_ACTION_DIM), dtype=action.dtype + ) + standard[..., target_slice] = action + return standard + + +def standardize_state_representation( + state: np.ndarray, embodiment_tag: str +) -> np.ndarray: + """Map per-robot state to a fixed-size standard state vector.""" + + target_slice = STATE_REPRESENTATION_SLICES.get(embodiment_tag) + + # Only allow explicitly configured embodiment tags. + if target_slice is None: + raise ValueError( + f"Unknown embodiment tag '{embodiment_tag}' for state mapping. " + f"Known tags: {sorted(STATE_REPRESENTATION_SLICES)}" + ) + + expected_dim = target_slice.stop - target_slice.start + if state.shape[-1] != expected_dim: + raise ValueError( + f"State dim mismatch for tag '{embodiment_tag}': " + f"{state.shape[-1]=} vs expected {expected_dim}." + ) + + standard = np.zeros( + (*state.shape[:-1], STANDARD_STATE_DIM), dtype=state.dtype + ) + standard[..., target_slice] = state + return standard + + +def calculate_dataset_statistics(parquet_paths: list[Path]) -> dict: + """Calculate the dataset statistics of all columns for a list of parquet files.""" + # Dataset statistics + all_low_dim_data_list = [] + # Collect all the data + # parquet_paths = parquet_paths[:3] + for parquet_path in tqdm( + sorted(list(parquet_paths)), + desc="Collecting all parquet files...", + ): + # Load the parquet file + parquet_data = pd.read_parquet(parquet_path) + parquet_data = parquet_data + all_low_dim_data_list.append(parquet_data) + + all_low_dim_data = pd.concat(all_low_dim_data_list, axis=0) + # Compute dataset statistics + dataset_statistics = {} + for le_modality in all_low_dim_data.columns: + if le_modality.startswith("annotation."): + continue + print(f"Computing statistics for {le_modality}...") + np_data = np.vstack( + [np.asarray(x, dtype=np.float32) for x in all_low_dim_data[le_modality]] + ) + dataset_statistics[le_modality] = { + "mean": np.mean(np_data, axis=0).tolist(), + "std": np.std(np_data, axis=0).tolist(), + "min": np.min(np_data, axis=0).tolist(), + "max": np.max(np_data, axis=0).tolist(), + "q01": np.quantile(np_data, 0.01, axis=0).tolist(), + "q99": np.quantile(np_data, 0.99, axis=0).tolist(), + } + return dataset_statistics + + +class ModalityConfig(BaseModel): + """Configuration for a modality.""" + + delta_indices: list[int] + """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices.""" + modality_keys: list[str] + """The keys to load for the modality in the dataset.""" + + +class LeRobotSingleDataset(Dataset): + """ + Base dataset class for LeRobot that supports sharding. + """ + def __init__( + self, + dataset_path: Path | str, + modality_configs: dict[str, ModalityConfig], + embodiment_tag: str | EmbodimentTag, + video_backend: str = "decord", + video_backend_kwargs: dict | None = None, + transforms: ComposedModalityTransform | None = None, + delete_pause_frame: bool = False, + **kwargs, + ): + """ + Initialize the dataset. + + Args: + dataset_path (Path | str): The path to the dataset. + modality_configs (dict[str, ModalityConfig]): The configuration for each modality. The keys are the modality names, and the values are the modality configurations. + See `ModalityConfig` for more details. + video_backend (str): Backend for video reading. + video_backend_kwargs (dict): Keyword arguments for the video backend when initializing the video reader. + transforms (ComposedModalityTransform): The transforms to apply to the dataset. + embodiment_tag (EmbodimentTag): Overload the embodiment tag for the dataset. e.g. define it as "new_embodiment" + """ + # first check if the path directory exists + if not Path(dataset_path).exists(): + raise FileNotFoundError(f"Dataset path {dataset_path} does not exist") + data_cfg = kwargs.get("data_cfg", {}) or {} + # indict letobot version + self._lerobot_version = data_cfg.get("lerobot_version", "v2.0") #self._indict_lerobot_version(**kwargs) + self.load_video = data_cfg.get("load_video", True) + + self.delete_pause_frame = delete_pause_frame + + # If video loading is disabled, skip video modality end-to-end. + if self.load_video: + self.modality_configs = modality_configs + else: + self.modality_configs = { + modality: config + for modality, config in modality_configs.items() + if modality != "video" + } + self.video_backend = video_backend + self.video_backend_kwargs = video_backend_kwargs if video_backend_kwargs is not None else {} + self.transforms = ( + transforms if transforms is not None else ComposedModalityTransform(transforms=[]) + ) + + self._dataset_path = Path(dataset_path) + self._dataset_name = self._dataset_path.name + self._dataset_id = DATASET_NAME_TO_ID.get(self._dataset_name) + if isinstance(embodiment_tag, EmbodimentTag): + self.tag = embodiment_tag.value + else: + self.tag = embodiment_tag + + self._metadata = self._get_metadata(EmbodimentTag(self.tag)) + + # LeRobot-specific config + self._lerobot_modality_meta = self._get_lerobot_modality_meta() + self._lerobot_info_meta = self._get_lerobot_info_meta() + self._data_path_pattern = self._get_data_path_pattern() + self._video_path_pattern = self._get_video_path_pattern() + self._chunk_size = self._get_chunk_size() + self._tasks = self._get_tasks() + self.curr_traj_data = None + self.curr_traj_id = None + + self._trajectory_ids, self._trajectory_lengths = self._get_trajectories() + self._modality_keys = self._get_modality_keys() + self._delta_indices = self._get_delta_indices() + self._all_steps = self._get_all_steps() + self.set_transforms_metadata(self.metadata) + self.set_epoch(0) + + print(f"Initialized dataset {self.dataset_name} with {embodiment_tag}") + + + # Check if the dataset is valid + self._check_integrity() + + @property + def dataset_path(self) -> Path: + """The path to the dataset that contains the METADATA_FILENAME file.""" + return self._dataset_path + + @property + def metadata(self) -> DatasetMetadata: + """The metadata for the dataset, loaded from metadata.json in the dataset directory""" + return self._metadata + + @property + def trajectory_ids(self) -> np.ndarray: + """The trajectory IDs in the dataset, stored as a 1D numpy array of strings.""" + return self._trajectory_ids + + @property + def trajectory_lengths(self) -> np.ndarray: + """The trajectory lengths in the dataset, stored as a 1D numpy array of integers. + The order of the lengths is the same as the order of the trajectory IDs. + """ + return self._trajectory_lengths + + @property + def all_steps(self) -> list[tuple[int, int]]: + """The trajectory IDs and base indices for all steps in the dataset. + Example: + self.trajectory_ids: [0, 1, 2] + self.trajectory_lengths: [3, 2, 4] + return: [ + ("traj_0", 0), ("traj_0", 1), ("traj_0", 2), + ("traj_1", 0), ("traj_1", 1), + ("traj_2", 0), ("traj_2", 1), ("traj_2", 2), ("traj_2", 3) + ] + """ + return self._all_steps + + @property + def modality_keys(self) -> dict: + """The modality keys for the dataset. The keys are the modality names, and the values are the keys for each modality. + + Example: { + "video": ["video.image_side_0", "video.image_side_1"], + "state": ["state.eef_position", "state.eef_rotation"], + "action": ["action.eef_position", "action.eef_rotation"], + "language": ["language.human.task"], + "timestamp": ["timestamp"], + "reward": ["reward"], + } + """ + return self._modality_keys + + @property + def delta_indices(self) -> dict[str, np.ndarray]: + """The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key.""" + return self._delta_indices + + @property + def dataset_name(self) -> str: + """The name of the dataset.""" + return self._dataset_name + + @property + def lerobot_modality_meta(self) -> LeRobotModalityMetadata: + """The metadata for the LeRobot dataset.""" + return self._lerobot_modality_meta + + @property + def lerobot_info_meta(self) -> dict: + """The metadata for the LeRobot dataset.""" + return self._lerobot_info_meta + + @property + def data_path_pattern(self) -> str: + """The path pattern for the LeRobot dataset.""" + return self._data_path_pattern + + @property + def video_path_pattern(self) -> str: + """The path pattern for the LeRobot dataset.""" + return self._video_path_pattern + + @property + def chunk_size(self) -> int: + """The chunk size for the LeRobot dataset.""" + return self._chunk_size + + @property + def tasks(self) -> pd.DataFrame: + """The tasks for the dataset.""" + return self._tasks + + def _get_metadata(self, embodiment_tag: EmbodimentTag) -> DatasetMetadata: + """Get the metadata for the dataset. + + Returns: + dict: The metadata for the dataset. + """ + + # 1. Modality metadata + modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME + assert ( + modality_meta_path.exists() + ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}" + # 1.1. State and action modalities + simplified_modality_meta: dict[str, dict] = {} + with open(modality_meta_path, "r") as f: + le_modality_meta = LeRobotModalityMetadata.model_validate(json.load(f)) + for modality in ["state", "action"]: + simplified_modality_meta[modality] = {} + le_state_action_meta: dict[str, LeRobotStateActionMetadata] = getattr( + le_modality_meta, modality + ) + for subkey in le_state_action_meta: + state_action_dtype = np.dtype(le_state_action_meta[subkey].dtype) + if np.issubdtype(state_action_dtype, np.floating): + continuous = True + else: + continuous = False + simplified_modality_meta[modality][subkey] = { + "absolute": le_state_action_meta[subkey].absolute, + "rotation_type": le_state_action_meta[subkey].rotation_type, + "shape": [ + le_state_action_meta[subkey].end - le_state_action_meta[subkey].start + ], + "continuous": continuous, + } + + # 1.2. Video modalities + le_info_path = self.dataset_path / LE_ROBOT_INFO_FILENAME + assert ( + le_info_path.exists() + ), f"Please provide a {LE_ROBOT_INFO_FILENAME} file in {self.dataset_path}" + with open(le_info_path, "r") as f: + le_info = json.load(f) + simplified_modality_meta["video"] = {} + for new_key in le_modality_meta.video: + original_key = le_modality_meta.video[new_key].original_key + if original_key is None: + original_key = new_key + le_video_meta = le_info["features"][original_key] + height = le_video_meta["shape"][le_video_meta["names"].index("height")] + width = le_video_meta["shape"][le_video_meta["names"].index("width")] + # NOTE(FH): different lerobot dataset versions have different keys for the number of channels and fps + try: + channels = le_video_meta["shape"][le_video_meta["names"].index("channel")] + fps = le_video_meta["video_info"]["video.fps"] + except (ValueError, KeyError): + # channels = le_video_meta["shape"][le_video_meta["names"].index("channels")] + channels = le_video_meta["info"]["video.channels"] + fps = le_video_meta["info"]["video.fps"] + simplified_modality_meta["video"][new_key] = { + "resolution": [width, height], + "channels": channels, + "fps": fps, + } + + # 2. Dataset statistics + stats_path = self.dataset_path / LE_ROBOT_STATS_FILENAME + try: + with open(stats_path, "r") as f: + le_statistics = json.load(f) + for stat in le_statistics.values(): + DatasetStatisticalValues.model_validate(stat) + except (FileNotFoundError, ValidationError) as e: + print(f"Failed to load dataset statistics: {e}") + print(f"Calculating dataset statistics for {self.dataset_name}") + # Get all parquet files in the dataset paths + parquet_files = list((self.dataset_path).glob(LE_ROBOT_DATA_FILENAME)) + parquet_files_filtered = [] + # parquet_files[0].name = "episode_033675.parquet" is broken file + for pf in parquet_files: + if "episode_033675.parquet" in pf.name: + continue + parquet_files_filtered.append(pf) + + le_statistics = calculate_dataset_statistics(parquet_files_filtered) + with open(stats_path, "w") as f: + json.dump(le_statistics, f, indent=4) + dataset_statistics = {} + for our_modality in ["state", "action"]: + dataset_statistics[our_modality] = {} + for subkey in simplified_modality_meta[our_modality]: + dataset_statistics[our_modality][subkey] = {} + state_action_meta = le_modality_meta.get_key_meta(f"{our_modality}.{subkey}") + assert isinstance(state_action_meta, LeRobotStateActionMetadata) + le_modality = state_action_meta.original_key + for stat_name in le_statistics[le_modality]: + indices = np.arange( + state_action_meta.start, + state_action_meta.end, + ) + stat = np.array(le_statistics[le_modality][stat_name]) + dataset_statistics[our_modality][subkey][stat_name] = stat[indices].tolist() + + # 3. Full dataset metadata + metadata = DatasetMetadata( + statistics=dataset_statistics, # type: ignore + modalities=simplified_modality_meta, # type: ignore + embodiment_tag=embodiment_tag, + ) + + return metadata + + def _get_trajectories(self) -> tuple[np.ndarray, np.ndarray]: + """Get the trajectories in the dataset.""" + # Get trajectory lengths, IDs, and whitelist from dataset metadata + # v2.0 + if self._lerobot_version == "v2.0": + file_path = self.dataset_path / LE_ROBOT_EPISODE_FILENAME + with open(file_path, "r") as f: + episode_metadata = [json.loads(line) for line in f] + trajectory_ids = [] + trajectory_lengths = [] + for episode in episode_metadata: + trajectory_ids.append(episode["episode_index"]) + trajectory_lengths.append(episode["length"]) + return np.array(trajectory_ids), np.array(trajectory_lengths) + # v3.0 + elif self._lerobot_version == "v3.0": + file_paths = list((self.dataset_path).glob(LE_ROBOT3_EPISODE_FILENAME)) + trajectory_ids = [] + trajectory_lengths = [] + # data_chunck_index = [] + # data_file_index = [] + # vido_from_index = [] + self.trajectory_ids_to_metadata = {} + for file_path in file_paths: + episodes_data = pd.read_parquet(file_path) + for index, episode in episodes_data.iterrows(): + trajectory_ids.append(episode["episode_index"]) + trajectory_lengths.append(episode["length"]) + + # TODO auto map key? just map to file_path and file_from_index + episode_meta = { + "data/chunk_index": episode["data/chunk_index"], + "data/file_index": episode["data/file_index"], + "data/file_from_index": index, + } + if self.load_video: + episode_meta["videos/observation.images.wrist/from_timestamp"] = episode[ + "videos/observation.images.wrist/from_timestamp" + ] + self.trajectory_ids_to_metadata[trajectory_ids[-1]] = episode_meta + + # 这里应该可以直接读取到 save index 信息 + return np.array(trajectory_ids), np.array(trajectory_lengths) + + def _get_all_steps(self) -> list[tuple[int, int]]: + """Get the trajectory IDs and base indices for all steps in the dataset. + + Returns: + list[tuple[str, int]]: A list of (trajectory_id, base_index) tuples. + """ + # Create a hash key based on configuration to ensure cache validity + config_key = self._get_steps_config_key() + + # Create a unique filename based on config_key + # steps_filename = f"steps_{config_key}.pkl" + # @BUG + # fast get static steps @fangjing --> don't use hash to dynamic sample + steps_filename = "steps_data_index.pkl" + + + steps_path = self.dataset_path / "meta" / steps_filename + + # Try to load cached steps first + try: + if steps_path.exists(): + with open(steps_path, "rb") as f: + cached_data = pickle.load(f) + return cached_data["steps"] + + except (FileNotFoundError, pickle.PickleError, KeyError) as e: + print(f"Failed to load cached steps: {e}") + print("Computing steps from scratch...") + + # Compute steps using single process + all_steps = self._get_all_steps_single_process() + + # Cache the computed steps with unique filename + try: + cache_data = { + "config_key": config_key, + "steps": all_steps, + "num_trajectories": len(self.trajectory_ids), + "total_steps": len(all_steps), + "computed_timestamp": pd.Timestamp.now().isoformat(), + "delete_pause_frame": self.delete_pause_frame, + } + + # Ensure the meta directory exists + steps_path.parent.mkdir(parents=True, exist_ok=True) + + with open(steps_path, "wb") as f: + pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL) + print(f"Cached steps saved to {steps_path}") + except Exception as e: + print(f"Failed to cache steps: {e}") + + return all_steps + + def _get_steps_config_key(self) -> str: + """Generate a configuration key for steps caching.""" + config_dict = { + "delete_pause_frame": self.delete_pause_frame, + "dataset_name": self.dataset_name, + } + # Create a hash of the configuration + config_str = str(sorted(config_dict.items())) + return hashlib.md5(config_str.encode()).hexdigest()[:12] # + + + def _get_all_steps_single_process(self) -> list[tuple[int, int]]: + """Original single-process implementation as fallback.""" + all_steps: list[tuple[int, int]] = [] + skipped_trajectories = 0 + processed_trajectories = 0 + + # Check if language modality is configured + has_language_modality = 'language' in self.modality_keys and len(self.modality_keys['language']) > 0 + # TODO why trajectory_length here, why not use data length? + for trajectory_id, trajectory_length in tqdm(zip(self.trajectory_ids, self.trajectory_lengths), total=len(self.trajectory_ids), desc="Getting All Step"): + try: + if self._lerobot_version == "v2.0": + data = self.get_trajectory_data(trajectory_id) + elif self._lerobot_version == "v3.0": + data = self.get_trajectory_data_lerobot_v3(trajectory_id) + + trajectory_skipped = False + + # Check if trajectory has valid language instruction (if language modality is configured) + if has_language_modality: + self.curr_traj_data = data # Set current trajectory data for get_language to work + + language_instruction = self.get_language(trajectory_id, self.modality_keys['language'][0], 0) + if not language_instruction or language_instruction[0] == "": + print(f"Skipping trajectory {trajectory_id} due to empty language instruction") + skipped_trajectories += 1 + trajectory_skipped = True + continue + + except Exception as e: + print(f"Skipping trajectory {trajectory_id} due to read error: {e}") + skipped_trajectories += 1 + trajectory_skipped = True + continue + + if not trajectory_skipped: + processed_trajectories += 1 + + for base_index in range(trajectory_length): + all_steps.append((trajectory_id, base_index)) + + # Print summary statistics + print(f"Single-process summary: Processed {processed_trajectories} trajectories, skipped {skipped_trajectories} empty trajectories") + print(f"Total steps: {len(all_steps)} from {len(self.trajectory_ids)} trajectories") + + return all_steps + + def _get_position_and_gripper_values(self, data: pd.DataFrame) -> tuple[list, list]: + """Get position and gripper values based on available columns in the dataset.""" + # Get action keys from modality_keys + action_keys = self.modality_keys.get('action', []) + + # Extract position data + delta_position_values = None + position_candidates = ['delta_eef_position'] + coordinate_candidates = ['x', 'y', 'z'] + + # First try combined position fields + for pos_key in position_candidates: + full_key = f"action.{pos_key}" + if full_key in action_keys: + try: + # Get the lerobot key for this modality + le_action_cfg = self.lerobot_modality_meta.action + subkey = pos_key + if subkey in le_action_cfg: + le_key = le_action_cfg[subkey].original_key or subkey + if le_key in data.columns: + data_array = np.stack(data[le_key]) + le_indices = np.arange(le_action_cfg[subkey].start, le_action_cfg[subkey].end) + filtered_data = data_array[:, le_indices] + delta_position_values = filtered_data.tolist() + break + except Exception: + continue + + # If combined fields not found, try individual x,y,z coordinates + if delta_position_values is None: + x_data, y_data, z_data = None, None, None + for coord in coordinate_candidates: + full_key = f"action.{coord}" + if full_key in action_keys: + try: + le_action_cfg = self.lerobot_modality_meta.action + if coord in le_action_cfg: + le_key = le_action_cfg[coord].original_key or coord + if le_key in data.columns: + data_array = np.stack(data[le_key]) + le_indices = np.arange(le_action_cfg[coord].start, le_action_cfg[coord].end) + coord_data = data_array[:, le_indices].flatten() + if coord == 'x': + x_data = coord_data + elif coord == 'y': + y_data = coord_data + elif coord == 'z': + z_data = coord_data + except Exception: + continue + + if x_data is not None and y_data is not None and z_data is not None: + delta_position_values = np.column_stack((x_data, y_data, z_data)).tolist() + + if delta_position_values is None: + # Fallback to the old hardcoded approach if metadata approach fails + if 'action.delta_eef_position' in data.columns: + delta_position_values = data['action.delta_eef_position'].to_numpy().tolist() + elif all(col in data.columns for col in ['action.x', 'action.y', 'action.z']): + x_vals = data['action.x'].to_numpy() + y_vals = data['action.y'].to_numpy() + z_vals = data['action.z'].to_numpy() + delta_position_values = np.column_stack((x_vals, y_vals, z_vals)).tolist() + else: + raise ValueError(f"No suitable position columns found. Available columns: {data.columns.tolist()}") + + # Extract gripper data + gripper_values = None + gripper_candidates = ['gripper_close', 'gripper'] + + for grip_key in gripper_candidates: + full_key = f"action.{grip_key}" + if full_key in action_keys: + try: + le_action_cfg = self.lerobot_modality_meta.action + if grip_key in le_action_cfg: + le_key = le_action_cfg[grip_key].original_key or grip_key + if le_key in data.columns: + data_array = np.stack(data[le_key]) + le_indices = np.arange(le_action_cfg[grip_key].start, le_action_cfg[grip_key].end) + gripper_data = data_array[:, le_indices].flatten() + gripper_values = gripper_data.tolist() + break + except Exception: + continue + + if gripper_values is None: + # Fallback to the old hardcoded approach if metadata approach fails + if 'action.gripper_close' in data.columns: + gripper_values = data['action.gripper_close'].to_numpy().tolist() + elif 'action.gripper' in data.columns: + gripper_values = data['action.gripper'].to_numpy().tolist() + else: + raise ValueError(f"No suitable gripper columns found. Available columns: {data.columns.tolist()}") + + return delta_position_values, gripper_values + + def _get_modality_keys(self) -> dict: + """Get the modality keys for the dataset. + The keys are the modality names, and the values are the keys for each modality. + See property `modality_keys` for the expected format. + """ + modality_keys = defaultdict(list) + for modality, config in self.modality_configs.items(): + modality_keys[modality] = config.modality_keys + return modality_keys + + def _get_delta_indices(self) -> dict[str, np.ndarray]: + """Restructure the delta indices to use modality.key as keys instead of just the modalities.""" + delta_indices: dict[str, np.ndarray] = {} + for config in self.modality_configs.values(): + for key in config.modality_keys: + delta_indices[key] = np.array(config.delta_indices) + return delta_indices + + def _get_lerobot_modality_meta(self) -> LeRobotModalityMetadata: + """Get the metadata for the LeRobot dataset.""" + modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME + assert ( + modality_meta_path.exists() + ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}" + with open(modality_meta_path, "r") as f: + modality_meta = LeRobotModalityMetadata.model_validate(json.load(f)) + return modality_meta + + def _get_lerobot_info_meta(self) -> dict: + """Get the metadata for the LeRobot dataset.""" + info_meta_path = self.dataset_path / LE_ROBOT_INFO_FILENAME + with open(info_meta_path, "r") as f: + info_meta = json.load(f) + return info_meta + + def _get_data_path_pattern(self) -> str: + """Get the data path pattern for the LeRobot dataset.""" + return self.lerobot_info_meta["data_path"] + + def _get_video_path_pattern(self) -> str: + """Get the video path pattern for the LeRobot dataset.""" + return self.lerobot_info_meta["video_path"] + + def _get_chunk_size(self) -> int: + """Get the chunk size for the LeRobot dataset.""" + return self.lerobot_info_meta["chunks_size"] + + def _get_tasks(self) -> pd.DataFrame: + """Get the tasks for the dataset.""" + if self._lerobot_version == "v2.0": + tasks_path = self.dataset_path / LE_ROBOT_TASKS_FILENAME + with open(tasks_path, "r") as f: + tasks = [json.loads(line) for line in f] + df = pd.DataFrame(tasks) + return df.set_index("task_index") + + elif self._lerobot_version == "v3.0": + tasks_path = self.dataset_path / LE_ROBOT3_TASKS_FILENAME + df = pd.read_parquet(tasks_path) + df = df.reset_index() # 把索引变成一列,列名通常为 'index' + df = df.rename(columns={'index': 'task'}) # 把 'index' 列重命名为 'task' + df = df[['task_index', 'task']] # 调整列顺序 + return df + def _check_integrity(self): + """Use the config to check if the keys are valid and detect silent data corruption.""" + ERROR_MSG_HEADER = f"Error occurred in initializing dataset {self.dataset_name}:\n" + + for modality_config in self.modality_configs.values(): + for key in modality_config.modality_keys: + if key == "lapa_action" or key == "dream_actions": + continue # no need for any metadata for lapa actions because it comes normalized + # Check if the key is valid + try: + self.lerobot_modality_meta.get_key_meta(key) + except Exception as e: + raise ValueError( + ERROR_MSG_HEADER + f"Unable to find key {key} in modality metadata:\n{e}" + ) + + def set_transforms_metadata(self, metadata: DatasetMetadata): + """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values.""" + self.transforms.set_metadata(metadata) + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset. + + Args: + epoch (int): The epoch to set. + """ + self.epoch = epoch + + def __len__(self) -> int: + """Get the total number of data points in the dataset. + + Returns: + int: the total number of data points in the dataset. + """ + return len(self.all_steps) + + def __str__(self) -> str: + """Get the description of the dataset.""" + return f"{self.dataset_name} ({len(self)} steps)" + + + def __getitem__(self, index: int) -> dict: + """Get the data for a single step in a trajectory. + + Args: + index (int): The index of the step to get. + + Returns: + dict: The data for the step. + """ + trajectory_id, base_index = self.all_steps[index] + data = self.get_step_data(trajectory_id, base_index) + + # Process all video keys dynamically + images = [] + for video_key in self.modality_keys.get("video", []): + image = data[video_key][0] + + image = Image.fromarray(image).resize((224, 224)) + images.append(image) + + # Get language and action data + language = data[self.modality_keys["language"][0]][0] + action = [] + for action_key in self.modality_keys["action"]: + action.append(data[action_key]) + action = np.concatenate(action, axis=1) + action = standardize_action_representation(action, self.tag) + + state = [] + for state_key in self.modality_keys["state"]: + state.append(data[state_key]) + state = np.concatenate(state, axis=1) + state = standardize_state_representation(state, self.tag) + + return dict(action=action, state=state, image=images, language=language, dataset_id=self._dataset_id) + + def get_step_data(self, trajectory_id: int, base_index: int) -> dict: + """Get the RAW data for a single step in a trajectory. No transforms are applied. + + Args: + trajectory_id (int): The name of the trajectory. + base_index (int): The base step index in the trajectory. + + Returns: + dict: The RAW data for the step. + + Example return: + { + "video": { + "video.image_side_0": [B, T, H, W, C], + "video.image_side_1": [B, T, H, W, C], + }, + "state": { + "state.eef_position": [B, T, state_dim], + "state.eef_rotation": [B, T, state_dim], + }, + "action": { + "action.eef_position": [B, T, action_dim], + "action.eef_rotation": [B, T, action_dim], + }, + } + """ + data = {} + # Get the data for all modalities # just for action base data + self.curr_traj_data = self.get_trajectory_data(trajectory_id) + # TODO @JinhuiYE The logic below is poorly implemented. Data reading should be directly based on curr_traj_data. + for modality in self.modality_keys: + # Get the data corresponding to each key in the modality + for key in self.modality_keys[modality]: + data[key] = self.get_data_by_modality(trajectory_id, modality, key, base_index) + return data + + def get_trajectory_data(self, trajectory_id: int) -> pd.DataFrame: + """Get the data for a trajectory.""" + if self._lerobot_version == "v2.0": + + if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None: + return self.curr_traj_data + else: + chunk_index = self.get_episode_chunk(trajectory_id) + parquet_path = self.dataset_path / self.data_path_pattern.format( + episode_chunk=chunk_index, episode_index=trajectory_id + ) + assert parquet_path.exists(), f"Parquet file not found at {parquet_path}" + return pd.read_parquet(parquet_path) + elif self._lerobot_version == "v3.0": + return self.get_trajectory_data_lerobot_v3(trajectory_id) + + def get_trajectory_data_lerobot_v3(self, trajectory_id: int) -> pd.DataFrame: + """Get the data for a trajectory from lerobot v3.""" + if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None: + return self.curr_traj_data + else: #TODO check detail later + chunk_index = self.get_episode_chunk(trajectory_id) + + file_index = self.get_episode_file_index(trajectory_id) + # file_from_index = self.get_episode_file_from_index(trajectory_id) + + + parquet_path = self.dataset_path / self.data_path_pattern.format( + chunk_index=chunk_index, file_index=file_index + ) + assert parquet_path.exists(), f"Parquet file not found at {parquet_path}" + file_data = pd.read_parquet(parquet_path) + + # filter by trajectory_id + episode_data = file_data.loc[file_data["episode_index"] == trajectory_id].copy() + + # fix timestamp from epis index to file index for video alignment + if self.load_video: + from_timestamp = self.trajectory_ids_to_metadata[trajectory_id].get( + "videos/observation.images.wrist/from_timestamp", 0 + ) + episode_data["timestamp"] = episode_data["timestamp"] + from_timestamp + + return episode_data + + + def get_trajectory_index(self, trajectory_id: int) -> int: + """Get the index of the trajectory in the dataset by the trajectory ID. + This is useful when you need to get the trajectory length or sampling weight corresponding to the trajectory ID. + + Args: + trajectory_id (str): The ID of the trajectory. + + Returns: + int: The index of the trajectory in the dataset. + """ + trajectory_indices = np.where(self.trajectory_ids == trajectory_id)[0] + if len(trajectory_indices) != 1: + raise ValueError( + f"Error finding trajectory index for {trajectory_id}, found {trajectory_indices=}" + ) + return trajectory_indices[0] + + def get_episode_chunk(self, ep_index: int) -> int: + """Get the chunk index for an episode index.""" + return ep_index // self.chunk_size + def get_episode_file_index(self, ep_index: int) -> int: + """Get the file index for an episode index.""" + episode_meta = self.trajectory_ids_to_metadata[ep_index] + return episode_meta["data/file_index"] + + def get_episode_file_from_index(self, ep_index: int) -> int: + """Get the file from index for an episode index.""" + episode_meta = self.trajectory_ids_to_metadata[ep_index] + return episode_meta["data/file_from_index"] + + + def retrieve_data_and_pad( + self, + array: np.ndarray, + step_indices: np.ndarray, + max_length: int, + padding_strategy: str = "first_last", + ) -> np.ndarray: + """Retrieve the data from the dataset and pad it if necessary. + Args: + array (np.ndarray): The array to retrieve the data from. + step_indices (np.ndarray): The step indices to retrieve the data for. + max_length (int): The maximum length of the data. + padding_strategy (str): The padding strategy, either "first" or "last". + """ + # Get the padding indices + front_padding_indices = step_indices < 0 + end_padding_indices = step_indices >= max_length + padding_positions = np.logical_or(front_padding_indices, end_padding_indices) + # Retrieve the data with the non-padding indices + # If there exists some padding, Given T step_indices, the shape of the retrieved data will be (T', ...) where T' < T + raw_data = array[step_indices[~padding_positions]] + assert isinstance(raw_data, np.ndarray), f"{type(raw_data)=}" + # This is the shape of the output, (T, ...) + if raw_data.ndim == 1: + expected_shape = (len(step_indices),) + else: + expected_shape = (len(step_indices), *array.shape[1:]) + + # Pad the data + output = np.zeros(expected_shape) + # Assign the non-padded data + output[~padding_positions] = raw_data + # If there exists some padding, pad the data + if padding_positions.any(): + if padding_strategy == "first_last": + # Use first / last step data to pad + front_padding_data = array[0] + end_padding_data = array[-1] + output[front_padding_indices] = front_padding_data + output[end_padding_indices] = end_padding_data + elif padding_strategy == "zero": + # Use zero padding + output[padding_positions] = 0 + else: + raise ValueError(f"Invalid padding strategy: {padding_strategy}") + return output + + def get_video_path(self, trajectory_id: int, key: str) -> Path: + chunk_index = self.get_episode_chunk(trajectory_id) + original_key = self.lerobot_modality_meta.video[key].original_key + if original_key is None: + original_key = key + if self._lerobot_version == "v2.0": + video_filename = self.video_path_pattern.format( + episode_chunk=chunk_index, episode_index=trajectory_id, video_key=original_key + ) + elif self._lerobot_version == "v3.0": + episode_meta = self.trajectory_ids_to_metadata[trajectory_id] + video_filename = self.video_path_pattern.format( + video_key=original_key, + chunk_index=episode_meta["data/chunk_index"], + file_index=episode_meta["data/file_index"], + ) + return self.dataset_path / video_filename + + def get_video( + self, + trajectory_id: int, + key: str, + base_index: int, + ) -> np.ndarray: + """Get the video frames for a trajectory by a base index. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (str): The ID of the trajectory. + key (str): The key of the video. + base_index (int): The base index of the trajectory. + + Returns: + np.ndarray: The video frames for the trajectory and frame indices. Shape: (T, H, W, C) + """ + # Get the step indices + step_indices = self.delta_indices[key] + base_index + # print(f"{step_indices=}") + # Get the trajectory index + trajectory_index = self.get_trajectory_index(trajectory_id) + # Ensure the indices are within the valid range + # This is equivalent to padding the video with extra frames at the beginning and end + step_indices = np.maximum(step_indices, 0) + step_indices = np.minimum(step_indices, self.trajectory_lengths[trajectory_index] - 1) + assert key.startswith("video."), f"Video key must start with 'video.', got {key}" + # Get the sub-key + key = key.replace("video.", "") + video_path = self.get_video_path(trajectory_id, key) + # Get the action/state timestamps for each frame in the video + assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" + assert "timestamp" in self.curr_traj_data.columns, f"No timestamp found in {trajectory_id=}" + timestamp: np.ndarray = self.curr_traj_data["timestamp"].to_numpy() + # Get the corresponding video timestamps from the step indices + video_timestamp = timestamp[step_indices] + + return get_frames_by_timestamps( + video_path.as_posix(), + video_timestamp, + video_backend=self.video_backend, # TODO + video_backend_kwargs=self.video_backend_kwargs, + ) + + def get_state_or_action( + self, + trajectory_id: int, + modality: str, + key: str, + base_index: int, + ) -> np.ndarray: + """Get the state or action data for a trajectory by a base index. + If the step indices are out of range, pad with the data: + if the data is stored in absolute format, pad with the first or last step data; + otherwise, pad with zero. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (int): The ID of the trajectory. + modality (str): The modality of the data. + key (str): The key of the data. + base_index (int): The base index of the trajectory. + + Returns: + np.ndarray: The data for the trajectory and step indices. + """ + # Get the step indices + step_indices = self.delta_indices[key] + base_index + # Get the trajectory index + trajectory_index = self.get_trajectory_index(trajectory_id) + # Get the maximum length of the trajectory + max_length = self.trajectory_lengths[trajectory_index] + assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}" + # Get the sub-key, e.g. state.joint_angles -> joint_angles + key = key.replace(modality + ".", "") + # Get the lerobot key + le_state_or_action_cfg = getattr(self.lerobot_modality_meta, modality) + le_key = le_state_or_action_cfg[key].original_key + if le_key is None: + le_key = key + # Get the data array, shape: (T, D) + assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" + assert le_key in self.curr_traj_data.columns, f"No {le_key} found in {trajectory_id=}" + data_array: np.ndarray = np.stack(self.curr_traj_data[le_key]) # type: ignore + assert data_array.ndim == 2, f"Expected 2D array, got key {le_key} is{data_array.shape} array" + le_indices = np.arange( + le_state_or_action_cfg[key].start, + le_state_or_action_cfg[key].end, + ) + data_array = data_array[:, le_indices] + # Get the state or action configuration + state_or_action_cfg = getattr(self.metadata.modalities, modality)[key] + + # Pad the data + return self.retrieve_data_and_pad( + array=data_array, + step_indices=step_indices, + max_length=max_length, + padding_strategy="first_last" if state_or_action_cfg.absolute else "zero", + # padding_strategy="zero", # HACK for realdata + ) + + def get_language( + self, + trajectory_id: int, + key: str, + base_index: int, + ) -> list[str]: + """Get the language annotation data for a trajectory by step indices. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (int): The ID of the trajectory. + key (str): The key of the annotation. + base_index (int): The base index of the trajectory. + + Returns: + list[str]: The annotation data for the trajectory and step indices. If no matching data is found, return empty strings. + """ + assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}" + # Get the step indices + step_indices = self.delta_indices[key] + base_index + # Get the trajectory index + trajectory_index = self.get_trajectory_index(trajectory_id) + # Get the maximum length of the trajectory + max_length = self.trajectory_lengths[trajectory_index] + # Get the end times corresponding to the closest indices + step_indices = np.maximum(step_indices, 0) + step_indices = np.minimum(step_indices, max_length - 1) + # Get the annotations + task_indices: list[int] = [] + assert key.startswith( + "annotation." + ), f"Language key must start with 'annotation.', got {key}" + subkey = key.replace("annotation.", "") + annotation_meta = self.lerobot_modality_meta.annotation + assert annotation_meta is not None, f"Annotation metadata is None for {subkey}" + assert ( + subkey in annotation_meta + ), f"Annotation key {subkey} not found in metadata, available annotation keys: {annotation_meta.keys()}" + subkey_meta = annotation_meta[subkey] + original_key = subkey_meta.original_key + if original_key is None: + original_key = key + for i in range(len(step_indices)): # + # task_indices.append(self.curr_traj_data[original_key][step_indices[i]].item()) + value = self.curr_traj_data[original_key].iloc[step_indices[i]] # TODO check v2.0 + task_indices.append(value if isinstance(value, (int, float)) else value.item()) + + return self.tasks.loc[task_indices]["task"].tolist() + + def get_data_by_modality( + self, + trajectory_id: int, + modality: str, + key: str, + base_index: int, + ): + """Get the data corresponding to the modality for a trajectory by a base index. + This method will call the corresponding helper method based on the modality. + See the helper methods for more details. + NOTE: For the language modality, the data is padded with empty strings if no matching data is found. + + Args: + dataset (BaseSingleDataset): The dataset to retrieve the data from. + trajectory_id (int): The ID of the trajectory. + modality (str): The modality of the data. + key (str): The key of the data. + base_index (int): The base index of the trajectory. + """ + if modality == "video": + return self.get_video(trajectory_id, key, base_index) + elif modality == "state" or modality == "action": + return self.get_state_or_action(trajectory_id, modality, key, base_index) + elif modality == "language": + return self.get_language(trajectory_id, key, base_index) + else: + raise ValueError(f"Invalid modality: {modality}") + + def _save_dataset_statistics_(self, save_path: Path | str, format: str = "json") -> None: + """ + Save dataset statistics to specified path in the required format. + Only includes statistics for keys that are actually used in the dataset. + Key order follows modality config order. + + Args: + save_path (Path | str): Path to save the statistics file + format (str): Save format, currently only supports "json" + """ + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + # Build the data structure to save + statistics_data = {} + + # Get used modality keys + used_action_keys, used_state_keys = get_used_modality_keys(self.modality_keys) + + # Organize statistics by tag + tag = self.tag + tag_stats = {} + + # Process action statistics (only for used keys, config order) + if hasattr(self.metadata.statistics, 'action') and self.metadata.statistics.action: + action_stats = self.metadata.statistics.action + filtered_action_stats = { + key: action_stats[key] + for key in used_action_keys + if key in action_stats + } + + if filtered_action_stats: + # Combine statistics from filtered action sub-keys + combined_action_stats = combine_modality_stats(filtered_action_stats) + + # Add mask field based on whether it's gripper or not + mask = generate_action_mask_for_used_keys( + self.metadata.modalities.action, filtered_action_stats.keys() + ) + combined_action_stats["mask"] = mask + + tag_stats["action"] = combined_action_stats + + # Process state statistics (only for used keys, config order) + if hasattr(self.metadata.statistics, 'state') and self.metadata.statistics.state: + state_stats = self.metadata.statistics.state + filtered_state_stats = { + key: state_stats[key] + for key in used_state_keys + if key in state_stats + } + + if filtered_state_stats: + combined_state_stats = combine_modality_stats(filtered_state_stats) + tag_stats["state"] = combined_state_stats + + # Add dataset counts + tag_stats["num_transitions"] = len(self) + tag_stats["num_trajectories"] = len(self.trajectory_ids) + + statistics_data[tag] = tag_stats + + # Save as JSON file + if format.lower() == "json": + if not str(save_path).endswith('.json'): + save_path = save_path.with_suffix('.json') + with open(save_path, 'w', encoding='utf-8') as f: + json.dump(statistics_data, f, indent=2, ensure_ascii=False) + else: + raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.") + + print(f"Single dataset statistics saved to: {save_path}") + print(f"Used action keys (reordered): {list(used_action_keys)}") + print(f"Used state keys (reordered): {list(used_state_keys)}") + + + +class MixtureSpecElement(BaseModel): + dataset_path: list[Path] | Path = Field(..., description="The path to the dataset.") + dataset_weight: float = Field(..., description="The weight of the dataset in the mixture.") + distribute_weights: bool = Field( + default=False, + description="Whether to distribute the weights of the dataset across all the paths. If True, the weights will be evenly distributed across all the paths.", + ) + + +# Helper functions for dataset statistics + +def combine_modality_stats(modality_stats: dict) -> dict: + """ + Combine statistics from all sub-keys under a modality. + + Args: + modality_stats (dict): Statistics for a modality, containing multiple sub-keys. + Each sub-key contains DatasetStatisticalValues object. + + Returns: + dict: Combined statistics + """ + combined_stats = { + "mean": [], + "std": [], + "max": [], + "min": [], + "q01": [], + "q99": [] + } + + # Combine statistics in sub-key order + for subkey in modality_stats.keys(): + subkey_stats = modality_stats[subkey] # This is a DatasetStatisticalValues object + + # Convert DatasetStatisticalValues to dict-like access + for stat_name in ["mean", "std", "max", "min", "q01", "q99"]: + stat_value = getattr(subkey_stats, stat_name) + if isinstance(stat_value, (list, tuple)): + combined_stats[stat_name].extend(stat_value) + else: + # Handle NDArray case - convert to list + if hasattr(stat_value, 'tolist'): + combined_stats[stat_name].extend(stat_value.tolist()) + else: + combined_stats[stat_name].append(float(stat_value)) + + return combined_stats + +def generate_action_mask_for_used_keys(action_modalities: dict, used_action_keys_ordered) -> list[bool]: + """ + Generate mask based on action modalities, but only for used keys. + All dimensions are set to True so every channel is de/normalized. + + Args: + action_modalities (dict): Configuration information for action modalities. + used_action_keys_ordered: Iterable of actually used action keys in the correct order. + + Returns: + list[bool]: List of mask values + """ + mask = [] + + # Generate mask in the same order as the statistics were combined + for subkey in used_action_keys_ordered: + if subkey in action_modalities: + subkey_config = action_modalities[subkey] + + # Get dimension count from shape + if hasattr(subkey_config, 'shape') and len(subkey_config.shape) > 0: + dim_count = subkey_config.shape[0] + else: + dim_count = 1 + + # Check if it's gripper-related + is_gripper = "gripper" in subkey.lower() + + # Generate mask value for each dimension + for _ in range(dim_count): + mask.append(not is_gripper) # gripper is False, others are True + + return mask + +def get_used_modality_keys(modality_keys: dict) -> tuple[set, set]: + """Extract used action and state keys from modality configuration.""" + used_action_keys = [] + used_state_keys = [] + + # Extract action keys (remove "action." prefix) + for action_key in modality_keys.get("action", []): + if action_key.startswith("action."): + clean_key = action_key.replace("action.", "") + used_action_keys.append(clean_key) + + # Extract state keys (remove "state." prefix) + for state_key in modality_keys.get("state", []): + if state_key.startswith("state."): + clean_key = state_key.replace("state.", "") + used_state_keys.append(clean_key) + + return used_action_keys, used_state_keys + + +def safe_hash(input_tuple): + # keep 128 bits of the hash + tuple_string = repr(input_tuple).encode("utf-8") + sha256 = hashlib.sha256() + sha256.update(tuple_string) + + seed = int(sha256.hexdigest(), 16) + + return seed & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF + + +class LeRobotMixtureDataset(Dataset): + """ + A mixture of multiple datasets. This class samples a single dataset based on the dataset weights and then calls the `__getitem__` method of the sampled dataset. + It is recommended to modify the single dataset class instead of this class. + """ + + def __init__( + self, + data_mixture: Sequence[tuple[LeRobotSingleDataset, float]], + mode: str, + balance_dataset_weights: bool = True, + balance_trajectory_weights: bool = True, + seed: int = 42, + metadata_config: dict = { + "percentile_mixing_method": "min_max", + }, + **kwargs, + ): + """ + Initialize the mixture dataset. + + Args: + data_mixture (list[tuple[LeRobotSingleDataset, float]]): Datasets and their corresponding weights. + mode (str): If "train", __getitem__ will return different samples every epoch; if "val" or "test", __getitem__ will return the same sample every epoch. + balance_dataset_weights (bool): If True, the weight of dataset will be multiplied by the total trajectory length of each dataset. + balance_trajectory_weights (bool): If True, sample trajectories within a dataset weighted by their length; otherwise, use equal weighting. + seed (int): Random seed for sampling. + """ + datasets: list[LeRobotSingleDataset] = [] + dataset_sampling_weights: list[float] = [] + for dataset, weight in data_mixture: + # Check if dataset is valid and has data + if len(dataset) == 0: + print(f"Warning: Skipping empty dataset {dataset.dataset_name}") + continue + datasets.append(dataset) + dataset_sampling_weights.append(weight) + + if len(datasets) == 0: + raise ValueError("No valid datasets found in the mixture. All datasets are empty.") + + self.datasets = datasets + self.balance_dataset_weights = balance_dataset_weights + self.balance_trajectory_weights = balance_trajectory_weights + self.seed = seed + self.mode = mode + + # Set properties for sampling + + # 1. Dataset lengths + self._dataset_lengths = np.array([len(dataset) for dataset in self.datasets]) + print(f"Dataset lengths: {self._dataset_lengths}") + + # 2. Dataset sampling weights + self._dataset_sampling_weights = np.array(dataset_sampling_weights) + + if self.balance_dataset_weights: + self._dataset_sampling_weights *= self._dataset_lengths + + # Check for zero or negative weights before normalization + if np.any(self._dataset_sampling_weights <= 0): + print(f"Warning: Found zero or negative sampling weights: {self._dataset_sampling_weights}") + # Set minimum weight to prevent division issues + self._dataset_sampling_weights = np.maximum(self._dataset_sampling_weights, 1e-8) + + # Normalize weights + weights_sum = self._dataset_sampling_weights.sum() + if weights_sum == 0 or np.isnan(weights_sum): + print(f"Error: Invalid weights sum: {weights_sum}") + # Fallback to equal weights + self._dataset_sampling_weights = np.ones(len(self.datasets)) / len(self.datasets) + print(f"Fallback to equal weights") + else: + self._dataset_sampling_weights /= weights_sum + + # 3. Trajectory sampling weights + self._trajectory_sampling_weights: list[np.ndarray] = [] + for i, dataset in enumerate(self.datasets): + trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) + if self.balance_trajectory_weights: + trajectory_sampling_weights *= dataset.trajectory_lengths + + # Check for zero or negative weights before normalization + if np.any(trajectory_sampling_weights <= 0): + print(f"Warning: Dataset {i} has zero or negative trajectory weights") + trajectory_sampling_weights = np.maximum(trajectory_sampling_weights, 1e-8) + + # Normalize weights + weights_sum = trajectory_sampling_weights.sum() + if weights_sum == 0 or np.isnan(weights_sum): + print(f"Error: Dataset {i} has invalid trajectory weights sum: {weights_sum}") + # Fallback to equal weights + trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) / len(dataset.trajectory_lengths) + else: + trajectory_sampling_weights /= weights_sum + + self._trajectory_sampling_weights.append(trajectory_sampling_weights) + + # 4. Primary dataset indices + self._primary_dataset_indices = np.array(dataset_sampling_weights) == 1.0 + if not np.any(self._primary_dataset_indices): + print(f"Warning: No dataset with weight 1.0 found. Original weights: {dataset_sampling_weights}") + # Fallback: use the dataset(s) with maximum weight as primary + max_weight = max(dataset_sampling_weights) + self._primary_dataset_indices = np.array(dataset_sampling_weights) == max_weight + print(f"Using datasets with maximum weight {max_weight} as primary: {self._primary_dataset_indices}") + + if not np.any(self._primary_dataset_indices): + # This should never happen, but just in case + print("Error: Still no primary dataset found. Using first dataset as primary.") + self._primary_dataset_indices = np.zeros(len(self.datasets), dtype=bool) + self._primary_dataset_indices[0] = True + + # Set the epoch and sample the first epoch + self.set_epoch(0) + + self.update_metadata(metadata_config) + + @property + def dataset_lengths(self) -> np.ndarray: + """The lengths of each dataset.""" + return self._dataset_lengths + + @property + def dataset_sampling_weights(self) -> np.ndarray: + """The sampling weights for each dataset.""" + return self._dataset_sampling_weights + + @property + def trajectory_sampling_weights(self) -> list[np.ndarray]: + """The sampling weights for each trajectory in each dataset.""" + return self._trajectory_sampling_weights + + @property + def primary_dataset_indices(self) -> np.ndarray: + """The indices of the primary datasets.""" + return self._primary_dataset_indices + + def __str__(self) -> str: + dataset_descriptions = [] + for dataset, weight in zip(self.datasets, self.dataset_sampling_weights): + dataset_description = { + "Dataset": str(dataset), + "Sampling weight": float(weight), + } + dataset_descriptions.append(dataset_description) + return json.dumps({"Mixture dataset": dataset_descriptions}, indent=2) + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset. + + Args: + epoch (int): The epoch to set. + """ + self.epoch = epoch + # self.sampled_steps = self.sample_epoch() + + def sample_step(self, index: int) -> tuple[LeRobotSingleDataset, int, int]: + """Sample a single step from the dataset.""" + # return self.sampled_steps[index] + + # Set seed + seed = index if self.mode != "train" else safe_hash((self.epoch, index, self.seed)) + rng = np.random.default_rng(seed) + + # Sample dataset + dataset_index = rng.choice(len(self.datasets), p=self.dataset_sampling_weights) + dataset = self.datasets[dataset_index] + + # Sample trajectory + # trajectory_index = rng.choice( + # len(dataset.trajectory_ids), p=self.trajectory_sampling_weights[dataset_index] + # ) + # trajectory_id = dataset.trajectory_ids[trajectory_index] + + # # Sample step + # base_index = rng.choice(dataset.trajectory_lengths[trajectory_index]) + # return dataset, trajectory_id, base_index + single_step_index = rng.choice(len(dataset.all_steps)) + trajectory_id, base_index = dataset.all_steps[single_step_index] + return dataset, trajectory_id, base_index + + def __getitem__(self, index: int) -> dict: + """Get the data for a single trajectory and start index. + + Args: + index (int): The index of the trajectory to get. + + Returns: + dict: The data for the trajectory and start index. + """ + max_retries = 10 + last_exception = None + + for attempt in range(max_retries): + try: + dataset, trajectory_name, step = self.sample_step(index) + data_raw = dataset.get_step_data(trajectory_name, step) + data = dataset.transforms(data_raw) + + # Process all video keys dynamically + images = [] + for video_key in dataset.modality_keys.get("video", []): + image = data[video_key][0] + + image = Image.fromarray(image).resize((224, 224)) #TODO check if this is ok + images.append(image) + + # Get language and action data + language = data[dataset.modality_keys["language"][0]][0] + action = [] + for action_key in dataset.modality_keys["action"]: + action.append(data[action_key]) + action = np.concatenate(action, axis=1).astype(np.float16) + action = standardize_action_representation(action, dataset.tag) + + state = [] + for state_key in dataset.modality_keys["state"]: + state.append(data[state_key]) + state = np.concatenate(state, axis=1).astype(np.float16) + state = standardize_state_representation(state, dataset.tag) + + return dict(action=action, state=state, image=images, lang=language, dataset_id=dataset._dataset_id) + + except Exception as e: + last_exception = e + if attempt < max_retries - 1: + # Log the error but continue trying + print(f"Attempt {attempt + 1}/{max_retries} failed for index {index}: {e}") + print(f"Retrying with new sample...") + # For retry, we can use a slightly different index to get a new sample + # This helps avoid getting stuck on the same problematic sample + index = random.randint(0, len(self) - 1) + else: + # All retries exhausted + print(f"All {max_retries} attempts failed for index {index}") + print(f"Last error: {last_exception}") + # Return a dummy sample or re-raise the exception + raise last_exception + + def __len__(self) -> int: + """Get the length of a single epoch in the mixture. + + Returns: + int: The length of a single epoch in the mixture. + """ + # Check for potential issues + if len(self.datasets) == 0: + return 0 + + # Check if any dataset lengths are 0 or NaN + if np.any(self.dataset_lengths == 0) or np.any(np.isnan(self.dataset_lengths)): + print(f"Warning: Found zero or NaN dataset lengths: {self.dataset_lengths}") + # Filter out zero/NaN length datasets + valid_indices = (self.dataset_lengths > 0) & (~np.isnan(self.dataset_lengths)) + if not np.any(valid_indices): + print("Error: All datasets have zero or NaN length") + return 0 + else: + valid_indices = np.ones(len(self.datasets), dtype=bool) + + # Check if any sampling weights are 0 or NaN + if np.any(self.dataset_sampling_weights == 0) or np.any(np.isnan(self.dataset_sampling_weights)): + print(f"Warning: Found zero or NaN sampling weights: {self.dataset_sampling_weights}") + # Use only valid weights + valid_weights = (self.dataset_sampling_weights > 0) & (~np.isnan(self.dataset_sampling_weights)) + valid_indices = valid_indices & valid_weights + if not np.any(valid_indices): + print("Error: All sampling weights are zero or NaN") + return 0 + + # Check primary dataset indices + primary_and_valid = self.primary_dataset_indices & valid_indices + if not np.any(primary_and_valid): + print(f"Warning: No valid primary datasets found. Primary indices: {self.primary_dataset_indices}, Valid indices: {valid_indices}") + # Fallback: use the largest valid dataset + if np.any(valid_indices): + max_length = self.dataset_lengths[valid_indices].max() + print(f"Fallback: Using maximum dataset length: {max_length}") + return int(max_length) + else: + return 0 + + # Calculate the ratio and get max + ratios = (self.dataset_lengths / self.dataset_sampling_weights)[primary_and_valid] + + # Check for NaN or inf in ratios + if np.any(np.isnan(ratios)) or np.any(np.isinf(ratios)): + print(f"Warning: Found NaN or inf in ratios: {ratios}") + print(f"Dataset lengths: {self.dataset_lengths[primary_and_valid]}") + print(f"Sampling weights: {self.dataset_sampling_weights[primary_and_valid]}") + # Filter out invalid ratios + valid_ratios = ratios[~np.isnan(ratios) & ~np.isinf(ratios)] + if len(valid_ratios) == 0: + print("Error: All ratios are NaN or inf") + return 0 + max_ratio = valid_ratios.max() + else: + max_ratio = ratios.max() + + result = int(max_ratio) + if result == 0: + print(f"Warning: Dataset mixture length is 0") + return result + + @staticmethod + def compute_overall_statistics( + per_task_stats: list[dict[str, dict[str, list[float] | np.ndarray]]], + dataset_sampling_weights: list[float] | np.ndarray, + percentile_mixing_method: str = "weighted_average", + ) -> dict[str, dict[str, list[float]]]: + """ + Computes overall statistics from per-task statistics using dataset sample weights. + + Args: + per_task_stats: List of per-task statistics. + Example format of one element in the per-task statistics list: + { + "state.gripper": { + "min": [...], + "max": [...], + "mean": [...], + "std": [...], + "q01": [...], + "q99": [...], + }, + ... + } + dataset_sampling_weights: List of sample weights for each task. + percentile_mixing_method: The method to mix the percentiles, either "weighted_average" or "weighted_std". + + Returns: + A dict of overall statistics per modality. + """ + # Normalize the sample weights to sum to 1 + dataset_sampling_weights = np.array(dataset_sampling_weights) + normalized_weights = dataset_sampling_weights / dataset_sampling_weights.sum() + + # Initialize overall statistics dict + overall_stats: dict[str, dict[str, list[float]]] = {} + + # Get the list of modality keys + modality_keys = per_task_stats[0].keys() + + for modality in modality_keys: + # Number of dimensions (assuming consistent across tasks) + num_dims = len(per_task_stats[0][modality]["mean"]) + + # Initialize accumulators for means and variances + weighted_means = np.zeros(num_dims) + weighted_squares = np.zeros(num_dims) + + # Collect min, max, q01, q99 from all tasks + min_list = [] + max_list = [] + q01_list = [] + q99_list = [] + + for task_idx, task_stats in enumerate(per_task_stats): + w_i = normalized_weights[task_idx] + stats = task_stats[modality] + means = np.array(stats["mean"]) + stds = np.array(stats["std"]) + + # Update weighted sums for mean and variance + weighted_means += w_i * means + weighted_squares += w_i * (stds**2 + means**2) + + # Collect min, max, q01, q99 + min_list.append(stats["min"]) + max_list.append(stats["max"]) + q01_list.append(stats["q01"]) + q99_list.append(stats["q99"]) + + # Compute overall mean + overall_mean = weighted_means.tolist() + + # Compute overall variance and std deviation + overall_variance = weighted_squares - weighted_means**2 + overall_std = np.sqrt(overall_variance).tolist() + + # Compute overall min and max per dimension + overall_min = np.min(np.array(min_list), axis=0).tolist() + overall_max = np.max(np.array(max_list), axis=0).tolist() + + # Compute overall q01 and q99 per dimension + # Use weighted average of per-task quantiles + q01_array = np.array(q01_list) + q99_array = np.array(q99_list) + if percentile_mixing_method == "weighted_average": + weighted_q01 = np.average(q01_array, axis=0, weights=normalized_weights).tolist() + weighted_q99 = np.average(q99_array, axis=0, weights=normalized_weights).tolist() + # std_q01 = np.std(q01_array, axis=0).tolist() + # std_q99 = np.std(q99_array, axis=0).tolist() + # print(modality) + # print(f"{std_q01=}, {std_q99=}") + # print(f"{weighted_q01=}, {weighted_q99=}") + elif percentile_mixing_method == "min_max": + weighted_q01 = np.min(q01_array, axis=0).tolist() + weighted_q99 = np.max(q99_array, axis=0).tolist() + else: + raise ValueError(f"Invalid percentile mixing method: {percentile_mixing_method}") + + # Store the overall statistics for the modality + overall_stats[modality] = { + "min": overall_min, + "max": overall_max, + "mean": overall_mean, + "std": overall_std, + "q01": weighted_q01, + "q99": weighted_q99, + } + + return overall_stats + + @staticmethod + def merge_metadata( + metadatas: list[DatasetMetadata], + dataset_sampling_weights: list[float], + percentile_mixing_method: str, + ) -> DatasetMetadata: + """Merge multiple metadata into one.""" + # Convert to dicts + metadata_dicts = [metadata.model_dump(mode="json") for metadata in metadatas] + # Create a new metadata dict + merged_metadata = {} + + # Check all metadata have the same embodiment tag + assert all( + metadata.embodiment_tag == metadatas[0].embodiment_tag for metadata in metadatas + ), "All metadata must have the same embodiment tag" + merged_metadata["embodiment_tag"] = metadatas[0].embodiment_tag + + # Merge the dataset statistics + dataset_statistics = {} + dataset_statistics["state"] = LeRobotMixtureDataset.compute_overall_statistics( + per_task_stats=[m["statistics"]["state"] for m in metadata_dicts], + dataset_sampling_weights=dataset_sampling_weights, + percentile_mixing_method=percentile_mixing_method, + ) + dataset_statistics["action"] = LeRobotMixtureDataset.compute_overall_statistics( + per_task_stats=[m["statistics"]["action"] for m in metadata_dicts], + dataset_sampling_weights=dataset_sampling_weights, + percentile_mixing_method=percentile_mixing_method, + ) + merged_metadata["statistics"] = dataset_statistics + + # Merge the modality configs + modality_configs = defaultdict(set) + for metadata in metadata_dicts: + for modality, configs in metadata["modalities"].items(): + modality_configs[modality].add(json.dumps(configs)) + merged_metadata["modalities"] = {} + for modality, configs in modality_configs.items(): + # Check that all modality configs correspond to the same tag matches + assert ( + len(configs) == 1 + ), f"Multiple modality configs for modality {modality}: {list(configs)}" + merged_metadata["modalities"][modality] = json.loads(configs.pop()) + + return DatasetMetadata.model_validate(merged_metadata) + + def update_metadata(self, metadata_config: dict, cached_statistics_path: Path | str | None = None) -> None: + """ + Merge multiple metadatas into one and set the transforms with the merged metadata. + + Args: + metadata_config (dict): Configuration for the metadata. + "percentile_mixing_method": The method to mix the percentiles, either "weighted_average" or "min_max". + weighted_average: Use the weighted average of the percentiles using the weight used in sampling the datasets. + min_max: Use the min of the 1st percentile and max of the 99th percentile. + """ + # If cached path is provided, try to load and apply + if cached_statistics_path is not None: + try: + cached_stats = self.load_merged_statistics(cached_statistics_path) + self.apply_cached_statistics(cached_stats) + return + except (FileNotFoundError, KeyError, ValidationError) as e: + print(f"Failed to load cached statistics: {e}") + print("Falling back to computing statistics from scratch...") + + self.tag = EmbodimentTag.NEW_EMBODIMENT.value + self.merged_metadata: dict[str, DatasetMetadata] = {} + # Group metadata by tag + all_metadatas: dict[str, list[DatasetMetadata]] = {} + for dataset in self.datasets: + if dataset.tag not in all_metadatas: + all_metadatas[dataset.tag] = [] + all_metadatas[dataset.tag].append(dataset.metadata) + for tag, metadatas in all_metadatas.items(): + self.merged_metadata[tag] = self.merge_metadata( + metadatas=metadatas, + dataset_sampling_weights=self.dataset_sampling_weights.tolist(), + percentile_mixing_method=metadata_config["percentile_mixing_method"], + ) + for dataset in self.datasets: + dataset.set_transforms_metadata(self.merged_metadata[dataset.tag]) + + def save_dataset_statistics(self, save_path: Path | str, format: str = "json") -> None: + """ + Save merged dataset statistics to specified path in the required format. + Only includes statistics for keys that are actually used in the datasets. + Key order follows each tag's modality config order. + + Args: + save_path (Path | str): Path to save the statistics file + format (str): Save format, currently only supports "json" + """ + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + # Build the data structure to save + statistics_data = {} + + # Keep key orders per embodiment tag (from modality config order) + tag_to_used_action_keys = {} + tag_to_used_state_keys = {} + for dataset in self.datasets: + if dataset.tag in tag_to_used_action_keys: + continue + used_action_keys, used_state_keys = get_used_modality_keys(dataset.modality_keys) + tag_to_used_action_keys[dataset.tag] = used_action_keys + tag_to_used_state_keys[dataset.tag] = used_state_keys + + # Organize statistics by tag + for tag, merged_metadata in self.merged_metadata.items(): + tag_stats = {} + + # Process action statistics + if hasattr(merged_metadata.statistics, 'action') and merged_metadata.statistics.action: + action_stats = merged_metadata.statistics.action + + used_action_keys = tag_to_used_action_keys.get(tag, []) + filtered_action_stats = { + key: action_stats[key] + for key in used_action_keys + if key in action_stats + } + + if filtered_action_stats: + combined_action_stats = combine_modality_stats(filtered_action_stats) + + mask = generate_action_mask_for_used_keys( + merged_metadata.modalities.action, filtered_action_stats.keys() + ) + combined_action_stats["mask"] = mask + + tag_stats["action"] = combined_action_stats + + # Process state statistics + if hasattr(merged_metadata.statistics, 'state') and merged_metadata.statistics.state: + state_stats = merged_metadata.statistics.state + + used_state_keys = tag_to_used_state_keys.get(tag, []) + filtered_state_stats = { + key: state_stats[key] + for key in used_state_keys + if key in state_stats + } + + if filtered_state_stats: + combined_state_stats = combine_modality_stats(filtered_state_stats) + tag_stats["state"] = combined_state_stats + + # Add dataset counts + tag_stats.update(self._get_dataset_counts(tag)) + + statistics_data[tag] = tag_stats + + # Save file + if format.lower() == "json": + if not str(save_path).endswith('.json'): + save_path = save_path.with_suffix('.json') + with open(save_path, 'w', encoding='utf-8') as f: + json.dump(statistics_data, f, indent=2, ensure_ascii=False) + else: + raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.") + + print(f"Merged dataset statistics saved to: {save_path}") + print(f"Used action keys by tag: {tag_to_used_action_keys}") + print(f"Used state keys by tag: {tag_to_used_state_keys}") + + + def _combine_modality_stats(self, modality_stats: dict) -> dict: + """Backward compatibility wrapper.""" + return combine_modality_stats(modality_stats) + + def _generate_action_mask_for_used_keys(self, action_modalities: dict, used_action_keys_ordered) -> list[bool]: + """Backward compatibility wrapper.""" + return generate_action_mask_for_used_keys(action_modalities, used_action_keys_ordered) + + def _get_dataset_counts(self, tag: str) -> dict: + """ + Get dataset count information for specified tag. + + Args: + tag (str): embodiment tag + + Returns: + dict: Dictionary containing num_transitions and num_trajectories + """ + num_transitions = 0 + num_trajectories = 0 + + # Count dataset information belonging to this tag + for dataset in self.datasets: + if dataset.tag == tag: + num_transitions += len(dataset) + num_trajectories += len(dataset.trajectory_ids) + + return { + "num_transitions": num_transitions, + "num_trajectories": num_trajectories + } + + @classmethod + def load_merged_statistics(cls, load_path: Path | str) -> dict: + """ + Load merged dataset statistics from file. + + Args: + load_path (Path | str): Path to the statistics file + + Returns: + dict: Dictionary containing merged statistics + """ + load_path = Path(load_path) + if not load_path.exists(): + raise FileNotFoundError(f"Statistics file not found: {load_path}") + + if load_path.suffix.lower() == '.json': + with open(load_path, 'r', encoding='utf-8') as f: + return json.load(f) + elif load_path.suffix.lower() == '.pkl': + import pickle + with open(load_path, 'rb') as f: + return pickle.load(f) + else: + raise ValueError(f"Unsupported file format: {load_path.suffix}") + + def apply_cached_statistics(self, cached_statistics: dict) -> None: + """ + Apply cached statistics to avoid recomputation. + + Args: + cached_statistics (dict): Statistics loaded from file + """ + # Validate that cached statistics match current datasets + if "metadata" in cached_statistics: + cached_dataset_names = set(cached_statistics["metadata"]["dataset_names"]) + current_dataset_names = set(dataset.dataset_name for dataset in self.datasets) + + if cached_dataset_names != current_dataset_names: + print("Warning: Cached statistics dataset names don't match current datasets.") + print(f"Cached: {cached_dataset_names}") + print(f"Current: {current_dataset_names}") + return + + # Apply cached statistics + self.merged_metadata = {} + for tag, stats_data in cached_statistics.items(): + if tag == "metadata": # Skip metadata field + continue + + # Convert back to DatasetMetadata format + metadata_dict = { + "embodiment_tag": tag, + "statistics": { + "action": {}, + "state": {} + }, + "modalities": {} + } + + # Convert action statistics back + if "action" in stats_data: + action_data = stats_data["action"] + # This is simplified - you may need to split back to sub-keys + metadata_dict["statistics"]["action"] = action_data + + # Convert state statistics back + if "state" in stats_data: + state_data = stats_data["state"] + metadata_dict["statistics"]["state"] = state_data + + self.merged_metadata[tag] = DatasetMetadata.model_validate(metadata_dict) + + # Update transforms metadata for each dataset + for dataset in self.datasets: + if dataset.tag in self.merged_metadata: + dataset.set_transforms_metadata(self.merged_metadata[dataset.tag]) + + print(f"Applied cached statistics for {len(self.merged_metadata)} embodiment tags.") + diff --git a/code/dataloader/gr00t_lerobot/embodiment_tags.py b/code/dataloader/gr00t_lerobot/embodiment_tags.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0376ba4316f0c4f6944f891750bfd58ec45e0a --- /dev/null +++ b/code/dataloader/gr00t_lerobot/embodiment_tags.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class EmbodimentTag(Enum): + GR1 = "gr1" + """ + The GR1 dataset. + """ + + OXE_DROID = "oxe_droid" + """ + The OxE Droid dataset. + """ + + OXE_BRIDGE = "oxe_bridge" + """ + The OxE Bridge dataset. + """ + + OXE_RT1 = "oxe_rt1" + """ + The OxE RT-1 dataset. + """ + + AGIBOT_GENIE1 = "agibot_genie1" + """ + The AgiBot Genie-1 with gripper dataset. + """ + + NEW_EMBODIMENT = "new_embodiment" + """ + Any new embodiment for finetuning. + """ + + FRANKA = 'franka' + """ + The Franka Emika Panda robot. + """ + + ROBOTWIN = "robotwin" + """ + RobotWin (dual-arm) datasets. + """ + + REAL_WORLD_FRANKA = "real_world_franka" + """ + The Real-World Franka robot. + """ + +# Embodiment tag string: to projector index in the Action Expert Module +# EMBODIMENT_TAG_MAPPING = { +# EmbodimentTag.NEW_EMBODIMENT.value: 31, +# EmbodimentTag.OXE_DROID.value: 17, +# EmbodimentTag.OXE_BRIDGE.value: 18, +# EmbodimentTag.OXE_RT1.value: 19, +# EmbodimentTag.AGIBOT_GENIE1.value: 26, +# EmbodimentTag.GR1.value: 24, +# EmbodimentTag.FRANKA.value: 25, +# EmbodimentTag.ROBOTWIN.value: 27, +# EmbodimentTag.REAL_WORLD_FRANKA.value: 28, +# } + +# Robot type to embodiment tag mapping +ROBOT_TYPE_TO_EMBODIMENT_TAG = { + "libero_franka": EmbodimentTag.FRANKA, + "oxe_droid": EmbodimentTag.OXE_DROID, + "oxe_bridge": EmbodimentTag.OXE_BRIDGE, + "oxe_rt1": EmbodimentTag.OXE_RT1, + "demo_sim_franka_delta_joints": EmbodimentTag.FRANKA, + "custom_robot_config": EmbodimentTag.NEW_EMBODIMENT, + "fourier_gr1_arms_waist": EmbodimentTag.GR1, + "robotwin": EmbodimentTag.ROBOTWIN, + "real_world_franka": EmbodimentTag.REAL_WORLD_FRANKA, + } + +DATASET_NAME_TO_ID = { + # Libero Datasets + "libero_object_no_noops_1.0.0_lerobot": 1, + "libero_goal_no_noops_1.0.0_lerobot": 1, + "libero_spatial_no_noops_1.0.0_lerobot": 1, + "libero_10_no_noops_1.0.0_lerobot": 1, + "libero_90_no_noops_lerobot": 1, + + # OXE Datasets + "bridge_orig_lerobot": 2, + "fractal20220817_data_lerobot": 3, + "droid_lerobot": 4, + "furniture_bench_dataset_lerobot": 5, + "taco_play_lerobot": 6, + + # RoboCasa Datasets + "gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PnPCanToDrawerClose_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PnPCupToDrawerClose_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PnPMilkToMicrowaveClose_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PnPPotatoToMicrowaveClose_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PnPWineToCabinetClose_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromCuttingboardToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromCuttingboardToPanSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromCuttingboardToPotSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromCuttingboardToTieredbasketSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromPlacematToBowlSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromPlacematToPlateSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromPlacematToTieredshelfSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromPlateToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromPlateToPanSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromPlateToPlateSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromTrayToPlateSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromTrayToPotSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromTrayToTieredbasketSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PosttrainPnPNovelFromTrayToTieredshelfSplitA_GR1ArmsAndWaistFourierHands_1000": 7, + "gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200": 7, + "gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200": 7, + "gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200": 7, + "gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200": 7, + "gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200": 7, + + # robotwin + "adjust_bottle": 8, + "beat_block_hammer": 8, + "blocks_ranking_rgb": 8, + "blocks_ranking_size": 8, + "click_alarmclock": 8, + "click_bell": 8, + "dump_bin_bigbin": 8, + "grab_roller": 8, + "handover_block": 8, + "handover_mic": 8, + "hanging_mug": 8, + "lift_pot": 8, + "move_can_pot": 8, + "move_pillbottle_pad": 8, + "move_playingcard_away": 8, + "move_stapler_pad": 8, + "open_laptop": 8, + "open_microwave": 8, + "pick_diverse_bottles": 8, + "pick_dual_bottles": 8, + "place_a2b_left": 8, + "place_a2b_right": 8, + "place_bread_basket": 8, + "place_bread_skillet": 8, + "place_burger_fries": 8, + "place_can_basket": 8, + "place_cans_plasticbox": 8, + "place_container_plate": 8, + "place_dual_shoes": 8, + "place_empty_cup": 8, + "place_fan": 8, + "place_mouse_pad": 8, + "place_object_basket": 8, + "place_object_scale": 8, + "place_object_stand": 8, + "place_phone_stand": 8, + "place_shoe": 8, + "press_stapler": 8, + "put_bottles_dustbin": 8, + "put_object_cabinet": 8, + "rotate_qrcode": 8, + "scan_object": 8, + "shake_bottle_horizontally": 8, + "shake_bottle": 8, + "stack_blocks_three": 8, + "stack_blocks_two": 8, + "stack_bowls_three": 8, + "stack_bowls_two": 8, + "stamp_seal": 8, + "turn_switch": 8, + + # real-world + "real_grasp_coke": 9, + "real_pick_up_cup_in_middle": 9, + "real_stack_cups": 9, + "real_put_apple_on_tray_and_then_put_banana_on_tray": 9, + "realworld_tasks_all": 9, + "realworld_4tasks": 9, + "realworld_collect": 9, + "realworld_pickplace_4tasks": 9, +} \ No newline at end of file diff --git a/code/dataloader/gr00t_lerobot/mixtures.py b/code/dataloader/gr00t_lerobot/mixtures.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd282a3e65440a3db58abd6baff45de5ce730d0 --- /dev/null +++ b/code/dataloader/gr00t_lerobot/mixtures.py @@ -0,0 +1,241 @@ +""" +mixtures.py + +Defines a registry of dataset mixtures and weights for the Open-X Embodiment Datasets. Each dataset is associated with +a float "sampling weight" +""" + +from typing import Dict, List, Tuple + + +# Dataset mixture name mapped to a list of tuples containing: +## {nakename: [(data_name, sampling_weight, robot_type)] } +DATASET_NAMED_MIXTURES = { + + "custom_dataset": [ + ("custom_dataset_name", 1.0, "custom_robot_config"), + ], + "custom_dataset_2": [ + ("custom_dataset_name_1", 1.0, "custom_robot_config"), + ("custom_dataset_name_2", 1.0, "custom_robot_config"), + ], + + "libero_all": [ + ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), + ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), + ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), + ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), + # ("libero_90_no_noops_lerobot", 1.0, "libero_franka"), + ], + "bridge": [ + ("bridge_orig_1.0.0_lerobot", 1.0, "oxe_bridge"), + ], + "bridge_rt_1": [ + ("bridge_orig_1.0.0_lerobot", 1.0, "oxe_bridge"), + ("fractal20220817_data_0.1.0_lerobot", 1.0, "oxe_rt1"), + ], + + "demo_sim_pick_place": [ + ("sim_pick_place", 1.0, "demo_sim_franka_delta_joints"), + ], + + "custom_dataset": [ + ("custom_dataset_name", 1.0, "custom_robot_config"), + ], + "custom_dataset_2": [ + ("custom_dataset_name_1", 1.0, "custom_robot_config"), + ("custom_dataset_name_2", 1.0, "custom_robot_config"), + ], + + "fourier_gr1_unified_1000": [ + ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PnPCanToDrawerClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PnPCupToDrawerClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PnPMilkToMicrowaveClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PnPPotatoToMicrowaveClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PnPWineToCabinetClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToPanSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToPotSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToTieredbasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromPlacematToBowlSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromPlacematToPlateSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromPlacematToTieredshelfSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromPlateToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromPlateToPanSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromPlateToPlateSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromTrayToPlateSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromTrayToPotSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromTrayToTieredbasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ("gr1_unified.PosttrainPnPNovelFromTrayToTieredshelfSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), + ], + + "BEHAVIOR_challenge": [ + ("BEHAVIOR_challenge", 1.0, "R1Pro"), + ], + + + "SO101_pick": [ + ("pick_dataset_name", 1.0, "SO101"), + ], + + "arx_x5": [ + ("arx_x5", 1.0, "arx_x5"), + ], + + "robotwin": [ + ("adjust_bottle", 1.0, "robotwin"), + ("beat_block_hammer", 1.0, "robotwin"), + ("blocks_ranking_rgb", 1.0, "robotwin"), + ("blocks_ranking_size", 1.0, "robotwin"), + ("click_alarmclock", 1.0, "robotwin"), + ("click_bell", 1.0, "robotwin"), + ("dump_bin_bigbin", 1.0, "robotwin"), + ("grab_roller", 1.0, "robotwin"), + ("handover_block", 1.0, "robotwin"), + ("handover_mic", 1.0, "robotwin"), + ("hanging_mug", 1.0, "robotwin"), + ("lift_pot", 1.0, "robotwin"), + ("move_can_pot", 1.0, "robotwin"), + ("move_pillbottle_pad", 1.0, "robotwin"), + ("move_playingcard_away", 1.0, "robotwin"), + ("move_stapler_pad", 1.0, "robotwin"), + ("open_laptop", 1.0, "robotwin"), + ("open_microwave", 1.0, "robotwin"), + ("pick_diverse_bottles", 1.0, "robotwin"), + ("pick_dual_bottles", 1.0, "robotwin"), + ("place_a2b_left", 1.0, "robotwin"), + ("place_a2b_right", 1.0, "robotwin"), + ("place_bread_basket", 1.0, "robotwin"), + ("place_bread_skillet", 1.0, "robotwin"), + ("place_burger_fries", 1.0, "robotwin"), + ("place_can_basket", 1.0, "robotwin"), + ("place_cans_plasticbox", 1.0, "robotwin"), + ("place_container_plate", 1.0, "robotwin"), + ("place_dual_shoes", 1.0, "robotwin"), + ("place_empty_cup", 1.0, "robotwin"), + ("place_fan", 1.0, "robotwin"), + ("place_mouse_pad", 1.0, "robotwin"), + ("place_object_basket", 1.0, "robotwin"), + ("place_object_scale", 1.0, "robotwin"), + ("place_object_stand", 1.0, "robotwin"), + ("place_phone_stand", 1.0, "robotwin"), + ("place_shoe", 1.0, "robotwin"), + ("press_stapler", 1.0, "robotwin"), + ("put_bottles_dustbin", 1.0, "robotwin"), + ("put_object_cabinet", 1.0, "robotwin"), + ("rotate_qrcode", 1.0, "robotwin"), + ("scan_object", 1.0, "robotwin"), + ("shake_bottle", 1.0, "robotwin"), + ("shake_bottle_horizontally", 1.0, "robotwin"), + ("stack_blocks_three", 1.0, "robotwin"), + ("stack_blocks_two", 1.0, "robotwin"), + ("stack_bowls_three", 1.0, "robotwin"), + ("stack_bowls_two", 1.0, "robotwin"), + ("stamp_seal", 1.0, "robotwin"), + ("turn_switch", 1.0, "robotwin"), + ], + "cross_embodiedment_17tasks": [ + # libero - 4 tasks + ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984 + ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042 + ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970 + ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469 + # robotwin - 8 tasks, selected by average trajectory length, 400, 500, 600, 700, 800, 900, 900, 1200 + ("beat_block_hammer", 1.0, "robotwin"), # + ("place_shoe", 1.0, "robotwin"), # + ("dump_bin_bigbin", 1.0, "robotwin"), # + ("put_object_cabinet", 1.0, "robotwin"), # + ("stack_blocks_two", 1.0, "robotwin"), # + ("stack_bowls_two", 1.0, "robotwin"), # + ("shake_bottle", 1.0, "robotwin"), # + ("hanging_mug", 1.0, "robotwin"), # + # ("blocks_ranking_rgb", 1.0, "robotwin"), # + # gr1 - 5 tasks + ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 71341 + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48282 + ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48066 + ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 41518 + ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 39739 + ], + "cross_embodiedment_21tasks": [ + # libero - 4 tasks + ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984 + ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042 + ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970 + ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469 + # robotwin - 8 tasks, selected by average trajectory length, 400, 500, 600, 700, 800, 900, 900, 1200 + ("beat_block_hammer", 1.0, "robotwin"), # + ("place_shoe", 1.0, "robotwin"), # + ("dump_bin_bigbin", 1.0, "robotwin"), # + ("put_object_cabinet", 1.0, "robotwin"), # + ("stack_blocks_two", 1.0, "robotwin"), # + ("stack_bowls_two", 1.0, "robotwin"), # + ("shake_bottle", 1.0, "robotwin"), # + ("hanging_mug", 1.0, "robotwin"), # + # ("blocks_ranking_rgb", 1.0, "robotwin"), # + # gr1 - 5 tasks + ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 71341 + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48282 + ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48066 + ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 41518 + ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 39739 + # real-world - 4 tasks + ("realworld_4tasks", 1.0, "real_world_franka"), + ], + "real_world_4tasks": [ + ("realworld_4tasks", 1.0, "real_world_franka"), + ], + "realworld_tasks_all": [ + ("realworld_tasks_all", 1.0, "real_world_franka"), + ], + "realworld_collect": [ + ("realworld_collect", 1.0, "real_world_franka"), + ], + "cross_embodiedment_13tasks": [ + # libero - 4 tasks + ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984 + ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042 + ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970 + ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469 + # gr1 - 5 tasks + ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 71341 + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48282 + ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48066 + ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 41518 + ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 39739 + # real-world - 4 tasks + ("realworld_pickplace_4tasks", 1.0, "real_world_franka"), + ], + "cross_embodiedment_simulator": [ + # libero - 4 tasks + ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984 + ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042 + ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970 + ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469 + # gr1 - 5 tasks + ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 71341 + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48282 + ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48066 + ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 41518 + ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 39739 + ], + "cross_embodiedment_simulator_moredata": [ + # libero - 4 tasks + ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984 + ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042 + ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970 + ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469 + ("libero_90_no_noops_lerobot", 1.0, "libero_franka"), # 901020 + # gr1 - 5 tasks + ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 71341 x 5 + ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 48282 x 5 + ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 48066 x 5 + ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 41518 x 5 + ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 39739 x 5 + ], +} diff --git a/code/dataloader/gr00t_lerobot/schema.py b/code/dataloader/gr00t_lerobot/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..64519e56a3c59e5d08c8f8f6370f640061859b4d --- /dev/null +++ b/code/dataloader/gr00t_lerobot/schema.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Optional + +from numpydantic import NDArray +from pydantic import BaseModel, Field, field_serializer + +from .embodiment_tags import EmbodimentTag + +# Common schema + + +class RotationType(Enum): + """Type of rotation representation""" + + AXIS_ANGLE = "axis_angle" + QUATERNION = "quaternion" + ROTATION_6D = "rotation_6d" + MATRIX = "matrix" + EULER_ANGLES_RPY = "euler_angles_rpy" + EULER_ANGLES_RYP = "euler_angles_ryp" + EULER_ANGLES_PRY = "euler_angles_pry" + EULER_ANGLES_PYR = "euler_angles_pyr" + EULER_ANGLES_YRP = "euler_angles_yrp" + EULER_ANGLES_YPR = "euler_angles_ypr" + + +# LeRobot schema + + +class LeRobotModalityField(BaseModel): + """Metadata for a LeRobot modality field.""" + + original_key: Optional[str] = Field( + default=None, + description="The original key of the modality in the LeRobot dataset", + ) + + +class LeRobotStateActionMetadata(LeRobotModalityField): + """Metadata for a LeRobot modality.""" + + start: int = Field( + ..., + description="The start index of the modality in the concatenated state/action vector", + ) + end: int = Field( + ..., + description="The end index of the modality in the concatenated state/action vector", + ) + rotation_type: Optional[RotationType] = Field( + default=None, description="The type of rotation for the modality" + ) + absolute: bool = Field(default=True, description="Whether the modality is absolute") + dtype: str = Field( + default="float64", + description="The data type of the modality. Defaults to float64.", + ) + range: Optional[tuple[float, float]] = Field( + default=None, + description="The range of the modality, if applicable. Defaults to None.", + ) + original_key: Optional[str] = Field( + default=None, + description="The original key of the modality in the LeRobot dataset.", + ) + + +class LeRobotStateMetadata(LeRobotStateActionMetadata): + """Metadata for a LeRobot state modality.""" + + original_key: Optional[str] = Field( + default="observation.state", # LeRobot convention for states + description="The original key of the state modality in the LeRobot dataset", + ) + + +class LeRobotActionMetadata(LeRobotStateActionMetadata): + """Metadata for a LeRobot action modality.""" + + original_key: Optional[str] = Field( + default="action", # LeRobot convention for actions + description="The original key of the action modality in the LeRobot dataset", + ) + + +class LeRobotModalityMetadata(BaseModel): + """Metadata for a LeRobot modality.""" + + state: dict[str, LeRobotStateMetadata] = Field( + ..., + description="The metadata for the state modality. The keys are the names of each split of the state vector.", + ) + action: dict[str, LeRobotActionMetadata] = Field( + ..., + description="The metadata for the action modality. The keys are the names of each split of the action vector.", + ) + video: dict[str, LeRobotModalityField] = Field( + ..., + description="The metadata for the video modality. The keys are the new names of each video modality.", + ) + annotation: Optional[dict[str, LeRobotModalityField]] = Field( + default=None, + description="The metadata for the annotation modality. The keys are the new names of each annotation modality.", + ) + + def get_key_meta(self, key: str) -> LeRobotModalityField: + """Get the metadata for a key in the LeRobot modality metadata. + + Args: + key (str): The key to get the metadata for. + + Returns: + LeRobotModalityField: The metadata for the key. + + Example: + lerobot_modality_meta = LeRobotModalityMetadata.model_validate(U.load_json(modality_meta_path)) + lerobot_modality_meta.get_key_meta("state.joint_shoulder_y") + lerobot_modality_meta.get_key_meta("video.main_camera") + lerobot_modality_meta.get_key_meta("annotation.human.action.task_description") + """ + split_key = key.split(".") + modality = split_key[0] + subkey = ".".join(split_key[1:]) + if modality == "state": + if subkey not in self.state: + raise ValueError( + f"Key: {key}, state key {subkey} not found in metadata, available state keys: {self.state.keys()}" + ) + return self.state[subkey] + elif modality == "action": + if subkey not in self.action: + raise ValueError( + f"Key: {key}, action key {subkey} not found in metadata, available action keys: {self.action.keys()}" + ) + return self.action[subkey] + elif modality == "video": + if subkey not in self.video: + raise ValueError( + f"Key: {key}, video key {subkey} not found in metadata, available video keys: {self.video.keys()}" + ) + return self.video[subkey] + elif modality == "annotation": + assert ( + self.annotation is not None + ), "Trying to get annotation metadata for a dataset with no annotations" + if subkey not in self.annotation: + raise ValueError( + f"Key: {key}, annotation key {subkey} not found in metadata, available annotation keys: {self.annotation.keys()}" + ) + return self.annotation[subkey] + else: + raise ValueError(f"Key: {key}, unexpected modality: {modality}") + + +# Dataset schema (parsed from LeRobot schema and simplified) + + +class DatasetStatisticalValues(BaseModel): + max: NDArray = Field(..., description="Maximum values") + min: NDArray = Field(..., description="Minimum values") + mean: NDArray = Field(..., description="Mean values") + std: NDArray = Field(..., description="Standard deviation") + q01: NDArray = Field(..., description="1st percentile values") + q99: NDArray = Field(..., description="99th percentile values") + + @field_serializer("*", when_used="json") + def serialize_ndarray(self, v: NDArray) -> list[float]: + return v.tolist() # type: ignore + + +class DatasetStatistics(BaseModel): + state: dict[str, DatasetStatisticalValues] = Field(..., description="Statistics of the state") + action: dict[str, DatasetStatisticalValues] = Field(..., description="Statistics of the action") + + +class VideoMetadata(BaseModel): + """Metadata of the video modality""" + + resolution: tuple[int, int] = Field(..., description="Resolution of the video") + channels: int = Field(..., description="Number of channels in the video", gt=0) + fps: float = Field(..., description="Frames per second", gt=0) + + +class StateActionMetadata(BaseModel): + absolute: bool = Field(..., description="Whether the state or action is absolute") + rotation_type: Optional[RotationType] = Field(None, description="Type of rotation, if any") + shape: tuple[int, ...] = Field(..., description="Shape of the state or action") + continuous: bool = Field(..., description="Whether the state or action is continuous") + + +class DatasetModalities(BaseModel): + video: dict[str, VideoMetadata] = Field(..., description="Metadata of the video") + state: dict[str, StateActionMetadata] = Field(..., description="Metadata of the state") + action: dict[str, StateActionMetadata] = Field(..., description="Metadata of the action") + + +class DatasetMetadata(BaseModel): + """Metadata of the trainable dataset + + Changes: + - Update to use the new RawCommitHashMetadataMetadata_V1_2 + """ + + statistics: DatasetStatistics = Field(..., description="Statistics of the dataset") + modalities: DatasetModalities = Field(..., description="Metadata of the modalities") + embodiment_tag: EmbodimentTag = Field(..., description="Embodiment tag of the dataset") \ No newline at end of file diff --git a/code/dataloader/gr00t_lerobot/transform/__init__.py b/code/dataloader/gr00t_lerobot/transform/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf22b5e4ca3d1c7937a25234cbf08e2644593587 --- /dev/null +++ b/code/dataloader/gr00t_lerobot/transform/__init__.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import ( + ComposedModalityTransform, + InvertibleModalityTransform, + ModalityTransform, +) +from .concat import ConcatTransform +# from .state_action import ( +# StateActionDropout, +# StateActionPerturbation, +# StateActionSinCosTransform, +# StateActionToTensor, +# StateActionTransform, +# ) +from .video import ( + VideoColorJitter, + VideoCrop, + VideoGrayscale, + VideoHorizontalFlip, + VideoRandomGrayscale, + VideoRandomPosterize, + VideoRandomRotation, + VideoResize, + VideoToNumpy, + VideoToTensor, + VideoTransform, +) diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-310.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bb6b78fd530deecd32d433f87df53df03795d79 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-311.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce43a586fe74874c6615489ebd4e023fde731959 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/base.cpython-310.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a3c6c6e0302d9bd17d0291b7e3d7e22b7f85d2b Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/base.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/base.cpython-311.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad0b45da5432f326c2d3baa1dbab7ccb961aa9a7 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/base.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/concat.cpython-310.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/concat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c029431c2bd59ddcff7e95c263646b6bb97023f Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/concat.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/concat.cpython-311.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/concat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac8ffb7df701e713c72e5792d43fcf79beef8a98 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/concat.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/state_action.cpython-310.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/state_action.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..434e5b2d121084471d1fc7868a168978869527cd Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/state_action.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/state_action.cpython-311.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/state_action.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..971f5db4ee3c7ed4593366370b177bcb88e57679 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/state_action.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/video.cpython-310.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/video.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eebb8c0e5de56c6f6f1301c506e368ca5986ff9 Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/video.cpython-310.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/__pycache__/video.cpython-311.pyc b/code/dataloader/gr00t_lerobot/transform/__pycache__/video.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b3a1366881b0af2248f142e90cf360dcfdedfce Binary files /dev/null and b/code/dataloader/gr00t_lerobot/transform/__pycache__/video.cpython-311.pyc differ diff --git a/code/dataloader/gr00t_lerobot/transform/base.py b/code/dataloader/gr00t_lerobot/transform/base.py new file mode 100644 index 0000000000000000000000000000000000000000..aac88559af98fa23f34fbb9135775d0819c281ef --- /dev/null +++ b/code/dataloader/gr00t_lerobot/transform/base.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr + +from ..schema import DatasetMetadata + + +class ModalityTransform(BaseModel, ABC): + """ + Abstract class for transforming data modalities, e.g. video frame augmentation or action normalization. + """ + + apply_to: list[str] = Field(..., description="The keys to apply the transform to.") + training: bool = Field( + default=True, description="Whether to apply the transform in training mode." + ) + _dataset_metadata: DatasetMetadata | None = PrivateAttr(default=None) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @property + def dataset_metadata(self) -> DatasetMetadata: + assert ( + self._dataset_metadata is not None + ), "Dataset metadata is not set. Please call set_metadata() before calling apply()." + return self._dataset_metadata + + @dataset_metadata.setter + def dataset_metadata(self, value: DatasetMetadata): + self._dataset_metadata = value + + def set_metadata(self, dataset_metadata: DatasetMetadata): + """ + Set the dataset metadata. This is useful for transforms that need to know the dataset metadata, e.g. to normalize actions. + Subclasses can override this method if they need to do something more complex. + """ + self.dataset_metadata = dataset_metadata + + def __call__(self, data: dict[str, Any]) -> dict[str, Any]: + """Apply the transformation to the data corresponding to target_keys and return the processed data. + + Args: + data (dict[str, Any]): The data to transform. + example: data = { + "video.image_side_0": np.ndarray, + "action.eef_position": np.ndarray, + ... + } + + Returns: + dict[str, Any]: The transformed data. + example: transformed_data = { + "video.image_side_0": np.ndarray, + "action.eef_position": torch.Tensor, # Normalized and converted to tensor + ... + } + """ + return self.apply(data) + + @abstractmethod + def apply(self, data: dict[str, Any]) -> dict[str, Any]: + """Apply the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data.""" + + def train(self): + self.training = True + + def eval(self): + self.training = False + + +class InvertibleModalityTransform(ModalityTransform): + @abstractmethod + def unapply(self, data: dict[str, Any]) -> dict[str, Any]: + """Reverse the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data.""" + + +class ComposedModalityTransform(ModalityTransform): + """Compose multiple modality transforms.""" + + transforms: list[ModalityTransform] = Field(..., description="The transforms to compose.") + apply_to: list[str] = Field( + default_factory=list, description="Will be ignored for composed transforms." + ) + training: bool = Field( + default=True, description="Whether to apply the transform in training mode." + ) + + model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) + + def set_metadata(self, dataset_metadata: DatasetMetadata): + for transform in self.transforms: + transform.set_metadata(dataset_metadata) + + def apply(self, data: dict[str, Any]) -> dict[str, Any]: + for i, transform in enumerate(self.transforms): + try: + data = transform(data) + except Exception as e: + raise ValueError(f"Error applying transform {i} to data: {e}") from e + return data + + def unapply(self, data: dict[str, Any]) -> dict[str, Any]: + for i, transform in enumerate(reversed(self.transforms)): + if isinstance(transform, InvertibleModalityTransform): + try: + data = transform.unapply(data) + except Exception as e: + step = len(self.transforms) - i - 1 + raise ValueError(f"Error unapplying transform {step} to data: {e}") from e + return data + + def train(self): + for transform in self.transforms: + transform.train() + + def eval(self): + for transform in self.transforms: + transform.eval() diff --git a/code/dataloader/gr00t_lerobot/transform/concat.py b/code/dataloader/gr00t_lerobot/transform/concat.py new file mode 100644 index 0000000000000000000000000000000000000000..cf8eea4c77fc163ecdb0d25aeca26a2cde99f8c4 --- /dev/null +++ b/code/dataloader/gr00t_lerobot/transform/concat.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np +import torch +from pydantic import Field + +from ..schema import DatasetMetadata, StateActionMetadata +from .base import InvertibleModalityTransform + + +class ConcatTransform(InvertibleModalityTransform): + """ + Concatenate the keys according to specified order. + """ + + # -- We inherit from ModalityTransform, so we keep apply_to as well -- + apply_to: list[str] = Field( + default_factory=list, description="Not used in this transform, kept for compatibility." + ) + + video_concat_order: list[str] = Field( + ..., + description="Concatenation order for each video modality. " + "Format: ['video.ego_view_pad_res224_freq20', ...]", + ) + + state_concat_order: Optional[list[str]] = Field( + default=None, + description="Concatenation order for each state modality. " + "Format: ['state.position', 'state.velocity', ...].", + ) + + action_concat_order: Optional[list[str]] = Field( + default=None, + description="Concatenation order for each action modality. " + "Format: ['action.position', 'action.velocity', ...].", + ) + + action_dims: dict[str, int] = Field( + default_factory=dict, + description="The dimensions of the action keys.", + ) + state_dims: dict[str, int] = Field( + default_factory=dict, + description="The dimensions of the state keys.", + ) + + def model_dump(self, *args, **kwargs): + if kwargs.get("mode", "python") == "json": + include = { + "apply_to", + "video_concat_order", + "state_concat_order", + "action_concat_order", + } + else: + include = kwargs.pop("include", None) + + return super().model_dump(*args, include=include, **kwargs) + + def apply(self, data: dict) -> dict: + grouped_keys = {} + for key in data.keys(): + try: + modality, _ = key.split(".") + except: # noqa: E722 + ### Handle language annotation special case + if "annotation" in key: + modality = "language" + else: + modality = "others" + if modality not in grouped_keys: + grouped_keys[modality] = [] + grouped_keys[modality].append(key) + + if "video" in grouped_keys: + # Check if keys in video_concat_order, state_concat_order, action_concat_order are + # ineed contained in the data. If not, then the keys are misspecified + video_keys = grouped_keys["video"] + assert self.video_concat_order is not None, f"{self.video_concat_order=}, {video_keys=}" + assert all( + item in video_keys for item in self.video_concat_order + ), f"keys in video_concat_order are misspecified, \n{video_keys=}, \n{self.video_concat_order=}" + + # Process each video view + unsqueezed_videos = [] + for video_key in self.video_concat_order: + video_data = data.pop(video_key) + unsqueezed_video = np.expand_dims( + video_data, axis=-4 + ) # [..., H, W, C] -> [..., 1, H, W, C] + unsqueezed_videos.append(unsqueezed_video) + # Concatenate along the new axis + unsqueezed_video = np.concatenate(unsqueezed_videos, axis=-4) # [..., V, H, W, C] + + # Video + data["video"] = unsqueezed_video + + # "state" + if "state" in grouped_keys: + state_keys = grouped_keys["state"] + assert self.state_concat_order is not None, f"{self.state_concat_order=}" + assert all( + item in state_keys for item in self.state_concat_order + ), f"keys in state_concat_order are misspecified, \n{state_keys=}, \n{self.state_concat_order=}" + # Check the state dims + for key in self.state_concat_order: + target_shapes = [self.state_dims[key]] + if self.is_rotation_key(key): + target_shapes.append(6) # Allow for rotation_6d + # if key in ["state.right_arm", "state.right_hand"]: + target_shapes.append(self.state_dims[key] * 2) # Allow for sin-cos transform + assert ( + data[key].shape[-1] in target_shapes + ), f"State dim mismatch for {key=}, {data[key].shape[-1]=}, {target_shapes=}" + # Concatenate the state keys + # We'll have StateActionToTensor before this transform, so here we use torch.cat + data["state"] = torch.cat( + [data.pop(key) for key in self.state_concat_order], dim=-1 + ) # [T, D_state] + + if "action" in grouped_keys: + action_keys = grouped_keys["action"] + assert self.action_concat_order is not None, f"{self.action_concat_order=}" + # Check if all keys in concat_order are present + assert set(self.action_concat_order) == set( + action_keys + ), f"{set(self.action_concat_order)=}, {set(action_keys)=}" + # Record the action dims + for key in self.action_concat_order: + target_shapes = [self.action_dims[key]] + if self.is_rotation_key(key): + target_shapes.append(3) # Allow for axis angle + assert ( + self.action_dims[key] == data[key].shape[-1] + ), f"Action dim mismatch for {key=}, {self.action_dims[key]=}, {data[key].shape[-1]=}" + # Concatenate the action keys + # We'll have StateActionToTensor before this transform, so here we use torch.cat + data["action"] = torch.cat( + [data.pop(key) for key in self.action_concat_order], dim=-1 + ) # [T, D_action] + + return data + + def unapply(self, data: dict) -> dict: + start_dim = 0 + assert "action" in data, f"{data.keys()=}" + # For those dataset without actions (LAPA), we'll never run unapply + assert self.action_concat_order is not None, f"{self.action_concat_order=}" + action_tensor = data.pop("action") + for key in self.action_concat_order: + if key not in self.action_dims: + raise ValueError(f"Action dim {key} not found in action_dims.") + end_dim = start_dim + self.action_dims[key] + data[key] = action_tensor[..., start_dim:end_dim] + start_dim = end_dim + if "state" in data: + assert self.state_concat_order is not None, f"{self.state_concat_order=}" + start_dim = 0 + state_tensor = data.pop("state") + for key in self.state_concat_order: + end_dim = start_dim + self.state_dims[key] + data[key] = state_tensor[..., start_dim:end_dim] + start_dim = end_dim + return data + + def __call__(self, data: dict) -> dict: + return self.apply(data) + + def get_modality_metadata(self, key: str) -> StateActionMetadata: + modality, subkey = key.split(".") + assert self.dataset_metadata is not None, "Metadata not set" + modality_config = getattr(self.dataset_metadata.modalities, modality) + assert subkey in modality_config, f"{subkey=} not found in {modality_config=}" + assert isinstance( + modality_config[subkey], StateActionMetadata + ), f"Expected {StateActionMetadata} for {subkey=}, got {type(modality_config[subkey])=}" + return modality_config[subkey] + + def get_state_action_dims(self, key: str) -> int: + """Get the dimension of a state or action key from the dataset metadata.""" + modality_config = self.get_modality_metadata(key) + shape = modality_config.shape + assert len(shape) == 1, f"{shape=}" + return shape[0] + + def is_rotation_key(self, key: str) -> bool: + modality_config = self.get_modality_metadata(key) + return modality_config.rotation_type is not None + + def set_metadata(self, dataset_metadata: DatasetMetadata): + """Set the metadata and compute the dimensions of the state and action keys.""" + super().set_metadata(dataset_metadata) + # Pre-compute the dimensions of the state and action keys + if self.action_concat_order is not None: + for key in self.action_concat_order: + self.action_dims[key] = self.get_state_action_dims(key) + if self.state_concat_order is not None: + for key in self.state_concat_order: + self.state_dims[key] = self.get_state_action_dims(key) diff --git a/code/dataloader/gr00t_lerobot/transform/state_action.py b/code/dataloader/gr00t_lerobot/transform/state_action.py new file mode 100644 index 0000000000000000000000000000000000000000..a01d5f7c39903e3e78f4d92e6f901d93a99707e1 --- /dev/null +++ b/code/dataloader/gr00t_lerobot/transform/state_action.py @@ -0,0 +1,606 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import random +from typing import Any, ClassVar + +import numpy as np +import pytorch3d.transforms as pt +import torch +from pydantic import Field, PrivateAttr, field_validator, model_validator + +from ..schema import DatasetMetadata, RotationType, StateActionMetadata +from .base import InvertibleModalityTransform, ModalityTransform + + +class RotationTransform: + """Adapted from https://github.com/real-stanford/diffusion_policy/blob/548a52bbb105518058e27bf34dcf90bf6f73681a/diffusion_policy/model/common/rotation_transformer.py""" + + valid_reps = ["axis_angle", "euler_angles", "quaternion", "rotation_6d", "matrix"] + + def __init__(self, from_rep="axis_angle", to_rep="rotation_6d"): + """ + Valid representations + + Always use matrix as intermediate representation. + """ + if from_rep.startswith("euler_angles"): + from_convention = from_rep.split("_")[-1] + from_rep = "euler_angles" + from_convention = from_convention.replace("r", "X").replace("p", "Y").replace("y", "Z") + else: + from_convention = None + if to_rep.startswith("euler_angles"): + to_convention = to_rep.split("_")[-1] + to_rep = "euler_angles" + to_convention = to_convention.replace("r", "X").replace("p", "Y").replace("y", "Z") + else: + to_convention = None + assert from_rep != to_rep, f"from_rep and to_rep cannot be the same: {from_rep}" + assert from_rep in self.valid_reps, f"Invalid from_rep: {from_rep}" + assert to_rep in self.valid_reps, f"Invalid to_rep: {to_rep}" + + forward_funcs = list() + inverse_funcs = list() + + if from_rep != "matrix": + funcs = [getattr(pt, f"{from_rep}_to_matrix"), getattr(pt, f"matrix_to_{from_rep}")] + if from_convention is not None: + funcs = [functools.partial(func, convention=from_convention) for func in funcs] + forward_funcs.append(funcs[0]) + inverse_funcs.append(funcs[1]) + + if to_rep != "matrix": + funcs = [getattr(pt, f"matrix_to_{to_rep}"), getattr(pt, f"{to_rep}_to_matrix")] + if to_convention is not None: + funcs = [functools.partial(func, convention=to_convention) for func in funcs] + forward_funcs.append(funcs[0]) + inverse_funcs.append(funcs[1]) + + inverse_funcs = inverse_funcs[::-1] + + self.forward_funcs = forward_funcs + self.inverse_funcs = inverse_funcs + + @staticmethod + def _apply_funcs(x: torch.Tensor, funcs: list) -> torch.Tensor: + assert isinstance(x, torch.Tensor) + for func in funcs: + x = func(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert isinstance( + x, torch.Tensor + ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" + return self._apply_funcs(x, self.forward_funcs) + + def inverse(self, x: torch.Tensor) -> torch.Tensor: + assert isinstance( + x, torch.Tensor + ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" + return self._apply_funcs(x, self.inverse_funcs) + + +class Normalizer: + valid_modes = ["q99", "mean_std", "min_max", "binary"] + + def __init__(self, mode: str, statistics: dict): + self.mode = mode + self.statistics = statistics + for key, value in self.statistics.items(): + self.statistics[key] = torch.tensor(value) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert isinstance( + x, torch.Tensor + ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" + + # Normalize the tensor + if self.mode == "q99": + # Range of q99 is [-1, 1] + q01 = self.statistics["q01"].to(x.dtype) + q99 = self.statistics["q99"].to(x.dtype) + + # In the case of q01 == q99, the normalization will be undefined + # So we set the normalized values to the original values + mask = q01 != q99 + normalized = torch.zeros_like(x) + + # Normalize the values where q01 != q99 + # Formula: 2 * (x - q01) / (q99 - q01) - 1 + normalized[..., mask] = (x[..., mask] - q01[..., mask]) / ( + q99[..., mask] - q01[..., mask] + ) + normalized[..., mask] = 2 * normalized[..., mask] - 1 + + # Set the normalized values to the original values where q01 == q99 + normalized[..., ~mask] = x[..., ~mask].to(x.dtype) + + # Clip the normalized values to be between -1 and 1 + normalized = torch.clamp(normalized, -1, 1) + + elif self.mode == "mean_std": + # Range of mean_std is not fixed, but can be positive or negative + mean = self.statistics["mean"].to(x.dtype) + std = self.statistics["std"].to(x.dtype) + + # In the case of std == 0, the normalization will be undefined + # So we set the normalized values to the original values + mask = std != 0 + normalized = torch.zeros_like(x) + + # Normalize the values where std != 0 + # Formula: (x - mean) / std + normalized[..., mask] = (x[..., mask] - mean[..., mask]) / std[..., mask] + + # Set the normalized values to the original values where std == 0 + normalized[..., ~mask] = x[..., ~mask].to(x.dtype) + + elif self.mode == "min_max": + # Range of min_max is [-1, 1] + min = self.statistics["min"].to(x.dtype) + max = self.statistics["max"].to(x.dtype) + + # In the case of min == max, the normalization will be undefined + # So we set the normalized values to the original values + mask = min != max + normalized = torch.zeros_like(x) + + # Normalize the values where min != max + # Formula: 2 * (x - min) / (max - min) - 1 + normalized[..., mask] = (x[..., mask] - min[..., mask]) / ( + max[..., mask] - min[..., mask] + ) + normalized[..., mask] = 2 * normalized[..., mask] - 1 + + # Set the normalized values to the original values where min == max + # normalized[..., ~mask] = x[..., ~mask].to(x.dtype) + # Set the normalized values to 0 where min == max + normalized[..., ~mask] = 0 + + elif self.mode == "scale": + # Range of scale is [0, 1] + min = self.statistics["min"].to(x.dtype) + max = self.statistics["max"].to(x.dtype) + abs_max = torch.max(torch.abs(min), torch.abs(max)) + mask = abs_max != 0 + normalized = torch.zeros_like(x) + normalized[..., mask] = x[..., mask] / abs_max[..., mask] + normalized[..., ~mask] = 0 + + elif self.mode == "binary": + # Range of binary is [0, 1] + normalized = (x > 0.5).to(x.dtype) + else: + raise ValueError(f"Invalid normalization mode: {self.mode}") + + return normalized + + def inverse(self, x: torch.Tensor) -> torch.Tensor: + assert isinstance( + x, torch.Tensor + ), f"Unexpected input type: {type(x)}. Expected type: {torch.Tensor}" + if self.mode == "q99": + q01 = self.statistics["q01"].to(x.dtype) + q99 = self.statistics["q99"].to(x.dtype) + return (x + 1) / 2 * (q99 - q01) + q01 + elif self.mode == "mean_std": + mean = self.statistics["mean"].to(x.dtype) + std = self.statistics["std"].to(x.dtype) + return x * std + mean + elif self.mode == "min_max": + min = self.statistics["min"].to(x.dtype) + max = self.statistics["max"].to(x.dtype) + return (x + 1) / 2 * (max - min) + min + elif self.mode == "binary": + return (x > 0.5).to(x.dtype) + else: + raise ValueError(f"Invalid normalization mode: {self.mode}") + + +class StateActionToTensor(InvertibleModalityTransform): + """ + Transforms states and actions to tensors. + """ + + input_dtypes: dict[str, np.dtype] = Field( + default_factory=dict, description="The input dtypes for each state key." + ) + output_dtypes: dict[str, torch.dtype] = Field( + default_factory=dict, description="The output dtypes for each state key." + ) + + def model_dump(self, *args, **kwargs): + if kwargs.get("mode", "python") == "json": + include = {"apply_to"} + else: + include = kwargs.pop("include", None) + + return super().model_dump(*args, include=include, **kwargs) + + @field_validator("input_dtypes", "output_dtypes", mode="before") + def validate_dtypes(cls, v): + for key, dtype in v.items(): + if isinstance(dtype, str): + if dtype.startswith("torch."): + dtype_split = dtype.split(".")[-1] + v[key] = getattr(torch, dtype_split) + elif dtype.startswith("np.") or dtype.startswith("numpy."): + dtype_split = dtype.split(".")[-1] + v[key] = np.dtype(dtype_split) + else: + raise ValueError(f"Invalid dtype: {dtype}") + return v + + def apply(self, data: dict[str, Any]) -> dict[str, Any]: + for key in self.apply_to: + if key not in data: + continue + value = data[key] + assert isinstance( + value, np.ndarray + ), f"Unexpected input type: {type(value)}. Expected type: {np.ndarray}" + data[key] = torch.from_numpy(value) + if key in self.output_dtypes: + data[key] = data[key].to(self.output_dtypes[key]) + return data + + def unapply(self, data: dict[str, Any]) -> dict[str, Any]: + for key in self.apply_to: + if key not in data: + continue + value = data[key] + assert isinstance( + value, torch.Tensor + ), f"Unexpected input type: {type(value)}. Expected type: {torch.Tensor}" + data[key] = value.numpy() + if key in self.input_dtypes: + data[key] = data[key].astype(self.input_dtypes[key]) + return data + + +class StateActionTransform(InvertibleModalityTransform): + """ + Class for state or action transform. + + Args: + apply_to (list[str]): The keys in the modality to load and transform. + normalization_modes (dict[str, str]): The normalization modes for each state key. + If a state key in apply_to is not present in the dictionary, it will not be normalized. + target_rotations (dict[str, str]): The target representations for each state key. + If a state key in apply_to is not present in the dictionary, it will not be rotated. + """ + + # Configurable attributes + apply_to: list[str] = Field(..., description="The keys in the modality to load and transform.") + normalization_modes: dict[str, str] = Field( + default_factory=dict, description="The normalization modes for each state key." + ) + target_rotations: dict[str, str] = Field( + default_factory=dict, description="The target representations for each state key." + ) + normalization_statistics: dict[str, dict] = Field( + default_factory=dict, description="The statistics for each state key." + ) + modality_metadata: dict[str, StateActionMetadata] = Field( + default_factory=dict, description="The modality metadata for each state key." + ) + + # Model variables + _rotation_transformers: dict[str, RotationTransform] = PrivateAttr(default_factory=dict) + _normalizers: dict[str, Normalizer] = PrivateAttr(default_factory=dict) + _input_dtypes: dict[str, np.dtype | torch.dtype] = PrivateAttr(default_factory=dict) + + # Model constants + _DEFAULT_MIN_MAX_STATISTICS: ClassVar[dict] = { + "rotation_6d": { + "min": [-1, -1, -1, -1, -1, -1], + "max": [1, 1, 1, 1, 1, 1], + }, + "euler_angles": { + "min": [-np.pi, -np.pi, -np.pi], + "max": [np.pi, np.pi, np.pi], + }, + "quaternion": { + "min": [-1, -1, -1, -1], + "max": [1, 1, 1, 1], + }, + "axis_angle": { + "min": [-np.pi, -np.pi, -np.pi], + "max": [np.pi, np.pi, np.pi], + }, + } + + def model_dump(self, *args, **kwargs): + if kwargs.get("mode", "python") == "json": + include = {"apply_to", "normalization_modes", "target_rotations"} + else: + include = kwargs.pop("include", None) + + return super().model_dump(*args, include=include, **kwargs) + + @field_validator("modality_metadata", mode="before") + def validate_modality_metadata(cls, v): + for modality_key, config in v.items(): + if isinstance(config, dict): + config = StateActionMetadata.model_validate(config) + else: + assert isinstance( + config, StateActionMetadata + ), f"Invalid source rotation config: {config}" + v[modality_key] = config + return v + + @model_validator(mode="after") + def validate_normalization_statistics(self): + for modality_key, normalization_statistics in self.normalization_statistics.items(): + if modality_key in self.normalization_modes: + normalization_mode = self.normalization_modes[modality_key] + if normalization_mode == "min_max": + assert ( + "min" in normalization_statistics and "max" in normalization_statistics + ), f"Min and max statistics are required for min_max normalization, but got {normalization_statistics}" + assert len(normalization_statistics["min"]) == len( + normalization_statistics["max"] + ), f"Min and max statistics must have the same length, but got {normalization_statistics['min']} and {normalization_statistics['max']}" + elif normalization_mode == "mean_std": + assert ( + "mean" in normalization_statistics and "std" in normalization_statistics + ), f"Mean and std statistics are required for mean_std normalization, but got {normalization_statistics}" + assert len(normalization_statistics["mean"]) == len( + normalization_statistics["std"] + ), f"Mean and std statistics must have the same length, but got {normalization_statistics['mean']} and {normalization_statistics['std']}" + elif normalization_mode == "q99": + assert ( + "q01" in normalization_statistics and "q99" in normalization_statistics + ), f"q01 and q99 statistics are required for q99 normalization, but got {normalization_statistics}" + assert len(normalization_statistics["q01"]) == len( + normalization_statistics["q99"] + ), f"q01 and q99 statistics must have the same length, but got {normalization_statistics['q01']} and {normalization_statistics['q99']}" + elif normalization_mode == "binary": + assert ( + len(normalization_statistics) == 1 + ), f"Binary normalization should only have one value, but got {normalization_statistics}" + assert normalization_statistics[0] in [ + 0, + 1, + ], f"Binary normalization should only have 0 or 1, but got {normalization_statistics[0]}" + else: + raise ValueError(f"Invalid normalization mode: {normalization_mode}") + return self + + def set_metadata(self, dataset_metadata: DatasetMetadata): + dataset_statistics = dataset_metadata.statistics + modality_metadata = dataset_metadata.modalities + + # Check that all state keys specified in apply_to have their modality_metadata + for key in self.apply_to: + split_key = key.split(".", 1) + assert len(split_key) == 2, "State keys should have two parts: 'modality.key'" + if key not in self.modality_metadata: + modality, state_key = split_key + assert hasattr(modality_metadata, modality), f"{modality} config not found" + assert state_key in getattr( + modality_metadata, modality + ), f"{state_key} config not found" + self.modality_metadata[key] = getattr(modality_metadata, modality)[state_key] + + # Check that all state keys specified in normalization_modes have their statistics in state_statistics + for key in self.normalization_modes: + split_key = key.split(".", 1) + assert len(split_key) == 2, "State keys should have two parts: 'modality.key'" + modality, state_key = split_key + assert hasattr(dataset_statistics, modality), f"{modality} statistics not found" + assert state_key in getattr( + dataset_statistics, modality + ), f"{state_key} statistics not found" + assert ( + len(getattr(modality_metadata, modality)[state_key].shape) == 1 + ), f"{getattr(modality_metadata, modality)[state_key].shape=}" + self.normalization_statistics[key] = getattr(dataset_statistics, modality)[ + state_key + ].model_dump() + + # Initialize the rotation transformers + for key in self.target_rotations: + # Get the original representation of the state + from_rep = self.modality_metadata[key].rotation_type + assert from_rep is not None, f"Source rotation type not found for {key}" + + # Get the target representation of the state, will raise an error if the target representation is not valid + to_rep = RotationType(self.target_rotations[key]) + + # If the original representation is not the same as the target representation, initialize the rotation transformer + if from_rep != to_rep: + self._rotation_transformers[key] = RotationTransform( + from_rep=from_rep.value, to_rep=to_rep.value + ) + + # Initialize the normalizers + for key in self.normalization_modes: + modality, state_key = key.split(".", 1) + # If the state has a nontrivial rotation, we need to handle it more carefully + # For absolute rotations, we need to convert them to the target representation and normalize them using min_max mode, + # since we can infer the bounds by the representation + # For relative rotations, we cannot normalize them as we don't know the bounds + if key in self._rotation_transformers: + # Case 1: Absolute rotation + if self.modality_metadata[key].absolute: + # Check that the normalization mode is valid + assert ( + self.normalization_modes[key] == "min_max" + ), "Absolute rotations that are converted to other formats must be normalized using `min_max` mode" + rotation_type = RotationType(self.target_rotations[key]).value + # If the target representation is euler angles, we need to parse the convention + if rotation_type.startswith("euler_angles"): + rotation_type = "euler_angles" + # Get the statistics for the target representation + statistics = self._DEFAULT_MIN_MAX_STATISTICS[rotation_type] + # Case 2: Relative rotation + else: + raise ValueError( + f"Cannot normalize relative rotations: {key} that's converted to {self.target_rotations[key]}" + ) + # If the state is not continuous, we should not use normalization modes other than binary + elif ( + not self.modality_metadata[key].continuous + and self.normalization_modes[key] != "binary" + ): + raise ValueError( + f"{key} is not continuous, so it should be normalized using `binary` mode" + ) + # Initialize the normalizer + else: + statistics = self.normalization_statistics[key] + self._normalizers[key] = Normalizer( + mode=self.normalization_modes[key], statistics=statistics + ) + + def apply(self, data: dict[str, Any]) -> dict[str, Any]: + for key in self.apply_to: + if key not in data: + # We allow some keys to be missing in the data, and only process the keys that are present + continue + if key not in self._input_dtypes: + input_dtype = data[key].dtype + assert isinstance( + input_dtype, torch.dtype + ), f"Unexpected input dtype: {input_dtype}. Expected type: {torch.dtype}" + self._input_dtypes[key] = input_dtype + else: + assert ( + data[key].dtype == self._input_dtypes[key] + ), f"All states corresponding to the same key must be of the same dtype, input dtype: {data[key].dtype}, expected dtype: {self._input_dtypes[key]}" + # Rotate the state + state = data[key] + if key in self._rotation_transformers: + state = self._rotation_transformers[key].forward(state) + # Normalize the state + if key in self._normalizers: + state = self._normalizers[key].forward(state) + data[key] = state + return data + + def unapply(self, data: dict[str, Any]) -> dict[str, Any]: + for key in self.apply_to: + if key not in data: + continue + state = data[key] + assert isinstance( + state, torch.Tensor + ), f"Unexpected state type: {type(state)}. Expected type: {torch.Tensor}" + # Unnormalize the state + if key in self._normalizers: + state = self._normalizers[key].inverse(state) + # Change the state back to its original representation + if key in self._rotation_transformers: + state = self._rotation_transformers[key].inverse(state) + assert isinstance( + state, torch.Tensor + ), f"State should be tensor after unapplying transformations, but got {type(state)}" + # Only convert back to the original dtype if it's known, i.e. `apply` was called before + # If not, we don't know the original dtype, so we don't convert + if key in self._input_dtypes: + original_dtype = self._input_dtypes[key] + if isinstance(original_dtype, np.dtype): + state = state.numpy().astype(original_dtype) + elif isinstance(original_dtype, torch.dtype): + state = state.to(original_dtype) + else: + raise ValueError(f"Invalid input dtype: {original_dtype}") + data[key] = state + return data + + +class StateActionPerturbation(ModalityTransform): + """ + Class for state or action perturbation. + + Args: + apply_to (list[str]): The keys in the modality to load and transform. + std (float): Standard deviation of the noise to be added to the state or action. + """ + + # Configurable attributes + std: float = Field( + ..., description="Standard deviation of the noise to be added to the state or action." + ) + + def apply(self, data: dict[str, Any]) -> dict[str, Any]: + if not self.training: + # Don't perturb the data in eval mode + return data + if self.std < 0: + # If the std is negative, we don't add any noise + return data + for key in self.apply_to: + state = data[key] + assert isinstance(state, torch.Tensor) + transformed_data_min = torch.min(state) + transformed_data_max = torch.max(state) + noise = torch.randn_like(state) * self.std + state += noise + # Clip to the original range + state = torch.clamp(state, transformed_data_min, transformed_data_max) + data[key] = state + return data + + +class StateActionDropout(ModalityTransform): + """ + Class for state or action dropout. + + Args: + apply_to (list[str]): The keys in the modality to load and transform. + dropout_prob (float): Probability of dropping out a state or action. + """ + + # Configurable attributes + dropout_prob: float = Field(..., description="Probability of dropping out a state or action.") + + def apply(self, data: dict[str, Any]) -> dict[str, Any]: + if not self.training: + # Don't drop out the data in eval mode + return data + if self.dropout_prob < 0: + # If the dropout probability is negative, we don't drop out any states + return data + if self.dropout_prob > 1e-9 and random.random() < self.dropout_prob: + for key in self.apply_to: + state = data[key] + assert isinstance(state, torch.Tensor) + state = torch.zeros_like(state) + data[key] = state + return data + + +class StateActionSinCosTransform(ModalityTransform): + """ + Class for state or action sin-cos transform. + + Args: + apply_to (list[str]): The keys in the modality to load and transform. + """ + + def apply(self, data: dict[str, Any]) -> dict[str, Any]: + for key in self.apply_to: + state = data[key] + assert isinstance(state, torch.Tensor) + sin_state = torch.sin(state) + cos_state = torch.cos(state) + data[key] = torch.cat([sin_state, cos_state], dim=-1) + return data diff --git a/code/dataloader/gr00t_lerobot/transform/video.py b/code/dataloader/gr00t_lerobot/transform/video.py new file mode 100644 index 0000000000000000000000000000000000000000..15310f697259d7ee6ed8eac7e5abaca211b81dc9 --- /dev/null +++ b/code/dataloader/gr00t_lerobot/transform/video.py @@ -0,0 +1,612 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, ClassVar, Literal + +import albumentations as A +import cv2 +import numpy as np +import torch +import torchvision.transforms.v2 as T +from einops import rearrange +from pydantic import Field, PrivateAttr, field_validator +from PIL import Image + +from ..schema import DatasetMetadata +from .base import ModalityTransform + + +class VideoTransform(ModalityTransform): + # Configurable attributes + backend: str = Field( + default="torchvision", description="The backend to use for the transformations" + ) + + # Model variables + _train_transform: Callable | None = PrivateAttr(default=None) + _eval_transform: Callable | None = PrivateAttr(default=None) + _original_resolutions: dict[str, tuple[int, int]] = PrivateAttr(default_factory=dict) + + # Model constants + _INTERPOLATION_MAP: ClassVar[dict[str, dict[str, Any]]] = PrivateAttr( + { + "nearest": { + "albumentations": cv2.INTER_NEAREST, + "torchvision": T.InterpolationMode.NEAREST, + }, + "linear": { + "albumentations": cv2.INTER_LINEAR, + "torchvision": T.InterpolationMode.BILINEAR, + }, + "cubic": { + "albumentations": cv2.INTER_CUBIC, + "torchvision": T.InterpolationMode.BICUBIC, + }, + "area": { + "albumentations": cv2.INTER_AREA, + "torchvision": None, # Torchvision does not support this interpolation mode + }, + "lanczos4": { + "albumentations": cv2.INTER_LANCZOS4, # Lanczos with a 4x4 filter + "torchvision": T.InterpolationMode.LANCZOS, # Torchvision does not specify filter size, might be different from 4x4 + }, + "linear_exact": { + "albumentations": cv2.INTER_LINEAR_EXACT, + "torchvision": None, # Torchvision does not support this interpolation mode + }, + "nearest_exact": { + "albumentations": cv2.INTER_NEAREST_EXACT, + "torchvision": T.InterpolationMode.NEAREST_EXACT, + }, + "max": { + "albumentations": cv2.INTER_MAX, + "torchvision": None, + }, + } + ) + + @property + def train_transform(self) -> Callable: + assert ( + self._train_transform is not None + ), "Transform is not set. Please call set_metadata() before calling apply()." + return self._train_transform + + @train_transform.setter + def train_transform(self, value: Callable): + self._train_transform = value + + @property + def eval_transform(self) -> Callable | None: + return self._eval_transform + + @eval_transform.setter + def eval_transform(self, value: Callable | None): + self._eval_transform = value + + @property + def original_resolutions(self) -> dict[str, tuple[int, int]]: + assert ( + self._original_resolutions is not None + ), "Original resolutions are not set. Please call set_metadata() before calling apply()." + return self._original_resolutions + + @original_resolutions.setter + def original_resolutions(self, value: dict[str, tuple[int, int]]): + self._original_resolutions = value + + def check_input(self, data: dict[str, Any]): + if self.backend == "torchvision": + for key in self.apply_to: + assert isinstance(data[key], torch.Tensor), f"Video {key} is not a torch tensor" + assert data[key].ndim in [ + 4, + 5, + ], f"Expected video {key} to have 4 or 5 dimensions (T, C, H, W or T, B, C, H, W), got {data[key].ndim}" + elif self.backend == "albumentations": + for key in self.apply_to: + assert isinstance(data[key], np.ndarray), f"Video {key} is not a numpy array" + assert data[key].ndim in [ + 4, + 5, + ], f"Expected video {key} to have 4 or 5 dimensions (T, C, H, W or T, B, C, H, W), got {data[key].ndim}" + else: + raise ValueError(f"Backend {self.backend} not supported") + + def set_metadata(self, dataset_metadata: DatasetMetadata): + super().set_metadata(dataset_metadata) + self.original_resolutions = {} + for key in self.apply_to: + split_keys = key.split(".") + assert len(split_keys) == 2, f"Invalid key: {key}. Expected format: modality.key" + sub_key = split_keys[1] + if sub_key in dataset_metadata.modalities.video: + self.original_resolutions[key] = dataset_metadata.modalities.video[ + sub_key + ].resolution + else: + raise ValueError( + f"Video key {sub_key} not found in dataset metadata. Available keys: {dataset_metadata.modalities.video.keys()}" + ) + train_transform = self.get_transform(mode="train") + eval_transform = self.get_transform(mode="eval") + if self.backend == "albumentations": + self.train_transform = A.ReplayCompose(transforms=[train_transform]) # type: ignore + if eval_transform is not None: + self.eval_transform = A.ReplayCompose(transforms=[eval_transform]) # type: ignore + else: + assert train_transform is not None, "Train transform must be set" + self.train_transform = train_transform + self.eval_transform = eval_transform + + def apply(self, data: dict[str, Any]) -> dict[str, Any]: + if self.training: + transform = self.train_transform + else: + transform = self.eval_transform + if transform is None: + return data + assert ( + transform is not None + ), "Transform is not set. Please call set_metadata() before calling apply()." + try: + self.check_input(data) + except AssertionError as e: + raise ValueError( + f"Input data does not match the expected format for {self.__class__.__name__}: {e}" + ) from e + + # Concatenate views + views = [data[key] for key in self.apply_to] + num_views = len(views) + is_batched = views[0].ndim == 5 + bs = views[0].shape[0] if is_batched else 1 + if isinstance(views[0], torch.Tensor): + views = torch.cat(views, 0) + elif isinstance(views[0], np.ndarray): + views = np.concatenate(views, 0) + else: + raise ValueError(f"Unsupported view type: {type(views[0])}") + if is_batched: + views = rearrange(views, "(v b) t c h w -> (v b t) c h w", v=num_views, b=bs) + # Apply the transform + if self.backend == "torchvision": + views = transform(views) + elif self.backend == "albumentations": + assert isinstance(transform, A.ReplayCompose), "Transform must be a ReplayCompose" + first_frame = views[0] + transformed = transform(image=first_frame) + replay_data = transformed["replay"] + transformed_first_frame = transformed["image"] + + if len(views) > 1: + # Apply the same transformations to the rest of the frames + transformed_frames = [ + transform.replay(replay_data, image=frame)["image"] for frame in views[1:] + ] + # Add the first frame back + transformed_frames = [transformed_first_frame] + transformed_frames + else: + # If there is only one frame, just make a list with one frame + transformed_frames = [transformed_first_frame] + + # Delete the replay data to save memory + del replay_data + views = np.stack(transformed_frames, 0) + + else: + raise ValueError(f"Backend {self.backend} not supported") + # Split views + if is_batched: + views = rearrange(views, "(v b t) c h w -> v b t c h w", v=num_views, b=bs) + else: + views = rearrange(views, "(v t) c h w -> v t c h w", v=num_views) + for key, view in zip(self.apply_to, views): + data[key] = view + return data + + @classmethod + def _validate_interpolation(cls, interpolation: str): + if interpolation not in cls._INTERPOLATION_MAP: + raise ValueError(f"Interpolation mode {interpolation} not supported") + + def _get_interpolation(self, interpolation: str, backend: str = "torchvision"): + """ + Get the interpolation mode for the given backend. + + Args: + interpolation (str): The interpolation mode. + backend (str): The backend to use. + + Returns: + Any: The interpolation mode for the given backend. + """ + return self._INTERPOLATION_MAP[interpolation][backend] + + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None: + raise NotImplementedError( + "set_transform is not implemented for VideoTransform. Please implement this function to set the transforms." + ) + + +class VideoCrop(VideoTransform): + height: int | None = Field(default=None, description="The height of the input image") + width: int | None = Field(default=None, description="The width of the input image") + scale: float = Field( + ..., + description="The scale of the crop. The crop size is (width * scale, height * scale)", + ) + + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable: + """Get the transform for the given mode. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable: If mode is "train", return a random crop transform. If mode is "eval", return a center crop transform. + """ + # 1. Check the input resolution + assert ( + len(set(self.original_resolutions.values())) == 1 + ), f"All video keys must have the same resolution, got: {self.original_resolutions}" + if self.height is None: + assert self.width is None, "Height and width must be either both provided or both None" + self.width, self.height = self.original_resolutions[self.apply_to[0]] + else: + assert ( + self.width is not None + ), "Height and width must be either both provided or both None" + # 2. Create the transform + size = (int(self.height * self.scale), int(self.width * self.scale)) + if self.backend == "torchvision": + if mode == "train": + return T.RandomCrop(size) + elif mode == "eval": + return T.CenterCrop(size) + else: + raise ValueError(f"Crop mode {mode} not supported") + elif self.backend == "albumentations": + if mode == "train": + return A.RandomCrop(height=size[0], width=size[1], p=1) + elif mode == "eval": + return A.CenterCrop(height=size[0], width=size[1], p=1) + else: + raise ValueError(f"Crop mode {mode} not supported") + else: + raise ValueError(f"Backend {self.backend} not supported") + + def check_input(self, data: dict[str, Any]): + super().check_input(data) + # Check the input resolution + for key in self.apply_to: + if self.backend == "torchvision": + height, width = data[key].shape[-2:] + elif self.backend == "albumentations": + height, width = data[key].shape[-3:-1] + else: + raise ValueError(f"Backend {self.backend} not supported") + assert ( + height == self.height and width == self.width + ), f"Video {key} has invalid shape {height, width}, expected {self.height, self.width}" + + +class VideoResize(VideoTransform): + height: int = Field(..., description="The height of the resize") + width: int = Field(..., description="The width of the resize") + interpolation: str = Field(default="linear", description="The interpolation mode") + antialias: bool = Field(default=True, description="Whether to apply antialiasing") + + @field_validator("interpolation") + def validate_interpolation(cls, v): + cls._validate_interpolation(v) + return v + + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable: + """Get the resize transform. Same transform for both train and eval. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable: The resize transform. + """ + interpolation = self._get_interpolation(self.interpolation, self.backend) + if interpolation is None: + raise ValueError( + f"Interpolation mode {self.interpolation} not supported for torchvision" + ) + if self.backend == "torchvision": + size = (self.height, self.width) + return T.Resize(size, interpolation=interpolation, antialias=self.antialias) + elif self.backend == "albumentations": + return A.Resize( + height=self.height, + width=self.width, + interpolation=interpolation, + p=1, + ) + else: + raise ValueError(f"Backend {self.backend} not supported") + + +class VideoRandomRotation(VideoTransform): + degrees: float | tuple[float, float] = Field( + ..., description="The degrees of the random rotation" + ) + interpolation: str = Field("linear", description="The interpolation mode") + + @field_validator("interpolation") + def validate_interpolation(cls, v): + cls._validate_interpolation(v) + return v + + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None: + """Get the random rotation transform, only used in train mode. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable | None: The random rotation transform. None for eval mode. + """ + if mode == "eval": + return None + interpolation = self._get_interpolation(self.interpolation, self.backend) + if interpolation is None: + raise ValueError( + f"Interpolation mode {self.interpolation} not supported for torchvision" + ) + if self.backend == "torchvision": + return T.RandomRotation(self.degrees, interpolation=interpolation) # type: ignore + elif self.backend == "albumentations": + return A.Rotate(limit=self.degrees, interpolation=interpolation, p=1) + else: + raise ValueError(f"Backend {self.backend} not supported") + + +class VideoHorizontalFlip(VideoTransform): + p: float = Field(..., description="The probability of the horizontal flip") + + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None: + """Get the horizontal flip transform, only used in train mode. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable | None: If mode is "train", return a horizontal flip transform. If mode is "eval", return None. + """ + if mode == "eval": + return None + if self.backend == "torchvision": + return T.RandomHorizontalFlip(self.p) + elif self.backend == "albumentations": + return A.HorizontalFlip(p=self.p) + else: + raise ValueError(f"Backend {self.backend} not supported") + + +class VideoGrayscale(VideoTransform): + p: float = Field(..., description="The probability of the grayscale transformation") + + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None: + """Get the grayscale transform, only used in train mode. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable | None: If mode is "train", return a grayscale transform. If mode is "eval", return None. + """ + if mode == "eval": + return None + if self.backend == "torchvision": + return T.RandomGrayscale(self.p) + elif self.backend == "albumentations": + return A.ToGray(p=self.p) + else: + raise ValueError(f"Backend {self.backend} not supported") + + +class VideoColorJitter(VideoTransform): + brightness: float | tuple[float, float] = Field( + ..., description="The brightness of the color jitter" + ) + contrast: float | tuple[float, float] = Field( + ..., description="The contrast of the color jitter" + ) + saturation: float | tuple[float, float] = Field( + ..., description="The saturation of the color jitter" + ) + hue: float | tuple[float, float] = Field(..., description="The hue of the color jitter") + + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None: + """Get the color jitter transform, only used in train mode. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable | None: If mode is "train", return a color jitter transform. If mode is "eval", return None. + """ + if mode == "eval": + return None + if self.backend == "torchvision": + return T.ColorJitter( + brightness=self.brightness, + contrast=self.contrast, + saturation=self.saturation, + hue=self.hue, + ) + elif self.backend == "albumentations": + return A.ColorJitter( + brightness=self.brightness, + contrast=self.contrast, + saturation=self.saturation, + hue=self.hue, + p=1, + ) + else: + raise ValueError(f"Backend {self.backend} not supported") + + +class VideoRandomGrayscale(VideoTransform): + p: float = Field(..., description="The probability of the grayscale transformation") + + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None: + """Get the grayscale transform, only used in train mode. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable | None: If mode is "train", return a grayscale transform. If mode is "eval", return None. + """ + if mode == "eval": + return None + if self.backend == "torchvision": + return T.RandomGrayscale(self.p) + elif self.backend == "albumentations": + return A.ToGray(p=self.p) + else: + raise ValueError(f"Backend {self.backend} not supported") + + +class VideoRandomPosterize(VideoTransform): + bits: int = Field(..., description="The number of bits to posterize the image") + p: float = Field(..., description="The probability of the posterize transformation") + + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None: + """Get the posterize transform, only used in train mode. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable | None: If mode is "train", return a posterize transform. If mode is "eval", return None. + """ + if mode == "eval": + return None + if self.backend == "torchvision": + return T.RandomPosterize(bits=self.bits, p=self.p) + elif self.backend == "albumentations": + return A.Posterize(num_bits=self.bits, p=self.p) + else: + raise ValueError(f"Backend {self.backend} not supported") + + +class VideoToTensor(VideoTransform): + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable: + """Get the to tensor transform. Same transform for both train and eval. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable: The to tensor transform. + """ + if self.backend == "torchvision": + return self.__class__.to_tensor + else: + raise ValueError(f"Backend {self.backend} not supported") + + def check_input(self, data: dict): + """Check if the input data has the correct shape. + Expected video shape: [T, H, W, C], dtype np.uint8 + """ + for key in self.apply_to: + assert key in data, f"Key {key} not found in data. Available keys: {data.keys()}" + assert data[key].ndim in [ + 4, + 5, + ], f"Video {key} must have 4 or 5 dimensions, got {data[key].ndim}" + assert ( + data[key].dtype == np.uint8 + ), f"Video {key} must have dtype uint8, got {data[key].dtype}" + input_resolution = data[key].shape[-3:-1][::-1] + if key in self.original_resolutions: + expected_resolution = self.original_resolutions[key] + else: + expected_resolution = input_resolution + assert ( + input_resolution == expected_resolution + ), f"Video {key} has invalid resolution {input_resolution}, expected {expected_resolution}. Full shape: {data[key].shape}" + + @staticmethod + def to_tensor(frames: np.ndarray) -> torch.Tensor: + """Convert numpy array to tensor efficiently. + + Args: + frames: numpy array of shape [T, H, W, C] in uint8 format + Returns: + tensor of shape [T, C, H, W] in range [0, 1] + """ + frames_tensor = torch.from_numpy(frames).to(torch.float32) / 255.0 + return frames_tensor.permute(0, 3, 1, 2) # [T, C, H, W] + + +class VideoToNumpy(VideoTransform): + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable: + """Get the to numpy transform. Same transform for both train and eval. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable: The to numpy transform. + """ + if self.backend == "torchvision": + return self.__class__.to_numpy + else: + raise ValueError(f"Backend {self.backend} not supported") + + @staticmethod + def to_numpy(frames: torch.Tensor) -> np.ndarray: + """Convert tensor back to numpy array efficiently. + + Args: + frames: tensor of shape [T, C, H, W] in range [0, 1] + Returns: + numpy array of shape [T, H, W, C] in uint8 format + """ + return (frames.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy() + +class VideoToPIL(VideoTransform): + def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable: + """Get the to PIL transform. Same transform for both train and eval. + + Args: + mode (Literal["train", "eval"]): The mode to get the transform for. + + Returns: + Callable: The to PIL transform. + """ + if self.backend == "torchvision": + return self.__class__.to_pil + else: + raise ValueError(f"Backend {self.backend} not supported") + + @staticmethod + def to_pil(frames: torch.Tensor) -> Image.Image: + """Convert tensor back to PIL Image. + + Args: + frames: tensor of shape [T, C, H, W] in range [0, 1] + Returns: + PIL Image of shape [T, H, W, C] in uint8 format + """ + # video PIL format? + return Image.fromarray((frames.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()) \ No newline at end of file diff --git a/code/dataloader/gr00t_lerobot/video.py b/code/dataloader/gr00t_lerobot/video.py new file mode 100644 index 0000000000000000000000000000000000000000..7f1bf2db18e81f223dc0f489411d42832c008f6e --- /dev/null +++ b/code/dataloader/gr00t_lerobot/video.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import av +import cv2 +import numpy as np + +import torch # noqa: F401 # isort: skip +import torchvision # noqa: F401 # isort: skip + +# Import decord with graceful fallback +try: + import decord # noqa: F401 + + DECORD_AVAILABLE = True +except ImportError: + DECORD_AVAILABLE = False + +try: + import torchcodec + + TORCHCODEC_AVAILABLE = True +except (ImportError, RuntimeError): + TORCHCODEC_AVAILABLE = False + + +def get_frames_by_indices( + video_path: str, + indices: list[int] | np.ndarray, + video_backend: str = "decord", + video_backend_kwargs: dict = {}, +) -> np.ndarray: + if video_backend == "decord": + if not DECORD_AVAILABLE: + raise ImportError("decord is not available.") + vr = decord.VideoReader(video_path, **video_backend_kwargs) + frames = vr.get_batch(indices) + return frames.asnumpy() + elif video_backend == "torchcodec": + if not TORCHCODEC_AVAILABLE: + raise ImportError("torchcodec is not available.") + decoder = torchcodec.decoders.VideoDecoder( + video_path, device="cpu", dimension_order="NHWC", num_ffmpeg_threads=0 + ) + return decoder.get_frames_at(indices=indices).data.numpy() + elif video_backend == "opencv": + frames = [] + cap = cv2.VideoCapture(video_path, **video_backend_kwargs) + for idx in indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if not ret: + raise ValueError(f"Unable to read frame at index {idx}") + frames.append(frame) + cap.release() + frames = np.array(frames) + return frames + else: + raise NotImplementedError + + +def get_frames_by_timestamps( + video_path: str, + timestamps: list[float] | np.ndarray, + video_backend: str = "decord", + video_backend_kwargs: dict = {}, +) -> np.ndarray: + """Get frames from a video at specified timestamps. + Args: + video_path (str): Path to the video file. + timestamps (list[int] | np.ndarray): Timestamps to retrieve frames for, in seconds. + video_backend (str, optional): Video backend to use. Defaults to "decord". + Returns: + np.ndarray: Frames at the specified timestamps. + """ + if video_backend == "decord": + # For some GPUs, AV format data cannot be read + if not DECORD_AVAILABLE: + raise ImportError("decord is not available.") + vr = decord.VideoReader(video_path, **video_backend_kwargs) + num_frames = len(vr) + # Retrieve the timestamps for each frame in the video + frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames)) + # Map each requested timestamp to the closest frame index + # Only take the first element of the frame_ts array which corresponds to start_seconds + indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0) + frames = vr.get_batch(indices) + return frames.asnumpy() + elif video_backend == "torchcodec": + if not TORCHCODEC_AVAILABLE: + raise ImportError("torchcodec is not available.") + decoder = torchcodec.decoders.VideoDecoder( + video_path, device="cpu", dimension_order="NHWC", num_ffmpeg_threads=0 + ) + return decoder.get_frames_played_at(seconds=timestamps).data.numpy() + elif video_backend == "opencv": + # Open the video file + cap = cv2.VideoCapture(video_path, **video_backend_kwargs) + if not cap.isOpened(): + raise ValueError(f"Unable to open video file: {video_path}") + # Retrieve the total number of frames + num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + # Calculate timestamps for each frame + fps = cap.get(cv2.CAP_PROP_FPS) + frame_ts = np.arange(num_frames) / fps + frame_ts = frame_ts[:, np.newaxis] # Reshape to (num_frames, 1) for broadcasting + # Map each requested timestamp to the closest frame index + indices = np.abs(frame_ts - timestamps).argmin(axis=0) + frames = [] + for idx in indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if not ret: + raise ValueError(f"Unable to read frame at index {idx}") + frames.append(frame) + cap.release() + frames = np.array(frames) + return frames + elif video_backend == "torchvision_av": + torchvision.set_video_backend("pyav") + loaded_frames = [] + loaded_ts = [] + + reader = None + try: + reader = torchvision.io.VideoReader(video_path, "video") + + for target_ts in timestamps: + # Reset reader state + reader.seek(target_ts, keyframes_only=True) + + closest_frame = None + closest_ts_diff = float('inf') + + for frame in reader: + current_ts = frame["pts"] + current_diff = abs(current_ts - target_ts) + + if closest_frame is None: + closest_frame = frame + + if current_diff < closest_ts_diff: + # Release the previous frame + if closest_frame is not None: + del closest_frame + closest_ts_diff = current_diff + closest_frame = frame + else: + # The time difference starts to increase, stop searching + break + + if closest_frame is not None: + frame_data = closest_frame["data"] + if isinstance(frame_data, torch.Tensor): + frame_data = frame_data.cpu().numpy() + loaded_frames.append(frame_data) + loaded_ts.append(closest_frame["pts"]) + + # Immediately release frame reference + del closest_frame + + finally: + # Thoroughly clean resources + if reader is not None: + if hasattr(reader, '_c'): + reader._c = None + if hasattr(reader, 'container'): + reader.container.close() + reader.container = None + # Force garbage collection + import gc + gc.collect() + + frames = np.array(loaded_frames) + return frames.transpose(0, 2, 3, 1) + else: + raise NotImplementedError + + +def get_all_frames( + video_path: str, + video_backend: str = "decord", + video_backend_kwargs: dict = {}, + resize_size: tuple[int, int] | None = None, +) -> np.ndarray: + """Get all frames from a video. + Args: + video_path (str): Path to the video file. + video_backend (str, optional): Video backend to use. Defaults to "decord". + video_backend_kwargs (dict, optional): Keyword arguments for the video backend. + resize_size (tuple[int, int], optional): Resize size for the frames. Defaults to None. + """ + if video_backend == "decord": + if not DECORD_AVAILABLE: + raise ImportError("decord is not available.") + vr = decord.VideoReader(video_path, **video_backend_kwargs) + frames = vr.get_batch(range(len(vr))).asnumpy() + elif video_backend == "torchcodec": + if not TORCHCODEC_AVAILABLE: + raise ImportError("torchcodec is not available.") + decoder = torchcodec.decoders.VideoDecoder( + video_path, device="cpu", dimension_order="NHWC", num_ffmpeg_threads=0 + ) + frames = decoder.get_frames_at(indices=range(len(decoder))) + return frames.data.numpy(), frames.pts_seconds.numpy() + elif video_backend == "pyav": + container = av.open(video_path) + frames = [] + for frame in container.decode(video=0): + frame = frame.to_ndarray(format="rgb24") + frames.append(frame) + frames = np.array(frames) + elif video_backend == "torchvision_av": + # set backend and reader + torchvision.set_video_backend("pyav") + reader = torchvision.io.VideoReader(video_path, "video") + frames = [] + for frame in reader: + frames.append(frame["data"].numpy()) + frames = np.array(frames) + frames = frames.transpose(0, 2, 3, 1) + else: + raise NotImplementedError(f"Video backend {video_backend} not implemented") + # resize frames if specified + if resize_size is not None: + frames = [cv2.resize(frame, resize_size) for frame in frames] + frames = np.array(frames) + return frames \ No newline at end of file diff --git a/code/dataloader/lerobot_datasets.py b/code/dataloader/lerobot_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..954d9bdeaaa1d88328d5bc8b5bd4716fce1903e9 --- /dev/null +++ b/code/dataloader/lerobot_datasets.py @@ -0,0 +1,145 @@ +# Copyright 2025 NVIDIA Corp. and affiliates. All rights reserved. +# Modified by [Fangjing Wang/ SUST University] in [2025]. +# Modification: [return raw data and suport multi-dataset mixture]. +# Modified by [Jinhui YE/ HKUST University] in [2025]. +# Modification: [suport topdowm processing, suport param from config]. + +from pathlib import Path +from typing import Sequence +from omegaconf import OmegaConf + +from starVLA.dataloader.gr00t_lerobot.datasets import LeRobotSingleDataset, LeRobotMixtureDataset +from starVLA.dataloader.gr00t_lerobot.mixtures import DATASET_NAMED_MIXTURES +from starVLA.dataloader.gr00t_lerobot.data_config import get_robot_type_config_map +from starVLA.dataloader.gr00t_lerobot.embodiment_tags import ROBOT_TYPE_TO_EMBODIMENT_TAG, EmbodimentTag + +def collate_fn(batch): + return batch + +def make_LeRobotSingleDataset( + data_root_dir: Path | str, + data_name: str, + robot_type: str, + delete_pause_frame: bool = False, + data_cfg: dict | None = None, +) -> LeRobotSingleDataset: + """ + Make a LeRobotSingleDataset object. + + :param data_root_dir: The root directory of the dataset. + :param data_name: The name of the dataset. + :param robot_type: The robot type config to use. + :param crop_obs_camera: Whether to crop the observation camera images. + :return: A LeRobotSingleDataset object. + """ + chunk_size = data_cfg.get("chunk_size") + state_use_action_chunk = data_cfg.get("state_use_action_chunk") + num_history_steps = data_cfg.get("num_history_steps", 0) + data_config = get_robot_type_config_map( + chunk_size=chunk_size, + state_use_action_chunk=state_use_action_chunk, + num_history_steps=num_history_steps, + )[robot_type] + modality_config = data_config.modality_config() + transforms = data_config.transform() + dataset_path = data_root_dir / data_name + if robot_type not in ROBOT_TYPE_TO_EMBODIMENT_TAG: + print(f"Warning: Robot type {robot_type} not found in ROBOT_TYPE_TO_EMBODIMENT_TAG, using {EmbodimentTag.NEW_EMBODIMENT} as default") + embodiment_tag = EmbodimentTag.NEW_EMBODIMENT + else: + embodiment_tag = ROBOT_TYPE_TO_EMBODIMENT_TAG[robot_type] + + video_backend = data_cfg.get("video_backend", "decord") if data_cfg else "decord" + + return LeRobotSingleDataset( + dataset_path=dataset_path, + modality_configs=modality_config, + transforms=transforms, + embodiment_tag=embodiment_tag, + video_backend=video_backend, # decord is more efficiency | torchvision_av for video.av1 + delete_pause_frame=delete_pause_frame, + data_cfg=data_cfg, + ) + +def get_vla_dataset( + data_cfg: dict, + mode: str = "train", + balance_dataset_weights: bool = False, + balance_trajectory_weights: bool = False, + seed: int = 42, + delete_pause_frame: bool = True, + **kwargs: dict, +) -> LeRobotMixtureDataset: + """ + Get a LeRobotMixtureDataset object. + """ + data_root_dir = data_cfg.data_root_dir + data_mix = data_cfg.data_mix + mixture_spec = DATASET_NAMED_MIXTURES[data_mix] + included_datasets, filtered_mixture_spec = set(), [] + for d_name, d_weight, robot_type in mixture_spec: + dataset_key = (d_name, robot_type) + if dataset_key in included_datasets: + print(f"Skipping Duplicate Dataset: `{(d_name, d_weight, robot_type)}`") + continue + + included_datasets.add(dataset_key) + filtered_mixture_spec.append((d_name, d_weight, robot_type)) + + dataset_mixture = [] + for d_name, d_weight, robot_type in filtered_mixture_spec: + dataset_mixture.append((make_LeRobotSingleDataset(Path(data_root_dir), d_name, robot_type, delete_pause_frame=delete_pause_frame, data_cfg=data_cfg), d_weight)) + + return LeRobotMixtureDataset( + dataset_mixture, + mode=mode, + balance_dataset_weights=balance_dataset_weights, + balance_trajectory_weights=balance_trajectory_weights, + seed=seed, + data_cfg=data_cfg, + **kwargs, + ) + + + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--config_yaml", type=str, default="./starVLA/config/training/starvla_cotrain_behavior.yaml", help="Path to YAML config") + args, clipargs = parser.parse_known_args() + + args.config_yaml = "examples/LIBERO/train_files/starvla_cotrain_libero.yaml" + cfg = OmegaConf.load(args.config_yaml) + + vla_dataset_cfg = cfg.datasets.vla_data + # vla_dataset_cfg.data_root_dir = "./playground/Datasets/behavior-1k" + # vla_dataset_cfg.include_state = True + # vla_dataset_cfg.data_mix = "BEHAVIOR_dual_base_depth" + vla_dataset_cfg.task_id = 1 + for task_id in ["all"]: + # 11,26,36,37 + # 5,11,13,26,36,27,43,44,45,46 + # 2,3,5,11,13,25,26,27, + # 3,5,11,13, / 14,15,16,17, / 19,20,23,25, / 26,27,30,34, / 36,37,38,39, 41,42,43,44,45,46,47,49 + vla_dataset_cfg.task_id = task_id + print(f"Testing Task ID: {task_id}") + dataset = get_vla_dataset(data_cfg=vla_dataset_cfg) + # dataset + from torch.utils.data import DataLoader + train_dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=1, # For Debug + collate_fn=collate_fn, + ) + + from tqdm import tqdm + count = 1 + for batch in tqdm(train_dataloader, desc="Processing Batches"): + # print(batch) + # print(1) + if count > 1: + break + count += 1 + pass \ No newline at end of file diff --git a/code/dataloader/qwenvl_llavajson/__pycache__/qwen_data_config.cpython-310.pyc b/code/dataloader/qwenvl_llavajson/__pycache__/qwen_data_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0343b52aeb3caf567010067354f61e471a4089e0 Binary files /dev/null and b/code/dataloader/qwenvl_llavajson/__pycache__/qwen_data_config.cpython-310.pyc differ diff --git a/code/dataloader/qwenvl_llavajson/__pycache__/qwen_data_config.cpython-311.pyc b/code/dataloader/qwenvl_llavajson/__pycache__/qwen_data_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c4cab9e5e946a8565b9cf4e0c11e995bbc1c1f6 Binary files /dev/null and b/code/dataloader/qwenvl_llavajson/__pycache__/qwen_data_config.cpython-311.pyc differ diff --git a/code/dataloader/qwenvl_llavajson/__pycache__/rope2d.cpython-310.pyc b/code/dataloader/qwenvl_llavajson/__pycache__/rope2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84b233b414798b0b9d42e91c2797d4e28ef148c1 Binary files /dev/null and b/code/dataloader/qwenvl_llavajson/__pycache__/rope2d.cpython-310.pyc differ diff --git a/code/dataloader/qwenvl_llavajson/__pycache__/rope2d.cpython-311.pyc b/code/dataloader/qwenvl_llavajson/__pycache__/rope2d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bb02c6ae52ef35de67ed7b90531aba9a5f16c30 Binary files /dev/null and b/code/dataloader/qwenvl_llavajson/__pycache__/rope2d.cpython-311.pyc differ diff --git a/code/dataloader/qwenvl_llavajson/qwen_data_config.py b/code/dataloader/qwenvl_llavajson/qwen_data_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c690f7ba14e9b6ec5f5bfaab310abaf867505b96 --- /dev/null +++ b/code/dataloader/qwenvl_llavajson/qwen_data_config.py @@ -0,0 +1,44 @@ +import re + +from pathlib import Path + +# You can add multimodal datasets here and register a short nickname to ${data_dict}. +# The data format should follow the general multimodal VLM format, for example: +# https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-finetune/README.md + +json_root = f"./playground/Datasets/LLaVA-OneVision-COCO/llava_jsons" +image_root = f"./playground/Datasets/LLaVA-OneVision-COCO/images" + +SHAREGPT4V_COCO = { + "annotation_path": f"{json_root}/sharegpt4v_coco.json", + "data_path": f"{image_root}/", +} + +data_dict = { + "sharegpt4v_coco": SHAREGPT4V_COCO, +} + +def parse_sampling_rate(dataset_name): + match = re.search(r"%(\d+)$", dataset_name) + if match: + return int(match.group(1)) / 100.0 + return 1.0 + +def data_list(dataset_names): + if dataset_names == ["all"]: + dataset_names = list(data_dict.keys()) + config_list = [] + for dataset_name in dataset_names: + sampling_rate = parse_sampling_rate(dataset_name) + dataset_name = re.sub(r"%(\d+)$", "", dataset_name) + if dataset_name in data_dict.keys(): + config = data_dict[dataset_name].copy() + config["sampling_rate"] = sampling_rate + config_list.append(config) + else: + raise ValueError(f"do not find {dataset_name}") + return config_list + +if __name__ == "__main__": + print(data_list) + diff --git a/code/dataloader/qwenvl_llavajson/rope2d.py b/code/dataloader/qwenvl_llavajson/rope2d.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4da85174dcc7da97bc27264858965628773296 --- /dev/null +++ b/code/dataloader/qwenvl_llavajson/rope2d.py @@ -0,0 +1,351 @@ +import os +import copy +import json +import random +import logging +import re +import time +import math +import ast +from dataclasses import dataclass, field +from typing import Dict, Optional, Sequence, List, Tuple +from io import BytesIO +import base64 + +import numpy as np +import torch +from torch.utils.data import Dataset +from PIL import Image +from decord import VideoReader +import transformers + + +def get_rope_index_25( + spatial_merge_size: Optional[int] = 2, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + image_token_id = 151655 + video_token_id = 151656 + vision_start_token_id = 151652 + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * 2 + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + +def get_rope_index_2( + spatial_merge_size: Optional[int] = 2, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embeddin for text part. + Examples: + Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [3, 4, 5, 6, 7] + text height position_ids: [3, 4, 5, 6, 7] + text width position_ids: [3, 4, 5, 6, 7] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + image_token_id = 151655 + video_token_id = 151656 + vision_start_token_id = 151652 + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas diff --git a/code/dataloader/vlm_datasets.py b/code/dataloader/vlm_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..422aeed3ec5352f8e85e4ab64e19747c24ef8aa8 --- /dev/null +++ b/code/dataloader/vlm_datasets.py @@ -0,0 +1,658 @@ +import os +import copy +import json +import random +import logging +import re +import time +import math +import itertools +import ast +from dataclasses import dataclass +from typing import Dict, Optional, Sequence, List, Tuple +from io import BytesIO +import base64 +from collections.abc import Sequence +from types import SimpleNamespace +import numpy as np +import torch +from torch.utils.data import Dataset +from PIL import Image +from decord import VideoReader +import transformers +from omegaconf import OmegaConf +from starVLA.dataloader.qwenvl_llavajson.qwen_data_config import data_list +from starVLA.dataloader.qwenvl_llavajson.rope2d import get_rope_index_25, get_rope_index_2 + +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = 151655 +VIDEO_TOKEN_INDEX = 151656 +DEFAULT_IMAGE_TOKEN = "\n" +DEFAULT_VIDEO_TOKEN = "