Timsty commited on
Commit
e94400c
·
verified ·
1 Parent(s): c7e95ad

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. code/__init__.py +0 -0
  3. code/__pycache__/__init__.cpython-310.pyc +0 -0
  4. code/__pycache__/__init__.cpython-311.pyc +0 -0
  5. code/config/deepseeds/deepspeed_zero2.yaml +9 -0
  6. code/config/deepseeds/deepspeed_zero3.yaml +7 -0
  7. code/config/deepseeds/ds_config.yaml +23 -0
  8. code/config/deepseeds/zero2.yaml +21 -0
  9. code/config/deepseeds/zero3.yaml +28 -0
  10. code/config/training/starvla_train_actionmodel_oxe.yaml +85 -0
  11. code/config/training/starvla_train_pi0.yaml +104 -0
  12. code/config/training/starvla_train_qwengr00t.yaml +99 -0
  13. code/config/training/starvla_train_qwenlatent_history_naive_oxe.yaml +106 -0
  14. code/config/training/starvla_train_qwenlatent_history_oxe.yaml +102 -0
  15. code/config/training/starvla_train_qwenlatent_oxe.yaml +103 -0
  16. code/config/training/starvla_train_qwenpi.yaml +97 -0
  17. code/dataloader/__init__.py +70 -0
  18. code/dataloader/__pycache__/__init__.cpython-310.pyc +0 -0
  19. code/dataloader/__pycache__/__init__.cpython-311.pyc +0 -0
  20. code/dataloader/__pycache__/lerobot_datasets.cpython-310.pyc +0 -0
  21. code/dataloader/__pycache__/lerobot_datasets.cpython-311.pyc +0 -0
  22. code/dataloader/__pycache__/vlm_datasets.cpython-310.pyc +0 -0
  23. code/dataloader/__pycache__/vlm_datasets.cpython-311.pyc +0 -0
  24. code/dataloader/gr00t_lerobot/README.md +0 -0
  25. code/dataloader/gr00t_lerobot/__init__.py +0 -0
  26. code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-310.pyc +0 -0
  27. code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-311.pyc +0 -0
  28. code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-310.pyc +0 -0
  29. code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-311.pyc +0 -0
  30. code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-310.pyc +0 -0
  31. code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-311.pyc +3 -0
  32. code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-310.pyc +0 -0
  33. code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-311.pyc +0 -0
  34. code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-310.pyc +0 -0
  35. code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-311.pyc +0 -0
  36. code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-310.pyc +0 -0
  37. code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-311.pyc +0 -0
  38. code/dataloader/gr00t_lerobot/__pycache__/video.cpython-310.pyc +0 -0
  39. code/dataloader/gr00t_lerobot/__pycache__/video.cpython-311.pyc +0 -0
  40. code/dataloader/gr00t_lerobot/data_config.py +392 -0
  41. code/dataloader/gr00t_lerobot/datasets.py +2165 -0
  42. code/dataloader/gr00t_lerobot/datasets_bak.py +2175 -0
  43. code/dataloader/gr00t_lerobot/datasets_bak2.py +2145 -0
  44. code/dataloader/gr00t_lerobot/embodiment_tags.py +198 -0
  45. code/dataloader/gr00t_lerobot/mixtures.py +241 -0
  46. code/dataloader/gr00t_lerobot/schema.py +221 -0
  47. code/dataloader/gr00t_lerobot/transform/__init__.py +41 -0
  48. code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-310.pyc +0 -0
  49. code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-311.pyc +0 -0
  50. code/dataloader/gr00t_lerobot/transform/__pycache__/base.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -2407,3 +2407,4 @@ wandb/wandb/run-20260419_111433-oh7yfg1j/run-oh7yfg1j.wandb filter=lfs diff=lfs
2407
  0422_QwenLatent_13tasks_stateactionprior_50k/videos/pytorch_model/n_action_steps_10_max_episode_steps_720_n_envs_1_gr1_unified/PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_Env/eb75ba54-de01-4932-90a2-36f3cdeaefaa_success0.mp4 filter=lfs diff=lfs merge=lfs -text
2408
  0422_QwenLatent_13tasks_stateactionprior_50k/videos/pytorch_model/n_action_steps_10_max_episode_steps_720_n_envs_1_gr1_unified/PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_Env/fb67ee40-aa3d-4616-b022-a36af1f3d7d0_success1.mp4 filter=lfs diff=lfs merge=lfs -text
2409
  0422_QwenLatent_13tasks_stateactionprior_50k/videos/pytorch_model/n_action_steps_10_max_episode_steps_720_n_envs_1_gr1_unified/PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_Env/fcdb3d2a-4bc1-4a9b-ae8e-7d9900cea828_success1.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
2407
  0422_QwenLatent_13tasks_stateactionprior_50k/videos/pytorch_model/n_action_steps_10_max_episode_steps_720_n_envs_1_gr1_unified/PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_Env/eb75ba54-de01-4932-90a2-36f3cdeaefaa_success0.mp4 filter=lfs diff=lfs merge=lfs -text
2408
  0422_QwenLatent_13tasks_stateactionprior_50k/videos/pytorch_model/n_action_steps_10_max_episode_steps_720_n_envs_1_gr1_unified/PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_Env/fb67ee40-aa3d-4616-b022-a36af1f3d7d0_success1.mp4 filter=lfs diff=lfs merge=lfs -text
2409
  0422_QwenLatent_13tasks_stateactionprior_50k/videos/pytorch_model/n_action_steps_10_max_episode_steps_720_n_envs_1_gr1_unified/PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_Env/fcdb3d2a-4bc1-4a9b-ae8e-7d9900cea828_success1.mp4 filter=lfs diff=lfs merge=lfs -text
2410
+ code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
code/__init__.py ADDED
File without changes
code/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (145 Bytes). View file
 
code/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (161 Bytes). View file
 
code/config/deepseeds/deepspeed_zero2.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_config_file: "./starVLA/config/deepseeds/ds_config.yaml"
5
+ deepspeed_multinode_launcher: standard
6
+ zero3_init_flag: false
7
+ distributed_type: DEEPSPEED
8
+ num_machines: 1
9
+ num_processes: 8
code/config/deepseeds/deepspeed_zero3.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_config_file: "./starVLA/config/deepseeds/zero3.yaml"
5
+ deepspeed_multinode_launcher: standard
6
+ zero3_init_flag: false
7
+ distributed_type: DEEPSPEED
code/config/deepseeds/ds_config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false
4
+ },
5
+ "bf16": {
6
+ "enabled": true
7
+ },
8
+ "train_micro_batch_size_per_gpu": "auto",
9
+ "train_batch_size": "auto",
10
+ "gradient_accumulation_steps": 1,
11
+ "zero_optimization": {
12
+ "stage": 2,
13
+ "allgather_partitions": true,
14
+ "allgather_bucket_size": 5e8,
15
+ "reduce_scatter": true,
16
+ "reduce_bucket_size": 5e8,
17
+ "overlap_comm": true,
18
+ "contiguous_gradients": true,
19
+ "cpu_offload": false
20
+ },
21
+ "gradient_clipping": 1.0,
22
+ "steps_per_print": 10
23
+ }
code/config/deepseeds/zero2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: bf16
14
+ num_machines: 1
15
+ num_processes: 8
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
code/config/deepseeds/zero3.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 3,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": 5e8,
22
+ "stage3_prefetch_bucket_size": 5e8,
23
+ "stage3_param_persistence_threshold": 1e6,
24
+ "stage3_max_live_parameters": 1e9,
25
+ "stage3_max_reuse_distance": 1e9,
26
+ "stage3_gather_16bit_weights_on_model_save": true
27
+ }
28
+ }
code/config/training/starvla_train_actionmodel_oxe.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_id: vla_jepa_temp
2
+ run_root_dir: ./runs
3
+ seed: 21
4
+ trackers: [jsonl, wandb]
5
+ wandb_entity: timsty
6
+ wandb_project: vla_jepa
7
+ is_debug: false
8
+
9
+ framework:
10
+ name: ActionModelFM
11
+ action_model:
12
+ action_size: 37
13
+ state_size: 74
14
+ use_state: ${datasets.vla_data.state_use_action_chunk}
15
+ hidden_size: 1024
16
+ intermediate_size: 3072
17
+ dataset_vocab_size: 256
18
+ num_data_tokens: 32
19
+ mask_ratio_mode: "uniform_per_traj"
20
+ mask_ratio_min: 0.25
21
+ mask_ratio_max: 0.75
22
+ min_action_len: 5
23
+ num_encoder_layers: 28
24
+ num_decoder_layers: 28
25
+ num_attention_heads: 16
26
+ num_key_value_heads: 8
27
+ head_dim: 128
28
+ max_position_embeddings: 4096
29
+ max_action_chunk_size: 50
30
+ rms_norm_eps: 1.0e-6
31
+ attention_dropout: 0.0
32
+ # --- Action model loss mode (choose one combination) ---
33
+ use_masked_action_recon: false # true = add reconstruction loss for masked-action view (two-view training)
34
+ qwen3_pretrained_name_or_path: /mnt/data/fangyu/model/Qwen/Qwen3-0.6B
35
+
36
+ datasets:
37
+ vla_data:
38
+ dataset_py: lerobot_datasets
39
+ data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY
40
+ data_mix: cross_embodiedment_13tasks
41
+ require_language: false
42
+ # action_type: delta_ee
43
+ default_image_resolution: [3, 224, 224]
44
+ per_device_batch_size: 256
45
+ load_all_data_for_training: true
46
+ load_video: false
47
+ obs: ["image_0"]
48
+ image_size: [224,224]
49
+ video_backend: torchcodec
50
+ chunk_size: 15
51
+ # state chunk aligned with action: state shape (L, state_dim) like action (L, action_dim)
52
+ state_use_action_chunk: true
53
+
54
+ trainer:
55
+ epochs: 1000
56
+ max_train_steps: 5000
57
+ num_warmup_steps: 1000
58
+ save_interval: 5000
59
+ eval_interval: 50
60
+ learning_rate:
61
+ base: 1e-04
62
+ lr_scheduler_type: cosine_with_min_lr
63
+ scheduler_specific_kwargs:
64
+ min_lr: 5.0e-07
65
+ freeze_modules: ''
66
+ loss_scale:
67
+ vla: 1.0
68
+ warmup_ratio: 0.1
69
+ weight_decay: 0.0
70
+ logging_frequency: 10
71
+ gradient_clipping: 5
72
+ gradient_accumulation_steps: 1
73
+
74
+ optimizer:
75
+ name: AdamW
76
+ betas: [0.9, 0.95]
77
+ eps: 1.0e-08
78
+ weight_decay: 1.0e-08
79
+
80
+ # parameters to be determined
81
+ is_resume: false
82
+ resume_epoch: null
83
+ resume_step: null
84
+ enable_gradient_checkpointing: true
85
+ enable_mixed_precision_training: true
code/config/training/starvla_train_pi0.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PI0 训练配置 - 使用 unified 37D action 表示
2
+ # action/state 投影层(原 openpi 硬编码 32D)会在 PI0Framework 初始化时自动替换为 37D,
3
+ # checkpoint 中对应的 32D 参数加载时自动跳过,其余 backbone 参数正常复用。
4
+
5
+ run_id: pi0_unified_37d
6
+ run_root_dir: ./runs
7
+ seed: 42
8
+ trackers: [jsonl, wandb]
9
+ wandb_entity: timsty
10
+ wandb_project: vla_jepa
11
+ is_debug: false
12
+
13
+ framework:
14
+ name: PI0
15
+ # PI0 模型配置
16
+ # action_dim 以本项目为准(统一 37D unified action 表示)。
17
+ # PI0Pytorch 源码中 action_in_proj / action_out_proj / state_proj 硬编码为 32D,
18
+ # PI0Framework.__init__ 会调用 _replace_pi0_projection_layers 将其替换为 37D,
19
+ # 加载 checkpoint 时这些层因 shape 不匹配会自动跳过(保持随机初始化)。
20
+ # 其余 VLM backbone 层(PaliGemma、action expert transformer 等)仍正常从 checkpoint 加载。
21
+ pi0:
22
+ paligemma_variant: "gemma_2b"
23
+ action_expert_variant: "gemma_300m"
24
+ pi05: false
25
+ action_dim: 37 # 项目统一维度;投影层会被自动替换,checkpoint 同维度参数跳过加载
26
+ state_dim: 74 # unified state 维度;state_proj 替换为 Linear(74, width),与 action_dim 独立
27
+ action_horizon: 15 # 与 chunk_size 对齐
28
+ dtype: "bfloat16"
29
+
30
+ # 预训练权重路径(pi05_libero 等,action_dim 不匹配时会 strict=False 部分加载)
31
+ pi0_checkpoint: /mnt/data/fangyu/model/openpi/openpi-assets/checkpoints/pi0_base_torch/model.pt
32
+
33
+ # PaliGemma tokenizer
34
+ tokenizer_path: /root/.cache/openpi/big_vision/paligemma_tokenizer.model
35
+
36
+ # 图像键名,与 openpi 三视角格式对应;gr1 单视角时配合 replicate_single_view
37
+ image_keys:
38
+ - "base_0_rgb"
39
+ - "left_wrist_0_rgb"
40
+ - "right_wrist_0_rgb"
41
+
42
+ # 当 dataset 仅提供 1 张图时复制到 3 视角(如 fourier_gr1 video.ego_view)
43
+ replicate_single_view: true
44
+
45
+ use_state: true
46
+
47
+ # 若 true,根据实际图像数量动态使用 image_keys 的前 N 个;否则固定全部 keys,不足补零
48
+ dynamic_image_keys: false
49
+
50
+ num_inference_steps: 10
51
+
52
+ # 输出截断维度,null 表示输出完整 action_dim
53
+ effective_action_dim: null
54
+
55
+ datasets:
56
+ vla_data:
57
+ dataset_py: lerobot_datasets
58
+ data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY
59
+ data_mix: cross_embodiedment_simulator
60
+ default_image_resolution: [3, 224, 224]
61
+ per_device_batch_size: 32
62
+ load_all_data_for_training: true
63
+ obs: ["image_0"]
64
+ image_size: [224, 224]
65
+ video_backend: torchcodec
66
+ load_video: true
67
+ chunk_size: 15
68
+ state_use_action_chunk: false
69
+ num_history_steps: 0
70
+ include_state: false # 训练 PI0 时不使用 state
71
+
72
+ trainer:
73
+ epochs: 100
74
+ max_train_steps: 20000
75
+ num_warmup_steps: 5000
76
+ num_stable_steps: 0
77
+ save_interval: 5000
78
+ max_checkpoints_to_keep: 20
79
+
80
+ learning_rate:
81
+ base: 2.5e-5
82
+ pi0_model: 2.5e-5
83
+
84
+ lr_scheduler_type: warmup_stable_cosine
85
+ scheduler_specific_kwargs:
86
+ min_lr_ratio: 0.001
87
+
88
+ freeze_modules: ""
89
+ warmup_ratio: 0.1
90
+ weight_decay: 0.0
91
+ logging_frequency: 10
92
+ gradient_clipping: 5.0
93
+ gradient_accumulation_steps: 1
94
+
95
+ optimizer:
96
+ name: AdamW
97
+ betas: [0.9, 0.95]
98
+ eps: 1.0e-08
99
+ weight_decay: 1.0e-08
100
+
101
+ is_resume: false
102
+ pretrained_checkpoint: null
103
+ enable_gradient_checkpointing: false
104
+ enable_mixed_precision_training: true
code/config/training/starvla_train_qwengr00t.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_id: qwengr00t_oxe
2
+ run_root_dir: ./runs
3
+ seed: 42
4
+ trackers: [jsonl, wandb]
5
+ wandb_entity: timsty
6
+ wandb_project: vla_jepa
7
+ is_debug: false
8
+
9
+ framework:
10
+ name: QwenGR00T
11
+ qwenvl:
12
+ base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct
13
+ attn_implementation: flash_attention_2
14
+ vl_hidden_dim: 2048
15
+ num_data_tokens: 32 # dataset soft prompt tokens prepended to VLM input (0 = disabled)
16
+
17
+ # QwenGR00T required action head config
18
+ action_model:
19
+ dataset_vocab_size: 256 # number of distinct dataset IDs for soft prompt embedding
20
+ action_model_type: DiT-B
21
+ hidden_size: 1024
22
+ add_pos_embed: true
23
+ max_seq_len: 1024
24
+ action_dim: 37
25
+ state_dim: 74
26
+ future_action_window_size: 14
27
+ action_horizon: 15
28
+ past_action_window_size: 0
29
+ noise_beta_alpha: 1.5
30
+ noise_beta_beta: 1.0
31
+ noise_s: 0.999
32
+ num_timestep_buckets: 1000
33
+ num_inference_timesteps: 10
34
+ num_target_vision_tokens: 32
35
+ diffusion_model_cfg:
36
+ cross_attention_dim: 2048
37
+ dropout: 0.2
38
+ final_dropout: true
39
+ interleave_self_attention: true
40
+ norm_type: "ada_norm"
41
+ num_layers: 16
42
+ output_dim: 1024
43
+ positional_embeddings: null
44
+
45
+ datasets:
46
+ vla_data:
47
+ dataset_py: lerobot_datasets
48
+ data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY
49
+ data_mix: cross_embodiedment_13tasks
50
+ CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?"
51
+ default_image_resolution: [3, 224, 224]
52
+ per_device_batch_size: 32
53
+ load_all_data_for_training: true
54
+ obs: ["image_0"]
55
+ image_size: [224, 224]
56
+ video_backend: torchcodec
57
+ load_video: true
58
+ chunk_size: 15
59
+ state_use_action_chunk: false
60
+ num_history_steps: 0
61
+ include_state: true
62
+
63
+ trainer:
64
+ epochs: 100
65
+ max_train_steps: 50000
66
+ num_warmup_steps: 5000
67
+ num_stable_steps: 0
68
+ save_interval: 5000
69
+ eval_interval: 50
70
+ max_checkpoints_to_keep: 20
71
+
72
+ # Used in QwenGR00T.forward() to repeat diffusion training pairs
73
+ repeated_diffusion_steps: 1
74
+
75
+ learning_rate:
76
+ base: 5e-05
77
+ qwen_vl_interface: 5e-05
78
+ action_model: 5e-05
79
+ lr_scheduler_type: warmup_stable_cosine
80
+ scheduler_specific_kwargs:
81
+ min_lr_ratio: 0.001`
82
+
83
+ freeze_modules: ''
84
+ warmup_ratio: 0.1
85
+ logging_frequency: 10
86
+ gradient_clipping: 5.0
87
+ gradient_accumulation_steps: 4
88
+
89
+ optimizer:
90
+ name: AdamW
91
+ betas: [0.9, 0.95]
92
+ eps: 1.0e-08
93
+ weight_decay: 1.0e-08
94
+
95
+ is_resume: false
96
+ resume_epoch: null
97
+ resume_step: null
98
+ enable_gradient_checkpointing: true
99
+ enable_mixed_precision_training: true
code/config/training/starvla_train_qwenlatent_history_naive_oxe.yaml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_id: vla_jepa_temp
2
+ run_root_dir: ./runs
3
+ seed: 42
4
+ trackers: [jsonl, wandb]
5
+ wandb_entity: timsty
6
+ wandb_project: vla_jepa
7
+ is_debug: false
8
+
9
+ framework:
10
+ # Naive baseline: history tokens are projected directly via two-layer MLPs
11
+ # (history_action_projector + history_state_projector) without any action
12
+ # encoder. Directly comparable to QwenLatent_history which uses the full
13
+ # action-encoder path for history encoding.
14
+ name: QwenLatent_history_naive
15
+ qwenvl:
16
+ base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct
17
+ attn_implementation: flash_attention_2
18
+ vl_hidden_dim: 2048
19
+ num_data_tokens: 32
20
+ action_model:
21
+ ckpt_path: /mnt/data/fangyu/code/reward_new/runs/0417_Action_9tasks_actionstate_fixchunk15/final_model/pytorch_model.pt
22
+ # ckpt_path: null
23
+ action_size: 37
24
+ state_size: 74 # 与 action model 一致;0 表示不使用 state
25
+ use_state: ${datasets.vla_data.state_use_action_chunk}
26
+ hidden_size: 1024
27
+ intermediate_size: 3072
28
+ dataset_vocab_size: 256
29
+ num_data_tokens: 32
30
+ min_action_len: 5
31
+ num_encoder_layers: 28
32
+ num_decoder_layers: 28
33
+ num_attention_heads: 16
34
+ num_key_value_heads: 8
35
+ head_dim: 128
36
+ max_position_embeddings: 2048
37
+ max_action_chunk_size: 50
38
+ rms_norm_eps: 1.0e-6
39
+ attention_dropout: 0.0
40
+ use_vae_reparameterization: false
41
+ use_ema: false # 是否使用 EMA;若为 false,则冻结 encoder,只训练 VLM 和 decoder
42
+ chunk_size: ${datasets.vla_data.chunk_size}
43
+ loss_mode: full # full, predict_only
44
+ qwen3_pretrained_name_or_path: /mnt/data/fangyu/model/Qwen/Qwen3-0.6B
45
+ datasets:
46
+ vla_data:
47
+ dataset_py: lerobot_datasets
48
+ data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY
49
+ data_mix: cross_embodiedment_simulator # bridge_rt_1
50
+ # action_type: delta_ee
51
+ CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?"
52
+ default_image_resolution: [3, 224, 224]
53
+ per_device_batch_size: 32
54
+ load_all_data_for_training: true
55
+ obs: ["image_0"]
56
+ image_size: [224,224]
57
+ video_backend: torchcodec
58
+ load_video: true
59
+ chunk_size: 30
60
+ # state chunk 与 action chunk 对齐(与 action model 训练一致)
61
+ state_use_action_chunk: true
62
+ # 历史 state/action 步数;>0 时每个 sample 会多返回 state_history、action_history
63
+ num_history_steps: 15
64
+ include_state: ${datasets.vla_data.state_use_action_chunk}
65
+
66
+ trainer:
67
+ epochs: 100
68
+ max_train_steps: 30000
69
+ num_warmup_steps: 3000
70
+ num_stable_steps: 0 # 保持 max_lr 的步数(在 warmup 之后)
71
+ mode: freeze_action_encoder_decay_aux_loss # freeze_action_encoder_decay_aux_loss
72
+ loss_weights_decay_steps: 5000
73
+
74
+ save_interval: 5000
75
+ eval_interval: 50
76
+ max_checkpoints_to_keep: 10 # 最多保留的checkpoint数量,超过则删除最旧的
77
+ learning_rate:
78
+ base: 2.5e-05
79
+ qwen_vl_interface: 2.5e-05
80
+ action_model: 2.5e-05
81
+ lr_scheduler_type: warmup_stable_cosine # options: warmup_stable_cosine (default), onecycle
82
+ scheduler_specific_kwargs:
83
+ min_lr_ratio: 0.001 # 最终 lr = base_lr * min_lr_ratio
84
+ freeze_modules: ''
85
+ loss_scale:
86
+ align_loss: 1.0
87
+ recon_loss: 1.0
88
+ predict_loss: 1.0
89
+ warmup_ratio: 0.1
90
+ weight_decay: 0.0
91
+ logging_frequency: 10
92
+ gradient_clipping: 5.0
93
+ gradient_accumulation_steps: 1
94
+
95
+ optimizer:
96
+ name: AdamW
97
+ betas: [0.9, 0.95]
98
+ eps: 1.0e-08
99
+ weight_decay: 1.0e-08
100
+
101
+ # parameters to be determined
102
+ is_resume: false
103
+ resume_epoch: null
104
+ resume_step: null
105
+ enable_gradient_checkpointing: true
106
+ enable_mixed_precision_training: true
code/config/training/starvla_train_qwenlatent_history_oxe.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_id: vla_jepa_temp
2
+ run_root_dir: ./runs
3
+ seed: 42
4
+ trackers: [jsonl, wandb]
5
+ wandb_entity: timsty
6
+ wandb_project: vla_jepa
7
+ is_debug: false
8
+
9
+ framework:
10
+ name: QwenLatent_history
11
+ qwenvl:
12
+ base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct
13
+ attn_implementation: flash_attention_2
14
+ vl_hidden_dim: 2048
15
+ num_data_tokens: 32
16
+ action_model:
17
+ ckpt_path: /mnt/data/fangyu/code/reward_new/runs/0418_Action_13tasks_actionstate_fixchunk15/final_model/pytorch_model.pt
18
+ # ckpt_path: null
19
+ action_size: 37
20
+ state_size: 74 # 与 action model 一致;0 表示不使用 state
21
+ use_state: ${datasets.vla_data.state_use_action_chunk}
22
+ hidden_size: 1024
23
+ intermediate_size: 3072
24
+ dataset_vocab_size: 256
25
+ num_data_tokens: 32
26
+ min_action_len: 5
27
+ num_encoder_layers: 28
28
+ num_decoder_layers: 28
29
+ num_attention_heads: 16
30
+ num_key_value_heads: 8
31
+ head_dim: 128
32
+ max_position_embeddings: 2048
33
+ max_action_chunk_size: 50
34
+ rms_norm_eps: 1.0e-6
35
+ attention_dropout: 0.0
36
+ use_vae_reparameterization: false
37
+ use_ema: false # 是否使用 EMA;若为 false,则冻结 encoder,只训练 VLM 和 decoder
38
+ chunk_size: ${datasets.vla_data.chunk_size}
39
+ loss_mode: full # full, predict_only
40
+ qwen3_pretrained_name_or_path: /mnt/data/fangyu/model/Qwen/Qwen3-0.6B
41
+ datasets:
42
+ vla_data:
43
+ dataset_py: lerobot_datasets
44
+ data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY
45
+ data_mix: cross_embodiedment_simulator # bridge_rt_1
46
+ # action_type: delta_ee
47
+ CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?"
48
+ default_image_resolution: [3, 224, 224]
49
+ per_device_batch_size: 32
50
+ load_all_data_for_training: true
51
+ obs: ["image_0"]
52
+ image_size: [224,224]
53
+ video_backend: torchcodec
54
+ load_video: true
55
+ chunk_size: 30
56
+ # state chunk 与 action chunk 对齐(与 action model 训练一致)
57
+ state_use_action_chunk: true
58
+ # 历史 state/action 步数;>0 时每个 sample 会多返回 state_history、action_history
59
+ num_history_steps: 15
60
+ include_state: ${datasets.vla_data.state_use_action_chunk}
61
+
62
+ trainer:
63
+ epochs: 100
64
+ max_train_steps: 50000
65
+ num_warmup_steps: 5000
66
+ num_stable_steps: 0 # 保持 max_lr 的步数(在 warmup 之后)
67
+ mode: freeze_action_encoder_decay_aux_loss # freeze_action_encoder_decay_aux_loss
68
+ loss_weights_decay_steps: 5000
69
+
70
+ save_interval: 5000
71
+ eval_interval: 50
72
+ max_checkpoints_to_keep: 10 # 最多保留的checkpoint数量,超过则删除最旧的
73
+ learning_rate:
74
+ base: 3e-05
75
+ qwen_vl_interface: 3e-05
76
+ action_model: 3e-05
77
+ lr_scheduler_type: warmup_stable_cosine # options: warmup_stable_cosine (default), onecycle
78
+ scheduler_specific_kwargs:
79
+ min_lr_ratio: 0.001 # 最终 lr = base_lr * min_lr_ratio
80
+ freeze_modules: ''
81
+ loss_scale:
82
+ align_loss: 1.0
83
+ recon_loss: 1.0
84
+ predict_loss: 1.0
85
+ warmup_ratio: 0.1
86
+ weight_decay: 0.0
87
+ logging_frequency: 10
88
+ gradient_clipping: 5.0
89
+ gradient_accumulation_steps: 1
90
+
91
+ optimizer:
92
+ name: AdamW
93
+ betas: [0.9, 0.95]
94
+ eps: 1.0e-08
95
+ weight_decay: 1.0e-08
96
+
97
+ # parameters to be determined
98
+ is_resume: false
99
+ resume_epoch: null
100
+ resume_step: null
101
+ enable_gradient_checkpointing: true
102
+ enable_mixed_precision_training: true
code/config/training/starvla_train_qwenlatent_oxe.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_id: vla_jepa_temp
2
+ run_root_dir: ./runs
3
+ seed: 21
4
+ trackers: [jsonl, wandb]
5
+ wandb_entity: timsty
6
+ wandb_project: vla_jepa
7
+ is_debug: false
8
+
9
+ framework:
10
+ name: QwenLatent
11
+ qwenvl:
12
+ base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct
13
+ attn_implementation: flash_attention_2
14
+ vl_hidden_dim: 2048
15
+ num_data_tokens: 32
16
+ action_model:
17
+ ckpt_path: /mnt/data/fangyu/code/reward_new/runs/0418_Action_13tasks_actionstate_fixchunk15/final_model/pytorch_model.pt
18
+ # ckpt_path: null
19
+ action_size: 37
20
+ state_size: 74 # 与 action model 一致;0 表示不使用 state
21
+ use_state: ${datasets.vla_data.state_use_action_chunk}
22
+ hidden_size: 1024
23
+ intermediate_size: 3072
24
+ dataset_vocab_size: 256
25
+ num_data_tokens: 32
26
+ num_t_samples: 4
27
+ min_action_len: 5
28
+ num_encoder_layers: 28
29
+ num_decoder_layers: 28
30
+ num_attention_heads: 16
31
+ num_key_value_heads: 8
32
+ head_dim: 128
33
+ max_position_embeddings: 2048
34
+ max_action_chunk_size: 50
35
+ rms_norm_eps: 1.0e-6
36
+ attention_dropout: 0.0
37
+ use_vae_reparameterization: false
38
+ use_ema: false # 是否使用 EMA;若为 false,则冻结 encoder,只训练 VLM 和 decoder
39
+ chunk_size: ${datasets.vla_data.chunk_size}
40
+ loss_mode: full # full, predict_only
41
+ qwen3_pretrained_name_or_path: /mnt/data/fangyu/model/Qwen/Qwen3-0.6B
42
+ datasets:
43
+ vla_data:
44
+ dataset_py: lerobot_datasets
45
+ data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY
46
+ data_mix: cross_embodiedment_13tasks # bridge_rt_1
47
+ # action_type: delta_ee
48
+ CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?"
49
+ default_image_resolution: [3, 224, 224]
50
+ per_device_batch_size: 32
51
+ load_all_data_for_training: true
52
+ obs: ["image_0"]
53
+ image_size: [224,224]
54
+ video_backend: torchcodec
55
+ load_video: true
56
+ chunk_size: 15
57
+ # state chunk 与 action chunk 对齐(与 action model 训练一致)
58
+ state_use_action_chunk: true
59
+ # 历史 state/action 步数;>0 时每个 sample 会多返回 state_history、action_history
60
+ num_history_steps: 0
61
+ include_state: ${datasets.vla_data.state_use_action_chunk}
62
+
63
+ trainer:
64
+ epochs: 100
65
+ max_train_steps: 50000
66
+ num_warmup_steps: 5000
67
+ num_stable_steps: 0 # 保持 max_lr 的步数(在 warmup 之后)
68
+ mode: decay_aux_loss # freeze_action_encoder_decay_aux_loss
69
+ loss_weights_decay_steps: 5000
70
+
71
+ save_interval: 5000
72
+ eval_interval: 50
73
+ max_checkpoints_to_keep: 20 # 最多保留的checkpoint数量,超过则删除最旧的
74
+ learning_rate:
75
+ base: 5e-05
76
+ qwen_vl_interface: 5e-05
77
+ action_model: 5e-05
78
+ lr_scheduler_type: warmup_stable_cosine # options: warmup_stable_cosine (default), onecycle
79
+ scheduler_specific_kwargs:
80
+ min_lr_ratio: 0.001 # 最终 lr = base_lr * min_lr_ratio
81
+ freeze_modules: ''
82
+ loss_scale:
83
+ align_loss: 1.0
84
+ recon_loss: 1.0
85
+ predict_loss: 1.0
86
+ warmup_ratio: 0.1
87
+ weight_decay: 0.0
88
+ logging_frequency: 10
89
+ gradient_clipping: 5.0
90
+ gradient_accumulation_steps: 1
91
+
92
+ optimizer:
93
+ name: AdamW
94
+ betas: [0.9, 0.95]
95
+ eps: 1.0e-08
96
+ weight_decay: 1.0e-08
97
+
98
+ # parameters to be determined
99
+ is_resume: false
100
+ resume_epoch: null
101
+ resume_step: null
102
+ enable_gradient_checkpointing: true
103
+ enable_mixed_precision_training: true
code/config/training/starvla_train_qwenpi.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_id: qwenpi_oxe
2
+ run_root_dir: ./runs
3
+ seed: 42
4
+ trackers: [jsonl, wandb]
5
+ wandb_entity: timsty
6
+ wandb_project: vla_jepa
7
+ is_debug: false
8
+
9
+ framework:
10
+ name: QwenPI
11
+ qwenvl:
12
+ base_vlm: /mnt/data/fangyu/model/Qwen/Qwen3-VL-2B-Instruct
13
+ attn_implementation: flash_attention_2
14
+ vl_hidden_dim: 2048
15
+ num_data_tokens: 32 # dataset soft prompt tokens prepended to VLM input (0 = disabled)
16
+
17
+ # QwenPI required action head config (LayerwiseFlowmatchingActionHead)
18
+ action_model:
19
+ dataset_vocab_size: 256 # number of distinct dataset IDs for soft prompt embedding
20
+ hidden_size: 1024
21
+ add_pos_embed: true
22
+ max_seq_len: 1024
23
+ action_dim: 37
24
+ state_dim: 74
25
+ future_action_window_size: 14
26
+ action_horizon: 15
27
+ past_action_window_size: 0
28
+ noise_beta_alpha: 1.5
29
+ noise_beta_beta: 1.0
30
+ noise_s: 0.999
31
+ num_timestep_buckets: 1000
32
+ num_inference_timesteps: 10
33
+ num_target_vision_tokens: 32
34
+ diffusion_model_cfg:
35
+ dropout: 0.2
36
+ final_dropout: true
37
+ interleave_self_attention: true
38
+ norm_type: "ada_norm"
39
+ output_dim: 1024
40
+ positional_embeddings: null
41
+
42
+ datasets:
43
+ vla_data:
44
+ dataset_py: lerobot_datasets
45
+ data_root_dir: /mnt/data/fangyu/dataset/IPEC-COMMUNITY
46
+ data_mix: cross_embodiedment_13tasks
47
+ CoT_prompt: "Task: {instruction}. What are the next 15 actions to take?"
48
+ default_image_resolution: [3, 224, 224]
49
+ per_device_batch_size: 64
50
+ load_all_data_for_training: true
51
+ obs: ["image_0"]
52
+ image_size: [224, 224]
53
+ video_backend: torchcodec
54
+ load_video: true
55
+ chunk_size: 15
56
+ state_use_action_chunk: false
57
+ num_history_steps: 0
58
+ include_state: true
59
+
60
+ trainer:
61
+ epochs: 100
62
+ max_train_steps: 50000
63
+ num_warmup_steps: 5000
64
+ num_stable_steps: 0
65
+ save_interval: 5000
66
+ eval_interval: 50
67
+ max_checkpoints_to_keep: 20
68
+
69
+ # Used in QwenPI.forward() to repeat diffusion training pairs
70
+ repeated_diffusion_steps: 1
71
+
72
+ learning_rate:
73
+ base: 5e-05
74
+ qwen_vl_interface: 5e-05
75
+ action_model: 5e-05
76
+ lr_scheduler_type: warmup_stable_cosine
77
+ scheduler_specific_kwargs:
78
+ min_lr_ratio: 0.001
79
+
80
+ freeze_modules: ''
81
+ warmup_ratio: 0.1
82
+ weight_decay: 0.0
83
+ logging_frequency: 10
84
+ gradient_clipping: 5.0
85
+ gradient_accumulation_steps: 1
86
+
87
+ optimizer:
88
+ name: AdamW
89
+ betas: [0.9, 0.95]
90
+ eps: 1.0e-08
91
+ weight_decay: 1.0e-08
92
+
93
+ is_resume: false
94
+ resume_epoch: null
95
+ resume_step: null
96
+ enable_gradient_checkpointing: true
97
+ enable_mixed_precision_training: true
code/dataloader/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from accelerate.logging import get_logger
3
+ import numpy as np
4
+ from torch.utils.data import DataLoader
5
+ import torch.distributed as dist
6
+ from pathlib import Path
7
+ from starVLA.dataloader.vlm_datasets import make_vlm_dataloader
8
+
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ def _is_main_process() -> bool:
13
+ return (not dist.is_initialized()) or dist.get_rank() == 0
14
+
15
+ def save_dataset_statistics(dataset_statistics, run_dir):
16
+ """Saves a `dataset_statistics.json` file."""
17
+ out_path = run_dir / "dataset_statistics.json"
18
+ with open(out_path, "w") as f_json:
19
+ for _, stats in dataset_statistics.items():
20
+ for k in stats["action"].keys():
21
+ if isinstance(stats["action"][k], np.ndarray):
22
+ stats["action"][k] = stats["action"][k].tolist()
23
+ if "proprio" in stats:
24
+ for k in stats["proprio"].keys():
25
+ if isinstance(stats["proprio"][k], np.ndarray):
26
+ stats["proprio"][k] = stats["proprio"][k].tolist()
27
+ if "num_trajectories" in stats:
28
+ if isinstance(stats["num_trajectories"], np.ndarray):
29
+ stats["num_trajectories"] = stats["num_trajectories"].item()
30
+ if "num_transitions" in stats:
31
+ if isinstance(stats["num_transitions"], np.ndarray):
32
+ stats["num_transitions"] = stats["num_transitions"].item()
33
+ json.dump(dataset_statistics, f_json, indent=2)
34
+ logger.info(f"Saved dataset statistics file at path {out_path}")
35
+
36
+
37
+
38
+ def build_dataloader(cfg, dataset_py="lerobot_datasets_oxe"): # TODO now here only is get dataset, we need mv dataloader to here
39
+
40
+ if dataset_py == "lerobot_datasets":
41
+ from starVLA.dataloader.lerobot_datasets import get_vla_dataset, collate_fn
42
+ vla_dataset_cfg = cfg.datasets.vla_data
43
+
44
+ vla_dataset = get_vla_dataset(data_cfg=vla_dataset_cfg)
45
+
46
+ vla_train_dataloader = DataLoader(
47
+ vla_dataset,
48
+ batch_size=cfg.datasets.vla_data.per_device_batch_size,
49
+ collate_fn=collate_fn,
50
+ num_workers=16,
51
+ prefetch_factor=20,
52
+ shuffle=True,
53
+ persistent_workers=True, # 保持 worker 存活,避免重启开销
54
+ pin_memory=True, # 加速 GPU 传输
55
+ drop_last=True, # 丢弃最后不完整的 batch,避免等待
56
+ timeout=30, # 设置超时,避免 worker 阻塞导致长时间等待
57
+ )
58
+ if _is_main_process():
59
+ output_dir = Path(cfg.output_dir)
60
+ vla_dataset.save_dataset_statistics(output_dir / "dataset_statistics.json")
61
+ return vla_train_dataloader
62
+ if dataset_py == "vlm_datasets":
63
+ vlm_data_module = make_vlm_dataloader(cfg)
64
+ vlm_train_dataloader = vlm_data_module["train_dataloader"]
65
+ return vlm_train_dataloader
66
+
67
+ raise ValueError(
68
+ f"Unsupported dataset builder `{dataset_py}`. "
69
+ "Expected one of: `lerobot_datasets`, `vlm_datasets`."
70
+ )
code/dataloader/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.37 kB). View file
 
code/dataloader/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.28 kB). View file
 
code/dataloader/__pycache__/lerobot_datasets.cpython-310.pyc ADDED
Binary file (3.74 kB). View file
 
code/dataloader/__pycache__/lerobot_datasets.cpython-311.pyc ADDED
Binary file (5.74 kB). View file
 
code/dataloader/__pycache__/vlm_datasets.cpython-310.pyc ADDED
Binary file (19.6 kB). View file
 
code/dataloader/__pycache__/vlm_datasets.cpython-311.pyc ADDED
Binary file (39.1 kB). View file
 
code/dataloader/gr00t_lerobot/README.md ADDED
File without changes
code/dataloader/gr00t_lerobot/__init__.py ADDED
File without changes
code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (170 Bytes). View file
 
code/dataloader/gr00t_lerobot/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (186 Bytes). View file
 
code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-310.pyc ADDED
Binary file (6.78 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/data_config.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-310.pyc ADDED
Binary file (59.2 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/datasets.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:953210fc54145a3fac1026888bb1106cdd6351cb8d4eb9e94161669f67db91d9
3
+ size 105575
code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-310.pyc ADDED
Binary file (5.75 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/embodiment_tags.cpython-311.pyc ADDED
Binary file (7 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-310.pyc ADDED
Binary file (6.47 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/mixtures.cpython-311.pyc ADDED
Binary file (6.75 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-310.pyc ADDED
Binary file (8.09 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/schema.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/video.cpython-310.pyc ADDED
Binary file (5.26 kB). View file
 
code/dataloader/gr00t_lerobot/__pycache__/video.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
code/dataloader/gr00t_lerobot/data_config.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ from starVLA.dataloader.gr00t_lerobot.datasets import ModalityConfig
19
+ from starVLA.dataloader.gr00t_lerobot.transform.base import ComposedModalityTransform, ModalityTransform
20
+ from starVLA.dataloader.gr00t_lerobot.transform.state_action import (
21
+ StateActionSinCosTransform,
22
+ StateActionToTensor,
23
+ StateActionTransform,
24
+ )
25
+
26
+
27
+ class BaseDataConfig(ABC):
28
+ @abstractmethod
29
+ def modality_config(self) -> dict[str, ModalityConfig]:
30
+ pass
31
+
32
+ @abstractmethod
33
+ def transform(self) -> ModalityTransform:
34
+ pass
35
+
36
+
37
+ ###########################################################################################
38
+
39
+ class Libero4in1DataConfig:
40
+ video_keys = [
41
+ "video.primary_image",
42
+ "video.wrist_image",
43
+ ]
44
+
45
+ state_keys = [
46
+ "state.x",
47
+ "state.y",
48
+ "state.z",
49
+ "state.roll",
50
+ "state.pitch",
51
+ "state.yaw",
52
+ "state.pad",
53
+ "state.gripper",
54
+ ]
55
+ action_keys = [
56
+ "action.x",
57
+ "action.y",
58
+ "action.z",
59
+ "action.roll",
60
+ "action.pitch",
61
+ "action.yaw",
62
+ "action.gripper",
63
+ ]
64
+
65
+ language_keys = ["annotation.human.action.task_description"]
66
+
67
+ observation_indices = [0]
68
+ action_indices = list(range(16))
69
+
70
+ def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0):
71
+ self.chunk_size = chunk_size
72
+ self.action_indices = list(range(chunk_size))
73
+ self.state_use_action_chunk = state_use_action_chunk
74
+ self.num_history_steps = int(num_history_steps or 0)
75
+ self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1]
76
+
77
+ def modality_config(self):
78
+ video_modality = ModalityConfig(
79
+ delta_indices=self.video_observation_indices,
80
+ modality_keys=self.video_keys,
81
+ )
82
+ state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices
83
+ state_modality = ModalityConfig(
84
+ delta_indices=state_delta,
85
+ modality_keys=self.state_keys,
86
+ )
87
+ action_modality = ModalityConfig(
88
+ delta_indices=self.action_indices,
89
+ modality_keys=self.action_keys,
90
+ )
91
+ language_modality = ModalityConfig(
92
+ delta_indices=self.observation_indices,
93
+ modality_keys=self.language_keys,
94
+ )
95
+ modality_configs = {
96
+ "video": video_modality,
97
+ "state": state_modality,
98
+ "action": action_modality,
99
+ "language": language_modality,
100
+ }
101
+ return modality_configs
102
+
103
+ def transform(self):
104
+ transforms = [
105
+ # state transforms
106
+ StateActionToTensor(apply_to=self.state_keys),
107
+ StateActionTransform(
108
+ apply_to=self.state_keys,
109
+ normalization_modes={
110
+ "state.x": "min_max",
111
+ "state.y": "min_max",
112
+ "state.z": "min_max",
113
+ "state.roll": "min_max",
114
+ "state.pitch": "min_max",
115
+ "state.yaw": "min_max",
116
+ "state.pad": "min_max",
117
+ # "state.gripper": "binary",
118
+ },
119
+ ),
120
+ # action transforms
121
+ StateActionToTensor(apply_to=self.action_keys),
122
+ StateActionTransform(
123
+ apply_to=self.action_keys,
124
+ normalization_modes={
125
+ "action.x": "min_max",
126
+ "action.y": "min_max",
127
+ "action.z": "min_max",
128
+ "action.roll": "min_max",
129
+ "action.pitch": "min_max",
130
+ "action.yaw": "min_max",
131
+ # "action.gripper": "binary",
132
+ },
133
+ ),
134
+ ]
135
+
136
+ return ComposedModalityTransform(transforms=transforms)
137
+
138
+ ###########################################################################################
139
+
140
+ class RealWorldFrankaDataConfig:
141
+ """Real-world Panda robot: 7 joints + 1 gripper (8D), single-arm -> right slot [7:15]."""
142
+ video_keys = [
143
+ "video.exterior_image_1_left",
144
+ "video.wrist_image_left",
145
+ ]
146
+ state_keys = [
147
+ "state.joints",
148
+ "state.gripper",
149
+ ]
150
+ action_keys = [
151
+ "action.joints",
152
+ "action.gripper",
153
+ ]
154
+ language_keys = ["annotation.human.action.task_description"]
155
+ observation_indices = [0]
156
+ action_indices = list(range(16))
157
+
158
+ def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0):
159
+ self.chunk_size = chunk_size
160
+ self.action_indices = list(range(chunk_size))
161
+ self.state_use_action_chunk = state_use_action_chunk
162
+ self.num_history_steps = int(num_history_steps or 0)
163
+ self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1]
164
+
165
+ def modality_config(self):
166
+ video_modality = ModalityConfig(
167
+ delta_indices=self.video_observation_indices,
168
+ modality_keys=self.video_keys,
169
+ )
170
+ state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices
171
+ state_modality = ModalityConfig(
172
+ delta_indices=state_delta,
173
+ modality_keys=self.state_keys,
174
+ )
175
+ action_modality = ModalityConfig(
176
+ delta_indices=self.action_indices,
177
+ modality_keys=self.action_keys,
178
+ )
179
+ language_modality = ModalityConfig(
180
+ delta_indices=self.observation_indices,
181
+ modality_keys=self.language_keys,
182
+ )
183
+ modality_configs = {
184
+ "video": video_modality,
185
+ "state": state_modality,
186
+ "action": action_modality,
187
+ "language": language_modality,
188
+ }
189
+ return modality_configs
190
+
191
+ def transform(self):
192
+ transforms = [
193
+ StateActionToTensor(apply_to=self.state_keys),
194
+ StateActionTransform(
195
+ apply_to=self.state_keys,
196
+ normalization_modes={
197
+ "state.joints": "min_max",
198
+ # "state.gripper": "binary",
199
+ },
200
+ ),
201
+ StateActionToTensor(apply_to=self.action_keys),
202
+ StateActionTransform(
203
+ apply_to=self.action_keys,
204
+ normalization_modes={
205
+ "action.joints": "min_max",
206
+ # "action.gripper": "binary",
207
+ },
208
+ ),
209
+ ]
210
+ return ComposedModalityTransform(transforms=transforms)
211
+
212
+
213
+ class AgilexDataConfig:
214
+ video_keys = [
215
+ "video.cam_high",
216
+ "video.cam_left_wrist",
217
+ "video.cam_right_wrist",
218
+ ]
219
+ state_keys = [
220
+ "state.left_joints",
221
+ "state.left_gripper",
222
+ "state.right_joints",
223
+ "state.right_gripper",
224
+ ]
225
+ action_keys = [
226
+ "action.left_joints",
227
+ "action.left_gripper",
228
+ "action.right_joints",
229
+ "action.right_gripper",
230
+ ]
231
+
232
+ language_keys = ["annotation.human.action.task_description"]
233
+ observation_indices = [0]
234
+
235
+ def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0):
236
+ self.chunk_size = chunk_size
237
+ self.action_indices = list(range(chunk_size))
238
+ self.state_use_action_chunk = state_use_action_chunk
239
+ self.num_history_steps = int(num_history_steps or 0)
240
+ self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1]
241
+
242
+ def modality_config(self):
243
+ video_modality = ModalityConfig(
244
+ delta_indices=self.video_observation_indices,
245
+ modality_keys=self.video_keys,
246
+ )
247
+ state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices
248
+ state_modality = ModalityConfig(
249
+ delta_indices=state_delta,
250
+ modality_keys=self.state_keys,
251
+ )
252
+ action_modality = ModalityConfig(
253
+ delta_indices=self.action_indices,
254
+ modality_keys=self.action_keys,
255
+ )
256
+ language_modality = ModalityConfig(
257
+ delta_indices=self.observation_indices,
258
+ modality_keys=self.language_keys,
259
+ )
260
+ modality_configs = {
261
+ "video": video_modality,
262
+ "state": state_modality,
263
+ "action": action_modality,
264
+ "language": language_modality,
265
+ }
266
+ return modality_configs
267
+
268
+ def transform(self):
269
+ transforms = [
270
+ # state transforms
271
+ StateActionToTensor(apply_to=self.state_keys),
272
+ StateActionTransform(
273
+ apply_to=self.state_keys,
274
+ normalization_modes={
275
+ "state.left_joints": "min_max",
276
+ "state.left_gripper": "binary",
277
+ "state.right_joints": "min_max",
278
+ "state.right_gripper": "binary",
279
+ },
280
+ ),
281
+ # action transforms
282
+ StateActionToTensor(apply_to=self.action_keys),
283
+ StateActionTransform(
284
+ apply_to=self.action_keys,
285
+ normalization_modes={
286
+ "action.left_joints": "min_max",
287
+ "action.left_gripper": "binary",
288
+ "action.right_joints": "min_max",
289
+ "action.right_gripper": "binary",
290
+ },
291
+ ),
292
+ ]
293
+ return ComposedModalityTransform(transforms=transforms)
294
+
295
+
296
+ class FourierGr1ArmsWaistDataConfig:
297
+ video_keys = ["video.ego_view"]
298
+ state_keys = [
299
+ "state.left_arm",
300
+ "state.right_arm",
301
+ "state.left_hand",
302
+ "state.right_hand",
303
+ "state.waist",
304
+ ]
305
+ action_keys = [
306
+ "action.left_arm",
307
+ "action.right_arm",
308
+ "action.left_hand",
309
+ "action.right_hand",
310
+ "action.waist",
311
+ ]
312
+ language_keys = ["annotation.human.coarse_action"]
313
+ observation_indices = [0]
314
+
315
+ def __init__(self, chunk_size: int = 16, state_use_action_chunk: bool = False, num_history_steps: int = 0):
316
+ self.chunk_size = chunk_size
317
+ self.action_indices = list(range(chunk_size))
318
+ self.state_use_action_chunk = state_use_action_chunk
319
+ self.num_history_steps = int(num_history_steps or 0)
320
+ self.video_observation_indices = [0] if self.num_history_steps == 0 else [0, self.num_history_steps - 1]
321
+
322
+ def modality_config(self):
323
+ video_modality = ModalityConfig(
324
+ delta_indices=self.video_observation_indices,
325
+ modality_keys=self.video_keys,
326
+ )
327
+ state_delta = self.action_indices if getattr(self, "state_use_action_chunk", False) else self.observation_indices
328
+ state_modality = ModalityConfig(
329
+ delta_indices=state_delta,
330
+ modality_keys=self.state_keys,
331
+ )
332
+ action_modality = ModalityConfig(
333
+ delta_indices=self.action_indices,
334
+ modality_keys=self.action_keys,
335
+ )
336
+ language_modality = ModalityConfig(
337
+ delta_indices=self.observation_indices,
338
+ modality_keys=self.language_keys,
339
+ )
340
+ modality_configs = {
341
+ "video": video_modality,
342
+ "state": state_modality,
343
+ "action": action_modality,
344
+ "language": language_modality,
345
+ }
346
+ return modality_configs
347
+
348
+ def transform(self) -> ModalityTransform:
349
+ transforms = [
350
+ # state transforms
351
+ StateActionToTensor(apply_to=self.state_keys),
352
+ StateActionSinCosTransform(apply_to=self.state_keys),
353
+ # action transforms
354
+ StateActionToTensor(apply_to=self.action_keys),
355
+ StateActionTransform(
356
+ apply_to=self.action_keys,
357
+ normalization_modes={key: "min_max" for key in self.action_keys},
358
+ ),
359
+ ]
360
+ return ComposedModalityTransform(transforms=transforms)
361
+
362
+ ###########################################################################################
363
+
364
+
365
+ def get_robot_type_config_map(
366
+ chunk_size: int = 15,
367
+ state_use_action_chunk: bool = True,
368
+ num_history_steps: int = 0,
369
+ ) -> dict[str, BaseDataConfig]:
370
+ """state_use_action_chunk: when True, state uses action_indices so state has shape (L, state_dim) aligned with action chunk."""
371
+ return {
372
+ "libero_franka": Libero4in1DataConfig(
373
+ chunk_size=chunk_size,
374
+ state_use_action_chunk=state_use_action_chunk,
375
+ num_history_steps=num_history_steps,
376
+ ),
377
+ "robotwin": AgilexDataConfig(
378
+ chunk_size=chunk_size,
379
+ state_use_action_chunk=state_use_action_chunk,
380
+ num_history_steps=num_history_steps,
381
+ ),
382
+ "fourier_gr1_arms_waist": FourierGr1ArmsWaistDataConfig(
383
+ chunk_size=chunk_size,
384
+ state_use_action_chunk=state_use_action_chunk,
385
+ num_history_steps=num_history_steps,
386
+ ),
387
+ "real_world_franka": RealWorldFrankaDataConfig(
388
+ chunk_size=chunk_size,
389
+ state_use_action_chunk=state_use_action_chunk,
390
+ num_history_steps=num_history_steps,
391
+ ),
392
+ }
code/dataloader/gr00t_lerobot/datasets.py ADDED
@@ -0,0 +1,2165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ """
18
+ In this file, we define 3 types of datasets:
19
+ 1. LeRobotSingleDataset: a single dataset for a given embodiment tag
20
+ 2. LeRobotMixtureDataset: a mixture of datasets for a given list of embodiment tags
21
+ 3. CachedLeRobotSingleDataset: a single dataset for a given embodiment tag,
22
+ with caching for the video frames
23
+
24
+ See `scripts/load_dataset.py` for examples on how to use these datasets.
25
+ """
26
+ import os
27
+ import hashlib
28
+ import json, torch
29
+ from collections import defaultdict
30
+ from pathlib import Path
31
+ from typing import Sequence
32
+ import os, random
33
+ import numpy as np
34
+ import pandas as pd
35
+ from pydantic import BaseModel, Field, ValidationError
36
+ from torch.utils.data import Dataset
37
+ from tqdm import tqdm
38
+ from PIL import Image
39
+
40
+ from starVLA.dataloader.gr00t_lerobot.video import get_all_frames, get_frames_by_timestamps
41
+
42
+ from starVLA.dataloader.gr00t_lerobot.embodiment_tags import EmbodimentTag, DATASET_NAME_TO_ID
43
+ from starVLA.dataloader.gr00t_lerobot.schema import (
44
+ DatasetMetadata,
45
+ DatasetStatisticalValues,
46
+ LeRobotModalityMetadata,
47
+ LeRobotStateActionMetadata,
48
+ )
49
+ from starVLA.dataloader.gr00t_lerobot.transform import ComposedModalityTransform
50
+
51
+ from functools import partial
52
+ from typing import Tuple, List
53
+ import pickle
54
+
55
+ # LeRobot v2.0 dataset file names
56
+ LE_ROBOT_MODALITY_FILENAME = "meta/modality.json"
57
+ LE_ROBOT_EPISODE_FILENAME = "meta/episodes.jsonl"
58
+ LE_ROBOT_TASKS_FILENAME = "meta/tasks.jsonl"
59
+ LE_ROBOT_INFO_FILENAME = "meta/info.json"
60
+ LE_ROBOT_STATS_FILENAME = "meta/stats_gr00t.json"
61
+ LE_ROBOT_DATA_FILENAME = "data/*/*.parquet"
62
+ LE_ROBOT_STEPS_FILENAME = "meta/steps.pkl"
63
+ EPSILON = 5e-4
64
+
65
+ # LeRobot v3.0 dataset file names
66
+ LE_ROBOT3_TASKS_FILENAME = "meta/tasks.parquet"
67
+ LE_ROBOT3_EPISODE_FILENAME = "meta/episodes/*/*.parquet"
68
+
69
+
70
+ # =============================================================================
71
+ # Unified Representation Layout & Helpers
72
+ # =============================================================================
73
+
74
+ STANDARD_ACTION_DIM = 37
75
+ #
76
+ # Unified action representation layout (0-based indices, Python slice is [start, stop)):
77
+ # Keep only: libero_franka, gr1, real_world_franka.
78
+ #
79
+ # - 0:7 -> left_arm (7D): xyz, rpy/euler, gripper
80
+ # Used by: gr1 left_arm
81
+ # - 7:14 -> right_arm (7D): same structure
82
+ # Used by: libero_franka; gr1 right_arm
83
+ # - 14:20 -> left_hand (6D): gr1 only
84
+ # - 20:26 -> right_hand (6D): gr1 only
85
+ # - 26:29 -> waist (3D): gr1 only
86
+ # - 29:37 -> joints + gripper (8D): real_world_franka only
87
+ #
88
+ # Mapping:
89
+ # libero_franka (7D) -> [7:14] (right_arm slot)
90
+ # gr1 (29D) -> [0:29]
91
+ # real_world_franka (8D) -> [29:37] (joints + gripper)
92
+
93
+ ACTION_REPRESENTATION_SLICES = {
94
+ # Single-arm (7D) -> right_arm slot [7:14]
95
+ "franka": slice(7, 14),
96
+
97
+ # Humanoid (29D) -> full [0:29]
98
+ "gr1": slice(0, 29),
99
+
100
+ # Real-world (8D) -> [29:37] (joints + gripper)
101
+ "real_world_franka": slice(29, 37),
102
+ }
103
+
104
+ STANDARD_STATE_DIM = 74
105
+ # Mapping:
106
+ # libero_franka (8D) -> [0:8]
107
+ # real_world_franka (8D) -> [8:16]
108
+ # gr1 (58D after sin/cos) -> [16:74]
109
+
110
+ STATE_REPRESENTATION_SLICES = {
111
+ # Single-arm (8D)
112
+ "franka": slice(0, 8),
113
+ # Real-world (8D)
114
+ "real_world_franka": slice(8, 16),
115
+ # GR1 isolated (58D, has StateActionSinCosTransform - different pipeline)
116
+ "gr1": slice(16, 74),
117
+ }
118
+
119
+
120
+ def standardize_action_representation(
121
+ action: np.ndarray, embodiment_tag: str
122
+ ) -> np.ndarray:
123
+ """Map per-robot action to a fixed-size standard action vector."""
124
+ target_slice = ACTION_REPRESENTATION_SLICES.get(embodiment_tag)
125
+
126
+ # Only allow explicitly configured embodiment tags.
127
+ if target_slice is None:
128
+ raise ValueError(
129
+ f"Unknown embodiment tag '{embodiment_tag}' for action mapping. "
130
+ f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES)}"
131
+ )
132
+
133
+ expected_dim = target_slice.stop - target_slice.start
134
+ if action.shape[-1] != expected_dim:
135
+ raise ValueError(
136
+ f"Action dim mismatch for tag '{embodiment_tag}': "
137
+ f"{action.shape[-1]=} vs expected {expected_dim}."
138
+ )
139
+
140
+ standard = np.zeros(
141
+ (*action.shape[:-1], STANDARD_ACTION_DIM), dtype=action.dtype
142
+ )
143
+ standard[..., target_slice] = action
144
+ return standard
145
+
146
+
147
+ def standardize_state_representation(
148
+ state: np.ndarray, embodiment_tag: str
149
+ ) -> np.ndarray:
150
+ """Map per-robot state to a fixed-size standard state vector."""
151
+
152
+ target_slice = STATE_REPRESENTATION_SLICES.get(embodiment_tag)
153
+
154
+ # Only allow explicitly configured embodiment tags.
155
+ if target_slice is None:
156
+ raise ValueError(
157
+ f"Unknown embodiment tag '{embodiment_tag}' for state mapping. "
158
+ f"Known tags: {sorted(STATE_REPRESENTATION_SLICES)}"
159
+ )
160
+
161
+ expected_dim = target_slice.stop - target_slice.start
162
+ if state.shape[-1] != expected_dim:
163
+ raise ValueError(
164
+ f"State dim mismatch for tag '{embodiment_tag}': "
165
+ f"{state.shape[-1]=} vs expected {expected_dim}."
166
+ )
167
+
168
+ standard = np.zeros(
169
+ (*state.shape[:-1], STANDARD_STATE_DIM), dtype=state.dtype
170
+ )
171
+ standard[..., target_slice] = state
172
+ return standard
173
+
174
+
175
+ def calculate_dataset_statistics(parquet_paths: list[Path]) -> dict:
176
+ """Calculate the dataset statistics of all columns for a list of parquet files."""
177
+ # Dataset statistics
178
+ all_low_dim_data_list = []
179
+ # Collect all the data
180
+ # parquet_paths = parquet_paths[:3]
181
+ for parquet_path in tqdm(
182
+ sorted(list(parquet_paths)),
183
+ desc="Collecting all parquet files...",
184
+ ):
185
+ # Load the parquet file
186
+ parquet_data = pd.read_parquet(parquet_path)
187
+ parquet_data = parquet_data
188
+ all_low_dim_data_list.append(parquet_data)
189
+
190
+ all_low_dim_data = pd.concat(all_low_dim_data_list, axis=0)
191
+ # Compute dataset statistics
192
+ dataset_statistics = {}
193
+ for le_modality in all_low_dim_data.columns:
194
+ if le_modality.startswith("annotation."):
195
+ continue
196
+ print(f"Computing statistics for {le_modality}...")
197
+ np_data = np.vstack(
198
+ [np.asarray(x, dtype=np.float32) for x in all_low_dim_data[le_modality]]
199
+ )
200
+ dataset_statistics[le_modality] = {
201
+ "mean": np.mean(np_data, axis=0).tolist(),
202
+ "std": np.std(np_data, axis=0).tolist(),
203
+ "min": np.min(np_data, axis=0).tolist(),
204
+ "max": np.max(np_data, axis=0).tolist(),
205
+ "q01": np.quantile(np_data, 0.01, axis=0).tolist(),
206
+ "q99": np.quantile(np_data, 0.99, axis=0).tolist(),
207
+ }
208
+ return dataset_statistics
209
+
210
+
211
+ class ModalityConfig(BaseModel):
212
+ """Configuration for a modality."""
213
+
214
+ delta_indices: list[int]
215
+ """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices."""
216
+ modality_keys: list[str]
217
+ """The keys to load for the modality in the dataset."""
218
+
219
+
220
+ class LeRobotSingleDataset(Dataset):
221
+ """
222
+ Base dataset class for LeRobot that supports sharding.
223
+ """
224
+ def __init__(
225
+ self,
226
+ dataset_path: Path | str,
227
+ modality_configs: dict[str, ModalityConfig],
228
+ embodiment_tag: str | EmbodimentTag,
229
+ video_backend: str = "decord",
230
+ video_backend_kwargs: dict | None = None,
231
+ transforms: ComposedModalityTransform | None = None,
232
+ delete_pause_frame: bool = False,
233
+ **kwargs,
234
+ ):
235
+ """
236
+ Initialize the dataset.
237
+
238
+ Args:
239
+ dataset_path (Path | str): The path to the dataset.
240
+ modality_configs (dict[str, ModalityConfig]): The configuration for each modality. The keys are the modality names, and the values are the modality configurations.
241
+ See `ModalityConfig` for more details.
242
+ video_backend (str): Backend for video reading.
243
+ video_backend_kwargs (dict): Keyword arguments for the video backend when initializing the video reader.
244
+ transforms (ComposedModalityTransform): The transforms to apply to the dataset.
245
+ embodiment_tag (EmbodimentTag): Overload the embodiment tag for the dataset. e.g. define it as "new_embodiment"
246
+ """
247
+ # first check if the path directory exists
248
+ if not Path(dataset_path).exists():
249
+ raise FileNotFoundError(f"Dataset path {dataset_path} does not exist")
250
+ data_cfg = kwargs.get("data_cfg", {}) or {}
251
+ # indict letobot version
252
+ self._lerobot_version = data_cfg.get("lerobot_version", "v2.0") #self._indict_lerobot_version(**kwargs)
253
+ self.load_video = data_cfg.get("load_video", True)
254
+ self.num_history_steps = int(data_cfg.get("num_history_steps", 0) or 0)
255
+
256
+ self.delete_pause_frame = delete_pause_frame
257
+
258
+ # If video loading is disabled, skip video modality end-to-end.
259
+ if self.load_video:
260
+ self.modality_configs = modality_configs
261
+ else:
262
+ self.modality_configs = {
263
+ modality: config
264
+ for modality, config in modality_configs.items()
265
+ if modality != "video"
266
+ }
267
+ self.video_backend = video_backend
268
+ self.video_backend_kwargs = video_backend_kwargs if video_backend_kwargs is not None else {}
269
+ self.transforms = (
270
+ transforms if transforms is not None else ComposedModalityTransform(transforms=[])
271
+ )
272
+
273
+ self._dataset_path = Path(dataset_path)
274
+ self._dataset_name = self._dataset_path.name
275
+ self._dataset_id = DATASET_NAME_TO_ID.get(self._dataset_name)
276
+ if isinstance(embodiment_tag, EmbodimentTag):
277
+ self.tag = embodiment_tag.value
278
+ else:
279
+ self.tag = embodiment_tag
280
+
281
+ self._metadata = self._get_metadata(EmbodimentTag(self.tag))
282
+
283
+ # LeRobot-specific config
284
+ self._lerobot_modality_meta = self._get_lerobot_modality_meta()
285
+ self._lerobot_info_meta = self._get_lerobot_info_meta()
286
+ self._data_path_pattern = self._get_data_path_pattern()
287
+ self._video_path_pattern = self._get_video_path_pattern()
288
+ self._chunk_size = self._get_chunk_size()
289
+ self._tasks = self._get_tasks()
290
+ self.curr_traj_data = None
291
+ self.curr_traj_id = None
292
+
293
+ self._trajectory_ids, self._trajectory_lengths = self._get_trajectories()
294
+ self._modality_keys = self._get_modality_keys()
295
+ self._delta_indices = self._get_delta_indices()
296
+ self._all_steps = self._get_all_steps()
297
+ self.set_transforms_metadata(self.metadata)
298
+ self.set_epoch(0)
299
+
300
+ print(f"Initialized dataset {self.dataset_name} with {embodiment_tag}")
301
+
302
+
303
+ # Check if the dataset is valid
304
+ self._check_integrity()
305
+
306
+ @property
307
+ def dataset_path(self) -> Path:
308
+ """The path to the dataset that contains the METADATA_FILENAME file."""
309
+ return self._dataset_path
310
+
311
+ @property
312
+ def metadata(self) -> DatasetMetadata:
313
+ """The metadata for the dataset, loaded from metadata.json in the dataset directory"""
314
+ return self._metadata
315
+
316
+ @property
317
+ def trajectory_ids(self) -> np.ndarray:
318
+ """The trajectory IDs in the dataset, stored as a 1D numpy array of strings."""
319
+ return self._trajectory_ids
320
+
321
+ @property
322
+ def trajectory_lengths(self) -> np.ndarray:
323
+ """The trajectory lengths in the dataset, stored as a 1D numpy array of integers.
324
+ The order of the lengths is the same as the order of the trajectory IDs.
325
+ """
326
+ return self._trajectory_lengths
327
+
328
+ @property
329
+ def all_steps(self) -> list[tuple[int, int]]:
330
+ """The trajectory IDs and base indices for all steps in the dataset.
331
+ Example:
332
+ self.trajectory_ids: [0, 1, 2]
333
+ self.trajectory_lengths: [3, 2, 4]
334
+ return: [
335
+ ("traj_0", 0), ("traj_0", 1), ("traj_0", 2),
336
+ ("traj_1", 0), ("traj_1", 1),
337
+ ("traj_2", 0), ("traj_2", 1), ("traj_2", 2), ("traj_2", 3)
338
+ ]
339
+ """
340
+ return self._all_steps
341
+
342
+ @property
343
+ def modality_keys(self) -> dict:
344
+ """The modality keys for the dataset. The keys are the modality names, and the values are the keys for each modality.
345
+
346
+ Example: {
347
+ "video": ["video.image_side_0", "video.image_side_1"],
348
+ "state": ["state.eef_position", "state.eef_rotation"],
349
+ "action": ["action.eef_position", "action.eef_rotation"],
350
+ "language": ["language.human.task"],
351
+ "timestamp": ["timestamp"],
352
+ "reward": ["reward"],
353
+ }
354
+ """
355
+ return self._modality_keys
356
+
357
+ @property
358
+ def delta_indices(self) -> dict[str, np.ndarray]:
359
+ """The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key."""
360
+ return self._delta_indices
361
+
362
+ @property
363
+ def dataset_name(self) -> str:
364
+ """The name of the dataset."""
365
+ return self._dataset_name
366
+
367
+ @property
368
+ def lerobot_modality_meta(self) -> LeRobotModalityMetadata:
369
+ """The metadata for the LeRobot dataset."""
370
+ return self._lerobot_modality_meta
371
+
372
+ @property
373
+ def lerobot_info_meta(self) -> dict:
374
+ """The metadata for the LeRobot dataset."""
375
+ return self._lerobot_info_meta
376
+
377
+ @property
378
+ def data_path_pattern(self) -> str:
379
+ """The path pattern for the LeRobot dataset."""
380
+ return self._data_path_pattern
381
+
382
+ @property
383
+ def video_path_pattern(self) -> str:
384
+ """The path pattern for the LeRobot dataset."""
385
+ return self._video_path_pattern
386
+
387
+ @property
388
+ def chunk_size(self) -> int:
389
+ """The chunk size for the LeRobot dataset."""
390
+ return self._chunk_size
391
+
392
+ @property
393
+ def tasks(self) -> pd.DataFrame:
394
+ """The tasks for the dataset."""
395
+ return self._tasks
396
+
397
+ def _get_metadata(self, embodiment_tag: EmbodimentTag) -> DatasetMetadata:
398
+ """Get the metadata for the dataset.
399
+
400
+ Returns:
401
+ dict: The metadata for the dataset.
402
+ """
403
+
404
+ # 1. Modality metadata
405
+ modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME
406
+ assert (
407
+ modality_meta_path.exists()
408
+ ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}"
409
+ # 1.1. State and action modalities
410
+ simplified_modality_meta: dict[str, dict] = {}
411
+ with open(modality_meta_path, "r") as f:
412
+ le_modality_meta = LeRobotModalityMetadata.model_validate(json.load(f))
413
+ for modality in ["state", "action"]:
414
+ simplified_modality_meta[modality] = {}
415
+ le_state_action_meta: dict[str, LeRobotStateActionMetadata] = getattr(
416
+ le_modality_meta, modality
417
+ )
418
+ for subkey in le_state_action_meta:
419
+ state_action_dtype = np.dtype(le_state_action_meta[subkey].dtype)
420
+ if np.issubdtype(state_action_dtype, np.floating):
421
+ continuous = True
422
+ else:
423
+ continuous = False
424
+ simplified_modality_meta[modality][subkey] = {
425
+ "absolute": le_state_action_meta[subkey].absolute,
426
+ "rotation_type": le_state_action_meta[subkey].rotation_type,
427
+ "shape": [
428
+ le_state_action_meta[subkey].end - le_state_action_meta[subkey].start
429
+ ],
430
+ "continuous": continuous,
431
+ }
432
+
433
+ # 1.2. Video modalities
434
+ le_info_path = self.dataset_path / LE_ROBOT_INFO_FILENAME
435
+ assert (
436
+ le_info_path.exists()
437
+ ), f"Please provide a {LE_ROBOT_INFO_FILENAME} file in {self.dataset_path}"
438
+ with open(le_info_path, "r") as f:
439
+ le_info = json.load(f)
440
+ simplified_modality_meta["video"] = {}
441
+ for new_key in le_modality_meta.video:
442
+ original_key = le_modality_meta.video[new_key].original_key
443
+ if original_key is None:
444
+ original_key = new_key
445
+ le_video_meta = le_info["features"][original_key]
446
+ height = le_video_meta["shape"][le_video_meta["names"].index("height")]
447
+ width = le_video_meta["shape"][le_video_meta["names"].index("width")]
448
+ # NOTE(FH): different lerobot dataset versions have different keys for the number of channels and fps
449
+ try:
450
+ channels = le_video_meta["shape"][le_video_meta["names"].index("channel")]
451
+ fps = le_video_meta["video_info"]["video.fps"]
452
+ except (ValueError, KeyError):
453
+ # channels = le_video_meta["shape"][le_video_meta["names"].index("channels")]
454
+ channels = le_video_meta["info"]["video.channels"]
455
+ fps = le_video_meta["info"]["video.fps"]
456
+ simplified_modality_meta["video"][new_key] = {
457
+ "resolution": [width, height],
458
+ "channels": channels,
459
+ "fps": fps,
460
+ }
461
+
462
+ # 2. Dataset statistics
463
+ stats_path = self.dataset_path / LE_ROBOT_STATS_FILENAME
464
+ try:
465
+ with open(stats_path, "r") as f:
466
+ le_statistics = json.load(f)
467
+ for stat in le_statistics.values():
468
+ DatasetStatisticalValues.model_validate(stat)
469
+ except (FileNotFoundError, ValidationError) as e:
470
+ print(f"Failed to load dataset statistics: {e}")
471
+ print(f"Calculating dataset statistics for {self.dataset_name}")
472
+ # Get all parquet files in the dataset paths
473
+ parquet_files = list((self.dataset_path).glob(LE_ROBOT_DATA_FILENAME))
474
+ parquet_files_filtered = []
475
+ # parquet_files[0].name = "episode_033675.parquet" is broken file
476
+ for pf in parquet_files:
477
+ if "episode_033675.parquet" in pf.name:
478
+ continue
479
+ parquet_files_filtered.append(pf)
480
+
481
+ le_statistics = calculate_dataset_statistics(parquet_files_filtered)
482
+ with open(stats_path, "w") as f:
483
+ json.dump(le_statistics, f, indent=4)
484
+ dataset_statistics = {}
485
+ for our_modality in ["state", "action"]:
486
+ dataset_statistics[our_modality] = {}
487
+ for subkey in simplified_modality_meta[our_modality]:
488
+ dataset_statistics[our_modality][subkey] = {}
489
+ state_action_meta = le_modality_meta.get_key_meta(f"{our_modality}.{subkey}")
490
+ assert isinstance(state_action_meta, LeRobotStateActionMetadata)
491
+ le_modality = state_action_meta.original_key
492
+ for stat_name in le_statistics[le_modality]:
493
+ indices = np.arange(
494
+ state_action_meta.start,
495
+ state_action_meta.end,
496
+ )
497
+ stat = np.array(le_statistics[le_modality][stat_name])
498
+ dataset_statistics[our_modality][subkey][stat_name] = stat[indices].tolist()
499
+
500
+ # 3. Full dataset metadata
501
+ metadata = DatasetMetadata(
502
+ statistics=dataset_statistics, # type: ignore
503
+ modalities=simplified_modality_meta, # type: ignore
504
+ embodiment_tag=embodiment_tag,
505
+ )
506
+
507
+ return metadata
508
+
509
+ def _get_trajectories(self) -> tuple[np.ndarray, np.ndarray]:
510
+ """Get the trajectories in the dataset."""
511
+ # Get trajectory lengths, IDs, and whitelist from dataset metadata
512
+ # v2.0
513
+ if self._lerobot_version == "v2.0":
514
+ file_path = self.dataset_path / LE_ROBOT_EPISODE_FILENAME
515
+ with open(file_path, "r") as f:
516
+ episode_metadata = [json.loads(line) for line in f]
517
+ trajectory_ids = []
518
+ trajectory_lengths = []
519
+ for episode in episode_metadata:
520
+ trajectory_ids.append(episode["episode_index"])
521
+ trajectory_lengths.append(episode["length"])
522
+ return np.array(trajectory_ids), np.array(trajectory_lengths)
523
+ # v3.0
524
+ elif self._lerobot_version == "v3.0":
525
+ file_paths = list((self.dataset_path).glob(LE_ROBOT3_EPISODE_FILENAME))
526
+ trajectory_ids = []
527
+ trajectory_lengths = []
528
+ # data_chunck_index = []
529
+ # data_file_index = []
530
+ # vido_from_index = []
531
+ self.trajectory_ids_to_metadata = {}
532
+ for file_path in file_paths:
533
+ episodes_data = pd.read_parquet(file_path)
534
+ for index, episode in episodes_data.iterrows():
535
+ trajectory_ids.append(episode["episode_index"])
536
+ trajectory_lengths.append(episode["length"])
537
+
538
+ # TODO auto map key? just map to file_path and file_from_index
539
+ episode_meta = {
540
+ "data/chunk_index": episode["data/chunk_index"],
541
+ "data/file_index": episode["data/file_index"],
542
+ "data/file_from_index": index,
543
+ }
544
+ if self.load_video:
545
+ episode_meta["videos/observation.images.wrist/from_timestamp"] = episode[
546
+ "videos/observation.images.wrist/from_timestamp"
547
+ ]
548
+ self.trajectory_ids_to_metadata[trajectory_ids[-1]] = episode_meta
549
+
550
+ # 这里应该可以直接读取到 save index 信息
551
+ return np.array(trajectory_ids), np.array(trajectory_lengths)
552
+
553
+ def _get_all_steps(self) -> list[tuple[int, int]]:
554
+ """Get the trajectory IDs and base indices for all steps in the dataset.
555
+
556
+ Returns:
557
+ list[tuple[str, int]]: A list of (trajectory_id, base_index) tuples.
558
+ """
559
+ # Create a hash key based on configuration to ensure cache validity
560
+ config_key = self._get_steps_config_key()
561
+
562
+ # Create a unique filename based on config_key
563
+ # steps_filename = f"steps_{config_key}.pkl"
564
+ # @BUG
565
+ # fast get static steps @fangjing --> don't use hash to dynamic sample
566
+ steps_filename = "steps_data_index.pkl"
567
+
568
+
569
+ steps_path = self.dataset_path / "meta" / steps_filename
570
+
571
+ # Try to load cached steps first
572
+ try:
573
+ if steps_path.exists():
574
+ with open(steps_path, "rb") as f:
575
+ cached_data = pickle.load(f)
576
+ return cached_data["steps"]
577
+
578
+ except (FileNotFoundError, pickle.PickleError, KeyError) as e:
579
+ print(f"Failed to load cached steps: {e}")
580
+ print("Computing steps from scratch...")
581
+
582
+ # Compute steps using single process
583
+ all_steps = self._get_all_steps_single_process()
584
+
585
+ # Cache the computed steps with unique filename
586
+ try:
587
+ cache_data = {
588
+ "config_key": config_key,
589
+ "steps": all_steps,
590
+ "num_trajectories": len(self.trajectory_ids),
591
+ "total_steps": len(all_steps),
592
+ "computed_timestamp": pd.Timestamp.now().isoformat(),
593
+ "delete_pause_frame": self.delete_pause_frame,
594
+ }
595
+
596
+ # Ensure the meta directory exists
597
+ steps_path.parent.mkdir(parents=True, exist_ok=True)
598
+
599
+ with open(steps_path, "wb") as f:
600
+ pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
601
+ print(f"Cached steps saved to {steps_path}")
602
+ except Exception as e:
603
+ print(f"Failed to cache steps: {e}")
604
+
605
+ return all_steps
606
+
607
+ def _get_steps_config_key(self) -> str:
608
+ """Generate a configuration key for steps caching."""
609
+ config_dict = {
610
+ "delete_pause_frame": self.delete_pause_frame,
611
+ "dataset_name": self.dataset_name,
612
+ }
613
+ # Create a hash of the configuration
614
+ config_str = str(sorted(config_dict.items()))
615
+ return hashlib.md5(config_str.encode()).hexdigest()[:12] #
616
+
617
+
618
+ def _get_all_steps_single_process(self) -> list[tuple[int, int]]:
619
+ """Original single-process implementation as fallback."""
620
+ all_steps: list[tuple[int, int]] = []
621
+ skipped_trajectories = 0
622
+ processed_trajectories = 0
623
+
624
+ # Check if language modality is configured
625
+ has_language_modality = 'language' in self.modality_keys and len(self.modality_keys['language']) > 0
626
+ # TODO why trajectory_length here, why not use data length?
627
+ for trajectory_id, trajectory_length in tqdm(zip(self.trajectory_ids, self.trajectory_lengths), total=len(self.trajectory_ids), desc="Getting All Step"):
628
+ try:
629
+ if self._lerobot_version == "v2.0":
630
+ data = self.get_trajectory_data(trajectory_id)
631
+ elif self._lerobot_version == "v3.0":
632
+ data = self.get_trajectory_data_lerobot_v3(trajectory_id)
633
+
634
+ trajectory_skipped = False
635
+
636
+ # Check if trajectory has valid language instruction (if language modality is configured)
637
+ if has_language_modality:
638
+ self.curr_traj_data = data # Set current trajectory data for get_language to work
639
+
640
+ language_instruction = self.get_language(trajectory_id, self.modality_keys['language'][0], 0)
641
+ if not language_instruction or language_instruction[0] == "":
642
+ print(f"Skipping trajectory {trajectory_id} due to empty language instruction")
643
+ skipped_trajectories += 1
644
+ trajectory_skipped = True
645
+ continue
646
+
647
+ except Exception as e:
648
+ print(f"Skipping trajectory {trajectory_id} due to read error: {e}")
649
+ skipped_trajectories += 1
650
+ trajectory_skipped = True
651
+ continue
652
+
653
+ if not trajectory_skipped:
654
+ processed_trajectories += 1
655
+
656
+ for base_index in range(trajectory_length):
657
+ all_steps.append((trajectory_id, base_index))
658
+
659
+ # Print summary statistics
660
+ print(f"Single-process summary: Processed {processed_trajectories} trajectories, skipped {skipped_trajectories} empty trajectories")
661
+ print(f"Total steps: {len(all_steps)} from {len(self.trajectory_ids)} trajectories")
662
+
663
+ return all_steps
664
+
665
+ def _get_position_and_gripper_values(self, data: pd.DataFrame) -> tuple[list, list]:
666
+ """Get position and gripper values based on available columns in the dataset."""
667
+ # Get action keys from modality_keys
668
+ action_keys = self.modality_keys.get('action', [])
669
+
670
+ # Extract position data
671
+ delta_position_values = None
672
+ position_candidates = ['delta_eef_position']
673
+ coordinate_candidates = ['x', 'y', 'z']
674
+
675
+ # First try combined position fields
676
+ for pos_key in position_candidates:
677
+ full_key = f"action.{pos_key}"
678
+ if full_key in action_keys:
679
+ try:
680
+ # Get the lerobot key for this modality
681
+ le_action_cfg = self.lerobot_modality_meta.action
682
+ subkey = pos_key
683
+ if subkey in le_action_cfg:
684
+ le_key = le_action_cfg[subkey].original_key or subkey
685
+ if le_key in data.columns:
686
+ data_array = np.stack(data[le_key])
687
+ le_indices = np.arange(le_action_cfg[subkey].start, le_action_cfg[subkey].end)
688
+ filtered_data = data_array[:, le_indices]
689
+ delta_position_values = filtered_data.tolist()
690
+ break
691
+ except Exception:
692
+ continue
693
+
694
+ # If combined fields not found, try individual x,y,z coordinates
695
+ if delta_position_values is None:
696
+ x_data, y_data, z_data = None, None, None
697
+ for coord in coordinate_candidates:
698
+ full_key = f"action.{coord}"
699
+ if full_key in action_keys:
700
+ try:
701
+ le_action_cfg = self.lerobot_modality_meta.action
702
+ if coord in le_action_cfg:
703
+ le_key = le_action_cfg[coord].original_key or coord
704
+ if le_key in data.columns:
705
+ data_array = np.stack(data[le_key])
706
+ le_indices = np.arange(le_action_cfg[coord].start, le_action_cfg[coord].end)
707
+ coord_data = data_array[:, le_indices].flatten()
708
+ if coord == 'x':
709
+ x_data = coord_data
710
+ elif coord == 'y':
711
+ y_data = coord_data
712
+ elif coord == 'z':
713
+ z_data = coord_data
714
+ except Exception:
715
+ continue
716
+
717
+ if x_data is not None and y_data is not None and z_data is not None:
718
+ delta_position_values = np.column_stack((x_data, y_data, z_data)).tolist()
719
+
720
+ if delta_position_values is None:
721
+ # Fallback to the old hardcoded approach if metadata approach fails
722
+ if 'action.delta_eef_position' in data.columns:
723
+ delta_position_values = data['action.delta_eef_position'].to_numpy().tolist()
724
+ elif all(col in data.columns for col in ['action.x', 'action.y', 'action.z']):
725
+ x_vals = data['action.x'].to_numpy()
726
+ y_vals = data['action.y'].to_numpy()
727
+ z_vals = data['action.z'].to_numpy()
728
+ delta_position_values = np.column_stack((x_vals, y_vals, z_vals)).tolist()
729
+ else:
730
+ raise ValueError(f"No suitable position columns found. Available columns: {data.columns.tolist()}")
731
+
732
+ # Extract gripper data
733
+ gripper_values = None
734
+ gripper_candidates = ['gripper_close', 'gripper']
735
+
736
+ for grip_key in gripper_candidates:
737
+ full_key = f"action.{grip_key}"
738
+ if full_key in action_keys:
739
+ try:
740
+ le_action_cfg = self.lerobot_modality_meta.action
741
+ if grip_key in le_action_cfg:
742
+ le_key = le_action_cfg[grip_key].original_key or grip_key
743
+ if le_key in data.columns:
744
+ data_array = np.stack(data[le_key])
745
+ le_indices = np.arange(le_action_cfg[grip_key].start, le_action_cfg[grip_key].end)
746
+ gripper_data = data_array[:, le_indices].flatten()
747
+ gripper_values = gripper_data.tolist()
748
+ break
749
+ except Exception:
750
+ continue
751
+
752
+ if gripper_values is None:
753
+ # Fallback to the old hardcoded approach if metadata approach fails
754
+ if 'action.gripper_close' in data.columns:
755
+ gripper_values = data['action.gripper_close'].to_numpy().tolist()
756
+ elif 'action.gripper' in data.columns:
757
+ gripper_values = data['action.gripper'].to_numpy().tolist()
758
+ else:
759
+ raise ValueError(f"No suitable gripper columns found. Available columns: {data.columns.tolist()}")
760
+
761
+ return delta_position_values, gripper_values
762
+
763
+ def _get_modality_keys(self) -> dict:
764
+ """Get the modality keys for the dataset.
765
+ The keys are the modality names, and the values are the keys for each modality.
766
+ See property `modality_keys` for the expected format.
767
+ """
768
+ modality_keys = defaultdict(list)
769
+ for modality, config in self.modality_configs.items():
770
+ modality_keys[modality] = config.modality_keys
771
+ return modality_keys
772
+
773
+ def _get_delta_indices(self) -> dict[str, np.ndarray]:
774
+ """Restructure the delta indices to use modality.key as keys instead of just the modalities."""
775
+ delta_indices: dict[str, np.ndarray] = {}
776
+ for config in self.modality_configs.values():
777
+ for key in config.modality_keys:
778
+ delta_indices[key] = np.array(config.delta_indices)
779
+ return delta_indices
780
+
781
+ def _get_lerobot_modality_meta(self) -> LeRobotModalityMetadata:
782
+ """Get the metadata for the LeRobot dataset."""
783
+ modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME
784
+ assert (
785
+ modality_meta_path.exists()
786
+ ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}"
787
+ with open(modality_meta_path, "r") as f:
788
+ modality_meta = LeRobotModalityMetadata.model_validate(json.load(f))
789
+ return modality_meta
790
+
791
+ def _get_lerobot_info_meta(self) -> dict:
792
+ """Get the metadata for the LeRobot dataset."""
793
+ info_meta_path = self.dataset_path / LE_ROBOT_INFO_FILENAME
794
+ with open(info_meta_path, "r") as f:
795
+ info_meta = json.load(f)
796
+ return info_meta
797
+
798
+ def _get_data_path_pattern(self) -> str:
799
+ """Get the data path pattern for the LeRobot dataset."""
800
+ return self.lerobot_info_meta["data_path"]
801
+
802
+ def _get_video_path_pattern(self) -> str:
803
+ """Get the video path pattern for the LeRobot dataset."""
804
+ return self.lerobot_info_meta["video_path"]
805
+
806
+ def _get_chunk_size(self) -> int:
807
+ """Get the chunk size for the LeRobot dataset."""
808
+ return self.lerobot_info_meta["chunks_size"]
809
+
810
+ def _get_tasks(self) -> pd.DataFrame:
811
+ """Get the tasks for the dataset."""
812
+ if self._lerobot_version == "v2.0":
813
+ tasks_path = self.dataset_path / LE_ROBOT_TASKS_FILENAME
814
+ with open(tasks_path, "r") as f:
815
+ tasks = [json.loads(line) for line in f]
816
+ df = pd.DataFrame(tasks)
817
+ return df.set_index("task_index")
818
+
819
+ elif self._lerobot_version == "v3.0":
820
+ tasks_path = self.dataset_path / LE_ROBOT3_TASKS_FILENAME
821
+ df = pd.read_parquet(tasks_path)
822
+ df = df.reset_index() # 把索引变成一列,列名通常为 'index'
823
+ df = df.rename(columns={'index': 'task'}) # 把 'index' 列重命名为 'task'
824
+ df = df[['task_index', 'task']] # 调整列顺序
825
+ return df
826
+ def _check_integrity(self):
827
+ """Use the config to check if the keys are valid and detect silent data corruption."""
828
+ ERROR_MSG_HEADER = f"Error occurred in initializing dataset {self.dataset_name}:\n"
829
+
830
+ for modality_config in self.modality_configs.values():
831
+ for key in modality_config.modality_keys:
832
+ if key == "lapa_action" or key == "dream_actions":
833
+ continue # no need for any metadata for lapa actions because it comes normalized
834
+ # Check if the key is valid
835
+ try:
836
+ self.lerobot_modality_meta.get_key_meta(key)
837
+ except Exception as e:
838
+ raise ValueError(
839
+ ERROR_MSG_HEADER + f"Unable to find key {key} in modality metadata:\n{e}"
840
+ )
841
+
842
+ def set_transforms_metadata(self, metadata: DatasetMetadata):
843
+ """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values."""
844
+ self.transforms.set_metadata(metadata)
845
+
846
+ def set_epoch(self, epoch: int):
847
+ """Set the epoch for the dataset.
848
+
849
+ Args:
850
+ epoch (int): The epoch to set.
851
+ """
852
+ self.epoch = epoch
853
+
854
+ def __len__(self) -> int:
855
+ """Get the total number of data points in the dataset.
856
+
857
+ Returns:
858
+ int: the total number of data points in the dataset.
859
+ """
860
+ return len(self.all_steps)
861
+
862
+ def __str__(self) -> str:
863
+ """Get the description of the dataset."""
864
+ return f"{self.dataset_name} ({len(self)} steps)"
865
+
866
+
867
+ def __getitem__(self, index: int) -> dict:
868
+ """Get the data for a single step in a trajectory.
869
+
870
+ Args:
871
+ index (int): The index of the step to get.
872
+
873
+ Returns:
874
+ dict: The data for the step.
875
+ """
876
+ trajectory_id, base_index = self.all_steps[index]
877
+ data = self.get_step_data(trajectory_id, base_index)
878
+
879
+ # Process all video keys dynamically
880
+ images = []
881
+ mid_images = []
882
+ for video_key in self.modality_keys.get("video", []):
883
+ video_frames = data[video_key]
884
+ image = video_frames[0]
885
+ image = Image.fromarray(image).resize((224, 224))
886
+ images.append(image)
887
+ if self.num_history_steps != 0:
888
+ history_index = min(self.num_history_steps - 1, len(video_frames) - 1)
889
+ mid_image = video_frames[history_index]
890
+ mid_image = Image.fromarray(mid_image).resize((224, 224))
891
+ mid_images.append(mid_image)
892
+
893
+ # Get language and action data
894
+ language = data[self.modality_keys["language"][0]][0]
895
+ action = []
896
+ for action_key in self.modality_keys["action"]:
897
+ action.append(data[action_key])
898
+ action = np.concatenate(action, axis=1)
899
+ action = standardize_action_representation(action, self.tag)
900
+
901
+ state = []
902
+ for state_key in self.modality_keys["state"]:
903
+ state.append(data[state_key])
904
+ state = np.concatenate(state, axis=1)
905
+ state = standardize_state_representation(state, self.tag)
906
+
907
+ sample = dict(action=action, state=state, image=images, language=language, dataset_id=self._dataset_id)
908
+ if self.num_history_steps != 0:
909
+ sample["mid_image"] = mid_images
910
+ return sample
911
+
912
+ def get_step_data(self, trajectory_id: int, base_index: int) -> dict:
913
+ """Get the RAW data for a single step in a trajectory. No transforms are applied.
914
+
915
+ Args:
916
+ trajectory_id (int): The name of the trajectory.
917
+ base_index (int): The base step index in the trajectory.
918
+
919
+ Returns:
920
+ dict: The RAW data for the step.
921
+
922
+ Example return:
923
+ {
924
+ "video": {
925
+ "video.image_side_0": [B, T, H, W, C],
926
+ "video.image_side_1": [B, T, H, W, C],
927
+ },
928
+ "state": {
929
+ "state.eef_position": [B, T, state_dim],
930
+ "state.eef_rotation": [B, T, state_dim],
931
+ },
932
+ "action": {
933
+ "action.eef_position": [B, T, action_dim],
934
+ "action.eef_rotation": [B, T, action_dim],
935
+ },
936
+ }
937
+ """
938
+ data = {}
939
+ # Get the data for all modalities # just for action base data
940
+ self.curr_traj_data = self.get_trajectory_data(trajectory_id)
941
+ # TODO @JinhuiYE The logic below is poorly implemented. Data reading should be directly based on curr_traj_data.
942
+ for modality in self.modality_keys:
943
+ # Get the data corresponding to each key in the modality
944
+ for key in self.modality_keys[modality]:
945
+ data[key] = self.get_data_by_modality(trajectory_id, modality, key, base_index)
946
+ return data
947
+
948
+ def get_trajectory_data(self, trajectory_id: int) -> pd.DataFrame:
949
+ """Get the data for a trajectory."""
950
+ if self._lerobot_version == "v2.0":
951
+
952
+ if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None:
953
+ return self.curr_traj_data
954
+ else:
955
+ chunk_index = self.get_episode_chunk(trajectory_id)
956
+ parquet_path = self.dataset_path / self.data_path_pattern.format(
957
+ episode_chunk=chunk_index, episode_index=trajectory_id
958
+ )
959
+ assert parquet_path.exists(), f"Parquet file not found at {parquet_path}"
960
+ return pd.read_parquet(parquet_path)
961
+ elif self._lerobot_version == "v3.0":
962
+ return self.get_trajectory_data_lerobot_v3(trajectory_id)
963
+
964
+ def get_trajectory_data_lerobot_v3(self, trajectory_id: int) -> pd.DataFrame:
965
+ """Get the data for a trajectory from lerobot v3."""
966
+ if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None:
967
+ return self.curr_traj_data
968
+ else: #TODO check detail later
969
+ chunk_index = self.get_episode_chunk(trajectory_id)
970
+
971
+ file_index = self.get_episode_file_index(trajectory_id)
972
+ # file_from_index = self.get_episode_file_from_index(trajectory_id)
973
+
974
+
975
+ parquet_path = self.dataset_path / self.data_path_pattern.format(
976
+ chunk_index=chunk_index, file_index=file_index
977
+ )
978
+ assert parquet_path.exists(), f"Parquet file not found at {parquet_path}"
979
+ file_data = pd.read_parquet(parquet_path)
980
+
981
+ # filter by trajectory_id
982
+ episode_data = file_data.loc[file_data["episode_index"] == trajectory_id].copy()
983
+
984
+ # fix timestamp from epis index to file index for video alignment
985
+ if self.load_video:
986
+ from_timestamp = self.trajectory_ids_to_metadata[trajectory_id].get(
987
+ "videos/observation.images.wrist/from_timestamp", 0
988
+ )
989
+ episode_data["timestamp"] = episode_data["timestamp"] + from_timestamp
990
+
991
+ return episode_data
992
+
993
+
994
+ def get_trajectory_index(self, trajectory_id: int) -> int:
995
+ """Get the index of the trajectory in the dataset by the trajectory ID.
996
+ This is useful when you need to get the trajectory length or sampling weight corresponding to the trajectory ID.
997
+
998
+ Args:
999
+ trajectory_id (str): The ID of the trajectory.
1000
+
1001
+ Returns:
1002
+ int: The index of the trajectory in the dataset.
1003
+ """
1004
+ trajectory_indices = np.where(self.trajectory_ids == trajectory_id)[0]
1005
+ if len(trajectory_indices) != 1:
1006
+ raise ValueError(
1007
+ f"Error finding trajectory index for {trajectory_id}, found {trajectory_indices=}"
1008
+ )
1009
+ return trajectory_indices[0]
1010
+
1011
+ def get_episode_chunk(self, ep_index: int) -> int:
1012
+ """Get the chunk index for an episode index."""
1013
+ return ep_index // self.chunk_size
1014
+ def get_episode_file_index(self, ep_index: int) -> int:
1015
+ """Get the file index for an episode index."""
1016
+ episode_meta = self.trajectory_ids_to_metadata[ep_index]
1017
+ return episode_meta["data/file_index"]
1018
+
1019
+ def get_episode_file_from_index(self, ep_index: int) -> int:
1020
+ """Get the file from index for an episode index."""
1021
+ episode_meta = self.trajectory_ids_to_metadata[ep_index]
1022
+ return episode_meta["data/file_from_index"]
1023
+
1024
+
1025
+ def retrieve_data_and_pad(
1026
+ self,
1027
+ array: np.ndarray,
1028
+ step_indices: np.ndarray,
1029
+ max_length: int,
1030
+ padding_strategy: str = "first_last",
1031
+ ) -> np.ndarray:
1032
+ """Retrieve the data from the dataset and pad it if necessary.
1033
+ Args:
1034
+ array (np.ndarray): The array to retrieve the data from.
1035
+ step_indices (np.ndarray): The step indices to retrieve the data for.
1036
+ max_length (int): The maximum length of the data.
1037
+ padding_strategy (str): The padding strategy, either "first" or "last".
1038
+ """
1039
+ # Get the padding indices
1040
+ front_padding_indices = step_indices < 0
1041
+ end_padding_indices = step_indices >= max_length
1042
+ padding_positions = np.logical_or(front_padding_indices, end_padding_indices)
1043
+ # Retrieve the data with the non-padding indices
1044
+ # If there exists some padding, Given T step_indices, the shape of the retrieved data will be (T', ...) where T' < T
1045
+ raw_data = array[step_indices[~padding_positions]]
1046
+ assert isinstance(raw_data, np.ndarray), f"{type(raw_data)=}"
1047
+ # This is the shape of the output, (T, ...)
1048
+ if raw_data.ndim == 1:
1049
+ expected_shape = (len(step_indices),)
1050
+ else:
1051
+ expected_shape = (len(step_indices), *array.shape[1:])
1052
+
1053
+ # Pad the data
1054
+ output = np.zeros(expected_shape)
1055
+ # Assign the non-padded data
1056
+ output[~padding_positions] = raw_data
1057
+ # If there exists some padding, pad the data
1058
+ if padding_positions.any():
1059
+ if padding_strategy == "first_last":
1060
+ # Use first / last step data to pad
1061
+ front_padding_data = array[0]
1062
+ end_padding_data = array[-1]
1063
+ output[front_padding_indices] = front_padding_data
1064
+ output[end_padding_indices] = end_padding_data
1065
+ elif padding_strategy == "zero":
1066
+ # Use zero padding
1067
+ output[padding_positions] = 0
1068
+ else:
1069
+ raise ValueError(f"Invalid padding strategy: {padding_strategy}")
1070
+ return output
1071
+
1072
+ def get_video_path(self, trajectory_id: int, key: str) -> Path:
1073
+ chunk_index = self.get_episode_chunk(trajectory_id)
1074
+ original_key = self.lerobot_modality_meta.video[key].original_key
1075
+ if original_key is None:
1076
+ original_key = key
1077
+ if self._lerobot_version == "v2.0":
1078
+ video_filename = self.video_path_pattern.format(
1079
+ episode_chunk=chunk_index, episode_index=trajectory_id, video_key=original_key
1080
+ )
1081
+ elif self._lerobot_version == "v3.0":
1082
+ episode_meta = self.trajectory_ids_to_metadata[trajectory_id]
1083
+ video_filename = self.video_path_pattern.format(
1084
+ video_key=original_key,
1085
+ chunk_index=episode_meta["data/chunk_index"],
1086
+ file_index=episode_meta["data/file_index"],
1087
+ )
1088
+ return self.dataset_path / video_filename
1089
+
1090
+ def get_video(
1091
+ self,
1092
+ trajectory_id: int,
1093
+ key: str,
1094
+ base_index: int,
1095
+ ) -> np.ndarray:
1096
+ """Get the video frames for a trajectory by a base index.
1097
+
1098
+ Args:
1099
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1100
+ trajectory_id (str): The ID of the trajectory.
1101
+ key (str): The key of the video.
1102
+ base_index (int): The base index of the trajectory.
1103
+
1104
+ Returns:
1105
+ np.ndarray: The video frames for the trajectory and frame indices. Shape: (T, H, W, C)
1106
+ """
1107
+ # Get the step indices
1108
+ step_indices = self.delta_indices[key] + base_index
1109
+ # print(f"{step_indices=}")
1110
+ # Get the trajectory index
1111
+ trajectory_index = self.get_trajectory_index(trajectory_id)
1112
+ # Ensure the indices are within the valid range
1113
+ # This is equivalent to padding the video with extra frames at the beginning and end
1114
+ step_indices = np.maximum(step_indices, 0)
1115
+ step_indices = np.minimum(step_indices, self.trajectory_lengths[trajectory_index] - 1)
1116
+ assert key.startswith("video."), f"Video key must start with 'video.', got {key}"
1117
+ # Get the sub-key
1118
+ key = key.replace("video.", "")
1119
+ video_path = self.get_video_path(trajectory_id, key)
1120
+ # Get the action/state timestamps for each frame in the video
1121
+ assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
1122
+ assert "timestamp" in self.curr_traj_data.columns, f"No timestamp found in {trajectory_id=}"
1123
+ timestamp: np.ndarray = self.curr_traj_data["timestamp"].to_numpy()
1124
+ # Get the corresponding video timestamps from the step indices
1125
+ video_timestamp = timestamp[step_indices]
1126
+
1127
+ return get_frames_by_timestamps(
1128
+ video_path.as_posix(),
1129
+ video_timestamp,
1130
+ video_backend=self.video_backend, # TODO
1131
+ video_backend_kwargs=self.video_backend_kwargs,
1132
+ )
1133
+
1134
+ def get_state_or_action(
1135
+ self,
1136
+ trajectory_id: int,
1137
+ modality: str,
1138
+ key: str,
1139
+ base_index: int,
1140
+ ) -> np.ndarray:
1141
+ """Get the state or action data for a trajectory by a base index.
1142
+ If the step indices are out of range, pad with the data:
1143
+ if the data is stored in absolute format, pad with the first or last step data;
1144
+ otherwise, pad with zero.
1145
+
1146
+ Args:
1147
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1148
+ trajectory_id (int): The ID of the trajectory.
1149
+ modality (str): The modality of the data.
1150
+ key (str): The key of the data.
1151
+ base_index (int): The base index of the trajectory.
1152
+
1153
+ Returns:
1154
+ np.ndarray: The data for the trajectory and step indices.
1155
+ """
1156
+ # Get the step indices
1157
+ step_indices = self.delta_indices[key] + base_index
1158
+ # Get the trajectory index
1159
+ trajectory_index = self.get_trajectory_index(trajectory_id)
1160
+ # Get the maximum length of the trajectory
1161
+ max_length = self.trajectory_lengths[trajectory_index]
1162
+ assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}"
1163
+ # Get the sub-key, e.g. state.joint_angles -> joint_angles
1164
+ key = key.replace(modality + ".", "")
1165
+ # Get the lerobot key
1166
+ le_state_or_action_cfg = getattr(self.lerobot_modality_meta, modality)
1167
+ le_key = le_state_or_action_cfg[key].original_key
1168
+ if le_key is None:
1169
+ le_key = key
1170
+ # Get the data array, shape: (T, D)
1171
+ assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
1172
+ assert le_key in self.curr_traj_data.columns, f"No {le_key} found in {trajectory_id=}"
1173
+ data_array: np.ndarray = np.stack(self.curr_traj_data[le_key]) # type: ignore
1174
+ assert data_array.ndim == 2, f"Expected 2D array, got key {le_key} is{data_array.shape} array"
1175
+ le_indices = np.arange(
1176
+ le_state_or_action_cfg[key].start,
1177
+ le_state_or_action_cfg[key].end,
1178
+ )
1179
+ data_array = data_array[:, le_indices]
1180
+ # Get the state or action configuration
1181
+ state_or_action_cfg = getattr(self.metadata.modalities, modality)[key]
1182
+
1183
+ # Pad the data
1184
+ return self.retrieve_data_and_pad(
1185
+ array=data_array,
1186
+ step_indices=step_indices,
1187
+ max_length=max_length,
1188
+ padding_strategy="first_last" if state_or_action_cfg.absolute else "zero",
1189
+ # padding_strategy="zero", # HACK for realdata
1190
+ )
1191
+
1192
+ def get_language(
1193
+ self,
1194
+ trajectory_id: int,
1195
+ key: str,
1196
+ base_index: int,
1197
+ ) -> list[str]:
1198
+ """Get the language annotation data for a trajectory by step indices.
1199
+
1200
+ Args:
1201
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1202
+ trajectory_id (int): The ID of the trajectory.
1203
+ key (str): The key of the annotation.
1204
+ base_index (int): The base index of the trajectory.
1205
+
1206
+ Returns:
1207
+ list[str]: The annotation data for the trajectory and step indices. If no matching data is found, return empty strings.
1208
+ """
1209
+ assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
1210
+ # Get the step indices
1211
+ step_indices = self.delta_indices[key] + base_index
1212
+ # Get the trajectory index
1213
+ trajectory_index = self.get_trajectory_index(trajectory_id)
1214
+ # Get the maximum length of the trajectory
1215
+ max_length = self.trajectory_lengths[trajectory_index]
1216
+ # Get the end times corresponding to the closest indices
1217
+ step_indices = np.maximum(step_indices, 0)
1218
+ step_indices = np.minimum(step_indices, max_length - 1)
1219
+ # Get the annotations
1220
+ task_indices: list[int] = []
1221
+ assert key.startswith(
1222
+ "annotation."
1223
+ ), f"Language key must start with 'annotation.', got {key}"
1224
+ subkey = key.replace("annotation.", "")
1225
+ annotation_meta = self.lerobot_modality_meta.annotation
1226
+ assert annotation_meta is not None, f"Annotation metadata is None for {subkey}"
1227
+ assert (
1228
+ subkey in annotation_meta
1229
+ ), f"Annotation key {subkey} not found in metadata, available annotation keys: {annotation_meta.keys()}"
1230
+ subkey_meta = annotation_meta[subkey]
1231
+ original_key = subkey_meta.original_key
1232
+ if original_key is None:
1233
+ original_key = key
1234
+ for i in range(len(step_indices)): #
1235
+ # task_indices.append(self.curr_traj_data[original_key][step_indices[i]].item())
1236
+ value = self.curr_traj_data[original_key].iloc[step_indices[i]] # TODO check v2.0
1237
+ task_indices.append(value if isinstance(value, (int, float)) else value.item())
1238
+
1239
+ return self.tasks.loc[task_indices]["task"].tolist()
1240
+
1241
+ def get_data_by_modality(
1242
+ self,
1243
+ trajectory_id: int,
1244
+ modality: str,
1245
+ key: str,
1246
+ base_index: int,
1247
+ ):
1248
+ """Get the data corresponding to the modality for a trajectory by a base index.
1249
+ This method will call the corresponding helper method based on the modality.
1250
+ See the helper methods for more details.
1251
+ NOTE: For the language modality, the data is padded with empty strings if no matching data is found.
1252
+
1253
+ Args:
1254
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1255
+ trajectory_id (int): The ID of the trajectory.
1256
+ modality (str): The modality of the data.
1257
+ key (str): The key of the data.
1258
+ base_index (int): The base index of the trajectory.
1259
+ """
1260
+ if modality == "video":
1261
+ return self.get_video(trajectory_id, key, base_index)
1262
+ elif modality == "state" or modality == "action":
1263
+ return self.get_state_or_action(trajectory_id, modality, key, base_index)
1264
+ elif modality == "language":
1265
+ return self.get_language(trajectory_id, key, base_index)
1266
+ else:
1267
+ raise ValueError(f"Invalid modality: {modality}")
1268
+
1269
+ def _save_dataset_statistics_(self, save_path: Path | str, format: str = "json") -> None:
1270
+ """
1271
+ Save dataset statistics to specified path in the required format.
1272
+ Only includes statistics for keys that are actually used in the dataset.
1273
+ Key order follows modality config order.
1274
+
1275
+ Args:
1276
+ save_path (Path | str): Path to save the statistics file
1277
+ format (str): Save format, currently only supports "json"
1278
+ """
1279
+ save_path = Path(save_path)
1280
+ save_path.parent.mkdir(parents=True, exist_ok=True)
1281
+
1282
+ # Build the data structure to save
1283
+ statistics_data = {}
1284
+
1285
+ # Get used modality keys
1286
+ used_action_keys, used_state_keys = get_used_modality_keys(self.modality_keys)
1287
+
1288
+ # Organize statistics by tag
1289
+ tag = self.tag
1290
+ tag_stats = {}
1291
+
1292
+ # Process action statistics (only for used keys, config order)
1293
+ if hasattr(self.metadata.statistics, 'action') and self.metadata.statistics.action:
1294
+ action_stats = self.metadata.statistics.action
1295
+ filtered_action_stats = {
1296
+ key: action_stats[key]
1297
+ for key in used_action_keys
1298
+ if key in action_stats
1299
+ }
1300
+
1301
+ if filtered_action_stats:
1302
+ # Combine statistics from filtered action sub-keys
1303
+ combined_action_stats = combine_modality_stats(filtered_action_stats)
1304
+
1305
+ # Add mask field based on whether it's gripper or not
1306
+ mask = generate_action_mask_for_used_keys(
1307
+ self.metadata.modalities.action, filtered_action_stats.keys()
1308
+ )
1309
+ combined_action_stats["mask"] = mask
1310
+
1311
+ tag_stats["action"] = combined_action_stats
1312
+
1313
+ # Process state statistics (only for used keys, config order)
1314
+ if hasattr(self.metadata.statistics, 'state') and self.metadata.statistics.state:
1315
+ state_stats = self.metadata.statistics.state
1316
+ filtered_state_stats = {
1317
+ key: state_stats[key]
1318
+ for key in used_state_keys
1319
+ if key in state_stats
1320
+ }
1321
+
1322
+ if filtered_state_stats:
1323
+ combined_state_stats = combine_modality_stats(filtered_state_stats)
1324
+ tag_stats["state"] = combined_state_stats
1325
+
1326
+ # Add dataset counts
1327
+ tag_stats["num_transitions"] = len(self)
1328
+ tag_stats["num_trajectories"] = len(self.trajectory_ids)
1329
+
1330
+ statistics_data[tag] = tag_stats
1331
+
1332
+ # Save as JSON file
1333
+ if format.lower() == "json":
1334
+ if not str(save_path).endswith('.json'):
1335
+ save_path = save_path.with_suffix('.json')
1336
+ with open(save_path, 'w', encoding='utf-8') as f:
1337
+ json.dump(statistics_data, f, indent=2, ensure_ascii=False)
1338
+ else:
1339
+ raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.")
1340
+
1341
+ print(f"Single dataset statistics saved to: {save_path}")
1342
+ print(f"Used action keys (reordered): {list(used_action_keys)}")
1343
+ print(f"Used state keys (reordered): {list(used_state_keys)}")
1344
+
1345
+
1346
+
1347
+ class MixtureSpecElement(BaseModel):
1348
+ dataset_path: list[Path] | Path = Field(..., description="The path to the dataset.")
1349
+ dataset_weight: float = Field(..., description="The weight of the dataset in the mixture.")
1350
+ distribute_weights: bool = Field(
1351
+ default=False,
1352
+ description="Whether to distribute the weights of the dataset across all the paths. If True, the weights will be evenly distributed across all the paths.",
1353
+ )
1354
+
1355
+
1356
+ # Helper functions for dataset statistics
1357
+
1358
+ def combine_modality_stats(modality_stats: dict) -> dict:
1359
+ """
1360
+ Combine statistics from all sub-keys under a modality.
1361
+
1362
+ Args:
1363
+ modality_stats (dict): Statistics for a modality, containing multiple sub-keys.
1364
+ Each sub-key contains DatasetStatisticalValues object.
1365
+
1366
+ Returns:
1367
+ dict: Combined statistics
1368
+ """
1369
+ combined_stats = {
1370
+ "mean": [],
1371
+ "std": [],
1372
+ "max": [],
1373
+ "min": [],
1374
+ "q01": [],
1375
+ "q99": []
1376
+ }
1377
+
1378
+ # Combine statistics in sub-key order
1379
+ for subkey in modality_stats.keys():
1380
+ subkey_stats = modality_stats[subkey] # This is a DatasetStatisticalValues object
1381
+
1382
+ # Convert DatasetStatisticalValues to dict-like access
1383
+ for stat_name in ["mean", "std", "max", "min", "q01", "q99"]:
1384
+ stat_value = getattr(subkey_stats, stat_name)
1385
+ if isinstance(stat_value, (list, tuple)):
1386
+ combined_stats[stat_name].extend(stat_value)
1387
+ else:
1388
+ # Handle NDArray case - convert to list
1389
+ if hasattr(stat_value, 'tolist'):
1390
+ combined_stats[stat_name].extend(stat_value.tolist())
1391
+ else:
1392
+ combined_stats[stat_name].append(float(stat_value))
1393
+
1394
+ return combined_stats
1395
+
1396
+ def generate_action_mask_for_used_keys(action_modalities: dict, used_action_keys_ordered) -> list[bool]:
1397
+ """
1398
+ Generate mask based on action modalities, but only for used keys.
1399
+ All dimensions are set to True so every channel is de/normalized.
1400
+
1401
+ Args:
1402
+ action_modalities (dict): Configuration information for action modalities.
1403
+ used_action_keys_ordered: Iterable of actually used action keys in the correct order.
1404
+
1405
+ Returns:
1406
+ list[bool]: List of mask values
1407
+ """
1408
+ mask = []
1409
+
1410
+ # Generate mask in the same order as the statistics were combined
1411
+ for subkey in used_action_keys_ordered:
1412
+ if subkey in action_modalities:
1413
+ subkey_config = action_modalities[subkey]
1414
+
1415
+ # Get dimension count from shape
1416
+ if hasattr(subkey_config, 'shape') and len(subkey_config.shape) > 0:
1417
+ dim_count = subkey_config.shape[0]
1418
+ else:
1419
+ dim_count = 1
1420
+
1421
+ # Check if it's gripper-related
1422
+ is_gripper = "gripper" in subkey.lower()
1423
+
1424
+ # Generate mask value for each dimension
1425
+ for _ in range(dim_count):
1426
+ mask.append(not is_gripper) # gripper is False, others are True
1427
+
1428
+ return mask
1429
+
1430
+ def get_used_modality_keys(modality_keys: dict) -> tuple[set, set]:
1431
+ """Extract used action and state keys from modality configuration."""
1432
+ used_action_keys = []
1433
+ used_state_keys = []
1434
+
1435
+ # Extract action keys (remove "action." prefix)
1436
+ for action_key in modality_keys.get("action", []):
1437
+ if action_key.startswith("action."):
1438
+ clean_key = action_key.replace("action.", "")
1439
+ used_action_keys.append(clean_key)
1440
+
1441
+ # Extract state keys (remove "state." prefix)
1442
+ for state_key in modality_keys.get("state", []):
1443
+ if state_key.startswith("state."):
1444
+ clean_key = state_key.replace("state.", "")
1445
+ used_state_keys.append(clean_key)
1446
+
1447
+ return used_action_keys, used_state_keys
1448
+
1449
+
1450
+ def safe_hash(input_tuple):
1451
+ # keep 128 bits of the hash
1452
+ tuple_string = repr(input_tuple).encode("utf-8")
1453
+ sha256 = hashlib.sha256()
1454
+ sha256.update(tuple_string)
1455
+
1456
+ seed = int(sha256.hexdigest(), 16)
1457
+
1458
+ return seed & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
1459
+
1460
+
1461
+ class LeRobotMixtureDataset(Dataset):
1462
+ """
1463
+ A mixture of multiple datasets. This class samples a single dataset based on the dataset weights and then calls the `__getitem__` method of the sampled dataset.
1464
+ It is recommended to modify the single dataset class instead of this class.
1465
+ """
1466
+
1467
+ def __init__(
1468
+ self,
1469
+ data_mixture: Sequence[tuple[LeRobotSingleDataset, float]],
1470
+ mode: str,
1471
+ balance_dataset_weights: bool = True,
1472
+ balance_trajectory_weights: bool = True,
1473
+ seed: int = 42,
1474
+ metadata_config: dict = {
1475
+ "percentile_mixing_method": "min_max",
1476
+ },
1477
+ **kwargs,
1478
+ ):
1479
+ """
1480
+ Initialize the mixture dataset.
1481
+
1482
+ Args:
1483
+ data_mixture (list[tuple[LeRobotSingleDataset, float]]): Datasets and their corresponding weights.
1484
+ mode (str): If "train", __getitem__ will return different samples every epoch; if "val" or "test", __getitem__ will return the same sample every epoch.
1485
+ balance_dataset_weights (bool): If True, the weight of dataset will be multiplied by the total trajectory length of each dataset.
1486
+ balance_trajectory_weights (bool): If True, sample trajectories within a dataset weighted by their length; otherwise, use equal weighting.
1487
+ seed (int): Random seed for sampling.
1488
+ """
1489
+ datasets: list[LeRobotSingleDataset] = []
1490
+ dataset_sampling_weights: list[float] = []
1491
+ for dataset, weight in data_mixture:
1492
+ # Check if dataset is valid and has data
1493
+ if len(dataset) == 0:
1494
+ print(f"Warning: Skipping empty dataset {dataset.dataset_name}")
1495
+ continue
1496
+ datasets.append(dataset)
1497
+ dataset_sampling_weights.append(weight)
1498
+
1499
+ if len(datasets) == 0:
1500
+ raise ValueError("No valid datasets found in the mixture. All datasets are empty.")
1501
+
1502
+ self.datasets = datasets
1503
+ self.balance_dataset_weights = balance_dataset_weights
1504
+ self.balance_trajectory_weights = balance_trajectory_weights
1505
+ self.seed = seed
1506
+ self.mode = mode
1507
+
1508
+ # Set properties for sampling
1509
+
1510
+ # 1. Dataset lengths
1511
+ self._dataset_lengths = np.array([len(dataset) for dataset in self.datasets])
1512
+ print(f"Dataset lengths: {self._dataset_lengths}")
1513
+
1514
+ # 2. Dataset sampling weights
1515
+ self._dataset_sampling_weights = np.array(dataset_sampling_weights)
1516
+
1517
+ if self.balance_dataset_weights:
1518
+ self._dataset_sampling_weights *= self._dataset_lengths
1519
+
1520
+ # Check for zero or negative weights before normalization
1521
+ if np.any(self._dataset_sampling_weights <= 0):
1522
+ print(f"Warning: Found zero or negative sampling weights: {self._dataset_sampling_weights}")
1523
+ # Set minimum weight to prevent division issues
1524
+ self._dataset_sampling_weights = np.maximum(self._dataset_sampling_weights, 1e-8)
1525
+
1526
+ # Normalize weights
1527
+ weights_sum = self._dataset_sampling_weights.sum()
1528
+ if weights_sum == 0 or np.isnan(weights_sum):
1529
+ print(f"Error: Invalid weights sum: {weights_sum}")
1530
+ # Fallback to equal weights
1531
+ self._dataset_sampling_weights = np.ones(len(self.datasets)) / len(self.datasets)
1532
+ print(f"Fallback to equal weights")
1533
+ else:
1534
+ self._dataset_sampling_weights /= weights_sum
1535
+
1536
+ # 3. Trajectory sampling weights
1537
+ self._trajectory_sampling_weights: list[np.ndarray] = []
1538
+ for i, dataset in enumerate(self.datasets):
1539
+ trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths))
1540
+ if self.balance_trajectory_weights:
1541
+ trajectory_sampling_weights *= dataset.trajectory_lengths
1542
+
1543
+ # Check for zero or negative weights before normalization
1544
+ if np.any(trajectory_sampling_weights <= 0):
1545
+ print(f"Warning: Dataset {i} has zero or negative trajectory weights")
1546
+ trajectory_sampling_weights = np.maximum(trajectory_sampling_weights, 1e-8)
1547
+
1548
+ # Normalize weights
1549
+ weights_sum = trajectory_sampling_weights.sum()
1550
+ if weights_sum == 0 or np.isnan(weights_sum):
1551
+ print(f"Error: Dataset {i} has invalid trajectory weights sum: {weights_sum}")
1552
+ # Fallback to equal weights
1553
+ trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) / len(dataset.trajectory_lengths)
1554
+ else:
1555
+ trajectory_sampling_weights /= weights_sum
1556
+
1557
+ self._trajectory_sampling_weights.append(trajectory_sampling_weights)
1558
+
1559
+ # 4. Primary dataset indices
1560
+ self._primary_dataset_indices = np.array(dataset_sampling_weights) == 1.0
1561
+ if not np.any(self._primary_dataset_indices):
1562
+ print(f"Warning: No dataset with weight 1.0 found. Original weights: {dataset_sampling_weights}")
1563
+ # Fallback: use the dataset(s) with maximum weight as primary
1564
+ max_weight = max(dataset_sampling_weights)
1565
+ self._primary_dataset_indices = np.array(dataset_sampling_weights) == max_weight
1566
+ print(f"Using datasets with maximum weight {max_weight} as primary: {self._primary_dataset_indices}")
1567
+
1568
+ if not np.any(self._primary_dataset_indices):
1569
+ # This should never happen, but just in case
1570
+ print("Error: Still no primary dataset found. Using first dataset as primary.")
1571
+ self._primary_dataset_indices = np.zeros(len(self.datasets), dtype=bool)
1572
+ self._primary_dataset_indices[0] = True
1573
+
1574
+ # Set the epoch and sample the first epoch
1575
+ self.set_epoch(0)
1576
+
1577
+ self.update_metadata(metadata_config)
1578
+
1579
+ @property
1580
+ def dataset_lengths(self) -> np.ndarray:
1581
+ """The lengths of each dataset."""
1582
+ return self._dataset_lengths
1583
+
1584
+ @property
1585
+ def dataset_sampling_weights(self) -> np.ndarray:
1586
+ """The sampling weights for each dataset."""
1587
+ return self._dataset_sampling_weights
1588
+
1589
+ @property
1590
+ def trajectory_sampling_weights(self) -> list[np.ndarray]:
1591
+ """The sampling weights for each trajectory in each dataset."""
1592
+ return self._trajectory_sampling_weights
1593
+
1594
+ @property
1595
+ def primary_dataset_indices(self) -> np.ndarray:
1596
+ """The indices of the primary datasets."""
1597
+ return self._primary_dataset_indices
1598
+
1599
+ def __str__(self) -> str:
1600
+ dataset_descriptions = []
1601
+ for dataset, weight in zip(self.datasets, self.dataset_sampling_weights):
1602
+ dataset_description = {
1603
+ "Dataset": str(dataset),
1604
+ "Sampling weight": float(weight),
1605
+ }
1606
+ dataset_descriptions.append(dataset_description)
1607
+ return json.dumps({"Mixture dataset": dataset_descriptions}, indent=2)
1608
+
1609
+ def set_epoch(self, epoch: int):
1610
+ """Set the epoch for the dataset.
1611
+
1612
+ Args:
1613
+ epoch (int): The epoch to set.
1614
+ """
1615
+ self.epoch = epoch
1616
+ # self.sampled_steps = self.sample_epoch()
1617
+
1618
+ def sample_step(self, index: int) -> tuple[LeRobotSingleDataset, int, int]:
1619
+ """Sample a single step from the dataset."""
1620
+ # return self.sampled_steps[index]
1621
+
1622
+ # Set seed
1623
+ seed = index if self.mode != "train" else safe_hash((self.epoch, index, self.seed))
1624
+ rng = np.random.default_rng(seed)
1625
+
1626
+ # Sample dataset
1627
+ dataset_index = rng.choice(len(self.datasets), p=self.dataset_sampling_weights)
1628
+ dataset = self.datasets[dataset_index]
1629
+
1630
+ # Sample trajectory
1631
+ # trajectory_index = rng.choice(
1632
+ # len(dataset.trajectory_ids), p=self.trajectory_sampling_weights[dataset_index]
1633
+ # )
1634
+ # trajectory_id = dataset.trajectory_ids[trajectory_index]
1635
+
1636
+ # # Sample step
1637
+ # base_index = rng.choice(dataset.trajectory_lengths[trajectory_index])
1638
+ # return dataset, trajectory_id, base_index
1639
+ single_step_index = rng.choice(len(dataset.all_steps))
1640
+ trajectory_id, base_index = dataset.all_steps[single_step_index]
1641
+ return dataset, trajectory_id, base_index
1642
+
1643
+ def __getitem__(self, index: int) -> dict:
1644
+ """Get the data for a single trajectory and start index.
1645
+
1646
+ Args:
1647
+ index (int): The index of the trajectory to get.
1648
+
1649
+ Returns:
1650
+ dict: The data for the trajectory and start index.
1651
+ """
1652
+ max_retries = 10
1653
+ last_exception = None
1654
+
1655
+ for attempt in range(max_retries):
1656
+ try:
1657
+ dataset, trajectory_name, step = self.sample_step(index)
1658
+ data_raw = dataset.get_step_data(trajectory_name, step)
1659
+ data = dataset.transforms(data_raw)
1660
+
1661
+ # Process all video keys dynamically
1662
+ images = []
1663
+ mid_images = []
1664
+ num_history_steps = int(getattr(dataset, "num_history_steps", 0) or 0)
1665
+ for video_key in dataset.modality_keys.get("video", []):
1666
+ video_frames = data[video_key]
1667
+ image = video_frames[0]
1668
+ image = Image.fromarray(image).resize((224, 224)) #TODO check if this is ok
1669
+ images.append(image)
1670
+ if num_history_steps != 0:
1671
+ history_index = min(num_history_steps - 1, len(video_frames) - 1)
1672
+ mid_image = video_frames[history_index]
1673
+ mid_image = Image.fromarray(mid_image).resize((224, 224))
1674
+ mid_images.append(mid_image)
1675
+
1676
+ # Get language and action data
1677
+ language = data[dataset.modality_keys["language"][0]][0]
1678
+ action = []
1679
+ for action_key in dataset.modality_keys["action"]:
1680
+ action.append(data[action_key])
1681
+ action = np.concatenate(action, axis=1).astype(np.float16)
1682
+ action = standardize_action_representation(action, dataset.tag)
1683
+
1684
+ state = []
1685
+ for state_key in dataset.modality_keys["state"]:
1686
+ state.append(data[state_key])
1687
+ state = np.concatenate(state, axis=1).astype(np.float16)
1688
+ state = standardize_state_representation(state, dataset.tag)
1689
+
1690
+ sample = dict(action=action, state=state, image=images, lang=language, dataset_id=dataset._dataset_id)
1691
+ if num_history_steps != 0:
1692
+ sample["mid_image"] = mid_images
1693
+ return sample
1694
+
1695
+ except Exception as e:
1696
+ last_exception = e
1697
+ if attempt < max_retries - 1:
1698
+ # Log the error but continue trying
1699
+ print(f"Attempt {attempt + 1}/{max_retries} failed for index {index}: {e}")
1700
+ print(f"Retrying with new sample...")
1701
+ # For retry, we can use a slightly different index to get a new sample
1702
+ # This helps avoid getting stuck on the same problematic sample
1703
+ index = random.randint(0, len(self) - 1)
1704
+ else:
1705
+ # All retries exhausted
1706
+ print(f"All {max_retries} attempts failed for index {index}")
1707
+ print(f"Last error: {last_exception}")
1708
+ # Return a dummy sample or re-raise the exception
1709
+ raise last_exception
1710
+
1711
+ def __len__(self) -> int:
1712
+ """Get the length of a single epoch in the mixture.
1713
+
1714
+ Returns:
1715
+ int: The length of a single epoch in the mixture.
1716
+ """
1717
+ # Check for potential issues
1718
+ if len(self.datasets) == 0:
1719
+ return 0
1720
+
1721
+ # Check if any dataset lengths are 0 or NaN
1722
+ if np.any(self.dataset_lengths == 0) or np.any(np.isnan(self.dataset_lengths)):
1723
+ print(f"Warning: Found zero or NaN dataset lengths: {self.dataset_lengths}")
1724
+ # Filter out zero/NaN length datasets
1725
+ valid_indices = (self.dataset_lengths > 0) & (~np.isnan(self.dataset_lengths))
1726
+ if not np.any(valid_indices):
1727
+ print("Error: All datasets have zero or NaN length")
1728
+ return 0
1729
+ else:
1730
+ valid_indices = np.ones(len(self.datasets), dtype=bool)
1731
+
1732
+ # Check if any sampling weights are 0 or NaN
1733
+ if np.any(self.dataset_sampling_weights == 0) or np.any(np.isnan(self.dataset_sampling_weights)):
1734
+ print(f"Warning: Found zero or NaN sampling weights: {self.dataset_sampling_weights}")
1735
+ # Use only valid weights
1736
+ valid_weights = (self.dataset_sampling_weights > 0) & (~np.isnan(self.dataset_sampling_weights))
1737
+ valid_indices = valid_indices & valid_weights
1738
+ if not np.any(valid_indices):
1739
+ print("Error: All sampling weights are zero or NaN")
1740
+ return 0
1741
+
1742
+ # Check primary dataset indices
1743
+ primary_and_valid = self.primary_dataset_indices & valid_indices
1744
+ if not np.any(primary_and_valid):
1745
+ print(f"Warning: No valid primary datasets found. Primary indices: {self.primary_dataset_indices}, Valid indices: {valid_indices}")
1746
+ # Fallback: use the largest valid dataset
1747
+ if np.any(valid_indices):
1748
+ max_length = self.dataset_lengths[valid_indices].max()
1749
+ print(f"Fallback: Using maximum dataset length: {max_length}")
1750
+ return int(max_length)
1751
+ else:
1752
+ return 0
1753
+
1754
+ # Calculate the ratio and get max
1755
+ ratios = (self.dataset_lengths / self.dataset_sampling_weights)[primary_and_valid]
1756
+
1757
+ # Check for NaN or inf in ratios
1758
+ if np.any(np.isnan(ratios)) or np.any(np.isinf(ratios)):
1759
+ print(f"Warning: Found NaN or inf in ratios: {ratios}")
1760
+ print(f"Dataset lengths: {self.dataset_lengths[primary_and_valid]}")
1761
+ print(f"Sampling weights: {self.dataset_sampling_weights[primary_and_valid]}")
1762
+ # Filter out invalid ratios
1763
+ valid_ratios = ratios[~np.isnan(ratios) & ~np.isinf(ratios)]
1764
+ if len(valid_ratios) == 0:
1765
+ print("Error: All ratios are NaN or inf")
1766
+ return 0
1767
+ max_ratio = valid_ratios.max()
1768
+ else:
1769
+ max_ratio = ratios.max()
1770
+
1771
+ result = int(max_ratio)
1772
+ if result == 0:
1773
+ print(f"Warning: Dataset mixture length is 0")
1774
+ return result
1775
+
1776
+ @staticmethod
1777
+ def compute_overall_statistics(
1778
+ per_task_stats: list[dict[str, dict[str, list[float] | np.ndarray]]],
1779
+ dataset_sampling_weights: list[float] | np.ndarray,
1780
+ percentile_mixing_method: str = "weighted_average",
1781
+ ) -> dict[str, dict[str, list[float]]]:
1782
+ """
1783
+ Computes overall statistics from per-task statistics using dataset sample weights.
1784
+
1785
+ Args:
1786
+ per_task_stats: List of per-task statistics.
1787
+ Example format of one element in the per-task statistics list:
1788
+ {
1789
+ "state.gripper": {
1790
+ "min": [...],
1791
+ "max": [...],
1792
+ "mean": [...],
1793
+ "std": [...],
1794
+ "q01": [...],
1795
+ "q99": [...],
1796
+ },
1797
+ ...
1798
+ }
1799
+ dataset_sampling_weights: List of sample weights for each task.
1800
+ percentile_mixing_method: The method to mix the percentiles, either "weighted_average" or "weighted_std".
1801
+
1802
+ Returns:
1803
+ A dict of overall statistics per modality.
1804
+ """
1805
+ # Normalize the sample weights to sum to 1
1806
+ dataset_sampling_weights = np.array(dataset_sampling_weights)
1807
+ normalized_weights = dataset_sampling_weights / dataset_sampling_weights.sum()
1808
+
1809
+ # Initialize overall statistics dict
1810
+ overall_stats: dict[str, dict[str, list[float]]] = {}
1811
+
1812
+ # Get the list of modality keys
1813
+ modality_keys = per_task_stats[0].keys()
1814
+
1815
+ for modality in modality_keys:
1816
+ # Number of dimensions (assuming consistent across tasks)
1817
+ num_dims = len(per_task_stats[0][modality]["mean"])
1818
+
1819
+ # Initialize accumulators for means and variances
1820
+ weighted_means = np.zeros(num_dims)
1821
+ weighted_squares = np.zeros(num_dims)
1822
+
1823
+ # Collect min, max, q01, q99 from all tasks
1824
+ min_list = []
1825
+ max_list = []
1826
+ q01_list = []
1827
+ q99_list = []
1828
+
1829
+ for task_idx, task_stats in enumerate(per_task_stats):
1830
+ w_i = normalized_weights[task_idx]
1831
+ stats = task_stats[modality]
1832
+ means = np.array(stats["mean"])
1833
+ stds = np.array(stats["std"])
1834
+
1835
+ # Update weighted sums for mean and variance
1836
+ weighted_means += w_i * means
1837
+ weighted_squares += w_i * (stds**2 + means**2)
1838
+
1839
+ # Collect min, max, q01, q99
1840
+ min_list.append(stats["min"])
1841
+ max_list.append(stats["max"])
1842
+ q01_list.append(stats["q01"])
1843
+ q99_list.append(stats["q99"])
1844
+
1845
+ # Compute overall mean
1846
+ overall_mean = weighted_means.tolist()
1847
+
1848
+ # Compute overall variance and std deviation
1849
+ overall_variance = weighted_squares - weighted_means**2
1850
+ overall_std = np.sqrt(overall_variance).tolist()
1851
+
1852
+ # Compute overall min and max per dimension
1853
+ overall_min = np.min(np.array(min_list), axis=0).tolist()
1854
+ overall_max = np.max(np.array(max_list), axis=0).tolist()
1855
+
1856
+ # Compute overall q01 and q99 per dimension
1857
+ # Use weighted average of per-task quantiles
1858
+ q01_array = np.array(q01_list)
1859
+ q99_array = np.array(q99_list)
1860
+ if percentile_mixing_method == "weighted_average":
1861
+ weighted_q01 = np.average(q01_array, axis=0, weights=normalized_weights).tolist()
1862
+ weighted_q99 = np.average(q99_array, axis=0, weights=normalized_weights).tolist()
1863
+ # std_q01 = np.std(q01_array, axis=0).tolist()
1864
+ # std_q99 = np.std(q99_array, axis=0).tolist()
1865
+ # print(modality)
1866
+ # print(f"{std_q01=}, {std_q99=}")
1867
+ # print(f"{weighted_q01=}, {weighted_q99=}")
1868
+ elif percentile_mixing_method == "min_max":
1869
+ weighted_q01 = np.min(q01_array, axis=0).tolist()
1870
+ weighted_q99 = np.max(q99_array, axis=0).tolist()
1871
+ else:
1872
+ raise ValueError(f"Invalid percentile mixing method: {percentile_mixing_method}")
1873
+
1874
+ # Store the overall statistics for the modality
1875
+ overall_stats[modality] = {
1876
+ "min": overall_min,
1877
+ "max": overall_max,
1878
+ "mean": overall_mean,
1879
+ "std": overall_std,
1880
+ "q01": weighted_q01,
1881
+ "q99": weighted_q99,
1882
+ }
1883
+
1884
+ return overall_stats
1885
+
1886
+ @staticmethod
1887
+ def merge_metadata(
1888
+ metadatas: list[DatasetMetadata],
1889
+ dataset_sampling_weights: list[float],
1890
+ percentile_mixing_method: str,
1891
+ ) -> DatasetMetadata:
1892
+ """Merge multiple metadata into one."""
1893
+ # Convert to dicts
1894
+ metadata_dicts = [metadata.model_dump(mode="json") for metadata in metadatas]
1895
+ # Create a new metadata dict
1896
+ merged_metadata = {}
1897
+
1898
+ # Check all metadata have the same embodiment tag
1899
+ assert all(
1900
+ metadata.embodiment_tag == metadatas[0].embodiment_tag for metadata in metadatas
1901
+ ), "All metadata must have the same embodiment tag"
1902
+ merged_metadata["embodiment_tag"] = metadatas[0].embodiment_tag
1903
+
1904
+ # Merge the dataset statistics
1905
+ dataset_statistics = {}
1906
+ dataset_statistics["state"] = LeRobotMixtureDataset.compute_overall_statistics(
1907
+ per_task_stats=[m["statistics"]["state"] for m in metadata_dicts],
1908
+ dataset_sampling_weights=dataset_sampling_weights,
1909
+ percentile_mixing_method=percentile_mixing_method,
1910
+ )
1911
+ dataset_statistics["action"] = LeRobotMixtureDataset.compute_overall_statistics(
1912
+ per_task_stats=[m["statistics"]["action"] for m in metadata_dicts],
1913
+ dataset_sampling_weights=dataset_sampling_weights,
1914
+ percentile_mixing_method=percentile_mixing_method,
1915
+ )
1916
+ merged_metadata["statistics"] = dataset_statistics
1917
+
1918
+ # Merge the modality configs
1919
+ modality_configs = defaultdict(set)
1920
+ for metadata in metadata_dicts:
1921
+ for modality, configs in metadata["modalities"].items():
1922
+ modality_configs[modality].add(json.dumps(configs))
1923
+ merged_metadata["modalities"] = {}
1924
+ for modality, configs in modality_configs.items():
1925
+ # Check that all modality configs correspond to the same tag matches
1926
+ assert (
1927
+ len(configs) == 1
1928
+ ), f"Multiple modality configs for modality {modality}: {list(configs)}"
1929
+ merged_metadata["modalities"][modality] = json.loads(configs.pop())
1930
+
1931
+ return DatasetMetadata.model_validate(merged_metadata)
1932
+
1933
+ def update_metadata(self, metadata_config: dict, cached_statistics_path: Path | str | None = None) -> None:
1934
+ """
1935
+ Merge multiple metadatas into one and set the transforms with the merged metadata.
1936
+
1937
+ Args:
1938
+ metadata_config (dict): Configuration for the metadata.
1939
+ "percentile_mixing_method": The method to mix the percentiles, either "weighted_average" or "min_max".
1940
+ weighted_average: Use the weighted average of the percentiles using the weight used in sampling the datasets.
1941
+ min_max: Use the min of the 1st percentile and max of the 99th percentile.
1942
+ """
1943
+ # If cached path is provided, try to load and apply
1944
+ if cached_statistics_path is not None:
1945
+ try:
1946
+ cached_stats = self.load_merged_statistics(cached_statistics_path)
1947
+ self.apply_cached_statistics(cached_stats)
1948
+ return
1949
+ except (FileNotFoundError, KeyError, ValidationError) as e:
1950
+ print(f"Failed to load cached statistics: {e}")
1951
+ print("Falling back to computing statistics from scratch...")
1952
+
1953
+ self.tag = EmbodimentTag.NEW_EMBODIMENT.value
1954
+ self.merged_metadata: dict[str, DatasetMetadata] = {}
1955
+ # Group metadata by tag
1956
+ all_metadatas: dict[str, list[DatasetMetadata]] = {}
1957
+ for dataset in self.datasets:
1958
+ if dataset.tag not in all_metadatas:
1959
+ all_metadatas[dataset.tag] = []
1960
+ all_metadatas[dataset.tag].append(dataset.metadata)
1961
+ for tag, metadatas in all_metadatas.items():
1962
+ self.merged_metadata[tag] = self.merge_metadata(
1963
+ metadatas=metadatas,
1964
+ dataset_sampling_weights=self.dataset_sampling_weights.tolist(),
1965
+ percentile_mixing_method=metadata_config["percentile_mixing_method"],
1966
+ )
1967
+ for dataset in self.datasets:
1968
+ dataset.set_transforms_metadata(self.merged_metadata[dataset.tag])
1969
+
1970
+ def save_dataset_statistics(self, save_path: Path | str, format: str = "json") -> None:
1971
+ """
1972
+ Save merged dataset statistics to specified path in the required format.
1973
+ Only includes statistics for keys that are actually used in the datasets.
1974
+ Key order follows each tag's modality config order.
1975
+
1976
+ Args:
1977
+ save_path (Path | str): Path to save the statistics file
1978
+ format (str): Save format, currently only supports "json"
1979
+ """
1980
+ save_path = Path(save_path)
1981
+ save_path.parent.mkdir(parents=True, exist_ok=True)
1982
+
1983
+ # Build the data structure to save
1984
+ statistics_data = {}
1985
+
1986
+ # Keep key orders per embodiment tag (from modality config order)
1987
+ tag_to_used_action_keys = {}
1988
+ tag_to_used_state_keys = {}
1989
+ for dataset in self.datasets:
1990
+ if dataset.tag in tag_to_used_action_keys:
1991
+ continue
1992
+ used_action_keys, used_state_keys = get_used_modality_keys(dataset.modality_keys)
1993
+ tag_to_used_action_keys[dataset.tag] = used_action_keys
1994
+ tag_to_used_state_keys[dataset.tag] = used_state_keys
1995
+
1996
+ # Organize statistics by tag
1997
+ for tag, merged_metadata in self.merged_metadata.items():
1998
+ tag_stats = {}
1999
+
2000
+ # Process action statistics
2001
+ if hasattr(merged_metadata.statistics, 'action') and merged_metadata.statistics.action:
2002
+ action_stats = merged_metadata.statistics.action
2003
+
2004
+ used_action_keys = tag_to_used_action_keys.get(tag, [])
2005
+ filtered_action_stats = {
2006
+ key: action_stats[key]
2007
+ for key in used_action_keys
2008
+ if key in action_stats
2009
+ }
2010
+
2011
+ if filtered_action_stats:
2012
+ combined_action_stats = combine_modality_stats(filtered_action_stats)
2013
+
2014
+ mask = generate_action_mask_for_used_keys(
2015
+ merged_metadata.modalities.action, filtered_action_stats.keys()
2016
+ )
2017
+ combined_action_stats["mask"] = mask
2018
+
2019
+ tag_stats["action"] = combined_action_stats
2020
+
2021
+ # Process state statistics
2022
+ if hasattr(merged_metadata.statistics, 'state') and merged_metadata.statistics.state:
2023
+ state_stats = merged_metadata.statistics.state
2024
+
2025
+ used_state_keys = tag_to_used_state_keys.get(tag, [])
2026
+ filtered_state_stats = {
2027
+ key: state_stats[key]
2028
+ for key in used_state_keys
2029
+ if key in state_stats
2030
+ }
2031
+
2032
+ if filtered_state_stats:
2033
+ combined_state_stats = combine_modality_stats(filtered_state_stats)
2034
+ tag_stats["state"] = combined_state_stats
2035
+
2036
+ # Add dataset counts
2037
+ tag_stats.update(self._get_dataset_counts(tag))
2038
+
2039
+ statistics_data[tag] = tag_stats
2040
+
2041
+ # Save file
2042
+ if format.lower() == "json":
2043
+ if not str(save_path).endswith('.json'):
2044
+ save_path = save_path.with_suffix('.json')
2045
+ with open(save_path, 'w', encoding='utf-8') as f:
2046
+ json.dump(statistics_data, f, indent=2, ensure_ascii=False)
2047
+ else:
2048
+ raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.")
2049
+
2050
+ print(f"Merged dataset statistics saved to: {save_path}")
2051
+ print(f"Used action keys by tag: {tag_to_used_action_keys}")
2052
+ print(f"Used state keys by tag: {tag_to_used_state_keys}")
2053
+
2054
+
2055
+ def _combine_modality_stats(self, modality_stats: dict) -> dict:
2056
+ """Backward compatibility wrapper."""
2057
+ return combine_modality_stats(modality_stats)
2058
+
2059
+ def _generate_action_mask_for_used_keys(self, action_modalities: dict, used_action_keys_ordered) -> list[bool]:
2060
+ """Backward compatibility wrapper."""
2061
+ return generate_action_mask_for_used_keys(action_modalities, used_action_keys_ordered)
2062
+
2063
+ def _get_dataset_counts(self, tag: str) -> dict:
2064
+ """
2065
+ Get dataset count information for specified tag.
2066
+
2067
+ Args:
2068
+ tag (str): embodiment tag
2069
+
2070
+ Returns:
2071
+ dict: Dictionary containing num_transitions and num_trajectories
2072
+ """
2073
+ num_transitions = 0
2074
+ num_trajectories = 0
2075
+
2076
+ # Count dataset information belonging to this tag
2077
+ for dataset in self.datasets:
2078
+ if dataset.tag == tag:
2079
+ num_transitions += len(dataset)
2080
+ num_trajectories += len(dataset.trajectory_ids)
2081
+
2082
+ return {
2083
+ "num_transitions": num_transitions,
2084
+ "num_trajectories": num_trajectories
2085
+ }
2086
+
2087
+ @classmethod
2088
+ def load_merged_statistics(cls, load_path: Path | str) -> dict:
2089
+ """
2090
+ Load merged dataset statistics from file.
2091
+
2092
+ Args:
2093
+ load_path (Path | str): Path to the statistics file
2094
+
2095
+ Returns:
2096
+ dict: Dictionary containing merged statistics
2097
+ """
2098
+ load_path = Path(load_path)
2099
+ if not load_path.exists():
2100
+ raise FileNotFoundError(f"Statistics file not found: {load_path}")
2101
+
2102
+ if load_path.suffix.lower() == '.json':
2103
+ with open(load_path, 'r', encoding='utf-8') as f:
2104
+ return json.load(f)
2105
+ elif load_path.suffix.lower() == '.pkl':
2106
+ import pickle
2107
+ with open(load_path, 'rb') as f:
2108
+ return pickle.load(f)
2109
+ else:
2110
+ raise ValueError(f"Unsupported file format: {load_path.suffix}")
2111
+
2112
+ def apply_cached_statistics(self, cached_statistics: dict) -> None:
2113
+ """
2114
+ Apply cached statistics to avoid recomputation.
2115
+
2116
+ Args:
2117
+ cached_statistics (dict): Statistics loaded from file
2118
+ """
2119
+ # Validate that cached statistics match current datasets
2120
+ if "metadata" in cached_statistics:
2121
+ cached_dataset_names = set(cached_statistics["metadata"]["dataset_names"])
2122
+ current_dataset_names = set(dataset.dataset_name for dataset in self.datasets)
2123
+
2124
+ if cached_dataset_names != current_dataset_names:
2125
+ print("Warning: Cached statistics dataset names don't match current datasets.")
2126
+ print(f"Cached: {cached_dataset_names}")
2127
+ print(f"Current: {current_dataset_names}")
2128
+ return
2129
+
2130
+ # Apply cached statistics
2131
+ self.merged_metadata = {}
2132
+ for tag, stats_data in cached_statistics.items():
2133
+ if tag == "metadata": # Skip metadata field
2134
+ continue
2135
+
2136
+ # Convert back to DatasetMetadata format
2137
+ metadata_dict = {
2138
+ "embodiment_tag": tag,
2139
+ "statistics": {
2140
+ "action": {},
2141
+ "state": {}
2142
+ },
2143
+ "modalities": {}
2144
+ }
2145
+
2146
+ # Convert action statistics back
2147
+ if "action" in stats_data:
2148
+ action_data = stats_data["action"]
2149
+ # This is simplified - you may need to split back to sub-keys
2150
+ metadata_dict["statistics"]["action"] = action_data
2151
+
2152
+ # Convert state statistics back
2153
+ if "state" in stats_data:
2154
+ state_data = stats_data["state"]
2155
+ metadata_dict["statistics"]["state"] = state_data
2156
+
2157
+ self.merged_metadata[tag] = DatasetMetadata.model_validate(metadata_dict)
2158
+
2159
+ # Update transforms metadata for each dataset
2160
+ for dataset in self.datasets:
2161
+ if dataset.tag in self.merged_metadata:
2162
+ dataset.set_transforms_metadata(self.merged_metadata[dataset.tag])
2163
+
2164
+ print(f"Applied cached statistics for {len(self.merged_metadata)} embodiment tags.")
2165
+
code/dataloader/gr00t_lerobot/datasets_bak.py ADDED
@@ -0,0 +1,2175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ """
18
+ In this file, we define 3 types of datasets:
19
+ 1. LeRobotSingleDataset: a single dataset for a given embodiment tag
20
+ 2. LeRobotMixtureDataset: a mixture of datasets for a given list of embodiment tags
21
+ 3. CachedLeRobotSingleDataset: a single dataset for a given embodiment tag,
22
+ with caching for the video frames
23
+
24
+ See `scripts/load_dataset.py` for examples on how to use these datasets.
25
+ """
26
+ import os
27
+ import hashlib
28
+ import json, torch
29
+ from collections import defaultdict
30
+ from pathlib import Path
31
+ from typing import Sequence
32
+ import os, random
33
+ import numpy as np
34
+ import pandas as pd
35
+ from pydantic import BaseModel, Field, ValidationError
36
+ from torch.utils.data import Dataset
37
+ from tqdm import tqdm
38
+ from PIL import Image
39
+
40
+ from starVLA.dataloader.gr00t_lerobot.video import get_all_frames, get_frames_by_timestamps
41
+
42
+ from starVLA.dataloader.gr00t_lerobot.embodiment_tags import EmbodimentTag, DATASET_NAME_TO_ID
43
+ from starVLA.dataloader.gr00t_lerobot.schema import (
44
+ DatasetMetadata,
45
+ DatasetStatisticalValues,
46
+ LeRobotModalityMetadata,
47
+ LeRobotStateActionMetadata,
48
+ )
49
+ from starVLA.dataloader.gr00t_lerobot.transform import ComposedModalityTransform
50
+
51
+ from functools import partial
52
+ from typing import Tuple, List
53
+ import pickle
54
+
55
+ # LeRobot v2.0 dataset file names
56
+ LE_ROBOT_MODALITY_FILENAME = "meta/modality.json"
57
+ LE_ROBOT_EPISODE_FILENAME = "meta/episodes.jsonl"
58
+ LE_ROBOT_TASKS_FILENAME = "meta/tasks.jsonl"
59
+ LE_ROBOT_INFO_FILENAME = "meta/info.json"
60
+ LE_ROBOT_STATS_FILENAME = "meta/stats_gr00t.json"
61
+ LE_ROBOT_DATA_FILENAME = "data/*/*.parquet"
62
+ LE_ROBOT_STEPS_FILENAME = "meta/steps.pkl"
63
+ EPSILON = 5e-4
64
+
65
+ # LeRobot v3.0 dataset file names
66
+ LE_ROBOT3_TASKS_FILENAME = "meta/tasks.parquet"
67
+ LE_ROBOT3_EPISODE_FILENAME = "meta/episodes/*/*.parquet"
68
+
69
+
70
+ # =============================================================================
71
+ # Unified Representation Layout & Helpers
72
+ # =============================================================================
73
+
74
+ STANDARD_ACTION_DIM = 37
75
+ #
76
+ # Unified action representation layout (0-based indices, Python slice is [start, stop)):
77
+ # TIGHT layout: all datasets share the same 29D space for better cross-embodiment transfer.
78
+ #
79
+ # - 0:7 -> left_arm (7D): xyz, rpy/euler, gripper
80
+ # Used by: robotwin left arm; gr1 left_arm
81
+ # - 7:14 -> right_arm (7D): same structure
82
+ # Used by: libero, bridge, fractal(rt1), oxe_droid (single-arm -> right slot);
83
+ # robotwin right arm; gr1 right_arm
84
+ # - 14:20 -> left_hand (6D): gr1 only
85
+ # - 20:26 -> right_hand (6D): gr1 only
86
+ # - 26:29 -> waist (3D): gr1 only
87
+ # - 29:37 -> joints + gripper (8D): real_world_franka only
88
+ #
89
+ # Mapping:
90
+ # libero/bridge/fractal/oxe_droid (7D) -> [7:14] (right_arm slot, single-arm default)
91
+ # robotwin (14D, left+right) -> [0:14]
92
+ # gr1/robocasa (29D) -> [0:29]
93
+ # real-world (8D) -> [29:37] (joints + gripper)
94
+
95
+ ACTION_REPRESENTATION_SLICES = {
96
+ # Single-arm (7D) -> right_arm slot [7:14] (single-arm default to right hand)
97
+ "franka": slice(7, 14),
98
+ "libero_franka": slice(7, 14),
99
+ "oxe_droid": slice(7, 14),
100
+ "oxe_rt1": slice(7, 14),
101
+ "oxe_bridge": slice(7, 14),
102
+
103
+ # Dual-arm (14D) -> left [0:7] + right [7:14]
104
+ "dual_arm_franka": slice(0, 14),
105
+ "robotwin": slice(0, 14),
106
+
107
+ # Humanoid (29D) -> full [0:29], standard vector 30D (index 29 pad 0)
108
+ "gr1": slice(0, 29),
109
+ "fourier_gr1_arms_waist": slice(0, 29),
110
+
111
+ # Real-world (8D) -> [29:37] (joints + gripper)
112
+ "real_world_franka": slice(29, 37),
113
+
114
+ # Fallback (single-arm -> right slot)
115
+ "new_embodiment": slice(7, 14),
116
+ }
117
+
118
+ STANDARD_STATE_DIM = 88
119
+ # Mapping:
120
+ # robotwin (14D) -> [0:14] (left [0:7] + right [7:14])
121
+ # libero/bridge/fractal (8D) -> [14:22] (right slot)
122
+ # real-world (8D) -> [22:30] (joints + gripper)
123
+ # gr1 (58D after sin/cos) -> [30:88] (isolated, different transform)
124
+
125
+ STATE_REPRESENTATION_SLICES = {
126
+ # Dual-arm (14D) -> left [0:7] + right [7:14]
127
+ "dual_arm_franka": slice(0, 14),
128
+ "robotwin": slice(0, 14),
129
+ # Single-arm (8D) -> right slot [7:15] (aligned with action right [7:14])
130
+ "franka": slice(14, 22),
131
+ "libero_franka": slice(14, 22),
132
+ "oxe_droid": slice(14, 22),
133
+ "oxe_rt1": slice(14, 22),
134
+ "oxe_bridge": slice(14, 22),
135
+ # Real-world (8D) -> [22:30] (joints + gripper)
136
+ "real_world_franka": slice(22, 30),
137
+ # GR1 isolated [30:88] (58D, has StateActionSinCosTransform - different pipeline)
138
+ "gr1": slice(30, 88),
139
+ # Fallback (single-arm -> right slot)
140
+ "new_embodiment": slice(14, 22),
141
+ }
142
+
143
+
144
+ def standardize_action_representation(
145
+ action: np.ndarray, embodiment_tag: str
146
+ ) -> np.ndarray:
147
+ """Map per-robot action to a fixed-size standard action vector."""
148
+ target_slice = ACTION_REPRESENTATION_SLICES.get(embodiment_tag)
149
+
150
+ # Fallback to 'new_embodiment' if tag not found, or raise error
151
+ if target_slice is None:
152
+ if "new_embodiment" in ACTION_REPRESENTATION_SLICES:
153
+ target_slice = ACTION_REPRESENTATION_SLICES["new_embodiment"]
154
+ else:
155
+ raise ValueError(
156
+ f"Unknown embodiment tag '{embodiment_tag}' for action mapping. "
157
+ f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES)}"
158
+ )
159
+
160
+ expected_dim = target_slice.stop - target_slice.start
161
+ if action.shape[-1] != expected_dim:
162
+ raise ValueError(
163
+ f"Action dim mismatch for tag '{embodiment_tag}': "
164
+ f"{action.shape[-1]=} vs expected {expected_dim}."
165
+ )
166
+
167
+ standard = np.zeros(
168
+ (*action.shape[:-1], STANDARD_ACTION_DIM), dtype=action.dtype
169
+ )
170
+ standard[..., target_slice] = action
171
+ return standard
172
+
173
+
174
+ def standardize_state_representation(
175
+ state: np.ndarray, embodiment_tag: str
176
+ ) -> np.ndarray:
177
+ """Map per-robot state to a fixed-size standard state vector."""
178
+
179
+ target_slice = STATE_REPRESENTATION_SLICES.get(embodiment_tag)
180
+
181
+ # Fallback to 'new_embodiment' if tag not found, or raise error
182
+ if target_slice is None:
183
+ if "new_embodiment" in STATE_REPRESENTATION_SLICES:
184
+ target_slice = STATE_REPRESENTATION_SLICES["new_embodiment"]
185
+ else:
186
+ raise ValueError(
187
+ f"Unknown embodiment tag '{embodiment_tag}' for state mapping. "
188
+ f"Known tags: {sorted(STATE_REPRESENTATION_SLICES)}"
189
+ )
190
+
191
+ expected_dim = target_slice.stop - target_slice.start
192
+ if state.shape[-1] != expected_dim:
193
+ raise ValueError(
194
+ f"State dim mismatch for tag '{embodiment_tag}': "
195
+ f"{state.shape[-1]=} vs expected {expected_dim}."
196
+ )
197
+
198
+ standard = np.zeros(
199
+ (*state.shape[:-1], STANDARD_STATE_DIM), dtype=state.dtype
200
+ )
201
+ standard[..., target_slice] = state
202
+ return standard
203
+
204
+
205
+ def calculate_dataset_statistics(parquet_paths: list[Path]) -> dict:
206
+ """Calculate the dataset statistics of all columns for a list of parquet files."""
207
+ # Dataset statistics
208
+ all_low_dim_data_list = []
209
+ # Collect all the data
210
+ # parquet_paths = parquet_paths[:3]
211
+ for parquet_path in tqdm(
212
+ sorted(list(parquet_paths)),
213
+ desc="Collecting all parquet files...",
214
+ ):
215
+ # Load the parquet file
216
+ parquet_data = pd.read_parquet(parquet_path)
217
+ parquet_data = parquet_data
218
+ all_low_dim_data_list.append(parquet_data)
219
+
220
+ all_low_dim_data = pd.concat(all_low_dim_data_list, axis=0)
221
+ # Compute dataset statistics
222
+ dataset_statistics = {}
223
+ for le_modality in all_low_dim_data.columns:
224
+ if le_modality.startswith("annotation."):
225
+ continue
226
+ print(f"Computing statistics for {le_modality}...")
227
+ np_data = np.vstack(
228
+ [np.asarray(x, dtype=np.float32) for x in all_low_dim_data[le_modality]]
229
+ )
230
+ dataset_statistics[le_modality] = {
231
+ "mean": np.mean(np_data, axis=0).tolist(),
232
+ "std": np.std(np_data, axis=0).tolist(),
233
+ "min": np.min(np_data, axis=0).tolist(),
234
+ "max": np.max(np_data, axis=0).tolist(),
235
+ "q01": np.quantile(np_data, 0.01, axis=0).tolist(),
236
+ "q99": np.quantile(np_data, 0.99, axis=0).tolist(),
237
+ }
238
+ return dataset_statistics
239
+
240
+
241
+ class ModalityConfig(BaseModel):
242
+ """Configuration for a modality."""
243
+
244
+ delta_indices: list[int]
245
+ """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices."""
246
+ modality_keys: list[str]
247
+ """The keys to load for the modality in the dataset."""
248
+
249
+
250
+ class LeRobotSingleDataset(Dataset):
251
+ """
252
+ Base dataset class for LeRobot that supports sharding.
253
+ """
254
+ def __init__(
255
+ self,
256
+ dataset_path: Path | str,
257
+ modality_configs: dict[str, ModalityConfig],
258
+ embodiment_tag: str | EmbodimentTag,
259
+ video_backend: str = "decord",
260
+ video_backend_kwargs: dict | None = None,
261
+ transforms: ComposedModalityTransform | None = None,
262
+ delete_pause_frame: bool = False,
263
+ **kwargs,
264
+ ):
265
+ """
266
+ Initialize the dataset.
267
+
268
+ Args:
269
+ dataset_path (Path | str): The path to the dataset.
270
+ modality_configs (dict[str, ModalityConfig]): The configuration for each modality. The keys are the modality names, and the values are the modality configurations.
271
+ See `ModalityConfig` for more details.
272
+ video_backend (str): Backend for video reading.
273
+ video_backend_kwargs (dict): Keyword arguments for the video backend when initializing the video reader.
274
+ transforms (ComposedModalityTransform): The transforms to apply to the dataset.
275
+ embodiment_tag (EmbodimentTag): Overload the embodiment tag for the dataset. e.g. define it as "new_embodiment"
276
+ """
277
+ # first check if the path directory exists
278
+ if not Path(dataset_path).exists():
279
+ raise FileNotFoundError(f"Dataset path {dataset_path} does not exist")
280
+ data_cfg = kwargs.get("data_cfg", {}) or {}
281
+ # indict letobot version
282
+ self._lerobot_version = data_cfg.get("lerobot_version", "v2.0") #self._indict_lerobot_version(**kwargs)
283
+ self.load_video = data_cfg.get("load_video", True)
284
+
285
+ self.delete_pause_frame = delete_pause_frame
286
+
287
+ # If video loading is disabled, skip video modality end-to-end.
288
+ if self.load_video:
289
+ self.modality_configs = modality_configs
290
+ else:
291
+ self.modality_configs = {
292
+ modality: config
293
+ for modality, config in modality_configs.items()
294
+ if modality != "video"
295
+ }
296
+ self.video_backend = video_backend
297
+ self.video_backend_kwargs = video_backend_kwargs if video_backend_kwargs is not None else {}
298
+ self.transforms = (
299
+ transforms if transforms is not None else ComposedModalityTransform(transforms=[])
300
+ )
301
+
302
+ self._dataset_path = Path(dataset_path)
303
+ self._dataset_name = self._dataset_path.name
304
+ self._dataset_id = DATASET_NAME_TO_ID.get(self._dataset_name)
305
+ if isinstance(embodiment_tag, EmbodimentTag):
306
+ self.tag = embodiment_tag.value
307
+ else:
308
+ self.tag = embodiment_tag
309
+
310
+ self._metadata = self._get_metadata(EmbodimentTag(self.tag))
311
+
312
+ # LeRobot-specific config
313
+ self._lerobot_modality_meta = self._get_lerobot_modality_meta()
314
+ self._lerobot_info_meta = self._get_lerobot_info_meta()
315
+ self._data_path_pattern = self._get_data_path_pattern()
316
+ self._video_path_pattern = self._get_video_path_pattern()
317
+ self._chunk_size = self._get_chunk_size()
318
+ self._tasks = self._get_tasks()
319
+ self.curr_traj_data = None
320
+ self.curr_traj_id = None
321
+
322
+ self._trajectory_ids, self._trajectory_lengths = self._get_trajectories()
323
+ self._modality_keys = self._get_modality_keys()
324
+ self._delta_indices = self._get_delta_indices()
325
+ self._all_steps = self._get_all_steps()
326
+ self.set_transforms_metadata(self.metadata)
327
+ self.set_epoch(0)
328
+
329
+ print(f"Initialized dataset {self.dataset_name} with {embodiment_tag}")
330
+
331
+
332
+ # Check if the dataset is valid
333
+ self._check_integrity()
334
+
335
+ @property
336
+ def dataset_path(self) -> Path:
337
+ """The path to the dataset that contains the METADATA_FILENAME file."""
338
+ return self._dataset_path
339
+
340
+ @property
341
+ def metadata(self) -> DatasetMetadata:
342
+ """The metadata for the dataset, loaded from metadata.json in the dataset directory"""
343
+ return self._metadata
344
+
345
+ @property
346
+ def trajectory_ids(self) -> np.ndarray:
347
+ """The trajectory IDs in the dataset, stored as a 1D numpy array of strings."""
348
+ return self._trajectory_ids
349
+
350
+ @property
351
+ def trajectory_lengths(self) -> np.ndarray:
352
+ """The trajectory lengths in the dataset, stored as a 1D numpy array of integers.
353
+ The order of the lengths is the same as the order of the trajectory IDs.
354
+ """
355
+ return self._trajectory_lengths
356
+
357
+ @property
358
+ def all_steps(self) -> list[tuple[int, int]]:
359
+ """The trajectory IDs and base indices for all steps in the dataset.
360
+ Example:
361
+ self.trajectory_ids: [0, 1, 2]
362
+ self.trajectory_lengths: [3, 2, 4]
363
+ return: [
364
+ ("traj_0", 0), ("traj_0", 1), ("traj_0", 2),
365
+ ("traj_1", 0), ("traj_1", 1),
366
+ ("traj_2", 0), ("traj_2", 1), ("traj_2", 2), ("traj_2", 3)
367
+ ]
368
+ """
369
+ return self._all_steps
370
+
371
+ @property
372
+ def modality_keys(self) -> dict:
373
+ """The modality keys for the dataset. The keys are the modality names, and the values are the keys for each modality.
374
+
375
+ Example: {
376
+ "video": ["video.image_side_0", "video.image_side_1"],
377
+ "state": ["state.eef_position", "state.eef_rotation"],
378
+ "action": ["action.eef_position", "action.eef_rotation"],
379
+ "language": ["language.human.task"],
380
+ "timestamp": ["timestamp"],
381
+ "reward": ["reward"],
382
+ }
383
+ """
384
+ return self._modality_keys
385
+
386
+ @property
387
+ def delta_indices(self) -> dict[str, np.ndarray]:
388
+ """The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key."""
389
+ return self._delta_indices
390
+
391
+ @property
392
+ def dataset_name(self) -> str:
393
+ """The name of the dataset."""
394
+ return self._dataset_name
395
+
396
+ @property
397
+ def lerobot_modality_meta(self) -> LeRobotModalityMetadata:
398
+ """The metadata for the LeRobot dataset."""
399
+ return self._lerobot_modality_meta
400
+
401
+ @property
402
+ def lerobot_info_meta(self) -> dict:
403
+ """The metadata for the LeRobot dataset."""
404
+ return self._lerobot_info_meta
405
+
406
+ @property
407
+ def data_path_pattern(self) -> str:
408
+ """The path pattern for the LeRobot dataset."""
409
+ return self._data_path_pattern
410
+
411
+ @property
412
+ def video_path_pattern(self) -> str:
413
+ """The path pattern for the LeRobot dataset."""
414
+ return self._video_path_pattern
415
+
416
+ @property
417
+ def chunk_size(self) -> int:
418
+ """The chunk size for the LeRobot dataset."""
419
+ return self._chunk_size
420
+
421
+ @property
422
+ def tasks(self) -> pd.DataFrame:
423
+ """The tasks for the dataset."""
424
+ return self._tasks
425
+
426
+ def _get_metadata(self, embodiment_tag: EmbodimentTag) -> DatasetMetadata:
427
+ """Get the metadata for the dataset.
428
+
429
+ Returns:
430
+ dict: The metadata for the dataset.
431
+ """
432
+
433
+ # 1. Modality metadata
434
+ modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME
435
+ assert (
436
+ modality_meta_path.exists()
437
+ ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}"
438
+ # 1.1. State and action modalities
439
+ simplified_modality_meta: dict[str, dict] = {}
440
+ with open(modality_meta_path, "r") as f:
441
+ le_modality_meta = LeRobotModalityMetadata.model_validate(json.load(f))
442
+ for modality in ["state", "action"]:
443
+ simplified_modality_meta[modality] = {}
444
+ le_state_action_meta: dict[str, LeRobotStateActionMetadata] = getattr(
445
+ le_modality_meta, modality
446
+ )
447
+ for subkey in le_state_action_meta:
448
+ state_action_dtype = np.dtype(le_state_action_meta[subkey].dtype)
449
+ if np.issubdtype(state_action_dtype, np.floating):
450
+ continuous = True
451
+ else:
452
+ continuous = False
453
+ simplified_modality_meta[modality][subkey] = {
454
+ "absolute": le_state_action_meta[subkey].absolute,
455
+ "rotation_type": le_state_action_meta[subkey].rotation_type,
456
+ "shape": [
457
+ le_state_action_meta[subkey].end - le_state_action_meta[subkey].start
458
+ ],
459
+ "continuous": continuous,
460
+ }
461
+
462
+ # 1.2. Video modalities
463
+ le_info_path = self.dataset_path / LE_ROBOT_INFO_FILENAME
464
+ assert (
465
+ le_info_path.exists()
466
+ ), f"Please provide a {LE_ROBOT_INFO_FILENAME} file in {self.dataset_path}"
467
+ with open(le_info_path, "r") as f:
468
+ le_info = json.load(f)
469
+ simplified_modality_meta["video"] = {}
470
+ for new_key in le_modality_meta.video:
471
+ original_key = le_modality_meta.video[new_key].original_key
472
+ if original_key is None:
473
+ original_key = new_key
474
+ le_video_meta = le_info["features"][original_key]
475
+ height = le_video_meta["shape"][le_video_meta["names"].index("height")]
476
+ width = le_video_meta["shape"][le_video_meta["names"].index("width")]
477
+ # NOTE(FH): different lerobot dataset versions have different keys for the number of channels and fps
478
+ try:
479
+ channels = le_video_meta["shape"][le_video_meta["names"].index("channel")]
480
+ fps = le_video_meta["video_info"]["video.fps"]
481
+ except (ValueError, KeyError):
482
+ # channels = le_video_meta["shape"][le_video_meta["names"].index("channels")]
483
+ channels = le_video_meta["info"]["video.channels"]
484
+ fps = le_video_meta["info"]["video.fps"]
485
+ simplified_modality_meta["video"][new_key] = {
486
+ "resolution": [width, height],
487
+ "channels": channels,
488
+ "fps": fps,
489
+ }
490
+
491
+ # 2. Dataset statistics
492
+ stats_path = self.dataset_path / LE_ROBOT_STATS_FILENAME
493
+ try:
494
+ with open(stats_path, "r") as f:
495
+ le_statistics = json.load(f)
496
+ for stat in le_statistics.values():
497
+ DatasetStatisticalValues.model_validate(stat)
498
+ except (FileNotFoundError, ValidationError) as e:
499
+ print(f"Failed to load dataset statistics: {e}")
500
+ print(f"Calculating dataset statistics for {self.dataset_name}")
501
+ # Get all parquet files in the dataset paths
502
+ parquet_files = list((self.dataset_path).glob(LE_ROBOT_DATA_FILENAME))
503
+ parquet_files_filtered = []
504
+ # parquet_files[0].name = "episode_033675.parquet" is broken file
505
+ for pf in parquet_files:
506
+ if "episode_033675.parquet" in pf.name:
507
+ continue
508
+ parquet_files_filtered.append(pf)
509
+
510
+ le_statistics = calculate_dataset_statistics(parquet_files_filtered)
511
+ with open(stats_path, "w") as f:
512
+ json.dump(le_statistics, f, indent=4)
513
+ dataset_statistics = {}
514
+ for our_modality in ["state", "action"]:
515
+ dataset_statistics[our_modality] = {}
516
+ for subkey in simplified_modality_meta[our_modality]:
517
+ dataset_statistics[our_modality][subkey] = {}
518
+ state_action_meta = le_modality_meta.get_key_meta(f"{our_modality}.{subkey}")
519
+ assert isinstance(state_action_meta, LeRobotStateActionMetadata)
520
+ le_modality = state_action_meta.original_key
521
+ for stat_name in le_statistics[le_modality]:
522
+ indices = np.arange(
523
+ state_action_meta.start,
524
+ state_action_meta.end,
525
+ )
526
+ stat = np.array(le_statistics[le_modality][stat_name])
527
+ dataset_statistics[our_modality][subkey][stat_name] = stat[indices].tolist()
528
+
529
+ # 3. Full dataset metadata
530
+ metadata = DatasetMetadata(
531
+ statistics=dataset_statistics, # type: ignore
532
+ modalities=simplified_modality_meta, # type: ignore
533
+ embodiment_tag=embodiment_tag,
534
+ )
535
+
536
+ return metadata
537
+
538
+ def _get_trajectories(self) -> tuple[np.ndarray, np.ndarray]:
539
+ """Get the trajectories in the dataset."""
540
+ # Get trajectory lengths, IDs, and whitelist from dataset metadata
541
+ # v2.0
542
+ if self._lerobot_version == "v2.0":
543
+ file_path = self.dataset_path / LE_ROBOT_EPISODE_FILENAME
544
+ with open(file_path, "r") as f:
545
+ episode_metadata = [json.loads(line) for line in f]
546
+ trajectory_ids = []
547
+ trajectory_lengths = []
548
+ for episode in episode_metadata:
549
+ trajectory_ids.append(episode["episode_index"])
550
+ trajectory_lengths.append(episode["length"])
551
+ return np.array(trajectory_ids), np.array(trajectory_lengths)
552
+ # v3.0
553
+ elif self._lerobot_version == "v3.0":
554
+ file_paths = list((self.dataset_path).glob(LE_ROBOT3_EPISODE_FILENAME))
555
+ trajectory_ids = []
556
+ trajectory_lengths = []
557
+ # data_chunck_index = []
558
+ # data_file_index = []
559
+ # vido_from_index = []
560
+ self.trajectory_ids_to_metadata = {}
561
+ for file_path in file_paths:
562
+ episodes_data = pd.read_parquet(file_path)
563
+ for index, episode in episodes_data.iterrows():
564
+ trajectory_ids.append(episode["episode_index"])
565
+ trajectory_lengths.append(episode["length"])
566
+
567
+ # TODO auto map key? just map to file_path and file_from_index
568
+ episode_meta = {
569
+ "data/chunk_index": episode["data/chunk_index"],
570
+ "data/file_index": episode["data/file_index"],
571
+ "data/file_from_index": index,
572
+ }
573
+ if self.load_video:
574
+ episode_meta["videos/observation.images.wrist/from_timestamp"] = episode[
575
+ "videos/observation.images.wrist/from_timestamp"
576
+ ]
577
+ self.trajectory_ids_to_metadata[trajectory_ids[-1]] = episode_meta
578
+
579
+ # 这里应该可以直接读取到 save index 信息
580
+ return np.array(trajectory_ids), np.array(trajectory_lengths)
581
+
582
+ def _get_all_steps(self) -> list[tuple[int, int]]:
583
+ """Get the trajectory IDs and base indices for all steps in the dataset.
584
+
585
+ Returns:
586
+ list[tuple[str, int]]: A list of (trajectory_id, base_index) tuples.
587
+ """
588
+ # Create a hash key based on configuration to ensure cache validity
589
+ config_key = self._get_steps_config_key()
590
+
591
+ # Create a unique filename based on config_key
592
+ # steps_filename = f"steps_{config_key}.pkl"
593
+ # @BUG
594
+ # fast get static steps @fangjing --> don't use hash to dynamic sample
595
+ steps_filename = "steps_data_index.pkl"
596
+
597
+
598
+ steps_path = self.dataset_path / "meta" / steps_filename
599
+
600
+ # Try to load cached steps first
601
+ try:
602
+ if steps_path.exists():
603
+ with open(steps_path, "rb") as f:
604
+ cached_data = pickle.load(f)
605
+ return cached_data["steps"]
606
+
607
+ except (FileNotFoundError, pickle.PickleError, KeyError) as e:
608
+ print(f"Failed to load cached steps: {e}")
609
+ print("Computing steps from scratch...")
610
+
611
+ # Compute steps using single process
612
+ all_steps = self._get_all_steps_single_process()
613
+
614
+ # Cache the computed steps with unique filename
615
+ try:
616
+ cache_data = {
617
+ "config_key": config_key,
618
+ "steps": all_steps,
619
+ "num_trajectories": len(self.trajectory_ids),
620
+ "total_steps": len(all_steps),
621
+ "computed_timestamp": pd.Timestamp.now().isoformat(),
622
+ "delete_pause_frame": self.delete_pause_frame,
623
+ }
624
+
625
+ # Ensure the meta directory exists
626
+ steps_path.parent.mkdir(parents=True, exist_ok=True)
627
+
628
+ with open(steps_path, "wb") as f:
629
+ pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
630
+ print(f"Cached steps saved to {steps_path}")
631
+ except Exception as e:
632
+ print(f"Failed to cache steps: {e}")
633
+
634
+ return all_steps
635
+
636
+ def _get_steps_config_key(self) -> str:
637
+ """Generate a configuration key for steps caching."""
638
+ config_dict = {
639
+ "delete_pause_frame": self.delete_pause_frame,
640
+ "dataset_name": self.dataset_name,
641
+ }
642
+ # Create a hash of the configuration
643
+ config_str = str(sorted(config_dict.items()))
644
+ return hashlib.md5(config_str.encode()).hexdigest()[:12] #
645
+
646
+
647
+ def _get_all_steps_single_process(self) -> list[tuple[int, int]]:
648
+ """Original single-process implementation as fallback."""
649
+ all_steps: list[tuple[int, int]] = []
650
+ skipped_trajectories = 0
651
+ processed_trajectories = 0
652
+
653
+ # Check if language modality is configured
654
+ has_language_modality = 'language' in self.modality_keys and len(self.modality_keys['language']) > 0
655
+ # TODO why trajectory_length here, why not use data length?
656
+ for trajectory_id, trajectory_length in tqdm(zip(self.trajectory_ids, self.trajectory_lengths), total=len(self.trajectory_ids), desc="Getting All Step"):
657
+ try:
658
+ if self._lerobot_version == "v2.0":
659
+ data = self.get_trajectory_data(trajectory_id)
660
+ elif self._lerobot_version == "v3.0":
661
+ data = self.get_trajectory_data_lerobot_v3(trajectory_id)
662
+
663
+ trajectory_skipped = False
664
+
665
+ # Check if trajectory has valid language instruction (if language modality is configured)
666
+ if has_language_modality:
667
+ self.curr_traj_data = data # Set current trajectory data for get_language to work
668
+
669
+ language_instruction = self.get_language(trajectory_id, self.modality_keys['language'][0], 0)
670
+ if not language_instruction or language_instruction[0] == "":
671
+ print(f"Skipping trajectory {trajectory_id} due to empty language instruction")
672
+ skipped_trajectories += 1
673
+ trajectory_skipped = True
674
+ continue
675
+
676
+ except Exception as e:
677
+ print(f"Skipping trajectory {trajectory_id} due to read error: {e}")
678
+ skipped_trajectories += 1
679
+ trajectory_skipped = True
680
+ continue
681
+
682
+ if not trajectory_skipped:
683
+ processed_trajectories += 1
684
+
685
+ for base_index in range(trajectory_length):
686
+ all_steps.append((trajectory_id, base_index))
687
+
688
+ # Print summary statistics
689
+ print(f"Single-process summary: Processed {processed_trajectories} trajectories, skipped {skipped_trajectories} empty trajectories")
690
+ print(f"Total steps: {len(all_steps)} from {len(self.trajectory_ids)} trajectories")
691
+
692
+ return all_steps
693
+
694
+ def _get_position_and_gripper_values(self, data: pd.DataFrame) -> tuple[list, list]:
695
+ """Get position and gripper values based on available columns in the dataset."""
696
+ # Get action keys from modality_keys
697
+ action_keys = self.modality_keys.get('action', [])
698
+
699
+ # Extract position data
700
+ delta_position_values = None
701
+ position_candidates = ['delta_eef_position']
702
+ coordinate_candidates = ['x', 'y', 'z']
703
+
704
+ # First try combined position fields
705
+ for pos_key in position_candidates:
706
+ full_key = f"action.{pos_key}"
707
+ if full_key in action_keys:
708
+ try:
709
+ # Get the lerobot key for this modality
710
+ le_action_cfg = self.lerobot_modality_meta.action
711
+ subkey = pos_key
712
+ if subkey in le_action_cfg:
713
+ le_key = le_action_cfg[subkey].original_key or subkey
714
+ if le_key in data.columns:
715
+ data_array = np.stack(data[le_key])
716
+ le_indices = np.arange(le_action_cfg[subkey].start, le_action_cfg[subkey].end)
717
+ filtered_data = data_array[:, le_indices]
718
+ delta_position_values = filtered_data.tolist()
719
+ break
720
+ except Exception:
721
+ continue
722
+
723
+ # If combined fields not found, try individual x,y,z coordinates
724
+ if delta_position_values is None:
725
+ x_data, y_data, z_data = None, None, None
726
+ for coord in coordinate_candidates:
727
+ full_key = f"action.{coord}"
728
+ if full_key in action_keys:
729
+ try:
730
+ le_action_cfg = self.lerobot_modality_meta.action
731
+ if coord in le_action_cfg:
732
+ le_key = le_action_cfg[coord].original_key or coord
733
+ if le_key in data.columns:
734
+ data_array = np.stack(data[le_key])
735
+ le_indices = np.arange(le_action_cfg[coord].start, le_action_cfg[coord].end)
736
+ coord_data = data_array[:, le_indices].flatten()
737
+ if coord == 'x':
738
+ x_data = coord_data
739
+ elif coord == 'y':
740
+ y_data = coord_data
741
+ elif coord == 'z':
742
+ z_data = coord_data
743
+ except Exception:
744
+ continue
745
+
746
+ if x_data is not None and y_data is not None and z_data is not None:
747
+ delta_position_values = np.column_stack((x_data, y_data, z_data)).tolist()
748
+
749
+ if delta_position_values is None:
750
+ # Fallback to the old hardcoded approach if metadata approach fails
751
+ if 'action.delta_eef_position' in data.columns:
752
+ delta_position_values = data['action.delta_eef_position'].to_numpy().tolist()
753
+ elif all(col in data.columns for col in ['action.x', 'action.y', 'action.z']):
754
+ x_vals = data['action.x'].to_numpy()
755
+ y_vals = data['action.y'].to_numpy()
756
+ z_vals = data['action.z'].to_numpy()
757
+ delta_position_values = np.column_stack((x_vals, y_vals, z_vals)).tolist()
758
+ else:
759
+ raise ValueError(f"No suitable position columns found. Available columns: {data.columns.tolist()}")
760
+
761
+ # Extract gripper data
762
+ gripper_values = None
763
+ gripper_candidates = ['gripper_close', 'gripper']
764
+
765
+ for grip_key in gripper_candidates:
766
+ full_key = f"action.{grip_key}"
767
+ if full_key in action_keys:
768
+ try:
769
+ le_action_cfg = self.lerobot_modality_meta.action
770
+ if grip_key in le_action_cfg:
771
+ le_key = le_action_cfg[grip_key].original_key or grip_key
772
+ if le_key in data.columns:
773
+ data_array = np.stack(data[le_key])
774
+ le_indices = np.arange(le_action_cfg[grip_key].start, le_action_cfg[grip_key].end)
775
+ gripper_data = data_array[:, le_indices].flatten()
776
+ gripper_values = gripper_data.tolist()
777
+ break
778
+ except Exception:
779
+ continue
780
+
781
+ if gripper_values is None:
782
+ # Fallback to the old hardcoded approach if metadata approach fails
783
+ if 'action.gripper_close' in data.columns:
784
+ gripper_values = data['action.gripper_close'].to_numpy().tolist()
785
+ elif 'action.gripper' in data.columns:
786
+ gripper_values = data['action.gripper'].to_numpy().tolist()
787
+ else:
788
+ raise ValueError(f"No suitable gripper columns found. Available columns: {data.columns.tolist()}")
789
+
790
+ return delta_position_values, gripper_values
791
+
792
+ def _get_modality_keys(self) -> dict:
793
+ """Get the modality keys for the dataset.
794
+ The keys are the modality names, and the values are the keys for each modality.
795
+ See property `modality_keys` for the expected format.
796
+ """
797
+ modality_keys = defaultdict(list)
798
+ for modality, config in self.modality_configs.items():
799
+ modality_keys[modality] = config.modality_keys
800
+ return modality_keys
801
+
802
+ def _get_delta_indices(self) -> dict[str, np.ndarray]:
803
+ """Restructure the delta indices to use modality.key as keys instead of just the modalities."""
804
+ delta_indices: dict[str, np.ndarray] = {}
805
+ for config in self.modality_configs.values():
806
+ for key in config.modality_keys:
807
+ delta_indices[key] = np.array(config.delta_indices)
808
+ return delta_indices
809
+
810
+ def _get_lerobot_modality_meta(self) -> LeRobotModalityMetadata:
811
+ """Get the metadata for the LeRobot dataset."""
812
+ modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME
813
+ assert (
814
+ modality_meta_path.exists()
815
+ ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}"
816
+ with open(modality_meta_path, "r") as f:
817
+ modality_meta = LeRobotModalityMetadata.model_validate(json.load(f))
818
+ return modality_meta
819
+
820
+ def _get_lerobot_info_meta(self) -> dict:
821
+ """Get the metadata for the LeRobot dataset."""
822
+ info_meta_path = self.dataset_path / LE_ROBOT_INFO_FILENAME
823
+ with open(info_meta_path, "r") as f:
824
+ info_meta = json.load(f)
825
+ return info_meta
826
+
827
+ def _get_data_path_pattern(self) -> str:
828
+ """Get the data path pattern for the LeRobot dataset."""
829
+ return self.lerobot_info_meta["data_path"]
830
+
831
+ def _get_video_path_pattern(self) -> str:
832
+ """Get the video path pattern for the LeRobot dataset."""
833
+ return self.lerobot_info_meta["video_path"]
834
+
835
+ def _get_chunk_size(self) -> int:
836
+ """Get the chunk size for the LeRobot dataset."""
837
+ return self.lerobot_info_meta["chunks_size"]
838
+
839
+ def _get_tasks(self) -> pd.DataFrame:
840
+ """Get the tasks for the dataset."""
841
+ if self._lerobot_version == "v2.0":
842
+ tasks_path = self.dataset_path / LE_ROBOT_TASKS_FILENAME
843
+ with open(tasks_path, "r") as f:
844
+ tasks = [json.loads(line) for line in f]
845
+ df = pd.DataFrame(tasks)
846
+ return df.set_index("task_index")
847
+
848
+ elif self._lerobot_version == "v3.0":
849
+ tasks_path = self.dataset_path / LE_ROBOT3_TASKS_FILENAME
850
+ df = pd.read_parquet(tasks_path)
851
+ df = df.reset_index() # 把索引变成一列,列名通常为 'index'
852
+ df = df.rename(columns={'index': 'task'}) # 把 'index' 列重命名为 'task'
853
+ df = df[['task_index', 'task']] # 调整列顺序
854
+ return df
855
+ def _check_integrity(self):
856
+ """Use the config to check if the keys are valid and detect silent data corruption."""
857
+ ERROR_MSG_HEADER = f"Error occurred in initializing dataset {self.dataset_name}:\n"
858
+
859
+ for modality_config in self.modality_configs.values():
860
+ for key in modality_config.modality_keys:
861
+ if key == "lapa_action" or key == "dream_actions":
862
+ continue # no need for any metadata for lapa actions because it comes normalized
863
+ # Check if the key is valid
864
+ try:
865
+ self.lerobot_modality_meta.get_key_meta(key)
866
+ except Exception as e:
867
+ raise ValueError(
868
+ ERROR_MSG_HEADER + f"Unable to find key {key} in modality metadata:\n{e}"
869
+ )
870
+
871
+ def set_transforms_metadata(self, metadata: DatasetMetadata):
872
+ """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values."""
873
+ self.transforms.set_metadata(metadata)
874
+
875
+ def set_epoch(self, epoch: int):
876
+ """Set the epoch for the dataset.
877
+
878
+ Args:
879
+ epoch (int): The epoch to set.
880
+ """
881
+ self.epoch = epoch
882
+
883
+ def __len__(self) -> int:
884
+ """Get the total number of data points in the dataset.
885
+
886
+ Returns:
887
+ int: the total number of data points in the dataset.
888
+ """
889
+ return len(self.all_steps)
890
+
891
+ def __str__(self) -> str:
892
+ """Get the description of the dataset."""
893
+ return f"{self.dataset_name} ({len(self)} steps)"
894
+
895
+
896
+ def __getitem__(self, index: int) -> dict:
897
+ """Get the data for a single step in a trajectory.
898
+
899
+ Args:
900
+ index (int): The index of the step to get.
901
+
902
+ Returns:
903
+ dict: The data for the step.
904
+ """
905
+ trajectory_id, base_index = self.all_steps[index]
906
+ data = self.get_step_data(trajectory_id, base_index)
907
+
908
+ # Process all video keys dynamically
909
+ images = []
910
+ for video_key in self.modality_keys.get("video", []):
911
+ image = data[video_key][0]
912
+
913
+ image = Image.fromarray(image).resize((224, 224))
914
+ images.append(image)
915
+
916
+ # Get language and action data
917
+ language = data[self.modality_keys["language"][0]][0]
918
+ action = []
919
+ for action_key in self.modality_keys["action"]:
920
+ action.append(data[action_key])
921
+ action = np.concatenate(action, axis=1)
922
+ action = standardize_action_representation(action, self.tag)
923
+
924
+ state = []
925
+ for state_key in self.modality_keys["state"]:
926
+ state.append(data[state_key])
927
+ state = np.concatenate(state, axis=1)
928
+ state = standardize_state_representation(state, self.tag)
929
+
930
+ return dict(action=action, state=state, image=images, language=language, dataset_id=self._dataset_id)
931
+
932
+ def get_step_data(self, trajectory_id: int, base_index: int) -> dict:
933
+ """Get the RAW data for a single step in a trajectory. No transforms are applied.
934
+
935
+ Args:
936
+ trajectory_id (int): The name of the trajectory.
937
+ base_index (int): The base step index in the trajectory.
938
+
939
+ Returns:
940
+ dict: The RAW data for the step.
941
+
942
+ Example return:
943
+ {
944
+ "video": {
945
+ "video.image_side_0": [B, T, H, W, C],
946
+ "video.image_side_1": [B, T, H, W, C],
947
+ },
948
+ "state": {
949
+ "state.eef_position": [B, T, state_dim],
950
+ "state.eef_rotation": [B, T, state_dim],
951
+ },
952
+ "action": {
953
+ "action.eef_position": [B, T, action_dim],
954
+ "action.eef_rotation": [B, T, action_dim],
955
+ },
956
+ }
957
+ """
958
+ data = {}
959
+ # Get the data for all modalities # just for action base data
960
+ self.curr_traj_data = self.get_trajectory_data(trajectory_id)
961
+ # TODO @JinhuiYE The logic below is poorly implemented. Data reading should be directly based on curr_traj_data.
962
+ for modality in self.modality_keys:
963
+ # Get the data corresponding to each key in the modality
964
+ for key in self.modality_keys[modality]:
965
+ data[key] = self.get_data_by_modality(trajectory_id, modality, key, base_index)
966
+ return data
967
+
968
+ def get_trajectory_data(self, trajectory_id: int) -> pd.DataFrame:
969
+ """Get the data for a trajectory."""
970
+ if self._lerobot_version == "v2.0":
971
+
972
+ if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None:
973
+ return self.curr_traj_data
974
+ else:
975
+ chunk_index = self.get_episode_chunk(trajectory_id)
976
+ parquet_path = self.dataset_path / self.data_path_pattern.format(
977
+ episode_chunk=chunk_index, episode_index=trajectory_id
978
+ )
979
+ assert parquet_path.exists(), f"Parquet file not found at {parquet_path}"
980
+ return pd.read_parquet(parquet_path)
981
+ elif self._lerobot_version == "v3.0":
982
+ return self.get_trajectory_data_lerobot_v3(trajectory_id)
983
+
984
+ def get_trajectory_data_lerobot_v3(self, trajectory_id: int) -> pd.DataFrame:
985
+ """Get the data for a trajectory from lerobot v3."""
986
+ if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None:
987
+ return self.curr_traj_data
988
+ else: #TODO check detail later
989
+ chunk_index = self.get_episode_chunk(trajectory_id)
990
+
991
+ file_index = self.get_episode_file_index(trajectory_id)
992
+ # file_from_index = self.get_episode_file_from_index(trajectory_id)
993
+
994
+
995
+ parquet_path = self.dataset_path / self.data_path_pattern.format(
996
+ chunk_index=chunk_index, file_index=file_index
997
+ )
998
+ assert parquet_path.exists(), f"Parquet file not found at {parquet_path}"
999
+ file_data = pd.read_parquet(parquet_path)
1000
+
1001
+ # filter by trajectory_id
1002
+ episode_data = file_data.loc[file_data["episode_index"] == trajectory_id].copy()
1003
+
1004
+ # fix timestamp from epis index to file index for video alignment
1005
+ if self.load_video:
1006
+ from_timestamp = self.trajectory_ids_to_metadata[trajectory_id].get(
1007
+ "videos/observation.images.wrist/from_timestamp", 0
1008
+ )
1009
+ episode_data["timestamp"] = episode_data["timestamp"] + from_timestamp
1010
+
1011
+ return episode_data
1012
+
1013
+
1014
+ def get_trajectory_index(self, trajectory_id: int) -> int:
1015
+ """Get the index of the trajectory in the dataset by the trajectory ID.
1016
+ This is useful when you need to get the trajectory length or sampling weight corresponding to the trajectory ID.
1017
+
1018
+ Args:
1019
+ trajectory_id (str): The ID of the trajectory.
1020
+
1021
+ Returns:
1022
+ int: The index of the trajectory in the dataset.
1023
+ """
1024
+ trajectory_indices = np.where(self.trajectory_ids == trajectory_id)[0]
1025
+ if len(trajectory_indices) != 1:
1026
+ raise ValueError(
1027
+ f"Error finding trajectory index for {trajectory_id}, found {trajectory_indices=}"
1028
+ )
1029
+ return trajectory_indices[0]
1030
+
1031
+ def get_episode_chunk(self, ep_index: int) -> int:
1032
+ """Get the chunk index for an episode index."""
1033
+ return ep_index // self.chunk_size
1034
+ def get_episode_file_index(self, ep_index: int) -> int:
1035
+ """Get the file index for an episode index."""
1036
+ episode_meta = self.trajectory_ids_to_metadata[ep_index]
1037
+ return episode_meta["data/file_index"]
1038
+
1039
+ def get_episode_file_from_index(self, ep_index: int) -> int:
1040
+ """Get the file from index for an episode index."""
1041
+ episode_meta = self.trajectory_ids_to_metadata[ep_index]
1042
+ return episode_meta["data/file_from_index"]
1043
+
1044
+
1045
+ def retrieve_data_and_pad(
1046
+ self,
1047
+ array: np.ndarray,
1048
+ step_indices: np.ndarray,
1049
+ max_length: int,
1050
+ padding_strategy: str = "first_last",
1051
+ ) -> np.ndarray:
1052
+ """Retrieve the data from the dataset and pad it if necessary.
1053
+ Args:
1054
+ array (np.ndarray): The array to retrieve the data from.
1055
+ step_indices (np.ndarray): The step indices to retrieve the data for.
1056
+ max_length (int): The maximum length of the data.
1057
+ padding_strategy (str): The padding strategy, either "first" or "last".
1058
+ """
1059
+ # Get the padding indices
1060
+ front_padding_indices = step_indices < 0
1061
+ end_padding_indices = step_indices >= max_length
1062
+ padding_positions = np.logical_or(front_padding_indices, end_padding_indices)
1063
+ # Retrieve the data with the non-padding indices
1064
+ # If there exists some padding, Given T step_indices, the shape of the retrieved data will be (T', ...) where T' < T
1065
+ raw_data = array[step_indices[~padding_positions]]
1066
+ assert isinstance(raw_data, np.ndarray), f"{type(raw_data)=}"
1067
+ # This is the shape of the output, (T, ...)
1068
+ if raw_data.ndim == 1:
1069
+ expected_shape = (len(step_indices),)
1070
+ else:
1071
+ expected_shape = (len(step_indices), *array.shape[1:])
1072
+
1073
+ # Pad the data
1074
+ output = np.zeros(expected_shape)
1075
+ # Assign the non-padded data
1076
+ output[~padding_positions] = raw_data
1077
+ # If there exists some padding, pad the data
1078
+ if padding_positions.any():
1079
+ if padding_strategy == "first_last":
1080
+ # Use first / last step data to pad
1081
+ front_padding_data = array[0]
1082
+ end_padding_data = array[-1]
1083
+ output[front_padding_indices] = front_padding_data
1084
+ output[end_padding_indices] = end_padding_data
1085
+ elif padding_strategy == "zero":
1086
+ # Use zero padding
1087
+ output[padding_positions] = 0
1088
+ else:
1089
+ raise ValueError(f"Invalid padding strategy: {padding_strategy}")
1090
+ return output
1091
+
1092
+ def get_video_path(self, trajectory_id: int, key: str) -> Path:
1093
+ chunk_index = self.get_episode_chunk(trajectory_id)
1094
+ original_key = self.lerobot_modality_meta.video[key].original_key
1095
+ if original_key is None:
1096
+ original_key = key
1097
+ if self._lerobot_version == "v2.0":
1098
+ video_filename = self.video_path_pattern.format(
1099
+ episode_chunk=chunk_index, episode_index=trajectory_id, video_key=original_key
1100
+ )
1101
+ elif self._lerobot_version == "v3.0":
1102
+ episode_meta = self.trajectory_ids_to_metadata[trajectory_id]
1103
+ video_filename = self.video_path_pattern.format(
1104
+ video_key=original_key,
1105
+ chunk_index=episode_meta["data/chunk_index"],
1106
+ file_index=episode_meta["data/file_index"],
1107
+ )
1108
+ return self.dataset_path / video_filename
1109
+
1110
+ def get_video(
1111
+ self,
1112
+ trajectory_id: int,
1113
+ key: str,
1114
+ base_index: int,
1115
+ ) -> np.ndarray:
1116
+ """Get the video frames for a trajectory by a base index.
1117
+
1118
+ Args:
1119
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1120
+ trajectory_id (str): The ID of the trajectory.
1121
+ key (str): The key of the video.
1122
+ base_index (int): The base index of the trajectory.
1123
+
1124
+ Returns:
1125
+ np.ndarray: The video frames for the trajectory and frame indices. Shape: (T, H, W, C)
1126
+ """
1127
+ # Get the step indices
1128
+ step_indices = self.delta_indices[key] + base_index
1129
+ # print(f"{step_indices=}")
1130
+ # Get the trajectory index
1131
+ trajectory_index = self.get_trajectory_index(trajectory_id)
1132
+ # Ensure the indices are within the valid range
1133
+ # This is equivalent to padding the video with extra frames at the beginning and end
1134
+ step_indices = np.maximum(step_indices, 0)
1135
+ step_indices = np.minimum(step_indices, self.trajectory_lengths[trajectory_index] - 1)
1136
+ assert key.startswith("video."), f"Video key must start with 'video.', got {key}"
1137
+ # Get the sub-key
1138
+ key = key.replace("video.", "")
1139
+ video_path = self.get_video_path(trajectory_id, key)
1140
+ # Get the action/state timestamps for each frame in the video
1141
+ assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
1142
+ assert "timestamp" in self.curr_traj_data.columns, f"No timestamp found in {trajectory_id=}"
1143
+ timestamp: np.ndarray = self.curr_traj_data["timestamp"].to_numpy()
1144
+ # Get the corresponding video timestamps from the step indices
1145
+ video_timestamp = timestamp[step_indices]
1146
+
1147
+ return get_frames_by_timestamps(
1148
+ video_path.as_posix(),
1149
+ video_timestamp,
1150
+ video_backend=self.video_backend, # TODO
1151
+ video_backend_kwargs=self.video_backend_kwargs,
1152
+ )
1153
+
1154
+ def get_state_or_action(
1155
+ self,
1156
+ trajectory_id: int,
1157
+ modality: str,
1158
+ key: str,
1159
+ base_index: int,
1160
+ ) -> np.ndarray:
1161
+ """Get the state or action data for a trajectory by a base index.
1162
+ If the step indices are out of range, pad with the data:
1163
+ if the data is stored in absolute format, pad with the first or last step data;
1164
+ otherwise, pad with zero.
1165
+
1166
+ Args:
1167
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1168
+ trajectory_id (int): The ID of the trajectory.
1169
+ modality (str): The modality of the data.
1170
+ key (str): The key of the data.
1171
+ base_index (int): The base index of the trajectory.
1172
+
1173
+ Returns:
1174
+ np.ndarray: The data for the trajectory and step indices.
1175
+ """
1176
+ # Get the step indices
1177
+ step_indices = self.delta_indices[key] + base_index
1178
+ # Get the trajectory index
1179
+ trajectory_index = self.get_trajectory_index(trajectory_id)
1180
+ # Get the maximum length of the trajectory
1181
+ max_length = self.trajectory_lengths[trajectory_index]
1182
+ assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}"
1183
+ # Get the sub-key, e.g. state.joint_angles -> joint_angles
1184
+ key = key.replace(modality + ".", "")
1185
+ # Get the lerobot key
1186
+ le_state_or_action_cfg = getattr(self.lerobot_modality_meta, modality)
1187
+ le_key = le_state_or_action_cfg[key].original_key
1188
+ if le_key is None:
1189
+ le_key = key
1190
+ # Get the data array, shape: (T, D)
1191
+ assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
1192
+ assert le_key in self.curr_traj_data.columns, f"No {le_key} found in {trajectory_id=}"
1193
+ data_array: np.ndarray = np.stack(self.curr_traj_data[le_key]) # type: ignore
1194
+ assert data_array.ndim == 2, f"Expected 2D array, got key {le_key} is{data_array.shape} array"
1195
+ le_indices = np.arange(
1196
+ le_state_or_action_cfg[key].start,
1197
+ le_state_or_action_cfg[key].end,
1198
+ )
1199
+ data_array = data_array[:, le_indices]
1200
+ # Get the state or action configuration
1201
+ state_or_action_cfg = getattr(self.metadata.modalities, modality)[key]
1202
+
1203
+ # Pad the data
1204
+ return self.retrieve_data_and_pad(
1205
+ array=data_array,
1206
+ step_indices=step_indices,
1207
+ max_length=max_length,
1208
+ padding_strategy="first_last" if state_or_action_cfg.absolute else "zero",
1209
+ # padding_strategy="zero", # HACK for realdata
1210
+ )
1211
+
1212
+ def get_language(
1213
+ self,
1214
+ trajectory_id: int,
1215
+ key: str,
1216
+ base_index: int,
1217
+ ) -> list[str]:
1218
+ """Get the language annotation data for a trajectory by step indices.
1219
+
1220
+ Args:
1221
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1222
+ trajectory_id (int): The ID of the trajectory.
1223
+ key (str): The key of the annotation.
1224
+ base_index (int): The base index of the trajectory.
1225
+
1226
+ Returns:
1227
+ list[str]: The annotation data for the trajectory and step indices. If no matching data is found, return empty strings.
1228
+ """
1229
+ assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
1230
+ # Get the step indices
1231
+ step_indices = self.delta_indices[key] + base_index
1232
+ # Get the trajectory index
1233
+ trajectory_index = self.get_trajectory_index(trajectory_id)
1234
+ # Get the maximum length of the trajectory
1235
+ max_length = self.trajectory_lengths[trajectory_index]
1236
+ # Get the end times corresponding to the closest indices
1237
+ step_indices = np.maximum(step_indices, 0)
1238
+ step_indices = np.minimum(step_indices, max_length - 1)
1239
+ # Get the annotations
1240
+ task_indices: list[int] = []
1241
+ assert key.startswith(
1242
+ "annotation."
1243
+ ), f"Language key must start with 'annotation.', got {key}"
1244
+ subkey = key.replace("annotation.", "")
1245
+ annotation_meta = self.lerobot_modality_meta.annotation
1246
+ assert annotation_meta is not None, f"Annotation metadata is None for {subkey}"
1247
+ assert (
1248
+ subkey in annotation_meta
1249
+ ), f"Annotation key {subkey} not found in metadata, available annotation keys: {annotation_meta.keys()}"
1250
+ subkey_meta = annotation_meta[subkey]
1251
+ original_key = subkey_meta.original_key
1252
+ if original_key is None:
1253
+ original_key = key
1254
+ for i in range(len(step_indices)): #
1255
+ # task_indices.append(self.curr_traj_data[original_key][step_indices[i]].item())
1256
+ value = self.curr_traj_data[original_key].iloc[step_indices[i]] # TODO check v2.0
1257
+ task_indices.append(value if isinstance(value, (int, float)) else value.item())
1258
+
1259
+ return self.tasks.loc[task_indices]["task"].tolist()
1260
+
1261
+ def get_data_by_modality(
1262
+ self,
1263
+ trajectory_id: int,
1264
+ modality: str,
1265
+ key: str,
1266
+ base_index: int,
1267
+ ):
1268
+ """Get the data corresponding to the modality for a trajectory by a base index.
1269
+ This method will call the corresponding helper method based on the modality.
1270
+ See the helper methods for more details.
1271
+ NOTE: For the language modality, the data is padded with empty strings if no matching data is found.
1272
+
1273
+ Args:
1274
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1275
+ trajectory_id (int): The ID of the trajectory.
1276
+ modality (str): The modality of the data.
1277
+ key (str): The key of the data.
1278
+ base_index (int): The base index of the trajectory.
1279
+ """
1280
+ if modality == "video":
1281
+ return self.get_video(trajectory_id, key, base_index)
1282
+ elif modality == "state" or modality == "action":
1283
+ return self.get_state_or_action(trajectory_id, modality, key, base_index)
1284
+ elif modality == "language":
1285
+ return self.get_language(trajectory_id, key, base_index)
1286
+ else:
1287
+ raise ValueError(f"Invalid modality: {modality}")
1288
+
1289
+ def _save_dataset_statistics_(self, save_path: Path | str, format: str = "json") -> None:
1290
+ """
1291
+ Save dataset statistics to specified path in the required format.
1292
+ Only includes statistics for keys that are actually used in the dataset.
1293
+ Key order follows modality config order.
1294
+
1295
+ Args:
1296
+ save_path (Path | str): Path to save the statistics file
1297
+ format (str): Save format, currently only supports "json"
1298
+ """
1299
+ save_path = Path(save_path)
1300
+ save_path.parent.mkdir(parents=True, exist_ok=True)
1301
+
1302
+ # Build the data structure to save
1303
+ statistics_data = {}
1304
+
1305
+ # Get used modality keys
1306
+ used_action_keys, used_state_keys = get_used_modality_keys(self.modality_keys)
1307
+
1308
+ # Organize statistics by tag
1309
+ tag = self.tag
1310
+ tag_stats = {}
1311
+
1312
+ # Process action statistics (only for used keys, config order)
1313
+ if hasattr(self.metadata.statistics, 'action') and self.metadata.statistics.action:
1314
+ action_stats = self.metadata.statistics.action
1315
+ filtered_action_stats = {
1316
+ key: action_stats[key]
1317
+ for key in used_action_keys
1318
+ if key in action_stats
1319
+ }
1320
+
1321
+ if filtered_action_stats:
1322
+ # Combine statistics from filtered action sub-keys
1323
+ combined_action_stats = combine_modality_stats(filtered_action_stats)
1324
+
1325
+ # Add mask field based on whether it's gripper or not
1326
+ mask = generate_action_mask_for_used_keys(
1327
+ self.metadata.modalities.action, filtered_action_stats.keys()
1328
+ )
1329
+ combined_action_stats["mask"] = mask
1330
+
1331
+ tag_stats["action"] = combined_action_stats
1332
+
1333
+ # Process state statistics (only for used keys, config order)
1334
+ if hasattr(self.metadata.statistics, 'state') and self.metadata.statistics.state:
1335
+ state_stats = self.metadata.statistics.state
1336
+ filtered_state_stats = {
1337
+ key: state_stats[key]
1338
+ for key in used_state_keys
1339
+ if key in state_stats
1340
+ }
1341
+
1342
+ if filtered_state_stats:
1343
+ combined_state_stats = combine_modality_stats(filtered_state_stats)
1344
+ tag_stats["state"] = combined_state_stats
1345
+
1346
+ # Add dataset counts
1347
+ tag_stats["num_transitions"] = len(self)
1348
+ tag_stats["num_trajectories"] = len(self.trajectory_ids)
1349
+
1350
+ statistics_data[tag] = tag_stats
1351
+
1352
+ # Save as JSON file
1353
+ if format.lower() == "json":
1354
+ if not str(save_path).endswith('.json'):
1355
+ save_path = save_path.with_suffix('.json')
1356
+ with open(save_path, 'w', encoding='utf-8') as f:
1357
+ json.dump(statistics_data, f, indent=2, ensure_ascii=False)
1358
+ else:
1359
+ raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.")
1360
+
1361
+ print(f"Single dataset statistics saved to: {save_path}")
1362
+ print(f"Used action keys (reordered): {list(used_action_keys)}")
1363
+ print(f"Used state keys (reordered): {list(used_state_keys)}")
1364
+
1365
+
1366
+
1367
+ class MixtureSpecElement(BaseModel):
1368
+ dataset_path: list[Path] | Path = Field(..., description="The path to the dataset.")
1369
+ dataset_weight: float = Field(..., description="The weight of the dataset in the mixture.")
1370
+ distribute_weights: bool = Field(
1371
+ default=False,
1372
+ description="Whether to distribute the weights of the dataset across all the paths. If True, the weights will be evenly distributed across all the paths.",
1373
+ )
1374
+
1375
+
1376
+ # Helper functions for dataset statistics
1377
+
1378
+ def combine_modality_stats(modality_stats: dict) -> dict:
1379
+ """
1380
+ Combine statistics from all sub-keys under a modality.
1381
+
1382
+ Args:
1383
+ modality_stats (dict): Statistics for a modality, containing multiple sub-keys.
1384
+ Each sub-key contains DatasetStatisticalValues object.
1385
+
1386
+ Returns:
1387
+ dict: Combined statistics
1388
+ """
1389
+ combined_stats = {
1390
+ "mean": [],
1391
+ "std": [],
1392
+ "max": [],
1393
+ "min": [],
1394
+ "q01": [],
1395
+ "q99": []
1396
+ }
1397
+
1398
+ # Combine statistics in sub-key order
1399
+ for subkey in modality_stats.keys():
1400
+ subkey_stats = modality_stats[subkey] # This is a DatasetStatisticalValues object
1401
+
1402
+ # Convert DatasetStatisticalValues to dict-like access
1403
+ for stat_name in ["mean", "std", "max", "min", "q01", "q99"]:
1404
+ stat_value = getattr(subkey_stats, stat_name)
1405
+ if isinstance(stat_value, (list, tuple)):
1406
+ combined_stats[stat_name].extend(stat_value)
1407
+ else:
1408
+ # Handle NDArray case - convert to list
1409
+ if hasattr(stat_value, 'tolist'):
1410
+ combined_stats[stat_name].extend(stat_value.tolist())
1411
+ else:
1412
+ combined_stats[stat_name].append(float(stat_value))
1413
+
1414
+ return combined_stats
1415
+
1416
+ def generate_action_mask_for_used_keys(action_modalities: dict, used_action_keys_ordered) -> list[bool]:
1417
+ """
1418
+ Generate mask based on action modalities, but only for used keys.
1419
+ All dimensions are set to True so every channel is de/normalized.
1420
+
1421
+ Args:
1422
+ action_modalities (dict): Configuration information for action modalities.
1423
+ used_action_keys_ordered: Iterable of actually used action keys in the correct order.
1424
+
1425
+ Returns:
1426
+ list[bool]: List of mask values
1427
+ """
1428
+ mask = []
1429
+
1430
+ # Generate mask in the same order as the statistics were combined
1431
+ for subkey in used_action_keys_ordered:
1432
+ if subkey in action_modalities:
1433
+ subkey_config = action_modalities[subkey]
1434
+
1435
+ # Get dimension count from shape
1436
+ if hasattr(subkey_config, 'shape') and len(subkey_config.shape) > 0:
1437
+ dim_count = subkey_config.shape[0]
1438
+ else:
1439
+ dim_count = 1
1440
+
1441
+ # Check if it's gripper-related
1442
+ is_gripper = "gripper" in subkey.lower()
1443
+
1444
+ # Generate mask value for each dimension
1445
+ for _ in range(dim_count):
1446
+ mask.append(not is_gripper) # gripper is False, others are True
1447
+
1448
+ return mask
1449
+
1450
+ def get_used_modality_keys(modality_keys: dict) -> tuple[set, set]:
1451
+ """Extract used action and state keys from modality configuration."""
1452
+ used_action_keys = []
1453
+ used_state_keys = []
1454
+
1455
+ # Extract action keys (remove "action." prefix)
1456
+ for action_key in modality_keys.get("action", []):
1457
+ if action_key.startswith("action."):
1458
+ clean_key = action_key.replace("action.", "")
1459
+ used_action_keys.append(clean_key)
1460
+
1461
+ # Extract state keys (remove "state." prefix)
1462
+ for state_key in modality_keys.get("state", []):
1463
+ if state_key.startswith("state."):
1464
+ clean_key = state_key.replace("state.", "")
1465
+ used_state_keys.append(clean_key)
1466
+
1467
+ return used_action_keys, used_state_keys
1468
+
1469
+
1470
+ def safe_hash(input_tuple):
1471
+ # keep 128 bits of the hash
1472
+ tuple_string = repr(input_tuple).encode("utf-8")
1473
+ sha256 = hashlib.sha256()
1474
+ sha256.update(tuple_string)
1475
+
1476
+ seed = int(sha256.hexdigest(), 16)
1477
+
1478
+ return seed & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
1479
+
1480
+
1481
+ class LeRobotMixtureDataset(Dataset):
1482
+ """
1483
+ A mixture of multiple datasets. This class samples a single dataset based on the dataset weights and then calls the `__getitem__` method of the sampled dataset.
1484
+ It is recommended to modify the single dataset class instead of this class.
1485
+ """
1486
+
1487
+ def __init__(
1488
+ self,
1489
+ data_mixture: Sequence[tuple[LeRobotSingleDataset, float]],
1490
+ mode: str,
1491
+ balance_dataset_weights: bool = True,
1492
+ balance_trajectory_weights: bool = True,
1493
+ seed: int = 42,
1494
+ metadata_config: dict = {
1495
+ "percentile_mixing_method": "min_max",
1496
+ },
1497
+ **kwargs,
1498
+ ):
1499
+ """
1500
+ Initialize the mixture dataset.
1501
+
1502
+ Args:
1503
+ data_mixture (list[tuple[LeRobotSingleDataset, float]]): Datasets and their corresponding weights.
1504
+ mode (str): If "train", __getitem__ will return different samples every epoch; if "val" or "test", __getitem__ will return the same sample every epoch.
1505
+ balance_dataset_weights (bool): If True, the weight of dataset will be multiplied by the total trajectory length of each dataset.
1506
+ balance_trajectory_weights (bool): If True, sample trajectories within a dataset weighted by their length; otherwise, use equal weighting.
1507
+ seed (int): Random seed for sampling.
1508
+ """
1509
+ datasets: list[LeRobotSingleDataset] = []
1510
+ dataset_sampling_weights: list[float] = []
1511
+ for dataset, weight in data_mixture:
1512
+ # Check if dataset is valid and has data
1513
+ if len(dataset) == 0:
1514
+ print(f"Warning: Skipping empty dataset {dataset.dataset_name}")
1515
+ continue
1516
+ datasets.append(dataset)
1517
+ dataset_sampling_weights.append(weight)
1518
+
1519
+ if len(datasets) == 0:
1520
+ raise ValueError("No valid datasets found in the mixture. All datasets are empty.")
1521
+
1522
+ self.datasets = datasets
1523
+ self.balance_dataset_weights = balance_dataset_weights
1524
+ self.balance_trajectory_weights = balance_trajectory_weights
1525
+ self.seed = seed
1526
+ self.mode = mode
1527
+
1528
+ # Set properties for sampling
1529
+
1530
+ # 1. Dataset lengths
1531
+ self._dataset_lengths = np.array([len(dataset) for dataset in self.datasets])
1532
+ print(f"Dataset lengths: {self._dataset_lengths}")
1533
+
1534
+ # 2. Dataset sampling weights
1535
+ self._dataset_sampling_weights = np.array(dataset_sampling_weights)
1536
+
1537
+ if self.balance_dataset_weights:
1538
+ self._dataset_sampling_weights *= self._dataset_lengths
1539
+
1540
+ # Check for zero or negative weights before normalization
1541
+ if np.any(self._dataset_sampling_weights <= 0):
1542
+ print(f"Warning: Found zero or negative sampling weights: {self._dataset_sampling_weights}")
1543
+ # Set minimum weight to prevent division issues
1544
+ self._dataset_sampling_weights = np.maximum(self._dataset_sampling_weights, 1e-8)
1545
+
1546
+ # Normalize weights
1547
+ weights_sum = self._dataset_sampling_weights.sum()
1548
+ if weights_sum == 0 or np.isnan(weights_sum):
1549
+ print(f"Error: Invalid weights sum: {weights_sum}")
1550
+ # Fallback to equal weights
1551
+ self._dataset_sampling_weights = np.ones(len(self.datasets)) / len(self.datasets)
1552
+ print(f"Fallback to equal weights")
1553
+ else:
1554
+ self._dataset_sampling_weights /= weights_sum
1555
+
1556
+ # 3. Trajectory sampling weights
1557
+ self._trajectory_sampling_weights: list[np.ndarray] = []
1558
+ for i, dataset in enumerate(self.datasets):
1559
+ trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths))
1560
+ if self.balance_trajectory_weights:
1561
+ trajectory_sampling_weights *= dataset.trajectory_lengths
1562
+
1563
+ # Check for zero or negative weights before normalization
1564
+ if np.any(trajectory_sampling_weights <= 0):
1565
+ print(f"Warning: Dataset {i} has zero or negative trajectory weights")
1566
+ trajectory_sampling_weights = np.maximum(trajectory_sampling_weights, 1e-8)
1567
+
1568
+ # Normalize weights
1569
+ weights_sum = trajectory_sampling_weights.sum()
1570
+ if weights_sum == 0 or np.isnan(weights_sum):
1571
+ print(f"Error: Dataset {i} has invalid trajectory weights sum: {weights_sum}")
1572
+ # Fallback to equal weights
1573
+ trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) / len(dataset.trajectory_lengths)
1574
+ else:
1575
+ trajectory_sampling_weights /= weights_sum
1576
+
1577
+ self._trajectory_sampling_weights.append(trajectory_sampling_weights)
1578
+
1579
+ # 4. Primary dataset indices
1580
+ self._primary_dataset_indices = np.array(dataset_sampling_weights) == 1.0
1581
+ if not np.any(self._primary_dataset_indices):
1582
+ print(f"Warning: No dataset with weight 1.0 found. Original weights: {dataset_sampling_weights}")
1583
+ # Fallback: use the dataset(s) with maximum weight as primary
1584
+ max_weight = max(dataset_sampling_weights)
1585
+ self._primary_dataset_indices = np.array(dataset_sampling_weights) == max_weight
1586
+ print(f"Using datasets with maximum weight {max_weight} as primary: {self._primary_dataset_indices}")
1587
+
1588
+ if not np.any(self._primary_dataset_indices):
1589
+ # This should never happen, but just in case
1590
+ print("Error: Still no primary dataset found. Using first dataset as primary.")
1591
+ self._primary_dataset_indices = np.zeros(len(self.datasets), dtype=bool)
1592
+ self._primary_dataset_indices[0] = True
1593
+
1594
+ # Set the epoch and sample the first epoch
1595
+ self.set_epoch(0)
1596
+
1597
+ self.update_metadata(metadata_config)
1598
+
1599
+ @property
1600
+ def dataset_lengths(self) -> np.ndarray:
1601
+ """The lengths of each dataset."""
1602
+ return self._dataset_lengths
1603
+
1604
+ @property
1605
+ def dataset_sampling_weights(self) -> np.ndarray:
1606
+ """The sampling weights for each dataset."""
1607
+ return self._dataset_sampling_weights
1608
+
1609
+ @property
1610
+ def trajectory_sampling_weights(self) -> list[np.ndarray]:
1611
+ """The sampling weights for each trajectory in each dataset."""
1612
+ return self._trajectory_sampling_weights
1613
+
1614
+ @property
1615
+ def primary_dataset_indices(self) -> np.ndarray:
1616
+ """The indices of the primary datasets."""
1617
+ return self._primary_dataset_indices
1618
+
1619
+ def __str__(self) -> str:
1620
+ dataset_descriptions = []
1621
+ for dataset, weight in zip(self.datasets, self.dataset_sampling_weights):
1622
+ dataset_description = {
1623
+ "Dataset": str(dataset),
1624
+ "Sampling weight": float(weight),
1625
+ }
1626
+ dataset_descriptions.append(dataset_description)
1627
+ return json.dumps({"Mixture dataset": dataset_descriptions}, indent=2)
1628
+
1629
+ def set_epoch(self, epoch: int):
1630
+ """Set the epoch for the dataset.
1631
+
1632
+ Args:
1633
+ epoch (int): The epoch to set.
1634
+ """
1635
+ self.epoch = epoch
1636
+ # self.sampled_steps = self.sample_epoch()
1637
+
1638
+ def sample_step(self, index: int) -> tuple[LeRobotSingleDataset, int, int]:
1639
+ """Sample a single step from the dataset."""
1640
+ # return self.sampled_steps[index]
1641
+
1642
+ # Set seed
1643
+ seed = index if self.mode != "train" else safe_hash((self.epoch, index, self.seed))
1644
+ rng = np.random.default_rng(seed)
1645
+
1646
+ # Sample dataset
1647
+ dataset_index = rng.choice(len(self.datasets), p=self.dataset_sampling_weights)
1648
+ dataset = self.datasets[dataset_index]
1649
+
1650
+ # Sample trajectory
1651
+ # trajectory_index = rng.choice(
1652
+ # len(dataset.trajectory_ids), p=self.trajectory_sampling_weights[dataset_index]
1653
+ # )
1654
+ # trajectory_id = dataset.trajectory_ids[trajectory_index]
1655
+
1656
+ # # Sample step
1657
+ # base_index = rng.choice(dataset.trajectory_lengths[trajectory_index])
1658
+ # return dataset, trajectory_id, base_index
1659
+ single_step_index = rng.choice(len(dataset.all_steps))
1660
+ trajectory_id, base_index = dataset.all_steps[single_step_index]
1661
+ return dataset, trajectory_id, base_index
1662
+
1663
+ def __getitem__(self, index: int) -> dict:
1664
+ """Get the data for a single trajectory and start index.
1665
+
1666
+ Args:
1667
+ index (int): The index of the trajectory to get.
1668
+
1669
+ Returns:
1670
+ dict: The data for the trajectory and start index.
1671
+ """
1672
+ max_retries = 10
1673
+ last_exception = None
1674
+
1675
+ for attempt in range(max_retries):
1676
+ try:
1677
+ dataset, trajectory_name, step = self.sample_step(index)
1678
+ data_raw = dataset.get_step_data(trajectory_name, step)
1679
+ data = dataset.transforms(data_raw)
1680
+
1681
+ # Process all video keys dynamically
1682
+ images = []
1683
+ for video_key in dataset.modality_keys.get("video", []):
1684
+ image = data[video_key][0]
1685
+
1686
+ image = Image.fromarray(image).resize((224, 224)) #TODO check if this is ok
1687
+ images.append(image)
1688
+
1689
+ # Get language and action data
1690
+ language = data[dataset.modality_keys["language"][0]][0]
1691
+ action = []
1692
+ for action_key in dataset.modality_keys["action"]:
1693
+ action.append(data[action_key])
1694
+ action = np.concatenate(action, axis=1).astype(np.float16)
1695
+ action = standardize_action_representation(action, dataset.tag)
1696
+
1697
+ state = []
1698
+ for state_key in dataset.modality_keys["state"]:
1699
+ state.append(data[state_key])
1700
+ state = np.concatenate(state, axis=1).astype(np.float16)
1701
+ state = standardize_state_representation(state, dataset.tag)
1702
+
1703
+ return dict(action=action, state=state, image=images, lang=language, dataset_id=dataset._dataset_id)
1704
+
1705
+ except Exception as e:
1706
+ last_exception = e
1707
+ if attempt < max_retries - 1:
1708
+ # Log the error but continue trying
1709
+ print(f"Attempt {attempt + 1}/{max_retries} failed for index {index}: {e}")
1710
+ print(f"Retrying with new sample...")
1711
+ # For retry, we can use a slightly different index to get a new sample
1712
+ # This helps avoid getting stuck on the same problematic sample
1713
+ index = random.randint(0, len(self) - 1)
1714
+ else:
1715
+ # All retries exhausted
1716
+ print(f"All {max_retries} attempts failed for index {index}")
1717
+ print(f"Last error: {last_exception}")
1718
+ # Return a dummy sample or re-raise the exception
1719
+ raise last_exception
1720
+
1721
+ def __len__(self) -> int:
1722
+ """Get the length of a single epoch in the mixture.
1723
+
1724
+ Returns:
1725
+ int: The length of a single epoch in the mixture.
1726
+ """
1727
+ # Check for potential issues
1728
+ if len(self.datasets) == 0:
1729
+ return 0
1730
+
1731
+ # Check if any dataset lengths are 0 or NaN
1732
+ if np.any(self.dataset_lengths == 0) or np.any(np.isnan(self.dataset_lengths)):
1733
+ print(f"Warning: Found zero or NaN dataset lengths: {self.dataset_lengths}")
1734
+ # Filter out zero/NaN length datasets
1735
+ valid_indices = (self.dataset_lengths > 0) & (~np.isnan(self.dataset_lengths))
1736
+ if not np.any(valid_indices):
1737
+ print("Error: All datasets have zero or NaN length")
1738
+ return 0
1739
+ else:
1740
+ valid_indices = np.ones(len(self.datasets), dtype=bool)
1741
+
1742
+ # Check if any sampling weights are 0 or NaN
1743
+ if np.any(self.dataset_sampling_weights == 0) or np.any(np.isnan(self.dataset_sampling_weights)):
1744
+ print(f"Warning: Found zero or NaN sampling weights: {self.dataset_sampling_weights}")
1745
+ # Use only valid weights
1746
+ valid_weights = (self.dataset_sampling_weights > 0) & (~np.isnan(self.dataset_sampling_weights))
1747
+ valid_indices = valid_indices & valid_weights
1748
+ if not np.any(valid_indices):
1749
+ print("Error: All sampling weights are zero or NaN")
1750
+ return 0
1751
+
1752
+ # Check primary dataset indices
1753
+ primary_and_valid = self.primary_dataset_indices & valid_indices
1754
+ if not np.any(primary_and_valid):
1755
+ print(f"Warning: No valid primary datasets found. Primary indices: {self.primary_dataset_indices}, Valid indices: {valid_indices}")
1756
+ # Fallback: use the largest valid dataset
1757
+ if np.any(valid_indices):
1758
+ max_length = self.dataset_lengths[valid_indices].max()
1759
+ print(f"Fallback: Using maximum dataset length: {max_length}")
1760
+ return int(max_length)
1761
+ else:
1762
+ return 0
1763
+
1764
+ # Calculate the ratio and get max
1765
+ ratios = (self.dataset_lengths / self.dataset_sampling_weights)[primary_and_valid]
1766
+
1767
+ # Check for NaN or inf in ratios
1768
+ if np.any(np.isnan(ratios)) or np.any(np.isinf(ratios)):
1769
+ print(f"Warning: Found NaN or inf in ratios: {ratios}")
1770
+ print(f"Dataset lengths: {self.dataset_lengths[primary_and_valid]}")
1771
+ print(f"Sampling weights: {self.dataset_sampling_weights[primary_and_valid]}")
1772
+ # Filter out invalid ratios
1773
+ valid_ratios = ratios[~np.isnan(ratios) & ~np.isinf(ratios)]
1774
+ if len(valid_ratios) == 0:
1775
+ print("Error: All ratios are NaN or inf")
1776
+ return 0
1777
+ max_ratio = valid_ratios.max()
1778
+ else:
1779
+ max_ratio = ratios.max()
1780
+
1781
+ result = int(max_ratio)
1782
+ if result == 0:
1783
+ print(f"Warning: Dataset mixture length is 0")
1784
+ return result
1785
+
1786
+ @staticmethod
1787
+ def compute_overall_statistics(
1788
+ per_task_stats: list[dict[str, dict[str, list[float] | np.ndarray]]],
1789
+ dataset_sampling_weights: list[float] | np.ndarray,
1790
+ percentile_mixing_method: str = "weighted_average",
1791
+ ) -> dict[str, dict[str, list[float]]]:
1792
+ """
1793
+ Computes overall statistics from per-task statistics using dataset sample weights.
1794
+
1795
+ Args:
1796
+ per_task_stats: List of per-task statistics.
1797
+ Example format of one element in the per-task statistics list:
1798
+ {
1799
+ "state.gripper": {
1800
+ "min": [...],
1801
+ "max": [...],
1802
+ "mean": [...],
1803
+ "std": [...],
1804
+ "q01": [...],
1805
+ "q99": [...],
1806
+ },
1807
+ ...
1808
+ }
1809
+ dataset_sampling_weights: List of sample weights for each task.
1810
+ percentile_mixing_method: The method to mix the percentiles, either "weighted_average" or "weighted_std".
1811
+
1812
+ Returns:
1813
+ A dict of overall statistics per modality.
1814
+ """
1815
+ # Normalize the sample weights to sum to 1
1816
+ dataset_sampling_weights = np.array(dataset_sampling_weights)
1817
+ normalized_weights = dataset_sampling_weights / dataset_sampling_weights.sum()
1818
+
1819
+ # Initialize overall statistics dict
1820
+ overall_stats: dict[str, dict[str, list[float]]] = {}
1821
+
1822
+ # Get the list of modality keys
1823
+ modality_keys = per_task_stats[0].keys()
1824
+
1825
+ for modality in modality_keys:
1826
+ # Number of dimensions (assuming consistent across tasks)
1827
+ num_dims = len(per_task_stats[0][modality]["mean"])
1828
+
1829
+ # Initialize accumulators for means and variances
1830
+ weighted_means = np.zeros(num_dims)
1831
+ weighted_squares = np.zeros(num_dims)
1832
+
1833
+ # Collect min, max, q01, q99 from all tasks
1834
+ min_list = []
1835
+ max_list = []
1836
+ q01_list = []
1837
+ q99_list = []
1838
+
1839
+ for task_idx, task_stats in enumerate(per_task_stats):
1840
+ w_i = normalized_weights[task_idx]
1841
+ stats = task_stats[modality]
1842
+ means = np.array(stats["mean"])
1843
+ stds = np.array(stats["std"])
1844
+
1845
+ # Update weighted sums for mean and variance
1846
+ weighted_means += w_i * means
1847
+ weighted_squares += w_i * (stds**2 + means**2)
1848
+
1849
+ # Collect min, max, q01, q99
1850
+ min_list.append(stats["min"])
1851
+ max_list.append(stats["max"])
1852
+ q01_list.append(stats["q01"])
1853
+ q99_list.append(stats["q99"])
1854
+
1855
+ # Compute overall mean
1856
+ overall_mean = weighted_means.tolist()
1857
+
1858
+ # Compute overall variance and std deviation
1859
+ overall_variance = weighted_squares - weighted_means**2
1860
+ overall_std = np.sqrt(overall_variance).tolist()
1861
+
1862
+ # Compute overall min and max per dimension
1863
+ overall_min = np.min(np.array(min_list), axis=0).tolist()
1864
+ overall_max = np.max(np.array(max_list), axis=0).tolist()
1865
+
1866
+ # Compute overall q01 and q99 per dimension
1867
+ # Use weighted average of per-task quantiles
1868
+ q01_array = np.array(q01_list)
1869
+ q99_array = np.array(q99_list)
1870
+ if percentile_mixing_method == "weighted_average":
1871
+ weighted_q01 = np.average(q01_array, axis=0, weights=normalized_weights).tolist()
1872
+ weighted_q99 = np.average(q99_array, axis=0, weights=normalized_weights).tolist()
1873
+ # std_q01 = np.std(q01_array, axis=0).tolist()
1874
+ # std_q99 = np.std(q99_array, axis=0).tolist()
1875
+ # print(modality)
1876
+ # print(f"{std_q01=}, {std_q99=}")
1877
+ # print(f"{weighted_q01=}, {weighted_q99=}")
1878
+ elif percentile_mixing_method == "min_max":
1879
+ weighted_q01 = np.min(q01_array, axis=0).tolist()
1880
+ weighted_q99 = np.max(q99_array, axis=0).tolist()
1881
+ else:
1882
+ raise ValueError(f"Invalid percentile mixing method: {percentile_mixing_method}")
1883
+
1884
+ # Store the overall statistics for the modality
1885
+ overall_stats[modality] = {
1886
+ "min": overall_min,
1887
+ "max": overall_max,
1888
+ "mean": overall_mean,
1889
+ "std": overall_std,
1890
+ "q01": weighted_q01,
1891
+ "q99": weighted_q99,
1892
+ }
1893
+
1894
+ return overall_stats
1895
+
1896
+ @staticmethod
1897
+ def merge_metadata(
1898
+ metadatas: list[DatasetMetadata],
1899
+ dataset_sampling_weights: list[float],
1900
+ percentile_mixing_method: str,
1901
+ ) -> DatasetMetadata:
1902
+ """Merge multiple metadata into one."""
1903
+ # Convert to dicts
1904
+ metadata_dicts = [metadata.model_dump(mode="json") for metadata in metadatas]
1905
+ # Create a new metadata dict
1906
+ merged_metadata = {}
1907
+
1908
+ # Check all metadata have the same embodiment tag
1909
+ assert all(
1910
+ metadata.embodiment_tag == metadatas[0].embodiment_tag for metadata in metadatas
1911
+ ), "All metadata must have the same embodiment tag"
1912
+ merged_metadata["embodiment_tag"] = metadatas[0].embodiment_tag
1913
+
1914
+ # Merge the dataset statistics
1915
+ dataset_statistics = {}
1916
+ dataset_statistics["state"] = LeRobotMixtureDataset.compute_overall_statistics(
1917
+ per_task_stats=[m["statistics"]["state"] for m in metadata_dicts],
1918
+ dataset_sampling_weights=dataset_sampling_weights,
1919
+ percentile_mixing_method=percentile_mixing_method,
1920
+ )
1921
+ dataset_statistics["action"] = LeRobotMixtureDataset.compute_overall_statistics(
1922
+ per_task_stats=[m["statistics"]["action"] for m in metadata_dicts],
1923
+ dataset_sampling_weights=dataset_sampling_weights,
1924
+ percentile_mixing_method=percentile_mixing_method,
1925
+ )
1926
+ merged_metadata["statistics"] = dataset_statistics
1927
+
1928
+ # Merge the modality configs
1929
+ modality_configs = defaultdict(set)
1930
+ for metadata in metadata_dicts:
1931
+ for modality, configs in metadata["modalities"].items():
1932
+ modality_configs[modality].add(json.dumps(configs))
1933
+ merged_metadata["modalities"] = {}
1934
+ for modality, configs in modality_configs.items():
1935
+ # Check that all modality configs correspond to the same tag matches
1936
+ assert (
1937
+ len(configs) == 1
1938
+ ), f"Multiple modality configs for modality {modality}: {list(configs)}"
1939
+ merged_metadata["modalities"][modality] = json.loads(configs.pop())
1940
+
1941
+ return DatasetMetadata.model_validate(merged_metadata)
1942
+
1943
+ def update_metadata(self, metadata_config: dict, cached_statistics_path: Path | str | None = None) -> None:
1944
+ """
1945
+ Merge multiple metadatas into one and set the transforms with the merged metadata.
1946
+
1947
+ Args:
1948
+ metadata_config (dict): Configuration for the metadata.
1949
+ "percentile_mixing_method": The method to mix the percentiles, either "weighted_average" or "min_max".
1950
+ weighted_average: Use the weighted average of the percentiles using the weight used in sampling the datasets.
1951
+ min_max: Use the min of the 1st percentile and max of the 99th percentile.
1952
+ """
1953
+ # If cached path is provided, try to load and apply
1954
+ if cached_statistics_path is not None:
1955
+ try:
1956
+ cached_stats = self.load_merged_statistics(cached_statistics_path)
1957
+ self.apply_cached_statistics(cached_stats)
1958
+ return
1959
+ except (FileNotFoundError, KeyError, ValidationError) as e:
1960
+ print(f"Failed to load cached statistics: {e}")
1961
+ print("Falling back to computing statistics from scratch...")
1962
+
1963
+ self.tag = EmbodimentTag.NEW_EMBODIMENT.value
1964
+ self.merged_metadata: dict[str, DatasetMetadata] = {}
1965
+ # Group metadata by tag
1966
+ all_metadatas: dict[str, list[DatasetMetadata]] = {}
1967
+ for dataset in self.datasets:
1968
+ if dataset.tag not in all_metadatas:
1969
+ all_metadatas[dataset.tag] = []
1970
+ all_metadatas[dataset.tag].append(dataset.metadata)
1971
+ for tag, metadatas in all_metadatas.items():
1972
+ self.merged_metadata[tag] = self.merge_metadata(
1973
+ metadatas=metadatas,
1974
+ dataset_sampling_weights=self.dataset_sampling_weights.tolist(),
1975
+ percentile_mixing_method=metadata_config["percentile_mixing_method"],
1976
+ )
1977
+ for dataset in self.datasets:
1978
+ dataset.set_transforms_metadata(self.merged_metadata[dataset.tag])
1979
+
1980
+ def save_dataset_statistics(self, save_path: Path | str, format: str = "json") -> None:
1981
+ """
1982
+ Save merged dataset statistics to specified path in the required format.
1983
+ Only includes statistics for keys that are actually used in the datasets.
1984
+ Key order follows each tag's modality config order.
1985
+
1986
+ Args:
1987
+ save_path (Path | str): Path to save the statistics file
1988
+ format (str): Save format, currently only supports "json"
1989
+ """
1990
+ save_path = Path(save_path)
1991
+ save_path.parent.mkdir(parents=True, exist_ok=True)
1992
+
1993
+ # Build the data structure to save
1994
+ statistics_data = {}
1995
+
1996
+ # Keep key orders per embodiment tag (from modality config order)
1997
+ tag_to_used_action_keys = {}
1998
+ tag_to_used_state_keys = {}
1999
+ for dataset in self.datasets:
2000
+ if dataset.tag in tag_to_used_action_keys:
2001
+ continue
2002
+ used_action_keys, used_state_keys = get_used_modality_keys(dataset.modality_keys)
2003
+ tag_to_used_action_keys[dataset.tag] = used_action_keys
2004
+ tag_to_used_state_keys[dataset.tag] = used_state_keys
2005
+
2006
+ # Organize statistics by tag
2007
+ for tag, merged_metadata in self.merged_metadata.items():
2008
+ tag_stats = {}
2009
+
2010
+ # Process action statistics
2011
+ if hasattr(merged_metadata.statistics, 'action') and merged_metadata.statistics.action:
2012
+ action_stats = merged_metadata.statistics.action
2013
+
2014
+ used_action_keys = tag_to_used_action_keys.get(tag, [])
2015
+ filtered_action_stats = {
2016
+ key: action_stats[key]
2017
+ for key in used_action_keys
2018
+ if key in action_stats
2019
+ }
2020
+
2021
+ if filtered_action_stats:
2022
+ combined_action_stats = combine_modality_stats(filtered_action_stats)
2023
+
2024
+ mask = generate_action_mask_for_used_keys(
2025
+ merged_metadata.modalities.action, filtered_action_stats.keys()
2026
+ )
2027
+ combined_action_stats["mask"] = mask
2028
+
2029
+ tag_stats["action"] = combined_action_stats
2030
+
2031
+ # Process state statistics
2032
+ if hasattr(merged_metadata.statistics, 'state') and merged_metadata.statistics.state:
2033
+ state_stats = merged_metadata.statistics.state
2034
+
2035
+ used_state_keys = tag_to_used_state_keys.get(tag, [])
2036
+ filtered_state_stats = {
2037
+ key: state_stats[key]
2038
+ for key in used_state_keys
2039
+ if key in state_stats
2040
+ }
2041
+
2042
+ if filtered_state_stats:
2043
+ combined_state_stats = combine_modality_stats(filtered_state_stats)
2044
+ tag_stats["state"] = combined_state_stats
2045
+
2046
+ # Add dataset counts
2047
+ tag_stats.update(self._get_dataset_counts(tag))
2048
+
2049
+ statistics_data[tag] = tag_stats
2050
+
2051
+ # Save file
2052
+ if format.lower() == "json":
2053
+ if not str(save_path).endswith('.json'):
2054
+ save_path = save_path.with_suffix('.json')
2055
+ with open(save_path, 'w', encoding='utf-8') as f:
2056
+ json.dump(statistics_data, f, indent=2, ensure_ascii=False)
2057
+ else:
2058
+ raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.")
2059
+
2060
+ print(f"Merged dataset statistics saved to: {save_path}")
2061
+ print(f"Used action keys by tag: {tag_to_used_action_keys}")
2062
+ print(f"Used state keys by tag: {tag_to_used_state_keys}")
2063
+
2064
+
2065
+ def _combine_modality_stats(self, modality_stats: dict) -> dict:
2066
+ """Backward compatibility wrapper."""
2067
+ return combine_modality_stats(modality_stats)
2068
+
2069
+ def _generate_action_mask_for_used_keys(self, action_modalities: dict, used_action_keys_ordered) -> list[bool]:
2070
+ """Backward compatibility wrapper."""
2071
+ return generate_action_mask_for_used_keys(action_modalities, used_action_keys_ordered)
2072
+
2073
+ def _get_dataset_counts(self, tag: str) -> dict:
2074
+ """
2075
+ Get dataset count information for specified tag.
2076
+
2077
+ Args:
2078
+ tag (str): embodiment tag
2079
+
2080
+ Returns:
2081
+ dict: Dictionary containing num_transitions and num_trajectories
2082
+ """
2083
+ num_transitions = 0
2084
+ num_trajectories = 0
2085
+
2086
+ # Count dataset information belonging to this tag
2087
+ for dataset in self.datasets:
2088
+ if dataset.tag == tag:
2089
+ num_transitions += len(dataset)
2090
+ num_trajectories += len(dataset.trajectory_ids)
2091
+
2092
+ return {
2093
+ "num_transitions": num_transitions,
2094
+ "num_trajectories": num_trajectories
2095
+ }
2096
+
2097
+ @classmethod
2098
+ def load_merged_statistics(cls, load_path: Path | str) -> dict:
2099
+ """
2100
+ Load merged dataset statistics from file.
2101
+
2102
+ Args:
2103
+ load_path (Path | str): Path to the statistics file
2104
+
2105
+ Returns:
2106
+ dict: Dictionary containing merged statistics
2107
+ """
2108
+ load_path = Path(load_path)
2109
+ if not load_path.exists():
2110
+ raise FileNotFoundError(f"Statistics file not found: {load_path}")
2111
+
2112
+ if load_path.suffix.lower() == '.json':
2113
+ with open(load_path, 'r', encoding='utf-8') as f:
2114
+ return json.load(f)
2115
+ elif load_path.suffix.lower() == '.pkl':
2116
+ import pickle
2117
+ with open(load_path, 'rb') as f:
2118
+ return pickle.load(f)
2119
+ else:
2120
+ raise ValueError(f"Unsupported file format: {load_path.suffix}")
2121
+
2122
+ def apply_cached_statistics(self, cached_statistics: dict) -> None:
2123
+ """
2124
+ Apply cached statistics to avoid recomputation.
2125
+
2126
+ Args:
2127
+ cached_statistics (dict): Statistics loaded from file
2128
+ """
2129
+ # Validate that cached statistics match current datasets
2130
+ if "metadata" in cached_statistics:
2131
+ cached_dataset_names = set(cached_statistics["metadata"]["dataset_names"])
2132
+ current_dataset_names = set(dataset.dataset_name for dataset in self.datasets)
2133
+
2134
+ if cached_dataset_names != current_dataset_names:
2135
+ print("Warning: Cached statistics dataset names don't match current datasets.")
2136
+ print(f"Cached: {cached_dataset_names}")
2137
+ print(f"Current: {current_dataset_names}")
2138
+ return
2139
+
2140
+ # Apply cached statistics
2141
+ self.merged_metadata = {}
2142
+ for tag, stats_data in cached_statistics.items():
2143
+ if tag == "metadata": # Skip metadata field
2144
+ continue
2145
+
2146
+ # Convert back to DatasetMetadata format
2147
+ metadata_dict = {
2148
+ "embodiment_tag": tag,
2149
+ "statistics": {
2150
+ "action": {},
2151
+ "state": {}
2152
+ },
2153
+ "modalities": {}
2154
+ }
2155
+
2156
+ # Convert action statistics back
2157
+ if "action" in stats_data:
2158
+ action_data = stats_data["action"]
2159
+ # This is simplified - you may need to split back to sub-keys
2160
+ metadata_dict["statistics"]["action"] = action_data
2161
+
2162
+ # Convert state statistics back
2163
+ if "state" in stats_data:
2164
+ state_data = stats_data["state"]
2165
+ metadata_dict["statistics"]["state"] = state_data
2166
+
2167
+ self.merged_metadata[tag] = DatasetMetadata.model_validate(metadata_dict)
2168
+
2169
+ # Update transforms metadata for each dataset
2170
+ for dataset in self.datasets:
2171
+ if dataset.tag in self.merged_metadata:
2172
+ dataset.set_transforms_metadata(self.merged_metadata[dataset.tag])
2173
+
2174
+ print(f"Applied cached statistics for {len(self.merged_metadata)} embodiment tags.")
2175
+
code/dataloader/gr00t_lerobot/datasets_bak2.py ADDED
@@ -0,0 +1,2145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ """
18
+ In this file, we define 3 types of datasets:
19
+ 1. LeRobotSingleDataset: a single dataset for a given embodiment tag
20
+ 2. LeRobotMixtureDataset: a mixture of datasets for a given list of embodiment tags
21
+ 3. CachedLeRobotSingleDataset: a single dataset for a given embodiment tag,
22
+ with caching for the video frames
23
+
24
+ See `scripts/load_dataset.py` for examples on how to use these datasets.
25
+ """
26
+ import os
27
+ import hashlib
28
+ import json, torch
29
+ from collections import defaultdict
30
+ from pathlib import Path
31
+ from typing import Sequence
32
+ import os, random
33
+ import numpy as np
34
+ import pandas as pd
35
+ from pydantic import BaseModel, Field, ValidationError
36
+ from torch.utils.data import Dataset
37
+ from tqdm import tqdm
38
+ from PIL import Image
39
+
40
+ from starVLA.dataloader.gr00t_lerobot.video import get_all_frames, get_frames_by_timestamps
41
+
42
+ from starVLA.dataloader.gr00t_lerobot.embodiment_tags import EmbodimentTag, DATASET_NAME_TO_ID
43
+ from starVLA.dataloader.gr00t_lerobot.schema import (
44
+ DatasetMetadata,
45
+ DatasetStatisticalValues,
46
+ LeRobotModalityMetadata,
47
+ LeRobotStateActionMetadata,
48
+ )
49
+ from starVLA.dataloader.gr00t_lerobot.transform import ComposedModalityTransform
50
+
51
+ from functools import partial
52
+ from typing import Tuple, List
53
+ import pickle
54
+
55
+ # LeRobot v2.0 dataset file names
56
+ LE_ROBOT_MODALITY_FILENAME = "meta/modality.json"
57
+ LE_ROBOT_EPISODE_FILENAME = "meta/episodes.jsonl"
58
+ LE_ROBOT_TASKS_FILENAME = "meta/tasks.jsonl"
59
+ LE_ROBOT_INFO_FILENAME = "meta/info.json"
60
+ LE_ROBOT_STATS_FILENAME = "meta/stats_gr00t.json"
61
+ LE_ROBOT_DATA_FILENAME = "data/*/*.parquet"
62
+ LE_ROBOT_STEPS_FILENAME = "meta/steps.pkl"
63
+ EPSILON = 5e-4
64
+
65
+ # LeRobot v3.0 dataset file names
66
+ LE_ROBOT3_TASKS_FILENAME = "meta/tasks.parquet"
67
+ LE_ROBOT3_EPISODE_FILENAME = "meta/episodes/*/*.parquet"
68
+
69
+
70
+ # =============================================================================
71
+ # Unified Representation Layout & Helpers
72
+ # =============================================================================
73
+
74
+ STANDARD_ACTION_DIM = 37
75
+ #
76
+ # Unified action representation layout (0-based indices, Python slice is [start, stop)):
77
+ # Keep only: libero_franka, gr1, real_world_franka.
78
+ #
79
+ # - 0:7 -> left_arm (7D): xyz, rpy/euler, gripper
80
+ # Used by: gr1 left_arm
81
+ # - 7:14 -> right_arm (7D): same structure
82
+ # Used by: libero_franka; gr1 right_arm
83
+ # - 14:20 -> left_hand (6D): gr1 only
84
+ # - 20:26 -> right_hand (6D): gr1 only
85
+ # - 26:29 -> waist (3D): gr1 only
86
+ # - 29:37 -> joints + gripper (8D): real_world_franka only
87
+ #
88
+ # Mapping:
89
+ # libero_franka (7D) -> [7:14] (right_arm slot)
90
+ # gr1 (29D) -> [0:29]
91
+ # real_world_franka (8D) -> [29:37] (joints + gripper)
92
+
93
+ ACTION_REPRESENTATION_SLICES = {
94
+ # Single-arm (7D) -> right_arm slot [7:14]
95
+ "franka": slice(7, 14),
96
+
97
+ # Humanoid (29D) -> full [0:29]
98
+ "gr1": slice(0, 29),
99
+
100
+ # Real-world (8D) -> [29:37] (joints + gripper)
101
+ "real_world_franka": slice(29, 37),
102
+ }
103
+
104
+ STANDARD_STATE_DIM = 74
105
+ # Mapping:
106
+ # libero_franka (8D) -> [0:8]
107
+ # real_world_franka (8D) -> [8:16]
108
+ # gr1 (58D after sin/cos) -> [16:74]
109
+
110
+ STATE_REPRESENTATION_SLICES = {
111
+ # Single-arm (8D)
112
+ "franka": slice(0, 8),
113
+ # Real-world (8D)
114
+ "real_world_franka": slice(8, 16),
115
+ # GR1 isolated (58D, has StateActionSinCosTransform - different pipeline)
116
+ "gr1": slice(16, 74),
117
+ }
118
+
119
+
120
+ def standardize_action_representation(
121
+ action: np.ndarray, embodiment_tag: str
122
+ ) -> np.ndarray:
123
+ """Map per-robot action to a fixed-size standard action vector."""
124
+ target_slice = ACTION_REPRESENTATION_SLICES.get(embodiment_tag)
125
+
126
+ # Only allow explicitly configured embodiment tags.
127
+ if target_slice is None:
128
+ raise ValueError(
129
+ f"Unknown embodiment tag '{embodiment_tag}' for action mapping. "
130
+ f"Known tags: {sorted(ACTION_REPRESENTATION_SLICES)}"
131
+ )
132
+
133
+ expected_dim = target_slice.stop - target_slice.start
134
+ if action.shape[-1] != expected_dim:
135
+ raise ValueError(
136
+ f"Action dim mismatch for tag '{embodiment_tag}': "
137
+ f"{action.shape[-1]=} vs expected {expected_dim}."
138
+ )
139
+
140
+ standard = np.zeros(
141
+ (*action.shape[:-1], STANDARD_ACTION_DIM), dtype=action.dtype
142
+ )
143
+ standard[..., target_slice] = action
144
+ return standard
145
+
146
+
147
+ def standardize_state_representation(
148
+ state: np.ndarray, embodiment_tag: str
149
+ ) -> np.ndarray:
150
+ """Map per-robot state to a fixed-size standard state vector."""
151
+
152
+ target_slice = STATE_REPRESENTATION_SLICES.get(embodiment_tag)
153
+
154
+ # Only allow explicitly configured embodiment tags.
155
+ if target_slice is None:
156
+ raise ValueError(
157
+ f"Unknown embodiment tag '{embodiment_tag}' for state mapping. "
158
+ f"Known tags: {sorted(STATE_REPRESENTATION_SLICES)}"
159
+ )
160
+
161
+ expected_dim = target_slice.stop - target_slice.start
162
+ if state.shape[-1] != expected_dim:
163
+ raise ValueError(
164
+ f"State dim mismatch for tag '{embodiment_tag}': "
165
+ f"{state.shape[-1]=} vs expected {expected_dim}."
166
+ )
167
+
168
+ standard = np.zeros(
169
+ (*state.shape[:-1], STANDARD_STATE_DIM), dtype=state.dtype
170
+ )
171
+ standard[..., target_slice] = state
172
+ return standard
173
+
174
+
175
+ def calculate_dataset_statistics(parquet_paths: list[Path]) -> dict:
176
+ """Calculate the dataset statistics of all columns for a list of parquet files."""
177
+ # Dataset statistics
178
+ all_low_dim_data_list = []
179
+ # Collect all the data
180
+ # parquet_paths = parquet_paths[:3]
181
+ for parquet_path in tqdm(
182
+ sorted(list(parquet_paths)),
183
+ desc="Collecting all parquet files...",
184
+ ):
185
+ # Load the parquet file
186
+ parquet_data = pd.read_parquet(parquet_path)
187
+ parquet_data = parquet_data
188
+ all_low_dim_data_list.append(parquet_data)
189
+
190
+ all_low_dim_data = pd.concat(all_low_dim_data_list, axis=0)
191
+ # Compute dataset statistics
192
+ dataset_statistics = {}
193
+ for le_modality in all_low_dim_data.columns:
194
+ if le_modality.startswith("annotation."):
195
+ continue
196
+ print(f"Computing statistics for {le_modality}...")
197
+ np_data = np.vstack(
198
+ [np.asarray(x, dtype=np.float32) for x in all_low_dim_data[le_modality]]
199
+ )
200
+ dataset_statistics[le_modality] = {
201
+ "mean": np.mean(np_data, axis=0).tolist(),
202
+ "std": np.std(np_data, axis=0).tolist(),
203
+ "min": np.min(np_data, axis=0).tolist(),
204
+ "max": np.max(np_data, axis=0).tolist(),
205
+ "q01": np.quantile(np_data, 0.01, axis=0).tolist(),
206
+ "q99": np.quantile(np_data, 0.99, axis=0).tolist(),
207
+ }
208
+ return dataset_statistics
209
+
210
+
211
+ class ModalityConfig(BaseModel):
212
+ """Configuration for a modality."""
213
+
214
+ delta_indices: list[int]
215
+ """Delta indices to sample relative to the current index. The returned data will correspond to the original data at a sampled base index + delta indices."""
216
+ modality_keys: list[str]
217
+ """The keys to load for the modality in the dataset."""
218
+
219
+
220
+ class LeRobotSingleDataset(Dataset):
221
+ """
222
+ Base dataset class for LeRobot that supports sharding.
223
+ """
224
+ def __init__(
225
+ self,
226
+ dataset_path: Path | str,
227
+ modality_configs: dict[str, ModalityConfig],
228
+ embodiment_tag: str | EmbodimentTag,
229
+ video_backend: str = "decord",
230
+ video_backend_kwargs: dict | None = None,
231
+ transforms: ComposedModalityTransform | None = None,
232
+ delete_pause_frame: bool = False,
233
+ **kwargs,
234
+ ):
235
+ """
236
+ Initialize the dataset.
237
+
238
+ Args:
239
+ dataset_path (Path | str): The path to the dataset.
240
+ modality_configs (dict[str, ModalityConfig]): The configuration for each modality. The keys are the modality names, and the values are the modality configurations.
241
+ See `ModalityConfig` for more details.
242
+ video_backend (str): Backend for video reading.
243
+ video_backend_kwargs (dict): Keyword arguments for the video backend when initializing the video reader.
244
+ transforms (ComposedModalityTransform): The transforms to apply to the dataset.
245
+ embodiment_tag (EmbodimentTag): Overload the embodiment tag for the dataset. e.g. define it as "new_embodiment"
246
+ """
247
+ # first check if the path directory exists
248
+ if not Path(dataset_path).exists():
249
+ raise FileNotFoundError(f"Dataset path {dataset_path} does not exist")
250
+ data_cfg = kwargs.get("data_cfg", {}) or {}
251
+ # indict letobot version
252
+ self._lerobot_version = data_cfg.get("lerobot_version", "v2.0") #self._indict_lerobot_version(**kwargs)
253
+ self.load_video = data_cfg.get("load_video", True)
254
+
255
+ self.delete_pause_frame = delete_pause_frame
256
+
257
+ # If video loading is disabled, skip video modality end-to-end.
258
+ if self.load_video:
259
+ self.modality_configs = modality_configs
260
+ else:
261
+ self.modality_configs = {
262
+ modality: config
263
+ for modality, config in modality_configs.items()
264
+ if modality != "video"
265
+ }
266
+ self.video_backend = video_backend
267
+ self.video_backend_kwargs = video_backend_kwargs if video_backend_kwargs is not None else {}
268
+ self.transforms = (
269
+ transforms if transforms is not None else ComposedModalityTransform(transforms=[])
270
+ )
271
+
272
+ self._dataset_path = Path(dataset_path)
273
+ self._dataset_name = self._dataset_path.name
274
+ self._dataset_id = DATASET_NAME_TO_ID.get(self._dataset_name)
275
+ if isinstance(embodiment_tag, EmbodimentTag):
276
+ self.tag = embodiment_tag.value
277
+ else:
278
+ self.tag = embodiment_tag
279
+
280
+ self._metadata = self._get_metadata(EmbodimentTag(self.tag))
281
+
282
+ # LeRobot-specific config
283
+ self._lerobot_modality_meta = self._get_lerobot_modality_meta()
284
+ self._lerobot_info_meta = self._get_lerobot_info_meta()
285
+ self._data_path_pattern = self._get_data_path_pattern()
286
+ self._video_path_pattern = self._get_video_path_pattern()
287
+ self._chunk_size = self._get_chunk_size()
288
+ self._tasks = self._get_tasks()
289
+ self.curr_traj_data = None
290
+ self.curr_traj_id = None
291
+
292
+ self._trajectory_ids, self._trajectory_lengths = self._get_trajectories()
293
+ self._modality_keys = self._get_modality_keys()
294
+ self._delta_indices = self._get_delta_indices()
295
+ self._all_steps = self._get_all_steps()
296
+ self.set_transforms_metadata(self.metadata)
297
+ self.set_epoch(0)
298
+
299
+ print(f"Initialized dataset {self.dataset_name} with {embodiment_tag}")
300
+
301
+
302
+ # Check if the dataset is valid
303
+ self._check_integrity()
304
+
305
+ @property
306
+ def dataset_path(self) -> Path:
307
+ """The path to the dataset that contains the METADATA_FILENAME file."""
308
+ return self._dataset_path
309
+
310
+ @property
311
+ def metadata(self) -> DatasetMetadata:
312
+ """The metadata for the dataset, loaded from metadata.json in the dataset directory"""
313
+ return self._metadata
314
+
315
+ @property
316
+ def trajectory_ids(self) -> np.ndarray:
317
+ """The trajectory IDs in the dataset, stored as a 1D numpy array of strings."""
318
+ return self._trajectory_ids
319
+
320
+ @property
321
+ def trajectory_lengths(self) -> np.ndarray:
322
+ """The trajectory lengths in the dataset, stored as a 1D numpy array of integers.
323
+ The order of the lengths is the same as the order of the trajectory IDs.
324
+ """
325
+ return self._trajectory_lengths
326
+
327
+ @property
328
+ def all_steps(self) -> list[tuple[int, int]]:
329
+ """The trajectory IDs and base indices for all steps in the dataset.
330
+ Example:
331
+ self.trajectory_ids: [0, 1, 2]
332
+ self.trajectory_lengths: [3, 2, 4]
333
+ return: [
334
+ ("traj_0", 0), ("traj_0", 1), ("traj_0", 2),
335
+ ("traj_1", 0), ("traj_1", 1),
336
+ ("traj_2", 0), ("traj_2", 1), ("traj_2", 2), ("traj_2", 3)
337
+ ]
338
+ """
339
+ return self._all_steps
340
+
341
+ @property
342
+ def modality_keys(self) -> dict:
343
+ """The modality keys for the dataset. The keys are the modality names, and the values are the keys for each modality.
344
+
345
+ Example: {
346
+ "video": ["video.image_side_0", "video.image_side_1"],
347
+ "state": ["state.eef_position", "state.eef_rotation"],
348
+ "action": ["action.eef_position", "action.eef_rotation"],
349
+ "language": ["language.human.task"],
350
+ "timestamp": ["timestamp"],
351
+ "reward": ["reward"],
352
+ }
353
+ """
354
+ return self._modality_keys
355
+
356
+ @property
357
+ def delta_indices(self) -> dict[str, np.ndarray]:
358
+ """The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key."""
359
+ return self._delta_indices
360
+
361
+ @property
362
+ def dataset_name(self) -> str:
363
+ """The name of the dataset."""
364
+ return self._dataset_name
365
+
366
+ @property
367
+ def lerobot_modality_meta(self) -> LeRobotModalityMetadata:
368
+ """The metadata for the LeRobot dataset."""
369
+ return self._lerobot_modality_meta
370
+
371
+ @property
372
+ def lerobot_info_meta(self) -> dict:
373
+ """The metadata for the LeRobot dataset."""
374
+ return self._lerobot_info_meta
375
+
376
+ @property
377
+ def data_path_pattern(self) -> str:
378
+ """The path pattern for the LeRobot dataset."""
379
+ return self._data_path_pattern
380
+
381
+ @property
382
+ def video_path_pattern(self) -> str:
383
+ """The path pattern for the LeRobot dataset."""
384
+ return self._video_path_pattern
385
+
386
+ @property
387
+ def chunk_size(self) -> int:
388
+ """The chunk size for the LeRobot dataset."""
389
+ return self._chunk_size
390
+
391
+ @property
392
+ def tasks(self) -> pd.DataFrame:
393
+ """The tasks for the dataset."""
394
+ return self._tasks
395
+
396
+ def _get_metadata(self, embodiment_tag: EmbodimentTag) -> DatasetMetadata:
397
+ """Get the metadata for the dataset.
398
+
399
+ Returns:
400
+ dict: The metadata for the dataset.
401
+ """
402
+
403
+ # 1. Modality metadata
404
+ modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME
405
+ assert (
406
+ modality_meta_path.exists()
407
+ ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}"
408
+ # 1.1. State and action modalities
409
+ simplified_modality_meta: dict[str, dict] = {}
410
+ with open(modality_meta_path, "r") as f:
411
+ le_modality_meta = LeRobotModalityMetadata.model_validate(json.load(f))
412
+ for modality in ["state", "action"]:
413
+ simplified_modality_meta[modality] = {}
414
+ le_state_action_meta: dict[str, LeRobotStateActionMetadata] = getattr(
415
+ le_modality_meta, modality
416
+ )
417
+ for subkey in le_state_action_meta:
418
+ state_action_dtype = np.dtype(le_state_action_meta[subkey].dtype)
419
+ if np.issubdtype(state_action_dtype, np.floating):
420
+ continuous = True
421
+ else:
422
+ continuous = False
423
+ simplified_modality_meta[modality][subkey] = {
424
+ "absolute": le_state_action_meta[subkey].absolute,
425
+ "rotation_type": le_state_action_meta[subkey].rotation_type,
426
+ "shape": [
427
+ le_state_action_meta[subkey].end - le_state_action_meta[subkey].start
428
+ ],
429
+ "continuous": continuous,
430
+ }
431
+
432
+ # 1.2. Video modalities
433
+ le_info_path = self.dataset_path / LE_ROBOT_INFO_FILENAME
434
+ assert (
435
+ le_info_path.exists()
436
+ ), f"Please provide a {LE_ROBOT_INFO_FILENAME} file in {self.dataset_path}"
437
+ with open(le_info_path, "r") as f:
438
+ le_info = json.load(f)
439
+ simplified_modality_meta["video"] = {}
440
+ for new_key in le_modality_meta.video:
441
+ original_key = le_modality_meta.video[new_key].original_key
442
+ if original_key is None:
443
+ original_key = new_key
444
+ le_video_meta = le_info["features"][original_key]
445
+ height = le_video_meta["shape"][le_video_meta["names"].index("height")]
446
+ width = le_video_meta["shape"][le_video_meta["names"].index("width")]
447
+ # NOTE(FH): different lerobot dataset versions have different keys for the number of channels and fps
448
+ try:
449
+ channels = le_video_meta["shape"][le_video_meta["names"].index("channel")]
450
+ fps = le_video_meta["video_info"]["video.fps"]
451
+ except (ValueError, KeyError):
452
+ # channels = le_video_meta["shape"][le_video_meta["names"].index("channels")]
453
+ channels = le_video_meta["info"]["video.channels"]
454
+ fps = le_video_meta["info"]["video.fps"]
455
+ simplified_modality_meta["video"][new_key] = {
456
+ "resolution": [width, height],
457
+ "channels": channels,
458
+ "fps": fps,
459
+ }
460
+
461
+ # 2. Dataset statistics
462
+ stats_path = self.dataset_path / LE_ROBOT_STATS_FILENAME
463
+ try:
464
+ with open(stats_path, "r") as f:
465
+ le_statistics = json.load(f)
466
+ for stat in le_statistics.values():
467
+ DatasetStatisticalValues.model_validate(stat)
468
+ except (FileNotFoundError, ValidationError) as e:
469
+ print(f"Failed to load dataset statistics: {e}")
470
+ print(f"Calculating dataset statistics for {self.dataset_name}")
471
+ # Get all parquet files in the dataset paths
472
+ parquet_files = list((self.dataset_path).glob(LE_ROBOT_DATA_FILENAME))
473
+ parquet_files_filtered = []
474
+ # parquet_files[0].name = "episode_033675.parquet" is broken file
475
+ for pf in parquet_files:
476
+ if "episode_033675.parquet" in pf.name:
477
+ continue
478
+ parquet_files_filtered.append(pf)
479
+
480
+ le_statistics = calculate_dataset_statistics(parquet_files_filtered)
481
+ with open(stats_path, "w") as f:
482
+ json.dump(le_statistics, f, indent=4)
483
+ dataset_statistics = {}
484
+ for our_modality in ["state", "action"]:
485
+ dataset_statistics[our_modality] = {}
486
+ for subkey in simplified_modality_meta[our_modality]:
487
+ dataset_statistics[our_modality][subkey] = {}
488
+ state_action_meta = le_modality_meta.get_key_meta(f"{our_modality}.{subkey}")
489
+ assert isinstance(state_action_meta, LeRobotStateActionMetadata)
490
+ le_modality = state_action_meta.original_key
491
+ for stat_name in le_statistics[le_modality]:
492
+ indices = np.arange(
493
+ state_action_meta.start,
494
+ state_action_meta.end,
495
+ )
496
+ stat = np.array(le_statistics[le_modality][stat_name])
497
+ dataset_statistics[our_modality][subkey][stat_name] = stat[indices].tolist()
498
+
499
+ # 3. Full dataset metadata
500
+ metadata = DatasetMetadata(
501
+ statistics=dataset_statistics, # type: ignore
502
+ modalities=simplified_modality_meta, # type: ignore
503
+ embodiment_tag=embodiment_tag,
504
+ )
505
+
506
+ return metadata
507
+
508
+ def _get_trajectories(self) -> tuple[np.ndarray, np.ndarray]:
509
+ """Get the trajectories in the dataset."""
510
+ # Get trajectory lengths, IDs, and whitelist from dataset metadata
511
+ # v2.0
512
+ if self._lerobot_version == "v2.0":
513
+ file_path = self.dataset_path / LE_ROBOT_EPISODE_FILENAME
514
+ with open(file_path, "r") as f:
515
+ episode_metadata = [json.loads(line) for line in f]
516
+ trajectory_ids = []
517
+ trajectory_lengths = []
518
+ for episode in episode_metadata:
519
+ trajectory_ids.append(episode["episode_index"])
520
+ trajectory_lengths.append(episode["length"])
521
+ return np.array(trajectory_ids), np.array(trajectory_lengths)
522
+ # v3.0
523
+ elif self._lerobot_version == "v3.0":
524
+ file_paths = list((self.dataset_path).glob(LE_ROBOT3_EPISODE_FILENAME))
525
+ trajectory_ids = []
526
+ trajectory_lengths = []
527
+ # data_chunck_index = []
528
+ # data_file_index = []
529
+ # vido_from_index = []
530
+ self.trajectory_ids_to_metadata = {}
531
+ for file_path in file_paths:
532
+ episodes_data = pd.read_parquet(file_path)
533
+ for index, episode in episodes_data.iterrows():
534
+ trajectory_ids.append(episode["episode_index"])
535
+ trajectory_lengths.append(episode["length"])
536
+
537
+ # TODO auto map key? just map to file_path and file_from_index
538
+ episode_meta = {
539
+ "data/chunk_index": episode["data/chunk_index"],
540
+ "data/file_index": episode["data/file_index"],
541
+ "data/file_from_index": index,
542
+ }
543
+ if self.load_video:
544
+ episode_meta["videos/observation.images.wrist/from_timestamp"] = episode[
545
+ "videos/observation.images.wrist/from_timestamp"
546
+ ]
547
+ self.trajectory_ids_to_metadata[trajectory_ids[-1]] = episode_meta
548
+
549
+ # 这里应该可以直接读取到 save index 信息
550
+ return np.array(trajectory_ids), np.array(trajectory_lengths)
551
+
552
+ def _get_all_steps(self) -> list[tuple[int, int]]:
553
+ """Get the trajectory IDs and base indices for all steps in the dataset.
554
+
555
+ Returns:
556
+ list[tuple[str, int]]: A list of (trajectory_id, base_index) tuples.
557
+ """
558
+ # Create a hash key based on configuration to ensure cache validity
559
+ config_key = self._get_steps_config_key()
560
+
561
+ # Create a unique filename based on config_key
562
+ # steps_filename = f"steps_{config_key}.pkl"
563
+ # @BUG
564
+ # fast get static steps @fangjing --> don't use hash to dynamic sample
565
+ steps_filename = "steps_data_index.pkl"
566
+
567
+
568
+ steps_path = self.dataset_path / "meta" / steps_filename
569
+
570
+ # Try to load cached steps first
571
+ try:
572
+ if steps_path.exists():
573
+ with open(steps_path, "rb") as f:
574
+ cached_data = pickle.load(f)
575
+ return cached_data["steps"]
576
+
577
+ except (FileNotFoundError, pickle.PickleError, KeyError) as e:
578
+ print(f"Failed to load cached steps: {e}")
579
+ print("Computing steps from scratch...")
580
+
581
+ # Compute steps using single process
582
+ all_steps = self._get_all_steps_single_process()
583
+
584
+ # Cache the computed steps with unique filename
585
+ try:
586
+ cache_data = {
587
+ "config_key": config_key,
588
+ "steps": all_steps,
589
+ "num_trajectories": len(self.trajectory_ids),
590
+ "total_steps": len(all_steps),
591
+ "computed_timestamp": pd.Timestamp.now().isoformat(),
592
+ "delete_pause_frame": self.delete_pause_frame,
593
+ }
594
+
595
+ # Ensure the meta directory exists
596
+ steps_path.parent.mkdir(parents=True, exist_ok=True)
597
+
598
+ with open(steps_path, "wb") as f:
599
+ pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
600
+ print(f"Cached steps saved to {steps_path}")
601
+ except Exception as e:
602
+ print(f"Failed to cache steps: {e}")
603
+
604
+ return all_steps
605
+
606
+ def _get_steps_config_key(self) -> str:
607
+ """Generate a configuration key for steps caching."""
608
+ config_dict = {
609
+ "delete_pause_frame": self.delete_pause_frame,
610
+ "dataset_name": self.dataset_name,
611
+ }
612
+ # Create a hash of the configuration
613
+ config_str = str(sorted(config_dict.items()))
614
+ return hashlib.md5(config_str.encode()).hexdigest()[:12] #
615
+
616
+
617
+ def _get_all_steps_single_process(self) -> list[tuple[int, int]]:
618
+ """Original single-process implementation as fallback."""
619
+ all_steps: list[tuple[int, int]] = []
620
+ skipped_trajectories = 0
621
+ processed_trajectories = 0
622
+
623
+ # Check if language modality is configured
624
+ has_language_modality = 'language' in self.modality_keys and len(self.modality_keys['language']) > 0
625
+ # TODO why trajectory_length here, why not use data length?
626
+ for trajectory_id, trajectory_length in tqdm(zip(self.trajectory_ids, self.trajectory_lengths), total=len(self.trajectory_ids), desc="Getting All Step"):
627
+ try:
628
+ if self._lerobot_version == "v2.0":
629
+ data = self.get_trajectory_data(trajectory_id)
630
+ elif self._lerobot_version == "v3.0":
631
+ data = self.get_trajectory_data_lerobot_v3(trajectory_id)
632
+
633
+ trajectory_skipped = False
634
+
635
+ # Check if trajectory has valid language instruction (if language modality is configured)
636
+ if has_language_modality:
637
+ self.curr_traj_data = data # Set current trajectory data for get_language to work
638
+
639
+ language_instruction = self.get_language(trajectory_id, self.modality_keys['language'][0], 0)
640
+ if not language_instruction or language_instruction[0] == "":
641
+ print(f"Skipping trajectory {trajectory_id} due to empty language instruction")
642
+ skipped_trajectories += 1
643
+ trajectory_skipped = True
644
+ continue
645
+
646
+ except Exception as e:
647
+ print(f"Skipping trajectory {trajectory_id} due to read error: {e}")
648
+ skipped_trajectories += 1
649
+ trajectory_skipped = True
650
+ continue
651
+
652
+ if not trajectory_skipped:
653
+ processed_trajectories += 1
654
+
655
+ for base_index in range(trajectory_length):
656
+ all_steps.append((trajectory_id, base_index))
657
+
658
+ # Print summary statistics
659
+ print(f"Single-process summary: Processed {processed_trajectories} trajectories, skipped {skipped_trajectories} empty trajectories")
660
+ print(f"Total steps: {len(all_steps)} from {len(self.trajectory_ids)} trajectories")
661
+
662
+ return all_steps
663
+
664
+ def _get_position_and_gripper_values(self, data: pd.DataFrame) -> tuple[list, list]:
665
+ """Get position and gripper values based on available columns in the dataset."""
666
+ # Get action keys from modality_keys
667
+ action_keys = self.modality_keys.get('action', [])
668
+
669
+ # Extract position data
670
+ delta_position_values = None
671
+ position_candidates = ['delta_eef_position']
672
+ coordinate_candidates = ['x', 'y', 'z']
673
+
674
+ # First try combined position fields
675
+ for pos_key in position_candidates:
676
+ full_key = f"action.{pos_key}"
677
+ if full_key in action_keys:
678
+ try:
679
+ # Get the lerobot key for this modality
680
+ le_action_cfg = self.lerobot_modality_meta.action
681
+ subkey = pos_key
682
+ if subkey in le_action_cfg:
683
+ le_key = le_action_cfg[subkey].original_key or subkey
684
+ if le_key in data.columns:
685
+ data_array = np.stack(data[le_key])
686
+ le_indices = np.arange(le_action_cfg[subkey].start, le_action_cfg[subkey].end)
687
+ filtered_data = data_array[:, le_indices]
688
+ delta_position_values = filtered_data.tolist()
689
+ break
690
+ except Exception:
691
+ continue
692
+
693
+ # If combined fields not found, try individual x,y,z coordinates
694
+ if delta_position_values is None:
695
+ x_data, y_data, z_data = None, None, None
696
+ for coord in coordinate_candidates:
697
+ full_key = f"action.{coord}"
698
+ if full_key in action_keys:
699
+ try:
700
+ le_action_cfg = self.lerobot_modality_meta.action
701
+ if coord in le_action_cfg:
702
+ le_key = le_action_cfg[coord].original_key or coord
703
+ if le_key in data.columns:
704
+ data_array = np.stack(data[le_key])
705
+ le_indices = np.arange(le_action_cfg[coord].start, le_action_cfg[coord].end)
706
+ coord_data = data_array[:, le_indices].flatten()
707
+ if coord == 'x':
708
+ x_data = coord_data
709
+ elif coord == 'y':
710
+ y_data = coord_data
711
+ elif coord == 'z':
712
+ z_data = coord_data
713
+ except Exception:
714
+ continue
715
+
716
+ if x_data is not None and y_data is not None and z_data is not None:
717
+ delta_position_values = np.column_stack((x_data, y_data, z_data)).tolist()
718
+
719
+ if delta_position_values is None:
720
+ # Fallback to the old hardcoded approach if metadata approach fails
721
+ if 'action.delta_eef_position' in data.columns:
722
+ delta_position_values = data['action.delta_eef_position'].to_numpy().tolist()
723
+ elif all(col in data.columns for col in ['action.x', 'action.y', 'action.z']):
724
+ x_vals = data['action.x'].to_numpy()
725
+ y_vals = data['action.y'].to_numpy()
726
+ z_vals = data['action.z'].to_numpy()
727
+ delta_position_values = np.column_stack((x_vals, y_vals, z_vals)).tolist()
728
+ else:
729
+ raise ValueError(f"No suitable position columns found. Available columns: {data.columns.tolist()}")
730
+
731
+ # Extract gripper data
732
+ gripper_values = None
733
+ gripper_candidates = ['gripper_close', 'gripper']
734
+
735
+ for grip_key in gripper_candidates:
736
+ full_key = f"action.{grip_key}"
737
+ if full_key in action_keys:
738
+ try:
739
+ le_action_cfg = self.lerobot_modality_meta.action
740
+ if grip_key in le_action_cfg:
741
+ le_key = le_action_cfg[grip_key].original_key or grip_key
742
+ if le_key in data.columns:
743
+ data_array = np.stack(data[le_key])
744
+ le_indices = np.arange(le_action_cfg[grip_key].start, le_action_cfg[grip_key].end)
745
+ gripper_data = data_array[:, le_indices].flatten()
746
+ gripper_values = gripper_data.tolist()
747
+ break
748
+ except Exception:
749
+ continue
750
+
751
+ if gripper_values is None:
752
+ # Fallback to the old hardcoded approach if metadata approach fails
753
+ if 'action.gripper_close' in data.columns:
754
+ gripper_values = data['action.gripper_close'].to_numpy().tolist()
755
+ elif 'action.gripper' in data.columns:
756
+ gripper_values = data['action.gripper'].to_numpy().tolist()
757
+ else:
758
+ raise ValueError(f"No suitable gripper columns found. Available columns: {data.columns.tolist()}")
759
+
760
+ return delta_position_values, gripper_values
761
+
762
+ def _get_modality_keys(self) -> dict:
763
+ """Get the modality keys for the dataset.
764
+ The keys are the modality names, and the values are the keys for each modality.
765
+ See property `modality_keys` for the expected format.
766
+ """
767
+ modality_keys = defaultdict(list)
768
+ for modality, config in self.modality_configs.items():
769
+ modality_keys[modality] = config.modality_keys
770
+ return modality_keys
771
+
772
+ def _get_delta_indices(self) -> dict[str, np.ndarray]:
773
+ """Restructure the delta indices to use modality.key as keys instead of just the modalities."""
774
+ delta_indices: dict[str, np.ndarray] = {}
775
+ for config in self.modality_configs.values():
776
+ for key in config.modality_keys:
777
+ delta_indices[key] = np.array(config.delta_indices)
778
+ return delta_indices
779
+
780
+ def _get_lerobot_modality_meta(self) -> LeRobotModalityMetadata:
781
+ """Get the metadata for the LeRobot dataset."""
782
+ modality_meta_path = self.dataset_path / LE_ROBOT_MODALITY_FILENAME
783
+ assert (
784
+ modality_meta_path.exists()
785
+ ), f"Please provide a {LE_ROBOT_MODALITY_FILENAME} file in {self.dataset_path}"
786
+ with open(modality_meta_path, "r") as f:
787
+ modality_meta = LeRobotModalityMetadata.model_validate(json.load(f))
788
+ return modality_meta
789
+
790
+ def _get_lerobot_info_meta(self) -> dict:
791
+ """Get the metadata for the LeRobot dataset."""
792
+ info_meta_path = self.dataset_path / LE_ROBOT_INFO_FILENAME
793
+ with open(info_meta_path, "r") as f:
794
+ info_meta = json.load(f)
795
+ return info_meta
796
+
797
+ def _get_data_path_pattern(self) -> str:
798
+ """Get the data path pattern for the LeRobot dataset."""
799
+ return self.lerobot_info_meta["data_path"]
800
+
801
+ def _get_video_path_pattern(self) -> str:
802
+ """Get the video path pattern for the LeRobot dataset."""
803
+ return self.lerobot_info_meta["video_path"]
804
+
805
+ def _get_chunk_size(self) -> int:
806
+ """Get the chunk size for the LeRobot dataset."""
807
+ return self.lerobot_info_meta["chunks_size"]
808
+
809
+ def _get_tasks(self) -> pd.DataFrame:
810
+ """Get the tasks for the dataset."""
811
+ if self._lerobot_version == "v2.0":
812
+ tasks_path = self.dataset_path / LE_ROBOT_TASKS_FILENAME
813
+ with open(tasks_path, "r") as f:
814
+ tasks = [json.loads(line) for line in f]
815
+ df = pd.DataFrame(tasks)
816
+ return df.set_index("task_index")
817
+
818
+ elif self._lerobot_version == "v3.0":
819
+ tasks_path = self.dataset_path / LE_ROBOT3_TASKS_FILENAME
820
+ df = pd.read_parquet(tasks_path)
821
+ df = df.reset_index() # 把索引变成一列,列名通常为 'index'
822
+ df = df.rename(columns={'index': 'task'}) # 把 'index' 列重命名为 'task'
823
+ df = df[['task_index', 'task']] # 调整列顺序
824
+ return df
825
+ def _check_integrity(self):
826
+ """Use the config to check if the keys are valid and detect silent data corruption."""
827
+ ERROR_MSG_HEADER = f"Error occurred in initializing dataset {self.dataset_name}:\n"
828
+
829
+ for modality_config in self.modality_configs.values():
830
+ for key in modality_config.modality_keys:
831
+ if key == "lapa_action" or key == "dream_actions":
832
+ continue # no need for any metadata for lapa actions because it comes normalized
833
+ # Check if the key is valid
834
+ try:
835
+ self.lerobot_modality_meta.get_key_meta(key)
836
+ except Exception as e:
837
+ raise ValueError(
838
+ ERROR_MSG_HEADER + f"Unable to find key {key} in modality metadata:\n{e}"
839
+ )
840
+
841
+ def set_transforms_metadata(self, metadata: DatasetMetadata):
842
+ """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values."""
843
+ self.transforms.set_metadata(metadata)
844
+
845
+ def set_epoch(self, epoch: int):
846
+ """Set the epoch for the dataset.
847
+
848
+ Args:
849
+ epoch (int): The epoch to set.
850
+ """
851
+ self.epoch = epoch
852
+
853
+ def __len__(self) -> int:
854
+ """Get the total number of data points in the dataset.
855
+
856
+ Returns:
857
+ int: the total number of data points in the dataset.
858
+ """
859
+ return len(self.all_steps)
860
+
861
+ def __str__(self) -> str:
862
+ """Get the description of the dataset."""
863
+ return f"{self.dataset_name} ({len(self)} steps)"
864
+
865
+
866
+ def __getitem__(self, index: int) -> dict:
867
+ """Get the data for a single step in a trajectory.
868
+
869
+ Args:
870
+ index (int): The index of the step to get.
871
+
872
+ Returns:
873
+ dict: The data for the step.
874
+ """
875
+ trajectory_id, base_index = self.all_steps[index]
876
+ data = self.get_step_data(trajectory_id, base_index)
877
+
878
+ # Process all video keys dynamically
879
+ images = []
880
+ for video_key in self.modality_keys.get("video", []):
881
+ image = data[video_key][0]
882
+
883
+ image = Image.fromarray(image).resize((224, 224))
884
+ images.append(image)
885
+
886
+ # Get language and action data
887
+ language = data[self.modality_keys["language"][0]][0]
888
+ action = []
889
+ for action_key in self.modality_keys["action"]:
890
+ action.append(data[action_key])
891
+ action = np.concatenate(action, axis=1)
892
+ action = standardize_action_representation(action, self.tag)
893
+
894
+ state = []
895
+ for state_key in self.modality_keys["state"]:
896
+ state.append(data[state_key])
897
+ state = np.concatenate(state, axis=1)
898
+ state = standardize_state_representation(state, self.tag)
899
+
900
+ return dict(action=action, state=state, image=images, language=language, dataset_id=self._dataset_id)
901
+
902
+ def get_step_data(self, trajectory_id: int, base_index: int) -> dict:
903
+ """Get the RAW data for a single step in a trajectory. No transforms are applied.
904
+
905
+ Args:
906
+ trajectory_id (int): The name of the trajectory.
907
+ base_index (int): The base step index in the trajectory.
908
+
909
+ Returns:
910
+ dict: The RAW data for the step.
911
+
912
+ Example return:
913
+ {
914
+ "video": {
915
+ "video.image_side_0": [B, T, H, W, C],
916
+ "video.image_side_1": [B, T, H, W, C],
917
+ },
918
+ "state": {
919
+ "state.eef_position": [B, T, state_dim],
920
+ "state.eef_rotation": [B, T, state_dim],
921
+ },
922
+ "action": {
923
+ "action.eef_position": [B, T, action_dim],
924
+ "action.eef_rotation": [B, T, action_dim],
925
+ },
926
+ }
927
+ """
928
+ data = {}
929
+ # Get the data for all modalities # just for action base data
930
+ self.curr_traj_data = self.get_trajectory_data(trajectory_id)
931
+ # TODO @JinhuiYE The logic below is poorly implemented. Data reading should be directly based on curr_traj_data.
932
+ for modality in self.modality_keys:
933
+ # Get the data corresponding to each key in the modality
934
+ for key in self.modality_keys[modality]:
935
+ data[key] = self.get_data_by_modality(trajectory_id, modality, key, base_index)
936
+ return data
937
+
938
+ def get_trajectory_data(self, trajectory_id: int) -> pd.DataFrame:
939
+ """Get the data for a trajectory."""
940
+ if self._lerobot_version == "v2.0":
941
+
942
+ if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None:
943
+ return self.curr_traj_data
944
+ else:
945
+ chunk_index = self.get_episode_chunk(trajectory_id)
946
+ parquet_path = self.dataset_path / self.data_path_pattern.format(
947
+ episode_chunk=chunk_index, episode_index=trajectory_id
948
+ )
949
+ assert parquet_path.exists(), f"Parquet file not found at {parquet_path}"
950
+ return pd.read_parquet(parquet_path)
951
+ elif self._lerobot_version == "v3.0":
952
+ return self.get_trajectory_data_lerobot_v3(trajectory_id)
953
+
954
+ def get_trajectory_data_lerobot_v3(self, trajectory_id: int) -> pd.DataFrame:
955
+ """Get the data for a trajectory from lerobot v3."""
956
+ if self.curr_traj_id == trajectory_id and self.curr_traj_data is not None:
957
+ return self.curr_traj_data
958
+ else: #TODO check detail later
959
+ chunk_index = self.get_episode_chunk(trajectory_id)
960
+
961
+ file_index = self.get_episode_file_index(trajectory_id)
962
+ # file_from_index = self.get_episode_file_from_index(trajectory_id)
963
+
964
+
965
+ parquet_path = self.dataset_path / self.data_path_pattern.format(
966
+ chunk_index=chunk_index, file_index=file_index
967
+ )
968
+ assert parquet_path.exists(), f"Parquet file not found at {parquet_path}"
969
+ file_data = pd.read_parquet(parquet_path)
970
+
971
+ # filter by trajectory_id
972
+ episode_data = file_data.loc[file_data["episode_index"] == trajectory_id].copy()
973
+
974
+ # fix timestamp from epis index to file index for video alignment
975
+ if self.load_video:
976
+ from_timestamp = self.trajectory_ids_to_metadata[trajectory_id].get(
977
+ "videos/observation.images.wrist/from_timestamp", 0
978
+ )
979
+ episode_data["timestamp"] = episode_data["timestamp"] + from_timestamp
980
+
981
+ return episode_data
982
+
983
+
984
+ def get_trajectory_index(self, trajectory_id: int) -> int:
985
+ """Get the index of the trajectory in the dataset by the trajectory ID.
986
+ This is useful when you need to get the trajectory length or sampling weight corresponding to the trajectory ID.
987
+
988
+ Args:
989
+ trajectory_id (str): The ID of the trajectory.
990
+
991
+ Returns:
992
+ int: The index of the trajectory in the dataset.
993
+ """
994
+ trajectory_indices = np.where(self.trajectory_ids == trajectory_id)[0]
995
+ if len(trajectory_indices) != 1:
996
+ raise ValueError(
997
+ f"Error finding trajectory index for {trajectory_id}, found {trajectory_indices=}"
998
+ )
999
+ return trajectory_indices[0]
1000
+
1001
+ def get_episode_chunk(self, ep_index: int) -> int:
1002
+ """Get the chunk index for an episode index."""
1003
+ return ep_index // self.chunk_size
1004
+ def get_episode_file_index(self, ep_index: int) -> int:
1005
+ """Get the file index for an episode index."""
1006
+ episode_meta = self.trajectory_ids_to_metadata[ep_index]
1007
+ return episode_meta["data/file_index"]
1008
+
1009
+ def get_episode_file_from_index(self, ep_index: int) -> int:
1010
+ """Get the file from index for an episode index."""
1011
+ episode_meta = self.trajectory_ids_to_metadata[ep_index]
1012
+ return episode_meta["data/file_from_index"]
1013
+
1014
+
1015
+ def retrieve_data_and_pad(
1016
+ self,
1017
+ array: np.ndarray,
1018
+ step_indices: np.ndarray,
1019
+ max_length: int,
1020
+ padding_strategy: str = "first_last",
1021
+ ) -> np.ndarray:
1022
+ """Retrieve the data from the dataset and pad it if necessary.
1023
+ Args:
1024
+ array (np.ndarray): The array to retrieve the data from.
1025
+ step_indices (np.ndarray): The step indices to retrieve the data for.
1026
+ max_length (int): The maximum length of the data.
1027
+ padding_strategy (str): The padding strategy, either "first" or "last".
1028
+ """
1029
+ # Get the padding indices
1030
+ front_padding_indices = step_indices < 0
1031
+ end_padding_indices = step_indices >= max_length
1032
+ padding_positions = np.logical_or(front_padding_indices, end_padding_indices)
1033
+ # Retrieve the data with the non-padding indices
1034
+ # If there exists some padding, Given T step_indices, the shape of the retrieved data will be (T', ...) where T' < T
1035
+ raw_data = array[step_indices[~padding_positions]]
1036
+ assert isinstance(raw_data, np.ndarray), f"{type(raw_data)=}"
1037
+ # This is the shape of the output, (T, ...)
1038
+ if raw_data.ndim == 1:
1039
+ expected_shape = (len(step_indices),)
1040
+ else:
1041
+ expected_shape = (len(step_indices), *array.shape[1:])
1042
+
1043
+ # Pad the data
1044
+ output = np.zeros(expected_shape)
1045
+ # Assign the non-padded data
1046
+ output[~padding_positions] = raw_data
1047
+ # If there exists some padding, pad the data
1048
+ if padding_positions.any():
1049
+ if padding_strategy == "first_last":
1050
+ # Use first / last step data to pad
1051
+ front_padding_data = array[0]
1052
+ end_padding_data = array[-1]
1053
+ output[front_padding_indices] = front_padding_data
1054
+ output[end_padding_indices] = end_padding_data
1055
+ elif padding_strategy == "zero":
1056
+ # Use zero padding
1057
+ output[padding_positions] = 0
1058
+ else:
1059
+ raise ValueError(f"Invalid padding strategy: {padding_strategy}")
1060
+ return output
1061
+
1062
+ def get_video_path(self, trajectory_id: int, key: str) -> Path:
1063
+ chunk_index = self.get_episode_chunk(trajectory_id)
1064
+ original_key = self.lerobot_modality_meta.video[key].original_key
1065
+ if original_key is None:
1066
+ original_key = key
1067
+ if self._lerobot_version == "v2.0":
1068
+ video_filename = self.video_path_pattern.format(
1069
+ episode_chunk=chunk_index, episode_index=trajectory_id, video_key=original_key
1070
+ )
1071
+ elif self._lerobot_version == "v3.0":
1072
+ episode_meta = self.trajectory_ids_to_metadata[trajectory_id]
1073
+ video_filename = self.video_path_pattern.format(
1074
+ video_key=original_key,
1075
+ chunk_index=episode_meta["data/chunk_index"],
1076
+ file_index=episode_meta["data/file_index"],
1077
+ )
1078
+ return self.dataset_path / video_filename
1079
+
1080
+ def get_video(
1081
+ self,
1082
+ trajectory_id: int,
1083
+ key: str,
1084
+ base_index: int,
1085
+ ) -> np.ndarray:
1086
+ """Get the video frames for a trajectory by a base index.
1087
+
1088
+ Args:
1089
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1090
+ trajectory_id (str): The ID of the trajectory.
1091
+ key (str): The key of the video.
1092
+ base_index (int): The base index of the trajectory.
1093
+
1094
+ Returns:
1095
+ np.ndarray: The video frames for the trajectory and frame indices. Shape: (T, H, W, C)
1096
+ """
1097
+ # Get the step indices
1098
+ step_indices = self.delta_indices[key] + base_index
1099
+ # print(f"{step_indices=}")
1100
+ # Get the trajectory index
1101
+ trajectory_index = self.get_trajectory_index(trajectory_id)
1102
+ # Ensure the indices are within the valid range
1103
+ # This is equivalent to padding the video with extra frames at the beginning and end
1104
+ step_indices = np.maximum(step_indices, 0)
1105
+ step_indices = np.minimum(step_indices, self.trajectory_lengths[trajectory_index] - 1)
1106
+ assert key.startswith("video."), f"Video key must start with 'video.', got {key}"
1107
+ # Get the sub-key
1108
+ key = key.replace("video.", "")
1109
+ video_path = self.get_video_path(trajectory_id, key)
1110
+ # Get the action/state timestamps for each frame in the video
1111
+ assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
1112
+ assert "timestamp" in self.curr_traj_data.columns, f"No timestamp found in {trajectory_id=}"
1113
+ timestamp: np.ndarray = self.curr_traj_data["timestamp"].to_numpy()
1114
+ # Get the corresponding video timestamps from the step indices
1115
+ video_timestamp = timestamp[step_indices]
1116
+
1117
+ return get_frames_by_timestamps(
1118
+ video_path.as_posix(),
1119
+ video_timestamp,
1120
+ video_backend=self.video_backend, # TODO
1121
+ video_backend_kwargs=self.video_backend_kwargs,
1122
+ )
1123
+
1124
+ def get_state_or_action(
1125
+ self,
1126
+ trajectory_id: int,
1127
+ modality: str,
1128
+ key: str,
1129
+ base_index: int,
1130
+ ) -> np.ndarray:
1131
+ """Get the state or action data for a trajectory by a base index.
1132
+ If the step indices are out of range, pad with the data:
1133
+ if the data is stored in absolute format, pad with the first or last step data;
1134
+ otherwise, pad with zero.
1135
+
1136
+ Args:
1137
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1138
+ trajectory_id (int): The ID of the trajectory.
1139
+ modality (str): The modality of the data.
1140
+ key (str): The key of the data.
1141
+ base_index (int): The base index of the trajectory.
1142
+
1143
+ Returns:
1144
+ np.ndarray: The data for the trajectory and step indices.
1145
+ """
1146
+ # Get the step indices
1147
+ step_indices = self.delta_indices[key] + base_index
1148
+ # Get the trajectory index
1149
+ trajectory_index = self.get_trajectory_index(trajectory_id)
1150
+ # Get the maximum length of the trajectory
1151
+ max_length = self.trajectory_lengths[trajectory_index]
1152
+ assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}"
1153
+ # Get the sub-key, e.g. state.joint_angles -> joint_angles
1154
+ key = key.replace(modality + ".", "")
1155
+ # Get the lerobot key
1156
+ le_state_or_action_cfg = getattr(self.lerobot_modality_meta, modality)
1157
+ le_key = le_state_or_action_cfg[key].original_key
1158
+ if le_key is None:
1159
+ le_key = key
1160
+ # Get the data array, shape: (T, D)
1161
+ assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
1162
+ assert le_key in self.curr_traj_data.columns, f"No {le_key} found in {trajectory_id=}"
1163
+ data_array: np.ndarray = np.stack(self.curr_traj_data[le_key]) # type: ignore
1164
+ assert data_array.ndim == 2, f"Expected 2D array, got key {le_key} is{data_array.shape} array"
1165
+ le_indices = np.arange(
1166
+ le_state_or_action_cfg[key].start,
1167
+ le_state_or_action_cfg[key].end,
1168
+ )
1169
+ data_array = data_array[:, le_indices]
1170
+ # Get the state or action configuration
1171
+ state_or_action_cfg = getattr(self.metadata.modalities, modality)[key]
1172
+
1173
+ # Pad the data
1174
+ return self.retrieve_data_and_pad(
1175
+ array=data_array,
1176
+ step_indices=step_indices,
1177
+ max_length=max_length,
1178
+ padding_strategy="first_last" if state_or_action_cfg.absolute else "zero",
1179
+ # padding_strategy="zero", # HACK for realdata
1180
+ )
1181
+
1182
+ def get_language(
1183
+ self,
1184
+ trajectory_id: int,
1185
+ key: str,
1186
+ base_index: int,
1187
+ ) -> list[str]:
1188
+ """Get the language annotation data for a trajectory by step indices.
1189
+
1190
+ Args:
1191
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1192
+ trajectory_id (int): The ID of the trajectory.
1193
+ key (str): The key of the annotation.
1194
+ base_index (int): The base index of the trajectory.
1195
+
1196
+ Returns:
1197
+ list[str]: The annotation data for the trajectory and step indices. If no matching data is found, return empty strings.
1198
+ """
1199
+ assert self.curr_traj_data is not None, f"No data found for {trajectory_id=}"
1200
+ # Get the step indices
1201
+ step_indices = self.delta_indices[key] + base_index
1202
+ # Get the trajectory index
1203
+ trajectory_index = self.get_trajectory_index(trajectory_id)
1204
+ # Get the maximum length of the trajectory
1205
+ max_length = self.trajectory_lengths[trajectory_index]
1206
+ # Get the end times corresponding to the closest indices
1207
+ step_indices = np.maximum(step_indices, 0)
1208
+ step_indices = np.minimum(step_indices, max_length - 1)
1209
+ # Get the annotations
1210
+ task_indices: list[int] = []
1211
+ assert key.startswith(
1212
+ "annotation."
1213
+ ), f"Language key must start with 'annotation.', got {key}"
1214
+ subkey = key.replace("annotation.", "")
1215
+ annotation_meta = self.lerobot_modality_meta.annotation
1216
+ assert annotation_meta is not None, f"Annotation metadata is None for {subkey}"
1217
+ assert (
1218
+ subkey in annotation_meta
1219
+ ), f"Annotation key {subkey} not found in metadata, available annotation keys: {annotation_meta.keys()}"
1220
+ subkey_meta = annotation_meta[subkey]
1221
+ original_key = subkey_meta.original_key
1222
+ if original_key is None:
1223
+ original_key = key
1224
+ for i in range(len(step_indices)): #
1225
+ # task_indices.append(self.curr_traj_data[original_key][step_indices[i]].item())
1226
+ value = self.curr_traj_data[original_key].iloc[step_indices[i]] # TODO check v2.0
1227
+ task_indices.append(value if isinstance(value, (int, float)) else value.item())
1228
+
1229
+ return self.tasks.loc[task_indices]["task"].tolist()
1230
+
1231
+ def get_data_by_modality(
1232
+ self,
1233
+ trajectory_id: int,
1234
+ modality: str,
1235
+ key: str,
1236
+ base_index: int,
1237
+ ):
1238
+ """Get the data corresponding to the modality for a trajectory by a base index.
1239
+ This method will call the corresponding helper method based on the modality.
1240
+ See the helper methods for more details.
1241
+ NOTE: For the language modality, the data is padded with empty strings if no matching data is found.
1242
+
1243
+ Args:
1244
+ dataset (BaseSingleDataset): The dataset to retrieve the data from.
1245
+ trajectory_id (int): The ID of the trajectory.
1246
+ modality (str): The modality of the data.
1247
+ key (str): The key of the data.
1248
+ base_index (int): The base index of the trajectory.
1249
+ """
1250
+ if modality == "video":
1251
+ return self.get_video(trajectory_id, key, base_index)
1252
+ elif modality == "state" or modality == "action":
1253
+ return self.get_state_or_action(trajectory_id, modality, key, base_index)
1254
+ elif modality == "language":
1255
+ return self.get_language(trajectory_id, key, base_index)
1256
+ else:
1257
+ raise ValueError(f"Invalid modality: {modality}")
1258
+
1259
+ def _save_dataset_statistics_(self, save_path: Path | str, format: str = "json") -> None:
1260
+ """
1261
+ Save dataset statistics to specified path in the required format.
1262
+ Only includes statistics for keys that are actually used in the dataset.
1263
+ Key order follows modality config order.
1264
+
1265
+ Args:
1266
+ save_path (Path | str): Path to save the statistics file
1267
+ format (str): Save format, currently only supports "json"
1268
+ """
1269
+ save_path = Path(save_path)
1270
+ save_path.parent.mkdir(parents=True, exist_ok=True)
1271
+
1272
+ # Build the data structure to save
1273
+ statistics_data = {}
1274
+
1275
+ # Get used modality keys
1276
+ used_action_keys, used_state_keys = get_used_modality_keys(self.modality_keys)
1277
+
1278
+ # Organize statistics by tag
1279
+ tag = self.tag
1280
+ tag_stats = {}
1281
+
1282
+ # Process action statistics (only for used keys, config order)
1283
+ if hasattr(self.metadata.statistics, 'action') and self.metadata.statistics.action:
1284
+ action_stats = self.metadata.statistics.action
1285
+ filtered_action_stats = {
1286
+ key: action_stats[key]
1287
+ for key in used_action_keys
1288
+ if key in action_stats
1289
+ }
1290
+
1291
+ if filtered_action_stats:
1292
+ # Combine statistics from filtered action sub-keys
1293
+ combined_action_stats = combine_modality_stats(filtered_action_stats)
1294
+
1295
+ # Add mask field based on whether it's gripper or not
1296
+ mask = generate_action_mask_for_used_keys(
1297
+ self.metadata.modalities.action, filtered_action_stats.keys()
1298
+ )
1299
+ combined_action_stats["mask"] = mask
1300
+
1301
+ tag_stats["action"] = combined_action_stats
1302
+
1303
+ # Process state statistics (only for used keys, config order)
1304
+ if hasattr(self.metadata.statistics, 'state') and self.metadata.statistics.state:
1305
+ state_stats = self.metadata.statistics.state
1306
+ filtered_state_stats = {
1307
+ key: state_stats[key]
1308
+ for key in used_state_keys
1309
+ if key in state_stats
1310
+ }
1311
+
1312
+ if filtered_state_stats:
1313
+ combined_state_stats = combine_modality_stats(filtered_state_stats)
1314
+ tag_stats["state"] = combined_state_stats
1315
+
1316
+ # Add dataset counts
1317
+ tag_stats["num_transitions"] = len(self)
1318
+ tag_stats["num_trajectories"] = len(self.trajectory_ids)
1319
+
1320
+ statistics_data[tag] = tag_stats
1321
+
1322
+ # Save as JSON file
1323
+ if format.lower() == "json":
1324
+ if not str(save_path).endswith('.json'):
1325
+ save_path = save_path.with_suffix('.json')
1326
+ with open(save_path, 'w', encoding='utf-8') as f:
1327
+ json.dump(statistics_data, f, indent=2, ensure_ascii=False)
1328
+ else:
1329
+ raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.")
1330
+
1331
+ print(f"Single dataset statistics saved to: {save_path}")
1332
+ print(f"Used action keys (reordered): {list(used_action_keys)}")
1333
+ print(f"Used state keys (reordered): {list(used_state_keys)}")
1334
+
1335
+
1336
+
1337
+ class MixtureSpecElement(BaseModel):
1338
+ dataset_path: list[Path] | Path = Field(..., description="The path to the dataset.")
1339
+ dataset_weight: float = Field(..., description="The weight of the dataset in the mixture.")
1340
+ distribute_weights: bool = Field(
1341
+ default=False,
1342
+ description="Whether to distribute the weights of the dataset across all the paths. If True, the weights will be evenly distributed across all the paths.",
1343
+ )
1344
+
1345
+
1346
+ # Helper functions for dataset statistics
1347
+
1348
+ def combine_modality_stats(modality_stats: dict) -> dict:
1349
+ """
1350
+ Combine statistics from all sub-keys under a modality.
1351
+
1352
+ Args:
1353
+ modality_stats (dict): Statistics for a modality, containing multiple sub-keys.
1354
+ Each sub-key contains DatasetStatisticalValues object.
1355
+
1356
+ Returns:
1357
+ dict: Combined statistics
1358
+ """
1359
+ combined_stats = {
1360
+ "mean": [],
1361
+ "std": [],
1362
+ "max": [],
1363
+ "min": [],
1364
+ "q01": [],
1365
+ "q99": []
1366
+ }
1367
+
1368
+ # Combine statistics in sub-key order
1369
+ for subkey in modality_stats.keys():
1370
+ subkey_stats = modality_stats[subkey] # This is a DatasetStatisticalValues object
1371
+
1372
+ # Convert DatasetStatisticalValues to dict-like access
1373
+ for stat_name in ["mean", "std", "max", "min", "q01", "q99"]:
1374
+ stat_value = getattr(subkey_stats, stat_name)
1375
+ if isinstance(stat_value, (list, tuple)):
1376
+ combined_stats[stat_name].extend(stat_value)
1377
+ else:
1378
+ # Handle NDArray case - convert to list
1379
+ if hasattr(stat_value, 'tolist'):
1380
+ combined_stats[stat_name].extend(stat_value.tolist())
1381
+ else:
1382
+ combined_stats[stat_name].append(float(stat_value))
1383
+
1384
+ return combined_stats
1385
+
1386
+ def generate_action_mask_for_used_keys(action_modalities: dict, used_action_keys_ordered) -> list[bool]:
1387
+ """
1388
+ Generate mask based on action modalities, but only for used keys.
1389
+ All dimensions are set to True so every channel is de/normalized.
1390
+
1391
+ Args:
1392
+ action_modalities (dict): Configuration information for action modalities.
1393
+ used_action_keys_ordered: Iterable of actually used action keys in the correct order.
1394
+
1395
+ Returns:
1396
+ list[bool]: List of mask values
1397
+ """
1398
+ mask = []
1399
+
1400
+ # Generate mask in the same order as the statistics were combined
1401
+ for subkey in used_action_keys_ordered:
1402
+ if subkey in action_modalities:
1403
+ subkey_config = action_modalities[subkey]
1404
+
1405
+ # Get dimension count from shape
1406
+ if hasattr(subkey_config, 'shape') and len(subkey_config.shape) > 0:
1407
+ dim_count = subkey_config.shape[0]
1408
+ else:
1409
+ dim_count = 1
1410
+
1411
+ # Check if it's gripper-related
1412
+ is_gripper = "gripper" in subkey.lower()
1413
+
1414
+ # Generate mask value for each dimension
1415
+ for _ in range(dim_count):
1416
+ mask.append(not is_gripper) # gripper is False, others are True
1417
+
1418
+ return mask
1419
+
1420
+ def get_used_modality_keys(modality_keys: dict) -> tuple[set, set]:
1421
+ """Extract used action and state keys from modality configuration."""
1422
+ used_action_keys = []
1423
+ used_state_keys = []
1424
+
1425
+ # Extract action keys (remove "action." prefix)
1426
+ for action_key in modality_keys.get("action", []):
1427
+ if action_key.startswith("action."):
1428
+ clean_key = action_key.replace("action.", "")
1429
+ used_action_keys.append(clean_key)
1430
+
1431
+ # Extract state keys (remove "state." prefix)
1432
+ for state_key in modality_keys.get("state", []):
1433
+ if state_key.startswith("state."):
1434
+ clean_key = state_key.replace("state.", "")
1435
+ used_state_keys.append(clean_key)
1436
+
1437
+ return used_action_keys, used_state_keys
1438
+
1439
+
1440
+ def safe_hash(input_tuple):
1441
+ # keep 128 bits of the hash
1442
+ tuple_string = repr(input_tuple).encode("utf-8")
1443
+ sha256 = hashlib.sha256()
1444
+ sha256.update(tuple_string)
1445
+
1446
+ seed = int(sha256.hexdigest(), 16)
1447
+
1448
+ return seed & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
1449
+
1450
+
1451
+ class LeRobotMixtureDataset(Dataset):
1452
+ """
1453
+ A mixture of multiple datasets. This class samples a single dataset based on the dataset weights and then calls the `__getitem__` method of the sampled dataset.
1454
+ It is recommended to modify the single dataset class instead of this class.
1455
+ """
1456
+
1457
+ def __init__(
1458
+ self,
1459
+ data_mixture: Sequence[tuple[LeRobotSingleDataset, float]],
1460
+ mode: str,
1461
+ balance_dataset_weights: bool = True,
1462
+ balance_trajectory_weights: bool = True,
1463
+ seed: int = 42,
1464
+ metadata_config: dict = {
1465
+ "percentile_mixing_method": "min_max",
1466
+ },
1467
+ **kwargs,
1468
+ ):
1469
+ """
1470
+ Initialize the mixture dataset.
1471
+
1472
+ Args:
1473
+ data_mixture (list[tuple[LeRobotSingleDataset, float]]): Datasets and their corresponding weights.
1474
+ mode (str): If "train", __getitem__ will return different samples every epoch; if "val" or "test", __getitem__ will return the same sample every epoch.
1475
+ balance_dataset_weights (bool): If True, the weight of dataset will be multiplied by the total trajectory length of each dataset.
1476
+ balance_trajectory_weights (bool): If True, sample trajectories within a dataset weighted by their length; otherwise, use equal weighting.
1477
+ seed (int): Random seed for sampling.
1478
+ """
1479
+ datasets: list[LeRobotSingleDataset] = []
1480
+ dataset_sampling_weights: list[float] = []
1481
+ for dataset, weight in data_mixture:
1482
+ # Check if dataset is valid and has data
1483
+ if len(dataset) == 0:
1484
+ print(f"Warning: Skipping empty dataset {dataset.dataset_name}")
1485
+ continue
1486
+ datasets.append(dataset)
1487
+ dataset_sampling_weights.append(weight)
1488
+
1489
+ if len(datasets) == 0:
1490
+ raise ValueError("No valid datasets found in the mixture. All datasets are empty.")
1491
+
1492
+ self.datasets = datasets
1493
+ self.balance_dataset_weights = balance_dataset_weights
1494
+ self.balance_trajectory_weights = balance_trajectory_weights
1495
+ self.seed = seed
1496
+ self.mode = mode
1497
+
1498
+ # Set properties for sampling
1499
+
1500
+ # 1. Dataset lengths
1501
+ self._dataset_lengths = np.array([len(dataset) for dataset in self.datasets])
1502
+ print(f"Dataset lengths: {self._dataset_lengths}")
1503
+
1504
+ # 2. Dataset sampling weights
1505
+ self._dataset_sampling_weights = np.array(dataset_sampling_weights)
1506
+
1507
+ if self.balance_dataset_weights:
1508
+ self._dataset_sampling_weights *= self._dataset_lengths
1509
+
1510
+ # Check for zero or negative weights before normalization
1511
+ if np.any(self._dataset_sampling_weights <= 0):
1512
+ print(f"Warning: Found zero or negative sampling weights: {self._dataset_sampling_weights}")
1513
+ # Set minimum weight to prevent division issues
1514
+ self._dataset_sampling_weights = np.maximum(self._dataset_sampling_weights, 1e-8)
1515
+
1516
+ # Normalize weights
1517
+ weights_sum = self._dataset_sampling_weights.sum()
1518
+ if weights_sum == 0 or np.isnan(weights_sum):
1519
+ print(f"Error: Invalid weights sum: {weights_sum}")
1520
+ # Fallback to equal weights
1521
+ self._dataset_sampling_weights = np.ones(len(self.datasets)) / len(self.datasets)
1522
+ print(f"Fallback to equal weights")
1523
+ else:
1524
+ self._dataset_sampling_weights /= weights_sum
1525
+
1526
+ # 3. Trajectory sampling weights
1527
+ self._trajectory_sampling_weights: list[np.ndarray] = []
1528
+ for i, dataset in enumerate(self.datasets):
1529
+ trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths))
1530
+ if self.balance_trajectory_weights:
1531
+ trajectory_sampling_weights *= dataset.trajectory_lengths
1532
+
1533
+ # Check for zero or negative weights before normalization
1534
+ if np.any(trajectory_sampling_weights <= 0):
1535
+ print(f"Warning: Dataset {i} has zero or negative trajectory weights")
1536
+ trajectory_sampling_weights = np.maximum(trajectory_sampling_weights, 1e-8)
1537
+
1538
+ # Normalize weights
1539
+ weights_sum = trajectory_sampling_weights.sum()
1540
+ if weights_sum == 0 or np.isnan(weights_sum):
1541
+ print(f"Error: Dataset {i} has invalid trajectory weights sum: {weights_sum}")
1542
+ # Fallback to equal weights
1543
+ trajectory_sampling_weights = np.ones(len(dataset.trajectory_lengths)) / len(dataset.trajectory_lengths)
1544
+ else:
1545
+ trajectory_sampling_weights /= weights_sum
1546
+
1547
+ self._trajectory_sampling_weights.append(trajectory_sampling_weights)
1548
+
1549
+ # 4. Primary dataset indices
1550
+ self._primary_dataset_indices = np.array(dataset_sampling_weights) == 1.0
1551
+ if not np.any(self._primary_dataset_indices):
1552
+ print(f"Warning: No dataset with weight 1.0 found. Original weights: {dataset_sampling_weights}")
1553
+ # Fallback: use the dataset(s) with maximum weight as primary
1554
+ max_weight = max(dataset_sampling_weights)
1555
+ self._primary_dataset_indices = np.array(dataset_sampling_weights) == max_weight
1556
+ print(f"Using datasets with maximum weight {max_weight} as primary: {self._primary_dataset_indices}")
1557
+
1558
+ if not np.any(self._primary_dataset_indices):
1559
+ # This should never happen, but just in case
1560
+ print("Error: Still no primary dataset found. Using first dataset as primary.")
1561
+ self._primary_dataset_indices = np.zeros(len(self.datasets), dtype=bool)
1562
+ self._primary_dataset_indices[0] = True
1563
+
1564
+ # Set the epoch and sample the first epoch
1565
+ self.set_epoch(0)
1566
+
1567
+ self.update_metadata(metadata_config)
1568
+
1569
+ @property
1570
+ def dataset_lengths(self) -> np.ndarray:
1571
+ """The lengths of each dataset."""
1572
+ return self._dataset_lengths
1573
+
1574
+ @property
1575
+ def dataset_sampling_weights(self) -> np.ndarray:
1576
+ """The sampling weights for each dataset."""
1577
+ return self._dataset_sampling_weights
1578
+
1579
+ @property
1580
+ def trajectory_sampling_weights(self) -> list[np.ndarray]:
1581
+ """The sampling weights for each trajectory in each dataset."""
1582
+ return self._trajectory_sampling_weights
1583
+
1584
+ @property
1585
+ def primary_dataset_indices(self) -> np.ndarray:
1586
+ """The indices of the primary datasets."""
1587
+ return self._primary_dataset_indices
1588
+
1589
+ def __str__(self) -> str:
1590
+ dataset_descriptions = []
1591
+ for dataset, weight in zip(self.datasets, self.dataset_sampling_weights):
1592
+ dataset_description = {
1593
+ "Dataset": str(dataset),
1594
+ "Sampling weight": float(weight),
1595
+ }
1596
+ dataset_descriptions.append(dataset_description)
1597
+ return json.dumps({"Mixture dataset": dataset_descriptions}, indent=2)
1598
+
1599
+ def set_epoch(self, epoch: int):
1600
+ """Set the epoch for the dataset.
1601
+
1602
+ Args:
1603
+ epoch (int): The epoch to set.
1604
+ """
1605
+ self.epoch = epoch
1606
+ # self.sampled_steps = self.sample_epoch()
1607
+
1608
+ def sample_step(self, index: int) -> tuple[LeRobotSingleDataset, int, int]:
1609
+ """Sample a single step from the dataset."""
1610
+ # return self.sampled_steps[index]
1611
+
1612
+ # Set seed
1613
+ seed = index if self.mode != "train" else safe_hash((self.epoch, index, self.seed))
1614
+ rng = np.random.default_rng(seed)
1615
+
1616
+ # Sample dataset
1617
+ dataset_index = rng.choice(len(self.datasets), p=self.dataset_sampling_weights)
1618
+ dataset = self.datasets[dataset_index]
1619
+
1620
+ # Sample trajectory
1621
+ # trajectory_index = rng.choice(
1622
+ # len(dataset.trajectory_ids), p=self.trajectory_sampling_weights[dataset_index]
1623
+ # )
1624
+ # trajectory_id = dataset.trajectory_ids[trajectory_index]
1625
+
1626
+ # # Sample step
1627
+ # base_index = rng.choice(dataset.trajectory_lengths[trajectory_index])
1628
+ # return dataset, trajectory_id, base_index
1629
+ single_step_index = rng.choice(len(dataset.all_steps))
1630
+ trajectory_id, base_index = dataset.all_steps[single_step_index]
1631
+ return dataset, trajectory_id, base_index
1632
+
1633
+ def __getitem__(self, index: int) -> dict:
1634
+ """Get the data for a single trajectory and start index.
1635
+
1636
+ Args:
1637
+ index (int): The index of the trajectory to get.
1638
+
1639
+ Returns:
1640
+ dict: The data for the trajectory and start index.
1641
+ """
1642
+ max_retries = 10
1643
+ last_exception = None
1644
+
1645
+ for attempt in range(max_retries):
1646
+ try:
1647
+ dataset, trajectory_name, step = self.sample_step(index)
1648
+ data_raw = dataset.get_step_data(trajectory_name, step)
1649
+ data = dataset.transforms(data_raw)
1650
+
1651
+ # Process all video keys dynamically
1652
+ images = []
1653
+ for video_key in dataset.modality_keys.get("video", []):
1654
+ image = data[video_key][0]
1655
+
1656
+ image = Image.fromarray(image).resize((224, 224)) #TODO check if this is ok
1657
+ images.append(image)
1658
+
1659
+ # Get language and action data
1660
+ language = data[dataset.modality_keys["language"][0]][0]
1661
+ action = []
1662
+ for action_key in dataset.modality_keys["action"]:
1663
+ action.append(data[action_key])
1664
+ action = np.concatenate(action, axis=1).astype(np.float16)
1665
+ action = standardize_action_representation(action, dataset.tag)
1666
+
1667
+ state = []
1668
+ for state_key in dataset.modality_keys["state"]:
1669
+ state.append(data[state_key])
1670
+ state = np.concatenate(state, axis=1).astype(np.float16)
1671
+ state = standardize_state_representation(state, dataset.tag)
1672
+
1673
+ return dict(action=action, state=state, image=images, lang=language, dataset_id=dataset._dataset_id)
1674
+
1675
+ except Exception as e:
1676
+ last_exception = e
1677
+ if attempt < max_retries - 1:
1678
+ # Log the error but continue trying
1679
+ print(f"Attempt {attempt + 1}/{max_retries} failed for index {index}: {e}")
1680
+ print(f"Retrying with new sample...")
1681
+ # For retry, we can use a slightly different index to get a new sample
1682
+ # This helps avoid getting stuck on the same problematic sample
1683
+ index = random.randint(0, len(self) - 1)
1684
+ else:
1685
+ # All retries exhausted
1686
+ print(f"All {max_retries} attempts failed for index {index}")
1687
+ print(f"Last error: {last_exception}")
1688
+ # Return a dummy sample or re-raise the exception
1689
+ raise last_exception
1690
+
1691
+ def __len__(self) -> int:
1692
+ """Get the length of a single epoch in the mixture.
1693
+
1694
+ Returns:
1695
+ int: The length of a single epoch in the mixture.
1696
+ """
1697
+ # Check for potential issues
1698
+ if len(self.datasets) == 0:
1699
+ return 0
1700
+
1701
+ # Check if any dataset lengths are 0 or NaN
1702
+ if np.any(self.dataset_lengths == 0) or np.any(np.isnan(self.dataset_lengths)):
1703
+ print(f"Warning: Found zero or NaN dataset lengths: {self.dataset_lengths}")
1704
+ # Filter out zero/NaN length datasets
1705
+ valid_indices = (self.dataset_lengths > 0) & (~np.isnan(self.dataset_lengths))
1706
+ if not np.any(valid_indices):
1707
+ print("Error: All datasets have zero or NaN length")
1708
+ return 0
1709
+ else:
1710
+ valid_indices = np.ones(len(self.datasets), dtype=bool)
1711
+
1712
+ # Check if any sampling weights are 0 or NaN
1713
+ if np.any(self.dataset_sampling_weights == 0) or np.any(np.isnan(self.dataset_sampling_weights)):
1714
+ print(f"Warning: Found zero or NaN sampling weights: {self.dataset_sampling_weights}")
1715
+ # Use only valid weights
1716
+ valid_weights = (self.dataset_sampling_weights > 0) & (~np.isnan(self.dataset_sampling_weights))
1717
+ valid_indices = valid_indices & valid_weights
1718
+ if not np.any(valid_indices):
1719
+ print("Error: All sampling weights are zero or NaN")
1720
+ return 0
1721
+
1722
+ # Check primary dataset indices
1723
+ primary_and_valid = self.primary_dataset_indices & valid_indices
1724
+ if not np.any(primary_and_valid):
1725
+ print(f"Warning: No valid primary datasets found. Primary indices: {self.primary_dataset_indices}, Valid indices: {valid_indices}")
1726
+ # Fallback: use the largest valid dataset
1727
+ if np.any(valid_indices):
1728
+ max_length = self.dataset_lengths[valid_indices].max()
1729
+ print(f"Fallback: Using maximum dataset length: {max_length}")
1730
+ return int(max_length)
1731
+ else:
1732
+ return 0
1733
+
1734
+ # Calculate the ratio and get max
1735
+ ratios = (self.dataset_lengths / self.dataset_sampling_weights)[primary_and_valid]
1736
+
1737
+ # Check for NaN or inf in ratios
1738
+ if np.any(np.isnan(ratios)) or np.any(np.isinf(ratios)):
1739
+ print(f"Warning: Found NaN or inf in ratios: {ratios}")
1740
+ print(f"Dataset lengths: {self.dataset_lengths[primary_and_valid]}")
1741
+ print(f"Sampling weights: {self.dataset_sampling_weights[primary_and_valid]}")
1742
+ # Filter out invalid ratios
1743
+ valid_ratios = ratios[~np.isnan(ratios) & ~np.isinf(ratios)]
1744
+ if len(valid_ratios) == 0:
1745
+ print("Error: All ratios are NaN or inf")
1746
+ return 0
1747
+ max_ratio = valid_ratios.max()
1748
+ else:
1749
+ max_ratio = ratios.max()
1750
+
1751
+ result = int(max_ratio)
1752
+ if result == 0:
1753
+ print(f"Warning: Dataset mixture length is 0")
1754
+ return result
1755
+
1756
+ @staticmethod
1757
+ def compute_overall_statistics(
1758
+ per_task_stats: list[dict[str, dict[str, list[float] | np.ndarray]]],
1759
+ dataset_sampling_weights: list[float] | np.ndarray,
1760
+ percentile_mixing_method: str = "weighted_average",
1761
+ ) -> dict[str, dict[str, list[float]]]:
1762
+ """
1763
+ Computes overall statistics from per-task statistics using dataset sample weights.
1764
+
1765
+ Args:
1766
+ per_task_stats: List of per-task statistics.
1767
+ Example format of one element in the per-task statistics list:
1768
+ {
1769
+ "state.gripper": {
1770
+ "min": [...],
1771
+ "max": [...],
1772
+ "mean": [...],
1773
+ "std": [...],
1774
+ "q01": [...],
1775
+ "q99": [...],
1776
+ },
1777
+ ...
1778
+ }
1779
+ dataset_sampling_weights: List of sample weights for each task.
1780
+ percentile_mixing_method: The method to mix the percentiles, either "weighted_average" or "weighted_std".
1781
+
1782
+ Returns:
1783
+ A dict of overall statistics per modality.
1784
+ """
1785
+ # Normalize the sample weights to sum to 1
1786
+ dataset_sampling_weights = np.array(dataset_sampling_weights)
1787
+ normalized_weights = dataset_sampling_weights / dataset_sampling_weights.sum()
1788
+
1789
+ # Initialize overall statistics dict
1790
+ overall_stats: dict[str, dict[str, list[float]]] = {}
1791
+
1792
+ # Get the list of modality keys
1793
+ modality_keys = per_task_stats[0].keys()
1794
+
1795
+ for modality in modality_keys:
1796
+ # Number of dimensions (assuming consistent across tasks)
1797
+ num_dims = len(per_task_stats[0][modality]["mean"])
1798
+
1799
+ # Initialize accumulators for means and variances
1800
+ weighted_means = np.zeros(num_dims)
1801
+ weighted_squares = np.zeros(num_dims)
1802
+
1803
+ # Collect min, max, q01, q99 from all tasks
1804
+ min_list = []
1805
+ max_list = []
1806
+ q01_list = []
1807
+ q99_list = []
1808
+
1809
+ for task_idx, task_stats in enumerate(per_task_stats):
1810
+ w_i = normalized_weights[task_idx]
1811
+ stats = task_stats[modality]
1812
+ means = np.array(stats["mean"])
1813
+ stds = np.array(stats["std"])
1814
+
1815
+ # Update weighted sums for mean and variance
1816
+ weighted_means += w_i * means
1817
+ weighted_squares += w_i * (stds**2 + means**2)
1818
+
1819
+ # Collect min, max, q01, q99
1820
+ min_list.append(stats["min"])
1821
+ max_list.append(stats["max"])
1822
+ q01_list.append(stats["q01"])
1823
+ q99_list.append(stats["q99"])
1824
+
1825
+ # Compute overall mean
1826
+ overall_mean = weighted_means.tolist()
1827
+
1828
+ # Compute overall variance and std deviation
1829
+ overall_variance = weighted_squares - weighted_means**2
1830
+ overall_std = np.sqrt(overall_variance).tolist()
1831
+
1832
+ # Compute overall min and max per dimension
1833
+ overall_min = np.min(np.array(min_list), axis=0).tolist()
1834
+ overall_max = np.max(np.array(max_list), axis=0).tolist()
1835
+
1836
+ # Compute overall q01 and q99 per dimension
1837
+ # Use weighted average of per-task quantiles
1838
+ q01_array = np.array(q01_list)
1839
+ q99_array = np.array(q99_list)
1840
+ if percentile_mixing_method == "weighted_average":
1841
+ weighted_q01 = np.average(q01_array, axis=0, weights=normalized_weights).tolist()
1842
+ weighted_q99 = np.average(q99_array, axis=0, weights=normalized_weights).tolist()
1843
+ # std_q01 = np.std(q01_array, axis=0).tolist()
1844
+ # std_q99 = np.std(q99_array, axis=0).tolist()
1845
+ # print(modality)
1846
+ # print(f"{std_q01=}, {std_q99=}")
1847
+ # print(f"{weighted_q01=}, {weighted_q99=}")
1848
+ elif percentile_mixing_method == "min_max":
1849
+ weighted_q01 = np.min(q01_array, axis=0).tolist()
1850
+ weighted_q99 = np.max(q99_array, axis=0).tolist()
1851
+ else:
1852
+ raise ValueError(f"Invalid percentile mixing method: {percentile_mixing_method}")
1853
+
1854
+ # Store the overall statistics for the modality
1855
+ overall_stats[modality] = {
1856
+ "min": overall_min,
1857
+ "max": overall_max,
1858
+ "mean": overall_mean,
1859
+ "std": overall_std,
1860
+ "q01": weighted_q01,
1861
+ "q99": weighted_q99,
1862
+ }
1863
+
1864
+ return overall_stats
1865
+
1866
+ @staticmethod
1867
+ def merge_metadata(
1868
+ metadatas: list[DatasetMetadata],
1869
+ dataset_sampling_weights: list[float],
1870
+ percentile_mixing_method: str,
1871
+ ) -> DatasetMetadata:
1872
+ """Merge multiple metadata into one."""
1873
+ # Convert to dicts
1874
+ metadata_dicts = [metadata.model_dump(mode="json") for metadata in metadatas]
1875
+ # Create a new metadata dict
1876
+ merged_metadata = {}
1877
+
1878
+ # Check all metadata have the same embodiment tag
1879
+ assert all(
1880
+ metadata.embodiment_tag == metadatas[0].embodiment_tag for metadata in metadatas
1881
+ ), "All metadata must have the same embodiment tag"
1882
+ merged_metadata["embodiment_tag"] = metadatas[0].embodiment_tag
1883
+
1884
+ # Merge the dataset statistics
1885
+ dataset_statistics = {}
1886
+ dataset_statistics["state"] = LeRobotMixtureDataset.compute_overall_statistics(
1887
+ per_task_stats=[m["statistics"]["state"] for m in metadata_dicts],
1888
+ dataset_sampling_weights=dataset_sampling_weights,
1889
+ percentile_mixing_method=percentile_mixing_method,
1890
+ )
1891
+ dataset_statistics["action"] = LeRobotMixtureDataset.compute_overall_statistics(
1892
+ per_task_stats=[m["statistics"]["action"] for m in metadata_dicts],
1893
+ dataset_sampling_weights=dataset_sampling_weights,
1894
+ percentile_mixing_method=percentile_mixing_method,
1895
+ )
1896
+ merged_metadata["statistics"] = dataset_statistics
1897
+
1898
+ # Merge the modality configs
1899
+ modality_configs = defaultdict(set)
1900
+ for metadata in metadata_dicts:
1901
+ for modality, configs in metadata["modalities"].items():
1902
+ modality_configs[modality].add(json.dumps(configs))
1903
+ merged_metadata["modalities"] = {}
1904
+ for modality, configs in modality_configs.items():
1905
+ # Check that all modality configs correspond to the same tag matches
1906
+ assert (
1907
+ len(configs) == 1
1908
+ ), f"Multiple modality configs for modality {modality}: {list(configs)}"
1909
+ merged_metadata["modalities"][modality] = json.loads(configs.pop())
1910
+
1911
+ return DatasetMetadata.model_validate(merged_metadata)
1912
+
1913
+ def update_metadata(self, metadata_config: dict, cached_statistics_path: Path | str | None = None) -> None:
1914
+ """
1915
+ Merge multiple metadatas into one and set the transforms with the merged metadata.
1916
+
1917
+ Args:
1918
+ metadata_config (dict): Configuration for the metadata.
1919
+ "percentile_mixing_method": The method to mix the percentiles, either "weighted_average" or "min_max".
1920
+ weighted_average: Use the weighted average of the percentiles using the weight used in sampling the datasets.
1921
+ min_max: Use the min of the 1st percentile and max of the 99th percentile.
1922
+ """
1923
+ # If cached path is provided, try to load and apply
1924
+ if cached_statistics_path is not None:
1925
+ try:
1926
+ cached_stats = self.load_merged_statistics(cached_statistics_path)
1927
+ self.apply_cached_statistics(cached_stats)
1928
+ return
1929
+ except (FileNotFoundError, KeyError, ValidationError) as e:
1930
+ print(f"Failed to load cached statistics: {e}")
1931
+ print("Falling back to computing statistics from scratch...")
1932
+
1933
+ self.tag = EmbodimentTag.NEW_EMBODIMENT.value
1934
+ self.merged_metadata: dict[str, DatasetMetadata] = {}
1935
+ # Group metadata by tag
1936
+ all_metadatas: dict[str, list[DatasetMetadata]] = {}
1937
+ for dataset in self.datasets:
1938
+ if dataset.tag not in all_metadatas:
1939
+ all_metadatas[dataset.tag] = []
1940
+ all_metadatas[dataset.tag].append(dataset.metadata)
1941
+ for tag, metadatas in all_metadatas.items():
1942
+ self.merged_metadata[tag] = self.merge_metadata(
1943
+ metadatas=metadatas,
1944
+ dataset_sampling_weights=self.dataset_sampling_weights.tolist(),
1945
+ percentile_mixing_method=metadata_config["percentile_mixing_method"],
1946
+ )
1947
+ for dataset in self.datasets:
1948
+ dataset.set_transforms_metadata(self.merged_metadata[dataset.tag])
1949
+
1950
+ def save_dataset_statistics(self, save_path: Path | str, format: str = "json") -> None:
1951
+ """
1952
+ Save merged dataset statistics to specified path in the required format.
1953
+ Only includes statistics for keys that are actually used in the datasets.
1954
+ Key order follows each tag's modality config order.
1955
+
1956
+ Args:
1957
+ save_path (Path | str): Path to save the statistics file
1958
+ format (str): Save format, currently only supports "json"
1959
+ """
1960
+ save_path = Path(save_path)
1961
+ save_path.parent.mkdir(parents=True, exist_ok=True)
1962
+
1963
+ # Build the data structure to save
1964
+ statistics_data = {}
1965
+
1966
+ # Keep key orders per embodiment tag (from modality config order)
1967
+ tag_to_used_action_keys = {}
1968
+ tag_to_used_state_keys = {}
1969
+ for dataset in self.datasets:
1970
+ if dataset.tag in tag_to_used_action_keys:
1971
+ continue
1972
+ used_action_keys, used_state_keys = get_used_modality_keys(dataset.modality_keys)
1973
+ tag_to_used_action_keys[dataset.tag] = used_action_keys
1974
+ tag_to_used_state_keys[dataset.tag] = used_state_keys
1975
+
1976
+ # Organize statistics by tag
1977
+ for tag, merged_metadata in self.merged_metadata.items():
1978
+ tag_stats = {}
1979
+
1980
+ # Process action statistics
1981
+ if hasattr(merged_metadata.statistics, 'action') and merged_metadata.statistics.action:
1982
+ action_stats = merged_metadata.statistics.action
1983
+
1984
+ used_action_keys = tag_to_used_action_keys.get(tag, [])
1985
+ filtered_action_stats = {
1986
+ key: action_stats[key]
1987
+ for key in used_action_keys
1988
+ if key in action_stats
1989
+ }
1990
+
1991
+ if filtered_action_stats:
1992
+ combined_action_stats = combine_modality_stats(filtered_action_stats)
1993
+
1994
+ mask = generate_action_mask_for_used_keys(
1995
+ merged_metadata.modalities.action, filtered_action_stats.keys()
1996
+ )
1997
+ combined_action_stats["mask"] = mask
1998
+
1999
+ tag_stats["action"] = combined_action_stats
2000
+
2001
+ # Process state statistics
2002
+ if hasattr(merged_metadata.statistics, 'state') and merged_metadata.statistics.state:
2003
+ state_stats = merged_metadata.statistics.state
2004
+
2005
+ used_state_keys = tag_to_used_state_keys.get(tag, [])
2006
+ filtered_state_stats = {
2007
+ key: state_stats[key]
2008
+ for key in used_state_keys
2009
+ if key in state_stats
2010
+ }
2011
+
2012
+ if filtered_state_stats:
2013
+ combined_state_stats = combine_modality_stats(filtered_state_stats)
2014
+ tag_stats["state"] = combined_state_stats
2015
+
2016
+ # Add dataset counts
2017
+ tag_stats.update(self._get_dataset_counts(tag))
2018
+
2019
+ statistics_data[tag] = tag_stats
2020
+
2021
+ # Save file
2022
+ if format.lower() == "json":
2023
+ if not str(save_path).endswith('.json'):
2024
+ save_path = save_path.with_suffix('.json')
2025
+ with open(save_path, 'w', encoding='utf-8') as f:
2026
+ json.dump(statistics_data, f, indent=2, ensure_ascii=False)
2027
+ else:
2028
+ raise ValueError(f"Unsupported format: {format}. Currently only 'json' is supported.")
2029
+
2030
+ print(f"Merged dataset statistics saved to: {save_path}")
2031
+ print(f"Used action keys by tag: {tag_to_used_action_keys}")
2032
+ print(f"Used state keys by tag: {tag_to_used_state_keys}")
2033
+
2034
+
2035
+ def _combine_modality_stats(self, modality_stats: dict) -> dict:
2036
+ """Backward compatibility wrapper."""
2037
+ return combine_modality_stats(modality_stats)
2038
+
2039
+ def _generate_action_mask_for_used_keys(self, action_modalities: dict, used_action_keys_ordered) -> list[bool]:
2040
+ """Backward compatibility wrapper."""
2041
+ return generate_action_mask_for_used_keys(action_modalities, used_action_keys_ordered)
2042
+
2043
+ def _get_dataset_counts(self, tag: str) -> dict:
2044
+ """
2045
+ Get dataset count information for specified tag.
2046
+
2047
+ Args:
2048
+ tag (str): embodiment tag
2049
+
2050
+ Returns:
2051
+ dict: Dictionary containing num_transitions and num_trajectories
2052
+ """
2053
+ num_transitions = 0
2054
+ num_trajectories = 0
2055
+
2056
+ # Count dataset information belonging to this tag
2057
+ for dataset in self.datasets:
2058
+ if dataset.tag == tag:
2059
+ num_transitions += len(dataset)
2060
+ num_trajectories += len(dataset.trajectory_ids)
2061
+
2062
+ return {
2063
+ "num_transitions": num_transitions,
2064
+ "num_trajectories": num_trajectories
2065
+ }
2066
+
2067
+ @classmethod
2068
+ def load_merged_statistics(cls, load_path: Path | str) -> dict:
2069
+ """
2070
+ Load merged dataset statistics from file.
2071
+
2072
+ Args:
2073
+ load_path (Path | str): Path to the statistics file
2074
+
2075
+ Returns:
2076
+ dict: Dictionary containing merged statistics
2077
+ """
2078
+ load_path = Path(load_path)
2079
+ if not load_path.exists():
2080
+ raise FileNotFoundError(f"Statistics file not found: {load_path}")
2081
+
2082
+ if load_path.suffix.lower() == '.json':
2083
+ with open(load_path, 'r', encoding='utf-8') as f:
2084
+ return json.load(f)
2085
+ elif load_path.suffix.lower() == '.pkl':
2086
+ import pickle
2087
+ with open(load_path, 'rb') as f:
2088
+ return pickle.load(f)
2089
+ else:
2090
+ raise ValueError(f"Unsupported file format: {load_path.suffix}")
2091
+
2092
+ def apply_cached_statistics(self, cached_statistics: dict) -> None:
2093
+ """
2094
+ Apply cached statistics to avoid recomputation.
2095
+
2096
+ Args:
2097
+ cached_statistics (dict): Statistics loaded from file
2098
+ """
2099
+ # Validate that cached statistics match current datasets
2100
+ if "metadata" in cached_statistics:
2101
+ cached_dataset_names = set(cached_statistics["metadata"]["dataset_names"])
2102
+ current_dataset_names = set(dataset.dataset_name for dataset in self.datasets)
2103
+
2104
+ if cached_dataset_names != current_dataset_names:
2105
+ print("Warning: Cached statistics dataset names don't match current datasets.")
2106
+ print(f"Cached: {cached_dataset_names}")
2107
+ print(f"Current: {current_dataset_names}")
2108
+ return
2109
+
2110
+ # Apply cached statistics
2111
+ self.merged_metadata = {}
2112
+ for tag, stats_data in cached_statistics.items():
2113
+ if tag == "metadata": # Skip metadata field
2114
+ continue
2115
+
2116
+ # Convert back to DatasetMetadata format
2117
+ metadata_dict = {
2118
+ "embodiment_tag": tag,
2119
+ "statistics": {
2120
+ "action": {},
2121
+ "state": {}
2122
+ },
2123
+ "modalities": {}
2124
+ }
2125
+
2126
+ # Convert action statistics back
2127
+ if "action" in stats_data:
2128
+ action_data = stats_data["action"]
2129
+ # This is simplified - you may need to split back to sub-keys
2130
+ metadata_dict["statistics"]["action"] = action_data
2131
+
2132
+ # Convert state statistics back
2133
+ if "state" in stats_data:
2134
+ state_data = stats_data["state"]
2135
+ metadata_dict["statistics"]["state"] = state_data
2136
+
2137
+ self.merged_metadata[tag] = DatasetMetadata.model_validate(metadata_dict)
2138
+
2139
+ # Update transforms metadata for each dataset
2140
+ for dataset in self.datasets:
2141
+ if dataset.tag in self.merged_metadata:
2142
+ dataset.set_transforms_metadata(self.merged_metadata[dataset.tag])
2143
+
2144
+ print(f"Applied cached statistics for {len(self.merged_metadata)} embodiment tags.")
2145
+
code/dataloader/gr00t_lerobot/embodiment_tags.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from enum import Enum
17
+
18
+
19
+ class EmbodimentTag(Enum):
20
+ GR1 = "gr1"
21
+ """
22
+ The GR1 dataset.
23
+ """
24
+
25
+ OXE_DROID = "oxe_droid"
26
+ """
27
+ The OxE Droid dataset.
28
+ """
29
+
30
+ OXE_BRIDGE = "oxe_bridge"
31
+ """
32
+ The OxE Bridge dataset.
33
+ """
34
+
35
+ OXE_RT1 = "oxe_rt1"
36
+ """
37
+ The OxE RT-1 dataset.
38
+ """
39
+
40
+ AGIBOT_GENIE1 = "agibot_genie1"
41
+ """
42
+ The AgiBot Genie-1 with gripper dataset.
43
+ """
44
+
45
+ NEW_EMBODIMENT = "new_embodiment"
46
+ """
47
+ Any new embodiment for finetuning.
48
+ """
49
+
50
+ FRANKA = 'franka'
51
+ """
52
+ The Franka Emika Panda robot.
53
+ """
54
+
55
+ ROBOTWIN = "robotwin"
56
+ """
57
+ RobotWin (dual-arm) datasets.
58
+ """
59
+
60
+ REAL_WORLD_FRANKA = "real_world_franka"
61
+ """
62
+ The Real-World Franka robot.
63
+ """
64
+
65
+ # Embodiment tag string: to projector index in the Action Expert Module
66
+ # EMBODIMENT_TAG_MAPPING = {
67
+ # EmbodimentTag.NEW_EMBODIMENT.value: 31,
68
+ # EmbodimentTag.OXE_DROID.value: 17,
69
+ # EmbodimentTag.OXE_BRIDGE.value: 18,
70
+ # EmbodimentTag.OXE_RT1.value: 19,
71
+ # EmbodimentTag.AGIBOT_GENIE1.value: 26,
72
+ # EmbodimentTag.GR1.value: 24,
73
+ # EmbodimentTag.FRANKA.value: 25,
74
+ # EmbodimentTag.ROBOTWIN.value: 27,
75
+ # EmbodimentTag.REAL_WORLD_FRANKA.value: 28,
76
+ # }
77
+
78
+ # Robot type to embodiment tag mapping
79
+ ROBOT_TYPE_TO_EMBODIMENT_TAG = {
80
+ "libero_franka": EmbodimentTag.FRANKA,
81
+ "oxe_droid": EmbodimentTag.OXE_DROID,
82
+ "oxe_bridge": EmbodimentTag.OXE_BRIDGE,
83
+ "oxe_rt1": EmbodimentTag.OXE_RT1,
84
+ "demo_sim_franka_delta_joints": EmbodimentTag.FRANKA,
85
+ "custom_robot_config": EmbodimentTag.NEW_EMBODIMENT,
86
+ "fourier_gr1_arms_waist": EmbodimentTag.GR1,
87
+ "robotwin": EmbodimentTag.ROBOTWIN,
88
+ "real_world_franka": EmbodimentTag.REAL_WORLD_FRANKA,
89
+ }
90
+
91
+ DATASET_NAME_TO_ID = {
92
+ # Libero Datasets
93
+ "libero_object_no_noops_1.0.0_lerobot": 1,
94
+ "libero_goal_no_noops_1.0.0_lerobot": 1,
95
+ "libero_spatial_no_noops_1.0.0_lerobot": 1,
96
+ "libero_10_no_noops_1.0.0_lerobot": 1,
97
+ "libero_90_no_noops_lerobot": 1,
98
+
99
+ # OXE Datasets
100
+ "bridge_orig_lerobot": 2,
101
+ "fractal20220817_data_lerobot": 3,
102
+ "droid_lerobot": 4,
103
+ "furniture_bench_dataset_lerobot": 5,
104
+ "taco_play_lerobot": 6,
105
+
106
+ # RoboCasa Datasets
107
+ "gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_1000": 7,
108
+ "gr1_unified.PnPCanToDrawerClose_GR1ArmsAndWaistFourierHands_1000": 7,
109
+ "gr1_unified.PnPCupToDrawerClose_GR1ArmsAndWaistFourierHands_1000": 7,
110
+ "gr1_unified.PnPMilkToMicrowaveClose_GR1ArmsAndWaistFourierHands_1000": 7,
111
+ "gr1_unified.PnPPotatoToMicrowaveClose_GR1ArmsAndWaistFourierHands_1000": 7,
112
+ "gr1_unified.PnPWineToCabinetClose_GR1ArmsAndWaistFourierHands_1000": 7,
113
+ "gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
114
+ "gr1_unified.PosttrainPnPNovelFromCuttingboardToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
115
+ "gr1_unified.PosttrainPnPNovelFromCuttingboardToPanSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
116
+ "gr1_unified.PosttrainPnPNovelFromCuttingboardToPotSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
117
+ "gr1_unified.PosttrainPnPNovelFromCuttingboardToTieredbasketSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
118
+ "gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
119
+ "gr1_unified.PosttrainPnPNovelFromPlacematToBowlSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
120
+ "gr1_unified.PosttrainPnPNovelFromPlacematToPlateSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
121
+ "gr1_unified.PosttrainPnPNovelFromPlacematToTieredshelfSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
122
+ "gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
123
+ "gr1_unified.PosttrainPnPNovelFromPlateToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
124
+ "gr1_unified.PosttrainPnPNovelFromPlateToPanSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
125
+ "gr1_unified.PosttrainPnPNovelFromPlateToPlateSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
126
+ "gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
127
+ "gr1_unified.PosttrainPnPNovelFromTrayToPlateSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
128
+ "gr1_unified.PosttrainPnPNovelFromTrayToPotSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
129
+ "gr1_unified.PosttrainPnPNovelFromTrayToTieredbasketSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
130
+ "gr1_unified.PosttrainPnPNovelFromTrayToTieredshelfSplitA_GR1ArmsAndWaistFourierHands_1000": 7,
131
+ "gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200": 7,
132
+ "gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200": 7,
133
+ "gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200": 7,
134
+ "gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200": 7,
135
+ "gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200": 7,
136
+
137
+ # robotwin
138
+ "adjust_bottle": 8,
139
+ "beat_block_hammer": 8,
140
+ "blocks_ranking_rgb": 8,
141
+ "blocks_ranking_size": 8,
142
+ "click_alarmclock": 8,
143
+ "click_bell": 8,
144
+ "dump_bin_bigbin": 8,
145
+ "grab_roller": 8,
146
+ "handover_block": 8,
147
+ "handover_mic": 8,
148
+ "hanging_mug": 8,
149
+ "lift_pot": 8,
150
+ "move_can_pot": 8,
151
+ "move_pillbottle_pad": 8,
152
+ "move_playingcard_away": 8,
153
+ "move_stapler_pad": 8,
154
+ "open_laptop": 8,
155
+ "open_microwave": 8,
156
+ "pick_diverse_bottles": 8,
157
+ "pick_dual_bottles": 8,
158
+ "place_a2b_left": 8,
159
+ "place_a2b_right": 8,
160
+ "place_bread_basket": 8,
161
+ "place_bread_skillet": 8,
162
+ "place_burger_fries": 8,
163
+ "place_can_basket": 8,
164
+ "place_cans_plasticbox": 8,
165
+ "place_container_plate": 8,
166
+ "place_dual_shoes": 8,
167
+ "place_empty_cup": 8,
168
+ "place_fan": 8,
169
+ "place_mouse_pad": 8,
170
+ "place_object_basket": 8,
171
+ "place_object_scale": 8,
172
+ "place_object_stand": 8,
173
+ "place_phone_stand": 8,
174
+ "place_shoe": 8,
175
+ "press_stapler": 8,
176
+ "put_bottles_dustbin": 8,
177
+ "put_object_cabinet": 8,
178
+ "rotate_qrcode": 8,
179
+ "scan_object": 8,
180
+ "shake_bottle_horizontally": 8,
181
+ "shake_bottle": 8,
182
+ "stack_blocks_three": 8,
183
+ "stack_blocks_two": 8,
184
+ "stack_bowls_three": 8,
185
+ "stack_bowls_two": 8,
186
+ "stamp_seal": 8,
187
+ "turn_switch": 8,
188
+
189
+ # real-world
190
+ "real_grasp_coke": 9,
191
+ "real_pick_up_cup_in_middle": 9,
192
+ "real_stack_cups": 9,
193
+ "real_put_apple_on_tray_and_then_put_banana_on_tray": 9,
194
+ "realworld_tasks_all": 9,
195
+ "realworld_4tasks": 9,
196
+ "realworld_collect": 9,
197
+ "realworld_pickplace_4tasks": 9,
198
+ }
code/dataloader/gr00t_lerobot/mixtures.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ mixtures.py
3
+
4
+ Defines a registry of dataset mixtures and weights for the Open-X Embodiment Datasets. Each dataset is associated with
5
+ a float "sampling weight"
6
+ """
7
+
8
+ from typing import Dict, List, Tuple
9
+
10
+
11
+ # Dataset mixture name mapped to a list of tuples containing:
12
+ ## {nakename: [(data_name, sampling_weight, robot_type)] }
13
+ DATASET_NAMED_MIXTURES = {
14
+
15
+ "custom_dataset": [
16
+ ("custom_dataset_name", 1.0, "custom_robot_config"),
17
+ ],
18
+ "custom_dataset_2": [
19
+ ("custom_dataset_name_1", 1.0, "custom_robot_config"),
20
+ ("custom_dataset_name_2", 1.0, "custom_robot_config"),
21
+ ],
22
+
23
+ "libero_all": [
24
+ ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"),
25
+ ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"),
26
+ ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"),
27
+ ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"),
28
+ # ("libero_90_no_noops_lerobot", 1.0, "libero_franka"),
29
+ ],
30
+ "bridge": [
31
+ ("bridge_orig_1.0.0_lerobot", 1.0, "oxe_bridge"),
32
+ ],
33
+ "bridge_rt_1": [
34
+ ("bridge_orig_1.0.0_lerobot", 1.0, "oxe_bridge"),
35
+ ("fractal20220817_data_0.1.0_lerobot", 1.0, "oxe_rt1"),
36
+ ],
37
+
38
+ "demo_sim_pick_place": [
39
+ ("sim_pick_place", 1.0, "demo_sim_franka_delta_joints"),
40
+ ],
41
+
42
+ "custom_dataset": [
43
+ ("custom_dataset_name", 1.0, "custom_robot_config"),
44
+ ],
45
+ "custom_dataset_2": [
46
+ ("custom_dataset_name_1", 1.0, "custom_robot_config"),
47
+ ("custom_dataset_name_2", 1.0, "custom_robot_config"),
48
+ ],
49
+
50
+ "fourier_gr1_unified_1000": [
51
+ ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
52
+ ("gr1_unified.PnPCanToDrawerClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
53
+ ("gr1_unified.PnPCupToDrawerClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
54
+ ("gr1_unified.PnPMilkToMicrowaveClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
55
+ ("gr1_unified.PnPPotatoToMicrowaveClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
56
+ ("gr1_unified.PnPWineToCabinetClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
57
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
58
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
59
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToPanSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
60
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToPotSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
61
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToTieredbasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
62
+ ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
63
+ ("gr1_unified.PosttrainPnPNovelFromPlacematToBowlSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
64
+ ("gr1_unified.PosttrainPnPNovelFromPlacematToPlateSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
65
+ ("gr1_unified.PosttrainPnPNovelFromPlacematToTieredshelfSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
66
+ ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
67
+ ("gr1_unified.PosttrainPnPNovelFromPlateToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
68
+ ("gr1_unified.PosttrainPnPNovelFromPlateToPanSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
69
+ ("gr1_unified.PosttrainPnPNovelFromPlateToPlateSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
70
+ ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
71
+ ("gr1_unified.PosttrainPnPNovelFromTrayToPlateSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
72
+ ("gr1_unified.PosttrainPnPNovelFromTrayToPotSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
73
+ ("gr1_unified.PosttrainPnPNovelFromTrayToTieredbasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
74
+ ("gr1_unified.PosttrainPnPNovelFromTrayToTieredshelfSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"),
75
+ ],
76
+
77
+ "BEHAVIOR_challenge": [
78
+ ("BEHAVIOR_challenge", 1.0, "R1Pro"),
79
+ ],
80
+
81
+
82
+ "SO101_pick": [
83
+ ("pick_dataset_name", 1.0, "SO101"),
84
+ ],
85
+
86
+ "arx_x5": [
87
+ ("arx_x5", 1.0, "arx_x5"),
88
+ ],
89
+
90
+ "robotwin": [
91
+ ("adjust_bottle", 1.0, "robotwin"),
92
+ ("beat_block_hammer", 1.0, "robotwin"),
93
+ ("blocks_ranking_rgb", 1.0, "robotwin"),
94
+ ("blocks_ranking_size", 1.0, "robotwin"),
95
+ ("click_alarmclock", 1.0, "robotwin"),
96
+ ("click_bell", 1.0, "robotwin"),
97
+ ("dump_bin_bigbin", 1.0, "robotwin"),
98
+ ("grab_roller", 1.0, "robotwin"),
99
+ ("handover_block", 1.0, "robotwin"),
100
+ ("handover_mic", 1.0, "robotwin"),
101
+ ("hanging_mug", 1.0, "robotwin"),
102
+ ("lift_pot", 1.0, "robotwin"),
103
+ ("move_can_pot", 1.0, "robotwin"),
104
+ ("move_pillbottle_pad", 1.0, "robotwin"),
105
+ ("move_playingcard_away", 1.0, "robotwin"),
106
+ ("move_stapler_pad", 1.0, "robotwin"),
107
+ ("open_laptop", 1.0, "robotwin"),
108
+ ("open_microwave", 1.0, "robotwin"),
109
+ ("pick_diverse_bottles", 1.0, "robotwin"),
110
+ ("pick_dual_bottles", 1.0, "robotwin"),
111
+ ("place_a2b_left", 1.0, "robotwin"),
112
+ ("place_a2b_right", 1.0, "robotwin"),
113
+ ("place_bread_basket", 1.0, "robotwin"),
114
+ ("place_bread_skillet", 1.0, "robotwin"),
115
+ ("place_burger_fries", 1.0, "robotwin"),
116
+ ("place_can_basket", 1.0, "robotwin"),
117
+ ("place_cans_plasticbox", 1.0, "robotwin"),
118
+ ("place_container_plate", 1.0, "robotwin"),
119
+ ("place_dual_shoes", 1.0, "robotwin"),
120
+ ("place_empty_cup", 1.0, "robotwin"),
121
+ ("place_fan", 1.0, "robotwin"),
122
+ ("place_mouse_pad", 1.0, "robotwin"),
123
+ ("place_object_basket", 1.0, "robotwin"),
124
+ ("place_object_scale", 1.0, "robotwin"),
125
+ ("place_object_stand", 1.0, "robotwin"),
126
+ ("place_phone_stand", 1.0, "robotwin"),
127
+ ("place_shoe", 1.0, "robotwin"),
128
+ ("press_stapler", 1.0, "robotwin"),
129
+ ("put_bottles_dustbin", 1.0, "robotwin"),
130
+ ("put_object_cabinet", 1.0, "robotwin"),
131
+ ("rotate_qrcode", 1.0, "robotwin"),
132
+ ("scan_object", 1.0, "robotwin"),
133
+ ("shake_bottle", 1.0, "robotwin"),
134
+ ("shake_bottle_horizontally", 1.0, "robotwin"),
135
+ ("stack_blocks_three", 1.0, "robotwin"),
136
+ ("stack_blocks_two", 1.0, "robotwin"),
137
+ ("stack_bowls_three", 1.0, "robotwin"),
138
+ ("stack_bowls_two", 1.0, "robotwin"),
139
+ ("stamp_seal", 1.0, "robotwin"),
140
+ ("turn_switch", 1.0, "robotwin"),
141
+ ],
142
+ "cross_embodiedment_17tasks": [
143
+ # libero - 4 tasks
144
+ ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984
145
+ ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042
146
+ ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970
147
+ ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469
148
+ # robotwin - 8 tasks, selected by average trajectory length, 400, 500, 600, 700, 800, 900, 900, 1200
149
+ ("beat_block_hammer", 1.0, "robotwin"), #
150
+ ("place_shoe", 1.0, "robotwin"), #
151
+ ("dump_bin_bigbin", 1.0, "robotwin"), #
152
+ ("put_object_cabinet", 1.0, "robotwin"), #
153
+ ("stack_blocks_two", 1.0, "robotwin"), #
154
+ ("stack_bowls_two", 1.0, "robotwin"), #
155
+ ("shake_bottle", 1.0, "robotwin"), #
156
+ ("hanging_mug", 1.0, "robotwin"), #
157
+ # ("blocks_ranking_rgb", 1.0, "robotwin"), #
158
+ # gr1 - 5 tasks
159
+ ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 71341
160
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48282
161
+ ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48066
162
+ ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 41518
163
+ ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 39739
164
+ ],
165
+ "cross_embodiedment_21tasks": [
166
+ # libero - 4 tasks
167
+ ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984
168
+ ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042
169
+ ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970
170
+ ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469
171
+ # robotwin - 8 tasks, selected by average trajectory length, 400, 500, 600, 700, 800, 900, 900, 1200
172
+ ("beat_block_hammer", 1.0, "robotwin"), #
173
+ ("place_shoe", 1.0, "robotwin"), #
174
+ ("dump_bin_bigbin", 1.0, "robotwin"), #
175
+ ("put_object_cabinet", 1.0, "robotwin"), #
176
+ ("stack_blocks_two", 1.0, "robotwin"), #
177
+ ("stack_bowls_two", 1.0, "robotwin"), #
178
+ ("shake_bottle", 1.0, "robotwin"), #
179
+ ("hanging_mug", 1.0, "robotwin"), #
180
+ # ("blocks_ranking_rgb", 1.0, "robotwin"), #
181
+ # gr1 - 5 tasks
182
+ ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 71341
183
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48282
184
+ ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48066
185
+ ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 41518
186
+ ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 39739
187
+ # real-world - 4 tasks
188
+ ("realworld_4tasks", 1.0, "real_world_franka"),
189
+ ],
190
+ "real_world_4tasks": [
191
+ ("realworld_4tasks", 1.0, "real_world_franka"),
192
+ ],
193
+ "realworld_tasks_all": [
194
+ ("realworld_tasks_all", 1.0, "real_world_franka"),
195
+ ],
196
+ "realworld_collect": [
197
+ ("realworld_collect", 1.0, "real_world_franka"),
198
+ ],
199
+ "cross_embodiedment_13tasks": [
200
+ # libero - 4 tasks
201
+ ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984
202
+ ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042
203
+ ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970
204
+ ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469
205
+ # gr1 - 5 tasks
206
+ ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 71341
207
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48282
208
+ ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48066
209
+ ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 41518
210
+ ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 39739
211
+ # real-world - 4 tasks
212
+ ("realworld_pickplace_4tasks", 1.0, "real_world_franka"),
213
+ ],
214
+ "cross_embodiedment_simulator": [
215
+ # libero - 4 tasks
216
+ ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984
217
+ ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042
218
+ ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970
219
+ ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469
220
+ # gr1 - 5 tasks
221
+ ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 71341
222
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48282
223
+ ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 48066
224
+ ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 41518
225
+ ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_200", 1.0, "fourier_gr1_arms_waist"), # 39739
226
+ ],
227
+ "cross_embodiedment_simulator_moredata": [
228
+ # libero - 4 tasks
229
+ ("libero_object_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 66984
230
+ ("libero_goal_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52042
231
+ ("libero_spatial_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 52970
232
+ ("libero_10_no_noops_1.0.0_lerobot", 1.0, "libero_franka"), # 101469
233
+ ("libero_90_no_noops_lerobot", 1.0, "libero_franka"), # 901020
234
+ # gr1 - 5 tasks
235
+ ("gr1_unified.PnPBottleToCabinetClose_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 71341 x 5
236
+ ("gr1_unified.PosttrainPnPNovelFromCuttingboardToBasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 48282 x 5
237
+ ("gr1_unified.PosttrainPnPNovelFromPlacematToBasketSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 48066 x 5
238
+ ("gr1_unified.PosttrainPnPNovelFromPlateToBowlSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 41518 x 5
239
+ ("gr1_unified.PosttrainPnPNovelFromTrayToCardboardboxSplitA_GR1ArmsAndWaistFourierHands_1000", 1.0, "fourier_gr1_arms_waist"), # 39739 x 5
240
+ ],
241
+ }
code/dataloader/gr00t_lerobot/schema.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from enum import Enum
17
+ from typing import Optional
18
+
19
+ from numpydantic import NDArray
20
+ from pydantic import BaseModel, Field, field_serializer
21
+
22
+ from .embodiment_tags import EmbodimentTag
23
+
24
+ # Common schema
25
+
26
+
27
+ class RotationType(Enum):
28
+ """Type of rotation representation"""
29
+
30
+ AXIS_ANGLE = "axis_angle"
31
+ QUATERNION = "quaternion"
32
+ ROTATION_6D = "rotation_6d"
33
+ MATRIX = "matrix"
34
+ EULER_ANGLES_RPY = "euler_angles_rpy"
35
+ EULER_ANGLES_RYP = "euler_angles_ryp"
36
+ EULER_ANGLES_PRY = "euler_angles_pry"
37
+ EULER_ANGLES_PYR = "euler_angles_pyr"
38
+ EULER_ANGLES_YRP = "euler_angles_yrp"
39
+ EULER_ANGLES_YPR = "euler_angles_ypr"
40
+
41
+
42
+ # LeRobot schema
43
+
44
+
45
+ class LeRobotModalityField(BaseModel):
46
+ """Metadata for a LeRobot modality field."""
47
+
48
+ original_key: Optional[str] = Field(
49
+ default=None,
50
+ description="The original key of the modality in the LeRobot dataset",
51
+ )
52
+
53
+
54
+ class LeRobotStateActionMetadata(LeRobotModalityField):
55
+ """Metadata for a LeRobot modality."""
56
+
57
+ start: int = Field(
58
+ ...,
59
+ description="The start index of the modality in the concatenated state/action vector",
60
+ )
61
+ end: int = Field(
62
+ ...,
63
+ description="The end index of the modality in the concatenated state/action vector",
64
+ )
65
+ rotation_type: Optional[RotationType] = Field(
66
+ default=None, description="The type of rotation for the modality"
67
+ )
68
+ absolute: bool = Field(default=True, description="Whether the modality is absolute")
69
+ dtype: str = Field(
70
+ default="float64",
71
+ description="The data type of the modality. Defaults to float64.",
72
+ )
73
+ range: Optional[tuple[float, float]] = Field(
74
+ default=None,
75
+ description="The range of the modality, if applicable. Defaults to None.",
76
+ )
77
+ original_key: Optional[str] = Field(
78
+ default=None,
79
+ description="The original key of the modality in the LeRobot dataset.",
80
+ )
81
+
82
+
83
+ class LeRobotStateMetadata(LeRobotStateActionMetadata):
84
+ """Metadata for a LeRobot state modality."""
85
+
86
+ original_key: Optional[str] = Field(
87
+ default="observation.state", # LeRobot convention for states
88
+ description="The original key of the state modality in the LeRobot dataset",
89
+ )
90
+
91
+
92
+ class LeRobotActionMetadata(LeRobotStateActionMetadata):
93
+ """Metadata for a LeRobot action modality."""
94
+
95
+ original_key: Optional[str] = Field(
96
+ default="action", # LeRobot convention for actions
97
+ description="The original key of the action modality in the LeRobot dataset",
98
+ )
99
+
100
+
101
+ class LeRobotModalityMetadata(BaseModel):
102
+ """Metadata for a LeRobot modality."""
103
+
104
+ state: dict[str, LeRobotStateMetadata] = Field(
105
+ ...,
106
+ description="The metadata for the state modality. The keys are the names of each split of the state vector.",
107
+ )
108
+ action: dict[str, LeRobotActionMetadata] = Field(
109
+ ...,
110
+ description="The metadata for the action modality. The keys are the names of each split of the action vector.",
111
+ )
112
+ video: dict[str, LeRobotModalityField] = Field(
113
+ ...,
114
+ description="The metadata for the video modality. The keys are the new names of each video modality.",
115
+ )
116
+ annotation: Optional[dict[str, LeRobotModalityField]] = Field(
117
+ default=None,
118
+ description="The metadata for the annotation modality. The keys are the new names of each annotation modality.",
119
+ )
120
+
121
+ def get_key_meta(self, key: str) -> LeRobotModalityField:
122
+ """Get the metadata for a key in the LeRobot modality metadata.
123
+
124
+ Args:
125
+ key (str): The key to get the metadata for.
126
+
127
+ Returns:
128
+ LeRobotModalityField: The metadata for the key.
129
+
130
+ Example:
131
+ lerobot_modality_meta = LeRobotModalityMetadata.model_validate(U.load_json(modality_meta_path))
132
+ lerobot_modality_meta.get_key_meta("state.joint_shoulder_y")
133
+ lerobot_modality_meta.get_key_meta("video.main_camera")
134
+ lerobot_modality_meta.get_key_meta("annotation.human.action.task_description")
135
+ """
136
+ split_key = key.split(".")
137
+ modality = split_key[0]
138
+ subkey = ".".join(split_key[1:])
139
+ if modality == "state":
140
+ if subkey not in self.state:
141
+ raise ValueError(
142
+ f"Key: {key}, state key {subkey} not found in metadata, available state keys: {self.state.keys()}"
143
+ )
144
+ return self.state[subkey]
145
+ elif modality == "action":
146
+ if subkey not in self.action:
147
+ raise ValueError(
148
+ f"Key: {key}, action key {subkey} not found in metadata, available action keys: {self.action.keys()}"
149
+ )
150
+ return self.action[subkey]
151
+ elif modality == "video":
152
+ if subkey not in self.video:
153
+ raise ValueError(
154
+ f"Key: {key}, video key {subkey} not found in metadata, available video keys: {self.video.keys()}"
155
+ )
156
+ return self.video[subkey]
157
+ elif modality == "annotation":
158
+ assert (
159
+ self.annotation is not None
160
+ ), "Trying to get annotation metadata for a dataset with no annotations"
161
+ if subkey not in self.annotation:
162
+ raise ValueError(
163
+ f"Key: {key}, annotation key {subkey} not found in metadata, available annotation keys: {self.annotation.keys()}"
164
+ )
165
+ return self.annotation[subkey]
166
+ else:
167
+ raise ValueError(f"Key: {key}, unexpected modality: {modality}")
168
+
169
+
170
+ # Dataset schema (parsed from LeRobot schema and simplified)
171
+
172
+
173
+ class DatasetStatisticalValues(BaseModel):
174
+ max: NDArray = Field(..., description="Maximum values")
175
+ min: NDArray = Field(..., description="Minimum values")
176
+ mean: NDArray = Field(..., description="Mean values")
177
+ std: NDArray = Field(..., description="Standard deviation")
178
+ q01: NDArray = Field(..., description="1st percentile values")
179
+ q99: NDArray = Field(..., description="99th percentile values")
180
+
181
+ @field_serializer("*", when_used="json")
182
+ def serialize_ndarray(self, v: NDArray) -> list[float]:
183
+ return v.tolist() # type: ignore
184
+
185
+
186
+ class DatasetStatistics(BaseModel):
187
+ state: dict[str, DatasetStatisticalValues] = Field(..., description="Statistics of the state")
188
+ action: dict[str, DatasetStatisticalValues] = Field(..., description="Statistics of the action")
189
+
190
+
191
+ class VideoMetadata(BaseModel):
192
+ """Metadata of the video modality"""
193
+
194
+ resolution: tuple[int, int] = Field(..., description="Resolution of the video")
195
+ channels: int = Field(..., description="Number of channels in the video", gt=0)
196
+ fps: float = Field(..., description="Frames per second", gt=0)
197
+
198
+
199
+ class StateActionMetadata(BaseModel):
200
+ absolute: bool = Field(..., description="Whether the state or action is absolute")
201
+ rotation_type: Optional[RotationType] = Field(None, description="Type of rotation, if any")
202
+ shape: tuple[int, ...] = Field(..., description="Shape of the state or action")
203
+ continuous: bool = Field(..., description="Whether the state or action is continuous")
204
+
205
+
206
+ class DatasetModalities(BaseModel):
207
+ video: dict[str, VideoMetadata] = Field(..., description="Metadata of the video")
208
+ state: dict[str, StateActionMetadata] = Field(..., description="Metadata of the state")
209
+ action: dict[str, StateActionMetadata] = Field(..., description="Metadata of the action")
210
+
211
+
212
+ class DatasetMetadata(BaseModel):
213
+ """Metadata of the trainable dataset
214
+
215
+ Changes:
216
+ - Update to use the new RawCommitHashMetadataMetadata_V1_2
217
+ """
218
+
219
+ statistics: DatasetStatistics = Field(..., description="Statistics of the dataset")
220
+ modalities: DatasetModalities = Field(..., description="Metadata of the modalities")
221
+ embodiment_tag: EmbodimentTag = Field(..., description="Embodiment tag of the dataset")
code/dataloader/gr00t_lerobot/transform/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .base import (
17
+ ComposedModalityTransform,
18
+ InvertibleModalityTransform,
19
+ ModalityTransform,
20
+ )
21
+ from .concat import ConcatTransform
22
+ # from .state_action import (
23
+ # StateActionDropout,
24
+ # StateActionPerturbation,
25
+ # StateActionSinCosTransform,
26
+ # StateActionToTensor,
27
+ # StateActionTransform,
28
+ # )
29
+ from .video import (
30
+ VideoColorJitter,
31
+ VideoCrop,
32
+ VideoGrayscale,
33
+ VideoHorizontalFlip,
34
+ VideoRandomGrayscale,
35
+ VideoRandomPosterize,
36
+ VideoRandomRotation,
37
+ VideoResize,
38
+ VideoToNumpy,
39
+ VideoToTensor,
40
+ VideoTransform,
41
+ )
code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (653 Bytes). View file
 
code/dataloader/gr00t_lerobot/transform/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (865 Bytes). View file
 
code/dataloader/gr00t_lerobot/transform/__pycache__/base.cpython-310.pyc ADDED
Binary file (5.53 kB). View file