TinkerYard commited on
Commit
b70b4db
·
verified ·
1 Parent(s): 4723fef

Upload 10 files

Browse files
Files changed (10) hide show
  1. .gitignore +184 -0
  2. LICENSE +21 -0
  3. README.md +366 -3
  4. finetune.sh +54 -0
  5. finetune_maniskill.sh +48 -0
  6. inference.sh +5 -0
  7. main.py +300 -0
  8. pretrain.sh +47 -0
  9. requirements.txt +11 -0
  10. requirements_data.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # Some encoder paths
163
+ facebook/
164
+ openai/
165
+ google/
166
+
167
+ # Log
168
+ logs/
169
+
170
+ # Output
171
+ outs/
172
+
173
+ # Checkpoints
174
+ checkpoints/
175
+
176
+ # VSC
177
+ .vscode/
178
+
179
+ # Wandb
180
+ wandb/
181
+
182
+ # Distributed leaning
183
+ hostfile.txt
184
+ .deepspeed_env
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 TSAIL group
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,366 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation
2
+
3
+ ### 📝[Paper](https://arxiv.org/pdf/2410.07864) | 🌍[Project Page](https://rdt-robotics.github.io/rdt-robotics/) | 🤗[Model](https://huggingface.co/robotics-diffusion-transformer/rdt-1b) | 🛢️[Data](https://huggingface.co/datasets/robotics-diffusion-transformer/rdt-ft-data) | 🏞️[Poster](./assets/iclr2025_poster.png)
4
+
5
+ ![](./assets/head.png)
6
+
7
+ RDT-1B is a **1B**-parameter (*largest* to date) imitation learning **Diffusion Transformer** pre-trained on **1M+** (*largest* to date) multi-robot episodes. Given language instruction and RGB images of up to three views, RDT can predict the next $64$ robot actions. RDT is inherently compatible with **almost all kinds of modern mobile manipulators**, from single-arm to dual-arm, joint to EEF, position to velocity, and even with wheeled locomotion.
8
+
9
+ We have fine-tuned RDT on **6K+** (one of the *largest*) self-collected bimanual episodes and deployed it on the ALOHA **dual-arm** robot. It has achieved state-of-the-art performance in terms of dexterity, zero-shot generalizability, and few-shot learning. You can find Demo videos on our [project page](https://rdt-robotics.github.io/rdt-robotics/).
10
+
11
+ This repo is an official PyTorch implementation of RDT, containing:
12
+
13
+ - 🛠️Model [implementation](models/rdt_runner.py) of RDT
14
+ - 🤗1M-step [checkpoint](https://huggingface.co/robotics-diffusion-transformer/rdt-1b) of RDT-1B pre-trained on multi-robot data
15
+ - 🤗500K-step [checkpoint](https://huggingface.co/robotics-diffusion-transformer/rdt-170m) of RDT-170M (RDT(small) in [ablation](https://arxiv.org/pdf/2410.07864))
16
+ - 📈Training and sampling [scripts](train/train.py) (with DeepSpeed)
17
+ - 🤖An [example](scripts/agilex_inference.py) of real-robot deployment
18
+ - 🕹️Simulation benchmark from [Maniskill](https://github.com/haosulab/ManiSkill) environment
19
+
20
+ The following guides include the [installation](#installation), [fine-tuning](#fine-tuning-on-your-own-dataset), and [deployment](#deployment-on-real-robots). Please refer to [pre-training](docs/pretrain.md) for a detailed list of pre-training datasets and a pre-training guide.
21
+
22
+ ## 📰 News
23
+ - [2025/04/04] [Poster](./assets/iclr2025_poster.png) is uploaded.
24
+ - [2024/12/17] 🔥 [Scripts](#simulation-benchmark) for evaluating RDT in Maniskill Simulation Benchmark is released!
25
+ - [2024/10/23] 🔥 **RDT-170M** (Smaller) model is released, a more VRAM-friendly solution 🚀💻.
26
+
27
+ ## Installation
28
+
29
+ 1. Clone this repo and install prerequisites:
30
+
31
+ ```bash
32
+ # Clone this repo
33
+ git clone git@github.com:thu-ml/RoboticsDiffusionTransformer.git
34
+ cd RoboticsDiffusionTransformer
35
+
36
+ # Create a Conda environment
37
+ conda create -n rdt python=3.10.0
38
+ conda activate rdt
39
+
40
+ # Install pytorch
41
+ # Look up https://pytorch.org/get-started/previous-versions/ with your cuda version for a correct command
42
+ pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
43
+
44
+ # Install packaging
45
+ pip install packaging==24.0
46
+
47
+ # Install flash-attn
48
+ pip install flash-attn --no-build-isolation
49
+
50
+ # Install other prequisites
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ 2. Download off-the-shelf multi-modal encoders:
55
+
56
+ You can download the encoders from the following links:
57
+
58
+ - `t5-v1_1-xxl`: [link](https://huggingface.co/google/t5-v1_1-xxl/tree/main)🤗
59
+ - `siglip`: [link](https://huggingface.co/google/siglip-so400m-patch14-384)🤗
60
+
61
+ And link the encoders to the repo directory:
62
+
63
+ ```bash
64
+ # Under the root directory of this repo
65
+ mkdir -p google
66
+
67
+ # Link the downloaded encoders to this repo
68
+ ln -s /path/to/t5-v1_1-xxl google/t5-v1_1-xxl
69
+ ln -s /path/to/siglip-so400m-patch14-384 google/siglip-so400m-patch14-384
70
+ ```
71
+ 3. Fill the missing argument in [this file](configs/base.yaml#L22):
72
+
73
+ Note that this buffer will only be used during pre-training. See [this doc](docs/pretrain.md) for more details.
74
+ ```
75
+ # ...
76
+
77
+ dataset:
78
+ # ...
79
+ # ADD YOUR buf_path: the path to the buffer (at least 400GB)
80
+ buf_path: /path/to/buffer
81
+ # ...
82
+ ```
83
+
84
+ ## Fine-Tuning on Your Own Dataset
85
+
86
+ If your fine-tuning dataset is in the [Open X-Embodiment](https://robotics-transformer-x.github.io/) or the collection of our pre-training datasets (see [this doc](docs/pretrain.md#download-and-prepare-datasets)), you can also fine-tune RDT through the pre-trained pipeline. You need to remove other redundant datasets in the parameters. We refer to [this guide](docs/pretrain.md) (pre-training).
87
+
88
+ 1. Prepare your dataset:
89
+
90
+ You need to download your dataset to the disk and give it a name `my_cool_dataset`.
91
+
92
+ Then, you can link your dataset to the repo directory:
93
+
94
+ ```bash
95
+ # Under the root directory of this repo
96
+ cd data
97
+ mkdir -p datasets
98
+
99
+ # Link the downloaded dataset to this repo
100
+ ln -s /path/to/my_cool_dataset datasets/my_cool_dataset
101
+ ```
102
+
103
+ 2. Implement the dataset loader:
104
+
105
+ You need to:
106
+
107
+ 1. Register the configuration of `my_cool_dataset`:
108
+
109
+ Append the control frequency of `my_cool_dataset` in [this file](configs/dataset_control_freq.json). Write the name of `my_cool_dataset` in [this file](configs/finetune_datasets.json) and [this file](configs/finetune_sample_weights.json), where the value of the sampling weight doesn't matter since you only have one dataset. In these two files, we leave a placeholder of `agilex`; you can simply replace it with `my_cool_dataset`.
110
+
111
+ 2. Re-Implement the class of `HDF5VLADataset`:
112
+
113
+ You can find this class in [this file](data/hdf5_vla_dataset.py). In this file, we provide an example of loading the fine-tuning dataset used in our paper (see [this link](https://huggingface.co/datasets/robotics-diffusion-transformer/rdt-ft-data)).
114
+
115
+ To adapt it to your dataset, you need to: (a) modify the `HDF5_DIR` (directory to `my_cool_dataset`) and `DATASET_NAME` (should be `"my_cool_dataset"`) in L21 and L22; (b) Implement the two functions of `parse_hdf5_file()` and `parse_hdf5_file_state_only()`. Please take a look at the original file for detailed comments and examples.
116
+
117
+ Note 1: Despite its name, you don't necessarily need to use HDF5 to store your data. Just make sure that the class is correctly implemented.
118
+
119
+ Note 2: During implementation, you may need to fill your robot action into the unified action vector (L180-194). Please refer to [this file](configs/state_vec.py) for an explanation of each element in the unified vector. We have reserved enough slots for each physical quantity. For example, we have reserved ten slots for joint angles. If your robot arm has six degrees of freedom, you only need to fill in the first six.
120
+
121
+ **IMPORTANT 1:** If your robot is single-arm, please fill its action into the *right-arm* portion of the unified action vector, aligning with our pre-training datasets.
122
+
123
+ **IMPORTANT 2:** We use [6D representation](https://arxiv.org/pdf/1812.07035) for EEF rotation. If your action space contains EEF rotation (angle or quaternion), please refer to [this file](docs/test_6drot.py) for conversion. We note that this mapping is not reversible. Different Euler angles may be equivalent and correspond to the same 6D representation.
124
+
125
+ **IMPORTANT 3:** No physical quantities (except the gripper width) are normalized during pre-training. This can preserve each physical quantity's meaning, thereby promoting generalization across robots. Therefore, we encourage you not to normalize any physical quantities but to choose appropriate units for them. Generally, we use the International System of Units, which ensures that most values fall within [-1,1]. As an exception, we perform min-max normalization on the gripper width to [0,1].
126
+
127
+ **IMPORTANT 4:** If you use RTX 4090 (or lower), the GPU memory may be too low to load the `t5-v1_1-xxl` encoder. Instead, we recommend you precompute the language embeddings (see [this file](scripts/encode_lang_batch.py) for an example script) and load them during training. In this way, you need to specify the path to the embeddings in the `HDF5VLADataset` (see L148) rather than the natural language.
128
+
129
+ 3. Compute the dataset statistics information for `my_cool_dataset`:
130
+
131
+ ```bash
132
+ # Under the root directory of this repo
133
+ # Use -h to see the full usage
134
+ python -m data.compute_dataset_stat_hdf5
135
+ ```
136
+
137
+ 3. Start fine-tuning:
138
+
139
+ Configurations relevant to model architecture and data processing are in [this file](configs/base.yaml). Normally, you do not need to modify these configurations; otherwise, it will cause errors in loading the pre-training checkpoint. Configurations relevant to training are passed through *Command Line Arguments*. Use `python main.py -h ` to see the descriptions. We provide an example of a fine-tuning script in [this file](finetune.sh) (`finetune.sh`). You may need to modify some of the parameters in this file, such as `CUTLASS_PATH` and `WANDB_PROJECT`.
140
+
141
+ Use this to start fine-tuning:
142
+
143
+ ```bash
144
+ source finetune.sh
145
+ ```
146
+
147
+ with `finetune.sh` detailed as below:
148
+
149
+ ```bash
150
+ deepspeed --hostfile=hostfile.txt main.py \
151
+ --deepspeed="./configs/zero2.json" \ # If you want to use DeepSpeed, which is strongly recommended
152
+ --pretrained_model_name_or_path=<MODEL ID | DIRECTORY OF MODEL WEIGHTS | PATH TO MODEL CHECKPOINT> \
153
+ --pretrained_text_encoder_name_or_path=<MODEL ID | PATH TO MODEL DIRECTORY > \ # e.g., google/t5-v1_1-xxl
154
+ --pretrained_vision_encoder_name_or_path=<MODEL ID | PATH TO MODEL DIRECTORY> \ # e.g., google/siglip-so400m-patch14-384
155
+ --output_dir=<DIRECTORY to SAVE CHECKPOINTS> \ # e.g., checkpoints/rdt-1b-agilex
156
+ --train_batch_size=32 \
157
+ --sample_batch_size=64 \ # batch size for diffusion sampling in validation
158
+ --max_train_steps=200000 \
159
+ --checkpointing_period=1000 \
160
+ --sample_period=500 \ # sample period for validation
161
+ --checkpoints_total_limit=40 \
162
+ --lr_scheduler="constant" \
163
+ --learning_rate=1e-4 \
164
+ --mixed_precision="bf16" \ # If you want to use mixed precision, bf16 is recommended
165
+ --dataloader_num_workers=8 \
166
+ --image_aug \ # If you want to use image augmentation
167
+ --dataset_type="finetune" \
168
+ --state_noise_snr=40 \ # If you want to add noise to the state
169
+ --load_from_hdf5 \ # If you use HDF5 to store your data
170
+ --report_to=wandb
171
+ ```
172
+
173
+ **IMPORTANT**: If you have already chosen to precompute the language embeddings, please specify `--precomp_lang_embed` in the `finetune.sh`.
174
+
175
+ Note 1: `pretrained_model_name_or_path` can one of:
176
+
177
+ - a string, the *model id* of a pre-trained model hosted inside a model repo on HuggingFace. Please fill with `"robotics-diffusion-transformer/rdt-1b"`, which is the officially-released [RDT-1B model](https://huggingface.co/robotics-diffusion-transformer/rdt-1b)🤗 at HuggingFace. (recommended)
178
+ - a string, the path to a *directory* containing the manually downloaded model weights from HuggingFace, e.g., `"/path/to/rdt-1b"`. You should first manually download the `rdt-1b` directory from this [link](https://huggingface.co/robotics-diffusion-transformer/rdt-1b)🤗.
179
+ - a string, the path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method. This can be either:
180
+ - `"checkpoints/rdt-pretrain-1b/checkpoint-<STEP NUMBER>"`: This is the path to the checkpoint saved in the `<STEP NUMBE>` iteration during pre-training. Refer to [this file](docs/pretrain.md) for a tutorial on how to start your own pre-training.
181
+ - `"checkpoints/rdt-pretrain-1b"`: If the pre-training completes normally without any exception, you can specify this path to load the last checkpoint.
182
+ - a string, the path to model checkpoint (`*.pt`) saved by DeepSpeed, e.g., `"checkpoints/rdt-pretrain-1b/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt"` (verified)
183
+ - `None` if you want to randomly initialize the model using configuration at `config_path`.
184
+
185
+ Note 2: You can monitor the training process by observing `loss` (through a long window moving average) and `overall_avg_sample_mse` in [Wandb](https://wandb.ai/site) or [TensorBoard](https://www.tensorflow.org/tensorboard). We empirically found that the lower the `overall_avg_sample_mse`, the better the model performs. Usually, fine-tuning is over when this value converges.
186
+
187
+ Note 3: If the training oscillates, you can increase the batch size by adding more GPUs or setting a larger `--gradient_accumulation_steps`.
188
+
189
+ Note 4: Please specify `--load_from_hdf5` in your script when finetuning with an HDF5 dataset.
190
+
191
+ ## Deployment on Real-Robots
192
+
193
+ We have encapsulated the inference of the model into a class named `RoboticDiffusionTransformerModel` (see [this file](scripts/agilex_model.py#L38)). You can call this class's `step()` method for inference. However, you may need to re-implement some parts according to your specific robot. You should at least modify the `_format_joint_to_state()` (L164) and `_unformat_action_to_joint()` (L196) to convert between robot raw actions and unified action vectors that RDT accepts. You may also specify the control frequency of your robot (L49).
194
+
195
+ **IMPORTANT**: When you feed the images into `step()`, remember the order MUST be `[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1}, ext_{t}, right_wrist_{t}, left_wrist_{t}]`.
196
+
197
+ We provide an example hardware code in [this file](scripts/agilex_inference.py) for deployment on Mobile ALOHA, and the corresponding running script in [this file](inference.sh) (`inference.sh`), which is detailed below;
198
+
199
+ ```bash
200
+ python -m scripts.agilex_inference \
201
+ --use_actions_interpolation \
202
+ --pretrained_model_name_or_path=<PATH TO MODEL CHECKPOINT> \ # your finetuned checkpoint: e.g., checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>, checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt, the same before
203
+ --lang_embeddings_path=<PATH TO YOUR INSTURCTION EMBEDDINGS> \ # e.g. outs/lang_embeddings/your_instr.pt"
204
+ --ctrl_freq=25 # your control frequency
205
+ ```
206
+
207
+ **IMPORTANT**: If you on-board GPU memory is not enough to encode the language, please refer to [this file](scripts/encode_lang.py) for precomputation and specify the language embedding path in `inference.sh`. Detail instructions are provided below:
208
+
209
+ 1. Set Required Parameters in `scripts/encode_lang.py`
210
+
211
+ ```python
212
+ # ...
213
+
214
+ GPU = 0
215
+ MODEL_PATH = "google/t5-v1_1-xxl"
216
+ CONFIG_PATH = "configs/base.yaml"
217
+ SAVE_DIR = "outs/" # output directory
218
+
219
+ # Modify this to your task name and instruction
220
+ TASK_NAME = "handover_pan"
221
+ INSTRUCTION = "Pick up the black marker on the right and put it into the packaging box on the left."
222
+
223
+ # Note: if your GPU VRAM is less than 24GB,
224
+ # it is recommended to enable offloading by specifying an offload directory.
225
+ OFFLOAD_DIR = None # Specify your offload directory here, ensuring the directory exists.
226
+
227
+ # ...
228
+ ```
229
+
230
+ 2. Run the script
231
+ ```
232
+ python -m scripts.encode_lang
233
+ ```
234
+
235
+ Note: If you want to deploy on the Mobile ALOHA robot, don't forget to install the hardware prerequisites (see [this repo](https://github.com/MarkFzp/mobile-aloha)).
236
+
237
+ ## Simulation Benchmark
238
+
239
+ We comprehensively evaluate RDT against baseline methods using the ManiSkill simulation benchmark. Specifically, we focus on five benchmark tasks: `PegInsertionSide`, `PickCube`, `StackCube`, `PlugCharger`, and `PushCube`. Here's a brief overview of the evaluation setup:
240
+
241
+ **Evaluation Setup:**
242
+
243
+ 1. **Install ManiSkill:**
244
+ Within the [RDT environment](#installation), install ManiSkill as follows:
245
+ ```bash
246
+ conda activate rdt
247
+ pip install --upgrade mani_skill
248
+ ```
249
+
250
+ 2. **Configure Vulkan:**
251
+ Follow the [ManiSkill documentation](https://maniskill.readthedocs.io/en/latest/user_guide/getting_started/installation.html#vulkan) to properly set up Vulkan。
252
+
253
+ 3. **Obtain Model Weights:**
254
+ Download the fine-tuned model weights from [this Hugging Face repository](https://huggingface.co/robotics-diffusion-transformer/maniskill-model/tree/main/rdt). Download the precomputed language embeddings from [here](https://huggingface.co/robotics-diffusion-transformer/maniskill-model/tree/main/lang_embeds) to the root directory of this repo.
255
+
256
+ 4. **Run Evaluation Scripts:**
257
+ After completing the setup steps, execute the provided evaluation scripts to assess RDT on the selected tasks.
258
+
259
+ ```
260
+ conda activate rdt
261
+ python -m eval_sim.eval_rdt_maniskill \
262
+ --pretrained_path PATH_TO_PRETRAINED_MODEL
263
+ ```
264
+
265
+ ### Implementation Details
266
+
267
+ #### Data
268
+
269
+ Utilizing the [official ManiSkill repository](https://github.com/haosulab/ManiSkill), we generated 5,000 trajectories through motion planning. The initial action mode of these trajectories is absolute joint position control and we subsequently converted them into delta end-effector pose control to align with the pre-training action space of OpenVLA and Octo. We strictly adhered to the official codebases of OpenVLA and Octo, modifying only the dataset-loading scripts. Consequently, we finetuned OpenVLA and Octo using the delta end-effector pose data. For RDT and Diffusion-Policy we leverage joint position control data for training which is aligned with our pre-training stage as well.
270
+
271
+ #### Training
272
+ - OpenVLA is fine-tuned from the officially released pre-trained checkpoint with LoRA-rank 32 until converge.
273
+ - Octo is fine-tuned from the officially released pre-trained checkpoint for 1M iterations until converge.
274
+ - Diffusion-Policy is trained from scratch for 1000 epochs. We select the checkpoint of 700 epoch which has the lowest validation sample loss of 1e-3.
275
+ - RDT is fine-tuned from our released pre-trained checkpoint for 300ks iterations.
276
+
277
+ #### Results
278
+
279
+ Each method is evaluated over 250 trials (10 random seeds with 25 trials per seed). The quantitative results, including success rate mean and std value across 10 random seeds are presented below:
280
+
281
+
282
+ ||PegInsertionSide|PickCube|StackCube|PlugCharger|PushCube|Mean|
283
+ |---|---|---|---|---|---|---|
284
+ |RDT|**13.2±0.29%**|**77.2±0.48%**|74.0±0.30%|**1.2±0.07%**|**100±0.00%**|**53.6±0.52%**|
285
+ |OpenVLA|0.0±0.00%|8±0.00%|8±0.00%|0.0±0.00%|8±0.00%|4.8±0.00%|
286
+ |Octo|0.0±0.00%|0.0±0.00%|0.0±0.00%|0.0±0.00%|0.0±0.00%|0.0±0.00%|
287
+ |Diffusion-Policy|0.0±0.00%|40.0±0.00%|**80.0±0.00%**|0.0%±0.00%|88.0±0.00%|30.2±0.00%|
288
+
289
+ #### Finetune RDT with Maniskill Data
290
+
291
+ To fine-tune RDT with Maniskill data, first download the Maniskill data from [here](https://huggingface.co/robotics-diffusion-transformer/maniskill-model) and extract it to `data/datasets/rdt-ft-data`. Then copy the code in `data/hdf5_vla_dataset.py` to `data/hdf5_maniskill_dataset.py` and run the following script:
292
+
293
+ ```
294
+ bash finetune_maniskill.sh
295
+ ```
296
+
297
+ #### Reproducing Baseline Results
298
+
299
+ Download and extract the fine-tuned model weights from [here](https://huggingface.co/robotics-diffusion-transformer/maniskill-model) to `eval_sim/`.
300
+
301
+ - OpenVLA: Clone [OpenVLA repo](https://github.com/openvla/openvla) in `./eval_sim/` and install its environment & ManiSkill. Then run the following script:
302
+ ```
303
+ python -m eval_sim.eval_openvla --pretrained_path PATH_TO_PRETRAINED_MODEL
304
+ ```
305
+ - Octo: Clone [Octo repo](https://github.com/octo-models/octo.git) in `./eval_sim/` and install its environment & ManiSkill. The run the following script:
306
+ ```
307
+ python -m eval_sim.eval_octo --pretrained_path PATH_TO_PRETRAINED_MODEL
308
+ ```
309
+ - Diffusion-Policy: Clone our simplified [Diffusion-Policy repo](https://github.com/LBG21/RDT-Eval-Diffusion-Policy) in `./eval_sim/` and run:
310
+ ```
311
+ python -m eval_sim.eval_dp --pretrained_path PATH_TO_PRETRAINED_MODEL
312
+ ```
313
+
314
+ ### RDT on [RoboTwin Benchmark](https://robotwin-platform.github.io/)
315
+
316
+ RoboTwin is a benchmark simulator based on Sapien, comprising 50 common dual-arm tasks. All policies are trained on 50 trajectories collected in a clean environment for each of the 50 tasks, and then deployed and tested for success rates in environments corresponding to either easy (clean) or hard (randomized) configurations for the same 50 tasks. You can refer to ​​[https://robotwin-platform.github.io/doc/usage/RDT.html#1-environment-setup](https://robotwin-platform.github.io/doc/usage/RDT.html#1-environment-setup)​​ to set up the environment required for RoboTwin and RDT. According to the official test results of RoboTwin ([RoboTwin 2.0 Leaderboard](https://robotwin-platform.github.io/leaderboard)), RDT ranks second only to Pi0 (excluding the DP3 policy, which utilizes ground truth point clouds).
317
+ <img width="2048" height="1224" alt="image" src="https://github.com/user-attachments/assets/e721ffab-3dde-42f0-b36e-1593aa964a99" />
318
+
319
+
320
+ ## FAQ
321
+
322
+ ### 1. How can I fine-tune RDTs with limited VRAM?
323
+
324
+ - **Use a Smaller Model**: Opt for the [RDT-170M model](https://huggingface.co/robotics-diffusion-transformer/rdt-170m), which requires less VRAM.
325
+
326
+ - **Select a Memory-Efficient ZeRO Stage**: Choose a more memory-efficient ZeRO stage based on your needs:
327
+ - **ZeRO-3 with Offload** > **ZeRO-3** > **ZeRO-2 with Offload** > **ZeRO-2** > **ZeRO-1**
328
+ - By default, we use [ZeRO-2](https://github.com/thu-ml/RoboticsDiffusionTransformer/blob/c68398ed526733faca4eec52cc1a7d15a9f8fea7/finetune.sh#L29) for a balance between speed and memory efficiency. Find more details on ZeRO stages [here](https://huggingface.co/docs/transformers/main/deepspeed#select-a-zero-stage) and [here](https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training).
329
+
330
+ - **Enable 8-bit Adam Optimization**: Activate 8-bit Adam by setting [`use_8bit_adam=True`](https://github.com/thu-ml/RoboticsDiffusionTransformer/blob/c68398ed526733faca4eec52cc1a7d15a9f8fea7/main.py#L195) for reduced memory usage during training.
331
+
332
+ - **Apply 4-bit or 8-bit Quantization**: Quantizing model weights can significantly reduce VRAM requirements.
333
+
334
+ - **Use [XFormers](https://github.com/facebookresearch/xformers)**: This library provides optimized transformers with efficient memory usage.
335
+
336
+ - **Enable Gradient Checkpointing**: Implement `gradient_checkpointing` manually to save memory during backpropagation. See [here](https://deepspeed.readthedocs.io/en/latest/activation-checkpointing.html) for instructions. Once you have successfully implemented this feature, we welcome you to submit a PR👏.
337
+ - **Gradient Accumulation**: Set a larger `--gradient_accumulation_steps=<num_steps>`. This will accumulate the gradients of `<num_steps>` batches for backpropagation. Equivalently, this will increase the batch size by `<num_steps>` times, at the cost of `<num_steps>` times the running time.
338
+
339
+ ### 2. How many steps are recommended for fine-tuning RDT?
340
+
341
+ Regardless of the batch size you select, it is recommended to train for at least 150K steps to achieve optimal results.
342
+
343
+ ### 3. What to do if t5-xxL is too large to store in GPU memory?
344
+
345
+ 1. Do not load T5-XXL in your GPU memory when training. Pre-compute language embeddings in advance.
346
+ 2. Set `OFFLOAD_DIR` to enable CPU offloading in `scripts/encode_lang_batch.py` and `scripts/encode_lang.py`.
347
+ 3. Use smaller versions of t5 like t5-base instead of t5-xxL.
348
+
349
+ ## Citation
350
+
351
+ If you find our work helpful, please cite us:
352
+
353
+ ```bibtex
354
+ @article{liu2024rdt,
355
+ title={RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation},
356
+ author={Liu, Songming and Wu, Lingxuan and Li, Bangguo and Tan, Hengkai and Chen, Huayu and Wang, Zhengyi and Xu, Ke and Su, Hang and Zhu, Jun},
357
+ journal={arXiv preprint arXiv:2410.07864},
358
+ year={2024}
359
+ }
360
+ ```
361
+
362
+ Thank you!
363
+
364
+ ## License
365
+
366
+ All the code, model weights, and data are licensed under [MIT license](./LICENSE).
finetune.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
2
+ export NCCL_IB_DISABLE=0
3
+ export NCCL_SOCKET_IFNAME=bond0
4
+ export NCCL_DEBUG=INFO
5
+ export NCCL_NVLS_ENABLE=0
6
+
7
+ export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
8
+ export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
9
+ export OUTPUT_DIR="./checkpoints/rdt-finetune-1b"
10
+ export CFLAGS="-I/usr/include"
11
+ export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
12
+ export CUTLASS_PATH="/path/to/cutlass"
13
+
14
+ export WANDB_PROJECT="robotics_diffusion_transformer"
15
+
16
+ if [ ! -d "$OUTPUT_DIR" ]; then
17
+ mkdir "$OUTPUT_DIR"
18
+ echo "Folder '$OUTPUT_DIR' created"
19
+ else
20
+ echo "Folder '$OUTPUT_DIR' already exists"
21
+ fi
22
+
23
+ # For run in a single node/machine
24
+ # accelerate launch main.py \
25
+ # --deepspeed="./configs/zero2.json" \
26
+ # ...
27
+
28
+ deepspeed --hostfile=hostfile.txt main.py \
29
+ --deepspeed="./configs/zero2.json" \
30
+ --pretrained_model_name_or_path="robotics-diffusion-transformer/rdt-1b" \
31
+ --pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
32
+ --pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
33
+ --output_dir=$OUTPUT_DIR \
34
+ --train_batch_size=32 \
35
+ --sample_batch_size=64 \
36
+ --max_train_steps=200000 \
37
+ --checkpointing_period=1000 \
38
+ --sample_period=500 \
39
+ --checkpoints_total_limit=40 \
40
+ --lr_scheduler="constant" \
41
+ --learning_rate=1e-4 \
42
+ --mixed_precision="bf16" \
43
+ --dataloader_num_workers=8 \
44
+ --image_aug \
45
+ --dataset_type="finetune" \
46
+ --state_noise_snr=40 \
47
+ --load_from_hdf5 \
48
+ --report_to=wandb
49
+
50
+ # Use this to resume training from some previous checkpoint
51
+ # --resume_from_checkpoint="checkpoint-36000" \
52
+ # Use this to load from saved lanuage instruction embeddings,
53
+ # instead of calculating it during training
54
+ # --precomp_lang_embed \
finetune_maniskill.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
2
+ export NCCL_IB_DISABLE=0
3
+ export NCCL_SOCKET_IFNAME=bond0
4
+ export NCCL_DEBUG=INFO
5
+ export NCCL_NVLS_ENABLE=0
6
+
7
+ export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
8
+ export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
9
+ export OUTPUT_DIR="./checkpoints/rdt-finetune-1b-sim"
10
+ export CFLAGS="-I/usr/include"
11
+ export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
12
+ export CUTLASS_PATH="/data/lingxuan/cutlass"
13
+
14
+ export WANDB_PROJECT="robotic_diffusion_transformer"
15
+
16
+ if [ ! -d "$OUTPUT_DIR" ]; then
17
+ mkdir "$OUTPUT_DIR"
18
+ echo "Folder '$OUTPUT_DIR' created"
19
+ else
20
+ echo "Folder '$OUTPUT_DIR' already exists"
21
+ fi
22
+ # For run in a single node/machine
23
+ # accelerate launch main.py \
24
+ # --deepspeed="./configs/zero2.json" \
25
+ # ...
26
+
27
+ accelerate launch main.py \
28
+ --deepspeed="./configs/zero2.json" \
29
+ --pretrained_model_name_or_path="robotics-diffusion-transformer/rdt-1b" \
30
+ --pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
31
+ --pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
32
+ --output_dir=$OUTPUT_DIR \
33
+ --train_batch_size=24 \
34
+ --sample_batch_size=32 \
35
+ --max_train_steps=400000 \
36
+ --checkpointing_period=10000 \
37
+ --sample_period=500 \
38
+ --checkpoints_total_limit=40 \
39
+ --lr_scheduler="constant" \
40
+ --learning_rate=1e-4 \
41
+ --mixed_precision="bf16" \
42
+ --dataloader_num_workers=8 \
43
+ --image_aug \
44
+ --dataset_type="finetune" \
45
+ --state_noise_snr=40 \
46
+ --load_from_hdf5 \
47
+ --report_to=wandb
48
+
inference.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ python -m scripts.agilex_inference \
2
+ --use_actions_interpolation \
3
+ --pretrained_model_name_or_path="checkpoints/your_finetuned_ckpt.pt" \ # your finetuned checkpoint: e.g., checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>, checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt,
4
+ --lang_embeddings_path="outs/lang_embeddings/your_instr.pt" \
5
+ --ctrl_freq=25 # your control frequency
main.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from train.train import train
4
+
5
+ from accelerate.logging import get_logger
6
+
7
+
8
+ def parse_args(input_args=None):
9
+ parser = argparse.ArgumentParser(description="Main script for training RDT.")
10
+ parser.add_argument(
11
+ "--config_path",
12
+ type=str,
13
+ default="configs/base.yaml",
14
+ help="Path to the configuration file. Default is `configs/base.yaml`.",
15
+ )
16
+ parser.add_argument(
17
+ "--deepspeed",
18
+ type=str,
19
+ default=None,
20
+ help="Enable DeepSpeed and pass the path to its config file or an already initialized DeepSpeed config dictionary",
21
+ )
22
+ parser.add_argument(
23
+ "--pretrained_text_encoder_name_or_path",
24
+ type=str,
25
+ default=None,
26
+ help="Pretrained text encoder name or path if not the same as model_name",
27
+ )
28
+ parser.add_argument(
29
+ "--pretrained_vision_encoder_name_or_path",
30
+ type=str,
31
+ default=None,
32
+ help="Pretrained vision encoder name or path if not the same as model_name",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--output_dir",
37
+ type=str,
38
+ default="checkpoints",
39
+ help="The output directory where the model predictions and checkpoints will be written.",
40
+ )
41
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
42
+
43
+ parser.add_argument(
44
+ "--load_from_hdf5",
45
+ action="store_true",
46
+ default=False,
47
+ help=(
48
+ "Whether to load the dataset directly from HDF5 files. "
49
+ "If False, the dataset will be loaded using producer-consumer pattern, "
50
+ "where the producer reads TFRecords and saves them to buffer, and the consumer reads from buffer."
51
+ )
52
+ )
53
+ parser.add_argument(
54
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
55
+ )
56
+ parser.add_argument(
57
+ "--sample_batch_size", type=int, default=8, help="Batch size (per device) for the sampling dataloader."
58
+ )
59
+ parser.add_argument(
60
+ "--num_sample_batches", type=int, default=2, help="Number of batches to sample from the dataset."
61
+ )
62
+ parser.add_argument("--num_train_epochs", type=int, default=1)
63
+ parser.add_argument(
64
+ "--max_train_steps",
65
+ type=int,
66
+ default=None,
67
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
68
+ )
69
+ parser.add_argument(
70
+ "--checkpointing_period",
71
+ type=int,
72
+ default=500,
73
+ help=(
74
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
75
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
76
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
77
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
78
+ "instructions."
79
+ ),
80
+ )
81
+ parser.add_argument(
82
+ "--checkpoints_total_limit",
83
+ type=int,
84
+ default=None,
85
+ help=(
86
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
87
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
88
+ " for more details"
89
+ ),
90
+ )
91
+ parser.add_argument(
92
+ "--resume_from_checkpoint",
93
+ type=str,
94
+ default=None,
95
+ help=(
96
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
97
+ ' `--checkpointing_period`, or `"latest"` to automatically select the last available checkpoint.'
98
+ ),
99
+ )
100
+ parser.add_argument(
101
+ "--pretrained_model_name_or_path",
102
+ type=str,
103
+ default=None,
104
+ help=(
105
+ "Path or name of a pretrained checkpoint to load the model from.\n"
106
+ " This can be either:\n"
107
+ " - a string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co, e.g., `robotics-diffusion-transformer/rdt-1b`,\n"
108
+ " - a path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method, e.g., `./my_model_directory/`.\n"
109
+ " - a path to model checkpoint (*.pt), .e.g, `my_model_directory/checkpoint-10000/pytorch_model/mp_rank_00_model_states.pt`"
110
+ " - `None` if you are randomly initializing model using configuration at `config_path`."
111
+ )
112
+ )
113
+ parser.add_argument(
114
+ "--gradient_accumulation_steps",
115
+ type=int,
116
+ default=1,
117
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
118
+ )
119
+ parser.add_argument(
120
+ "--gradient_checkpointing",
121
+ action="store_true",
122
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
123
+ )
124
+ parser.add_argument(
125
+ "--learning_rate",
126
+ type=float,
127
+ default=5e-6,
128
+ help="Initial learning rate (after the potential warmup period) to use.",
129
+ )
130
+ parser.add_argument(
131
+ "--cond_mask_prob",
132
+ type=float,
133
+ default=0.1,
134
+ help=(
135
+ "The probability to randomly mask the conditions (except states) during training. "
136
+ "If set to 0, the conditions are not masked."
137
+ ),
138
+ )
139
+ parser.add_argument(
140
+ "--cam_ext_mask_prob",
141
+ type=float,
142
+ default=-1.0,
143
+ help=(
144
+ "The probability to randomly mask the external camera image during training. "
145
+ "If set to < 0, the external camera image is masked with the probability of `cond_mask_prob`."
146
+ ),
147
+ )
148
+ parser.add_argument(
149
+ "--state_noise_snr",
150
+ type=float,
151
+ default=None,
152
+ help=(
153
+ "The signal-to-noise ratio (SNR, unit: dB) for adding noise to the states. "
154
+ "Default is None, which means no noise is added."
155
+ ),
156
+ )
157
+ parser.add_argument(
158
+ "--image_aug",
159
+ action="store_true",
160
+ default=False,
161
+ help="Whether or not to apply image augmentation (ColorJitter, blur, noise, etc) to the input images.",
162
+ )
163
+ parser.add_argument(
164
+ "--precomp_lang_embed",
165
+ action="store_true",
166
+ default=False,
167
+ help="Whether or not to use precomputed language embeddings.",
168
+ )
169
+ parser.add_argument(
170
+ "--scale_lr",
171
+ action="store_true",
172
+ default=False,
173
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
174
+ )
175
+ parser.add_argument(
176
+ "--lr_scheduler",
177
+ type=str,
178
+ default="constant",
179
+ help=(
180
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
181
+ ' "constant", "constant_with_warmup"]'
182
+ ),
183
+ )
184
+ parser.add_argument(
185
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
186
+ )
187
+ parser.add_argument(
188
+ "--lr_num_cycles",
189
+ type=int,
190
+ default=1,
191
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
192
+ )
193
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
194
+ parser.add_argument(
195
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
196
+ )
197
+ parser.add_argument(
198
+ "--dataloader_num_workers",
199
+ type=int,
200
+ default=0,
201
+ help=(
202
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
203
+ ),
204
+ )
205
+ parser.add_argument("--alpha", type=float, default=0.9, help="The moving average coefficient for each dataset's loss.")
206
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
207
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
208
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
209
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
210
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
211
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
212
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
213
+ parser.add_argument(
214
+ "--hub_model_id",
215
+ type=str,
216
+ default=None,
217
+ help="The name of the repository to keep in sync with the local `output_dir`.",
218
+ )
219
+ parser.add_argument(
220
+ "--logging_dir",
221
+ type=str,
222
+ default="logs",
223
+ help=(
224
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
225
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
226
+ ),
227
+ )
228
+ parser.add_argument(
229
+ "--allow_tf32",
230
+ action="store_true",
231
+ help=(
232
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
233
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
234
+ ),
235
+ )
236
+ parser.add_argument(
237
+ "--report_to",
238
+ type=str,
239
+ default="tensorboard",
240
+ help=(
241
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
242
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
243
+ ),
244
+ )
245
+ parser.add_argument(
246
+ "--sample_period",
247
+ type=int,
248
+ default=-1,
249
+ help=(
250
+ "Run sampling every X steps. During the sampling phase, the model will sample a trajectory"
251
+ " and report the error between the sampled trajectory and groud-truth trajectory"
252
+ " in the training batch."
253
+ ),
254
+ )
255
+ parser.add_argument(
256
+ "--mixed_precision",
257
+ type=str,
258
+ default=None,
259
+ choices=["no", "fp16", "bf16"],
260
+ help=(
261
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
262
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
263
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
264
+ ),
265
+ )
266
+
267
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
268
+ parser.add_argument(
269
+ "--set_grads_to_none",
270
+ action="store_true",
271
+ help=(
272
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
273
+ " behaviors, so disable this argument if it causes any problems. More info:"
274
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
275
+ ),
276
+ )
277
+
278
+ parser.add_argument('--dataset_type',
279
+ type=str,
280
+ default="pretrain",
281
+ required=False,
282
+ help="Whether to load the pretrain dataset or finetune dataset."
283
+ )
284
+
285
+ if input_args is not None:
286
+ args = parser.parse_args(input_args)
287
+ else:
288
+ args = parser.parse_args()
289
+
290
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
291
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
292
+ args.local_rank = env_local_rank
293
+
294
+ return args
295
+
296
+
297
+ if __name__ == "__main__":
298
+ logger = get_logger(__name__)
299
+ args = parse_args()
300
+ train(args, logger)
pretrain.sh ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
2
+ export NCCL_IB_DISABLE=0
3
+ export NCCL_SOCKET_IFNAME=bond0
4
+ export NCCL_DEBUG=INFO
5
+ export NCCL_NVLS_ENABLE=0
6
+
7
+ export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
8
+ export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
9
+ export OUTPUT_DIR="./checkpoints/rdt-pretrain-1b"
10
+ export CFLAGS="-I/usr/include"
11
+ export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
12
+ export CUTLASS_PATH="/path/to/cutlass"
13
+
14
+ export WANDB_PROJECT="robotics_diffusion_transformer"
15
+
16
+ if [ ! -d "$OUTPUT_DIR" ]; then
17
+ mkdir "$OUTPUT_DIR"
18
+ echo "Folder '$OUTPUT_DIR' created"
19
+ else
20
+ echo "Folder '$OUTPUT_DIR' already exists"
21
+ fi
22
+
23
+ # For run in a single node/machine
24
+ # accelerate launch main.py \
25
+ # --deepspeed="./configs/zero2.json" \
26
+ # ...
27
+
28
+ deepspeed --hostfile=hostfile.txt main.py \
29
+ --deepspeed="./configs/zero2.json" \
30
+ --pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
31
+ --pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
32
+ --output_dir=$OUTPUT_DIR \
33
+ --train_batch_size=32 \
34
+ --sample_batch_size=64 \
35
+ --max_train_steps=1000000 \
36
+ --checkpointing_period=1000 \
37
+ --sample_period=500 \
38
+ --checkpoints_total_limit=40 \
39
+ --lr_scheduler="constant" \
40
+ --learning_rate=1e-4 \
41
+ --mixed_precision="bf16" \
42
+ --dataloader_num_workers=8 \
43
+ --dataset_type="pretrain" \
44
+ --report_to=wandb
45
+
46
+ # Use this to resume training from some previous checkpoint
47
+ # --resume_from_checkpoint="checkpoint-1000" \
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ packaging==24.0
2
+ wandb==0.17.0
3
+ deepspeed==0.14.2
4
+ accelerate==0.30.1
5
+ diffusers==0.27.2
6
+ timm==1.0.3
7
+ transformers==4.41.0
8
+ sentencepiece==0.2.0
9
+ h5py==3.11.0
10
+ opencv-python==4.9.0.80
11
+ imgaug==0.4.0
requirements_data.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ tfds-nightly==4.9.4.dev202402070044
2
+ gsutil==5.27
3
+ tensorflow==2.15.0.post1
4
+ pillow==10.2.0
5
+ pyyaml==6.0.1
6
+ opencv-python==4.9.0.80
7
+ tensorflow-graphics==2021.12.3
8
+ imageio==2.34.0
9
+ imageio-ffmpeg==0.4.9