Spaces:
Running
on
Zero
Running
on
Zero
Add theia
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +24 -0
- theia/__init__.py +1 -0
- theia/configs/dataset/ego4d.yaml +5 -0
- theia/configs/dataset/epic_kitchen.yaml +5 -0
- theia/configs/dataset/image_video_default.yaml +7 -0
- theia/configs/dataset/image_video_mix.yaml +8 -0
- theia/configs/dataset/imagenet.yaml +5 -0
- theia/configs/dataset/oxe_octo_mix.yaml +12 -0
- theia/configs/dataset/ssv2.yaml +5 -0
- theia/configs/logging/default.yaml +6 -0
- theia/configs/model/backbone/deit.yaml +2 -0
- theia/configs/model/backbone/deit_nocls.yaml +2 -0
- theia/configs/model/backbone/deit_reg.yaml +3 -0
- theia/configs/model/translator/conv.yaml +3 -0
- theia/configs/model/translator/lconv.yaml +3 -0
- theia/configs/model/translator/mlp.yaml +4 -0
- theia/configs/model/translator/transformer.yaml +5 -0
- theia/configs/train_rvfm_imagenet.yaml +9 -0
- theia/configs/training/frame_level.yaml +35 -0
- theia/configs/training/target_models/cdds.yaml +6 -0
- theia/configs/training/target_models/cddsv.yaml +7 -0
- theia/configs/training/target_models/cddv.yaml +6 -0
- theia/configs/training/target_models/cdesv.yaml +6 -0
- theia/configs/training/target_models/cdis.yaml +5 -0
- theia/configs/training/target_models/cdisv.yaml +6 -0
- theia/configs/training/target_models/cdiv.yaml +5 -0
- theia/configs/training/target_models/clip.yaml +3 -0
- theia/configs/training/target_models/ddsv.yaml +6 -0
- theia/configs/training/target_models/depth_anything.yaml +3 -0
- theia/configs/training/target_models/dinov2.yaml +3 -0
- theia/configs/training/target_models/sam.yaml +3 -0
- theia/configs/training/target_models/vit.yaml +3 -0
- theia/dataset/__init__.py +5 -0
- theia/dataset/data_utils.py +591 -0
- theia/dataset/image/__init__.py +3 -0
- theia/dataset/image/image_common.py +5 -0
- theia/dataset/oxe/__init__.py +1 -0
- theia/dataset/oxe/oxe_common.py +430 -0
- theia/dataset/oxe/oxe_mixes.py +139 -0
- theia/dataset/oxe/oxe_transforms.py +15 -0
- theia/dataset/video/__init__.py +3 -0
- theia/dataset/video/video_common.py +11 -0
- theia/decoding/__init__.py +5 -0
- theia/decoding/decode.py +198 -0
- theia/decoding/depth_anything.py +57 -0
- theia/decoding/dinov2.py +69 -0
- theia/decoding/sam.py +191 -0
- theia/example/decode_to_vfms.ipynb +69 -0
- theia/foundation_models/__init__.py +9 -0
- theia/foundation_models/common.py +87 -0
LICENSE
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2024 Boston Dynamics AI Institute LLC
|
2 |
+
|
3 |
+
Redistribution and use in source and binary forms, with or without
|
4 |
+
modification, are permitted provided that the following conditions are met:
|
5 |
+
1. Redistributions of source code must retain the copyright notice included
|
6 |
+
with the software, this list of conditions and the following disclaimer.
|
7 |
+
2. Redistributions in binary form must reproduce the copyright notice, this
|
8 |
+
list of conditions and the following disclaimer in the documentation and/or
|
9 |
+
other materials provided with the distribution.
|
10 |
+
3. Modified versions of the software must be conspicuously marked as such.
|
11 |
+
4. The software may only be used for non-commercial research purposes.
|
12 |
+
For profit enterprises may use the software, subject to this limitation.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE AI INSTITUTE AND CONTRIBUTORS "AS IS" AND
|
15 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, NON-
|
16 |
+
INFRINGEMENT,TITLE, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE AI INSTITUTE OR CONTRIBUTORS BE LIABLE FOR
|
18 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, DAMAGES ARISING OUT OF CLAIMS OF
|
20 |
+
INTELLECTUAL PROPERTY RIGHTS INFRINGEMENT; PROCUREMENT OF SUBSTITUTE GOODS OR
|
21 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
22 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
23 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
24 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
theia/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
theia/configs/dataset/ego4d.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- image_video_default
|
3 |
+
|
4 |
+
dataset_mix:
|
5 |
+
- "ego4d_1in150"
|
theia/configs/dataset/epic_kitchen.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- image_video_default
|
3 |
+
|
4 |
+
dataset_mix:
|
5 |
+
- "epic_kitchen_1in60"
|
theia/configs/dataset/image_video_default.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
return_metadata: False
|
2 |
+
shuffle: True
|
3 |
+
shuffle_buffer_size: 1024
|
4 |
+
feature_norm: True
|
5 |
+
dataset_root: "/storage/nfs/datasets/jshang/"
|
6 |
+
dataset_ratio: 0.1
|
7 |
+
load_action: False
|
theia/configs/dataset/image_video_mix.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- image_video_default
|
3 |
+
|
4 |
+
dataset_mix:
|
5 |
+
- "ego4d_1in150"
|
6 |
+
- "ssv2_1in32"
|
7 |
+
- "epic_kitchen_1in60"
|
8 |
+
- "imagenet"
|
theia/configs/dataset/imagenet.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- image_video_default
|
3 |
+
|
4 |
+
dataset_mix:
|
5 |
+
- "imagenet"
|
theia/configs/dataset/oxe_octo_mix.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: dataset.oxe.oxe_data_utils.OXEDataset
|
2 |
+
dataset_mix: "oxe_magic_soup"
|
3 |
+
image_action_set_root: "/storage/nfs/datasets/jshang/oxe_image_action"
|
4 |
+
feature_set_root: "/storage/nfs/datasets/jshang/oxe_vfm_features"
|
5 |
+
image_views: null
|
6 |
+
split: "train"
|
7 |
+
data_portion: 0.01
|
8 |
+
load_action: False
|
9 |
+
bf16: True
|
10 |
+
safe_tensors: True
|
11 |
+
trajectory_subsample_len: 32
|
12 |
+
return_metadata: False
|
theia/configs/dataset/ssv2.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- image_video_default
|
3 |
+
|
4 |
+
dataset_mix:
|
5 |
+
- "ssv2_1in32"
|
theia/configs/logging/default.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_path: "/storage/nfs/jshang/trained_models"
|
2 |
+
log_path: "/storage/nfs/jshang/logs"
|
3 |
+
save_ckpt_interval: 20000
|
4 |
+
notes: ""
|
5 |
+
run_identifier_prefix: ""
|
6 |
+
project: "theia"
|
theia/configs/model/backbone/deit.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
backbone: facebook/deit-small-patch16-224
|
2 |
+
pretrained: False
|
theia/configs/model/backbone/deit_nocls.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
backbone: nocls-facebook/deit-tiny-patch16-224
|
2 |
+
pretrained: False
|
theia/configs/model/backbone/deit_reg.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
backbone: reg-facebook/deit-tiny-patch16-224
|
2 |
+
pretrained: False
|
3 |
+
num_reg_tokens: 7
|
theia/configs/model/translator/conv.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
type: "conv"
|
2 |
+
kwargs:
|
3 |
+
translator_hidden_size: 1024
|
theia/configs/model/translator/lconv.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
type: "lconv"
|
2 |
+
kwargs:
|
3 |
+
hidden_size_factor: 1.0
|
theia/configs/model/translator/mlp.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
type: "mlp"
|
2 |
+
kwargs:
|
3 |
+
translator_n_layer: 3
|
4 |
+
hidden_size: 1024
|
theia/configs/model/translator/transformer.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
type: "transformer"
|
2 |
+
kwargs:
|
3 |
+
translator_n_layers: 2
|
4 |
+
translator_n_heads: 8
|
5 |
+
translator_hidden_size: 1024
|
theia/configs/train_rvfm_imagenet.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- dataset: imagenet
|
3 |
+
- model/backbone: deit
|
4 |
+
- model/translator: lconv
|
5 |
+
- training: frame_level
|
6 |
+
- logging: default
|
7 |
+
- _self_
|
8 |
+
|
9 |
+
seed: 0
|
theia/configs/training/frame_level.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- target_models: cdiv
|
3 |
+
|
4 |
+
epochs: 50
|
5 |
+
warm_up_steps_ratio: 0.1
|
6 |
+
|
7 |
+
base_lr: 2e-3
|
8 |
+
batch_size: 16
|
9 |
+
random_target_models: -1
|
10 |
+
num_workers: 8
|
11 |
+
# base training settings to scale lr, rarely changed
|
12 |
+
base_batch_size: 64
|
13 |
+
base_world_size: 8
|
14 |
+
|
15 |
+
weight_decay: 0.01
|
16 |
+
|
17 |
+
|
18 |
+
optimizer:
|
19 |
+
_target_: torch.optim.AdamW
|
20 |
+
betas: [0.9, 0.999]
|
21 |
+
|
22 |
+
lr_scheduler:
|
23 |
+
_target_: theia.lr_schedulers.get_constant_lrs_with_linear_warm_up
|
24 |
+
warm_up_lr_start_factor: 1e-2
|
25 |
+
|
26 |
+
|
27 |
+
grad_clip: False
|
28 |
+
grad_clip_norm_warmup: 10.0
|
29 |
+
grad_clip_norm: 1.0
|
30 |
+
|
31 |
+
freeze_translator: False
|
32 |
+
freeze_translator_start_steps_ratio: 0.2
|
33 |
+
translator_lr_factor: 1.0
|
34 |
+
|
35 |
+
main_loss: cos_l1
|
theia/configs/training/target_models/cdds.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "facebook/dinov2-large"
|
3 |
+
- "openai/clip-vit-large-patch14"
|
4 |
+
- "facebook/sam-vit-huge"
|
5 |
+
- "LiheYoung/depth-anything-large-hf"
|
6 |
+
target_model_weights: null
|
theia/configs/training/target_models/cddsv.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "google/vit-huge-patch14-224-in21k"
|
3 |
+
- "facebook/dinov2-large"
|
4 |
+
- "openai/clip-vit-large-patch14"
|
5 |
+
- "facebook/sam-vit-huge"
|
6 |
+
- "LiheYoung/depth-anything-large-hf"
|
7 |
+
target_model_weights: null
|
theia/configs/training/target_models/cddv.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "google/vit-huge-patch14-224-in21k"
|
3 |
+
- "facebook/dinov2-large"
|
4 |
+
- "openai/clip-vit-large-patch14"
|
5 |
+
- "LiheYoung/depth-anything-large-hf"
|
6 |
+
target_model_weights: null
|
theia/configs/training/target_models/cdesv.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "google/vit-huge-patch14-224-in21k"
|
3 |
+
- "openai/clip-vit-large-patch14"
|
4 |
+
- "facebook/sam-vit-huge"
|
5 |
+
- "LiheYoung/depth-anything-large-hf"
|
6 |
+
target_model_weights: null
|
theia/configs/training/target_models/cdis.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "facebook/dinov2-large"
|
3 |
+
- "openai/clip-vit-large-patch14"
|
4 |
+
- "facebook/sam-vit-huge"
|
5 |
+
target_model_weights: null
|
theia/configs/training/target_models/cdisv.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "google/vit-huge-patch14-224-in21k"
|
3 |
+
- "facebook/dinov2-large"
|
4 |
+
- "openai/clip-vit-large-patch14"
|
5 |
+
- "facebook/sam-vit-huge"
|
6 |
+
target_model_weights: null
|
theia/configs/training/target_models/cdiv.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "google/vit-huge-patch14-224-in21k"
|
3 |
+
- "facebook/dinov2-large"
|
4 |
+
- "openai/clip-vit-large-patch14"
|
5 |
+
target_model_weights: null
|
theia/configs/training/target_models/clip.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "openai/clip-vit-large-patch14"
|
3 |
+
target_model_weights: null
|
theia/configs/training/target_models/ddsv.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "google/vit-huge-patch14-224-in21k"
|
3 |
+
- "facebook/dinov2-large"
|
4 |
+
- "facebook/sam-vit-huge"
|
5 |
+
- "LiheYoung/depth-anything-large-hf"
|
6 |
+
target_model_weights: null
|
theia/configs/training/target_models/depth_anything.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "LiheYoung/depth-anything-large-hf"
|
3 |
+
target_model_weights: null
|
theia/configs/training/target_models/dinov2.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "facebook/dinov2-large"
|
3 |
+
target_model_weights: null
|
theia/configs/training/target_models/sam.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "facebook/sam-vit-huge"
|
3 |
+
target_model_weights: null
|
theia/configs/training/target_models/vit.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
target_model_names:
|
2 |
+
- "google/vit-huge-patch14-224-in21k"
|
3 |
+
target_model_weights: null
|
theia/dataset/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from .image.image_common import ALL_IMAGE_DATASETS
|
4 |
+
from .oxe.oxe_common import ALL_OXE_DATASETS
|
5 |
+
from .video.video_common import ALL_VIDEO_DATASETS
|
theia/dataset/data_utils.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
"""Defines PyTorch datasets of dataloaders for multiple image, video, and OXE datasets.
|
4 |
+
Should use with webdataset >= 0.2.90. See https://github.com/webdataset/webdataset/pull/347"""
|
5 |
+
|
6 |
+
import glob
|
7 |
+
import json
|
8 |
+
import math
|
9 |
+
import os.path as osp
|
10 |
+
from collections import OrderedDict
|
11 |
+
from functools import partial
|
12 |
+
from io import BytesIO
|
13 |
+
from typing import Any, Callable, Generator, Iterator, Literal, Optional
|
14 |
+
|
15 |
+
import cv2
|
16 |
+
import numpy as np
|
17 |
+
import omegaconf
|
18 |
+
import torch
|
19 |
+
import webdataset as wds
|
20 |
+
from datasets.combine import DatasetType
|
21 |
+
from einops import rearrange
|
22 |
+
from numpy.typing import NDArray
|
23 |
+
from safetensors.torch import load as sft_load
|
24 |
+
from torch import default_generator
|
25 |
+
from torch.utils.data import DataLoader, Dataset, IterableDataset, default_collate
|
26 |
+
|
27 |
+
from theia.foundation_models.common import MODELS
|
28 |
+
from theia.dataset.oxe.oxe_common import ALL_OXE_DATASETS
|
29 |
+
from theia.dataset.oxe.oxe_mixes import OXE_NAMED_MIXES
|
30 |
+
|
31 |
+
PACKED_FEATURES = [model_name for model_name in MODELS if "llava" not in model_name]
|
32 |
+
|
33 |
+
|
34 |
+
def normalize_ds_weights_by_ds_len(weights: list[float], lengths: list[int]) -> tuple[list[float], float | Literal[0]]:
|
35 |
+
"""Normalize dataset weights by dataset lengths (frames).
|
36 |
+
|
37 |
+
Args:
|
38 |
+
weights (list[float]): assigned weights.
|
39 |
+
lengths (list[int]): lengths of datasets.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
tuple[list[float], int]: normalized weights, and sum of the expected lengths of datasets
|
43 |
+
"""
|
44 |
+
expected_lengths = [weight * length for weight, length in zip(weights, lengths, strict=False)]
|
45 |
+
sum_expected_lengths = sum(expected_lengths)
|
46 |
+
if sum_expected_lengths == 0:
|
47 |
+
raise ValueError("Sum of dataset length is 0.")
|
48 |
+
normalized_weights = [length * 1.0 / sum_expected_lengths for length in expected_lengths]
|
49 |
+
return normalized_weights, sum_expected_lengths
|
50 |
+
|
51 |
+
|
52 |
+
def get_vo_keys(dataset_name: str, image_views: Optional[list | str | dict[str, str | list[str]]] = None) -> list[str]:
|
53 |
+
"""Get visual observation keys of datasets (to be compatible with OXE).
|
54 |
+
|
55 |
+
Args:
|
56 |
+
dataset_name (str): name of the dataset.
|
57 |
+
image_views (Optional[dict[str, str | list[str]]], optional): keys of selected views.
|
58 |
+
Defaults to None.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
list[str]: keys to the views in the dataset.
|
62 |
+
"""
|
63 |
+
default_visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"][:1]
|
64 |
+
visual_observation_keys = []
|
65 |
+
if image_views is None:
|
66 |
+
visual_observation_keys = default_visual_observation_keys
|
67 |
+
elif isinstance(image_views, list):
|
68 |
+
visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"]
|
69 |
+
elif isinstance(image_views, str):
|
70 |
+
if image_views == "static":
|
71 |
+
visual_observation_keys = [
|
72 |
+
k
|
73 |
+
for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"]
|
74 |
+
if "wrist" not in k and "hand" not in k
|
75 |
+
]
|
76 |
+
elif image_views == "wrist":
|
77 |
+
visual_observation_keys = [
|
78 |
+
k for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"] if "wrist" in k or "hand" in k
|
79 |
+
]
|
80 |
+
if len(visual_observation_keys) == 0:
|
81 |
+
visual_observation_keys = default_visual_observation_keys
|
82 |
+
return visual_observation_keys
|
83 |
+
|
84 |
+
|
85 |
+
class RandomMix(IterableDataset):
|
86 |
+
"""A random interleave of multiple iterable datasets."""
|
87 |
+
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
datasets: list[IterableDataset],
|
91 |
+
probs: list[float] | NDArray | None = None,
|
92 |
+
stopping_strategy: str = "all_exhausted",
|
93 |
+
seed: Optional[int | str] = 0,
|
94 |
+
) -> None:
|
95 |
+
"""Initialization of a random interleave dataset.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
datasets (list[IterableDataset]): datasets to be interleaved.
|
99 |
+
probs (list[float] | NDArray, optional): probability of each dataset. Defaults to None.
|
100 |
+
stopping_strategy (str, optional): when to end the sampling for one epoch. Defaults to `all_exhausted`.
|
101 |
+
`all_exhausted`: each sample in the dataset will be sampled at least once.
|
102 |
+
`first_exhausted`: when the first dataset is ran out, this episode ends.
|
103 |
+
See also https://huggingface.co/docs/datasets/en/stream#interleave for definitions.
|
104 |
+
seed (Optional[int | str]): seed. Defaults to 0.
|
105 |
+
"""
|
106 |
+
self.datasets = datasets
|
107 |
+
if probs is None:
|
108 |
+
self.probs = [1.0] * len(self.datasets)
|
109 |
+
elif isinstance(probs, np.ndarray):
|
110 |
+
self.probs = probs.tolist()
|
111 |
+
else:
|
112 |
+
self.probs = probs
|
113 |
+
self.stopping_strategy = stopping_strategy
|
114 |
+
self.seed = seed
|
115 |
+
|
116 |
+
def __iter__(self) -> Generator:
|
117 |
+
"""Return an iterator over the sources."""
|
118 |
+
sources = [iter(d) for d in self.datasets]
|
119 |
+
probs = self.probs[:]
|
120 |
+
seed_gen = torch.Generator()
|
121 |
+
seed_gen.manual_seed(self.seed)
|
122 |
+
cum = (np.array(probs) / np.sum(probs)).cumsum()
|
123 |
+
while len(sources) > 0:
|
124 |
+
r = torch.rand(1, generator=seed_gen).item()
|
125 |
+
i = np.searchsorted(cum, r)
|
126 |
+
try:
|
127 |
+
yield next(sources[i])
|
128 |
+
except StopIteration:
|
129 |
+
if self.stopping_strategy == "all_exhausted":
|
130 |
+
del sources[i]
|
131 |
+
del probs[i]
|
132 |
+
cum = (np.array(probs) / np.sum(probs)).cumsum()
|
133 |
+
elif self.stopping_strategy == "first_exhausted":
|
134 |
+
break
|
135 |
+
|
136 |
+
|
137 |
+
def decode_sample(
|
138 |
+
key: str, data: bytes, image_transform: Optional[Callable] = None, feature_transform: Optional[Callable] = None
|
139 |
+
) -> Any:
|
140 |
+
"""Decode a sample from bytes with optional image and feature transforms
|
141 |
+
|
142 |
+
Args:
|
143 |
+
key (str): key of an attribute (a column) of the sample.
|
144 |
+
data (bytes): original data bytes.
|
145 |
+
image_transform (Optional[Callable], optional): image transform. Defaults to None.
|
146 |
+
feature_transform (Optional[Callable], optional): feature transform. Defaults to None.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
Any: decoded data.
|
150 |
+
"""
|
151 |
+
if ".safetensors" in key:
|
152 |
+
sft = sft_load(data)
|
153 |
+
embedding = rearrange(sft["embedding"], "c h w -> (h w) c")
|
154 |
+
if feature_transform is not None:
|
155 |
+
embedding = feature_transform(embedding)
|
156 |
+
if "cls_token" in sft:
|
157 |
+
cls = sft["cls_token"]
|
158 |
+
if feature_transform is not None:
|
159 |
+
cls = feature_transform(cls)
|
160 |
+
return {"embedding": embedding, "cls": cls}
|
161 |
+
return {"embedding": embedding}
|
162 |
+
elif key == ".image":
|
163 |
+
image = np.load(BytesIO(data))
|
164 |
+
if len(image.shape) == 2:
|
165 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
166 |
+
elif len(image.shape) == 3 and image.shape[-1] == 4:
|
167 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
168 |
+
if image_transform is not None:
|
169 |
+
return image_transform(image)
|
170 |
+
return image
|
171 |
+
else:
|
172 |
+
return data
|
173 |
+
|
174 |
+
|
175 |
+
def get_oxe_frame_dataset(
|
176 |
+
dataset_root: str,
|
177 |
+
dataset_mix: Optional[str | dict[str, float] | list] = "oxe_magic_soup",
|
178 |
+
feature_models: Optional[list[str]] = None,
|
179 |
+
split: str = "train",
|
180 |
+
dataset_ratio: float = 1.0,
|
181 |
+
image_views: Optional[dict[str, str | list[str]]] = None,
|
182 |
+
image_transform: Optional[Callable[[Any], torch.Tensor]] = None,
|
183 |
+
seed: Optional[int | str] = 0,
|
184 |
+
shuffle: bool = False,
|
185 |
+
world_size: int = 1,
|
186 |
+
) -> tuple[dict[str, DatasetType], float | Literal[0]]:
|
187 |
+
"""Get OXE datasets at frame level.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
dataset_root (str): root dir of the datasets.
|
191 |
+
dataset_mix (Optional[str | dict[str, float] | list], optional): how to mix the datasets.
|
192 |
+
Defaults to "oxe_magic_soup".
|
193 |
+
feature_models (Optional[list[str]], optional): models to load their features. Defaults to None.
|
194 |
+
split (str, optional): split "train" or "val" or "test". Defaults to "train".
|
195 |
+
dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0.
|
196 |
+
image_views (Optional[dict[str, str | list[str]]], optional): image views to select. Defaults to None.
|
197 |
+
image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples.
|
198 |
+
Defaults to None.
|
199 |
+
seed (Optional[int | str], optional): seed. Defaults to 0.
|
200 |
+
shuffle (bool, optional): shuffle or not. Defaults to False.
|
201 |
+
world_size (int, optional): world size of DDP training. Defaults to 1.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}.
|
205 |
+
"""
|
206 |
+
# read dataset mix from any acceptable form
|
207 |
+
if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES:
|
208 |
+
dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]})
|
209 |
+
elif isinstance(dataset_mix, dict):
|
210 |
+
dataset_mix = OrderedDict(**dataset_mix)
|
211 |
+
elif isinstance(dataset_mix, list):
|
212 |
+
dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
|
213 |
+
else:
|
214 |
+
raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.")
|
215 |
+
|
216 |
+
if split == "eval" or split == "val":
|
217 |
+
dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
|
218 |
+
|
219 |
+
# note down the dataset weights
|
220 |
+
dataset_weights: list[float] = []
|
221 |
+
# get frame level length
|
222 |
+
dataset_lens: list[int] = []
|
223 |
+
|
224 |
+
all_feature_datasets: dict[str, DatasetType] = {}
|
225 |
+
for dataset in dataset_mix:
|
226 |
+
visual_observation_keys = get_vo_keys(dataset_name=dataset, image_views=image_views)
|
227 |
+
|
228 |
+
if feature_models is None:
|
229 |
+
feature_models = PACKED_FEATURES
|
230 |
+
|
231 |
+
with open(osp.join(dataset_root, dataset, "splits.json"), "r") as splitf:
|
232 |
+
dataset_len = json.load(splitf)[split]
|
233 |
+
# if the length is 0, skip
|
234 |
+
# this may happen for small datasets with very few shards
|
235 |
+
if dataset_len == 0:
|
236 |
+
continue
|
237 |
+
|
238 |
+
for vo_key in visual_observation_keys:
|
239 |
+
for model_name in feature_models:
|
240 |
+
if model_name not in PACKED_FEATURES:
|
241 |
+
feature_set_name = model_name
|
242 |
+
path_pattern = osp.join(
|
243 |
+
dataset_root, dataset, vo_key + f"_{model_name.replace('/', '_')}", f"*-{split}*.tar"
|
244 |
+
)
|
245 |
+
rename_kw = {model_name: model_name.replace("/", "_") + ".safetensors"} # replace v by k
|
246 |
+
elif "packed" in all_feature_datasets:
|
247 |
+
continue
|
248 |
+
else:
|
249 |
+
feature_set_name = "packed"
|
250 |
+
path_pattern = osp.join(dataset_root, dataset, vo_key, f"*-{split}*.tar")
|
251 |
+
rename_kw = {
|
252 |
+
name: name.replace("/", "_") + ".safetensors" for name in PACKED_FEATURES
|
253 |
+
} # replace v by k
|
254 |
+
rename_kw["image"] = "image"
|
255 |
+
|
256 |
+
if feature_set_name not in all_feature_datasets:
|
257 |
+
all_feature_datasets[feature_set_name] = []
|
258 |
+
|
259 |
+
shard_paths = sorted(glob.glob(path_pattern))
|
260 |
+
num_shards = len(shard_paths)
|
261 |
+
if num_shards < world_size * 8:
|
262 |
+
shard_paths *= math.ceil(world_size * 8 / num_shards)
|
263 |
+
ds = (
|
264 |
+
wds.WebDataset(
|
265 |
+
shard_paths,
|
266 |
+
nodesplitter=wds.split_by_node,
|
267 |
+
workersplitter=wds.split_by_worker,
|
268 |
+
detshuffle=True,
|
269 |
+
shardshuffle=shuffle,
|
270 |
+
seed=seed,
|
271 |
+
)
|
272 |
+
.decode(partial(decode_sample, image_transform=image_transform))
|
273 |
+
.rename(keep=False, **rename_kw)
|
274 |
+
)
|
275 |
+
all_feature_datasets[feature_set_name].append(ds)
|
276 |
+
|
277 |
+
dataset_weights.append(dataset_mix[dataset])
|
278 |
+
dataset_lens.append(math.ceil(dataset_len * dataset_ratio))
|
279 |
+
|
280 |
+
normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens)
|
281 |
+
|
282 |
+
combined_feature_datasets: dict[str, Dataset] = {}
|
283 |
+
for feature_set_name, fds in all_feature_datasets.items():
|
284 |
+
ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted")
|
285 |
+
combined_feature_datasets[feature_set_name] = ds
|
286 |
+
|
287 |
+
return combined_feature_datasets, sum_expected_lengths
|
288 |
+
|
289 |
+
|
290 |
+
def get_oxe_frame_dataloader(
|
291 |
+
datasets: dict[str, DatasetType], batch_size: Optional[int] = None, shuffle_buffer_size: int = 1_000, **kwargs: Any
|
292 |
+
) -> dict[str, DataLoader]:
|
293 |
+
"""Get dataloaders of OXE datasets. Corresponding to `get_oxe_frame_dataset()`.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
datasets (dict[str, DatasetType]): OXE datasets from `get_oxe_frame_dataset().
|
297 |
+
batch_size (Optional[int], optional): batch size. Defaults to None.
|
298 |
+
shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}.
|
302 |
+
"""
|
303 |
+
loaders = {
|
304 |
+
k: (
|
305 |
+
wds.WebLoader(datasets[k], batch_size=None, **kwargs)
|
306 |
+
.shuffle(shuffle_buffer_size) # shuffle after mix
|
307 |
+
.batched(batch_size, collation_fn=default_collate)
|
308 |
+
)
|
309 |
+
for k in datasets
|
310 |
+
}
|
311 |
+
return loaders
|
312 |
+
|
313 |
+
|
314 |
+
def get_oxe_frame_iterator(
|
315 |
+
data_loaders: dict[str, DataLoader],
|
316 |
+
) -> Iterator[dict[str, Any]]:
|
317 |
+
"""Get iterator from dataloders. Corresponding to `get_oxe_frame_dataloader()`.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
data_loaders (dict[str, DataLoader]): dataloaders from `get_oxe_frame_dataloader()`.
|
321 |
+
|
322 |
+
Yields:
|
323 |
+
Iterator[dict[str, Any]]: data sample.
|
324 |
+
"""
|
325 |
+
packed_loader = data_loaders.get("packed", None)
|
326 |
+
# place packed_loader at the first
|
327 |
+
if packed_loader is not None:
|
328 |
+
loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]]
|
329 |
+
else:
|
330 |
+
loaders = list(data_loaders.values())
|
331 |
+
|
332 |
+
# merge dicts
|
333 |
+
for data in zip(*loaders, strict=False):
|
334 |
+
# yield data
|
335 |
+
for i in range(1, len(loaders)):
|
336 |
+
for k in data[i]:
|
337 |
+
if k not in data[0]:
|
338 |
+
data[0][k] = data[i][k]
|
339 |
+
yield data[0]
|
340 |
+
|
341 |
+
|
342 |
+
def normalize_feature(
|
343 |
+
x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None
|
344 |
+
) -> torch.Tensor:
|
345 |
+
"""Normalize the feature given mean and std.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
x (torch.Tensor): input features
|
349 |
+
mean (Optional[torch.Tensor], optional): mean values. Defaults to None.
|
350 |
+
std (Optional[torch.Tensor], optional): std values. Defaults to None.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
torch.Tensor: feature after normalization
|
354 |
+
"""
|
355 |
+
return x if mean is None or std is None else (x - mean) / std
|
356 |
+
|
357 |
+
|
358 |
+
def load_feature_stats(
|
359 |
+
dataset_root: str, feature_models: list[str]
|
360 |
+
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
361 |
+
"""Load feature statictics (mean and variance).
|
362 |
+
|
363 |
+
Args:
|
364 |
+
dataset_root (str): root dir of the dataset (or where to hold the statistics).
|
365 |
+
feature_models (list[str]): names of the models/features.
|
366 |
+
|
367 |
+
Returns:
|
368 |
+
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variances. Keys are model names.
|
369 |
+
"""
|
370 |
+
feature_means: dict[str, torch.Tensor] = {}
|
371 |
+
feature_vars: dict[str, torch.Tensor] = {}
|
372 |
+
for model in feature_models:
|
373 |
+
model_name = model.replace("/", "_")
|
374 |
+
feature_means[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_mean_{model_name}.npy"))).to(
|
375 |
+
torch.bfloat16
|
376 |
+
)
|
377 |
+
feature_vars[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_var_{model_name}.npy"))).to(
|
378 |
+
torch.bfloat16
|
379 |
+
)
|
380 |
+
return feature_means, feature_vars
|
381 |
+
|
382 |
+
|
383 |
+
def pad_shard_paths(shard_paths: list[str], num_shards: int, num_parts: int) -> list[str]:
|
384 |
+
"""Pad shard paths to be divided by number of partitions (ranks*nodes).
|
385 |
+
|
386 |
+
Args:
|
387 |
+
shard_paths (list[str]): pathes of dataset shards.
|
388 |
+
num_shards (int): number of shards.
|
389 |
+
num_parts (int): number of partitions.
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
list[str]: shard paths padded.
|
393 |
+
"""
|
394 |
+
final_shard_paths = shard_paths
|
395 |
+
if num_shards % num_parts != 0:
|
396 |
+
if num_shards < num_parts - num_shards:
|
397 |
+
for _ in range(math.floor((num_parts - num_shards) / num_shards)):
|
398 |
+
final_shard_paths += shard_paths[:]
|
399 |
+
final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)]
|
400 |
+
else:
|
401 |
+
final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)]
|
402 |
+
return final_shard_paths
|
403 |
+
|
404 |
+
|
405 |
+
def get_image_video_dataset(
|
406 |
+
dataset_root: str,
|
407 |
+
feature_models: list[str],
|
408 |
+
dataset_mix: Optional[str | dict[str, float] | list] = None,
|
409 |
+
split: str = "train",
|
410 |
+
dataset_ratio: float = 1.0,
|
411 |
+
image_transform: Optional[Callable[[Any], torch.Tensor]] = None,
|
412 |
+
feature_norm: bool = False,
|
413 |
+
seed: Optional[int | str] = 0,
|
414 |
+
shuffle: bool = False,
|
415 |
+
world_size: int = 1,
|
416 |
+
**kwargs: Any,
|
417 |
+
) -> tuple[dict[str, DatasetType], float | Literal[0]]:
|
418 |
+
"""Get image and video datasets at frame level.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
dataset_root (str): root dir of the datasets.
|
422 |
+
feature_models (list[str]): models to load their features.
|
423 |
+
dataset_mix (Optional[str | dict[str, float] | list], optional): how to mix the datasets.
|
424 |
+
split (str, optional): split "train" or "val" or "test". Defaults to "train".
|
425 |
+
dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0.
|
426 |
+
image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples.
|
427 |
+
Defaults to None.
|
428 |
+
feature_norm: (bool, optional): whether to normalize the feature. Defaults to False.
|
429 |
+
seed (Optional[int | str], optional): seed. Defaults to 0.
|
430 |
+
shuffle (bool, optional): shuffle or not. Defaults to False.
|
431 |
+
world_size (int, optional): world size of DDP training. Defaults to 1.
|
432 |
+
kwargs (Any): arguments to pass-through.
|
433 |
+
|
434 |
+
Returns:
|
435 |
+
tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}.
|
436 |
+
"""
|
437 |
+
# read dataset mix from any acceptable form
|
438 |
+
if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES:
|
439 |
+
dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]})
|
440 |
+
elif isinstance(dataset_mix, dict):
|
441 |
+
dataset_mix = OrderedDict(**dataset_mix)
|
442 |
+
elif isinstance(dataset_mix, list) or isinstance(dataset_mix, omegaconf.listconfig.ListConfig):
|
443 |
+
dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
|
444 |
+
else:
|
445 |
+
raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.")
|
446 |
+
|
447 |
+
if split == "eval" or split == "val":
|
448 |
+
dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
|
449 |
+
|
450 |
+
# note down the dataset weights
|
451 |
+
dataset_weights: list[float] = []
|
452 |
+
# get frame level length
|
453 |
+
dataset_lens: list[int] = []
|
454 |
+
|
455 |
+
all_feature_datasets: dict[str, DatasetType] = {}
|
456 |
+
|
457 |
+
if feature_norm:
|
458 |
+
feature_means, feature_vars = load_feature_stats(dataset_root, feature_models)
|
459 |
+
|
460 |
+
for d in dataset_mix:
|
461 |
+
|
462 |
+
with open(osp.join(dataset_root, d, "splits.json"), "r") as splitf:
|
463 |
+
dataset_len = json.load(splitf)[split]
|
464 |
+
|
465 |
+
# if the length is 0, skip
|
466 |
+
# this may happen for small datasets with very few shards
|
467 |
+
if dataset_len == 0:
|
468 |
+
continue
|
469 |
+
|
470 |
+
path_pattern = osp.join(dataset_root, d, "images", f"*-{split}.tar")
|
471 |
+
if "image" not in all_feature_datasets:
|
472 |
+
all_feature_datasets["image"] = []
|
473 |
+
shard_paths = sorted(glob.glob(path_pattern))
|
474 |
+
num_shards = len(shard_paths)
|
475 |
+
num_parts = world_size
|
476 |
+
final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts)
|
477 |
+
ds = wds.WebDataset(
|
478 |
+
final_shard_paths,
|
479 |
+
nodesplitter=wds.split_by_node,
|
480 |
+
workersplitter=wds.split_by_worker,
|
481 |
+
detshuffle=True,
|
482 |
+
shardshuffle=shuffle,
|
483 |
+
seed=seed,
|
484 |
+
).decode(partial(decode_sample, image_transform=image_transform))
|
485 |
+
all_feature_datasets["image"].append(ds)
|
486 |
+
|
487 |
+
for model_name in feature_models:
|
488 |
+
path_pattern = osp.join(dataset_root, d, f"{model_name.replace('/', '_')}", f"*-{split}.tar")
|
489 |
+
rename_kw = {model_name: model_name.replace("/", "_").lower() + ".safetensors"} # replace v by k
|
490 |
+
|
491 |
+
if model_name not in all_feature_datasets:
|
492 |
+
all_feature_datasets[model_name] = []
|
493 |
+
|
494 |
+
shard_paths = sorted(glob.glob(path_pattern))
|
495 |
+
num_shards = len(shard_paths)
|
496 |
+
num_parts = world_size
|
497 |
+
final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts)
|
498 |
+
if feature_norm:
|
499 |
+
feature_transform = partial(
|
500 |
+
normalize_feature, mean=feature_means[model_name], std=feature_vars[model_name]
|
501 |
+
)
|
502 |
+
else:
|
503 |
+
feature_transform = None
|
504 |
+
ds = (
|
505 |
+
wds.WebDataset(
|
506 |
+
final_shard_paths,
|
507 |
+
nodesplitter=wds.split_by_node,
|
508 |
+
workersplitter=wds.split_by_worker,
|
509 |
+
detshuffle=True,
|
510 |
+
shardshuffle=shuffle,
|
511 |
+
seed=seed,
|
512 |
+
)
|
513 |
+
.decode(partial(decode_sample, image_transform=image_transform, feature_transform=feature_transform))
|
514 |
+
.rename(keep=False, **rename_kw)
|
515 |
+
)
|
516 |
+
all_feature_datasets[model_name].append(ds)
|
517 |
+
|
518 |
+
dataset_weights.append(dataset_mix[d])
|
519 |
+
dataset_lens.append(math.ceil(dataset_len * dataset_ratio))
|
520 |
+
|
521 |
+
normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens)
|
522 |
+
|
523 |
+
combined_feature_datasets: dict[str, Dataset] = {}
|
524 |
+
for feature_set_name, fds in all_feature_datasets.items():
|
525 |
+
ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted", seed=seed)
|
526 |
+
combined_feature_datasets[feature_set_name] = ds
|
527 |
+
|
528 |
+
return combined_feature_datasets, sum_expected_lengths
|
529 |
+
|
530 |
+
|
531 |
+
def get_frame_dataloader(
|
532 |
+
datasets: dict[str, DatasetType],
|
533 |
+
batch_size: Optional[int] = None,
|
534 |
+
shuffle: bool = False,
|
535 |
+
shuffle_buffer_size: int = 1_000,
|
536 |
+
seed: Optional[int] = 0,
|
537 |
+
**kwargs: Any,
|
538 |
+
) -> dict[str, DataLoader]:
|
539 |
+
"""Get dataloaders of image and video datasets. Corresponding to `get_image_video_dataset()`.
|
540 |
+
|
541 |
+
Args:
|
542 |
+
datasets (dict[str, DatasetType]): image and video datasets from `get_image_video_dataset().
|
543 |
+
batch_size (Optional[int], optional): batch size. Defaults to None.
|
544 |
+
shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000.
|
545 |
+
|
546 |
+
Returns:
|
547 |
+
dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}.
|
548 |
+
"""
|
549 |
+
loaders = {}
|
550 |
+
for k in datasets:
|
551 |
+
loader = wds.WebLoader(datasets[k], batch_size=None, generator=default_generator, **kwargs)
|
552 |
+
if shuffle:
|
553 |
+
loader = loader.shuffle(shuffle_buffer_size, seed=seed) # shuffle after mix
|
554 |
+
loader = loader.batched(batch_size, collation_fn=default_collate)
|
555 |
+
loaders[k] = loader
|
556 |
+
return loaders
|
557 |
+
|
558 |
+
|
559 |
+
def get_frame_iterator(
|
560 |
+
data_loaders: dict[str, DataLoader],
|
561 |
+
) -> Iterator[dict[str, Any]]:
|
562 |
+
"""Get iterator from image and video dataset dataloders. Corresponding to `get_frame_dataloader()`.
|
563 |
+
|
564 |
+
Args:
|
565 |
+
data_loaders (dict[str, DataLoader]): dataloaders from `get_frame_dataloader()`.
|
566 |
+
|
567 |
+
Yields:
|
568 |
+
Iterator[dict[str, Any]]: data sample.
|
569 |
+
"""
|
570 |
+
packed_loader = data_loaders.get("packed", None)
|
571 |
+
# place packed_loader at the first
|
572 |
+
if packed_loader is not None:
|
573 |
+
loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]]
|
574 |
+
else:
|
575 |
+
loaders = list(data_loaders.values())
|
576 |
+
|
577 |
+
# merge dicts
|
578 |
+
# this is to accommodate the old organization of datasets (each shard contains one or more columns,
|
579 |
+
# and images are duplicated columns).
|
580 |
+
# In new (current) dataset organization (columns are completely separated),
|
581 |
+
# column keys are all different except some "built-in" keys added by webdataset,
|
582 |
+
# but they are not related to any data, training, so on.
|
583 |
+
# During transit from old to new, where two organizations exist at the same time,
|
584 |
+
# this is to ignore extra "image" field in datasets loaded.
|
585 |
+
for data in zip(*loaders, strict=False):
|
586 |
+
# yield data
|
587 |
+
for i in range(1, len(loaders)):
|
588 |
+
for k in data[i]:
|
589 |
+
if k not in data[0]:
|
590 |
+
data[0][k] = data[i][k]
|
591 |
+
yield data[0]
|
theia/dataset/image/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from .image_common import ALL_IMAGE_DATASETS
|
theia/dataset/image/image_common.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
ALL_IMAGE_DATASETS = OrderedDict({"imagenet": {"steps": 1_281_167}})
|
theia/dataset/oxe/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
theia/dataset/oxe/oxe_common.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
"""
|
7 |
+
This ALL_OXE_DATASETS below records metadata of all subsets of OXE dataset.
|
8 |
+
The datasets are in alphabetical order.
|
9 |
+
|
10 |
+
versions (list[str]): available and usable versions, sorted from older to newer.
|
11 |
+
Usually use the last one.
|
12 |
+
episodes (int): total episodes in the dataset.
|
13 |
+
steps (int): total steps in the dataset.
|
14 |
+
visual_observation_keys (list[str]): keys to specify image observations.
|
15 |
+
"""
|
16 |
+
ALL_OXE_DATASETS: OrderedDict = OrderedDict(
|
17 |
+
{
|
18 |
+
"agent_aware_affordances": {
|
19 |
+
"versions": ["1.0.0"],
|
20 |
+
"episodes": 118,
|
21 |
+
"steps": 151628,
|
22 |
+
"visual_observation_keys": ["image"],
|
23 |
+
},
|
24 |
+
"asu_table_top_converted_externally_to_rlds": {
|
25 |
+
"versions": ["0.1.0"],
|
26 |
+
"episodes": 110,
|
27 |
+
"steps": 26113,
|
28 |
+
"visual_observation_keys": ["image"],
|
29 |
+
},
|
30 |
+
"austin_buds_dataset_converted_externally_to_rlds": {
|
31 |
+
"versions": ["0.1.0"],
|
32 |
+
"episodes": 50,
|
33 |
+
"steps": 34112,
|
34 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
35 |
+
},
|
36 |
+
"austin_sailor_dataset_converted_externally_to_rlds": {
|
37 |
+
"versions": ["0.1.0"],
|
38 |
+
"episodes": 240,
|
39 |
+
"steps": 353094,
|
40 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
41 |
+
},
|
42 |
+
"austin_sirius_dataset_converted_externally_to_rlds": {
|
43 |
+
"versions": ["0.1.0"],
|
44 |
+
"episodes": 559,
|
45 |
+
"steps": 279939,
|
46 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
47 |
+
},
|
48 |
+
"bc_z": {
|
49 |
+
"versions": [
|
50 |
+
"0.1.0", # "1.0.0", "old1.0.1", and "1.0.1" are not usable
|
51 |
+
],
|
52 |
+
"episodes": 39350,
|
53 |
+
"steps": 5471693,
|
54 |
+
"visual_observation_keys": ["image"],
|
55 |
+
},
|
56 |
+
"berkeley_autolab_ur5": {
|
57 |
+
"versions": ["0.1.0"],
|
58 |
+
"episodes": 896,
|
59 |
+
"steps": 87783,
|
60 |
+
"visual_observation_keys": ["image", "hand_image"],
|
61 |
+
},
|
62 |
+
"berkeley_cable_routing": {
|
63 |
+
"versions": ["0.1.0"],
|
64 |
+
"episodes": 1482,
|
65 |
+
"steps": 38240,
|
66 |
+
"visual_observation_keys": ["image", "top_image", "wrist225_image", "wrist45_image"],
|
67 |
+
},
|
68 |
+
"berkeley_fanuc_manipulation": {
|
69 |
+
"versions": ["0.1.0"],
|
70 |
+
"episodes": 415,
|
71 |
+
"steps": 62613,
|
72 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
73 |
+
},
|
74 |
+
"berkeley_gnm_cory_hall": {
|
75 |
+
"versions": ["0.1.0"],
|
76 |
+
"episodes": 7331,
|
77 |
+
"steps": 156012,
|
78 |
+
"visual_observation_keys": ["image"],
|
79 |
+
},
|
80 |
+
"berkeley_gnm_recon": {
|
81 |
+
"versions": ["0.1.0"],
|
82 |
+
"episodes": 11834,
|
83 |
+
"steps": 610907,
|
84 |
+
"visual_observation_keys": ["image"],
|
85 |
+
},
|
86 |
+
"berkeley_gnm_sac_son": {
|
87 |
+
"versions": ["0.1.0"],
|
88 |
+
"episodes": 2955,
|
89 |
+
"steps": 241059,
|
90 |
+
"visual_observation_keys": ["image"],
|
91 |
+
},
|
92 |
+
"berkeley_mvp_converted_externally_to_rlds": {
|
93 |
+
"versions": ["0.1.0"],
|
94 |
+
"episodes": 480,
|
95 |
+
"steps": 45308,
|
96 |
+
"visual_observation_keys": ["hand_image"],
|
97 |
+
},
|
98 |
+
"berkeley_rpt_converted_externally_to_rlds": {
|
99 |
+
"versions": ["0.1.0"],
|
100 |
+
"episodes": 908,
|
101 |
+
"steps": 392578,
|
102 |
+
"visual_observation_keys": ["hand_image"],
|
103 |
+
},
|
104 |
+
"bridge": {"versions": ["0.1.0"], "episodes": 25460, "steps": 864292, "visual_observation_keys": ["image"]},
|
105 |
+
"cmu_franka_exploration_dataset_converted_externally_to_rlds": {
|
106 |
+
"versions": ["0.1.0"],
|
107 |
+
"episodes": 199,
|
108 |
+
"steps": 1990,
|
109 |
+
"visual_observation_keys": ["image"],
|
110 |
+
},
|
111 |
+
"cmu_play_fusion": {
|
112 |
+
"versions": ["0.1.0"],
|
113 |
+
"episodes": 576,
|
114 |
+
"steps": 235922,
|
115 |
+
"visual_observation_keys": ["image"],
|
116 |
+
},
|
117 |
+
"cmu_playing_with_food": { # this dataset seems to be corrupted
|
118 |
+
"versions": ["1.0.0"],
|
119 |
+
"episodes": 4200,
|
120 |
+
"steps": 83240,
|
121 |
+
"visual_observation_keys": ["image"],
|
122 |
+
},
|
123 |
+
"cmu_stretch": {"versions": ["0.1.0"], "episodes": 135, "steps": 25016, "visual_observation_keys": ["image"]},
|
124 |
+
"columbia_cairlab_pusht_real": {
|
125 |
+
"versions": ["0.1.0"],
|
126 |
+
"episodes": 122,
|
127 |
+
"steps": 24924,
|
128 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
129 |
+
},
|
130 |
+
"dlr_edan_shared_control_converted_externally_to_rlds": {
|
131 |
+
"versions": ["0.1.0"],
|
132 |
+
"episodes": 104,
|
133 |
+
"steps": 8928,
|
134 |
+
"visual_observation_keys": ["image"],
|
135 |
+
},
|
136 |
+
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
|
137 |
+
"versions": ["0.1.0"],
|
138 |
+
"episodes": 107,
|
139 |
+
"steps": 7622,
|
140 |
+
"visual_observation_keys": ["image"],
|
141 |
+
},
|
142 |
+
"dlr_sara_pour_converted_externally_to_rlds": {
|
143 |
+
"versions": ["0.1.0"],
|
144 |
+
"episodes": 100,
|
145 |
+
"steps": 12971,
|
146 |
+
"visual_observation_keys": ["image"],
|
147 |
+
},
|
148 |
+
"eth_agent_affordances": {
|
149 |
+
"versions": ["0.1.0"],
|
150 |
+
"episodes": 118,
|
151 |
+
"steps": 151628,
|
152 |
+
"visual_observation_keys": ["image"],
|
153 |
+
},
|
154 |
+
"fanuc_manipulation_v2": {
|
155 |
+
"versions": ["1.0.0"],
|
156 |
+
"episodes": 415,
|
157 |
+
"steps": 62613,
|
158 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
159 |
+
},
|
160 |
+
"fractal20220817_data": {
|
161 |
+
"versions": ["0.1.0"],
|
162 |
+
"episodes": 87212,
|
163 |
+
"steps": 3786400,
|
164 |
+
"visual_observation_keys": ["image"],
|
165 |
+
},
|
166 |
+
"furniture_bench_dataset_converted_externally_to_rlds": {
|
167 |
+
"versions": ["0.1.0"],
|
168 |
+
"episodes": 5100,
|
169 |
+
"steps": 3948057,
|
170 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
171 |
+
},
|
172 |
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
|
173 |
+
"versions": ["0.1.0"],
|
174 |
+
"episodes": 631,
|
175 |
+
"steps": 146241,
|
176 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
177 |
+
},
|
178 |
+
"imperial_wrist_dataset": {
|
179 |
+
"versions": ["1.0.0"],
|
180 |
+
"episodes": 170,
|
181 |
+
"steps": 7148,
|
182 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
183 |
+
},
|
184 |
+
"imperialcollege_sawyer_wrist_cam": {
|
185 |
+
"versions": ["0.1.0"],
|
186 |
+
"episodes": 170,
|
187 |
+
"steps": 7148,
|
188 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
189 |
+
},
|
190 |
+
"jaco_play": {
|
191 |
+
"versions": ["0.1.0"],
|
192 |
+
"episodes": 976,
|
193 |
+
"steps": 70127,
|
194 |
+
"visual_observation_keys": ["image", "image_wrist"],
|
195 |
+
},
|
196 |
+
"kaist_nonprehensile_converted_externally_to_rlds": {
|
197 |
+
"versions": ["0.1.0"],
|
198 |
+
"episodes": 201,
|
199 |
+
"steps": 32429,
|
200 |
+
"visual_observation_keys": ["image"],
|
201 |
+
},
|
202 |
+
"kuka": {"versions": ["0.1.0"], "episodes": 580392, "steps": 8583978, "visual_observation_keys": ["image"]},
|
203 |
+
"language_table": {
|
204 |
+
"versions": ["0.0.1", "0.1.0"],
|
205 |
+
"episodes": 442226,
|
206 |
+
"steps": 7045476,
|
207 |
+
"visual_observation_keys": ["rgb"],
|
208 |
+
},
|
209 |
+
"language_table_blocktoabsolute_oracle_sim": {
|
210 |
+
"versions": ["0.0.1"],
|
211 |
+
"episodes": 200000,
|
212 |
+
"steps": 15866385,
|
213 |
+
"visual_observation_keys": ["rgb"],
|
214 |
+
},
|
215 |
+
"language_table_blocktoblock_4block_sim": {
|
216 |
+
"versions": ["0.0.1"],
|
217 |
+
"episodes": 8298,
|
218 |
+
"steps": 326768,
|
219 |
+
"visual_observation_keys": ["rgb"],
|
220 |
+
},
|
221 |
+
"language_table_blocktoblock_oracle_sim": {
|
222 |
+
"versions": ["0.0.1"],
|
223 |
+
"episodes": 200000,
|
224 |
+
"steps": 12970620,
|
225 |
+
"visual_observation_keys": ["rgb"],
|
226 |
+
},
|
227 |
+
"language_table_blocktoblock_sim": {
|
228 |
+
"versions": ["0.0.1"],
|
229 |
+
"episodes": 8000,
|
230 |
+
"steps": 351688,
|
231 |
+
"visual_observation_keys": ["rgb"],
|
232 |
+
},
|
233 |
+
"language_table_blocktoblockrelative_oracle_sim": {
|
234 |
+
"versions": ["0.0.1"],
|
235 |
+
"episodes": 200000,
|
236 |
+
"steps": 13016749,
|
237 |
+
"visual_observation_keys": ["rgb"],
|
238 |
+
},
|
239 |
+
"language_table_blocktorelative_oracle_sim": {
|
240 |
+
"versions": ["0.0.1"],
|
241 |
+
"episodes": 200000,
|
242 |
+
"steps": 8655815,
|
243 |
+
"visual_observation_keys": ["rgb"],
|
244 |
+
},
|
245 |
+
"language_table_separate_oracle_sim": {
|
246 |
+
"versions": ["0.0.1"],
|
247 |
+
"episodes": 200000,
|
248 |
+
"steps": 3196661,
|
249 |
+
"visual_observation_keys": ["rgb"],
|
250 |
+
},
|
251 |
+
"language_table_sim": {
|
252 |
+
"versions": ["0.0.1"],
|
253 |
+
"episodes": 181020,
|
254 |
+
"steps": 4665423,
|
255 |
+
"visual_observation_keys": ["rgb"],
|
256 |
+
},
|
257 |
+
"maniskill_dataset_converted_externally_to_rlds": {
|
258 |
+
"versions": ["0.1.0"],
|
259 |
+
"episodes": 30213,
|
260 |
+
"steps": 4537402,
|
261 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
262 |
+
},
|
263 |
+
"mutex_dataset": {
|
264 |
+
"versions": ["1.0.0"],
|
265 |
+
"episodes": 1500,
|
266 |
+
"steps": 361883,
|
267 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
268 |
+
},
|
269 |
+
"nyu_door_opening_surprising_effectiveness": {
|
270 |
+
"versions": ["0.1.0"],
|
271 |
+
"episodes": 435,
|
272 |
+
"steps": 18196,
|
273 |
+
"visual_observation_keys": ["image"],
|
274 |
+
},
|
275 |
+
"nyu_franka_play_dataset_converted_externally_to_rlds": {
|
276 |
+
"versions": ["0.1.0"],
|
277 |
+
"episodes": 365,
|
278 |
+
"steps": 34448,
|
279 |
+
"visual_observation_keys": ["image", "image_additional_view"],
|
280 |
+
},
|
281 |
+
"nyu_rot_dataset_converted_externally_to_rlds": {
|
282 |
+
"versions": ["0.1.0"],
|
283 |
+
"episodes": 14,
|
284 |
+
"steps": 440,
|
285 |
+
"visual_observation_keys": ["image"],
|
286 |
+
},
|
287 |
+
"qut_dexterous_manpulation": {
|
288 |
+
"versions": ["0.1.0"],
|
289 |
+
"episodes": 200,
|
290 |
+
"steps": 176278,
|
291 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
292 |
+
},
|
293 |
+
"robo_net": {
|
294 |
+
"versions": ["0.1.0", "1.0.0"],
|
295 |
+
"episodes": 82775,
|
296 |
+
"steps": 2483250,
|
297 |
+
"visual_observation_keys": ["image", "image1", "image2"],
|
298 |
+
},
|
299 |
+
"robot_vqa": {
|
300 |
+
"versions": ["0.1.0"],
|
301 |
+
"episodes": 3331523,
|
302 |
+
"steps": 3331523,
|
303 |
+
"visual_observation_keys": ["images"],
|
304 |
+
},
|
305 |
+
"roboturk": {
|
306 |
+
"versions": ["0.1.0"],
|
307 |
+
"episodes": 1796,
|
308 |
+
"steps": 168423,
|
309 |
+
"visual_observation_keys": ["front_rgb"],
|
310 |
+
},
|
311 |
+
"stanford_hydra_dataset_converted_externally_to_rlds": {
|
312 |
+
"versions": ["0.1.0"],
|
313 |
+
"episodes": 570,
|
314 |
+
"steps": 358234,
|
315 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
316 |
+
},
|
317 |
+
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
|
318 |
+
"versions": ["0.1.0"],
|
319 |
+
"episodes": 3000,
|
320 |
+
"steps": 149985,
|
321 |
+
"visual_observation_keys": ["image"],
|
322 |
+
},
|
323 |
+
"stanford_mask_vit_converted_externally_to_rlds": {
|
324 |
+
"versions": ["0.1.0"],
|
325 |
+
"episodes": 9109,
|
326 |
+
"steps": 282379,
|
327 |
+
"visual_observation_keys": ["image"],
|
328 |
+
},
|
329 |
+
"stanford_robocook_converted_externally_to_rlds": {
|
330 |
+
"versions": ["0.1.0"],
|
331 |
+
"episodes": 2460,
|
332 |
+
"steps": 112980,
|
333 |
+
"visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"],
|
334 |
+
},
|
335 |
+
"taco_play": {
|
336 |
+
"versions": ["0.1.0"],
|
337 |
+
"episodes": 3242,
|
338 |
+
"steps": 213972,
|
339 |
+
"visual_observation_keys": ["rgb_static", "rgb_gripper"],
|
340 |
+
},
|
341 |
+
"tokyo_u_lsmo_converted_externally_to_rlds": {
|
342 |
+
"versions": ["0.1.0"],
|
343 |
+
"episodes": 50,
|
344 |
+
"steps": 11925,
|
345 |
+
"visual_observation_keys": ["image"],
|
346 |
+
},
|
347 |
+
"toto": {"versions": ["0.1.0"], "episodes": 902, "steps": 294139, "visual_observation_keys": ["image"]},
|
348 |
+
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
|
349 |
+
"versions": ["0.1.0"],
|
350 |
+
"episodes": 150,
|
351 |
+
"steps": 3970,
|
352 |
+
"visual_observation_keys": ["image"],
|
353 |
+
},
|
354 |
+
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
|
355 |
+
"versions": ["0.1.0"],
|
356 |
+
"episodes": 1355,
|
357 |
+
"steps": 67750,
|
358 |
+
"visual_observation_keys": ["image"],
|
359 |
+
},
|
360 |
+
"uiuc_d3field": { # this dataset seems to be corrupted
|
361 |
+
"versions": ["0.1.0", "1.1.2"],
|
362 |
+
"episodes": 196,
|
363 |
+
"steps": 13384,
|
364 |
+
"visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"],
|
365 |
+
},
|
366 |
+
"usc_cloth_sim_converted_externally_to_rlds": {
|
367 |
+
"versions": ["0.1.0"],
|
368 |
+
"episodes": 800,
|
369 |
+
"steps": 80000,
|
370 |
+
"visual_observation_keys": ["image"],
|
371 |
+
},
|
372 |
+
"utaustin_mutex": {
|
373 |
+
"versions": ["0.1.0"],
|
374 |
+
"episodes": 1500,
|
375 |
+
"steps": 361883,
|
376 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
377 |
+
},
|
378 |
+
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
|
379 |
+
"versions": ["0.1.0"],
|
380 |
+
"episodes": 64,
|
381 |
+
"steps": 9140,
|
382 |
+
"visual_observation_keys": ["image"],
|
383 |
+
},
|
384 |
+
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
|
385 |
+
"versions": ["0.1.0"],
|
386 |
+
"episodes": 192,
|
387 |
+
"steps": 26346,
|
388 |
+
"visual_observation_keys": ["image"],
|
389 |
+
},
|
390 |
+
"utokyo_saytap_converted_externally_to_rlds": {
|
391 |
+
"versions": ["0.1.0"],
|
392 |
+
"episodes": 20,
|
393 |
+
"steps": 22937,
|
394 |
+
"visual_observation_keys": ["image", "wrist_image"],
|
395 |
+
},
|
396 |
+
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
|
397 |
+
"versions": ["0.1.0"],
|
398 |
+
"episodes": 64,
|
399 |
+
"steps": 1388,
|
400 |
+
"visual_observation_keys": ["image"],
|
401 |
+
},
|
402 |
+
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
|
403 |
+
"versions": ["0.1.0"],
|
404 |
+
"episodes": 92,
|
405 |
+
"steps": 6789,
|
406 |
+
"visual_observation_keys": ["image", "hand_image", "image2"],
|
407 |
+
},
|
408 |
+
"viola": {
|
409 |
+
"versions": ["0.1.0"],
|
410 |
+
"episodes": 135,
|
411 |
+
"steps": 68913,
|
412 |
+
"visual_observation_keys": ["agentview_rgb", "eye_in_hand_rgb"],
|
413 |
+
},
|
414 |
+
}
|
415 |
+
)
|
416 |
+
|
417 |
+
|
418 |
+
def oxe_dsname2path(dataset_name: str, version: Optional[str] = None) -> str:
|
419 |
+
"""From dataset name to remote google clound path to the dataset.
|
420 |
+
|
421 |
+
Args:
|
422 |
+
dataset_name (str): dataset name.
|
423 |
+
version (Optional[str]): version string.
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
str: google clound path
|
427 |
+
"""
|
428 |
+
if version is None:
|
429 |
+
version = ALL_OXE_DATASETS[dataset_name]["versions"][-1]
|
430 |
+
return f"gs://gresearch/robotics/{dataset_name}/{version}"
|
theia/dataset/oxe/oxe_mixes.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified. Modifications Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
"""MIT License Copyright (c) 2023 Robotic AI & Learning Lab Berkeley
|
4 |
+
|
5 |
+
From Octo https://github.com/octo-models/octo/blob/main/octo/data/oxe/oxe_dataset_mixes.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
BRIDGE_MIX = [
|
9 |
+
("bridge_dataset", 1.0),
|
10 |
+
]
|
11 |
+
|
12 |
+
RT_X_MIX = [
|
13 |
+
("fractal20220817_data", 0.54087122203),
|
14 |
+
("kuka", 0.8341046294),
|
15 |
+
("bridge_dataset", 1.0),
|
16 |
+
("taco_play", 2.0),
|
17 |
+
("jaco_play", 2.0),
|
18 |
+
("berkeley_cable_routing", 3.0),
|
19 |
+
("roboturk", 1.0),
|
20 |
+
("nyu_door_opening_surprising_effectiveness", 5.0),
|
21 |
+
("viola", 2.0),
|
22 |
+
("berkeley_autolab_ur5", 1.0),
|
23 |
+
("toto", 1.0),
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
OXE_FRANKA_MIX = [
|
28 |
+
("taco_play", 1.0),
|
29 |
+
("berkeley_cable_routing", 1.0),
|
30 |
+
("viola", 1.0),
|
31 |
+
("toto", 1.0),
|
32 |
+
("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
|
33 |
+
("austin_buds_dataset_converted_externally_to_rlds", 3.0),
|
34 |
+
("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
|
35 |
+
("maniskill_dataset_converted_externally_to_rlds", 0.1),
|
36 |
+
("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
|
37 |
+
("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0),
|
38 |
+
("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
|
39 |
+
("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
|
40 |
+
("berkeley_rpt_converted_externally_to_rlds", 1.0),
|
41 |
+
("kaist_nonprehensile_converted_externally_to_rlds", 3.0),
|
42 |
+
("stanford_robocook_converted_externally_to_rlds", 1.0),
|
43 |
+
("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
|
44 |
+
("utaustin_mutex", 1.0),
|
45 |
+
# ("cmu_playing_with_food", 1.0),
|
46 |
+
("cmu_play_fusion", 1.0),
|
47 |
+
]
|
48 |
+
|
49 |
+
OXE_MAGIC_SOUP = [
|
50 |
+
("fractal20220817_data", 0.54087122203),
|
51 |
+
("kuka", 0.8341046294),
|
52 |
+
("bridge", 1.0),
|
53 |
+
("taco_play", 2.0),
|
54 |
+
("jaco_play", 1.0),
|
55 |
+
("berkeley_cable_routing", 1.0),
|
56 |
+
("roboturk", 2.0),
|
57 |
+
("nyu_door_opening_surprising_effectiveness", 1.0),
|
58 |
+
("viola", 2.0),
|
59 |
+
("berkeley_autolab_ur5", 2.0),
|
60 |
+
("toto", 1.0),
|
61 |
+
("language_table", 0.1),
|
62 |
+
("stanford_hydra_dataset_converted_externally_to_rlds", 2.0),
|
63 |
+
("austin_buds_dataset_converted_externally_to_rlds", 1.0),
|
64 |
+
("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
|
65 |
+
("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
|
66 |
+
("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0),
|
67 |
+
("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
|
68 |
+
("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
|
69 |
+
("bc_z", 0.2),
|
70 |
+
("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
|
71 |
+
("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
|
72 |
+
# ("uiuc_d3field", 1.0), --> somehow raw data is broken
|
73 |
+
("utaustin_mutex", 1.0),
|
74 |
+
("berkeley_fanuc_manipulation", 2.0),
|
75 |
+
("cmu_stretch", 1.0),
|
76 |
+
]
|
77 |
+
|
78 |
+
|
79 |
+
OXE_FULL_MIX = [
|
80 |
+
("fractal20220817_data", 1.0),
|
81 |
+
("kuka", 1.0),
|
82 |
+
("bridge_dataset", 1),
|
83 |
+
("taco_play", 1.0),
|
84 |
+
("jaco_play", 1.0),
|
85 |
+
("berkeley_cable_routing", 1.0),
|
86 |
+
("roboturk", 1.0),
|
87 |
+
("nyu_door_opening_surprising_effectiveness", 1.0),
|
88 |
+
("viola", 1.0),
|
89 |
+
("berkeley_autolab_ur5", 1.0),
|
90 |
+
("toto", 1.0),
|
91 |
+
("language_table", 1.0),
|
92 |
+
("columbia_cairlab_pusht_real", 1.0),
|
93 |
+
("stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 1.0),
|
94 |
+
("nyu_rot_dataset_converted_externally_to_rlds", 1.0),
|
95 |
+
("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
|
96 |
+
("austin_buds_dataset_converted_externally_to_rlds", 1.0),
|
97 |
+
("nyu_franka_play_dataset_converted_externally_to_rlds", 1.0),
|
98 |
+
("maniskill_dataset_converted_externally_to_rlds", 1.0),
|
99 |
+
("furniture_bench_dataset_converted_externally_to_rlds", 1.0),
|
100 |
+
("cmu_franka_exploration_dataset_converted_externally_to_rlds", 1.0),
|
101 |
+
("ucsd_kitchen_dataset_converted_externally_to_rlds", 1.0),
|
102 |
+
("ucsd_pick_and_place_dataset_converted_externally_to_rlds", 1.0),
|
103 |
+
("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
|
104 |
+
("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
|
105 |
+
("bc_z", 1.0),
|
106 |
+
("utokyo_pr2_opening_fridge_converted_externally_to_rlds", 1.0),
|
107 |
+
("utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 1.0),
|
108 |
+
("utokyo_xarm_pick_and_place_converted_externally_to_rlds", 1.0),
|
109 |
+
("utokyo_xarm_bimanual_converted_externally_to_rlds", 1.0),
|
110 |
+
("robo_net", 1.0),
|
111 |
+
("berkeley_mvp_converted_externally_to_rlds", 1.0),
|
112 |
+
("berkeley_rpt_converted_externally_to_rlds", 1.0),
|
113 |
+
("kaist_nonprehensile_converted_externally_to_rlds", 1.0),
|
114 |
+
("stanford_mask_vit_converted_externally_to_rlds", 1.0),
|
115 |
+
("tokyo_u_lsmo_converted_externally_to_rlds", 1.0),
|
116 |
+
("dlr_sara_pour_converted_externally_to_rlds", 1.0),
|
117 |
+
("dlr_sara_grid_clamp_converted_externally_to_rlds", 1.0),
|
118 |
+
("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
|
119 |
+
("asu_table_top_converted_externally_to_rlds", 1.0),
|
120 |
+
("stanford_robocook_converted_externally_to_rlds", 1.0),
|
121 |
+
("imperialcollege_sawyer_wrist_cam", 1.0),
|
122 |
+
("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
|
123 |
+
("uiuc_d3field", 1.0),
|
124 |
+
("utaustin_mutex", 1.0),
|
125 |
+
("berkeley_fanuc_manipulation", 1.0),
|
126 |
+
("cmu_playing_with_food", 1.0),
|
127 |
+
("cmu_play_fusion", 1.0),
|
128 |
+
("cmu_stretch", 1.0),
|
129 |
+
("berkeley_gnm_recon", 1.0),
|
130 |
+
("berkeley_gnm_cory_hall", 1.0),
|
131 |
+
("berkeley_gnm_sac_son", 1.0),
|
132 |
+
]
|
133 |
+
|
134 |
+
OXE_NAMED_MIXES = {
|
135 |
+
"bridge": BRIDGE_MIX,
|
136 |
+
"rtx": RT_X_MIX,
|
137 |
+
"rtx_franka": RT_X_MIX + OXE_FRANKA_MIX,
|
138 |
+
"oxe_magic_soup": OXE_MAGIC_SOUP,
|
139 |
+
}
|
theia/dataset/oxe/oxe_transforms.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from numpy.typing import NDArray
|
5 |
+
from torchvision.transforms.v2 import Compose, Normalize, ToDtype, ToImage
|
6 |
+
|
7 |
+
|
8 |
+
def totensor(arr: NDArray) -> torch.Tensor:
|
9 |
+
"""Convert ndarray to tensor."""
|
10 |
+
return torch.from_numpy(arr)
|
11 |
+
|
12 |
+
|
13 |
+
oxe_image_transform = Compose(
|
14 |
+
[ToImage(), ToDtype(torch.float32, scale=True), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
|
15 |
+
) # ImageNet statistics normalization
|
theia/dataset/video/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from .video_common import ALL_VIDEO_DATASETS
|
theia/dataset/video/video_common.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
ALL_VIDEO_DATASETS = OrderedDict(
|
6 |
+
{
|
7 |
+
"ego4d_1in150": {"steps": 2_800_871},
|
8 |
+
"epic_kitchen_1in60": {"steps": 333_117},
|
9 |
+
"ssv2_1in32": {"steps": 312_772},
|
10 |
+
}
|
11 |
+
)
|
theia/decoding/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from .decode import decode_everything, load_feature_stats
|
4 |
+
from .depth_anything import prepare_depth_decoder
|
5 |
+
from .sam import prepare_mask_generator
|
theia/decoding/decode.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from einops import rearrange
|
10 |
+
from numpy.typing import NDArray
|
11 |
+
from PIL import Image
|
12 |
+
from sklearn.decomposition import PCA
|
13 |
+
from transformers import SamModel, SamProcessor
|
14 |
+
from transformers.pipelines import MaskGenerationPipeline
|
15 |
+
|
16 |
+
from theia.decoding.depth_anything import decode_depth_anything
|
17 |
+
from theia.decoding.dinov2 import decode_dinov2
|
18 |
+
from theia.decoding.sam import decode_sam
|
19 |
+
from theia.preprocessing.feature_extraction_core import (
|
20 |
+
get_feature_outputs,
|
21 |
+
get_model,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def denormalize_feature(
|
26 |
+
x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None
|
27 |
+
) -> torch.Tensor:
|
28 |
+
"""Denormalize the features using mean and std.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
x (torch.Tensor): features to be denomalized.
|
32 |
+
mean (Optional[torch.Tensor], optional): mean value of the features. Defaults to None
|
33 |
+
std (Optional[torch.Tensor], optional): std value of the features. Defaults to None.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
torch.Tensor: denormalized features.
|
37 |
+
"""
|
38 |
+
if mean is None and std is None:
|
39 |
+
return x
|
40 |
+
elif mean is None and std is not None:
|
41 |
+
return x * std
|
42 |
+
elif mean is not None and std is None:
|
43 |
+
return x + mean
|
44 |
+
return x * std + mean
|
45 |
+
|
46 |
+
|
47 |
+
def load_feature_stats(
|
48 |
+
feature_models: list[str], stat_file_root: str
|
49 |
+
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
50 |
+
"""Load the statistics (mean and variance) of the features, per model.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
feature_models (list[str]): names of the models. Note: there are `/` in the name.
|
54 |
+
stat_file_root (str): directory that holds feature stat files.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variance.
|
58 |
+
"""
|
59 |
+
feature_means: dict[str, torch.Tensor] = {}
|
60 |
+
feature_vars: dict[str, torch.Tensor] = {}
|
61 |
+
for model in feature_models:
|
62 |
+
model_name = model.replace("/", "_")
|
63 |
+
feature_means[model] = torch.from_numpy(
|
64 |
+
np.load(os.path.join(stat_file_root, f"imagenet_mean_{model_name}.npy"))
|
65 |
+
)
|
66 |
+
feature_vars[model] = torch.from_numpy(np.load(os.path.join(stat_file_root, f"imagenet_var_{model_name}.npy")))
|
67 |
+
return feature_means, feature_vars
|
68 |
+
|
69 |
+
|
70 |
+
def decode_everything(
|
71 |
+
theia_model: nn.Module,
|
72 |
+
feature_means: dict[str, torch.Tensor],
|
73 |
+
feature_vars: dict[str, torch.Tensor],
|
74 |
+
images: list[Image.Image],
|
75 |
+
mask_generator: MaskGenerationPipeline,
|
76 |
+
sam_model: SamModel,
|
77 |
+
depth_anything_decoder: nn.Module,
|
78 |
+
pred_iou_thresh: float = 0.9,
|
79 |
+
stability_score_thresh: float = 0.9,
|
80 |
+
gt: bool = False,
|
81 |
+
pca: Optional[PCA] = None,
|
82 |
+
device: int | str | torch.device = 0,
|
83 |
+
) -> tuple[list[NDArray], Optional[list[NDArray]]]:
|
84 |
+
"""Decode features from given `theia_model` into different outputs corresponding to upstream models including
|
85 |
+
DINOv2, Sam, and Depth-Anything.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
theia_model (nn.Module): theia model.
|
89 |
+
feature_means (dict[str, torch.Tensor]): means of the features for denormalization.
|
90 |
+
feature_vars (dict[str, torch.Tensor]): variance of the features for denormalization.
|
91 |
+
images (list[Image.Image]): input images.
|
92 |
+
mask_generator (MaskGenerationPipeline): mask generation pipeline.
|
93 |
+
sam_model (SamModel): sam model.
|
94 |
+
depth_anything_decoder (nn.Module): depth anything decoder.
|
95 |
+
pred_iou_thresh (float, optional): iou threshold for mask generation.
|
96 |
+
See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9.
|
97 |
+
stability_score_thresh (float, optional): stability score threshold for mask generation.
|
98 |
+
See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9.
|
99 |
+
gt (bool): whether to attach ground truth result in the visualization. Defaults to False.
|
100 |
+
pca (Optional[PCA]): pca for DINOv2 decoding. If provided, will use this pca particular. Defaults to None.
|
101 |
+
device (int | str | torch.device, optional): device for decoding. Defaults to 0.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
tuple[list[NDArray], Optional[list[NDArray]]]: decoding results from given model,
|
105 |
+
and ground truth (if `gt=True`).
|
106 |
+
"""
|
107 |
+
features: dict[str, torch.Tensor] = {}
|
108 |
+
with torch.no_grad():
|
109 |
+
for im in images:
|
110 |
+
feature = theia_model([im])
|
111 |
+
if len(features) == 0:
|
112 |
+
features = {k: [] for k in feature}
|
113 |
+
for k in feature:
|
114 |
+
features[k].append(feature[k].detach().cpu())
|
115 |
+
for k in features:
|
116 |
+
features[k] = torch.cat(features[k], dim=0)
|
117 |
+
for m in features:
|
118 |
+
features[m] = denormalize_feature(features[m], feature_means[m], feature_vars[m])
|
119 |
+
|
120 |
+
dino_model_name = "facebook/dinov2-large"
|
121 |
+
sam_model_name = "facebook/sam-vit-huge"
|
122 |
+
depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
|
123 |
+
|
124 |
+
pca = None
|
125 |
+
# gt
|
126 |
+
gt_decode_results = None
|
127 |
+
if gt:
|
128 |
+
def legit_model_name(model_name: str) -> str:
|
129 |
+
return model_name.replace("/", "_")
|
130 |
+
|
131 |
+
dino_model, dino_processor = get_model(dino_model_name, device=device)
|
132 |
+
dino_gt_feature = []
|
133 |
+
for im in images:
|
134 |
+
dino_gt_feature.append(
|
135 |
+
get_feature_outputs(
|
136 |
+
legit_model_name(dino_model_name), dino_model, dino_processor, [im], dtype=torch.float
|
137 |
+
)[legit_model_name(dino_model_name)]["embedding"]
|
138 |
+
.detach()
|
139 |
+
.cpu()
|
140 |
+
)
|
141 |
+
dino_gt_feature = torch.cat(dino_gt_feature, dim=0)
|
142 |
+
dino_gt_feature = rearrange(dino_gt_feature, "b c h w -> b (h w) c")
|
143 |
+
dino_gt_dec, pca = decode_dinov2(dino_gt_feature, pca=pca)
|
144 |
+
sam_processor = SamProcessor.from_pretrained(sam_model_name)
|
145 |
+
sam_gt_feature = []
|
146 |
+
for im in images:
|
147 |
+
sam_inputs = sam_processor(images=[im], return_tensors="pt").to(device)
|
148 |
+
with torch.no_grad():
|
149 |
+
sam_gt_feature.append(sam_model.get_image_embeddings(sam_inputs["pixel_values"]).detach().cpu())
|
150 |
+
sam_gt_feature = torch.cat(sam_gt_feature, dim=0)
|
151 |
+
sam_gt_feature = rearrange(sam_gt_feature, "b c h w -> b (h w) c")
|
152 |
+
sam_gt_dec = decode_sam(
|
153 |
+
sam_gt_feature, images, mask_generator, pred_iou_thresh=0.9, stability_score_thresh=0.9, device=device
|
154 |
+
)
|
155 |
+
depth_anything_model, depth_anything_processor = get_model(depth_anything_model_name, device=device)
|
156 |
+
depth_anything_gt_feature = []
|
157 |
+
for im in images:
|
158 |
+
depth_anything_gt_feature.append(
|
159 |
+
get_feature_outputs(
|
160 |
+
legit_model_name(depth_anything_model_name),
|
161 |
+
depth_anything_model,
|
162 |
+
depth_anything_processor,
|
163 |
+
[im],
|
164 |
+
dtype=torch.float,
|
165 |
+
)[legit_model_name(depth_anything_model_name)]["embedding"]
|
166 |
+
.detach()
|
167 |
+
.cpu()
|
168 |
+
)
|
169 |
+
depth_anything_gt_feature = torch.cat(depth_anything_gt_feature, dim=0)
|
170 |
+
depth_anything_gt_feature = rearrange(depth_anything_gt_feature, "b c h w -> b (h w) c")
|
171 |
+
depth_gt_dec = decode_depth_anything(depth_anything_gt_feature, depth_anything_decoder, device=device)
|
172 |
+
|
173 |
+
gt_decode_results = [
|
174 |
+
np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_gt_dec[i], sam_gt_dec[i], depth_gt_dec[i]])
|
175 |
+
for i in range(len(images))
|
176 |
+
]
|
177 |
+
|
178 |
+
dino_dec, _ = decode_dinov2(features[dino_model_name], pca=pca)
|
179 |
+
|
180 |
+
try:
|
181 |
+
sam_dec = decode_sam(
|
182 |
+
features[sam_model_name],
|
183 |
+
images,
|
184 |
+
mask_generator,
|
185 |
+
pred_iou_thresh=pred_iou_thresh,
|
186 |
+
stability_score_thresh=stability_score_thresh,
|
187 |
+
device=device,
|
188 |
+
)
|
189 |
+
except IndexError:
|
190 |
+
sam_dec = np.zeros_like(dino_dec)
|
191 |
+
depth_dec = decode_depth_anything(features[depth_anything_model_name], depth_anything_decoder, device=device)
|
192 |
+
|
193 |
+
theia_decode_results = [
|
194 |
+
np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_dec[i], sam_dec[i], depth_dec[i]])
|
195 |
+
for i in range(len(images))
|
196 |
+
]
|
197 |
+
|
198 |
+
return theia_decode_results, gt_decode_results
|
theia/decoding/depth_anything.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from theia.foundation_models.vision_models.depth_anything import DepthAnythingForDepthEstimation
|
7 |
+
from numpy.typing import NDArray
|
8 |
+
from torch.nn.functional import interpolate
|
9 |
+
|
10 |
+
|
11 |
+
def prepare_depth_decoder(model_name: str, device: int | str | torch.device = 0) -> tuple[nn.Module, int]:
|
12 |
+
"""Prepare a depth decoder using DepthAnythingForDepthEstimation.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
model_name (str): name of the depth anything model.
|
16 |
+
device (int | str | torch.device, optional): device to put the model on. Defaults to 0.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
tuple[nn.Module, int]: the decoder, and the patch size for depth anything model.
|
20 |
+
"""
|
21 |
+
decoder_head = DepthAnythingForDepthEstimation.from_pretrained(model_name)
|
22 |
+
patch_size = decoder_head.config.patch_size
|
23 |
+
decoder_head = decoder_head.head
|
24 |
+
decoder_head = decoder_head.to(device)
|
25 |
+
return decoder_head, patch_size
|
26 |
+
|
27 |
+
|
28 |
+
def decode_depth_anything(features: torch.Tensor, decoder: nn.Module, device: int | str | torch.device = 0) -> NDArray:
|
29 |
+
"""Decode features to predicted depth using depth anything
|
30 |
+
|
31 |
+
Args:
|
32 |
+
features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
|
33 |
+
decoder (nn.Module): depth anything decoder
|
34 |
+
device (int | str | torch.device, optional): device to perform the decoding. Defaults to 0.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
NDArray: decoded depth in image format, represented by an NDArray in size [batch_size, height, width, channels]
|
38 |
+
with value between [0, 1]. The depth values are min-max normalized to [0, 1] to generate images.
|
39 |
+
"""
|
40 |
+
with torch.no_grad():
|
41 |
+
P = int(features.size(1) ** 0.5)
|
42 |
+
features = rearrange(features, "b (h w) c -> b c h w", h=P, w=P)
|
43 |
+
features = interpolate(features, (224, 224))
|
44 |
+
predicted_depths = []
|
45 |
+
for feature in features:
|
46 |
+
feature = feature.unsqueeze(0).to(device)
|
47 |
+
|
48 |
+
predicted_depth = decoder.activation1(feature)
|
49 |
+
predicted_depth = decoder.conv3(predicted_depth)
|
50 |
+
predicted_depth = decoder.activation2(predicted_depth)
|
51 |
+
predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width)
|
52 |
+
for i in range(len(predicted_depth)):
|
53 |
+
min_depth, max_depth = predicted_depth[i].min(), predicted_depth[i].max()
|
54 |
+
predicted_depth[i] = (predicted_depth[i] - min_depth) / (max_depth - min_depth)
|
55 |
+
predicted_depths.append(predicted_depth.detach().cpu())
|
56 |
+
predicted_depths = torch.cat(predicted_depths, dim=0)
|
57 |
+
return predicted_depths.unsqueeze(-1).repeat((1, 1, 1, 3)).numpy() # type: ignore [attr-defined]
|
theia/decoding/dinov2.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from numpy.typing import NDArray
|
8 |
+
from sklearn.decomposition import PCA
|
9 |
+
from sklearn.preprocessing import minmax_scale
|
10 |
+
|
11 |
+
|
12 |
+
def decode_dinov2(
|
13 |
+
features: NDArray, threshold: int | float = -100, interpolation: bool = False, pca: Optional[PCA] = None
|
14 |
+
) -> tuple[NDArray, PCA]:
|
15 |
+
"""
|
16 |
+
Decode the input `features` in DINOv2 style using PCA.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
features (NDArray): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
|
20 |
+
threshold (int | float): threshold of foreground-background split in PCA visualization.
|
21 |
+
Defaults to -100 (all patches are included).
|
22 |
+
interpolation (bool): whether interpolate the 16x16 pca map to the original image size.
|
23 |
+
pca (Optional[PCA]): if provided, use the provided PCA. This is to keep visualizations stable across samples.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
tuple[NDArray, PCA]: the rendered image of this visualization, in NDArray in size
|
27 |
+
[batch_size, height, width, channels] with value ranges [0, 1], and the PCA used in this visualization.
|
28 |
+
"""
|
29 |
+
features = features.numpy()
|
30 |
+
batch_size, spatial_size, latent_dim = features.shape
|
31 |
+
h = w = int(spatial_size**0.5)
|
32 |
+
|
33 |
+
features = features.reshape(-1, latent_dim)
|
34 |
+
|
35 |
+
if pca is None:
|
36 |
+
pca = PCA(n_components=3)
|
37 |
+
pca.fit(features)
|
38 |
+
|
39 |
+
pca_features = pca.transform(features)
|
40 |
+
|
41 |
+
# segment using the first component
|
42 |
+
bg_mask = pca_features[:, 0] < threshold
|
43 |
+
fg_mask = ~bg_mask
|
44 |
+
|
45 |
+
# PCA for only foreground patches
|
46 |
+
# pca.fit(features[fg_mask])
|
47 |
+
pca_features_fg = pca.transform(features[fg_mask])
|
48 |
+
for i in range(3):
|
49 |
+
pca_features_fg[:, i] = minmax_scale(pca_features_fg[:, i])
|
50 |
+
|
51 |
+
pca_features_rgb = pca_features.copy()
|
52 |
+
pca_features_rgb[bg_mask] = 0
|
53 |
+
pca_features_rgb[fg_mask] = pca_features_fg
|
54 |
+
|
55 |
+
pca_features_rgb = pca_features_rgb.reshape(batch_size, h, w, 3)
|
56 |
+
if not interpolation:
|
57 |
+
H = W = 224
|
58 |
+
scale = H // h
|
59 |
+
interpolated_pca_features = np.zeros((batch_size, H, W, 3), dtype=pca_features_rgb.dtype)
|
60 |
+
for i in range(len(pca_features_rgb)):
|
61 |
+
for j in range(h):
|
62 |
+
for k in range(w):
|
63 |
+
interpolated_pca_features[i, scale * j : scale * (j + 1), scale * k : scale * (k + 1)] = (
|
64 |
+
pca_features_rgb[i, j, k]
|
65 |
+
)
|
66 |
+
pca_features_rgb = interpolated_pca_features
|
67 |
+
else:
|
68 |
+
pca_features_rgb = np.stack([cv2.resize(p, (224, 224)) for p in pca_features_rgb])
|
69 |
+
return pca_features_rgb, pca
|
theia/decoding/sam.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from typing import Any, Generator, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from einops import rearrange
|
8 |
+
from numpy.typing import NDArray
|
9 |
+
from PIL import Image
|
10 |
+
from transformers import SamModel, SamProcessor
|
11 |
+
from transformers.image_utils import load_image
|
12 |
+
from transformers.pipelines import MaskGenerationPipeline
|
13 |
+
|
14 |
+
|
15 |
+
class MaskGenerationPipelineWithEmbeddings(MaskGenerationPipeline):
|
16 |
+
"""
|
17 |
+
The wrapper class for huggingface transformers.pipelines.MaskGenerationPipeline
|
18 |
+
that can decode from intermediate SAM embeddings.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def _sanitize_parameters(self, **kwargs: Any) -> tuple[dict[str, Any], ...]:
|
22 |
+
preprocess_kwargs = {}
|
23 |
+
postprocess_kwargs = {}
|
24 |
+
forward_params = {}
|
25 |
+
# preprocess args
|
26 |
+
if "embeddings" in kwargs: # inject embeddings here
|
27 |
+
preprocess_kwargs["embeddings"] = kwargs["embeddings"]
|
28 |
+
if "points_per_batch" in kwargs:
|
29 |
+
preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
|
30 |
+
if "points_per_crop" in kwargs:
|
31 |
+
preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
|
32 |
+
if "crops_n_layers" in kwargs:
|
33 |
+
preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
|
34 |
+
if "crop_overlap_ratio" in kwargs:
|
35 |
+
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
|
36 |
+
if "crop_n_points_downscale_factor" in kwargs:
|
37 |
+
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
|
38 |
+
if "timeout" in kwargs:
|
39 |
+
preprocess_kwargs["timeout"] = kwargs["timeout"]
|
40 |
+
# postprocess args
|
41 |
+
if "pred_iou_thresh" in kwargs:
|
42 |
+
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
|
43 |
+
if "stability_score_offset" in kwargs:
|
44 |
+
forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
|
45 |
+
if "mask_threshold" in kwargs:
|
46 |
+
forward_params["mask_threshold"] = kwargs["mask_threshold"]
|
47 |
+
if "stability_score_thresh" in kwargs:
|
48 |
+
forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
|
49 |
+
if "crops_nms_thresh" in kwargs:
|
50 |
+
postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
|
51 |
+
if "output_rle_mask" in kwargs:
|
52 |
+
postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
|
53 |
+
if "output_bboxes_mask" in kwargs:
|
54 |
+
postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
|
55 |
+
return preprocess_kwargs, forward_params, postprocess_kwargs
|
56 |
+
|
57 |
+
def preprocess(
|
58 |
+
self,
|
59 |
+
image: list[Image.Image],
|
60 |
+
points_per_batch: int = 64,
|
61 |
+
crops_n_layers: int = 0,
|
62 |
+
crop_overlap_ratio: float = 512 / 1500,
|
63 |
+
points_per_crop: int = 32,
|
64 |
+
crop_n_points_downscale_factor: int = 1,
|
65 |
+
timeout: Optional[float] = None,
|
66 |
+
embeddings: Optional[torch.Tensor] = None,
|
67 |
+
) -> Generator[Any, Any, Any]:
|
68 |
+
image = load_image(image, timeout=timeout)
|
69 |
+
target_size = self.image_processor.size["longest_edge"]
|
70 |
+
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
|
71 |
+
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
|
72 |
+
)
|
73 |
+
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
|
74 |
+
|
75 |
+
with self.device_placement():
|
76 |
+
if self.framework == "pt":
|
77 |
+
inference_context = self.get_inference_context()
|
78 |
+
with inference_context():
|
79 |
+
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
|
80 |
+
if embeddings is None:
|
81 |
+
image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
|
82 |
+
else:
|
83 |
+
model_inputs.pop("pixel_values")
|
84 |
+
image_embeddings = embeddings
|
85 |
+
model_inputs["image_embeddings"] = image_embeddings
|
86 |
+
|
87 |
+
n_points = grid_points.shape[1]
|
88 |
+
points_per_batch = points_per_batch if points_per_batch is not None else n_points
|
89 |
+
|
90 |
+
if points_per_batch <= 0:
|
91 |
+
raise ValueError(
|
92 |
+
"Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
|
93 |
+
"To return all points at once, set points_per_batch to None"
|
94 |
+
)
|
95 |
+
|
96 |
+
for i in range(0, n_points, points_per_batch):
|
97 |
+
batched_points = grid_points[:, i : i + points_per_batch, :, :]
|
98 |
+
labels = input_labels[:, i : i + points_per_batch]
|
99 |
+
is_last = i == n_points - points_per_batch
|
100 |
+
yield {
|
101 |
+
"input_points": batched_points,
|
102 |
+
"input_labels": labels,
|
103 |
+
"input_boxes": crop_boxes,
|
104 |
+
"is_last": is_last,
|
105 |
+
**model_inputs,
|
106 |
+
}
|
107 |
+
|
108 |
+
|
109 |
+
def draw_mask(mask: NDArray, random_color: bool = False) -> NDArray:
|
110 |
+
"""Draw the mask on an image.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
mask (NDArray): mask in shape [height, width].
|
114 |
+
random_color (bool): if using a random color. Defaults to False.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
NDArray: NDArray format of the image.
|
118 |
+
"""
|
119 |
+
if random_color:
|
120 |
+
color = np.concatenate([np.random.random(3)], axis=0)
|
121 |
+
else:
|
122 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255])
|
123 |
+
h, w = mask.shape[-2:]
|
124 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
125 |
+
return mask_image
|
126 |
+
|
127 |
+
|
128 |
+
def decode_sam(
|
129 |
+
features: torch.Tensor,
|
130 |
+
images: list[Image.Image],
|
131 |
+
mask_generator: Any,
|
132 |
+
points_per_batch: int = 64,
|
133 |
+
pred_iou_thresh: float = 0.5,
|
134 |
+
stability_score_thresh: float = 0.6,
|
135 |
+
random_color: bool = True,
|
136 |
+
device: int | str | torch.device = 0,
|
137 |
+
) -> NDArray:
|
138 |
+
"""Decode features using SAM (auto-prompting) mask generation pipeline.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
|
142 |
+
images (list[Image.Image]): images corresponding to these features.
|
143 |
+
mask_generator (Any): mask generation pipeline.
|
144 |
+
points_per_batch (int): points per batch for auto-prompting. Defaults to 64.
|
145 |
+
See transformers.pipelines.MaskGenerationPipeline for more details. Same below.
|
146 |
+
pred_iou_thresh (float): iou threshold. Defaults to 0.5.
|
147 |
+
stability_score_thresh (float): stability threshold. Defaults to 0.6.
|
148 |
+
random_color (bool): if using a random color. Defaults to True.
|
149 |
+
device (int | str | torch.device): device to perform the decoding. Defaults to 0.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
NDArray: decoded masks rendered in image format, represented by an NDArray in size
|
153 |
+
[batch_size, height, width, channels] with value between [0, 1].
|
154 |
+
"""
|
155 |
+
masks_rgbs = []
|
156 |
+
num_patches = int(features.size(1) ** 0.5)
|
157 |
+
features = rearrange(features, "b (h w) c -> b c h w", h=num_patches, w=num_patches)
|
158 |
+
with torch.no_grad():
|
159 |
+
for im, feature in zip(images, features, strict=False):
|
160 |
+
predicted_ouputs = mask_generator(
|
161 |
+
im,
|
162 |
+
points_per_batch=points_per_batch,
|
163 |
+
embeddings=feature.unsqueeze(0).to(device),
|
164 |
+
pred_iou_thresh=pred_iou_thresh,
|
165 |
+
stability_score_thresh=stability_score_thresh,
|
166 |
+
)
|
167 |
+
predicted_masks = predicted_ouputs["masks"]
|
168 |
+
masks_rgb = np.zeros((224, 224, 3), dtype=np.float32)
|
169 |
+
for mask in predicted_masks:
|
170 |
+
masks_rgb += draw_mask(mask, random_color=random_color)
|
171 |
+
# masks_rgb = cv2.cvtColor(masks_rgb, cv2.COLOR_RGBA2RGB)
|
172 |
+
masks_rgbs.append(masks_rgb)
|
173 |
+
return np.stack(masks_rgbs)
|
174 |
+
|
175 |
+
|
176 |
+
def prepare_mask_generator(device: int | str | torch.device = 0) -> MaskGenerationPipeline:
|
177 |
+
"""Prepare a mask generation pipeline on device `device`.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
device (int | str | torch.device): device to perform mask generation. Defaults to 0.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
MaskGenerationPipeline: mask generator.
|
184 |
+
"""
|
185 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
186 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
187 |
+
sam_model.eval()
|
188 |
+
mask_generator = MaskGenerationPipelineWithEmbeddings(
|
189 |
+
task="mask_generation", model=sam_model, image_processor=processor.image_processor, device=device
|
190 |
+
)
|
191 |
+
return mask_generator, sam_model
|
theia/example/decode_to_vfms.ipynb
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"import cv2\n",
|
11 |
+
"import torch\n",
|
12 |
+
"from PIL import Image\n",
|
13 |
+
"import numpy as np\n",
|
14 |
+
"from transformers import AutoModel\n",
|
15 |
+
"from torchvision.io import read_video, write_video\n",
|
16 |
+
"from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything\n",
|
17 |
+
"\n",
|
18 |
+
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
|
19 |
+
"theia_model = AutoModel.from_pretrained(\"theaiinstitute/theia-base-patch16-224-cdiv\", trust_remote_code=True)\n",
|
20 |
+
"theia_model = theia_model.to(device)\n",
|
21 |
+
"target_model_names = [\n",
|
22 |
+
" \"google/vit-huge-patch14-224-in21k\",\n",
|
23 |
+
" \"facebook/dinov2-large\",\n",
|
24 |
+
" \"openai/clip-vit-large-patch14\",\n",
|
25 |
+
" \"facebook/sam-vit-huge\",\n",
|
26 |
+
" \"LiheYoung/depth-anything-large-hf\",\n",
|
27 |
+
"]\n",
|
28 |
+
"feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root=\"../../../feature_stats\")\n",
|
29 |
+
"\n",
|
30 |
+
"mask_generator, sam_model = prepare_mask_generator(device)\n",
|
31 |
+
"depth_anything_model_name = \"LiheYoung/depth-anything-large-hf\"\n",
|
32 |
+
"depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, device)\n",
|
33 |
+
"\n",
|
34 |
+
"example_video_path = \"../../../media/example_video_to_visualize.mp4\"\n",
|
35 |
+
"video, _, _ = read_video(example_video_path, pts_unit=\"sec\", output_format=\"THWC\")\n",
|
36 |
+
"video = video.numpy()\n",
|
37 |
+
"images = [Image.fromarray(cv2.resize(im, (224, 224))) for im in video]\n",
|
38 |
+
"\n",
|
39 |
+
"theia_decode_results, gt_decode_results = decode_everything(\n",
|
40 |
+
" theia_model=theia_model,\n",
|
41 |
+
" feature_means=feature_means,\n",
|
42 |
+
" feature_vars=feature_vars,\n",
|
43 |
+
" images=images,\n",
|
44 |
+
" mask_generator=mask_generator,\n",
|
45 |
+
" sam_model=sam_model,\n",
|
46 |
+
" depth_anything_decoder=depth_anything_decoder,\n",
|
47 |
+
" pred_iou_thresh=0.5,\n",
|
48 |
+
" stability_score_thresh=0.7,\n",
|
49 |
+
" gt=True,\n",
|
50 |
+
" device=device,\n",
|
51 |
+
")\n",
|
52 |
+
"\n",
|
53 |
+
"vis_video = np.stack(\n",
|
54 |
+
" [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)]\n",
|
55 |
+
")\n",
|
56 |
+
"vis_video = torch.from_numpy(vis_video * 255.0).to(torch.uint8)\n",
|
57 |
+
"vis_save_path = \"./visualized.mp4\"\n",
|
58 |
+
"write_video(vis_save_path, vis_video, fps=10)"
|
59 |
+
]
|
60 |
+
}
|
61 |
+
],
|
62 |
+
"metadata": {
|
63 |
+
"language_info": {
|
64 |
+
"name": "python"
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"nbformat": 4,
|
68 |
+
"nbformat_minor": 2
|
69 |
+
}
|
theia/foundation_models/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
from .vision_language_models.clip import get_clip_feature, get_clip_model
|
4 |
+
from .vision_language_models.llava import get_llava_vision_model, get_llava_visual_feature
|
5 |
+
from .vision_models.deit import get_deit_feature, get_deit_model
|
6 |
+
from .vision_models.depth_anything import get_depth_anything_feature, get_depth_anything_model
|
7 |
+
from .vision_models.dinov2 import get_dinov2_feature, get_dinov2_model
|
8 |
+
from .vision_models.sam import get_sam_feature, get_sam_model
|
9 |
+
from .vision_models.vit import get_vit_feature, get_vit_model
|
theia/foundation_models/common.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
MODELS = [
|
8 |
+
"facebook/dinov2-large",
|
9 |
+
"facebook/sam-vit-huge",
|
10 |
+
"google/vit-huge-patch14-224-in21k",
|
11 |
+
"llava-hf/llava-1.5-7b-hf",
|
12 |
+
"openai/clip-vit-large-patch14",
|
13 |
+
"LiheYoung/depth-anything-large-hf",
|
14 |
+
]
|
15 |
+
|
16 |
+
# handy model feature size constants
|
17 |
+
# in the format of (latent_dim, width, height)
|
18 |
+
MODEL_FEATURE_SIZES = {
|
19 |
+
"facebook/dinov2-large": (1024, 16, 16),
|
20 |
+
"facebook/sam-vit-huge": (256, 64, 64),
|
21 |
+
"google/vit-huge-patch14-224-in21k": (1280, 16, 16),
|
22 |
+
"llava-hf/llava-1.5-7b-hf": (1024, 24, 24),
|
23 |
+
"openai/clip-vit-large-patch14": (1024, 16, 16),
|
24 |
+
"LiheYoung/depth-anything-large-hf": (32, 64, 64),
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
def get_model_feature_size(
|
29 |
+
model_name: str, keep_spatial: bool = False, return_torch_size: bool = False
|
30 |
+
) -> tuple[int, ...] | torch.Size:
|
31 |
+
"""
|
32 |
+
Get the size of queried model feature.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
model_name (str): name of the model.
|
36 |
+
keep_spatial (bool): whether to preserve spatial dim. Defaults to False.
|
37 |
+
return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
tuple[int, ...] | torch.Size: the size of the feature.
|
41 |
+
"""
|
42 |
+
size: tuple[int, ...] = MODEL_FEATURE_SIZES[model_name]
|
43 |
+
|
44 |
+
if not keep_spatial:
|
45 |
+
size = (size[0], math.prod(size[1:]))
|
46 |
+
|
47 |
+
if return_torch_size:
|
48 |
+
size = torch.Size(size)
|
49 |
+
|
50 |
+
return size
|
51 |
+
|
52 |
+
|
53 |
+
def get_max_model_spatial_size(
|
54 |
+
keep_spatial: bool = True,
|
55 |
+
return_torch_size: bool = False,
|
56 |
+
return_model_name: bool = False,
|
57 |
+
) -> tuple[int, ...] | tuple[tuple[int, ...], str]:
|
58 |
+
"""Get the maximal spatial dimensions from available models
|
59 |
+
|
60 |
+
Args:
|
61 |
+
keep_spatial (bool): whether to preserve spatial dim. Defaults to True.
|
62 |
+
return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False.
|
63 |
+
return_model_name (bool): the name of the model with maximal size. Defaults to False.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
tuple[int, ...] | tuple[tuple[int, ...], str]: the maximal size and optional model name.
|
67 |
+
"""
|
68 |
+
max_flatten_size = -1
|
69 |
+
max_size: tuple[int, ...] = ()
|
70 |
+
max_size_model_name: str = ""
|
71 |
+
for model, size in MODEL_FEATURE_SIZES.items():
|
72 |
+
flatten_size = math.prod(size[1:])
|
73 |
+
if flatten_size > max_flatten_size:
|
74 |
+
max_flatten_size = flatten_size
|
75 |
+
max_size = size[1:]
|
76 |
+
max_size_model_name = model
|
77 |
+
|
78 |
+
if not keep_spatial:
|
79 |
+
max_size = (max_flatten_size,)
|
80 |
+
|
81 |
+
if return_torch_size:
|
82 |
+
max_size = torch.Size(max_size)
|
83 |
+
|
84 |
+
if return_model_name:
|
85 |
+
return max_size, max_size_model_name
|
86 |
+
else:
|
87 |
+
return max_size
|