bmay commited on
Commit
26791f7
1 Parent(s): 77b08da
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +24 -0
  2. theia/__init__.py +1 -0
  3. theia/configs/dataset/ego4d.yaml +5 -0
  4. theia/configs/dataset/epic_kitchen.yaml +5 -0
  5. theia/configs/dataset/image_video_default.yaml +7 -0
  6. theia/configs/dataset/image_video_mix.yaml +8 -0
  7. theia/configs/dataset/imagenet.yaml +5 -0
  8. theia/configs/dataset/oxe_octo_mix.yaml +12 -0
  9. theia/configs/dataset/ssv2.yaml +5 -0
  10. theia/configs/logging/default.yaml +6 -0
  11. theia/configs/model/backbone/deit.yaml +2 -0
  12. theia/configs/model/backbone/deit_nocls.yaml +2 -0
  13. theia/configs/model/backbone/deit_reg.yaml +3 -0
  14. theia/configs/model/translator/conv.yaml +3 -0
  15. theia/configs/model/translator/lconv.yaml +3 -0
  16. theia/configs/model/translator/mlp.yaml +4 -0
  17. theia/configs/model/translator/transformer.yaml +5 -0
  18. theia/configs/train_rvfm_imagenet.yaml +9 -0
  19. theia/configs/training/frame_level.yaml +35 -0
  20. theia/configs/training/target_models/cdds.yaml +6 -0
  21. theia/configs/training/target_models/cddsv.yaml +7 -0
  22. theia/configs/training/target_models/cddv.yaml +6 -0
  23. theia/configs/training/target_models/cdesv.yaml +6 -0
  24. theia/configs/training/target_models/cdis.yaml +5 -0
  25. theia/configs/training/target_models/cdisv.yaml +6 -0
  26. theia/configs/training/target_models/cdiv.yaml +5 -0
  27. theia/configs/training/target_models/clip.yaml +3 -0
  28. theia/configs/training/target_models/ddsv.yaml +6 -0
  29. theia/configs/training/target_models/depth_anything.yaml +3 -0
  30. theia/configs/training/target_models/dinov2.yaml +3 -0
  31. theia/configs/training/target_models/sam.yaml +3 -0
  32. theia/configs/training/target_models/vit.yaml +3 -0
  33. theia/dataset/__init__.py +5 -0
  34. theia/dataset/data_utils.py +591 -0
  35. theia/dataset/image/__init__.py +3 -0
  36. theia/dataset/image/image_common.py +5 -0
  37. theia/dataset/oxe/__init__.py +1 -0
  38. theia/dataset/oxe/oxe_common.py +430 -0
  39. theia/dataset/oxe/oxe_mixes.py +139 -0
  40. theia/dataset/oxe/oxe_transforms.py +15 -0
  41. theia/dataset/video/__init__.py +3 -0
  42. theia/dataset/video/video_common.py +11 -0
  43. theia/decoding/__init__.py +5 -0
  44. theia/decoding/decode.py +198 -0
  45. theia/decoding/depth_anything.py +57 -0
  46. theia/decoding/dinov2.py +69 -0
  47. theia/decoding/sam.py +191 -0
  48. theia/example/decode_to_vfms.ipynb +69 -0
  49. theia/foundation_models/__init__.py +9 -0
  50. theia/foundation_models/common.py +87 -0
LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2024 Boston Dynamics AI Institute LLC
2
+
3
+ Redistribution and use in source and binary forms, with or without
4
+ modification, are permitted provided that the following conditions are met:
5
+ 1. Redistributions of source code must retain the copyright notice included
6
+ with the software, this list of conditions and the following disclaimer.
7
+ 2. Redistributions in binary form must reproduce the copyright notice, this
8
+ list of conditions and the following disclaimer in the documentation and/or
9
+ other materials provided with the distribution.
10
+ 3. Modified versions of the software must be conspicuously marked as such.
11
+ 4. The software may only be used for non-commercial research purposes.
12
+ For profit enterprises may use the software, subject to this limitation.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE AI INSTITUTE AND CONTRIBUTORS "AS IS" AND
15
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, NON-
16
+ INFRINGEMENT,TITLE, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE AI INSTITUTE OR CONTRIBUTORS BE LIABLE FOR
18
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, DAMAGES ARISING OUT OF CLAIMS OF
20
+ INTELLECTUAL PROPERTY RIGHTS INFRINGEMENT; PROCUREMENT OF SUBSTITUTE GOODS OR
21
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
theia/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
theia/configs/dataset/ego4d.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - image_video_default
3
+
4
+ dataset_mix:
5
+ - "ego4d_1in150"
theia/configs/dataset/epic_kitchen.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - image_video_default
3
+
4
+ dataset_mix:
5
+ - "epic_kitchen_1in60"
theia/configs/dataset/image_video_default.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ return_metadata: False
2
+ shuffle: True
3
+ shuffle_buffer_size: 1024
4
+ feature_norm: True
5
+ dataset_root: "/storage/nfs/datasets/jshang/"
6
+ dataset_ratio: 0.1
7
+ load_action: False
theia/configs/dataset/image_video_mix.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - image_video_default
3
+
4
+ dataset_mix:
5
+ - "ego4d_1in150"
6
+ - "ssv2_1in32"
7
+ - "epic_kitchen_1in60"
8
+ - "imagenet"
theia/configs/dataset/imagenet.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - image_video_default
3
+
4
+ dataset_mix:
5
+ - "imagenet"
theia/configs/dataset/oxe_octo_mix.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: dataset.oxe.oxe_data_utils.OXEDataset
2
+ dataset_mix: "oxe_magic_soup"
3
+ image_action_set_root: "/storage/nfs/datasets/jshang/oxe_image_action"
4
+ feature_set_root: "/storage/nfs/datasets/jshang/oxe_vfm_features"
5
+ image_views: null
6
+ split: "train"
7
+ data_portion: 0.01
8
+ load_action: False
9
+ bf16: True
10
+ safe_tensors: True
11
+ trajectory_subsample_len: 32
12
+ return_metadata: False
theia/configs/dataset/ssv2.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - image_video_default
3
+
4
+ dataset_mix:
5
+ - "ssv2_1in32"
theia/configs/logging/default.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ model_path: "/storage/nfs/jshang/trained_models"
2
+ log_path: "/storage/nfs/jshang/logs"
3
+ save_ckpt_interval: 20000
4
+ notes: ""
5
+ run_identifier_prefix: ""
6
+ project: "theia"
theia/configs/model/backbone/deit.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ backbone: facebook/deit-small-patch16-224
2
+ pretrained: False
theia/configs/model/backbone/deit_nocls.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ backbone: nocls-facebook/deit-tiny-patch16-224
2
+ pretrained: False
theia/configs/model/backbone/deit_reg.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ backbone: reg-facebook/deit-tiny-patch16-224
2
+ pretrained: False
3
+ num_reg_tokens: 7
theia/configs/model/translator/conv.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: "conv"
2
+ kwargs:
3
+ translator_hidden_size: 1024
theia/configs/model/translator/lconv.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: "lconv"
2
+ kwargs:
3
+ hidden_size_factor: 1.0
theia/configs/model/translator/mlp.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ type: "mlp"
2
+ kwargs:
3
+ translator_n_layer: 3
4
+ hidden_size: 1024
theia/configs/model/translator/transformer.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ type: "transformer"
2
+ kwargs:
3
+ translator_n_layers: 2
4
+ translator_n_heads: 8
5
+ translator_hidden_size: 1024
theia/configs/train_rvfm_imagenet.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - dataset: imagenet
3
+ - model/backbone: deit
4
+ - model/translator: lconv
5
+ - training: frame_level
6
+ - logging: default
7
+ - _self_
8
+
9
+ seed: 0
theia/configs/training/frame_level.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - target_models: cdiv
3
+
4
+ epochs: 50
5
+ warm_up_steps_ratio: 0.1
6
+
7
+ base_lr: 2e-3
8
+ batch_size: 16
9
+ random_target_models: -1
10
+ num_workers: 8
11
+ # base training settings to scale lr, rarely changed
12
+ base_batch_size: 64
13
+ base_world_size: 8
14
+
15
+ weight_decay: 0.01
16
+
17
+
18
+ optimizer:
19
+ _target_: torch.optim.AdamW
20
+ betas: [0.9, 0.999]
21
+
22
+ lr_scheduler:
23
+ _target_: theia.lr_schedulers.get_constant_lrs_with_linear_warm_up
24
+ warm_up_lr_start_factor: 1e-2
25
+
26
+
27
+ grad_clip: False
28
+ grad_clip_norm_warmup: 10.0
29
+ grad_clip_norm: 1.0
30
+
31
+ freeze_translator: False
32
+ freeze_translator_start_steps_ratio: 0.2
33
+ translator_lr_factor: 1.0
34
+
35
+ main_loss: cos_l1
theia/configs/training/target_models/cdds.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ target_model_names:
2
+ - "facebook/dinov2-large"
3
+ - "openai/clip-vit-large-patch14"
4
+ - "facebook/sam-vit-huge"
5
+ - "LiheYoung/depth-anything-large-hf"
6
+ target_model_weights: null
theia/configs/training/target_models/cddsv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ target_model_names:
2
+ - "google/vit-huge-patch14-224-in21k"
3
+ - "facebook/dinov2-large"
4
+ - "openai/clip-vit-large-patch14"
5
+ - "facebook/sam-vit-huge"
6
+ - "LiheYoung/depth-anything-large-hf"
7
+ target_model_weights: null
theia/configs/training/target_models/cddv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ target_model_names:
2
+ - "google/vit-huge-patch14-224-in21k"
3
+ - "facebook/dinov2-large"
4
+ - "openai/clip-vit-large-patch14"
5
+ - "LiheYoung/depth-anything-large-hf"
6
+ target_model_weights: null
theia/configs/training/target_models/cdesv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ target_model_names:
2
+ - "google/vit-huge-patch14-224-in21k"
3
+ - "openai/clip-vit-large-patch14"
4
+ - "facebook/sam-vit-huge"
5
+ - "LiheYoung/depth-anything-large-hf"
6
+ target_model_weights: null
theia/configs/training/target_models/cdis.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ target_model_names:
2
+ - "facebook/dinov2-large"
3
+ - "openai/clip-vit-large-patch14"
4
+ - "facebook/sam-vit-huge"
5
+ target_model_weights: null
theia/configs/training/target_models/cdisv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ target_model_names:
2
+ - "google/vit-huge-patch14-224-in21k"
3
+ - "facebook/dinov2-large"
4
+ - "openai/clip-vit-large-patch14"
5
+ - "facebook/sam-vit-huge"
6
+ target_model_weights: null
theia/configs/training/target_models/cdiv.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ target_model_names:
2
+ - "google/vit-huge-patch14-224-in21k"
3
+ - "facebook/dinov2-large"
4
+ - "openai/clip-vit-large-patch14"
5
+ target_model_weights: null
theia/configs/training/target_models/clip.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ target_model_names:
2
+ - "openai/clip-vit-large-patch14"
3
+ target_model_weights: null
theia/configs/training/target_models/ddsv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ target_model_names:
2
+ - "google/vit-huge-patch14-224-in21k"
3
+ - "facebook/dinov2-large"
4
+ - "facebook/sam-vit-huge"
5
+ - "LiheYoung/depth-anything-large-hf"
6
+ target_model_weights: null
theia/configs/training/target_models/depth_anything.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ target_model_names:
2
+ - "LiheYoung/depth-anything-large-hf"
3
+ target_model_weights: null
theia/configs/training/target_models/dinov2.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ target_model_names:
2
+ - "facebook/dinov2-large"
3
+ target_model_weights: null
theia/configs/training/target_models/sam.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ target_model_names:
2
+ - "facebook/sam-vit-huge"
3
+ target_model_weights: null
theia/configs/training/target_models/vit.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ target_model_names:
2
+ - "google/vit-huge-patch14-224-in21k"
3
+ target_model_weights: null
theia/dataset/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from .image.image_common import ALL_IMAGE_DATASETS
4
+ from .oxe.oxe_common import ALL_OXE_DATASETS
5
+ from .video.video_common import ALL_VIDEO_DATASETS
theia/dataset/data_utils.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ """Defines PyTorch datasets of dataloaders for multiple image, video, and OXE datasets.
4
+ Should use with webdataset >= 0.2.90. See https://github.com/webdataset/webdataset/pull/347"""
5
+
6
+ import glob
7
+ import json
8
+ import math
9
+ import os.path as osp
10
+ from collections import OrderedDict
11
+ from functools import partial
12
+ from io import BytesIO
13
+ from typing import Any, Callable, Generator, Iterator, Literal, Optional
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import omegaconf
18
+ import torch
19
+ import webdataset as wds
20
+ from datasets.combine import DatasetType
21
+ from einops import rearrange
22
+ from numpy.typing import NDArray
23
+ from safetensors.torch import load as sft_load
24
+ from torch import default_generator
25
+ from torch.utils.data import DataLoader, Dataset, IterableDataset, default_collate
26
+
27
+ from theia.foundation_models.common import MODELS
28
+ from theia.dataset.oxe.oxe_common import ALL_OXE_DATASETS
29
+ from theia.dataset.oxe.oxe_mixes import OXE_NAMED_MIXES
30
+
31
+ PACKED_FEATURES = [model_name for model_name in MODELS if "llava" not in model_name]
32
+
33
+
34
+ def normalize_ds_weights_by_ds_len(weights: list[float], lengths: list[int]) -> tuple[list[float], float | Literal[0]]:
35
+ """Normalize dataset weights by dataset lengths (frames).
36
+
37
+ Args:
38
+ weights (list[float]): assigned weights.
39
+ lengths (list[int]): lengths of datasets.
40
+
41
+ Returns:
42
+ tuple[list[float], int]: normalized weights, and sum of the expected lengths of datasets
43
+ """
44
+ expected_lengths = [weight * length for weight, length in zip(weights, lengths, strict=False)]
45
+ sum_expected_lengths = sum(expected_lengths)
46
+ if sum_expected_lengths == 0:
47
+ raise ValueError("Sum of dataset length is 0.")
48
+ normalized_weights = [length * 1.0 / sum_expected_lengths for length in expected_lengths]
49
+ return normalized_weights, sum_expected_lengths
50
+
51
+
52
+ def get_vo_keys(dataset_name: str, image_views: Optional[list | str | dict[str, str | list[str]]] = None) -> list[str]:
53
+ """Get visual observation keys of datasets (to be compatible with OXE).
54
+
55
+ Args:
56
+ dataset_name (str): name of the dataset.
57
+ image_views (Optional[dict[str, str | list[str]]], optional): keys of selected views.
58
+ Defaults to None.
59
+
60
+ Returns:
61
+ list[str]: keys to the views in the dataset.
62
+ """
63
+ default_visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"][:1]
64
+ visual_observation_keys = []
65
+ if image_views is None:
66
+ visual_observation_keys = default_visual_observation_keys
67
+ elif isinstance(image_views, list):
68
+ visual_observation_keys = ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"]
69
+ elif isinstance(image_views, str):
70
+ if image_views == "static":
71
+ visual_observation_keys = [
72
+ k
73
+ for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"]
74
+ if "wrist" not in k and "hand" not in k
75
+ ]
76
+ elif image_views == "wrist":
77
+ visual_observation_keys = [
78
+ k for k in ALL_OXE_DATASETS[dataset_name]["visual_observation_keys"] if "wrist" in k or "hand" in k
79
+ ]
80
+ if len(visual_observation_keys) == 0:
81
+ visual_observation_keys = default_visual_observation_keys
82
+ return visual_observation_keys
83
+
84
+
85
+ class RandomMix(IterableDataset):
86
+ """A random interleave of multiple iterable datasets."""
87
+
88
+ def __init__(
89
+ self,
90
+ datasets: list[IterableDataset],
91
+ probs: list[float] | NDArray | None = None,
92
+ stopping_strategy: str = "all_exhausted",
93
+ seed: Optional[int | str] = 0,
94
+ ) -> None:
95
+ """Initialization of a random interleave dataset.
96
+
97
+ Args:
98
+ datasets (list[IterableDataset]): datasets to be interleaved.
99
+ probs (list[float] | NDArray, optional): probability of each dataset. Defaults to None.
100
+ stopping_strategy (str, optional): when to end the sampling for one epoch. Defaults to `all_exhausted`.
101
+ `all_exhausted`: each sample in the dataset will be sampled at least once.
102
+ `first_exhausted`: when the first dataset is ran out, this episode ends.
103
+ See also https://huggingface.co/docs/datasets/en/stream#interleave for definitions.
104
+ seed (Optional[int | str]): seed. Defaults to 0.
105
+ """
106
+ self.datasets = datasets
107
+ if probs is None:
108
+ self.probs = [1.0] * len(self.datasets)
109
+ elif isinstance(probs, np.ndarray):
110
+ self.probs = probs.tolist()
111
+ else:
112
+ self.probs = probs
113
+ self.stopping_strategy = stopping_strategy
114
+ self.seed = seed
115
+
116
+ def __iter__(self) -> Generator:
117
+ """Return an iterator over the sources."""
118
+ sources = [iter(d) for d in self.datasets]
119
+ probs = self.probs[:]
120
+ seed_gen = torch.Generator()
121
+ seed_gen.manual_seed(self.seed)
122
+ cum = (np.array(probs) / np.sum(probs)).cumsum()
123
+ while len(sources) > 0:
124
+ r = torch.rand(1, generator=seed_gen).item()
125
+ i = np.searchsorted(cum, r)
126
+ try:
127
+ yield next(sources[i])
128
+ except StopIteration:
129
+ if self.stopping_strategy == "all_exhausted":
130
+ del sources[i]
131
+ del probs[i]
132
+ cum = (np.array(probs) / np.sum(probs)).cumsum()
133
+ elif self.stopping_strategy == "first_exhausted":
134
+ break
135
+
136
+
137
+ def decode_sample(
138
+ key: str, data: bytes, image_transform: Optional[Callable] = None, feature_transform: Optional[Callable] = None
139
+ ) -> Any:
140
+ """Decode a sample from bytes with optional image and feature transforms
141
+
142
+ Args:
143
+ key (str): key of an attribute (a column) of the sample.
144
+ data (bytes): original data bytes.
145
+ image_transform (Optional[Callable], optional): image transform. Defaults to None.
146
+ feature_transform (Optional[Callable], optional): feature transform. Defaults to None.
147
+
148
+ Returns:
149
+ Any: decoded data.
150
+ """
151
+ if ".safetensors" in key:
152
+ sft = sft_load(data)
153
+ embedding = rearrange(sft["embedding"], "c h w -> (h w) c")
154
+ if feature_transform is not None:
155
+ embedding = feature_transform(embedding)
156
+ if "cls_token" in sft:
157
+ cls = sft["cls_token"]
158
+ if feature_transform is not None:
159
+ cls = feature_transform(cls)
160
+ return {"embedding": embedding, "cls": cls}
161
+ return {"embedding": embedding}
162
+ elif key == ".image":
163
+ image = np.load(BytesIO(data))
164
+ if len(image.shape) == 2:
165
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
166
+ elif len(image.shape) == 3 and image.shape[-1] == 4:
167
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
168
+ if image_transform is not None:
169
+ return image_transform(image)
170
+ return image
171
+ else:
172
+ return data
173
+
174
+
175
+ def get_oxe_frame_dataset(
176
+ dataset_root: str,
177
+ dataset_mix: Optional[str | dict[str, float] | list] = "oxe_magic_soup",
178
+ feature_models: Optional[list[str]] = None,
179
+ split: str = "train",
180
+ dataset_ratio: float = 1.0,
181
+ image_views: Optional[dict[str, str | list[str]]] = None,
182
+ image_transform: Optional[Callable[[Any], torch.Tensor]] = None,
183
+ seed: Optional[int | str] = 0,
184
+ shuffle: bool = False,
185
+ world_size: int = 1,
186
+ ) -> tuple[dict[str, DatasetType], float | Literal[0]]:
187
+ """Get OXE datasets at frame level.
188
+
189
+ Args:
190
+ dataset_root (str): root dir of the datasets.
191
+ dataset_mix (Optional[str | dict[str, float] | list], optional): how to mix the datasets.
192
+ Defaults to "oxe_magic_soup".
193
+ feature_models (Optional[list[str]], optional): models to load their features. Defaults to None.
194
+ split (str, optional): split "train" or "val" or "test". Defaults to "train".
195
+ dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0.
196
+ image_views (Optional[dict[str, str | list[str]]], optional): image views to select. Defaults to None.
197
+ image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples.
198
+ Defaults to None.
199
+ seed (Optional[int | str], optional): seed. Defaults to 0.
200
+ shuffle (bool, optional): shuffle or not. Defaults to False.
201
+ world_size (int, optional): world size of DDP training. Defaults to 1.
202
+
203
+ Returns:
204
+ tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}.
205
+ """
206
+ # read dataset mix from any acceptable form
207
+ if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES:
208
+ dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]})
209
+ elif isinstance(dataset_mix, dict):
210
+ dataset_mix = OrderedDict(**dataset_mix)
211
+ elif isinstance(dataset_mix, list):
212
+ dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
213
+ else:
214
+ raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.")
215
+
216
+ if split == "eval" or split == "val":
217
+ dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
218
+
219
+ # note down the dataset weights
220
+ dataset_weights: list[float] = []
221
+ # get frame level length
222
+ dataset_lens: list[int] = []
223
+
224
+ all_feature_datasets: dict[str, DatasetType] = {}
225
+ for dataset in dataset_mix:
226
+ visual_observation_keys = get_vo_keys(dataset_name=dataset, image_views=image_views)
227
+
228
+ if feature_models is None:
229
+ feature_models = PACKED_FEATURES
230
+
231
+ with open(osp.join(dataset_root, dataset, "splits.json"), "r") as splitf:
232
+ dataset_len = json.load(splitf)[split]
233
+ # if the length is 0, skip
234
+ # this may happen for small datasets with very few shards
235
+ if dataset_len == 0:
236
+ continue
237
+
238
+ for vo_key in visual_observation_keys:
239
+ for model_name in feature_models:
240
+ if model_name not in PACKED_FEATURES:
241
+ feature_set_name = model_name
242
+ path_pattern = osp.join(
243
+ dataset_root, dataset, vo_key + f"_{model_name.replace('/', '_')}", f"*-{split}*.tar"
244
+ )
245
+ rename_kw = {model_name: model_name.replace("/", "_") + ".safetensors"} # replace v by k
246
+ elif "packed" in all_feature_datasets:
247
+ continue
248
+ else:
249
+ feature_set_name = "packed"
250
+ path_pattern = osp.join(dataset_root, dataset, vo_key, f"*-{split}*.tar")
251
+ rename_kw = {
252
+ name: name.replace("/", "_") + ".safetensors" for name in PACKED_FEATURES
253
+ } # replace v by k
254
+ rename_kw["image"] = "image"
255
+
256
+ if feature_set_name not in all_feature_datasets:
257
+ all_feature_datasets[feature_set_name] = []
258
+
259
+ shard_paths = sorted(glob.glob(path_pattern))
260
+ num_shards = len(shard_paths)
261
+ if num_shards < world_size * 8:
262
+ shard_paths *= math.ceil(world_size * 8 / num_shards)
263
+ ds = (
264
+ wds.WebDataset(
265
+ shard_paths,
266
+ nodesplitter=wds.split_by_node,
267
+ workersplitter=wds.split_by_worker,
268
+ detshuffle=True,
269
+ shardshuffle=shuffle,
270
+ seed=seed,
271
+ )
272
+ .decode(partial(decode_sample, image_transform=image_transform))
273
+ .rename(keep=False, **rename_kw)
274
+ )
275
+ all_feature_datasets[feature_set_name].append(ds)
276
+
277
+ dataset_weights.append(dataset_mix[dataset])
278
+ dataset_lens.append(math.ceil(dataset_len * dataset_ratio))
279
+
280
+ normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens)
281
+
282
+ combined_feature_datasets: dict[str, Dataset] = {}
283
+ for feature_set_name, fds in all_feature_datasets.items():
284
+ ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted")
285
+ combined_feature_datasets[feature_set_name] = ds
286
+
287
+ return combined_feature_datasets, sum_expected_lengths
288
+
289
+
290
+ def get_oxe_frame_dataloader(
291
+ datasets: dict[str, DatasetType], batch_size: Optional[int] = None, shuffle_buffer_size: int = 1_000, **kwargs: Any
292
+ ) -> dict[str, DataLoader]:
293
+ """Get dataloaders of OXE datasets. Corresponding to `get_oxe_frame_dataset()`.
294
+
295
+ Args:
296
+ datasets (dict[str, DatasetType]): OXE datasets from `get_oxe_frame_dataset().
297
+ batch_size (Optional[int], optional): batch size. Defaults to None.
298
+ shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000.
299
+
300
+ Returns:
301
+ dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}.
302
+ """
303
+ loaders = {
304
+ k: (
305
+ wds.WebLoader(datasets[k], batch_size=None, **kwargs)
306
+ .shuffle(shuffle_buffer_size) # shuffle after mix
307
+ .batched(batch_size, collation_fn=default_collate)
308
+ )
309
+ for k in datasets
310
+ }
311
+ return loaders
312
+
313
+
314
+ def get_oxe_frame_iterator(
315
+ data_loaders: dict[str, DataLoader],
316
+ ) -> Iterator[dict[str, Any]]:
317
+ """Get iterator from dataloders. Corresponding to `get_oxe_frame_dataloader()`.
318
+
319
+ Args:
320
+ data_loaders (dict[str, DataLoader]): dataloaders from `get_oxe_frame_dataloader()`.
321
+
322
+ Yields:
323
+ Iterator[dict[str, Any]]: data sample.
324
+ """
325
+ packed_loader = data_loaders.get("packed", None)
326
+ # place packed_loader at the first
327
+ if packed_loader is not None:
328
+ loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]]
329
+ else:
330
+ loaders = list(data_loaders.values())
331
+
332
+ # merge dicts
333
+ for data in zip(*loaders, strict=False):
334
+ # yield data
335
+ for i in range(1, len(loaders)):
336
+ for k in data[i]:
337
+ if k not in data[0]:
338
+ data[0][k] = data[i][k]
339
+ yield data[0]
340
+
341
+
342
+ def normalize_feature(
343
+ x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None
344
+ ) -> torch.Tensor:
345
+ """Normalize the feature given mean and std.
346
+
347
+ Args:
348
+ x (torch.Tensor): input features
349
+ mean (Optional[torch.Tensor], optional): mean values. Defaults to None.
350
+ std (Optional[torch.Tensor], optional): std values. Defaults to None.
351
+
352
+ Returns:
353
+ torch.Tensor: feature after normalization
354
+ """
355
+ return x if mean is None or std is None else (x - mean) / std
356
+
357
+
358
+ def load_feature_stats(
359
+ dataset_root: str, feature_models: list[str]
360
+ ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
361
+ """Load feature statictics (mean and variance).
362
+
363
+ Args:
364
+ dataset_root (str): root dir of the dataset (or where to hold the statistics).
365
+ feature_models (list[str]): names of the models/features.
366
+
367
+ Returns:
368
+ tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variances. Keys are model names.
369
+ """
370
+ feature_means: dict[str, torch.Tensor] = {}
371
+ feature_vars: dict[str, torch.Tensor] = {}
372
+ for model in feature_models:
373
+ model_name = model.replace("/", "_")
374
+ feature_means[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_mean_{model_name}.npy"))).to(
375
+ torch.bfloat16
376
+ )
377
+ feature_vars[model] = torch.from_numpy(np.load(osp.join(dataset_root, f"imagenet_var_{model_name}.npy"))).to(
378
+ torch.bfloat16
379
+ )
380
+ return feature_means, feature_vars
381
+
382
+
383
+ def pad_shard_paths(shard_paths: list[str], num_shards: int, num_parts: int) -> list[str]:
384
+ """Pad shard paths to be divided by number of partitions (ranks*nodes).
385
+
386
+ Args:
387
+ shard_paths (list[str]): pathes of dataset shards.
388
+ num_shards (int): number of shards.
389
+ num_parts (int): number of partitions.
390
+
391
+ Returns:
392
+ list[str]: shard paths padded.
393
+ """
394
+ final_shard_paths = shard_paths
395
+ if num_shards % num_parts != 0:
396
+ if num_shards < num_parts - num_shards:
397
+ for _ in range(math.floor((num_parts - num_shards) / num_shards)):
398
+ final_shard_paths += shard_paths[:]
399
+ final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)]
400
+ else:
401
+ final_shard_paths += shard_paths[: num_parts - len(final_shard_paths)]
402
+ return final_shard_paths
403
+
404
+
405
+ def get_image_video_dataset(
406
+ dataset_root: str,
407
+ feature_models: list[str],
408
+ dataset_mix: Optional[str | dict[str, float] | list] = None,
409
+ split: str = "train",
410
+ dataset_ratio: float = 1.0,
411
+ image_transform: Optional[Callable[[Any], torch.Tensor]] = None,
412
+ feature_norm: bool = False,
413
+ seed: Optional[int | str] = 0,
414
+ shuffle: bool = False,
415
+ world_size: int = 1,
416
+ **kwargs: Any,
417
+ ) -> tuple[dict[str, DatasetType], float | Literal[0]]:
418
+ """Get image and video datasets at frame level.
419
+
420
+ Args:
421
+ dataset_root (str): root dir of the datasets.
422
+ feature_models (list[str]): models to load their features.
423
+ dataset_mix (Optional[str | dict[str, float] | list], optional): how to mix the datasets.
424
+ split (str, optional): split "train" or "val" or "test". Defaults to "train".
425
+ dataset_ratio (float, optional): how much data use for the (combined) dataset. Defaults to 1.0.
426
+ image_transform (Optional[Callable[[Any], torch.Tensor]], optional): image transform applied to samples.
427
+ Defaults to None.
428
+ feature_norm: (bool, optional): whether to normalize the feature. Defaults to False.
429
+ seed (Optional[int | str], optional): seed. Defaults to 0.
430
+ shuffle (bool, optional): shuffle or not. Defaults to False.
431
+ world_size (int, optional): world size of DDP training. Defaults to 1.
432
+ kwargs (Any): arguments to pass-through.
433
+
434
+ Returns:
435
+ tuple[dict[str, DatasetType], int]: a dict of {dataset name: dataset class}.
436
+ """
437
+ # read dataset mix from any acceptable form
438
+ if isinstance(dataset_mix, str) and dataset_mix in OXE_NAMED_MIXES:
439
+ dataset_mix = OrderedDict({k: v for k, v in OXE_NAMED_MIXES[dataset_mix]})
440
+ elif isinstance(dataset_mix, dict):
441
+ dataset_mix = OrderedDict(**dataset_mix)
442
+ elif isinstance(dataset_mix, list) or isinstance(dataset_mix, omegaconf.listconfig.ListConfig):
443
+ dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
444
+ else:
445
+ raise ValueError(f"dataset_mix of {dataset_mix}:{type(dataset_mix)} is not supported.")
446
+
447
+ if split == "eval" or split == "val":
448
+ dataset_mix = OrderedDict({d: 1.0 for d in dataset_mix})
449
+
450
+ # note down the dataset weights
451
+ dataset_weights: list[float] = []
452
+ # get frame level length
453
+ dataset_lens: list[int] = []
454
+
455
+ all_feature_datasets: dict[str, DatasetType] = {}
456
+
457
+ if feature_norm:
458
+ feature_means, feature_vars = load_feature_stats(dataset_root, feature_models)
459
+
460
+ for d in dataset_mix:
461
+
462
+ with open(osp.join(dataset_root, d, "splits.json"), "r") as splitf:
463
+ dataset_len = json.load(splitf)[split]
464
+
465
+ # if the length is 0, skip
466
+ # this may happen for small datasets with very few shards
467
+ if dataset_len == 0:
468
+ continue
469
+
470
+ path_pattern = osp.join(dataset_root, d, "images", f"*-{split}.tar")
471
+ if "image" not in all_feature_datasets:
472
+ all_feature_datasets["image"] = []
473
+ shard_paths = sorted(glob.glob(path_pattern))
474
+ num_shards = len(shard_paths)
475
+ num_parts = world_size
476
+ final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts)
477
+ ds = wds.WebDataset(
478
+ final_shard_paths,
479
+ nodesplitter=wds.split_by_node,
480
+ workersplitter=wds.split_by_worker,
481
+ detshuffle=True,
482
+ shardshuffle=shuffle,
483
+ seed=seed,
484
+ ).decode(partial(decode_sample, image_transform=image_transform))
485
+ all_feature_datasets["image"].append(ds)
486
+
487
+ for model_name in feature_models:
488
+ path_pattern = osp.join(dataset_root, d, f"{model_name.replace('/', '_')}", f"*-{split}.tar")
489
+ rename_kw = {model_name: model_name.replace("/", "_").lower() + ".safetensors"} # replace v by k
490
+
491
+ if model_name not in all_feature_datasets:
492
+ all_feature_datasets[model_name] = []
493
+
494
+ shard_paths = sorted(glob.glob(path_pattern))
495
+ num_shards = len(shard_paths)
496
+ num_parts = world_size
497
+ final_shard_paths = pad_shard_paths(shard_paths, num_shards, num_parts)
498
+ if feature_norm:
499
+ feature_transform = partial(
500
+ normalize_feature, mean=feature_means[model_name], std=feature_vars[model_name]
501
+ )
502
+ else:
503
+ feature_transform = None
504
+ ds = (
505
+ wds.WebDataset(
506
+ final_shard_paths,
507
+ nodesplitter=wds.split_by_node,
508
+ workersplitter=wds.split_by_worker,
509
+ detshuffle=True,
510
+ shardshuffle=shuffle,
511
+ seed=seed,
512
+ )
513
+ .decode(partial(decode_sample, image_transform=image_transform, feature_transform=feature_transform))
514
+ .rename(keep=False, **rename_kw)
515
+ )
516
+ all_feature_datasets[model_name].append(ds)
517
+
518
+ dataset_weights.append(dataset_mix[d])
519
+ dataset_lens.append(math.ceil(dataset_len * dataset_ratio))
520
+
521
+ normalized_dataset_weights, sum_expected_lengths = normalize_ds_weights_by_ds_len(dataset_weights, dataset_lens)
522
+
523
+ combined_feature_datasets: dict[str, Dataset] = {}
524
+ for feature_set_name, fds in all_feature_datasets.items():
525
+ ds = RandomMix(fds, probs=normalized_dataset_weights, stopping_strategy="all_exhausted", seed=seed)
526
+ combined_feature_datasets[feature_set_name] = ds
527
+
528
+ return combined_feature_datasets, sum_expected_lengths
529
+
530
+
531
+ def get_frame_dataloader(
532
+ datasets: dict[str, DatasetType],
533
+ batch_size: Optional[int] = None,
534
+ shuffle: bool = False,
535
+ shuffle_buffer_size: int = 1_000,
536
+ seed: Optional[int] = 0,
537
+ **kwargs: Any,
538
+ ) -> dict[str, DataLoader]:
539
+ """Get dataloaders of image and video datasets. Corresponding to `get_image_video_dataset()`.
540
+
541
+ Args:
542
+ datasets (dict[str, DatasetType]): image and video datasets from `get_image_video_dataset().
543
+ batch_size (Optional[int], optional): batch size. Defaults to None.
544
+ shuffle_buffer_size (int, optional): buffer for shuffle while streaming. Defaults to 1_000.
545
+
546
+ Returns:
547
+ dict[str, DataLoader]: dataloaders. a dict of {dataset name: dataloader}.
548
+ """
549
+ loaders = {}
550
+ for k in datasets:
551
+ loader = wds.WebLoader(datasets[k], batch_size=None, generator=default_generator, **kwargs)
552
+ if shuffle:
553
+ loader = loader.shuffle(shuffle_buffer_size, seed=seed) # shuffle after mix
554
+ loader = loader.batched(batch_size, collation_fn=default_collate)
555
+ loaders[k] = loader
556
+ return loaders
557
+
558
+
559
+ def get_frame_iterator(
560
+ data_loaders: dict[str, DataLoader],
561
+ ) -> Iterator[dict[str, Any]]:
562
+ """Get iterator from image and video dataset dataloders. Corresponding to `get_frame_dataloader()`.
563
+
564
+ Args:
565
+ data_loaders (dict[str, DataLoader]): dataloaders from `get_frame_dataloader()`.
566
+
567
+ Yields:
568
+ Iterator[dict[str, Any]]: data sample.
569
+ """
570
+ packed_loader = data_loaders.get("packed", None)
571
+ # place packed_loader at the first
572
+ if packed_loader is not None:
573
+ loaders = [packed_loader, *[data_loaders[k] for k in data_loaders if k != "packed"]]
574
+ else:
575
+ loaders = list(data_loaders.values())
576
+
577
+ # merge dicts
578
+ # this is to accommodate the old organization of datasets (each shard contains one or more columns,
579
+ # and images are duplicated columns).
580
+ # In new (current) dataset organization (columns are completely separated),
581
+ # column keys are all different except some "built-in" keys added by webdataset,
582
+ # but they are not related to any data, training, so on.
583
+ # During transit from old to new, where two organizations exist at the same time,
584
+ # this is to ignore extra "image" field in datasets loaded.
585
+ for data in zip(*loaders, strict=False):
586
+ # yield data
587
+ for i in range(1, len(loaders)):
588
+ for k in data[i]:
589
+ if k not in data[0]:
590
+ data[0][k] = data[i][k]
591
+ yield data[0]
theia/dataset/image/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from .image_common import ALL_IMAGE_DATASETS
theia/dataset/image/image_common.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from collections import OrderedDict
4
+
5
+ ALL_IMAGE_DATASETS = OrderedDict({"imagenet": {"steps": 1_281_167}})
theia/dataset/oxe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
theia/dataset/oxe/oxe_common.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from collections import OrderedDict
4
+ from typing import Optional
5
+
6
+ """
7
+ This ALL_OXE_DATASETS below records metadata of all subsets of OXE dataset.
8
+ The datasets are in alphabetical order.
9
+
10
+ versions (list[str]): available and usable versions, sorted from older to newer.
11
+ Usually use the last one.
12
+ episodes (int): total episodes in the dataset.
13
+ steps (int): total steps in the dataset.
14
+ visual_observation_keys (list[str]): keys to specify image observations.
15
+ """
16
+ ALL_OXE_DATASETS: OrderedDict = OrderedDict(
17
+ {
18
+ "agent_aware_affordances": {
19
+ "versions": ["1.0.0"],
20
+ "episodes": 118,
21
+ "steps": 151628,
22
+ "visual_observation_keys": ["image"],
23
+ },
24
+ "asu_table_top_converted_externally_to_rlds": {
25
+ "versions": ["0.1.0"],
26
+ "episodes": 110,
27
+ "steps": 26113,
28
+ "visual_observation_keys": ["image"],
29
+ },
30
+ "austin_buds_dataset_converted_externally_to_rlds": {
31
+ "versions": ["0.1.0"],
32
+ "episodes": 50,
33
+ "steps": 34112,
34
+ "visual_observation_keys": ["image", "wrist_image"],
35
+ },
36
+ "austin_sailor_dataset_converted_externally_to_rlds": {
37
+ "versions": ["0.1.0"],
38
+ "episodes": 240,
39
+ "steps": 353094,
40
+ "visual_observation_keys": ["image", "wrist_image"],
41
+ },
42
+ "austin_sirius_dataset_converted_externally_to_rlds": {
43
+ "versions": ["0.1.0"],
44
+ "episodes": 559,
45
+ "steps": 279939,
46
+ "visual_observation_keys": ["image", "wrist_image"],
47
+ },
48
+ "bc_z": {
49
+ "versions": [
50
+ "0.1.0", # "1.0.0", "old1.0.1", and "1.0.1" are not usable
51
+ ],
52
+ "episodes": 39350,
53
+ "steps": 5471693,
54
+ "visual_observation_keys": ["image"],
55
+ },
56
+ "berkeley_autolab_ur5": {
57
+ "versions": ["0.1.0"],
58
+ "episodes": 896,
59
+ "steps": 87783,
60
+ "visual_observation_keys": ["image", "hand_image"],
61
+ },
62
+ "berkeley_cable_routing": {
63
+ "versions": ["0.1.0"],
64
+ "episodes": 1482,
65
+ "steps": 38240,
66
+ "visual_observation_keys": ["image", "top_image", "wrist225_image", "wrist45_image"],
67
+ },
68
+ "berkeley_fanuc_manipulation": {
69
+ "versions": ["0.1.0"],
70
+ "episodes": 415,
71
+ "steps": 62613,
72
+ "visual_observation_keys": ["image", "wrist_image"],
73
+ },
74
+ "berkeley_gnm_cory_hall": {
75
+ "versions": ["0.1.0"],
76
+ "episodes": 7331,
77
+ "steps": 156012,
78
+ "visual_observation_keys": ["image"],
79
+ },
80
+ "berkeley_gnm_recon": {
81
+ "versions": ["0.1.0"],
82
+ "episodes": 11834,
83
+ "steps": 610907,
84
+ "visual_observation_keys": ["image"],
85
+ },
86
+ "berkeley_gnm_sac_son": {
87
+ "versions": ["0.1.0"],
88
+ "episodes": 2955,
89
+ "steps": 241059,
90
+ "visual_observation_keys": ["image"],
91
+ },
92
+ "berkeley_mvp_converted_externally_to_rlds": {
93
+ "versions": ["0.1.0"],
94
+ "episodes": 480,
95
+ "steps": 45308,
96
+ "visual_observation_keys": ["hand_image"],
97
+ },
98
+ "berkeley_rpt_converted_externally_to_rlds": {
99
+ "versions": ["0.1.0"],
100
+ "episodes": 908,
101
+ "steps": 392578,
102
+ "visual_observation_keys": ["hand_image"],
103
+ },
104
+ "bridge": {"versions": ["0.1.0"], "episodes": 25460, "steps": 864292, "visual_observation_keys": ["image"]},
105
+ "cmu_franka_exploration_dataset_converted_externally_to_rlds": {
106
+ "versions": ["0.1.0"],
107
+ "episodes": 199,
108
+ "steps": 1990,
109
+ "visual_observation_keys": ["image"],
110
+ },
111
+ "cmu_play_fusion": {
112
+ "versions": ["0.1.0"],
113
+ "episodes": 576,
114
+ "steps": 235922,
115
+ "visual_observation_keys": ["image"],
116
+ },
117
+ "cmu_playing_with_food": { # this dataset seems to be corrupted
118
+ "versions": ["1.0.0"],
119
+ "episodes": 4200,
120
+ "steps": 83240,
121
+ "visual_observation_keys": ["image"],
122
+ },
123
+ "cmu_stretch": {"versions": ["0.1.0"], "episodes": 135, "steps": 25016, "visual_observation_keys": ["image"]},
124
+ "columbia_cairlab_pusht_real": {
125
+ "versions": ["0.1.0"],
126
+ "episodes": 122,
127
+ "steps": 24924,
128
+ "visual_observation_keys": ["image", "wrist_image"],
129
+ },
130
+ "dlr_edan_shared_control_converted_externally_to_rlds": {
131
+ "versions": ["0.1.0"],
132
+ "episodes": 104,
133
+ "steps": 8928,
134
+ "visual_observation_keys": ["image"],
135
+ },
136
+ "dlr_sara_grid_clamp_converted_externally_to_rlds": {
137
+ "versions": ["0.1.0"],
138
+ "episodes": 107,
139
+ "steps": 7622,
140
+ "visual_observation_keys": ["image"],
141
+ },
142
+ "dlr_sara_pour_converted_externally_to_rlds": {
143
+ "versions": ["0.1.0"],
144
+ "episodes": 100,
145
+ "steps": 12971,
146
+ "visual_observation_keys": ["image"],
147
+ },
148
+ "eth_agent_affordances": {
149
+ "versions": ["0.1.0"],
150
+ "episodes": 118,
151
+ "steps": 151628,
152
+ "visual_observation_keys": ["image"],
153
+ },
154
+ "fanuc_manipulation_v2": {
155
+ "versions": ["1.0.0"],
156
+ "episodes": 415,
157
+ "steps": 62613,
158
+ "visual_observation_keys": ["image", "wrist_image"],
159
+ },
160
+ "fractal20220817_data": {
161
+ "versions": ["0.1.0"],
162
+ "episodes": 87212,
163
+ "steps": 3786400,
164
+ "visual_observation_keys": ["image"],
165
+ },
166
+ "furniture_bench_dataset_converted_externally_to_rlds": {
167
+ "versions": ["0.1.0"],
168
+ "episodes": 5100,
169
+ "steps": 3948057,
170
+ "visual_observation_keys": ["image", "wrist_image"],
171
+ },
172
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
173
+ "versions": ["0.1.0"],
174
+ "episodes": 631,
175
+ "steps": 146241,
176
+ "visual_observation_keys": ["image", "wrist_image"],
177
+ },
178
+ "imperial_wrist_dataset": {
179
+ "versions": ["1.0.0"],
180
+ "episodes": 170,
181
+ "steps": 7148,
182
+ "visual_observation_keys": ["image", "wrist_image"],
183
+ },
184
+ "imperialcollege_sawyer_wrist_cam": {
185
+ "versions": ["0.1.0"],
186
+ "episodes": 170,
187
+ "steps": 7148,
188
+ "visual_observation_keys": ["image", "wrist_image"],
189
+ },
190
+ "jaco_play": {
191
+ "versions": ["0.1.0"],
192
+ "episodes": 976,
193
+ "steps": 70127,
194
+ "visual_observation_keys": ["image", "image_wrist"],
195
+ },
196
+ "kaist_nonprehensile_converted_externally_to_rlds": {
197
+ "versions": ["0.1.0"],
198
+ "episodes": 201,
199
+ "steps": 32429,
200
+ "visual_observation_keys": ["image"],
201
+ },
202
+ "kuka": {"versions": ["0.1.0"], "episodes": 580392, "steps": 8583978, "visual_observation_keys": ["image"]},
203
+ "language_table": {
204
+ "versions": ["0.0.1", "0.1.0"],
205
+ "episodes": 442226,
206
+ "steps": 7045476,
207
+ "visual_observation_keys": ["rgb"],
208
+ },
209
+ "language_table_blocktoabsolute_oracle_sim": {
210
+ "versions": ["0.0.1"],
211
+ "episodes": 200000,
212
+ "steps": 15866385,
213
+ "visual_observation_keys": ["rgb"],
214
+ },
215
+ "language_table_blocktoblock_4block_sim": {
216
+ "versions": ["0.0.1"],
217
+ "episodes": 8298,
218
+ "steps": 326768,
219
+ "visual_observation_keys": ["rgb"],
220
+ },
221
+ "language_table_blocktoblock_oracle_sim": {
222
+ "versions": ["0.0.1"],
223
+ "episodes": 200000,
224
+ "steps": 12970620,
225
+ "visual_observation_keys": ["rgb"],
226
+ },
227
+ "language_table_blocktoblock_sim": {
228
+ "versions": ["0.0.1"],
229
+ "episodes": 8000,
230
+ "steps": 351688,
231
+ "visual_observation_keys": ["rgb"],
232
+ },
233
+ "language_table_blocktoblockrelative_oracle_sim": {
234
+ "versions": ["0.0.1"],
235
+ "episodes": 200000,
236
+ "steps": 13016749,
237
+ "visual_observation_keys": ["rgb"],
238
+ },
239
+ "language_table_blocktorelative_oracle_sim": {
240
+ "versions": ["0.0.1"],
241
+ "episodes": 200000,
242
+ "steps": 8655815,
243
+ "visual_observation_keys": ["rgb"],
244
+ },
245
+ "language_table_separate_oracle_sim": {
246
+ "versions": ["0.0.1"],
247
+ "episodes": 200000,
248
+ "steps": 3196661,
249
+ "visual_observation_keys": ["rgb"],
250
+ },
251
+ "language_table_sim": {
252
+ "versions": ["0.0.1"],
253
+ "episodes": 181020,
254
+ "steps": 4665423,
255
+ "visual_observation_keys": ["rgb"],
256
+ },
257
+ "maniskill_dataset_converted_externally_to_rlds": {
258
+ "versions": ["0.1.0"],
259
+ "episodes": 30213,
260
+ "steps": 4537402,
261
+ "visual_observation_keys": ["image", "wrist_image"],
262
+ },
263
+ "mutex_dataset": {
264
+ "versions": ["1.0.0"],
265
+ "episodes": 1500,
266
+ "steps": 361883,
267
+ "visual_observation_keys": ["image", "wrist_image"],
268
+ },
269
+ "nyu_door_opening_surprising_effectiveness": {
270
+ "versions": ["0.1.0"],
271
+ "episodes": 435,
272
+ "steps": 18196,
273
+ "visual_observation_keys": ["image"],
274
+ },
275
+ "nyu_franka_play_dataset_converted_externally_to_rlds": {
276
+ "versions": ["0.1.0"],
277
+ "episodes": 365,
278
+ "steps": 34448,
279
+ "visual_observation_keys": ["image", "image_additional_view"],
280
+ },
281
+ "nyu_rot_dataset_converted_externally_to_rlds": {
282
+ "versions": ["0.1.0"],
283
+ "episodes": 14,
284
+ "steps": 440,
285
+ "visual_observation_keys": ["image"],
286
+ },
287
+ "qut_dexterous_manpulation": {
288
+ "versions": ["0.1.0"],
289
+ "episodes": 200,
290
+ "steps": 176278,
291
+ "visual_observation_keys": ["image", "wrist_image"],
292
+ },
293
+ "robo_net": {
294
+ "versions": ["0.1.0", "1.0.0"],
295
+ "episodes": 82775,
296
+ "steps": 2483250,
297
+ "visual_observation_keys": ["image", "image1", "image2"],
298
+ },
299
+ "robot_vqa": {
300
+ "versions": ["0.1.0"],
301
+ "episodes": 3331523,
302
+ "steps": 3331523,
303
+ "visual_observation_keys": ["images"],
304
+ },
305
+ "roboturk": {
306
+ "versions": ["0.1.0"],
307
+ "episodes": 1796,
308
+ "steps": 168423,
309
+ "visual_observation_keys": ["front_rgb"],
310
+ },
311
+ "stanford_hydra_dataset_converted_externally_to_rlds": {
312
+ "versions": ["0.1.0"],
313
+ "episodes": 570,
314
+ "steps": 358234,
315
+ "visual_observation_keys": ["image", "wrist_image"],
316
+ },
317
+ "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
318
+ "versions": ["0.1.0"],
319
+ "episodes": 3000,
320
+ "steps": 149985,
321
+ "visual_observation_keys": ["image"],
322
+ },
323
+ "stanford_mask_vit_converted_externally_to_rlds": {
324
+ "versions": ["0.1.0"],
325
+ "episodes": 9109,
326
+ "steps": 282379,
327
+ "visual_observation_keys": ["image"],
328
+ },
329
+ "stanford_robocook_converted_externally_to_rlds": {
330
+ "versions": ["0.1.0"],
331
+ "episodes": 2460,
332
+ "steps": 112980,
333
+ "visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"],
334
+ },
335
+ "taco_play": {
336
+ "versions": ["0.1.0"],
337
+ "episodes": 3242,
338
+ "steps": 213972,
339
+ "visual_observation_keys": ["rgb_static", "rgb_gripper"],
340
+ },
341
+ "tokyo_u_lsmo_converted_externally_to_rlds": {
342
+ "versions": ["0.1.0"],
343
+ "episodes": 50,
344
+ "steps": 11925,
345
+ "visual_observation_keys": ["image"],
346
+ },
347
+ "toto": {"versions": ["0.1.0"], "episodes": 902, "steps": 294139, "visual_observation_keys": ["image"]},
348
+ "ucsd_kitchen_dataset_converted_externally_to_rlds": {
349
+ "versions": ["0.1.0"],
350
+ "episodes": 150,
351
+ "steps": 3970,
352
+ "visual_observation_keys": ["image"],
353
+ },
354
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
355
+ "versions": ["0.1.0"],
356
+ "episodes": 1355,
357
+ "steps": 67750,
358
+ "visual_observation_keys": ["image"],
359
+ },
360
+ "uiuc_d3field": { # this dataset seems to be corrupted
361
+ "versions": ["0.1.0", "1.1.2"],
362
+ "episodes": 196,
363
+ "steps": 13384,
364
+ "visual_observation_keys": ["image_1", "image_2", "image_3", "image_4"],
365
+ },
366
+ "usc_cloth_sim_converted_externally_to_rlds": {
367
+ "versions": ["0.1.0"],
368
+ "episodes": 800,
369
+ "steps": 80000,
370
+ "visual_observation_keys": ["image"],
371
+ },
372
+ "utaustin_mutex": {
373
+ "versions": ["0.1.0"],
374
+ "episodes": 1500,
375
+ "steps": 361883,
376
+ "visual_observation_keys": ["image", "wrist_image"],
377
+ },
378
+ "utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
379
+ "versions": ["0.1.0"],
380
+ "episodes": 64,
381
+ "steps": 9140,
382
+ "visual_observation_keys": ["image"],
383
+ },
384
+ "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
385
+ "versions": ["0.1.0"],
386
+ "episodes": 192,
387
+ "steps": 26346,
388
+ "visual_observation_keys": ["image"],
389
+ },
390
+ "utokyo_saytap_converted_externally_to_rlds": {
391
+ "versions": ["0.1.0"],
392
+ "episodes": 20,
393
+ "steps": 22937,
394
+ "visual_observation_keys": ["image", "wrist_image"],
395
+ },
396
+ "utokyo_xarm_bimanual_converted_externally_to_rlds": {
397
+ "versions": ["0.1.0"],
398
+ "episodes": 64,
399
+ "steps": 1388,
400
+ "visual_observation_keys": ["image"],
401
+ },
402
+ "utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
403
+ "versions": ["0.1.0"],
404
+ "episodes": 92,
405
+ "steps": 6789,
406
+ "visual_observation_keys": ["image", "hand_image", "image2"],
407
+ },
408
+ "viola": {
409
+ "versions": ["0.1.0"],
410
+ "episodes": 135,
411
+ "steps": 68913,
412
+ "visual_observation_keys": ["agentview_rgb", "eye_in_hand_rgb"],
413
+ },
414
+ }
415
+ )
416
+
417
+
418
+ def oxe_dsname2path(dataset_name: str, version: Optional[str] = None) -> str:
419
+ """From dataset name to remote google clound path to the dataset.
420
+
421
+ Args:
422
+ dataset_name (str): dataset name.
423
+ version (Optional[str]): version string.
424
+
425
+ Returns:
426
+ str: google clound path
427
+ """
428
+ if version is None:
429
+ version = ALL_OXE_DATASETS[dataset_name]["versions"][-1]
430
+ return f"gs://gresearch/robotics/{dataset_name}/{version}"
theia/dataset/oxe/oxe_mixes.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File modified. Modifications Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ """MIT License Copyright (c) 2023 Robotic AI & Learning Lab Berkeley
4
+
5
+ From Octo https://github.com/octo-models/octo/blob/main/octo/data/oxe/oxe_dataset_mixes.py
6
+ """
7
+
8
+ BRIDGE_MIX = [
9
+ ("bridge_dataset", 1.0),
10
+ ]
11
+
12
+ RT_X_MIX = [
13
+ ("fractal20220817_data", 0.54087122203),
14
+ ("kuka", 0.8341046294),
15
+ ("bridge_dataset", 1.0),
16
+ ("taco_play", 2.0),
17
+ ("jaco_play", 2.0),
18
+ ("berkeley_cable_routing", 3.0),
19
+ ("roboturk", 1.0),
20
+ ("nyu_door_opening_surprising_effectiveness", 5.0),
21
+ ("viola", 2.0),
22
+ ("berkeley_autolab_ur5", 1.0),
23
+ ("toto", 1.0),
24
+ ]
25
+
26
+
27
+ OXE_FRANKA_MIX = [
28
+ ("taco_play", 1.0),
29
+ ("berkeley_cable_routing", 1.0),
30
+ ("viola", 1.0),
31
+ ("toto", 1.0),
32
+ ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
33
+ ("austin_buds_dataset_converted_externally_to_rlds", 3.0),
34
+ ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
35
+ ("maniskill_dataset_converted_externally_to_rlds", 0.1),
36
+ ("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
37
+ ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0),
38
+ ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
39
+ ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
40
+ ("berkeley_rpt_converted_externally_to_rlds", 1.0),
41
+ ("kaist_nonprehensile_converted_externally_to_rlds", 3.0),
42
+ ("stanford_robocook_converted_externally_to_rlds", 1.0),
43
+ ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
44
+ ("utaustin_mutex", 1.0),
45
+ # ("cmu_playing_with_food", 1.0),
46
+ ("cmu_play_fusion", 1.0),
47
+ ]
48
+
49
+ OXE_MAGIC_SOUP = [
50
+ ("fractal20220817_data", 0.54087122203),
51
+ ("kuka", 0.8341046294),
52
+ ("bridge", 1.0),
53
+ ("taco_play", 2.0),
54
+ ("jaco_play", 1.0),
55
+ ("berkeley_cable_routing", 1.0),
56
+ ("roboturk", 2.0),
57
+ ("nyu_door_opening_surprising_effectiveness", 1.0),
58
+ ("viola", 2.0),
59
+ ("berkeley_autolab_ur5", 2.0),
60
+ ("toto", 1.0),
61
+ ("language_table", 0.1),
62
+ ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0),
63
+ ("austin_buds_dataset_converted_externally_to_rlds", 1.0),
64
+ ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0),
65
+ ("furniture_bench_dataset_converted_externally_to_rlds", 0.1),
66
+ ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0),
67
+ ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
68
+ ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
69
+ ("bc_z", 0.2),
70
+ ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
71
+ ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
72
+ # ("uiuc_d3field", 1.0), --> somehow raw data is broken
73
+ ("utaustin_mutex", 1.0),
74
+ ("berkeley_fanuc_manipulation", 2.0),
75
+ ("cmu_stretch", 1.0),
76
+ ]
77
+
78
+
79
+ OXE_FULL_MIX = [
80
+ ("fractal20220817_data", 1.0),
81
+ ("kuka", 1.0),
82
+ ("bridge_dataset", 1),
83
+ ("taco_play", 1.0),
84
+ ("jaco_play", 1.0),
85
+ ("berkeley_cable_routing", 1.0),
86
+ ("roboturk", 1.0),
87
+ ("nyu_door_opening_surprising_effectiveness", 1.0),
88
+ ("viola", 1.0),
89
+ ("berkeley_autolab_ur5", 1.0),
90
+ ("toto", 1.0),
91
+ ("language_table", 1.0),
92
+ ("columbia_cairlab_pusht_real", 1.0),
93
+ ("stanford_kuka_multimodal_dataset_converted_externally_to_rlds", 1.0),
94
+ ("nyu_rot_dataset_converted_externally_to_rlds", 1.0),
95
+ ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0),
96
+ ("austin_buds_dataset_converted_externally_to_rlds", 1.0),
97
+ ("nyu_franka_play_dataset_converted_externally_to_rlds", 1.0),
98
+ ("maniskill_dataset_converted_externally_to_rlds", 1.0),
99
+ ("furniture_bench_dataset_converted_externally_to_rlds", 1.0),
100
+ ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 1.0),
101
+ ("ucsd_kitchen_dataset_converted_externally_to_rlds", 1.0),
102
+ ("ucsd_pick_and_place_dataset_converted_externally_to_rlds", 1.0),
103
+ ("austin_sailor_dataset_converted_externally_to_rlds", 1.0),
104
+ ("austin_sirius_dataset_converted_externally_to_rlds", 1.0),
105
+ ("bc_z", 1.0),
106
+ ("utokyo_pr2_opening_fridge_converted_externally_to_rlds", 1.0),
107
+ ("utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds", 1.0),
108
+ ("utokyo_xarm_pick_and_place_converted_externally_to_rlds", 1.0),
109
+ ("utokyo_xarm_bimanual_converted_externally_to_rlds", 1.0),
110
+ ("robo_net", 1.0),
111
+ ("berkeley_mvp_converted_externally_to_rlds", 1.0),
112
+ ("berkeley_rpt_converted_externally_to_rlds", 1.0),
113
+ ("kaist_nonprehensile_converted_externally_to_rlds", 1.0),
114
+ ("stanford_mask_vit_converted_externally_to_rlds", 1.0),
115
+ ("tokyo_u_lsmo_converted_externally_to_rlds", 1.0),
116
+ ("dlr_sara_pour_converted_externally_to_rlds", 1.0),
117
+ ("dlr_sara_grid_clamp_converted_externally_to_rlds", 1.0),
118
+ ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0),
119
+ ("asu_table_top_converted_externally_to_rlds", 1.0),
120
+ ("stanford_robocook_converted_externally_to_rlds", 1.0),
121
+ ("imperialcollege_sawyer_wrist_cam", 1.0),
122
+ ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0),
123
+ ("uiuc_d3field", 1.0),
124
+ ("utaustin_mutex", 1.0),
125
+ ("berkeley_fanuc_manipulation", 1.0),
126
+ ("cmu_playing_with_food", 1.0),
127
+ ("cmu_play_fusion", 1.0),
128
+ ("cmu_stretch", 1.0),
129
+ ("berkeley_gnm_recon", 1.0),
130
+ ("berkeley_gnm_cory_hall", 1.0),
131
+ ("berkeley_gnm_sac_son", 1.0),
132
+ ]
133
+
134
+ OXE_NAMED_MIXES = {
135
+ "bridge": BRIDGE_MIX,
136
+ "rtx": RT_X_MIX,
137
+ "rtx_franka": RT_X_MIX + OXE_FRANKA_MIX,
138
+ "oxe_magic_soup": OXE_MAGIC_SOUP,
139
+ }
theia/dataset/oxe/oxe_transforms.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ import torch
4
+ from numpy.typing import NDArray
5
+ from torchvision.transforms.v2 import Compose, Normalize, ToDtype, ToImage
6
+
7
+
8
+ def totensor(arr: NDArray) -> torch.Tensor:
9
+ """Convert ndarray to tensor."""
10
+ return torch.from_numpy(arr)
11
+
12
+
13
+ oxe_image_transform = Compose(
14
+ [ToImage(), ToDtype(torch.float32, scale=True), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
15
+ ) # ImageNet statistics normalization
theia/dataset/video/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from .video_common import ALL_VIDEO_DATASETS
theia/dataset/video/video_common.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from collections import OrderedDict
4
+
5
+ ALL_VIDEO_DATASETS = OrderedDict(
6
+ {
7
+ "ego4d_1in150": {"steps": 2_800_871},
8
+ "epic_kitchen_1in60": {"steps": 333_117},
9
+ "ssv2_1in32": {"steps": 312_772},
10
+ }
11
+ )
theia/decoding/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from .decode import decode_everything, load_feature_stats
4
+ from .depth_anything import prepare_depth_decoder
5
+ from .sam import prepare_mask_generator
theia/decoding/decode.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange
10
+ from numpy.typing import NDArray
11
+ from PIL import Image
12
+ from sklearn.decomposition import PCA
13
+ from transformers import SamModel, SamProcessor
14
+ from transformers.pipelines import MaskGenerationPipeline
15
+
16
+ from theia.decoding.depth_anything import decode_depth_anything
17
+ from theia.decoding.dinov2 import decode_dinov2
18
+ from theia.decoding.sam import decode_sam
19
+ from theia.preprocessing.feature_extraction_core import (
20
+ get_feature_outputs,
21
+ get_model,
22
+ )
23
+
24
+
25
+ def denormalize_feature(
26
+ x: torch.Tensor, mean: Optional[torch.Tensor] = None, std: Optional[torch.Tensor] = None
27
+ ) -> torch.Tensor:
28
+ """Denormalize the features using mean and std.
29
+
30
+ Args:
31
+ x (torch.Tensor): features to be denomalized.
32
+ mean (Optional[torch.Tensor], optional): mean value of the features. Defaults to None
33
+ std (Optional[torch.Tensor], optional): std value of the features. Defaults to None.
34
+
35
+ Returns:
36
+ torch.Tensor: denormalized features.
37
+ """
38
+ if mean is None and std is None:
39
+ return x
40
+ elif mean is None and std is not None:
41
+ return x * std
42
+ elif mean is not None and std is None:
43
+ return x + mean
44
+ return x * std + mean
45
+
46
+
47
+ def load_feature_stats(
48
+ feature_models: list[str], stat_file_root: str
49
+ ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
50
+ """Load the statistics (mean and variance) of the features, per model.
51
+
52
+ Args:
53
+ feature_models (list[str]): names of the models. Note: there are `/` in the name.
54
+ stat_file_root (str): directory that holds feature stat files.
55
+
56
+ Returns:
57
+ tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: means and variance.
58
+ """
59
+ feature_means: dict[str, torch.Tensor] = {}
60
+ feature_vars: dict[str, torch.Tensor] = {}
61
+ for model in feature_models:
62
+ model_name = model.replace("/", "_")
63
+ feature_means[model] = torch.from_numpy(
64
+ np.load(os.path.join(stat_file_root, f"imagenet_mean_{model_name}.npy"))
65
+ )
66
+ feature_vars[model] = torch.from_numpy(np.load(os.path.join(stat_file_root, f"imagenet_var_{model_name}.npy")))
67
+ return feature_means, feature_vars
68
+
69
+
70
+ def decode_everything(
71
+ theia_model: nn.Module,
72
+ feature_means: dict[str, torch.Tensor],
73
+ feature_vars: dict[str, torch.Tensor],
74
+ images: list[Image.Image],
75
+ mask_generator: MaskGenerationPipeline,
76
+ sam_model: SamModel,
77
+ depth_anything_decoder: nn.Module,
78
+ pred_iou_thresh: float = 0.9,
79
+ stability_score_thresh: float = 0.9,
80
+ gt: bool = False,
81
+ pca: Optional[PCA] = None,
82
+ device: int | str | torch.device = 0,
83
+ ) -> tuple[list[NDArray], Optional[list[NDArray]]]:
84
+ """Decode features from given `theia_model` into different outputs corresponding to upstream models including
85
+ DINOv2, Sam, and Depth-Anything.
86
+
87
+ Args:
88
+ theia_model (nn.Module): theia model.
89
+ feature_means (dict[str, torch.Tensor]): means of the features for denormalization.
90
+ feature_vars (dict[str, torch.Tensor]): variance of the features for denormalization.
91
+ images (list[Image.Image]): input images.
92
+ mask_generator (MaskGenerationPipeline): mask generation pipeline.
93
+ sam_model (SamModel): sam model.
94
+ depth_anything_decoder (nn.Module): depth anything decoder.
95
+ pred_iou_thresh (float, optional): iou threshold for mask generation.
96
+ See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9.
97
+ stability_score_thresh (float, optional): stability score threshold for mask generation.
98
+ See transformers.pipelines.MaskGenerationPipeline for more details. Defaults to 0.9.
99
+ gt (bool): whether to attach ground truth result in the visualization. Defaults to False.
100
+ pca (Optional[PCA]): pca for DINOv2 decoding. If provided, will use this pca particular. Defaults to None.
101
+ device (int | str | torch.device, optional): device for decoding. Defaults to 0.
102
+
103
+ Returns:
104
+ tuple[list[NDArray], Optional[list[NDArray]]]: decoding results from given model,
105
+ and ground truth (if `gt=True`).
106
+ """
107
+ features: dict[str, torch.Tensor] = {}
108
+ with torch.no_grad():
109
+ for im in images:
110
+ feature = theia_model([im])
111
+ if len(features) == 0:
112
+ features = {k: [] for k in feature}
113
+ for k in feature:
114
+ features[k].append(feature[k].detach().cpu())
115
+ for k in features:
116
+ features[k] = torch.cat(features[k], dim=0)
117
+ for m in features:
118
+ features[m] = denormalize_feature(features[m], feature_means[m], feature_vars[m])
119
+
120
+ dino_model_name = "facebook/dinov2-large"
121
+ sam_model_name = "facebook/sam-vit-huge"
122
+ depth_anything_model_name = "LiheYoung/depth-anything-large-hf"
123
+
124
+ pca = None
125
+ # gt
126
+ gt_decode_results = None
127
+ if gt:
128
+ def legit_model_name(model_name: str) -> str:
129
+ return model_name.replace("/", "_")
130
+
131
+ dino_model, dino_processor = get_model(dino_model_name, device=device)
132
+ dino_gt_feature = []
133
+ for im in images:
134
+ dino_gt_feature.append(
135
+ get_feature_outputs(
136
+ legit_model_name(dino_model_name), dino_model, dino_processor, [im], dtype=torch.float
137
+ )[legit_model_name(dino_model_name)]["embedding"]
138
+ .detach()
139
+ .cpu()
140
+ )
141
+ dino_gt_feature = torch.cat(dino_gt_feature, dim=0)
142
+ dino_gt_feature = rearrange(dino_gt_feature, "b c h w -> b (h w) c")
143
+ dino_gt_dec, pca = decode_dinov2(dino_gt_feature, pca=pca)
144
+ sam_processor = SamProcessor.from_pretrained(sam_model_name)
145
+ sam_gt_feature = []
146
+ for im in images:
147
+ sam_inputs = sam_processor(images=[im], return_tensors="pt").to(device)
148
+ with torch.no_grad():
149
+ sam_gt_feature.append(sam_model.get_image_embeddings(sam_inputs["pixel_values"]).detach().cpu())
150
+ sam_gt_feature = torch.cat(sam_gt_feature, dim=0)
151
+ sam_gt_feature = rearrange(sam_gt_feature, "b c h w -> b (h w) c")
152
+ sam_gt_dec = decode_sam(
153
+ sam_gt_feature, images, mask_generator, pred_iou_thresh=0.9, stability_score_thresh=0.9, device=device
154
+ )
155
+ depth_anything_model, depth_anything_processor = get_model(depth_anything_model_name, device=device)
156
+ depth_anything_gt_feature = []
157
+ for im in images:
158
+ depth_anything_gt_feature.append(
159
+ get_feature_outputs(
160
+ legit_model_name(depth_anything_model_name),
161
+ depth_anything_model,
162
+ depth_anything_processor,
163
+ [im],
164
+ dtype=torch.float,
165
+ )[legit_model_name(depth_anything_model_name)]["embedding"]
166
+ .detach()
167
+ .cpu()
168
+ )
169
+ depth_anything_gt_feature = torch.cat(depth_anything_gt_feature, dim=0)
170
+ depth_anything_gt_feature = rearrange(depth_anything_gt_feature, "b c h w -> b (h w) c")
171
+ depth_gt_dec = decode_depth_anything(depth_anything_gt_feature, depth_anything_decoder, device=device)
172
+
173
+ gt_decode_results = [
174
+ np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_gt_dec[i], sam_gt_dec[i], depth_gt_dec[i]])
175
+ for i in range(len(images))
176
+ ]
177
+
178
+ dino_dec, _ = decode_dinov2(features[dino_model_name], pca=pca)
179
+
180
+ try:
181
+ sam_dec = decode_sam(
182
+ features[sam_model_name],
183
+ images,
184
+ mask_generator,
185
+ pred_iou_thresh=pred_iou_thresh,
186
+ stability_score_thresh=stability_score_thresh,
187
+ device=device,
188
+ )
189
+ except IndexError:
190
+ sam_dec = np.zeros_like(dino_dec)
191
+ depth_dec = decode_depth_anything(features[depth_anything_model_name], depth_anything_decoder, device=device)
192
+
193
+ theia_decode_results = [
194
+ np.hstack([np.array(images[i]).astype(np.float32) / 255.0, dino_dec[i], sam_dec[i], depth_dec[i]])
195
+ for i in range(len(images))
196
+ ]
197
+
198
+ return theia_decode_results, gt_decode_results
theia/decoding/depth_anything.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from theia.foundation_models.vision_models.depth_anything import DepthAnythingForDepthEstimation
7
+ from numpy.typing import NDArray
8
+ from torch.nn.functional import interpolate
9
+
10
+
11
+ def prepare_depth_decoder(model_name: str, device: int | str | torch.device = 0) -> tuple[nn.Module, int]:
12
+ """Prepare a depth decoder using DepthAnythingForDepthEstimation.
13
+
14
+ Args:
15
+ model_name (str): name of the depth anything model.
16
+ device (int | str | torch.device, optional): device to put the model on. Defaults to 0.
17
+
18
+ Returns:
19
+ tuple[nn.Module, int]: the decoder, and the patch size for depth anything model.
20
+ """
21
+ decoder_head = DepthAnythingForDepthEstimation.from_pretrained(model_name)
22
+ patch_size = decoder_head.config.patch_size
23
+ decoder_head = decoder_head.head
24
+ decoder_head = decoder_head.to(device)
25
+ return decoder_head, patch_size
26
+
27
+
28
+ def decode_depth_anything(features: torch.Tensor, decoder: nn.Module, device: int | str | torch.device = 0) -> NDArray:
29
+ """Decode features to predicted depth using depth anything
30
+
31
+ Args:
32
+ features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
33
+ decoder (nn.Module): depth anything decoder
34
+ device (int | str | torch.device, optional): device to perform the decoding. Defaults to 0.
35
+
36
+ Returns:
37
+ NDArray: decoded depth in image format, represented by an NDArray in size [batch_size, height, width, channels]
38
+ with value between [0, 1]. The depth values are min-max normalized to [0, 1] to generate images.
39
+ """
40
+ with torch.no_grad():
41
+ P = int(features.size(1) ** 0.5)
42
+ features = rearrange(features, "b (h w) c -> b c h w", h=P, w=P)
43
+ features = interpolate(features, (224, 224))
44
+ predicted_depths = []
45
+ for feature in features:
46
+ feature = feature.unsqueeze(0).to(device)
47
+
48
+ predicted_depth = decoder.activation1(feature)
49
+ predicted_depth = decoder.conv3(predicted_depth)
50
+ predicted_depth = decoder.activation2(predicted_depth)
51
+ predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width)
52
+ for i in range(len(predicted_depth)):
53
+ min_depth, max_depth = predicted_depth[i].min(), predicted_depth[i].max()
54
+ predicted_depth[i] = (predicted_depth[i] - min_depth) / (max_depth - min_depth)
55
+ predicted_depths.append(predicted_depth.detach().cpu())
56
+ predicted_depths = torch.cat(predicted_depths, dim=0)
57
+ return predicted_depths.unsqueeze(-1).repeat((1, 1, 1, 3)).numpy() # type: ignore [attr-defined]
theia/decoding/dinov2.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from typing import Optional
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+ from sklearn.decomposition import PCA
9
+ from sklearn.preprocessing import minmax_scale
10
+
11
+
12
+ def decode_dinov2(
13
+ features: NDArray, threshold: int | float = -100, interpolation: bool = False, pca: Optional[PCA] = None
14
+ ) -> tuple[NDArray, PCA]:
15
+ """
16
+ Decode the input `features` in DINOv2 style using PCA.
17
+
18
+ Args:
19
+ features (NDArray): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
20
+ threshold (int | float): threshold of foreground-background split in PCA visualization.
21
+ Defaults to -100 (all patches are included).
22
+ interpolation (bool): whether interpolate the 16x16 pca map to the original image size.
23
+ pca (Optional[PCA]): if provided, use the provided PCA. This is to keep visualizations stable across samples.
24
+
25
+ Returns:
26
+ tuple[NDArray, PCA]: the rendered image of this visualization, in NDArray in size
27
+ [batch_size, height, width, channels] with value ranges [0, 1], and the PCA used in this visualization.
28
+ """
29
+ features = features.numpy()
30
+ batch_size, spatial_size, latent_dim = features.shape
31
+ h = w = int(spatial_size**0.5)
32
+
33
+ features = features.reshape(-1, latent_dim)
34
+
35
+ if pca is None:
36
+ pca = PCA(n_components=3)
37
+ pca.fit(features)
38
+
39
+ pca_features = pca.transform(features)
40
+
41
+ # segment using the first component
42
+ bg_mask = pca_features[:, 0] < threshold
43
+ fg_mask = ~bg_mask
44
+
45
+ # PCA for only foreground patches
46
+ # pca.fit(features[fg_mask])
47
+ pca_features_fg = pca.transform(features[fg_mask])
48
+ for i in range(3):
49
+ pca_features_fg[:, i] = minmax_scale(pca_features_fg[:, i])
50
+
51
+ pca_features_rgb = pca_features.copy()
52
+ pca_features_rgb[bg_mask] = 0
53
+ pca_features_rgb[fg_mask] = pca_features_fg
54
+
55
+ pca_features_rgb = pca_features_rgb.reshape(batch_size, h, w, 3)
56
+ if not interpolation:
57
+ H = W = 224
58
+ scale = H // h
59
+ interpolated_pca_features = np.zeros((batch_size, H, W, 3), dtype=pca_features_rgb.dtype)
60
+ for i in range(len(pca_features_rgb)):
61
+ for j in range(h):
62
+ for k in range(w):
63
+ interpolated_pca_features[i, scale * j : scale * (j + 1), scale * k : scale * (k + 1)] = (
64
+ pca_features_rgb[i, j, k]
65
+ )
66
+ pca_features_rgb = interpolated_pca_features
67
+ else:
68
+ pca_features_rgb = np.stack([cv2.resize(p, (224, 224)) for p in pca_features_rgb])
69
+ return pca_features_rgb, pca
theia/decoding/sam.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from typing import Any, Generator, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ from einops import rearrange
8
+ from numpy.typing import NDArray
9
+ from PIL import Image
10
+ from transformers import SamModel, SamProcessor
11
+ from transformers.image_utils import load_image
12
+ from transformers.pipelines import MaskGenerationPipeline
13
+
14
+
15
+ class MaskGenerationPipelineWithEmbeddings(MaskGenerationPipeline):
16
+ """
17
+ The wrapper class for huggingface transformers.pipelines.MaskGenerationPipeline
18
+ that can decode from intermediate SAM embeddings.
19
+ """
20
+
21
+ def _sanitize_parameters(self, **kwargs: Any) -> tuple[dict[str, Any], ...]:
22
+ preprocess_kwargs = {}
23
+ postprocess_kwargs = {}
24
+ forward_params = {}
25
+ # preprocess args
26
+ if "embeddings" in kwargs: # inject embeddings here
27
+ preprocess_kwargs["embeddings"] = kwargs["embeddings"]
28
+ if "points_per_batch" in kwargs:
29
+ preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
30
+ if "points_per_crop" in kwargs:
31
+ preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
32
+ if "crops_n_layers" in kwargs:
33
+ preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
34
+ if "crop_overlap_ratio" in kwargs:
35
+ preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
36
+ if "crop_n_points_downscale_factor" in kwargs:
37
+ preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
38
+ if "timeout" in kwargs:
39
+ preprocess_kwargs["timeout"] = kwargs["timeout"]
40
+ # postprocess args
41
+ if "pred_iou_thresh" in kwargs:
42
+ forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
43
+ if "stability_score_offset" in kwargs:
44
+ forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
45
+ if "mask_threshold" in kwargs:
46
+ forward_params["mask_threshold"] = kwargs["mask_threshold"]
47
+ if "stability_score_thresh" in kwargs:
48
+ forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
49
+ if "crops_nms_thresh" in kwargs:
50
+ postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
51
+ if "output_rle_mask" in kwargs:
52
+ postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
53
+ if "output_bboxes_mask" in kwargs:
54
+ postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
55
+ return preprocess_kwargs, forward_params, postprocess_kwargs
56
+
57
+ def preprocess(
58
+ self,
59
+ image: list[Image.Image],
60
+ points_per_batch: int = 64,
61
+ crops_n_layers: int = 0,
62
+ crop_overlap_ratio: float = 512 / 1500,
63
+ points_per_crop: int = 32,
64
+ crop_n_points_downscale_factor: int = 1,
65
+ timeout: Optional[float] = None,
66
+ embeddings: Optional[torch.Tensor] = None,
67
+ ) -> Generator[Any, Any, Any]:
68
+ image = load_image(image, timeout=timeout)
69
+ target_size = self.image_processor.size["longest_edge"]
70
+ crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
71
+ image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
72
+ )
73
+ model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
74
+
75
+ with self.device_placement():
76
+ if self.framework == "pt":
77
+ inference_context = self.get_inference_context()
78
+ with inference_context():
79
+ model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
80
+ if embeddings is None:
81
+ image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
82
+ else:
83
+ model_inputs.pop("pixel_values")
84
+ image_embeddings = embeddings
85
+ model_inputs["image_embeddings"] = image_embeddings
86
+
87
+ n_points = grid_points.shape[1]
88
+ points_per_batch = points_per_batch if points_per_batch is not None else n_points
89
+
90
+ if points_per_batch <= 0:
91
+ raise ValueError(
92
+ "Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
93
+ "To return all points at once, set points_per_batch to None"
94
+ )
95
+
96
+ for i in range(0, n_points, points_per_batch):
97
+ batched_points = grid_points[:, i : i + points_per_batch, :, :]
98
+ labels = input_labels[:, i : i + points_per_batch]
99
+ is_last = i == n_points - points_per_batch
100
+ yield {
101
+ "input_points": batched_points,
102
+ "input_labels": labels,
103
+ "input_boxes": crop_boxes,
104
+ "is_last": is_last,
105
+ **model_inputs,
106
+ }
107
+
108
+
109
+ def draw_mask(mask: NDArray, random_color: bool = False) -> NDArray:
110
+ """Draw the mask on an image.
111
+
112
+ Args:
113
+ mask (NDArray): mask in shape [height, width].
114
+ random_color (bool): if using a random color. Defaults to False.
115
+
116
+ Returns:
117
+ NDArray: NDArray format of the image.
118
+ """
119
+ if random_color:
120
+ color = np.concatenate([np.random.random(3)], axis=0)
121
+ else:
122
+ color = np.array([30 / 255, 144 / 255, 255 / 255])
123
+ h, w = mask.shape[-2:]
124
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
125
+ return mask_image
126
+
127
+
128
+ def decode_sam(
129
+ features: torch.Tensor,
130
+ images: list[Image.Image],
131
+ mask_generator: Any,
132
+ points_per_batch: int = 64,
133
+ pred_iou_thresh: float = 0.5,
134
+ stability_score_thresh: float = 0.6,
135
+ random_color: bool = True,
136
+ device: int | str | torch.device = 0,
137
+ ) -> NDArray:
138
+ """Decode features using SAM (auto-prompting) mask generation pipeline.
139
+
140
+ Args:
141
+ features (torch.Tensor): features to be decoded, should be in shape [batch_size, num_tokens, latent_dim].
142
+ images (list[Image.Image]): images corresponding to these features.
143
+ mask_generator (Any): mask generation pipeline.
144
+ points_per_batch (int): points per batch for auto-prompting. Defaults to 64.
145
+ See transformers.pipelines.MaskGenerationPipeline for more details. Same below.
146
+ pred_iou_thresh (float): iou threshold. Defaults to 0.5.
147
+ stability_score_thresh (float): stability threshold. Defaults to 0.6.
148
+ random_color (bool): if using a random color. Defaults to True.
149
+ device (int | str | torch.device): device to perform the decoding. Defaults to 0.
150
+
151
+ Returns:
152
+ NDArray: decoded masks rendered in image format, represented by an NDArray in size
153
+ [batch_size, height, width, channels] with value between [0, 1].
154
+ """
155
+ masks_rgbs = []
156
+ num_patches = int(features.size(1) ** 0.5)
157
+ features = rearrange(features, "b (h w) c -> b c h w", h=num_patches, w=num_patches)
158
+ with torch.no_grad():
159
+ for im, feature in zip(images, features, strict=False):
160
+ predicted_ouputs = mask_generator(
161
+ im,
162
+ points_per_batch=points_per_batch,
163
+ embeddings=feature.unsqueeze(0).to(device),
164
+ pred_iou_thresh=pred_iou_thresh,
165
+ stability_score_thresh=stability_score_thresh,
166
+ )
167
+ predicted_masks = predicted_ouputs["masks"]
168
+ masks_rgb = np.zeros((224, 224, 3), dtype=np.float32)
169
+ for mask in predicted_masks:
170
+ masks_rgb += draw_mask(mask, random_color=random_color)
171
+ # masks_rgb = cv2.cvtColor(masks_rgb, cv2.COLOR_RGBA2RGB)
172
+ masks_rgbs.append(masks_rgb)
173
+ return np.stack(masks_rgbs)
174
+
175
+
176
+ def prepare_mask_generator(device: int | str | torch.device = 0) -> MaskGenerationPipeline:
177
+ """Prepare a mask generation pipeline on device `device`.
178
+
179
+ Args:
180
+ device (int | str | torch.device): device to perform mask generation. Defaults to 0.
181
+
182
+ Returns:
183
+ MaskGenerationPipeline: mask generator.
184
+ """
185
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
186
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
187
+ sam_model.eval()
188
+ mask_generator = MaskGenerationPipelineWithEmbeddings(
189
+ task="mask_generation", model=sam_model, image_processor=processor.image_processor, device=device
190
+ )
191
+ return mask_generator, sam_model
theia/example/decode_to_vfms.ipynb ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import cv2\n",
11
+ "import torch\n",
12
+ "from PIL import Image\n",
13
+ "import numpy as np\n",
14
+ "from transformers import AutoModel\n",
15
+ "from torchvision.io import read_video, write_video\n",
16
+ "from theia.decoding import load_feature_stats, prepare_depth_decoder, prepare_mask_generator, decode_everything\n",
17
+ "\n",
18
+ "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
19
+ "theia_model = AutoModel.from_pretrained(\"theaiinstitute/theia-base-patch16-224-cdiv\", trust_remote_code=True)\n",
20
+ "theia_model = theia_model.to(device)\n",
21
+ "target_model_names = [\n",
22
+ " \"google/vit-huge-patch14-224-in21k\",\n",
23
+ " \"facebook/dinov2-large\",\n",
24
+ " \"openai/clip-vit-large-patch14\",\n",
25
+ " \"facebook/sam-vit-huge\",\n",
26
+ " \"LiheYoung/depth-anything-large-hf\",\n",
27
+ "]\n",
28
+ "feature_means, feature_vars = load_feature_stats(target_model_names, stat_file_root=\"../../../feature_stats\")\n",
29
+ "\n",
30
+ "mask_generator, sam_model = prepare_mask_generator(device)\n",
31
+ "depth_anything_model_name = \"LiheYoung/depth-anything-large-hf\"\n",
32
+ "depth_anything_decoder, _ = prepare_depth_decoder(depth_anything_model_name, device)\n",
33
+ "\n",
34
+ "example_video_path = \"../../../media/example_video_to_visualize.mp4\"\n",
35
+ "video, _, _ = read_video(example_video_path, pts_unit=\"sec\", output_format=\"THWC\")\n",
36
+ "video = video.numpy()\n",
37
+ "images = [Image.fromarray(cv2.resize(im, (224, 224))) for im in video]\n",
38
+ "\n",
39
+ "theia_decode_results, gt_decode_results = decode_everything(\n",
40
+ " theia_model=theia_model,\n",
41
+ " feature_means=feature_means,\n",
42
+ " feature_vars=feature_vars,\n",
43
+ " images=images,\n",
44
+ " mask_generator=mask_generator,\n",
45
+ " sam_model=sam_model,\n",
46
+ " depth_anything_decoder=depth_anything_decoder,\n",
47
+ " pred_iou_thresh=0.5,\n",
48
+ " stability_score_thresh=0.7,\n",
49
+ " gt=True,\n",
50
+ " device=device,\n",
51
+ ")\n",
52
+ "\n",
53
+ "vis_video = np.stack(\n",
54
+ " [np.vstack([tr, gtr]) for tr, gtr in zip(theia_decode_results, gt_decode_results, strict=False)]\n",
55
+ ")\n",
56
+ "vis_video = torch.from_numpy(vis_video * 255.0).to(torch.uint8)\n",
57
+ "vis_save_path = \"./visualized.mp4\"\n",
58
+ "write_video(vis_save_path, vis_video, fps=10)"
59
+ ]
60
+ }
61
+ ],
62
+ "metadata": {
63
+ "language_info": {
64
+ "name": "python"
65
+ }
66
+ },
67
+ "nbformat": 4,
68
+ "nbformat_minor": 2
69
+ }
theia/foundation_models/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ from .vision_language_models.clip import get_clip_feature, get_clip_model
4
+ from .vision_language_models.llava import get_llava_vision_model, get_llava_visual_feature
5
+ from .vision_models.deit import get_deit_feature, get_deit_model
6
+ from .vision_models.depth_anything import get_depth_anything_feature, get_depth_anything_model
7
+ from .vision_models.dinov2 import get_dinov2_feature, get_dinov2_model
8
+ from .vision_models.sam import get_sam_feature, get_sam_model
9
+ from .vision_models.vit import get_vit_feature, get_vit_model
theia/foundation_models/common.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
2
+
3
+ import math
4
+
5
+ import torch
6
+
7
+ MODELS = [
8
+ "facebook/dinov2-large",
9
+ "facebook/sam-vit-huge",
10
+ "google/vit-huge-patch14-224-in21k",
11
+ "llava-hf/llava-1.5-7b-hf",
12
+ "openai/clip-vit-large-patch14",
13
+ "LiheYoung/depth-anything-large-hf",
14
+ ]
15
+
16
+ # handy model feature size constants
17
+ # in the format of (latent_dim, width, height)
18
+ MODEL_FEATURE_SIZES = {
19
+ "facebook/dinov2-large": (1024, 16, 16),
20
+ "facebook/sam-vit-huge": (256, 64, 64),
21
+ "google/vit-huge-patch14-224-in21k": (1280, 16, 16),
22
+ "llava-hf/llava-1.5-7b-hf": (1024, 24, 24),
23
+ "openai/clip-vit-large-patch14": (1024, 16, 16),
24
+ "LiheYoung/depth-anything-large-hf": (32, 64, 64),
25
+ }
26
+
27
+
28
+ def get_model_feature_size(
29
+ model_name: str, keep_spatial: bool = False, return_torch_size: bool = False
30
+ ) -> tuple[int, ...] | torch.Size:
31
+ """
32
+ Get the size of queried model feature.
33
+
34
+ Args:
35
+ model_name (str): name of the model.
36
+ keep_spatial (bool): whether to preserve spatial dim. Defaults to False.
37
+ return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False.
38
+
39
+ Returns:
40
+ tuple[int, ...] | torch.Size: the size of the feature.
41
+ """
42
+ size: tuple[int, ...] = MODEL_FEATURE_SIZES[model_name]
43
+
44
+ if not keep_spatial:
45
+ size = (size[0], math.prod(size[1:]))
46
+
47
+ if return_torch_size:
48
+ size = torch.Size(size)
49
+
50
+ return size
51
+
52
+
53
+ def get_max_model_spatial_size(
54
+ keep_spatial: bool = True,
55
+ return_torch_size: bool = False,
56
+ return_model_name: bool = False,
57
+ ) -> tuple[int, ...] | tuple[tuple[int, ...], str]:
58
+ """Get the maximal spatial dimensions from available models
59
+
60
+ Args:
61
+ keep_spatial (bool): whether to preserve spatial dim. Defaults to True.
62
+ return_torch_size (bool): return torch.Size instead of python tuple. Defaults to False.
63
+ return_model_name (bool): the name of the model with maximal size. Defaults to False.
64
+
65
+ Returns:
66
+ tuple[int, ...] | tuple[tuple[int, ...], str]: the maximal size and optional model name.
67
+ """
68
+ max_flatten_size = -1
69
+ max_size: tuple[int, ...] = ()
70
+ max_size_model_name: str = ""
71
+ for model, size in MODEL_FEATURE_SIZES.items():
72
+ flatten_size = math.prod(size[1:])
73
+ if flatten_size > max_flatten_size:
74
+ max_flatten_size = flatten_size
75
+ max_size = size[1:]
76
+ max_size_model_name = model
77
+
78
+ if not keep_spatial:
79
+ max_size = (max_flatten_size,)
80
+
81
+ if return_torch_size:
82
+ max_size = torch.Size(max_size)
83
+
84
+ if return_model_name:
85
+ return max_size, max_size_model_name
86
+ else:
87
+ return max_size