RoyYang0714 commited on
Commit
9b33fca
·
1 Parent(s): 41b3aa4

feat: Try to build everything locally.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +81 -0
  2. README.md +1 -0
  3. opendet3d/__init__.py +1 -0
  4. opendet3d/common/__init__.py +0 -0
  5. opendet3d/common/parallel.py +76 -0
  6. opendet3d/data/__init__.py +0 -0
  7. opendet3d/data/datasets/__init__.py +0 -0
  8. opendet3d/data/datasets/argoverse.py +94 -0
  9. opendet3d/data/datasets/coco3d.py +518 -0
  10. opendet3d/data/datasets/odvg.py +280 -0
  11. opendet3d/data/datasets/omni3d/__init__.py +1 -0
  12. opendet3d/data/datasets/omni3d/arkitscenes.py +81 -0
  13. opendet3d/data/datasets/omni3d/hypersim.py +190 -0
  14. opendet3d/data/datasets/omni3d/kitti_object.py +105 -0
  15. opendet3d/data/datasets/omni3d/nuscenes.py +62 -0
  16. opendet3d/data/datasets/omni3d/objectron.py +56 -0
  17. opendet3d/data/datasets/omni3d/omni3d_classes.py +156 -0
  18. opendet3d/data/datasets/omni3d/sunrgbd.py +278 -0
  19. opendet3d/data/datasets/omni3d/util.py +74 -0
  20. opendet3d/data/datasets/scannet.py +449 -0
  21. opendet3d/data/transforms/__init__.py +0 -0
  22. opendet3d/data/transforms/crop.py +43 -0
  23. opendet3d/data/transforms/language.py +267 -0
  24. opendet3d/data/transforms/pad.py +176 -0
  25. opendet3d/data/transforms/resize.py +121 -0
  26. opendet3d/eval/__init__.py +0 -0
  27. opendet3d/eval/detect3d.py +1249 -0
  28. opendet3d/eval/omni3d.py +285 -0
  29. opendet3d/eval/open.py +140 -0
  30. opendet3d/model/__init__.py +0 -0
  31. opendet3d/model/detect/__init__.py +0 -0
  32. opendet3d/model/detect/grounding_dino.py +1050 -0
  33. opendet3d/model/detect3d/__init__.py +0 -0
  34. opendet3d/model/detect3d/grounding_dino_3d.py +812 -0
  35. opendet3d/model/language/__init__.py +0 -0
  36. opendet3d/model/language/mm_bert.py +255 -0
  37. opendet3d/op/__init__.py +0 -0
  38. opendet3d/op/base/__init__.py +0 -0
  39. opendet3d/op/base/swin.py +870 -0
  40. opendet3d/op/box/__init__.py +0 -0
  41. opendet3d/op/box/box2d.py +272 -0
  42. opendet3d/op/box/box3d.py +79 -0
  43. opendet3d/op/box/iou_box3d.py +174 -0
  44. opendet3d/op/box/matchers/__init__.py +0 -0
  45. opendet3d/op/box/matchers/hungarian.py +117 -0
  46. opendet3d/op/detect/__init__.py +0 -0
  47. opendet3d/op/detect/deformable_detr.py +463 -0
  48. opendet3d/op/detect/detr.py +358 -0
  49. opendet3d/op/detect/dino.py +667 -0
  50. opendet3d/op/detect/grounding_dino/__init__.py +17 -0
.gitignore ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # no IntelliJ files
2
+ .idea
3
+
4
+ # don't upload macOS folder info
5
+ *.DS_Store
6
+
7
+ # don't upload node_modules from npm test
8
+ node_modules/*
9
+ flow-typed/*
10
+
11
+ # potential files generated by golang
12
+ bin/
13
+
14
+ # don't upload webpack bundle file
15
+ app/dist/
16
+
17
+ # potential integration testing data directory
18
+ # test_data/
19
+ /data
20
+
21
+ #python
22
+ *.pyc
23
+ __pycache__/
24
+
25
+ # pytype
26
+ .pytype
27
+
28
+ # vscode sftp settings
29
+ .vscode/sftp.json
30
+
31
+ # vscode launch settings
32
+ .vscode/launch.json
33
+
34
+ # redis
35
+ *.rdb
36
+
37
+ # mypy
38
+ .mypy_cache
39
+
40
+ # jest coverage cache
41
+ coverage/
42
+
43
+ # downloaded repos and models
44
+ scalabel/bot/experimental/*
45
+
46
+ # python virtual environment
47
+ env/
48
+
49
+ # vscode workspace configuration
50
+ *.code-workspace
51
+
52
+ # sphinx build folder
53
+ _build/
54
+
55
+ # media files are not in this repo
56
+ doc/media
57
+
58
+ # ignore rope db cache
59
+ .vscode/.ropeproject
60
+
61
+ # python build
62
+ build/
63
+ dist/
64
+
65
+ # coverage
66
+ .coverage*
67
+
68
+ # package default workspace
69
+ vis4d-workspace
70
+
71
+ *.tmp
72
+
73
+ # local test logs and scripts
74
+ log/
75
+ /*.sh
76
+ docs/source/api/*
77
+ docs/source/tutorials/.ipynb_checkpoints/*
78
+ wandb/
79
+
80
+ # No lightning logs
81
+ lightning_logs/
README.md CHANGED
@@ -6,6 +6,7 @@ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.44.0
8
  app_file: app.py
 
9
  pinned: false
10
  ---
11
 
 
6
  sdk: gradio
7
  sdk_version: 5.44.0
8
  app_file: app.py
9
+ app_build_command: cd vis4d_cuda_ops && pip install -v .
10
  pinned: false
11
  ---
12
 
opendet3d/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """3D-MOOD."""
opendet3d/common/__init__.py ADDED
File without changes
opendet3d/common/parallel.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for cpu parallelization."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from multiprocessing import Process, Queue
7
+
8
+ # Disabling unused import becase we need Tuple in typing
9
+ from typing import Callable, Iterable, List, Optional, Tuple, TypeVar
10
+
11
+ from tqdm import tqdm
12
+
13
+ Inputs = TypeVar("Inputs")
14
+ Return = TypeVar("Return")
15
+
16
+ cpu_num = os.cpu_count()
17
+ NPROC: int = min(4, cpu_num if cpu_num else 1)
18
+
19
+
20
+ def run(
21
+ func: Callable[[Inputs], Return],
22
+ q_in: "Queue[Tuple[int, Optional[Tuple[Inputs]]]]",
23
+ q_out: "Queue[Tuple[int, Return]]",
24
+ ) -> None:
25
+ """Run function on the inputs from the queue."""
26
+ while True:
27
+ i, x = q_in.get()
28
+ if i < 0 or x is None:
29
+ break
30
+ q_out.put((i, func(x[0])))
31
+
32
+
33
+ def pmap(
34
+ func: Callable[[Inputs], Return],
35
+ inputs: Iterable[Inputs],
36
+ max_len: int,
37
+ nprocs: int = NPROC,
38
+ ) -> List[Return]:
39
+ """Parrell mapping func to arguments.
40
+
41
+ Different from the python pool map, this function will not hang if any of
42
+ the processes throws an exception.
43
+ """
44
+ q_in: "Queue[Tuple[int, Optional[Tuple[Inputs]]]]" = Queue(1)
45
+ q_out: "Queue[Tuple[int, Return]]" = Queue()
46
+
47
+ proc = [
48
+ Process(target=run, args=(func, q_in, q_out)) for _ in range(nprocs)
49
+ ]
50
+ for p in proc:
51
+ p.daemon = True
52
+ p.start()
53
+
54
+ count = 0
55
+ with tqdm(total=max_len) as pbar:
56
+ for i, x in enumerate(inputs):
57
+ q_in.put((i, (x,)))
58
+ count += 1
59
+
60
+ if count % nprocs == 0:
61
+ pbar.update()
62
+
63
+ pbar.refresh()
64
+
65
+ pbar.update()
66
+ pbar.refresh()
67
+
68
+ for _ in range(nprocs):
69
+ q_in.put((-1, None))
70
+
71
+ res = [q_out.get() for _ in range(count)]
72
+
73
+ for p in proc:
74
+ p.join()
75
+
76
+ return [x for _, x in sorted(res)]
opendet3d/data/__init__.py ADDED
File without changes
opendet3d/data/datasets/__init__.py ADDED
File without changes
opendet3d/data/datasets/argoverse.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Argoverse V2 Sensor dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from vis4d.common.typing import ArgsType, DictStrAny
6
+
7
+ from .coco3d import COCO3DDataset
8
+
9
+ TRAIN_SAMPLE_RATE = 10
10
+ VAL_SAMPLE_RATE = 5
11
+ ACC_FRAMES = 5
12
+
13
+
14
+ av2_class_map = {
15
+ "regular vehicle": 0,
16
+ "pedestrian": 1,
17
+ "bicyclist": 2,
18
+ "motorcyclist": 3,
19
+ "wheeled rider": 4,
20
+ "bollard": 5,
21
+ "construction cone": 6,
22
+ "sign": 7,
23
+ "construction barrel": 8,
24
+ "stop sign": 9,
25
+ "mobile pedestrian crossing sign": 10,
26
+ "large vehicle": 11,
27
+ "bus": 12,
28
+ "box truck": 13,
29
+ "truck": 14,
30
+ "vehicular trailer": 15,
31
+ "truck cab": 16,
32
+ "school bus": 17,
33
+ "articulated bus": 18,
34
+ "message board trailer": 19,
35
+ "bicycle": 20,
36
+ "motorcycle": 21,
37
+ "wheeled device": 22,
38
+ "wheelchair": 23,
39
+ "stroller": 24,
40
+ "dog": 25,
41
+ }
42
+
43
+ av2_det_map = {
44
+ "regular vehicle": 0,
45
+ "pedestrian": 1,
46
+ "bicyclist": 2,
47
+ "motorcyclist": 3,
48
+ "wheeled rider": 4,
49
+ "bollard": 5,
50
+ "construction cone": 6,
51
+ "sign": 7,
52
+ "construction barrel": 8,
53
+ "stop sign": 9,
54
+ "mobile pedestrian crossing sign": 10,
55
+ "large vehicle": 11,
56
+ "bus": 12,
57
+ "box truck": 13,
58
+ "truck": 14,
59
+ "vehicular trailer": 15,
60
+ "truck cab": 16,
61
+ "school bus": 17,
62
+ "articulated bus": 18,
63
+ "bicycle": 19,
64
+ "motorcycle": 20,
65
+ "wheeled device": 21,
66
+ "stroller": 22,
67
+ }
68
+
69
+
70
+ class AV2SensorDataset(COCO3DDataset):
71
+ """Argoverse V2 Sensor dataset."""
72
+
73
+ def __init__(
74
+ self,
75
+ class_map: dict[str, int] = av2_class_map,
76
+ max_depth: float = 80.0,
77
+ depth_scale: float = 256.0,
78
+ **kwargs: ArgsType,
79
+ ) -> None:
80
+ """Creates an instance of the class."""
81
+ super().__init__(
82
+ class_map=class_map,
83
+ max_depth=max_depth,
84
+ depth_scale=depth_scale,
85
+ **kwargs,
86
+ )
87
+
88
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
89
+ """Get the depth filenames."""
90
+ return (
91
+ img["file_path"]
92
+ .replace("images", "depth")
93
+ .replace(".jpg", "_depth.png")
94
+ )
opendet3d/data/datasets/coco3d.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """COCO 3D API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextlib
6
+ import io
7
+ import json
8
+ import os
9
+ import time
10
+ from collections import defaultdict
11
+ from collections.abc import Sequence
12
+
13
+ import numpy as np
14
+ from pycocotools.coco import COCO
15
+ from pyquaternion import Quaternion
16
+ from scipy.spatial.transform import Rotation as R
17
+ from vis4d.common.logging import rank_zero_info, rank_zero_warn
18
+ from vis4d.common.typing import ArgsType, DictStrAny
19
+ from vis4d.data.const import AxisMode
20
+ from vis4d.data.const import CommonKeys as K
21
+ from vis4d.data.datasets.base import Dataset
22
+ from vis4d.data.datasets.util import (
23
+ CacheMappingMixin,
24
+ im_decode,
25
+ print_class_histogram,
26
+ )
27
+ from vis4d.data.typing import DictData
28
+
29
+
30
+ class COCO3DDataset(CacheMappingMixin, Dataset):
31
+ """3D Object Detection Dataset using coco annotation files."""
32
+
33
+ def __init__(
34
+ self,
35
+ data_root: str,
36
+ dataset_name: str,
37
+ class_map: dict[str, int],
38
+ det_map: dict[str, int],
39
+ keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d),
40
+ with_depth: bool = False,
41
+ max_depth: float = 80.0,
42
+ depth_scale: float = 256.0,
43
+ remove_empty: bool = False,
44
+ data_prefix: str | None = None,
45
+ text_prompt_mapping: dict[str, dict[str, str]] | None = None,
46
+ cache_as_binary: bool = False,
47
+ cached_file_path: str | None = None,
48
+ **kwargs: ArgsType,
49
+ ) -> None:
50
+ """Creates an instance of the class."""
51
+ super().__init__(**kwargs)
52
+ self.data_root = data_root
53
+ self.dataset_name = dataset_name
54
+ self.annotation_file = f"{dataset_name}.json"
55
+
56
+ self.keys_to_load = list(keys_to_load)
57
+ self.remove_empty = remove_empty
58
+
59
+ self.class_map = class_map # Class mapping in the annotation file
60
+ self.det_map = det_map # Class mapping for detection
61
+ self.categories = sorted(self.det_map, key=self.det_map.get)
62
+
63
+ self.data_prefix = data_prefix
64
+ self.text_prompt_mapping = text_prompt_mapping
65
+
66
+ # Metric Depth
67
+ if with_depth and not K.depth_maps in keys_to_load:
68
+ self.keys_to_load.append(K.depth_maps)
69
+
70
+ self.max_depth = max_depth
71
+ self.depth_scale = depth_scale
72
+
73
+ # Load annotations
74
+ self.samples, _ = self._load_mapping(
75
+ self._generate_data_mapping,
76
+ self._filter_data,
77
+ cache_as_binary=cache_as_binary,
78
+ cached_file_path=cached_file_path,
79
+ )
80
+
81
+ def __repr__(self) -> str:
82
+ """Concise representation of the dataset."""
83
+ return self.dataset_name
84
+
85
+ def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
86
+ """Remove empty samples."""
87
+ samples = []
88
+
89
+ frequencies = {cat: 0 for cat in sorted(self.det_map)}
90
+
91
+ empty_samples = 0
92
+ no_depth_samples = 0
93
+ for sample in data:
94
+ if self.remove_empty and len(sample["anns"]) == 0:
95
+ empty_samples += 1
96
+ continue
97
+
98
+ if (
99
+ K.depth_maps in self.keys_to_load
100
+ and "depth_filename" not in sample
101
+ ):
102
+ empty_samples += 1
103
+ no_depth_samples += 1
104
+ continue
105
+
106
+ for ann in sample["anns"]:
107
+ frequencies[ann["category_name"]] += 1
108
+
109
+ samples.append(sample)
110
+
111
+ rank_zero_info(
112
+ f"Propocessing {self.dataset_name} with {len(samples)} samples."
113
+ )
114
+ rank_zero_info(f"No depth samples: {no_depth_samples}")
115
+ rank_zero_info(f"Filtered {empty_samples} empty samples")
116
+ print_class_histogram(frequencies)
117
+
118
+ return samples
119
+
120
+ def _get_cat_id(
121
+ self, img: DictStrAny, ann: DictStrAny, cat_name: str
122
+ ) -> None:
123
+ """Get the category id from the category name."""
124
+ ann["category_id"] = self.det_map[cat_name]
125
+
126
+ def _generate_data_mapping(self) -> list[DictStrAny]:
127
+ """Generates the data mapping."""
128
+ # Load annotations
129
+ with contextlib.redirect_stdout(io.StringIO()):
130
+ coco_api = COCO3D(
131
+ os.path.join(
132
+ self.data_root, "annotations", self.annotation_file
133
+ ),
134
+ self.categories,
135
+ )
136
+
137
+ cats_map = {v: k for k, v in self.class_map.items()}
138
+
139
+ img_ids = sorted(coco_api.getImgIds())
140
+ imgs = coco_api.loadImgs(img_ids)
141
+
142
+ samples = []
143
+ for img_id, img in zip(img_ids, imgs):
144
+ # Fix file path for Omni3D
145
+ if self.data_prefix is not None:
146
+ img["file_path"] = os.path.join(
147
+ self.data_prefix, img["file_path"]
148
+ )
149
+
150
+ valid_anns = []
151
+ anns = coco_api.imgToAnns[img_id]
152
+
153
+ boxes = []
154
+ boxes3d = np.empty((0, 10), dtype=np.float32)[1:]
155
+ class_ids = np.empty((0,), dtype=np.int64)[1:]
156
+ for ann in anns:
157
+ cat_name = cats_map[ann["category_id"]]
158
+ assert cat_name == ann["category_name"]
159
+
160
+ if cat_name in {"dontcare", "ignore", "void"}:
161
+ continue
162
+
163
+ if ann["ignore"]:
164
+ continue
165
+
166
+ self._get_cat_id(img, ann, cat_name)
167
+
168
+ # Box 2D
169
+ x1, y1, width, height = ann["bbox"]
170
+ x2, y2 = x1 + width, y1 + height
171
+ boxes.append((x1, y1, x2, y2))
172
+
173
+ # Class
174
+ class_ids = np.concatenate(
175
+ [
176
+ class_ids,
177
+ np.array([ann["category_id"]], dtype=np.int64),
178
+ ]
179
+ )
180
+
181
+ # Box 3D
182
+ center = ann["center_cam"]
183
+ width, height, length = ann["dimensions"]
184
+
185
+ # Check if the rotation matrix is valid
186
+ try:
187
+ x, y, z, w = R.from_matrix(
188
+ np.array(ann["R_cam"])
189
+ ).as_quat()
190
+ except Exception as e:
191
+ rank_zero_warn(
192
+ f"Error processing rotation matrix for annotation {ann['id']}: {e}"
193
+ )
194
+ continue
195
+
196
+ orientation = Quaternion([w, x, y, z])
197
+
198
+ boxes3d = np.concatenate(
199
+ [
200
+ boxes3d,
201
+ np.array(
202
+ [
203
+ [
204
+ *center,
205
+ width,
206
+ length,
207
+ height,
208
+ *orientation.elements,
209
+ ]
210
+ ],
211
+ dtype=np.float32,
212
+ ),
213
+ ]
214
+ )
215
+
216
+ valid_anns.append(ann)
217
+
218
+ boxes2d = (
219
+ np.empty((0, 4), dtype=np.float32)
220
+ if not boxes
221
+ else np.array(boxes, dtype=np.float32)
222
+ )
223
+
224
+ depth_filename = self.get_depth_filenames(img)
225
+
226
+ sample = {
227
+ "img_id": img_id,
228
+ "img": img,
229
+ "anns": valid_anns,
230
+ "boxes2d": boxes2d,
231
+ "boxes3d": boxes3d,
232
+ "class_ids": class_ids,
233
+ }
234
+
235
+ if depth_filename is not None and self.data_backend.exists(
236
+ depth_filename
237
+ ):
238
+ sample["depth_filename"] = depth_filename
239
+
240
+ samples.append(sample)
241
+
242
+ return samples
243
+
244
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
245
+ """Get the depth filenames.
246
+
247
+ Since not every data has depth.
248
+ """
249
+ return None
250
+
251
+ def get_cat_ids(self, idx: int) -> list[int]:
252
+ """Return the samples."""
253
+ return self.samples[idx]["class_ids"].tolist()
254
+
255
+ def __len__(self) -> int:
256
+ """Total number of samples of data."""
257
+ return len(self.samples)
258
+
259
+ def get_depth_map(self, sample: DictStrAny) -> np.ndarray:
260
+ """Get the depth map."""
261
+ depth_bytes = self.data_backend.get(sample["depth_filename"])
262
+ depth_array = im_decode(depth_bytes)
263
+
264
+ depth = np.ascontiguousarray(depth_array, dtype=np.float32)
265
+
266
+ depth = depth / self.depth_scale
267
+
268
+ return depth
269
+
270
+ def __getitem__(self, idx: int) -> DictData:
271
+ """Get single sample.
272
+
273
+ Args:
274
+ idx (int): Index of sample.
275
+
276
+ Returns:
277
+ DictData: sample at index in Vis4D input format.
278
+ """
279
+ sample = self.samples[idx]
280
+ data_dict: DictData = {}
281
+
282
+ # Get image info
283
+ data_dict[K.sample_names] = sample["img_id"]
284
+
285
+ data_dict["dataset_name"] = self.dataset_name
286
+ data_dict[K.boxes2d_names] = self.categories
287
+ data_dict["text_prompt_mapping"] = self.text_prompt_mapping
288
+
289
+ if K.images in self.keys_to_load:
290
+ im_bytes = self.data_backend.get(sample["img"]["file_path"])
291
+ image = np.ascontiguousarray(
292
+ im_decode(im_bytes, mode=self.image_channel_mode),
293
+ dtype=np.float32,
294
+ )[None]
295
+
296
+ data_dict[K.images] = image
297
+ data_dict[K.input_hw] = (image.shape[1], image.shape[2])
298
+
299
+ data_dict[K.original_images] = image
300
+ data_dict[K.original_hw] = (image.shape[1], image.shape[2])
301
+
302
+ # Get camera info
303
+ intrinsics = np.array(sample["img"]["K"], dtype=np.float32)
304
+ data_dict[K.intrinsics] = intrinsics
305
+ data_dict["original_intrinsics"] = intrinsics
306
+
307
+ data_dict[K.boxes2d] = sample["boxes2d"]
308
+ data_dict[K.boxes2d_classes] = sample["class_ids"]
309
+ data_dict[K.boxes3d] = sample["boxes3d"]
310
+ data_dict[K.boxes3d_classes] = sample["class_ids"]
311
+ data_dict[K.axis_mode] = AxisMode.OPENCV
312
+
313
+ if K.depth_maps in self.keys_to_load:
314
+ depth = self.get_depth_map(sample)
315
+
316
+ depth[depth > self.max_depth] = 0
317
+
318
+ data_dict[K.depth_maps] = depth
319
+
320
+ data_dict["tokens_positive"] = None
321
+
322
+ self.data_backend.close()
323
+
324
+ return data_dict
325
+
326
+
327
+ class COCO3D(COCO):
328
+ """COCO API with 3D annotations."""
329
+
330
+ def __init__(
331
+ self,
332
+ annotation_files: Sequence[str] | str,
333
+ category_names: Sequence[str] | None = None,
334
+ ignore_names: Sequence[str] = ("dontcare", "ignore", "void"),
335
+ truncation_thres: float = 0.33333333,
336
+ visibility_thres: float = 0.33333333,
337
+ min_height_thres: float = 0.0625,
338
+ max_height_thres: float = 1.50,
339
+ modal_2D_boxes: bool = False,
340
+ trunc_2D_boxes: bool = True,
341
+ max_depth: int = 1e8,
342
+ ) -> None:
343
+ """Creates an instance of the class."""
344
+ self.dataset, self.anns, self.cats, self.imgs = {}, {}, {}, {}
345
+ self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
346
+
347
+ self.truncation_thres = truncation_thres
348
+ self.visibility_thres = visibility_thres
349
+ self.min_height_thres = min_height_thres
350
+ self.max_height_thres = max_height_thres
351
+ self.max_depth = max_depth
352
+
353
+ if isinstance(annotation_files, str):
354
+ annotation_files = [annotation_files]
355
+
356
+ cats_ids_master = []
357
+ cats_master = []
358
+
359
+ for annotation_file in annotation_files:
360
+ _, tail = os.path.split(annotation_file)
361
+ name, _ = os.path.splitext(tail)
362
+
363
+ print(f"loading {name} annotations into memory...")
364
+ tic = time.time()
365
+
366
+ with open(annotation_file, "r") as f:
367
+ dataset = json.load(f)
368
+
369
+ assert (
370
+ type(dataset) == dict
371
+ ), f"annotation file format {type(dataset)} not supported"
372
+ print(f"Done (t={time.time() - tic:.2f}s)")
373
+
374
+ if type(dataset["info"]) == list:
375
+ dataset["info"] = dataset["info"][0]
376
+
377
+ dataset["info"]["known_category_ids"] = [
378
+ cat["id"] for cat in dataset["categories"]
379
+ ]
380
+
381
+ # first dataset
382
+ if len(self.dataset) == 0:
383
+ self.dataset = dataset
384
+ # concatenate datasets
385
+ else:
386
+ if type(self.dataset["info"]) == dict:
387
+ self.dataset["info"] = [self.dataset["info"]]
388
+
389
+ self.dataset["info"] += [dataset["info"]]
390
+ self.dataset["annotations"] += dataset["annotations"]
391
+ self.dataset["images"] += dataset["images"]
392
+
393
+ # sort through categories
394
+ for cat in dataset["categories"]:
395
+ if not cat["id"] in cats_ids_master:
396
+ cats_ids_master.append(cat["id"])
397
+ cats_master.append(cat)
398
+
399
+ # category names are provided to us
400
+ if category_names is not None:
401
+ self.dataset["categories"] = [
402
+ cats_master[i]
403
+ for i in np.argsort(cats_ids_master)
404
+ if cats_master[i]["name"] in category_names
405
+ ]
406
+ # no categories are provided, so assume use ALL available.
407
+ else:
408
+ self.dataset["categories"] = [
409
+ cats_master[i] for i in np.argsort(cats_ids_master)
410
+ ]
411
+
412
+ category_names = [
413
+ cat["name"] for cat in self.dataset["categories"]
414
+ ]
415
+
416
+ # determine which categories we may actually use for filtering.
417
+ trainable_cats = set(ignore_names) | set(category_names)
418
+
419
+ valid_anns = []
420
+ im_height_map = {}
421
+
422
+ for im_obj in self.dataset["images"]:
423
+ im_height_map[im_obj["id"]] = im_obj["height"]
424
+
425
+ # Filter out annotations
426
+ for anno_idx, anno in enumerate(self.dataset["annotations"]):
427
+
428
+ im_height = im_height_map[anno["image_id"]]
429
+
430
+ # tightly annotated 2D boxes are not always available.
431
+ if (
432
+ modal_2D_boxes
433
+ and "bbox2D_tight" in anno
434
+ and anno["bbox2D_tight"][0] != -1
435
+ ):
436
+ bbox2D = anno["bbox2D_tight"]
437
+ elif (
438
+ trunc_2D_boxes
439
+ and "bbox2D_trunc" in anno
440
+ and not np.all([val == -1 for val in anno["bbox2D_trunc"]])
441
+ ):
442
+ bbox2D = anno["bbox2D_trunc"]
443
+ elif anno["bbox2D_proj"][0] != -1:
444
+ bbox2D = anno["bbox2D_proj"]
445
+ elif anno["bbox2D_tight"][0] != -1:
446
+ bbox2D = anno["bbox2D_tight"]
447
+ else:
448
+ continue
449
+
450
+ # convert to xywh
451
+ bbox2D[2] = bbox2D[2] - bbox2D[0]
452
+ bbox2D[3] = bbox2D[3] - bbox2D[1]
453
+
454
+ ignore = self.is_ignore(anno, bbox2D, ignore_names, im_height)
455
+
456
+ width = bbox2D[2]
457
+ height = bbox2D[3]
458
+
459
+ self.dataset["annotations"][anno_idx]["area"] = width * height
460
+ self.dataset["annotations"][anno_idx]["iscrowd"] = False
461
+ self.dataset["annotations"][anno_idx]["ignore"] = ignore
462
+ self.dataset["annotations"][anno_idx]["ignore2D"] = ignore
463
+ self.dataset["annotations"][anno_idx]["ignore3D"] = ignore
464
+
465
+ self.dataset["annotations"][anno_idx]["bbox"] = bbox2D
466
+ self.dataset["annotations"][anno_idx]["bbox3D"] = anno[
467
+ "bbox3D_cam"
468
+ ]
469
+ self.dataset["annotations"][anno_idx]["depth"] = anno[
470
+ "center_cam"
471
+ ][2]
472
+
473
+ category_name = anno["category_name"]
474
+
475
+ if category_name in trainable_cats:
476
+ valid_anns.append(self.dataset["annotations"][anno_idx])
477
+
478
+ self.dataset["annotations"] = valid_anns
479
+
480
+ self.createIndex()
481
+
482
+ def is_ignore(
483
+ self,
484
+ anno,
485
+ bbox2D: list[float, float, float, float],
486
+ ignore_names: Sequence[str] | None,
487
+ image_height: int,
488
+ ) -> bool:
489
+ ignore = anno["behind_camera"]
490
+ ignore |= not bool(anno["valid3D"])
491
+
492
+ if ignore:
493
+ return ignore
494
+
495
+ ignore |= anno["dimensions"][0] <= 0
496
+ ignore |= anno["dimensions"][1] <= 0
497
+ ignore |= anno["dimensions"][2] <= 0
498
+ ignore |= anno["center_cam"][2] > self.max_depth
499
+ ignore |= anno["lidar_pts"] == 0
500
+ ignore |= anno["segmentation_pts"] == 0
501
+ ignore |= anno["depth_error"] > 0.5
502
+
503
+ ignore |= bbox2D[3] <= self.min_height_thres * image_height
504
+ ignore |= bbox2D[3] >= self.max_height_thres * image_height
505
+
506
+ ignore |= (
507
+ anno["truncation"] >= 0
508
+ and anno["truncation"] >= self.truncation_thres
509
+ )
510
+ ignore |= (
511
+ anno["visibility"] >= 0
512
+ and anno["visibility"] <= self.visibility_thres
513
+ )
514
+
515
+ if ignore_names is not None:
516
+ ignore |= anno["category_name"] in ignore_names
517
+
518
+ return ignore
opendet3d/data/datasets/odvg.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Object detection and visual grounding dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os.path as osp
7
+
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ from vis4d.common.logging import rank_zero_info
11
+ from vis4d.common.typing import ArgsType, DictStrAny
12
+ from vis4d.data.const import CommonKeys as K
13
+ from vis4d.data.datasets.base import Dataset
14
+ from vis4d.data.datasets.util import (
15
+ CacheMappingMixin,
16
+ im_decode,
17
+ print_class_histogram,
18
+ )
19
+ from vis4d.data.typing import DictData
20
+
21
+
22
+ class ODVGDataset(CacheMappingMixin, Dataset):
23
+ """Object detection and visual grounding dataset."""
24
+
25
+ def __init__(
26
+ self,
27
+ data_root: str,
28
+ ann_file: str,
29
+ label_map_file: str | None = None,
30
+ dataset_type: str = "VG",
31
+ dataset_prefix: str | None = None,
32
+ remove_empty: bool = False,
33
+ cache_as_binary: bool = False,
34
+ cached_file_path: str | None = None,
35
+ **kwargs: ArgsType,
36
+ ) -> None:
37
+ """Create an object detection and visual grounding dataset."""
38
+ super().__init__(**kwargs)
39
+
40
+ self.data_root = data_root
41
+ self.ann_file = ann_file
42
+ self.dataset_type = dataset_type
43
+ self.dataset_prefix = dataset_prefix
44
+ self.remove_empty = remove_empty
45
+
46
+ if label_map_file is not None:
47
+ label_map_file = osp.join(self.data_root, label_map_file)
48
+
49
+ with open(label_map_file, "r") as file:
50
+ # dict[class_id (str): class_name (str)]
51
+ self.label_map = json.load(file)
52
+
53
+ self.dataset_type = "OD"
54
+
55
+ self.det_map = {v: int(k) for k, v in self.label_map.items()}
56
+ self.categories = sorted(self.det_map, key=self.det_map.get)
57
+ else:
58
+ self.label_map = None
59
+ self.dataset_type = "VG"
60
+
61
+ # Load annotations
62
+ self.samples, _ = self._load_mapping(
63
+ self._generate_data_mapping,
64
+ self._filter_data,
65
+ cache_as_binary=cache_as_binary,
66
+ cached_file_path=cached_file_path,
67
+ )
68
+
69
+ def __repr__(self) -> str:
70
+ """Concise representation of the dataset."""
71
+ return f"ODVGDataset({self.ann_file})"
72
+
73
+ def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
74
+ """Remove empty samples."""
75
+ samples = []
76
+
77
+ if self.dataset_type == "OD":
78
+ frequencies = {cat: 0 for _, cat in self.label_map.items()}
79
+
80
+ empty_samples = 0
81
+ for sample in data:
82
+ if self.remove_empty and len(sample["anns"]) == 0:
83
+ empty_samples += 1
84
+ continue
85
+
86
+ if self.dataset_type == "OD":
87
+ for ann in sample["anns"]:
88
+ frequencies[ann["category"]] += 1
89
+
90
+ samples.append(sample)
91
+
92
+ rank_zero_info(f"Propocessing {self} with {len(samples)} samples.")
93
+ rank_zero_info(f"Filtered {empty_samples} empty samples")
94
+
95
+ if self.dataset_type == "OD":
96
+ frequencies = dict(sorted(frequencies.items()))
97
+
98
+ print_class_histogram(frequencies)
99
+
100
+ return samples
101
+
102
+ def _generate_data_mapping(self) -> list[DictStrAny]:
103
+ """Generates the data mapping."""
104
+ with open(osp.join(self.data_root, self.ann_file), "r") as f:
105
+ data_list = [json.loads(line) for line in f]
106
+
107
+ if self.with_camera:
108
+ with open(osp.join(self.data_root, "cam_info.json"), "r") as f:
109
+ cameras = json.load(f)
110
+
111
+ samples = []
112
+ for data in tqdm(data_list):
113
+ data_info = {}
114
+
115
+ if self.dataset_prefix is not None:
116
+ img_path = osp.join(
117
+ self.data_root, self.dataset_prefix, data["filename"]
118
+ )
119
+ else:
120
+ img_path = osp.join(self.data_root, data["filename"])
121
+
122
+ data_info["img_path"] = img_path
123
+
124
+ # Pseudo K
125
+ if self.with_camera:
126
+ data_info["K"] = cameras[img_path][0]
127
+
128
+ # Pseudo Depth Path
129
+ if self.dataset_prefix is not None:
130
+ depth_path = osp.join(
131
+ self.data_root,
132
+ f"{self.dataset_prefix}_depth",
133
+ data["filename"].replace(".jpg", "_depth.png"),
134
+ )
135
+ else:
136
+ depth_path = osp.join(
137
+ self.data_root,
138
+ data["filename"].replace(".jpg", "_depth.png"),
139
+ )
140
+ data_info["depth_path"] = depth_path
141
+
142
+ data_info["height"] = data["height"]
143
+ data_info["width"] = data["width"]
144
+
145
+ valid_anns = []
146
+ boxes = []
147
+ class_ids = np.empty((0,), dtype=np.int64)[1:]
148
+ if self.dataset_type == "OD":
149
+ instances = data.get("detection", {}).get("instances", [])
150
+
151
+ for ann in instances:
152
+ bbox = ann["bbox"]
153
+
154
+ # Box 2D
155
+ x1, y1, x2, y2 = bbox
156
+ inter_w = max(0, min(x2, data["width"]) - max(x1, 0))
157
+ inter_h = max(0, min(y2, data["height"]) - max(y1, 0))
158
+
159
+ if inter_w * inter_h == 0:
160
+ continue
161
+ if (x2 - x1) < 1 or (y2 - y1) < 1:
162
+ continue
163
+
164
+ boxes.append(bbox)
165
+
166
+ # Class
167
+ class_ids = np.concatenate(
168
+ [class_ids, np.array([ann["label"]], dtype=np.int64)]
169
+ )
170
+
171
+ valid_anns.append(ann)
172
+ else:
173
+ anno = data["grounding"]
174
+
175
+ caption = anno["caption"].lower().strip()
176
+ if not caption.endswith("."):
177
+ caption = caption + ". "
178
+
179
+ data_info["caption"] = caption
180
+
181
+ regions = anno["regions"]
182
+ phrases = []
183
+ positive_positions = []
184
+ for i, region in enumerate(regions):
185
+ bboxes = region["bbox"]
186
+
187
+ if not isinstance(bboxes[0], list):
188
+ bboxes = [bboxes]
189
+
190
+ for bbox in bboxes:
191
+ x1, y1, x2, y2 = bbox
192
+ inter_w = max(0, min(x2, data["width"]) - max(x1, 0))
193
+ inter_h = max(0, min(y2, data["height"]) - max(y1, 0))
194
+
195
+ if inter_w * inter_h == 0:
196
+ continue
197
+ if (x2 - x1) < 1 or (y2 - y1) < 1:
198
+ continue
199
+
200
+ boxes.append(bbox)
201
+ phrases.append(region["phrase"])
202
+ positive_positions.append(region["tokens_positive"])
203
+ valid_anns.append(region)
204
+
205
+ class_ids = np.concatenate(
206
+ [class_ids, np.array([i], dtype=np.int64)]
207
+ )
208
+
209
+ data_info["phrases"] = phrases
210
+ data_info["positive_positions"] = positive_positions
211
+
212
+ boxes2d = (
213
+ np.empty((0, 4), dtype=np.float32)
214
+ if not boxes
215
+ else np.array(boxes, dtype=np.float32)
216
+ )
217
+
218
+ data_info["boxes2d"] = boxes2d
219
+ data_info["class_ids"] = class_ids
220
+ data_info["anns"] = valid_anns
221
+
222
+ samples.append(data_info)
223
+
224
+ del data_list
225
+ return samples
226
+
227
+ def get_cat_ids(self, idx: int) -> list[int]:
228
+ """Return the samples."""
229
+ return self.samples[idx]["class_ids"].tolist()
230
+
231
+ def __len__(self) -> int:
232
+ """Total number of samples of data."""
233
+ return len(self.samples)
234
+
235
+ def __getitem__(self, idx: int) -> DictData:
236
+ """Get single sample.
237
+
238
+ Args:
239
+ idx (int): Index of sample.
240
+
241
+ Returns:
242
+ DictData: sample at index in Vis4D input format.
243
+ """
244
+ sample = self.samples[idx]
245
+ data_dict: DictData = {}
246
+
247
+ # Get image info
248
+ sample_name = sample["img_path"].split("/")[-1]
249
+ data_dict[K.sample_names] = sample_name
250
+
251
+ im_bytes = self.data_backend.get(sample["img_path"])
252
+ image = np.ascontiguousarray(
253
+ im_decode(im_bytes, mode=self.image_channel_mode),
254
+ dtype=np.float32,
255
+ )[None]
256
+
257
+ data_dict[K.images] = image
258
+ data_dict[K.input_hw] = (image.shape[1], image.shape[2])
259
+
260
+ data_dict[K.original_images] = image
261
+ data_dict[K.original_hw] = (image.shape[1], image.shape[2])
262
+
263
+ data_dict[K.boxes2d] = sample["boxes2d"]
264
+ data_dict[K.boxes2d_classes] = sample["class_ids"]
265
+
266
+ if self.dataset_type == "OD":
267
+ data_dict[K.boxes2d_names] = self.categories
268
+ data_dict["phrases"] = None
269
+ data_dict["positive_positions"] = None
270
+ else:
271
+ data_dict[K.boxes2d_names] = sample["caption"]
272
+ data_dict["phrases"] = sample["phrases"]
273
+ data_dict["positive_positions"] = sample["positive_positions"]
274
+
275
+ data_dict["dataset_type"] = self.dataset_type
276
+ data_dict["label_map"] = self.label_map
277
+
278
+ self.data_backend.close()
279
+
280
+ return data_dict
opendet3d/data/datasets/omni3d/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Omni3D Dataset."""
opendet3d/data/datasets/omni3d/arkitscenes.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ARKitScenes from Omni3D."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ from vis4d.common.typing import ArgsType, DictStrAny
8
+
9
+ from opendet3d.data.datasets.coco3d import COCO3DDataset
10
+
11
+ from .omni3d_classes import omni3d_class_map
12
+
13
+ arkitscenes_det_map = {
14
+ "bathtub": 0,
15
+ "bed": 1,
16
+ "cabinet": 2,
17
+ "chair": 3,
18
+ "fireplace": 4,
19
+ "machine": 5,
20
+ "oven": 6,
21
+ "refrigerator": 7,
22
+ "shelves": 8,
23
+ "sink": 9,
24
+ "sofa": 10,
25
+ "stove": 11,
26
+ "table": 12,
27
+ "television": 13,
28
+ "toilet": 14,
29
+ }
30
+
31
+ omni3d_arkitscenes_det_map = {
32
+ "table": 0,
33
+ "bed": 1,
34
+ "sofa": 2,
35
+ "television": 3,
36
+ "refrigerator": 4,
37
+ "chair": 5,
38
+ "oven": 6,
39
+ "machine": 7,
40
+ "stove": 8,
41
+ "shelves": 9,
42
+ "sink": 10,
43
+ "cabinet": 11,
44
+ "bathtub": 12,
45
+ "toilet": 13,
46
+ }
47
+
48
+
49
+ class ARKitScenes(COCO3DDataset):
50
+ """ARKitScenes Dataset."""
51
+
52
+ def __init__(
53
+ self,
54
+ class_map: dict[str, int] = omni3d_class_map,
55
+ max_depth: float = 10.0,
56
+ depth_scale: float = 1000.0,
57
+ **kwargs: ArgsType,
58
+ ) -> None:
59
+ """Creates an instance of the class."""
60
+ super().__init__(
61
+ class_map=class_map,
62
+ max_depth=max_depth,
63
+ depth_scale=depth_scale,
64
+ **kwargs,
65
+ )
66
+
67
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
68
+ """Get the depth filenames.
69
+
70
+ Since not every data has depth.
71
+ """
72
+ _, _, split, video_id, image_name = img["file_path"].split("/")
73
+
74
+ depth_filename = os.path.join(
75
+ "data/ARKitScenes_depth",
76
+ split,
77
+ video_id,
78
+ image_name.replace("jpg", "png"),
79
+ )
80
+
81
+ return depth_filename
opendet3d/data/datasets/omni3d/hypersim.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hypersim from Omni3D."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ from vis4d.common.typing import ArgsType, DictStrAny
8
+
9
+ from opendet3d.data.datasets.coco3d import COCO3DDataset
10
+
11
+ from .omni3d_classes import omni3d_class_map
12
+
13
+ hypersim_train_det_map = {
14
+ "bathtub": 0,
15
+ "bed": 1,
16
+ "blinds": 2,
17
+ "bookcase": 3,
18
+ "books": 4,
19
+ "box": 5,
20
+ "cabinet": 6,
21
+ "chair": 7,
22
+ "clothes": 8,
23
+ "counter": 9,
24
+ "curtain": 10,
25
+ "desk": 11,
26
+ "door": 12,
27
+ "dresser": 13,
28
+ "floor mat": 14,
29
+ "lamp": 15,
30
+ "mirror": 16,
31
+ "night stand": 17,
32
+ "person": 18,
33
+ "picture": 19,
34
+ "pillow": 20,
35
+ "refrigerator": 21,
36
+ "shelves": 22,
37
+ "sink": 23,
38
+ "sofa": 24,
39
+ "stationery": 25,
40
+ "table": 26,
41
+ "television": 27,
42
+ "toilet": 28,
43
+ "towel": 29,
44
+ "window": 30,
45
+ }
46
+
47
+ hypersim_val_det_map = {
48
+ "bathtub": 0,
49
+ "bed": 1,
50
+ "blinds": 2,
51
+ "bookcase": 3,
52
+ "books": 4,
53
+ "box": 5,
54
+ "cabinet": 6,
55
+ "chair": 7,
56
+ "clothes": 8,
57
+ "counter": 9,
58
+ "curtain": 10,
59
+ "desk": 11,
60
+ "door": 12,
61
+ "dresser": 13,
62
+ "floor mat": 14,
63
+ "lamp": 15,
64
+ "mirror": 16,
65
+ "night stand": 17,
66
+ "picture": 18,
67
+ "pillow": 19,
68
+ "refrigerator": 20,
69
+ "shelves": 21,
70
+ "sink": 22,
71
+ "sofa": 23,
72
+ "stationery": 24,
73
+ "table": 25,
74
+ "television": 26,
75
+ "toilet": 27,
76
+ "towel": 28,
77
+ "window": 29,
78
+ }
79
+
80
+ hypersim_test_det_map = {
81
+ "bathtub": 0,
82
+ "bed": 1,
83
+ "blinds": 2,
84
+ "board": 3,
85
+ "bookcase": 4,
86
+ "books": 5,
87
+ "box": 6,
88
+ "cabinet": 7,
89
+ "chair": 8,
90
+ "clothes": 9,
91
+ "counter": 10,
92
+ "curtain": 11,
93
+ "desk": 12,
94
+ "door": 13,
95
+ "floor mat": 14,
96
+ "lamp": 15,
97
+ "mirror": 16,
98
+ "night stand": 17,
99
+ "picture": 18,
100
+ "pillow": 19,
101
+ "refrigerator": 20,
102
+ "shelves": 21,
103
+ "sink": 22,
104
+ "sofa": 23,
105
+ "stationery": 24,
106
+ "table": 25,
107
+ "television": 26,
108
+ "towel": 27,
109
+ "window": 28,
110
+ }
111
+
112
+
113
+ omni3d_hypersim_det_map = {
114
+ "books": 0,
115
+ "chair": 1,
116
+ "towel": 2,
117
+ "blinds": 3,
118
+ "window": 4,
119
+ "lamp": 5,
120
+ "shelves": 6,
121
+ "mirror": 7,
122
+ "sink": 8,
123
+ "cabinet": 9,
124
+ "bathtub": 10,
125
+ "door": 11,
126
+ "desk": 12,
127
+ "box": 13,
128
+ "bookcase": 14,
129
+ "picture": 15,
130
+ "table": 16,
131
+ "counter": 17,
132
+ "bed": 18,
133
+ "night stand": 19,
134
+ "pillow": 20,
135
+ "sofa": 21,
136
+ "television": 22,
137
+ "floor mat": 23,
138
+ "curtain": 24,
139
+ "clothes": 25,
140
+ "stationery": 26,
141
+ "refrigerator": 27,
142
+ }
143
+
144
+
145
+ def get_hypersim_det_map(split: str) -> dict[str, int]:
146
+ """Get Hypersim detection map."""
147
+ assert split in {"train", "val", "test"}, f"Invalid split: {split}"
148
+
149
+ if split == "train":
150
+ return hypersim_train_det_map
151
+ elif split == "val":
152
+ return hypersim_val_det_map
153
+ elif split == "test":
154
+ return hypersim_test_det_map
155
+
156
+
157
+ class Hypersim(COCO3DDataset):
158
+ """Hypersim Dataset."""
159
+
160
+ def __init__(
161
+ self,
162
+ class_map: dict[str, int] = omni3d_class_map,
163
+ max_depth: float = 50.0,
164
+ depth_scale: float = 1000.0,
165
+ **kwargs: ArgsType,
166
+ ) -> None:
167
+ """Creates an instance of the class."""
168
+ super().__init__(
169
+ class_map=class_map,
170
+ max_depth=max_depth,
171
+ depth_scale=depth_scale,
172
+ **kwargs,
173
+ )
174
+
175
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
176
+ """Get the depth filenames.
177
+
178
+ Since not every data has depth.
179
+ """
180
+ _, _, scene, _, img_dir, img_name = img["file_path"].split("/")
181
+
182
+ depth_filename = os.path.join(
183
+ "data/hypersim_depth",
184
+ scene,
185
+ "images",
186
+ img_dir,
187
+ img_name.replace("jpg", "png"),
188
+ )
189
+
190
+ return depth_filename
opendet3d/data/datasets/omni3d/kitti_object.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """KITTI Object from Omni3D.
2
+
3
+ KITTI Object Labels:
4
+ Categories, -, -, alpha, x1, y1, x2, y2, h, w, l, x, botom_y, z, ry
5
+
6
+ KITTI Object Categories:
7
+ {
8
+ "Pedestrian": "pedestrian",
9
+ "Cyclist": "cyclist",
10
+ "Car": "car",
11
+ "Van": "car",
12
+ "Truck": "truck",
13
+ "Tram": "tram",
14
+ "Person": "pedestrian",
15
+ "Person_sitting": "pedestrian",
16
+ "Misc": "misc",
17
+ "DontCare": "dontcare",
18
+ }
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import os
24
+
25
+ from vis4d.common.typing import ArgsType, DictStrAny
26
+
27
+ from opendet3d.data.datasets.coco3d import COCO3DDataset
28
+
29
+ from .omni3d_classes import omni3d_class_map
30
+
31
+ kitti_train_det_map = kitti_test_det_map = {
32
+ "car": 0,
33
+ "cyclist": 1,
34
+ "pedestrian": 2,
35
+ "person": 3,
36
+ "tram": 4,
37
+ "truck": 5,
38
+ "van": 6,
39
+ }
40
+
41
+ kitti_val_det_map = {
42
+ "car": 0,
43
+ "cyclist": 1,
44
+ "pedestrian": 2,
45
+ "tram": 3,
46
+ "truck": 4,
47
+ }
48
+
49
+ # KITTI-Omni3D Mapping
50
+ omni3d_kitti_det_map = {
51
+ "pedestrian": 0,
52
+ "car": 1,
53
+ "cyclist": 2,
54
+ "van": 3,
55
+ "truck": 4,
56
+ }
57
+
58
+
59
+ def get_kitti_det_map(split: str) -> dict[str, int]:
60
+ """Get the KITTI detection map."""
61
+ assert split in {"train", "val", "test"}, f"Invalid split: {split}"
62
+
63
+ if split == "val":
64
+ return kitti_val_det_map
65
+
66
+ # Train and Test are the same
67
+ return kitti_train_det_map
68
+
69
+
70
+ class KITTIObject(COCO3DDataset):
71
+ """KITTI Object Dataset."""
72
+
73
+ def __init__(
74
+ self,
75
+ class_map: dict[str, int] = omni3d_class_map,
76
+ max_depth: float = 80.0,
77
+ depth_scale: float = 256.0,
78
+ depth_data_root: str = "data/KITTI_object_depth",
79
+ **kwargs: ArgsType,
80
+ ) -> None:
81
+ """Creates an instance of the class."""
82
+ self.depth_data_root = depth_data_root
83
+
84
+ super().__init__(
85
+ class_map=class_map,
86
+ max_depth=max_depth,
87
+ depth_scale=depth_scale,
88
+ **kwargs,
89
+ )
90
+
91
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
92
+ """Get the depth filenames.
93
+
94
+ Since not every data has depth.
95
+ """
96
+ _, _, split, image_id, img_filename = img["file_path"].split("/")
97
+
98
+ depth_filename = os.path.join(
99
+ self.depth_data_root,
100
+ split,
101
+ image_id,
102
+ img_filename.replace(".jpg", ".png"),
103
+ )
104
+
105
+ return depth_filename
opendet3d/data/datasets/omni3d/nuscenes.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """nuScenes from Omni3D."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from vis4d.common.typing import ArgsType, DictStrAny
6
+
7
+ from opendet3d.data.datasets.coco3d import COCO3DDataset
8
+
9
+ from .omni3d_classes import omni3d_class_map
10
+
11
+ nusc_det_map = {
12
+ "bicycle": 0,
13
+ "motorcycle": 1,
14
+ "pedestrian": 2,
15
+ "bus": 3,
16
+ "car": 4,
17
+ "trailer": 5,
18
+ "truck": 6,
19
+ "traffic cone": 7,
20
+ "barrier": 8,
21
+ }
22
+
23
+
24
+ class nuScenes(COCO3DDataset):
25
+ """nuScenes dataset."""
26
+
27
+ def __init__(
28
+ self,
29
+ class_map: dict[str, int] = omni3d_class_map,
30
+ max_depth: float = 80.0,
31
+ depth_scale: float = 256.0,
32
+ **kwargs: ArgsType,
33
+ ) -> None:
34
+ """Creates an instance of the class."""
35
+ super().__init__(
36
+ class_map=class_map,
37
+ max_depth=max_depth,
38
+ depth_scale=depth_scale,
39
+ **kwargs,
40
+ )
41
+
42
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
43
+ """Get the depth filenames.
44
+
45
+ Since not every data has depth.
46
+ """
47
+ img["file_path"] = img["file_path"].replace("nuScenes", "nuscenes")
48
+
49
+ depth_filename = (
50
+ img["file_path"]
51
+ .replace("nuscenes", "nuscenes_depth")
52
+ .replace("jpg", "png")
53
+ )
54
+ return depth_filename
55
+
56
+ def get_cat_ids(self, idx: int) -> list[int]:
57
+ """Return the samples."""
58
+ return self.samples[idx]["class_ids"].tolist()
59
+
60
+ def __len__(self) -> int:
61
+ """Total number of samples of data."""
62
+ return len(self.samples)
opendet3d/data/datasets/omni3d/objectron.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Objectron from Omni3D."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ from vis4d.common.typing import ArgsType, DictStrAny
8
+
9
+ from opendet3d.data.datasets.coco3d import COCO3DDataset
10
+
11
+ from .omni3d_classes import omni3d_class_map
12
+
13
+ objectron_det_map = {
14
+ "bicycle": 0,
15
+ "books": 1,
16
+ "bottle": 2,
17
+ "camera": 3,
18
+ "cereal box": 4,
19
+ "chair": 5,
20
+ "cup": 6,
21
+ "laptop": 7,
22
+ "shoes": 8,
23
+ }
24
+
25
+
26
+ class Objectron(COCO3DDataset):
27
+ """Objectron dataset."""
28
+
29
+ def __init__(
30
+ self,
31
+ class_map: dict[str, int] = omni3d_class_map,
32
+ max_depth: float = 12.0,
33
+ depth_scale: float = 1000.0,
34
+ **kwargs: ArgsType,
35
+ ) -> None:
36
+ """Creates an instance of the class."""
37
+ super().__init__(
38
+ class_map=class_map,
39
+ max_depth=max_depth,
40
+ depth_scale=depth_scale,
41
+ **kwargs,
42
+ )
43
+
44
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
45
+ """Get the depth filenames.
46
+
47
+ Since not every data has depth.
48
+ """
49
+ _, _, split, img_name = img["file_path"].split("/")
50
+
51
+ depth_filename = os.path.join(
52
+ "data/objectron_depth",
53
+ split,
54
+ img_name.replace(".jpg", "_depth.png"),
55
+ )
56
+ return depth_filename
opendet3d/data/datasets/omni3d/omni3d_classes.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Omni3D classes."""
2
+
3
+ omni3d_class_map = {
4
+ "pedestrian": 0,
5
+ "car": 1,
6
+ "dontcare": 2,
7
+ "cyclist": 3,
8
+ "van": 4,
9
+ "truck": 5,
10
+ "tram": 6,
11
+ "person": 7,
12
+ "traffic cone": 8,
13
+ "barrier": 9,
14
+ "motorcycle": 10,
15
+ "bicycle": 11,
16
+ "bus": 12,
17
+ "trailer": 13,
18
+ "books": 14,
19
+ "bottle": 15,
20
+ "camera": 16,
21
+ "cereal box": 17,
22
+ "chair": 18,
23
+ "cup": 19,
24
+ "laptop": 20,
25
+ "shoes": 21,
26
+ "towel": 22,
27
+ "blinds": 23,
28
+ "window": 24,
29
+ "lamp": 25,
30
+ "shelves": 26,
31
+ "mirror": 27,
32
+ "sink": 28,
33
+ "cabinet": 29,
34
+ "bathtub": 30,
35
+ "door": 31,
36
+ "toilet": 32,
37
+ "desk": 33,
38
+ "box": 34,
39
+ "bookcase": 35,
40
+ "picture": 36,
41
+ "table": 37,
42
+ "counter": 38,
43
+ "bed": 39,
44
+ "night stand": 40,
45
+ "dresser": 41,
46
+ "pillow": 42,
47
+ "sofa": 43,
48
+ "television": 44,
49
+ "floor mat": 45,
50
+ "curtain": 46,
51
+ "clothes": 47,
52
+ "stationery": 48,
53
+ "refrigerator": 49,
54
+ "board": 50,
55
+ "kitchen pan": 51,
56
+ "bin": 52,
57
+ "stove": 53,
58
+ "microwave": 54,
59
+ "plates": 55,
60
+ "bowl": 56,
61
+ "oven": 57,
62
+ "vase": 58,
63
+ "faucet": 59,
64
+ "tissues": 60,
65
+ "machine": 61,
66
+ "printer": 62,
67
+ "monitor": 63,
68
+ "podium": 64,
69
+ "cart": 65,
70
+ "projector": 66,
71
+ "electronics": 67,
72
+ "computer": 68,
73
+ "air conditioner": 69,
74
+ "drawers": 70,
75
+ "coffee maker": 71,
76
+ "toaster": 72,
77
+ "potted plant": 73,
78
+ "painting": 74,
79
+ "bag": 75,
80
+ "tray": 76,
81
+ "keyboard": 77,
82
+ "blanket": 78,
83
+ "rack": 79,
84
+ "phone": 80,
85
+ "mouse": 81,
86
+ "fire extinguisher": 82,
87
+ "toys": 83,
88
+ "ladder": 84,
89
+ "fan": 85,
90
+ "glass": 86,
91
+ "clock": 87,
92
+ "toilet paper": 88,
93
+ "closet": 89,
94
+ "fume hood": 90,
95
+ "utensils": 91,
96
+ "soundsystem": 92,
97
+ "fire place": 93,
98
+ "shower curtain": 94,
99
+ "remote": 95,
100
+ "pen": 96,
101
+ "fireplace": 97,
102
+ }
103
+
104
+ # Used for Cube R-CNN and Omni3D benchmark
105
+ omni3d_det_map = {
106
+ "pedestrian": 0,
107
+ "car": 1,
108
+ "cyclist": 2,
109
+ "van": 3,
110
+ "truck": 4,
111
+ "traffic cone": 5,
112
+ "barrier": 6,
113
+ "motorcycle": 7,
114
+ "bicycle": 8,
115
+ "bus": 9,
116
+ "trailer": 10,
117
+ "books": 11,
118
+ "bottle": 12,
119
+ "camera": 13,
120
+ "cereal box": 14,
121
+ "chair": 15,
122
+ "cup": 16,
123
+ "laptop": 17,
124
+ "shoes": 18,
125
+ "towel": 19,
126
+ "blinds": 20,
127
+ "window": 21,
128
+ "lamp": 22,
129
+ "shelves": 23,
130
+ "mirror": 24,
131
+ "sink": 25,
132
+ "cabinet": 26,
133
+ "bathtub": 27,
134
+ "door": 28,
135
+ "toilet": 29,
136
+ "desk": 30,
137
+ "box": 31,
138
+ "bookcase": 32,
139
+ "picture": 33,
140
+ "table": 34,
141
+ "counter": 35,
142
+ "bed": 36,
143
+ "night stand": 37,
144
+ "pillow": 38,
145
+ "sofa": 39,
146
+ "television": 40,
147
+ "floor mat": 41,
148
+ "curtain": 42,
149
+ "clothes": 43,
150
+ "stationery": 44,
151
+ "refrigerator": 45,
152
+ "bin": 46,
153
+ "stove": 47,
154
+ "oven": 48,
155
+ "machine": 49,
156
+ }
opendet3d/data/datasets/omni3d/sunrgbd.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SUN RGB-D from Omni3D."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import numpy as np
8
+ from vis4d.common.typing import ArgsType, DictStrAny
9
+ from vis4d.data.datasets.util import im_decode
10
+
11
+ from opendet3d.data.datasets.coco3d import COCO3DDataset
12
+
13
+ from .omni3d_classes import omni3d_class_map
14
+
15
+ # Train and Test are sharing the classes
16
+ sun_rgbd_train_det_map = sun_rgbd_test_det_map = {
17
+ "air conditioner": 0,
18
+ "bag": 1,
19
+ "bathtub": 2,
20
+ "bed": 3,
21
+ "bicycle": 4,
22
+ "bin": 5,
23
+ "blanket": 6,
24
+ "blinds": 7,
25
+ "board": 8,
26
+ "bookcase": 9,
27
+ "books": 10,
28
+ "bottle": 11,
29
+ "bowl": 12,
30
+ "box": 13,
31
+ "cabinet": 14,
32
+ "cart": 15,
33
+ "chair": 16,
34
+ "clock": 17,
35
+ "closet": 18,
36
+ "clothes": 19,
37
+ "coffee maker": 20,
38
+ "computer": 21,
39
+ "counter": 22,
40
+ "cup": 23,
41
+ "curtain": 24,
42
+ "desk": 25,
43
+ "door": 26,
44
+ "drawers": 27,
45
+ "dresser": 28,
46
+ "electronics": 29,
47
+ "fan": 30,
48
+ "faucet": 31,
49
+ "fire extinguisher": 32,
50
+ "fire place": 33,
51
+ "floor mat": 34,
52
+ "fume hood": 35,
53
+ "glass": 36,
54
+ "keyboard": 37,
55
+ "kitchen pan": 38,
56
+ "ladder": 39,
57
+ "lamp": 40,
58
+ "laptop": 41,
59
+ "machine": 42,
60
+ "microwave": 43,
61
+ "mirror": 44,
62
+ "monitor": 45,
63
+ "mouse": 46,
64
+ "night stand": 47,
65
+ "oven": 48,
66
+ "painting": 49,
67
+ "pen": 50,
68
+ "person": 51,
69
+ "phone": 52,
70
+ "picture": 53,
71
+ "pillow": 54,
72
+ "plates": 55,
73
+ "podium": 56,
74
+ "potted plant": 57,
75
+ "printer": 58,
76
+ "projector": 59,
77
+ "rack": 60,
78
+ "refrigerator": 61,
79
+ "remote": 62,
80
+ "shelves": 63,
81
+ "shoes": 64,
82
+ "shower curtain": 65,
83
+ "sink": 66,
84
+ "sofa": 67,
85
+ "soundsystem": 68,
86
+ "stationery": 69,
87
+ "stove": 70,
88
+ "table": 71,
89
+ "television": 72,
90
+ "tissues": 73,
91
+ "toaster": 74,
92
+ "toilet": 75,
93
+ "toilet paper": 76,
94
+ "towel": 77,
95
+ "toys": 78,
96
+ "tray": 79,
97
+ "utensils": 80,
98
+ "vase": 81,
99
+ "window": 82,
100
+ }
101
+
102
+ sun_rgbd_val_det_map = {
103
+ "air conditioner": 0,
104
+ "bag": 1,
105
+ "bathtub": 2,
106
+ "bed": 3,
107
+ "bin": 4,
108
+ "blanket": 5,
109
+ "blinds": 6,
110
+ "board": 7,
111
+ "bookcase": 8,
112
+ "books": 9,
113
+ "bottle": 10,
114
+ "bowl": 11,
115
+ "box": 12,
116
+ "cabinet": 13,
117
+ "cart": 14,
118
+ "chair": 15,
119
+ "closet": 16,
120
+ "clothes": 17,
121
+ "coffee maker": 18,
122
+ "computer": 19,
123
+ "counter": 20,
124
+ "cup": 21,
125
+ "curtain": 22,
126
+ "desk": 23,
127
+ "door": 24,
128
+ "drawers": 25,
129
+ "dresser": 26,
130
+ "electronics": 27,
131
+ "fan": 28,
132
+ "faucet": 29,
133
+ "fire extinguisher": 30,
134
+ "fire place": 31,
135
+ "fume hood": 32,
136
+ "keyboard": 33,
137
+ "kitchen pan": 34,
138
+ "lamp": 35,
139
+ "laptop": 36,
140
+ "machine": 37,
141
+ "microwave": 38,
142
+ "mirror": 39,
143
+ "monitor": 40,
144
+ "night stand": 41,
145
+ "oven": 42,
146
+ "painting": 43,
147
+ "pen": 44,
148
+ "person": 45,
149
+ "phone": 46,
150
+ "picture": 47,
151
+ "pillow": 48,
152
+ "plates": 49,
153
+ "potted plant": 50,
154
+ "printer": 51,
155
+ "projector": 52,
156
+ "rack": 53,
157
+ "refrigerator": 54,
158
+ "shelves": 55,
159
+ "sink": 56,
160
+ "sofa": 57,
161
+ "soundsystem": 58,
162
+ "stationery": 59,
163
+ "stove": 60,
164
+ "table": 61,
165
+ "television": 62,
166
+ "tissues": 63,
167
+ "toaster": 64,
168
+ "toilet": 65,
169
+ "towel": 66,
170
+ "toys": 67,
171
+ "tray": 68,
172
+ "utensils": 69,
173
+ "vase": 70,
174
+ "window": 71,
175
+ }
176
+
177
+ omni3d_sun_rgbd_det_map = {
178
+ "bicycle": 0,
179
+ "books": 1,
180
+ "bottle": 2,
181
+ "chair": 3,
182
+ "cup": 4,
183
+ "laptop": 5,
184
+ "shoes": 6,
185
+ "towel": 7,
186
+ "blinds": 8,
187
+ "window": 9,
188
+ "lamp": 10,
189
+ "shelves": 11,
190
+ "mirror": 12,
191
+ "sink": 13,
192
+ "cabinet": 14,
193
+ "bathtub": 15,
194
+ "door": 16,
195
+ "toilet": 17,
196
+ "desk": 18,
197
+ "box": 19,
198
+ "bookcase": 20,
199
+ "picture": 21,
200
+ "table": 22,
201
+ "counter": 23,
202
+ "bed": 24,
203
+ "night stand": 25,
204
+ "pillow": 26,
205
+ "sofa": 27,
206
+ "television": 28,
207
+ "floor mat": 29,
208
+ "curtain": 30,
209
+ "clothes": 31,
210
+ "stationery": 32,
211
+ "refrigerator": 33,
212
+ "bin": 34,
213
+ "stove": 35,
214
+ "oven": 36,
215
+ "machine": 37,
216
+ }
217
+
218
+
219
+ def get_sunrgbd_det_map(split: str) -> dict[str, int]:
220
+ """Get the SUN RGB-D detection map."""
221
+ assert split in {"train", "val", "test"}, f"Invalid split: {split}"
222
+
223
+ if split == "train":
224
+ return sun_rgbd_train_det_map
225
+ elif split == "val":
226
+ return sun_rgbd_val_det_map
227
+ else:
228
+ return sun_rgbd_test_det_map
229
+
230
+
231
+ class SUNRGBD(COCO3DDataset):
232
+ """SUN RGB-D Dataset."""
233
+
234
+ def __init__(
235
+ self,
236
+ class_map: dict[str, int] = omni3d_class_map,
237
+ max_depth: float = 8.0,
238
+ depth_scale: float = 1000.0,
239
+ **kwargs: ArgsType,
240
+ ) -> None:
241
+ """Initialize SUN RGB-D dataset."""
242
+ super().__init__(
243
+ class_map=class_map,
244
+ max_depth=max_depth,
245
+ depth_scale=depth_scale,
246
+ **kwargs,
247
+ )
248
+
249
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
250
+ """Get the depth filenames.
251
+
252
+ Since not every data has depth.
253
+ """
254
+ img["file_path"] = img["file_path"].replace("//", "/")
255
+
256
+ data_dir = img["file_path"].split("/image")[0]
257
+
258
+ depth_files = self.data_backend.listdir(
259
+ os.path.join(data_dir, "depth")
260
+ )
261
+ assert len(depth_files) == 1
262
+
263
+ depth_filename = os.path.join(data_dir, "depth", depth_files[0])
264
+
265
+ return depth_filename
266
+
267
+ def get_depth_map(self, sample: DictStrAny) -> np.ndarray:
268
+ """Get the depth map."""
269
+ depth_bytes = self.data_backend.get(sample["depth_filename"])
270
+ depth_array = im_decode(depth_bytes)
271
+
272
+ depth_array = depth_array >> 3 | depth_array << (16 - 3)
273
+
274
+ depth = np.ascontiguousarray(depth_array, dtype=np.float32)
275
+
276
+ depth = depth / self.depth_scale
277
+
278
+ return depth
opendet3d/data/datasets/omni3d/util.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Omni3D data util."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .arkitscenes import arkitscenes_det_map, omni3d_arkitscenes_det_map
6
+ from .hypersim import get_hypersim_det_map, omni3d_hypersim_det_map
7
+ from .kitti_object import get_kitti_det_map, omni3d_kitti_det_map
8
+ from .nuscenes import nusc_det_map
9
+ from .objectron import objectron_det_map
10
+ from .sunrgbd import get_sunrgbd_det_map, omni3d_sun_rgbd_det_map
11
+
12
+ DATASET_ID_MAP = {
13
+ 0: "KITTI_train",
14
+ 1: "KITTI_val",
15
+ 2: "KITTI_test",
16
+ 3: "nuScenes_train",
17
+ 4: "nuScenes_val",
18
+ 5: "nuScenes_test",
19
+ 6: "Objectron_train",
20
+ 7: "Objectron_val",
21
+ 8: "Objectron_test",
22
+ 9: "Hypersim_train",
23
+ 10: "Hypersim_val",
24
+ 11: "Hypersim_test",
25
+ 12: "SUNRGBD_train",
26
+ 13: "SUNRGBD_val",
27
+ 14: "SUNRGBD_test",
28
+ 15: "ARKitScenes_train",
29
+ 16: "ARKitScenes_val",
30
+ 17: "ARKitScenes_test",
31
+ }
32
+
33
+
34
+ def get_dataset_det_map(
35
+ dataset_name: str,
36
+ omni3d50: bool = True,
37
+ ) -> tuple[str, dict[str, int]]:
38
+ """Get the detection map."""
39
+ if "train" in dataset_name:
40
+ split = "train"
41
+ elif "val" in dataset_name:
42
+ split = "val"
43
+ elif "test" in dataset_name:
44
+ split = "test"
45
+ else:
46
+ raise ValueError(f"Unknown dataset_name: {dataset_name}")
47
+
48
+ if "nuScenes" in dataset_name:
49
+ det_map = nusc_det_map
50
+ elif "KITTI" in dataset_name:
51
+ if omni3d50:
52
+ det_map = omni3d_kitti_det_map
53
+ else:
54
+ det_map = get_kitti_det_map(split)
55
+ elif "Objectron" in dataset_name:
56
+ det_map = objectron_det_map
57
+ elif "SUNRGBD" in dataset_name:
58
+ if omni3d50:
59
+ det_map = omni3d_sun_rgbd_det_map
60
+ else:
61
+ det_map = get_sunrgbd_det_map(split)
62
+ elif "Hypersim" in dataset_name:
63
+ if omni3d50:
64
+ det_map = omni3d_hypersim_det_map
65
+ else:
66
+ det_map = get_hypersim_det_map(split)
67
+ elif "ARKitScenes" in dataset_name:
68
+ det_map = (
69
+ omni3d_arkitscenes_det_map if omni3d50 else arkitscenes_det_map
70
+ )
71
+ else:
72
+ raise ValueError(f"Unknown dataset_name: {dataset_name}")
73
+
74
+ return det_map
opendet3d/data/datasets/scannet.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ScanNet dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from vis4d.common.typing import ArgsType, DictStrAny
6
+
7
+ from .coco3d import COCO3DDataset
8
+
9
+ scannet_class_map = {
10
+ "cabinet": 3,
11
+ "bed": 4,
12
+ "chair": 5,
13
+ "sofa": 6,
14
+ "table": 7,
15
+ "door": 8,
16
+ "window": 9,
17
+ "bookshelf": 10,
18
+ "picture": 11,
19
+ "counter": 12,
20
+ "desk": 14,
21
+ "curtain": 16,
22
+ "refrigerator": 24,
23
+ "shower curtain": 28,
24
+ "toilet": 33,
25
+ "sink": 34,
26
+ "bathtub": 36,
27
+ "other furniture": 39,
28
+ }
29
+
30
+ scannet_det_map = {
31
+ "cabinet": 0,
32
+ "bed": 1,
33
+ "chair": 2,
34
+ "sofa": 3,
35
+ "table": 4,
36
+ "door": 5,
37
+ "window": 6,
38
+ "bookshelf": 7,
39
+ "picture": 8,
40
+ "counter": 9,
41
+ "desk": 10,
42
+ "curtain": 11,
43
+ "refrigerator": 12,
44
+ "shower curtain": 13,
45
+ "toilet": 14,
46
+ "sink": 15,
47
+ "bathtub": 16,
48
+ "other furniture": 17,
49
+ }
50
+
51
+ scannet200_class_map = {
52
+ "chair": 2,
53
+ "book": 22,
54
+ "door": 5,
55
+ "object": 1163,
56
+ "window": 16,
57
+ "table": 4,
58
+ "trash can": 56,
59
+ "pillow": 13,
60
+ "picture": 15,
61
+ "box": 26,
62
+ "doorframe": 161,
63
+ "monitor": 19,
64
+ "cabinet": 7,
65
+ "desk": 9,
66
+ "shelf": 8,
67
+ "office chair": 10,
68
+ "towel": 31,
69
+ "couch": 6,
70
+ "sink": 14,
71
+ "backpack": 48,
72
+ "lamp": 28,
73
+ "bed": 11,
74
+ "bookshelf": 18,
75
+ "mirror": 71,
76
+ "curtain": 21,
77
+ "plant": 40,
78
+ "whiteboard": 52,
79
+ "radiator": 96,
80
+ "kitchen cabinet": 29,
81
+ "toilet paper": 49,
82
+ "armchair": 23,
83
+ "shoe": 63,
84
+ "coffee table": 24,
85
+ "toilet": 17,
86
+ "bag": 47,
87
+ "clothes": 32,
88
+ "keyboard": 46,
89
+ "bottle": 65,
90
+ "recycling bin": 97,
91
+ "nightstand": 34,
92
+ "stool": 38,
93
+ "tv": 33,
94
+ "file cabinet": 75,
95
+ "dresser": 36,
96
+ "computer tower": 64,
97
+ "telephone": 101,
98
+ "cup": 130,
99
+ "refrigerator": 27,
100
+ "end table": 44,
101
+ "jacket": 131,
102
+ "shower curtain": 55,
103
+ "bathtub": 42,
104
+ "microwave": 59,
105
+ "kitchen counter": 159,
106
+ "sofa chair": 74,
107
+ "paper towel dispenser": 82,
108
+ "bathroom vanity": 1164,
109
+ "suitcase": 93,
110
+ "laptop": 77,
111
+ "ottoman": 67,
112
+ "shower wall": 128,
113
+ "printer": 50,
114
+ "counter": 35,
115
+ "board": 69,
116
+ "soap dispenser": 100,
117
+ "stove": 62,
118
+ "light": 105,
119
+ "closet wall": 1165,
120
+ "mini fridge": 165,
121
+ "fan": 76,
122
+ "tissue box": 230,
123
+ "blanket": 54,
124
+ "bathroom stall": 125,
125
+ "copier": 72,
126
+ "bench": 68,
127
+ "bar": 145,
128
+ "soap dish": 157,
129
+ "laundry hamper": 1166,
130
+ "storage bin": 132,
131
+ "bathroom stall door": 1167,
132
+ "light switch": 232,
133
+ "coffee maker": 134,
134
+ "tv stand": 51,
135
+ "decoration": 250,
136
+ "ceiling light": 1168,
137
+ "range hood": 342,
138
+ "blackboard": 89,
139
+ "clock": 103,
140
+ "wardrobe": 99,
141
+ "rail": 95,
142
+ "bulletin board": 154,
143
+ "mat": 140,
144
+ "trash bin": 1169,
145
+ "ledge": 193,
146
+ "seat": 116,
147
+ "mouse": 202,
148
+ "basket": 73,
149
+ "shower": 78,
150
+ "dumbbell": 1170,
151
+ "paper": 79,
152
+ "person": 80,
153
+ "windowsill": 141,
154
+ "closet": 57,
155
+ "bucket": 102,
156
+ "sign": 261,
157
+ "speaker": 118,
158
+ "dishwasher": 136,
159
+ "container": 98,
160
+ "stair rail": 1171,
161
+ "shower curtain rod": 170,
162
+ "tube": 1172,
163
+ "bathroom cabinet": 1173,
164
+ "storage container": 221,
165
+ "paper bag": 570,
166
+ "paper towel roll": 138,
167
+ "ball": 168,
168
+ "closet door": 276,
169
+ "laundry basket": 106,
170
+ "cart": 214,
171
+ "dish rack": 323,
172
+ "stairs": 58,
173
+ "blinds": 86,
174
+ "purse": 399,
175
+ "bicycle": 121,
176
+ "tray": 185,
177
+ "plunger": 300,
178
+ "paper cutter": 180,
179
+ "toilet paper dispenser": 163,
180
+ "bin": 66,
181
+ "toilet seat cover dispenser": 208,
182
+ "guitar": 112,
183
+ "mailbox": 540,
184
+ "handicap bar": 395,
185
+ "fire extinguisher": 166,
186
+ "ladder": 122,
187
+ "column": 120,
188
+ "pipe": 107,
189
+ "vacuum cleaner": 283,
190
+ "plate": 88,
191
+ "piano": 90,
192
+ "water cooler": 177,
193
+ "cd case": 1174,
194
+ "bowl": 562,
195
+ "closet rod": 1175,
196
+ "bathroom counter": 1156,
197
+ "oven": 84,
198
+ "stand": 104,
199
+ "scale": 229,
200
+ "washing machine": 70,
201
+ "broom": 325,
202
+ "hat": 169,
203
+ "guitar case": 331,
204
+ "rack": 87,
205
+ "water pitcher": 488,
206
+ "laundry detergent": 776,
207
+ "hair dryer": 370,
208
+ "pillar": 191,
209
+ "divider": 748,
210
+ "power outlet": 242,
211
+ "dining table": 45,
212
+ "shower floor": 417,
213
+ "shower door": 188,
214
+ "coffee kettle": 1176,
215
+ "structure": 1178,
216
+ "clothes dryer": 110,
217
+ "toaster": 148,
218
+ "ironing board": 155,
219
+ "alarm clock": 572,
220
+ "shower head": 1179,
221
+ "water bottle": 392,
222
+ "keyboard piano": 1180,
223
+ "projector screen": 609,
224
+ "case of water bottles": 1181,
225
+ "toaster oven": 195,
226
+ "music stand": 581,
227
+ "coat rack": 1182,
228
+ "storage organizer": 1183,
229
+ "machine": 139,
230
+ "folded chair": 1184,
231
+ "fire alarm": 1185,
232
+ "fireplace": 156,
233
+ "vent": 408,
234
+ "furniture": 213,
235
+ "power strip": 1186,
236
+ "calendar": 1187,
237
+ "poster": 1188,
238
+ "toilet paper holder": 115,
239
+ "potted plant": 1189,
240
+ "stuffed animal": 304,
241
+ "luggage": 1190,
242
+ "headphones": 312,
243
+ "crate": 233,
244
+ "candle": 286,
245
+ "projector": 264,
246
+ "mattress": 1191,
247
+ "dustpan": 356,
248
+ "cushion": 39,
249
+ "stick": 1163,
250
+ }
251
+
252
+ scannet200_det_map = {
253
+ "chair": 0,
254
+ "table": 1,
255
+ "door": 2,
256
+ "couch": 3,
257
+ "cabinet": 4,
258
+ "shelf": 5,
259
+ "desk": 6,
260
+ "office chair": 7,
261
+ "bed": 8,
262
+ "pillow": 9,
263
+ "sink": 10,
264
+ "picture": 11,
265
+ "window": 12,
266
+ "toilet": 13,
267
+ "bookshelf": 14,
268
+ "monitor": 15,
269
+ "curtain": 16,
270
+ "book": 17,
271
+ "armchair": 18,
272
+ "coffee table": 19,
273
+ "box": 20,
274
+ "refrigerator": 21,
275
+ "lamp": 22,
276
+ "kitchen cabinet": 23,
277
+ "towel": 24,
278
+ "clothes": 25,
279
+ "tv": 26,
280
+ "nightstand": 27,
281
+ "counter": 28,
282
+ "dresser": 29,
283
+ "stool": 30,
284
+ "plant": 31,
285
+ "bathtub": 32,
286
+ "end table": 33,
287
+ "dining table": 34,
288
+ "keyboard": 35,
289
+ "bag": 36,
290
+ "backpack": 37,
291
+ "toilet paper": 38,
292
+ "printer": 39,
293
+ "tv stand": 40,
294
+ "whiteboard": 41,
295
+ "blanket": 42,
296
+ "shower curtain": 43,
297
+ "trash can": 44,
298
+ "closet": 45,
299
+ "stairs": 46,
300
+ "microwave": 47,
301
+ "stove": 48,
302
+ "shoe": 49,
303
+ "computer tower": 50,
304
+ "bottle": 51,
305
+ "bin": 52,
306
+ "ottoman": 53,
307
+ "bench": 54,
308
+ "board": 55,
309
+ "washing machine": 56,
310
+ "mirror": 57,
311
+ "copier": 58,
312
+ "basket": 59,
313
+ "sofa chair": 60,
314
+ "file cabinet": 61,
315
+ "fan": 62,
316
+ "laptop": 63,
317
+ "shower": 64,
318
+ "paper": 65,
319
+ "person": 66,
320
+ "paper towel dispenser": 67,
321
+ "oven": 68,
322
+ "blinds": 69,
323
+ "rack": 70,
324
+ "plate": 71,
325
+ "blackboard": 72,
326
+ "piano": 73,
327
+ "suitcase": 74,
328
+ "rail": 75,
329
+ "radiator": 76,
330
+ "recycling bin": 77,
331
+ "container": 78,
332
+ "wardrobe": 79,
333
+ "soap dispenser": 80,
334
+ "telephone": 81,
335
+ "bucket": 82,
336
+ "clock": 83,
337
+ "stand": 84,
338
+ "light": 85,
339
+ "laundry basket": 86,
340
+ "pipe": 87,
341
+ "clothes dryer": 88,
342
+ "guitar": 89,
343
+ "toilet paper holder": 90,
344
+ "seat": 91,
345
+ "speaker": 92,
346
+ "column": 93,
347
+ "ladder": 94,
348
+ "cup": 95,
349
+ "jacket": 96,
350
+ "storage bin": 97,
351
+ "coffee maker": 98,
352
+ "dishwasher": 99,
353
+ "paper towel roll": 100,
354
+ "machine": 101,
355
+ "mat": 102,
356
+ "windowsill": 103,
357
+ "bar": 104,
358
+ "bulletin board": 105,
359
+ "ironing board": 106,
360
+ "fireplace": 107,
361
+ "soap dish": 108,
362
+ "kitchen counter": 109,
363
+ "doorframe": 110,
364
+ "toilet paper dispenser": 111,
365
+ "mini fridge": 112,
366
+ "fire extinguisher": 113,
367
+ "ball": 114,
368
+ "hat": 115,
369
+ "shower curtain rod": 116,
370
+ "water cooler": 117,
371
+ "paper cutter": 118,
372
+ "tray": 119,
373
+ "pillar": 120,
374
+ "ledge": 121,
375
+ "toaster oven": 122,
376
+ "mouse": 123,
377
+ "toilet seat cover dispenser": 124,
378
+ "cart": 125,
379
+ "scale": 126,
380
+ "tissue box": 127,
381
+ "light switch": 128,
382
+ "crate": 129,
383
+ "power outlet": 130,
384
+ "decoration": 131,
385
+ "sign": 132,
386
+ "projector": 133,
387
+ "closet door": 134,
388
+ "vacuum cleaner": 135,
389
+ "headphones": 136,
390
+ "dish rack": 137,
391
+ "broom": 138,
392
+ "range hood": 139,
393
+ "hair dryer": 140,
394
+ "water bottle": 141,
395
+ "vent": 142,
396
+ "mailbox": 143,
397
+ "bowl": 144,
398
+ "paper bag": 145,
399
+ "projector screen": 146,
400
+ "divider": 147,
401
+ "laundry detergent": 148,
402
+ "bathroom counter": 149,
403
+ "stick": 150,
404
+ "bathroom vanity": 151,
405
+ "closet wall": 152,
406
+ "laundry hamper": 153,
407
+ "bathroom stall door": 154,
408
+ "ceiling light": 155,
409
+ "trash bin": 156,
410
+ "dumbbell": 157,
411
+ "stair rail": 158,
412
+ "tube": 159,
413
+ "bathroom cabinet": 160,
414
+ "coffee kettle": 161,
415
+ "shower head": 162,
416
+ "case of water bottles": 163,
417
+ "power strip": 164,
418
+ "calendar": 165,
419
+ "poster": 166,
420
+ "mattress": 167,
421
+ }
422
+
423
+
424
+ class ScanNetDataset(COCO3DDataset):
425
+ """ScanNetV2 dataset."""
426
+
427
+ def __init__(
428
+ self,
429
+ class_map: dict[str, int] = scannet_class_map,
430
+ max_depth: float = 12.0,
431
+ depth_scale: float = 1000.0,
432
+ **kwargs: ArgsType,
433
+ ) -> None:
434
+ """Creates an instance of the class."""
435
+ super().__init__(
436
+ class_map=class_map,
437
+ max_depth=max_depth,
438
+ depth_scale=depth_scale,
439
+ **kwargs,
440
+ )
441
+
442
+ def get_depth_filenames(self, img: DictStrAny) -> str | None:
443
+ """Get the depth filenames.
444
+
445
+ Since not every data has depth.
446
+ """
447
+ return (
448
+ img["file_path"].replace("image", "depth").replace(".jpg", ".png")
449
+ )
opendet3d/data/transforms/__init__.py ADDED
File without changes
opendet3d/data/transforms/crop.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Crop transforms."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from vis4d.common.typing import (
6
+ NDArrayBool,
7
+ NDArrayF32,
8
+ NDArrayI64,
9
+ )
10
+ from vis4d.data.const import CommonKeys as K
11
+ from vis4d.data.transforms.base import Transform
12
+
13
+
14
+ @Transform(
15
+ in_keys=[
16
+ K.boxes3d,
17
+ K.boxes3d_classes,
18
+ K.boxes3d_track_ids,
19
+ "transforms.crop.keep_mask",
20
+ ],
21
+ out_keys=[K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids],
22
+ )
23
+ class CropBoxes3D:
24
+ """Crop 3D bounding boxes."""
25
+
26
+ def __call__(
27
+ self,
28
+ boxes_list: list[NDArrayF32],
29
+ classes_list: list[NDArrayI64],
30
+ track_ids_list: list[NDArrayI64] | None,
31
+ keep_mask_list: list[NDArrayBool],
32
+ ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]:
33
+ """Crop 3D bounding boxes."""
34
+ for i, (boxes, classes, keep_mask) in enumerate(
35
+ zip(boxes_list, classes_list, keep_mask_list)
36
+ ):
37
+ boxes_list[i] = boxes[keep_mask]
38
+ classes_list[i] = classes[keep_mask]
39
+
40
+ if track_ids_list is not None:
41
+ track_ids_list[i] = track_ids_list[i][keep_mask]
42
+
43
+ return boxes_list, classes_list, track_ids_list
opendet3d/data/transforms/language.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Language related transforms."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ import re
7
+
8
+ import numpy as np
9
+ from transformers import AutoTokenizer
10
+ from vis4d.common.logging import rank_zero_warn
11
+ from vis4d.common.typing import NDArrayF32, NDArrayI64
12
+ from vis4d.data.const import CommonKeys as K
13
+ from vis4d.data.transforms.base import Transform
14
+
15
+
16
+ def clean_name(name: str) -> str:
17
+ """Clean the name."""
18
+ name = re.sub(r"\(.*\)", "", name)
19
+ name = re.sub(r"_", " ", name)
20
+ name = re.sub(r" ", " ", name)
21
+ name = name.lower()
22
+ return name
23
+
24
+
25
+ def generate_senetence_given_labels(
26
+ positive_label_list: list[int],
27
+ negative_label_list: list[str],
28
+ label_map: dict[str, str],
29
+ ) -> tuple[dict[int, list[list[int]]], str, dict[int, int]]:
30
+ """Generate a sentence given positive and negative labels."""
31
+ label_to_positions = {}
32
+
33
+ label_list = negative_label_list + positive_label_list
34
+
35
+ random.shuffle(label_list)
36
+
37
+ pheso_caption = ""
38
+
39
+ label_remap_dict = {}
40
+ for index, label in enumerate(label_list):
41
+ start_index = len(pheso_caption)
42
+
43
+ pheso_caption += clean_name(label_map[str(label)])
44
+
45
+ end_index = len(pheso_caption)
46
+
47
+ if label in positive_label_list:
48
+ label_to_positions[index] = [[start_index, end_index]]
49
+ label_remap_dict[int(label)] = index
50
+
51
+ pheso_caption += ". "
52
+
53
+ return label_to_positions, pheso_caption, label_remap_dict
54
+
55
+
56
+ @Transform(
57
+ [
58
+ "dataset_type",
59
+ K.boxes2d,
60
+ K.boxes2d_classes,
61
+ K.boxes2d_names,
62
+ "label_map",
63
+ "positive_positions",
64
+ ],
65
+ [K.boxes2d, K.boxes2d_classes, K.boxes2d_names, "tokens_positive"],
66
+ )
67
+ class RandomSamplingNegPos:
68
+ """Randomly sample negative and positive labels for object detection."""
69
+
70
+ def __init__(
71
+ self,
72
+ tokenizer_name: str = "bert-base-uncased",
73
+ num_sample_negative: int = 85,
74
+ max_tokens: int = 256,
75
+ full_sampling_prob: float = 0.5,
76
+ ) -> None:
77
+ """Creates an instance of RandomSamplingNegPos."""
78
+ if AutoTokenizer is None:
79
+ raise RuntimeError(
80
+ "transformers is not installed, please install it by: "
81
+ "pip install transformers."
82
+ )
83
+
84
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
85
+ self.num_sample_negative = num_sample_negative
86
+ self.full_sampling_prob = full_sampling_prob
87
+ self.max_tokens = max_tokens
88
+
89
+ def __call__(
90
+ self,
91
+ dataset_type_list: list[str],
92
+ boxes_list: list[NDArrayF32],
93
+ class_ids_list: list[NDArrayI64],
94
+ texts_list: list[str] | None = None,
95
+ label_map_list: dict | None = None,
96
+ positive_positions_list: list[dict] | None = None,
97
+ ) -> tuple[
98
+ list[NDArrayF32],
99
+ list[NDArrayI64],
100
+ list[str],
101
+ list[dict[int, list[list[int]]]],
102
+ ]:
103
+ """Randomly sample negative and positive labels."""
104
+ new_texts_list = []
105
+ tokens_positive_list = []
106
+ for i, (boxes, class_ids) in enumerate(
107
+ zip(boxes_list, class_ids_list)
108
+ ):
109
+ if dataset_type_list[i] == "OD":
110
+ assert (
111
+ label_map_list[i] is not None
112
+ ), "label_map should not be None"
113
+ boxes_list[i], class_ids_list[i], text, tokens_positive = (
114
+ self.od_aug(boxes, class_ids, label_map_list[i])
115
+ )
116
+ new_texts_list.append(text)
117
+ tokens_positive_list.append(tokens_positive)
118
+ else:
119
+ assert (
120
+ positive_positions_list[i] is not None
121
+ ), "positive_positions should not be None"
122
+ tokens_positive = self.vg_aug(
123
+ class_ids, positive_positions_list[i]
124
+ )
125
+ new_texts_list.append(texts_list[i])
126
+ tokens_positive_list.append(tokens_positive)
127
+
128
+ return boxes_list, class_ids_list, new_texts_list, tokens_positive_list
129
+
130
+ def vg_aug(self, class_ids: NDArrayI64, positive_positions):
131
+ """Visual Genome data augmentation."""
132
+ positive_label_list = np.unique(class_ids).tolist()
133
+
134
+ label_to_positions = {}
135
+ for label in positive_label_list:
136
+ label_to_positions[label] = positive_positions[label]
137
+
138
+ return label_to_positions
139
+
140
+ def od_aug(
141
+ self,
142
+ boxes: NDArrayF32,
143
+ class_ids: NDArrayI64,
144
+ label_map: dict,
145
+ ) -> tuple[NDArrayF32, NDArrayI64, str, dict[int, list[list[int]]]]:
146
+ """Object detection data augmentation."""
147
+ original_box_num = len(class_ids)
148
+
149
+ # If the category name is in the format of 'a/b' (in object365),
150
+ # we randomly select one of them.
151
+ for key, value in label_map.items():
152
+ if "/" in value:
153
+ label_map[key] = random.choice(value.split("/")).strip()
154
+
155
+ keep_box_index, class_ids, positive_caption_length = (
156
+ self.check_for_positive_overflow(class_ids, label_map)
157
+ )
158
+
159
+ boxes = boxes[keep_box_index]
160
+
161
+ if len(boxes) < original_box_num:
162
+ rank_zero_warn(
163
+ f"Remove {original_box_num - len(boxes)} boxes due to "
164
+ "positive caption overflow."
165
+ )
166
+
167
+ valid_negative_indexes = list(label_map.keys())
168
+
169
+ positive_label_list = np.unique(class_ids).tolist()
170
+
171
+ full_negative = self.num_sample_negative
172
+ if full_negative > len(valid_negative_indexes):
173
+ full_negative = len(valid_negative_indexes)
174
+
175
+ outer_prob = random.random()
176
+
177
+ if outer_prob < self.full_sampling_prob:
178
+ # c. probability_full: add both all positive and all negatives
179
+ num_negatives = full_negative
180
+ else:
181
+ if random.random() < 1.0:
182
+ num_negatives = np.random.choice(max(1, full_negative)) + 1
183
+ else:
184
+ num_negatives = full_negative
185
+
186
+ # Keep some negatives
187
+ negative_label_list = set()
188
+ if num_negatives != -1:
189
+ if num_negatives > len(valid_negative_indexes):
190
+ num_negatives = len(valid_negative_indexes)
191
+
192
+ for i in np.random.choice(
193
+ valid_negative_indexes, size=num_negatives, replace=False
194
+ ):
195
+ if int(i) not in positive_label_list:
196
+ negative_label_list.add(i)
197
+
198
+ random.shuffle(positive_label_list)
199
+
200
+ negative_label_list = list(negative_label_list)
201
+ random.shuffle(negative_label_list)
202
+
203
+ negative_max_length = self.max_tokens - positive_caption_length
204
+ screened_negative_label_list = []
205
+
206
+ for negative_label in negative_label_list:
207
+ label_text = clean_name(label_map[str(negative_label)]) + ". "
208
+
209
+ tokenized = self.tokenizer.tokenize(label_text)
210
+
211
+ negative_max_length -= len(tokenized)
212
+
213
+ if negative_max_length > 0:
214
+ screened_negative_label_list.append(negative_label)
215
+ else:
216
+ break
217
+
218
+ negative_label_list = screened_negative_label_list
219
+ label_to_positions, pheso_caption, label_remap_dict = (
220
+ generate_senetence_given_labels(
221
+ positive_label_list, negative_label_list, label_map
222
+ )
223
+ )
224
+
225
+ # label remap
226
+ if len(class_ids) > 0:
227
+ class_ids = np.vectorize(lambda x: label_remap_dict[x])(class_ids)
228
+
229
+ return boxes, class_ids, pheso_caption, label_to_positions
230
+
231
+ def check_for_positive_overflow(
232
+ self, class_ids: NDArrayI64, label_map: dict[str, str]
233
+ ) -> tuple[list[int], NDArrayI64, int]:
234
+ """Check if having too many positive labels."""
235
+ # generate a caption by appending the positive labels
236
+ positive_label_list = np.unique(class_ids).tolist()
237
+
238
+ # random shuffule so we can sample different annotations
239
+ # at different epochs
240
+ random.shuffle(positive_label_list)
241
+
242
+ kept_lables = []
243
+ length = 0
244
+ for _, label in enumerate(positive_label_list):
245
+ label_text = clean_name(label_map[str(label)]) + ". "
246
+
247
+ tokenized = self.tokenizer.tokenize(label_text)
248
+
249
+ length += len(tokenized)
250
+
251
+ if length > self.max_tokens:
252
+ break
253
+ else:
254
+ kept_lables.append(label)
255
+
256
+ keep_box_index = []
257
+ keep_gt_labels = []
258
+ for i, class_id in enumerate(class_ids):
259
+ if class_id in kept_lables:
260
+ keep_box_index.append(i)
261
+ keep_gt_labels.append(class_id)
262
+
263
+ return (
264
+ keep_box_index,
265
+ np.array(keep_gt_labels, dtype=np.int64),
266
+ length,
267
+ )
opendet3d/data/transforms/pad.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pad transformation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TypedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from vis4d.common.typing import NDArrayF32
10
+ from vis4d.data.const import CommonKeys as K
11
+ from vis4d.data.transforms.base import Transform
12
+ from vis4d.data.transforms.pad import _get_max_shape
13
+
14
+
15
+ class PadParam(TypedDict):
16
+ """Parameters for Reshape."""
17
+
18
+ pad_top: int
19
+ pad_bottom: int
20
+ pad_left: int
21
+ pad_right: int
22
+
23
+
24
+ @Transform(
25
+ [K.images, K.input_hw],
26
+ [K.images, "transforms.pad", K.input_hw, "padding"],
27
+ )
28
+ class CenterPadImages:
29
+ """Pad batch of images at the bottom right."""
30
+
31
+ def __init__(
32
+ self,
33
+ stride: int = 32,
34
+ mode: str = "constant",
35
+ value: float = 0.0,
36
+ update_input_hw: bool = False,
37
+ shape: tuple[int, int] | None = None,
38
+ pad2square: bool = False,
39
+ ) -> None:
40
+ """Creates an instance of PadImage.
41
+
42
+ Args:
43
+ stride (int, optional): Chooses padding size so that the input will
44
+ be divisible by stride. Defaults to 32.
45
+ mode (str, optional): Padding mode. One of constant, reflect,
46
+ replicate or circular. Defaults to "constant".
47
+ value (float, optional): Value for constant padding.
48
+ Defaults to 0.0.
49
+ shape (tuple[int, int], optional): Shape of the padded image
50
+ (H, W). Defaults to None.
51
+ pad2square (bool, optional): Pad to square. Defaults to False.
52
+ """
53
+ self.stride = stride
54
+ self.mode = mode
55
+ self.value = value
56
+ self.update_input_hw = update_input_hw
57
+ self.shape = shape
58
+ self.pad2square = pad2square
59
+
60
+ def __call__(
61
+ self, images: list[NDArrayF32], input_hw: list[tuple[int, int]]
62
+ ) -> tuple[list[NDArrayF32], list[PadParam], list[tuple[int, int]]]:
63
+ """Pad images to consistent size."""
64
+ heights = [im.shape[1] for im in images]
65
+ widths = [im.shape[2] for im in images]
66
+
67
+ max_hw = _get_max_shape(
68
+ heights, widths, self.stride, self.shape, self.pad2square
69
+ )
70
+
71
+ # generate params for torch pad
72
+ pad_params = []
73
+ target_input_hw = []
74
+ paddings = []
75
+ for i, (image, h, w) in enumerate(zip(images, heights, widths)):
76
+ pad_top, pad_bottom = (max_hw[0] - h) // 2, max_hw[0] - h - (
77
+ max_hw[0] - h
78
+ ) // 2
79
+
80
+ pad_left, pad_right = (max_hw[1] - w) // 2, max_hw[1] - w - (
81
+ max_hw[1] - w
82
+ ) // 2
83
+
84
+ image_ = torch.from_numpy(image).permute(0, 3, 1, 2)
85
+ image_ = F.pad(
86
+ image_,
87
+ (pad_left, pad_right, pad_top, pad_bottom),
88
+ self.mode,
89
+ self.value,
90
+ )
91
+ images[i] = image_.permute(0, 2, 3, 1).numpy()
92
+
93
+ pad_params.append(
94
+ PadParam(
95
+ pad_top=pad_top,
96
+ pad_bottom=pad_bottom,
97
+ pad_left=pad_left,
98
+ pad_right=pad_right,
99
+ )
100
+ )
101
+
102
+ paddings.append([pad_left, pad_right, pad_top, pad_bottom])
103
+
104
+ target_input_hw.append(max_hw)
105
+
106
+ if self.update_input_hw:
107
+ input_hw = target_input_hw
108
+
109
+ return images, pad_params, input_hw, paddings
110
+
111
+
112
+ @Transform([K.intrinsics, "transforms.pad"], K.intrinsics)
113
+ class CenterPadIntrinsics:
114
+ """Resize Intrinsics."""
115
+
116
+ def __call__(
117
+ self, intrinsics: list[NDArrayF32], pad_params: list[PadParam]
118
+ ) -> list[NDArrayF32]:
119
+ """Scale camera intrinsics when resizing."""
120
+ for i, intrinsic in enumerate(intrinsics):
121
+ intrinsic[0, 2] += pad_params[i]["pad_left"]
122
+ intrinsic[1, 2] += pad_params[i]["pad_top"]
123
+
124
+ intrinsics[i] = intrinsic
125
+ return intrinsics
126
+
127
+
128
+ @Transform([K.boxes2d, "transforms.pad"], K.boxes2d)
129
+ class CenterPadBoxes2D:
130
+ """Pad batch of depth maps at the bottom right."""
131
+
132
+ def __call__(
133
+ self, boxes_list: list[NDArrayF32], pad_params: list[PadParam]
134
+ ) -> list[NDArrayF32]:
135
+ """Scale camera intrinsics when resizing."""
136
+ for i, boxes in enumerate(boxes_list):
137
+ boxes[:, 0] += pad_params[i]["pad_left"]
138
+ boxes[:, 1] += pad_params[i]["pad_top"]
139
+ boxes[:, 2] += pad_params[i]["pad_left"]
140
+ boxes[:, 3] += pad_params[i]["pad_top"]
141
+
142
+ boxes_list[i] = boxes
143
+
144
+ return boxes_list
145
+
146
+
147
+ @Transform([K.depth_maps, "transforms.pad"], K.depth_maps)
148
+ class CenterPadDepthMaps:
149
+ """Pad batch of depth maps at the bottom right."""
150
+
151
+ def __init__(self, mode: str = "constant", value: int = 0) -> None:
152
+ """Creates an instance."""
153
+ self.mode = mode
154
+ self.value = value
155
+
156
+ def __call__(
157
+ self, depth_maps: list[NDArrayF32], pad_params: list[PadParam]
158
+ ) -> list[NDArrayF32]:
159
+ """Pad images to consistent size."""
160
+
161
+ # generate params for torch pad
162
+ for i, (depth, pad_param_dict) in enumerate(
163
+ zip(depth_maps, pad_params)
164
+ ):
165
+ pad_param = (
166
+ pad_param_dict["pad_left"],
167
+ pad_param_dict["pad_right"],
168
+ pad_param_dict["pad_top"],
169
+ pad_param_dict["pad_bottom"],
170
+ )
171
+
172
+ depth_ = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0)
173
+ depth_ = F.pad(depth_, pad_param, self.mode, self.value)
174
+ depth_maps[i] = depth_.squeeze(0).squeeze(0).numpy()
175
+
176
+ return depth_maps
opendet3d/data/transforms/resize.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Resize transformation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+ import torch
9
+ from vis4d.common.typing import NDArrayF32, NDArrayI64
10
+ from vis4d.data.const import CommonKeys as K
11
+ from vis4d.data.transforms.base import Transform
12
+ from vis4d.data.transforms.resize import ResizeParam, resize_tensor
13
+
14
+
15
+ @Transform(K.images, ["transforms.resize", K.input_hw])
16
+ class GenResizeParameters:
17
+ """Generate the parameters for a resize operation."""
18
+
19
+ def __init__(
20
+ self, shape: tuple[int, int], scales: tuple[float, float] | float = 1.0
21
+ ) -> None:
22
+ """Create a new instance of the class."""
23
+ self.shape = shape
24
+ self.scales = scales
25
+
26
+ def __call__(
27
+ self, images: list[NDArrayF32]
28
+ ) -> tuple[list[ResizeParam], list[tuple[int, int]]]:
29
+ """Compute the parameters and put them in the data dict."""
30
+ if isinstance(self.scales, float):
31
+ random_scale = self.scales
32
+ else:
33
+ random_scale = np.random.uniform(self.scales[0], self.scales[1])
34
+
35
+ shape = (
36
+ math.ceil(self.shape[0] * random_scale - 0.5),
37
+ math.ceil(self.shape[1] * random_scale - 0.5),
38
+ )
39
+
40
+ output_ratio = shape[1] / shape[0]
41
+
42
+ image = images[0]
43
+
44
+ input_h, input_w = (image.shape[1], image.shape[2])
45
+ input_ratio = input_w / input_h
46
+
47
+ if output_ratio > input_ratio:
48
+ scale = shape[0] / input_h
49
+ else:
50
+ scale = shape[1] / input_w
51
+
52
+ target_shape = (
53
+ math.ceil(input_h * scale - 0.5),
54
+ math.ceil(input_w * scale - 0.5),
55
+ )
56
+
57
+ scale_factor = (target_shape[0] / input_h, target_shape[1] / input_w)
58
+
59
+ resize_params = [
60
+ ResizeParam(target_shape=target_shape, scale_factor=scale_factor)
61
+ ] * len(images)
62
+ target_shapes = [target_shape] * len(images)
63
+
64
+ return resize_params, target_shapes
65
+
66
+
67
+ @Transform(
68
+ [K.panoptic_masks, "transforms.resize.target_shape"], K.panoptic_masks
69
+ )
70
+ class ResizePanopticMasks:
71
+ """Resize panoptic segmentation masks."""
72
+
73
+ def __call__(
74
+ self,
75
+ masks_list: list[NDArrayI64],
76
+ target_shape_list: list[tuple[int, int]],
77
+ ) -> list[NDArrayI64]:
78
+ """Resize masks."""
79
+ for i, (masks, target_shape) in enumerate(
80
+ zip(masks_list, target_shape_list)
81
+ ):
82
+ masks_ = torch.from_numpy(masks)
83
+ masks_ = (
84
+ resize_tensor(
85
+ masks_.float().unsqueeze(0).unsqueeze(0),
86
+ target_shape,
87
+ interpolation="nearest",
88
+ )
89
+ .type(masks_.dtype)
90
+ .squeeze(0)
91
+ .squeeze(0)
92
+ )
93
+ masks_list[i] = masks_.numpy()
94
+ return masks_list
95
+
96
+
97
+ @Transform([K.boxes3d, "transforms.resize.scale_factor"], K.boxes3d)
98
+ class ResizeBoxes3D:
99
+ """Resize list of 2D bounding boxes."""
100
+
101
+ def __call__(
102
+ self,
103
+ boxes_list: list[NDArrayF32],
104
+ scale_factors: list[tuple[float, float]],
105
+ ) -> list[NDArrayF32]:
106
+ """Resize 2D bounding boxes.
107
+
108
+ Args:
109
+ boxes_list: (list[NDArrayF32]): The bounding boxes to be resized.
110
+ scale_factors (list[tuple[float, float]]): scaling factors.
111
+
112
+ Returns:
113
+ list[NDArrayF32]: Resized bounding boxes according to parameters in
114
+ resize.
115
+ """
116
+ for i, (boxes, scale_factor) in enumerate(
117
+ zip(boxes_list, scale_factors)
118
+ ):
119
+ boxes[:, 2] /= scale_factor[0]
120
+ boxes_list[i] = boxes
121
+ return boxes_list
opendet3d/eval/__init__.py ADDED
File without changes
opendet3d/eval/detect3d.py ADDED
@@ -0,0 +1,1249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """3D Multiple Object Detection Evaluator."""
2
+
3
+ import contextlib
4
+ import copy
5
+ import datetime
6
+ import io
7
+ import itertools
8
+ import json
9
+ import os
10
+ import time
11
+ from collections import defaultdict
12
+
13
+ import numpy as np
14
+ import pycocotools.mask as maskUtils
15
+ import torch
16
+ from pycocotools.cocoeval import COCOeval
17
+ from scipy.spatial.distance import cdist
18
+ from terminaltables import AsciiTable
19
+ from vis4d.common.array import array_to_numpy
20
+ from vis4d.common.distributed import all_gather_object_cpu
21
+ from vis4d.common.typing import (
22
+ ArrayLike,
23
+ DictStrAny,
24
+ GenericFunc,
25
+ MetricLogs,
26
+ NDArrayF32,
27
+ NDArrayI64,
28
+ )
29
+ from vis4d.data.const import AxisMode
30
+ from vis4d.eval.base import Evaluator
31
+ from vis4d.eval.coco.detect import xyxy_to_xywh
32
+ from vis4d.op.box.box3d import boxes3d_to_corners
33
+ from vis4d.op.geometry.rotation import quaternion_to_matrix
34
+
35
+ from opendet3d.data.datasets.coco3d import COCO3D
36
+ from opendet3d.op.box.box3d import box3d_overlap
37
+ from opendet3d.op.geometric.rotation import so3_relative_angle
38
+
39
+
40
+ class Detect3DEvaluator(Evaluator):
41
+ """3D object detection evaluation with COCO format."""
42
+
43
+ def __init__(
44
+ self,
45
+ det_map: dict[str, int],
46
+ cat_map: dict[str, int],
47
+ annotation: str,
48
+ id2name: dict[int, str] | None = None,
49
+ per_class_eval: bool = True,
50
+ eval_prox: bool = False,
51
+ iou_type: str = "bbox",
52
+ num_columns: int = 6,
53
+ base_classes: list[str] | None = None,
54
+ ) -> None:
55
+ """Create an instance of the class."""
56
+ if id2name is None:
57
+ self.id2name = {v: k for k, v in det_map.items()}
58
+ else:
59
+ self.id2name = id2name
60
+
61
+ self.annotation = annotation
62
+ self.per_class_eval = per_class_eval
63
+ self.eval_prox = eval_prox
64
+ self.iou_type = iou_type
65
+ self.num_columns = num_columns
66
+ self.base_classes = base_classes
67
+
68
+ self.tp_errors = ["ATE", "AOE", "ASE"]
69
+
70
+ category_names = sorted(det_map, key=det_map.get)
71
+
72
+ with contextlib.redirect_stdout(io.StringIO()):
73
+ self._coco_gt = COCO3D([annotation], category_names)
74
+
75
+ self.cat_map = cat_map
76
+
77
+ self.bbox_2D_evals_per_cat_area: DictStrAny = {}
78
+ self.bbox_3D_evals_per_cat_area: DictStrAny = {}
79
+ self._predictions: list[DictStrAny] = []
80
+
81
+ def __repr__(self) -> str:
82
+ """Returns the string representation of the object."""
83
+ return f"3D Object Detection Evaluator with {self.annotation}"
84
+
85
+ @property
86
+ def metrics(self) -> list[str]:
87
+ """Supported metrics.
88
+
89
+ Returns:
90
+ list[str]: Metrics to evaluate.
91
+ """
92
+ return ["2D", "3D"]
93
+
94
+ def gather(self, gather_func: GenericFunc) -> None:
95
+ """Accumulate predictions across processes."""
96
+ all_preds = all_gather_object_cpu(
97
+ self._predictions, use_system_tmp=False
98
+ )
99
+ if all_preds is not None:
100
+ self._predictions = list(itertools.chain(*all_preds))
101
+
102
+ def reset(self) -> None:
103
+ """Reset the saved predictions to start new round of evaluation."""
104
+ self._predictions.clear()
105
+ self.bbox_2D_evals_per_cat_area.clear()
106
+ self.bbox_3D_evals_per_cat_area.clear()
107
+
108
+ def process_batch(
109
+ self,
110
+ coco_image_id: list[int],
111
+ pred_boxes: list[ArrayLike],
112
+ pred_scores: list[ArrayLike],
113
+ pred_classes: list[ArrayLike],
114
+ pred_boxes3d: list[ArrayLike] | None = None,
115
+ ) -> None:
116
+ """Process sample and convert detections to coco format."""
117
+ for i, image_id in enumerate(coco_image_id):
118
+ boxes = array_to_numpy(
119
+ pred_boxes[i].to(torch.float32), n_dims=None, dtype=np.float32
120
+ )
121
+ scores = array_to_numpy(
122
+ pred_scores[i].to(torch.float32), n_dims=None, dtype=np.float32
123
+ )
124
+ classes = array_to_numpy(
125
+ pred_classes[i], n_dims=None, dtype=np.int64
126
+ )
127
+
128
+ if pred_boxes3d is not None:
129
+ boxes3d = array_to_numpy(
130
+ pred_boxes3d[i].to(torch.float32),
131
+ n_dims=None,
132
+ dtype=np.float32,
133
+ )
134
+ else:
135
+ boxes3d = None
136
+
137
+ self._predictions_to_coco(
138
+ image_id, boxes, boxes3d, scores, classes
139
+ )
140
+
141
+ def _predictions_to_coco(
142
+ self,
143
+ img_id: int,
144
+ boxes: NDArrayF32,
145
+ boxes3d: NDArrayF32 | None,
146
+ scores: NDArrayF32,
147
+ classes: NDArrayI64,
148
+ ) -> None:
149
+ """Convert predictions to COCO format."""
150
+ boxes_xyxy = copy.deepcopy(boxes)
151
+ boxes_xywh = xyxy_to_xywh(boxes_xyxy)
152
+
153
+ if boxes3d is not None:
154
+ # FIXME: Make axismode configurable
155
+ corners_3d = boxes3d_to_corners(
156
+ torch.from_numpy(boxes3d), AxisMode.OPENCV
157
+ )
158
+
159
+ for i, (box, box_score, box_class) in enumerate(
160
+ zip(boxes_xywh, scores, classes)
161
+ ):
162
+ xywh = box.tolist()
163
+
164
+ result = {
165
+ "image_id": img_id,
166
+ "bbox": xywh,
167
+ "category_id": self.cat_map[self.id2name[box_class.item()]],
168
+ "score": box_score.item(),
169
+ }
170
+
171
+ # mapping to Omni3D format
172
+ if boxes3d is not None:
173
+ result["center_cam"] = boxes3d[i][:3].tolist()
174
+
175
+ # wlh to whl
176
+ result["dimensions"] = boxes3d[i][[3, 5, 4]].tolist()
177
+
178
+ result["R_cam"] = (
179
+ quaternion_to_matrix(torch.from_numpy(boxes3d[i][6:10]))
180
+ .numpy()
181
+ .tolist()
182
+ )
183
+
184
+ corners = corners_3d[i].numpy().tolist()
185
+
186
+ result["bbox3D"] = [
187
+ corners[6],
188
+ corners[4],
189
+ corners[0],
190
+ corners[2],
191
+ corners[7],
192
+ corners[5],
193
+ corners[1],
194
+ corners[3],
195
+ ]
196
+
197
+ result["depth"] = boxes3d[i][2].item()
198
+
199
+ self._predictions.append(result)
200
+
201
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
202
+ """Evaluate predictions."""
203
+ if metric == "2D":
204
+ metrics = ["AP", "AP50", "AP75", "AP95", "APs", "APm", "APl"]
205
+ else:
206
+ if self.iou_type == "bbox":
207
+ metrics = ["AP", "AP15", "AP25", "AP50", "APn", "APm", "APf"]
208
+ main_metric = "AP"
209
+ else:
210
+ metrics = ["AP", "ATE", "ASE", "AOE", "ODS"]
211
+ main_metric = "ODS"
212
+
213
+ if self.base_classes is not None:
214
+ metrics += [f"{main_metric}_Base", f"{main_metric}_Novel"]
215
+
216
+ if len(self._predictions) == 0:
217
+ return {m: 0.0 for m in metrics}, "No predictions to evaluate."
218
+
219
+ with contextlib.redirect_stdout(io.StringIO()):
220
+ coco_dt = self._coco_gt.loadRes(self._predictions)
221
+
222
+ assert coco_dt is not None
223
+ evaluator = Detect3Deval(
224
+ self._coco_gt,
225
+ coco_dt,
226
+ mode=metric,
227
+ eval_prox=self.eval_prox,
228
+ iou_type=self.iou_type,
229
+ )
230
+ evaluator.evaluate()
231
+ evaluator.accumulate()
232
+
233
+ if self.iou_type == "bbox":
234
+ log_str = "\n" + evaluator.summarize()
235
+
236
+ # precision: (iou, recall, cls, area range, max dets)
237
+ precisions = evaluator.eval["precision"]
238
+ assert len(self._coco_gt.getCatIds()) == precisions.shape[2]
239
+
240
+ if metric == "2D":
241
+ self.bbox_2D_evals_per_cat_area = evaluator.evals_per_cat_area
242
+
243
+ score_dict = dict(zip(metrics, evaluator.stats))
244
+ else:
245
+ if self.iou_type == "bbox":
246
+ self.bbox_3D_evals_per_cat_area = evaluator.evals_per_cat_area
247
+
248
+ score_dict = dict(zip(metrics, evaluator.stats))
249
+ else:
250
+ trans_tp_errors = evaluator.eval["trans_tp_errors"]
251
+ rot_tp_errors = evaluator.eval["rot_tp_errors"]
252
+ scale_tp_errors = evaluator.eval["scale_tp_errors"]
253
+
254
+ precision = precisions[:, :, :, 0, -1]
255
+ precision = precision[precision > -1]
256
+ if precision.size:
257
+ mAP = np.mean(precision).item()
258
+ else:
259
+ mAP = float("nan")
260
+
261
+ trans_tp = trans_tp_errors[:, :, :, 0, -1]
262
+ trans_tp = trans_tp[trans_tp > -1]
263
+
264
+ rot_tp = rot_tp_errors[:, :, :, 0, -1]
265
+ rot_tp = rot_tp[rot_tp > -1]
266
+
267
+ scale_tp = scale_tp_errors[:, :, :, 0, -1]
268
+ scale_tp = scale_tp[scale_tp > -1]
269
+
270
+ if trans_tp.size:
271
+ mATE = np.mean(trans_tp).item()
272
+ mAOE = np.mean(rot_tp).item()
273
+ mASE = np.mean(scale_tp).item()
274
+
275
+ mODS = (
276
+ np.sum(mAP * 3 + (1 - mATE) + (1 - mAOE) + (1 - mASE))
277
+ / 6
278
+ )
279
+
280
+ else:
281
+ mATE = float("nan")
282
+ mAOE = float("nan")
283
+ mASE = float("nan")
284
+ mODS = float("nan")
285
+
286
+ score_dict = {
287
+ "AP": mAP,
288
+ "ATE": mATE,
289
+ "ASE": mASE,
290
+ "AOE": mAOE,
291
+ "ODS": mODS,
292
+ }
293
+
294
+ log_str = "\nHigh-level metrics:"
295
+ for k, v in score_dict.items():
296
+ log_str += f"\n{k}: {v:.4f}"
297
+
298
+ if self.per_class_eval:
299
+ results_per_category = []
300
+ score_base_list = []
301
+ score_novel_list = []
302
+
303
+ for idx, cat_id in enumerate(self._coco_gt.getCatIds()):
304
+ # area range index 0: all area ranges
305
+ # max dets index -1: typically 100 per image
306
+ nm = self._coco_gt.loadCats(cat_id)[0]
307
+ precision = precisions[:, :, idx, 0, -1]
308
+ precision = precision[precision > -1]
309
+ if precision.size:
310
+ ap = np.mean(precision).item()
311
+ else:
312
+ ap = float("nan")
313
+
314
+ if self.iou_type == "dist":
315
+ trans_tp = trans_tp_errors[:, :, idx, 0, -1]
316
+ trans_tp = trans_tp[trans_tp > -1]
317
+
318
+ rot_tp = rot_tp_errors[:, :, idx, 0, -1]
319
+ rot_tp = rot_tp[rot_tp > -1]
320
+
321
+ scale_tp = scale_tp_errors[:, :, idx, 0, -1]
322
+ scale_tp = scale_tp[scale_tp > -1]
323
+
324
+ if trans_tp.size:
325
+ ate = np.mean(trans_tp).item()
326
+ aoe = np.mean(rot_tp).item()
327
+ ase = np.mean(scale_tp).item()
328
+
329
+ ods = (
330
+ np.sum(ap * 3 + (1 - ate) + (1 - aoe) + (1 - ase))
331
+ / 6
332
+ )
333
+
334
+ else:
335
+ ate = float("nan")
336
+ aoe = float("nan")
337
+ ase = float("nan")
338
+ ods = float("nan")
339
+
340
+ results_per_category.append(
341
+ (
342
+ f'{nm["name"]}',
343
+ f"{ap:0.3f}",
344
+ f"{ate:0.3f}",
345
+ f"{ase:0.3f}",
346
+ f"{aoe:0.3f}",
347
+ f"{ods:0.3f}",
348
+ )
349
+ )
350
+ else:
351
+ results_per_category.append(
352
+ (f'{nm["name"]}', f"{ap:0.3f}")
353
+ )
354
+
355
+ if self.base_classes is not None:
356
+ if self.iou_type == "dist":
357
+ score = ods
358
+ else:
359
+ score = ap
360
+
361
+ if nm["name"] in self.base_classes:
362
+ score_base_list.append(score)
363
+ else:
364
+ score_novel_list.append(score)
365
+
366
+ results_flatten = list(itertools.chain(*results_per_category))
367
+
368
+ if self.iou_type == "dist":
369
+ num_columns = 6
370
+ headers = ["category", "AP", "ATE", "ASE", "AOE", "ODS"]
371
+ else:
372
+ num_columns = min(
373
+ self.num_columns, len(results_per_category) * 2
374
+ )
375
+ headers = ["category", "AP"] * (num_columns // 2)
376
+ results = itertools.zip_longest(
377
+ *[results_flatten[i::num_columns] for i in range(num_columns)]
378
+ )
379
+ table_data = [headers] + list(results)
380
+ table = AsciiTable(table_data)
381
+ log_str = f"\n{table.table}\n{log_str}"
382
+
383
+ if self.base_classes is not None:
384
+ score_dict[f"{main_metric}_Base"] = np.mean(score_base_list).item()
385
+ score_dict[f"{main_metric}_Novel"] = np.mean(
386
+ score_novel_list
387
+ ).item()
388
+
389
+ return score_dict, log_str
390
+
391
+ def save(
392
+ self, metric: str, output_dir: str, prefix: str | None = None
393
+ ) -> None:
394
+ """Save the results to json files."""
395
+ assert metric in self.metrics
396
+
397
+ if prefix is not None:
398
+ result_folder = os.path.join(output_dir, prefix)
399
+ os.makedirs(result_folder, exist_ok=True)
400
+ else:
401
+ result_folder = output_dir
402
+
403
+ result_file = os.path.join(
404
+ result_folder, f"detect_{metric}_results.json"
405
+ )
406
+
407
+ with open(result_file, mode="w", encoding="utf-8") as f:
408
+ json.dump(self._predictions, f)
409
+
410
+
411
+ class Detect3Deval(COCOeval):
412
+ """COCOeval Wrapper for 2D and 3D box evaluation.
413
+
414
+ Now it support bbox IoU matching only.
415
+ """
416
+
417
+ def __init__(
418
+ self,
419
+ cocoGt=None,
420
+ cocoDt=None,
421
+ mode: str = "2D",
422
+ iou_type: str = "bbox",
423
+ eval_prox: bool = False,
424
+ ):
425
+ """Initialize Detect3Deval using coco APIs for Gt and Dt.
426
+
427
+ Args:
428
+ cocoGt: COCO object with ground truth annotations
429
+ cocoDt: COCO object with detection results
430
+ mode: (str) defines whether to evaluate 2D or 3D performance.
431
+ One of {"2D", "3D"}
432
+ eval_prox: (bool) if True, performs "Proximity Evaluation", i.e.
433
+ evaluates detections in the proximity of the ground truth2D
434
+ boxes. This is used for datasets which are not exhaustively
435
+ annotated.
436
+ """
437
+ if mode not in {"2D", "3D"}:
438
+ raise Exception(f"{mode} mode is not supported")
439
+ self.mode = mode
440
+ self.iou_type = iou_type
441
+ self.eval_prox = eval_prox
442
+
443
+ self.cocoGt = cocoGt # ground truth COCO API
444
+ self.cocoDt = cocoDt # detections COCO API
445
+
446
+ # per-image per-category evaluation results [KxAxI] elements
447
+ self.evalImgs = defaultdict(list)
448
+
449
+ self.eval = {} # accumulated evaluation results
450
+ self._gts = defaultdict(list) # gt for evaluation
451
+ self._dts = defaultdict(list) # dt for evaluation
452
+ self.params = Detect3DParams(mode=mode, iouType=iou_type) # parameters
453
+ self._paramsEval = {} # parameters for evaluation
454
+ self.stats = [] # result summarization
455
+ self.ious = {} # ious between all gts and dts
456
+
457
+ if cocoGt is not None:
458
+ self.params.imgIds = sorted(cocoGt.getImgIds())
459
+ self.params.catIds = sorted(cocoGt.getCatIds())
460
+
461
+ self.evals_per_cat_area = None
462
+
463
+ def _prepare(self) -> None:
464
+ """Prepare ._gts and ._dts for evaluation based on params."""
465
+ p = self.params
466
+
467
+ if p.useCats:
468
+ gts = self.cocoGt.loadAnns(
469
+ self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
470
+ )
471
+ dts = self.cocoDt.loadAnns(
472
+ self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
473
+ )
474
+
475
+ else:
476
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
477
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
478
+
479
+ # set ignore flag
480
+ ignore_flag = "ignore2D" if self.mode == "2D" else "ignore3D"
481
+ for gt in gts:
482
+ gt[ignore_flag] = gt[ignore_flag] if ignore_flag in gt else 0
483
+
484
+ self._gts = defaultdict(list) # gt for evaluation
485
+ self._dts = defaultdict(list) # dt for evaluation
486
+
487
+ for gt in gts:
488
+ self._gts[gt["image_id"], gt["category_id"]].append(gt)
489
+
490
+ for dt in dts:
491
+ self._dts[dt["image_id"], dt["category_id"]].append(dt)
492
+
493
+ self.evalImgs = defaultdict(
494
+ list
495
+ ) # per-image per-category evaluation results
496
+ self.eval = {} # accumulated evaluation results
497
+
498
+ def accumulate(self, p=None) -> None:
499
+ """Accumulate per image evaluation and store the result in self.eval.
500
+
501
+ Args:
502
+ p: input params for evaluation
503
+ """
504
+ print("Accumulating evaluation results...")
505
+ assert self.evalImgs, "Please run evaluate() first"
506
+
507
+ tic = time.time()
508
+
509
+ # allows input customized parameters
510
+ if p is None:
511
+ p = self.params
512
+
513
+ p.catIds = p.catIds if p.useCats == 1 else [-1]
514
+
515
+ T = len(p.iouThrs)
516
+ R = len(p.recThrs)
517
+ K = len(p.catIds) if p.useCats else 1
518
+ A = len(p.areaRng)
519
+ M = len(p.maxDets)
520
+
521
+ precision = -np.ones(
522
+ (T, R, K, A, M)
523
+ ) # -1 for the precision of absent categories
524
+ trans_tp_errors = -np.ones((T, R, K, A, M))
525
+ rot_tp_errors = -np.ones((T, R, K, A, M))
526
+ scale_tp_errors = -np.ones((T, R, K, A, M))
527
+ recall = -np.ones((T, K, A, M))
528
+ scores = -np.ones((T, R, K, A, M))
529
+
530
+ # create dictionary for future indexing
531
+ _pe = self._paramsEval
532
+
533
+ catIds = _pe.catIds if _pe.useCats else [-1]
534
+ setK = set(catIds)
535
+ setA = set(map(tuple, _pe.areaRng))
536
+ setM = set(_pe.maxDets)
537
+ setI = set(_pe.imgIds)
538
+
539
+ # get inds to evaluate
540
+ catid_list = [k for n, k in enumerate(p.catIds) if k in setK]
541
+ k_list = [n for n, k in enumerate(p.catIds) if k in setK]
542
+ m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
543
+ a_list = [
544
+ n
545
+ for n, a in enumerate(map(lambda x: tuple(x), p.areaRng))
546
+ if a in setA
547
+ ]
548
+ i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
549
+
550
+ I0 = len(_pe.imgIds)
551
+ A0 = len(_pe.areaRng)
552
+
553
+ has_precomputed_evals = not (self.evals_per_cat_area is None)
554
+
555
+ if has_precomputed_evals:
556
+ evals_per_cat_area = self.evals_per_cat_area
557
+ else:
558
+ evals_per_cat_area = {}
559
+
560
+ # retrieve E at each category, area range, and max number of detections
561
+ for k, (k0, catId) in enumerate(zip(k_list, catid_list)):
562
+ Nk = k0 * A0 * I0
563
+ for a, a0 in enumerate(a_list):
564
+ Na = a0 * I0
565
+
566
+ if has_precomputed_evals:
567
+ E = evals_per_cat_area[(catId, a)]
568
+
569
+ else:
570
+ E = [self.evalImgs[Nk + Na + i] for i in i_list]
571
+ E = [e for e in E if not e is None]
572
+ evals_per_cat_area[(catId, a)] = E
573
+
574
+ if len(E) == 0:
575
+ continue
576
+
577
+ for m, maxDet in enumerate(m_list):
578
+
579
+ dtScores = np.concatenate(
580
+ [e["dtScores"][0:maxDet] for e in E]
581
+ )
582
+
583
+ # different sorting method generates slightly different results.
584
+ # mergesort is used to be consistent as Matlab implementation.
585
+ inds = np.argsort(-dtScores, kind="mergesort")
586
+ dtScoresSorted = dtScores[inds]
587
+
588
+ dtm = np.concatenate(
589
+ [e["dtMatches"][:, 0:maxDet] for e in E], axis=1
590
+ )[:, inds]
591
+ dtIg = np.concatenate(
592
+ [e["dtIgnore"][:, 0:maxDet] for e in E], axis=1
593
+ )[:, inds]
594
+ gtIg = np.concatenate([e["gtIgnore"] for e in E])
595
+ npig = np.count_nonzero(gtIg == 0)
596
+
597
+ if npig == 0:
598
+ continue
599
+
600
+ tps = np.logical_and(dtm, np.logical_not(dtIg))
601
+ fps = np.logical_and(
602
+ np.logical_not(dtm), np.logical_not(dtIg)
603
+ )
604
+
605
+ tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float64)
606
+ fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float64)
607
+
608
+ # Compute TP error
609
+ if self.iou_type == "dist":
610
+ tems = np.concatenate(
611
+ [e["dtTranslationError"][:, 0:maxDet] for e in E],
612
+ axis=1,
613
+ )[:, inds]
614
+
615
+ oems = np.concatenate(
616
+ [e["dtOrientationError"][:, 0:maxDet] for e in E],
617
+ axis=1,
618
+ )[:, inds]
619
+
620
+ sems = np.concatenate(
621
+ [e["dtScaleError"][:, 0:maxDet] for e in E], axis=1
622
+ )[:, inds]
623
+
624
+ for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
625
+ tp = np.array(tp)
626
+ fp = np.array(fp)
627
+ nd = len(tp)
628
+ rc = tp / npig
629
+ pr = tp / (fp + tp + np.spacing(1))
630
+
631
+ q = np.zeros((R,))
632
+ ss = np.zeros((R,))
633
+ tran_tp_error = np.ones((R,))
634
+ rot_tp_error = np.ones((R,))
635
+ scale_tp_error = np.ones((R,))
636
+
637
+ if nd:
638
+ recall[t, k, a, m] = rc[-1]
639
+
640
+ else:
641
+ recall[t, k, a, m] = 0
642
+
643
+ # numpy is slow without cython optimization for accessing elements
644
+ # use python array gets significant speed improvement
645
+ pr = pr.tolist()
646
+ q = q.tolist()
647
+ tran_tp_error = tran_tp_error.tolist()
648
+ rot_tp_error = rot_tp_error.tolist()
649
+ scale_tp_error = scale_tp_error.tolist()
650
+
651
+ for i in range(nd - 1, 0, -1):
652
+ if pr[i] > pr[i - 1]:
653
+ pr[i - 1] = pr[i]
654
+
655
+ inds = np.searchsorted(rc, p.recThrs, side="left")
656
+
657
+ try:
658
+ for ri, pi in enumerate(inds):
659
+ q[ri] = pr[pi]
660
+ ss[ri] = dtScoresSorted[pi]
661
+ if self.iou_type == "dist":
662
+ tran_tp_error[ri] = tems[t][pi]
663
+ rot_tp_error[ri] = oems[t][pi]
664
+ scale_tp_error[ri] = sems[t][pi]
665
+ except:
666
+ pass
667
+
668
+ precision[t, :, k, a, m] = np.array(q)
669
+ scores[t, :, k, a, m] = np.array(ss)
670
+
671
+ if self.iou_type == "dist":
672
+ trans_tp_errors[t, :, k, a, m] = np.array(
673
+ tran_tp_error
674
+ )
675
+ rot_tp_errors[t, :, k, a, m] = np.array(
676
+ rot_tp_error
677
+ )
678
+ scale_tp_errors[t, :, k, a, m] = np.array(
679
+ scale_tp_error
680
+ )
681
+
682
+ self.evals_per_cat_area = evals_per_cat_area
683
+
684
+ self.eval = {
685
+ "params": p,
686
+ "counts": [T, R, K, A, M],
687
+ "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
688
+ "precision": precision,
689
+ "recall": recall,
690
+ "scores": scores,
691
+ "trans_tp_errors": trans_tp_errors,
692
+ "rot_tp_errors": rot_tp_errors,
693
+ "scale_tp_errors": scale_tp_errors,
694
+ }
695
+
696
+ toc = time.time()
697
+ print("DONE (t={:0.2f}s).".format(toc - tic))
698
+
699
+ def evaluate(self) -> None:
700
+ """Run per image evaluation on given images.
701
+
702
+ It will store results (a list of dict) in self.evalImgs
703
+ """
704
+ print("Running per image evaluation...")
705
+
706
+ p = self.params
707
+ print(f"Evaluate annotation type *{p.iouType}*")
708
+
709
+ tic = time.time()
710
+
711
+ p.imgIds = list(np.unique(p.imgIds))
712
+ if p.useCats:
713
+ p.catIds = list(np.unique(p.catIds))
714
+
715
+ p.maxDets = sorted(p.maxDets)
716
+ self.params = p
717
+
718
+ self._prepare()
719
+
720
+ catIds = p.catIds if p.useCats else [-1]
721
+
722
+ # loop through images, area range, max detection number
723
+ self.ious = {
724
+ (imgId, catId): self.computeIoU(imgId, catId)
725
+ for imgId in p.imgIds
726
+ for catId in catIds
727
+ }
728
+
729
+ maxDet = p.maxDets[-1]
730
+
731
+ self.evalImgs = [
732
+ self.evaluateImg(imgId, catId, areaRng, maxDet)
733
+ for catId in catIds
734
+ for areaRng in p.areaRng
735
+ for imgId in p.imgIds
736
+ ]
737
+
738
+ self._paramsEval = copy.deepcopy(self.params)
739
+
740
+ toc = time.time()
741
+ print("DONE (t={:0.2f}s).".format(toc - tic))
742
+
743
+ def computeIoU(self, imgId, catId) -> tuple[NDArrayF32, NDArrayF32]:
744
+ """Computes the IoUs by sorting based on score"""
745
+ p = self.params
746
+
747
+ if p.useCats:
748
+ gt = self._gts[imgId, catId]
749
+ dt = self._dts[imgId, catId]
750
+ else:
751
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
752
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
753
+
754
+ if len(gt) == 0 and len(dt) == 0:
755
+ return []
756
+
757
+ inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
758
+ dt = [dt[i] for i in inds]
759
+ if len(dt) > p.maxDets[-1]:
760
+ dt = dt[0 : p.maxDets[-1]]
761
+
762
+ if self.mode == "2D":
763
+ g = [g["bbox"] for g in gt]
764
+ d = [d["bbox"] for d in dt]
765
+ elif self.mode == "3D":
766
+ g = [g["bbox3D"] for g in gt]
767
+ d = [d["bbox3D"] for d in dt]
768
+
769
+ # compute iou between each dt and gt region
770
+ # iscrowd is required in builtin maskUtils so we
771
+ # use a dummy buffer for it
772
+ iscrowd = [0 for _ in gt]
773
+ if self.mode == "2D":
774
+ ious = maskUtils.iou(d, g, iscrowd)
775
+ elif len(d) > 0 and len(g) > 0:
776
+ if p.iouType == "bbox":
777
+ dd = torch.tensor(d, dtype=torch.float32)
778
+ gg = torch.tensor(g, dtype=torch.float32)
779
+
780
+ ious = box3d_overlap(dd, gg).cpu().numpy()
781
+ else:
782
+ ious = np.zeros((len(d), len(g)))
783
+
784
+ dd = [d["center_cam"] for d in dt]
785
+ gg = [g["center_cam"] for g in gt]
786
+
787
+ ious = cdist(dd, gg, metric="euclidean")
788
+ else:
789
+ ious = []
790
+
791
+ in_prox = None
792
+
793
+ if self.eval_prox:
794
+ g = [g["bbox"] for g in gt]
795
+ d = [d["bbox"] for d in dt]
796
+ iscrowd = [0 for o in gt]
797
+ ious2d = maskUtils.iou(d, g, iscrowd)
798
+
799
+ if type(ious2d) == list:
800
+ in_prox = []
801
+
802
+ else:
803
+ in_prox = ious2d > p.proximity_thresh
804
+
805
+ return ious, in_prox
806
+
807
+ def evaluateImg(self, imgId, catId, aRng, maxDet):
808
+ """
809
+ Perform evaluation for single category and image
810
+ Returns:
811
+ dict (single image results)
812
+ """
813
+
814
+ p = self.params
815
+ if p.useCats:
816
+ gt = self._gts[imgId, catId]
817
+ dt = self._dts[imgId, catId]
818
+
819
+ else:
820
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
821
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
822
+
823
+ if len(gt) == 0 and len(dt) == 0:
824
+ return None
825
+
826
+ flag_range = "area" if self.mode == "2D" else "depth"
827
+ flag_ignore = "ignore2D" if self.mode == "2D" else "ignore3D"
828
+
829
+ for g in gt:
830
+ if g[flag_ignore] or (
831
+ g[flag_range] < aRng[0] or g[flag_range] > aRng[1]
832
+ ):
833
+ g["_ignore"] = 1
834
+ else:
835
+ g["_ignore"] = 0
836
+
837
+ # sort dt highest score first, sort gt ignore last
838
+ gtind = np.argsort([g["_ignore"] for g in gt], kind="mergesort")
839
+ gt = [gt[i] for i in gtind]
840
+ dtind = np.argsort([-d["score"] for d in dt], kind="mergesort")
841
+ dt = [dt[i] for i in dtind[0:maxDet]]
842
+
843
+ # load computed ious
844
+ ious = (
845
+ self.ious[imgId, catId][0][:, gtind]
846
+ if len(self.ious[imgId, catId][0]) > 0
847
+ else self.ious[imgId, catId][0]
848
+ )
849
+
850
+ if self.eval_prox:
851
+ in_prox = (
852
+ self.ious[imgId, catId][1][:, gtind]
853
+ if len(self.ious[imgId, catId][1]) > 0
854
+ else self.ious[imgId, catId][1]
855
+ )
856
+
857
+ T = len(p.iouThrs)
858
+ G = len(gt)
859
+ D = len(dt)
860
+ gtm = np.zeros((T, G))
861
+ dtm = np.zeros((T, D))
862
+ tem = np.ones((T, D)) # Translation Error
863
+ sem = np.ones((T, D)) # Scale Error
864
+ oem = np.ones((T, D)) # Oritentation Error
865
+ gtIg = np.array([g["_ignore"] for g in gt])
866
+ dtIg = np.zeros((T, D))
867
+
868
+ dist_thres = 1
869
+ if not len(ious) == 0:
870
+ for tind, t in enumerate(p.iouThrs):
871
+ for dind, d in enumerate(dt):
872
+
873
+ # information about best match so far (m=-1 -> unmatched)
874
+ iou = min([t, 1 - 1e-10])
875
+ m = -1
876
+
877
+ for gind, g in enumerate(gt):
878
+ # in case of proximity evaluation, if not in proximity continue
879
+ if self.eval_prox and not in_prox[dind, gind]:
880
+ continue
881
+
882
+ # if this gt already matched, continue
883
+ if gtm[tind, gind] > 0:
884
+ continue
885
+
886
+ # if dt matched to reg gt, and on ignore gt, stop
887
+ if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1:
888
+ break
889
+
890
+ # continue to next gt unless better match made
891
+ if p.iouType == "bbox" and ious[dind, gind] < iou:
892
+ continue
893
+
894
+ if p.iouType == "dist":
895
+ # Compute Object Radius
896
+ gt_obj_radius = (
897
+ np.linalg.norm(np.array(g["dimensions"])) / 2
898
+ )
899
+ if ious[dind, gind] > gt_obj_radius * iou:
900
+ continue
901
+ else:
902
+ dist_thres = gt_obj_radius * iou
903
+
904
+ # if match successful and best so far, store appropriately
905
+ iou = ious[dind, gind]
906
+ m = gind
907
+
908
+ # if match made store id of match for both dt and gt
909
+ if m == -1:
910
+ continue
911
+
912
+ dtIg[tind, dind] = gtIg[m]
913
+ dtm[tind, dind] = gt[m]["id"]
914
+ gtm[tind, m] = d["id"]
915
+
916
+ if p.iouType == "dist":
917
+ # Translation Error
918
+ tem[tind, dind] = np.linalg.norm(
919
+ np.array(d["center_cam"])
920
+ - np.array(gt[m]["center_cam"])
921
+ ) / (dist_thres)
922
+
923
+ # Orientation Error
924
+ oem[tind, dind] = (
925
+ so3_relative_angle(
926
+ torch.tensor(d["R_cam"])[None],
927
+ torch.tensor(gt[m]["R_cam"])[None],
928
+ cos_bound=1e-2,
929
+ eps=1e-2,
930
+ ).item()
931
+ / np.pi
932
+ )
933
+
934
+ # Scale Error
935
+ min_whl = np.minimum(
936
+ d["dimensions"], gt[m]["dimensions"]
937
+ )
938
+ volume_annotation = np.prod(gt[m]["dimensions"])
939
+ volume_result = np.prod(d["dimensions"])
940
+
941
+ intersection = np.prod(min_whl)
942
+ union = (
943
+ volume_annotation + volume_result - intersection
944
+ )
945
+ scale_iou = intersection / union
946
+
947
+ sem[tind, dind] = 1 - scale_iou
948
+
949
+ # set unmatched detections outside of area range to ignore
950
+ a = np.array(
951
+ [d[flag_range] < aRng[0] or d[flag_range] > aRng[1] for d in dt]
952
+ ).reshape((1, len(dt)))
953
+
954
+ dtIg = np.logical_or(
955
+ dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0))
956
+ )
957
+
958
+ # in case of proximity evaluation, ignore detections which are far from gt regions
959
+ if self.eval_prox and len(in_prox) > 0:
960
+ dt_far = in_prox.any(1) == 0
961
+ dtIg = np.logical_or(
962
+ dtIg, np.repeat(dt_far.reshape((1, len(dt))), T, 0)
963
+ )
964
+
965
+ # store results for given image and category
966
+ return {
967
+ "image_id": imgId,
968
+ "category_id": catId,
969
+ "aRng": aRng,
970
+ "maxDet": maxDet,
971
+ "dtIds": [d["id"] for d in dt],
972
+ "gtIds": [g["id"] for g in gt],
973
+ "dtMatches": dtm,
974
+ "gtMatches": gtm,
975
+ "dtScores": [d["score"] for d in dt],
976
+ "gtIgnore": gtIg,
977
+ "dtIgnore": dtIg,
978
+ "dtTranslationError": tem,
979
+ "dtScaleError": sem,
980
+ "dtOrientationError": oem,
981
+ }
982
+
983
+ def summarize(self):
984
+ """
985
+ Compute and display summary metrics for evaluation results.
986
+ Note this functin can *only* be applied on the default parameter setting
987
+ """
988
+
989
+ def _summarize(
990
+ mode, ap=1, iouThr=None, areaRng="all", maxDets=100, log_str=""
991
+ ):
992
+ p = self.params
993
+ eval = self.eval
994
+
995
+ if mode == "2D":
996
+ if self.iou_type == "bbox":
997
+ iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
998
+ else:
999
+ iStr = " {:<18} {} @[ Dist={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
1000
+
1001
+ elif mode == "3D":
1002
+ if self.iou_type == "bbox":
1003
+ iStr = " {:<18} {} @[ IoU={:<9} | depth={:>6s} | maxDets={:>3d} ] = {:0.3f}"
1004
+ else:
1005
+ iStr = " {:<18} {} @[ Dist={:<9} | depth={:>6s} | maxDets={:>3d} ] = {:0.3f}"
1006
+
1007
+ titleStr = "Average Precision" if ap == 1 else "Average Recall"
1008
+ typeStr = "(AP)" if ap == 1 else "(AR)"
1009
+
1010
+ iouStr = (
1011
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
1012
+ if iouThr is None
1013
+ else "{:0.2f}".format(iouThr)
1014
+ )
1015
+
1016
+ aind = [
1017
+ i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng
1018
+ ]
1019
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
1020
+
1021
+ if ap == 1:
1022
+
1023
+ # dimension of precision: [TxRxKxAxM]
1024
+ s = eval["precision"]
1025
+
1026
+ # IoU
1027
+ if iouThr is not None:
1028
+ t = np.where(np.isclose(iouThr, p.iouThrs.astype(float)))[
1029
+ 0
1030
+ ]
1031
+ s = s[t]
1032
+
1033
+ s = s[:, :, :, aind, mind]
1034
+
1035
+ else:
1036
+ # dimension of recall: [TxKxAxM]
1037
+ s = eval["recall"]
1038
+ if iouThr is not None:
1039
+ t = np.where(iouThr == p.iouThrs)[0]
1040
+ s = s[t]
1041
+ s = s[:, :, aind, mind]
1042
+
1043
+ if len(s[s > -1]) == 0:
1044
+ mean_s = -1
1045
+
1046
+ else:
1047
+ mean_s = np.mean(s[s > -1])
1048
+
1049
+ if log_str != "":
1050
+ log_str += "\n"
1051
+
1052
+ log_str += "mode={} ".format(mode) + iStr.format(
1053
+ titleStr, typeStr, iouStr, areaRng, maxDets, mean_s
1054
+ )
1055
+
1056
+ return mean_s, log_str
1057
+
1058
+ def _summarizeDets(mode):
1059
+
1060
+ params = self.params
1061
+
1062
+ # Define the thresholds to be printed
1063
+ if mode == "2D":
1064
+ thres = [0.5, 0.75, 0.95]
1065
+ else:
1066
+ if self.iou_type == "bbox":
1067
+ thres = [0.15, 0.25, 0.50]
1068
+ else:
1069
+ thres = [0.5, 0.75, 1.0]
1070
+
1071
+ stats = np.zeros((13,))
1072
+ stats[0], log_str = _summarize(mode, 1)
1073
+
1074
+ stats[1], log_str = _summarize(
1075
+ mode,
1076
+ 1,
1077
+ iouThr=thres[0],
1078
+ maxDets=params.maxDets[2],
1079
+ log_str=log_str,
1080
+ )
1081
+
1082
+ stats[2], log_str = _summarize(
1083
+ mode,
1084
+ 1,
1085
+ iouThr=thres[1],
1086
+ maxDets=params.maxDets[2],
1087
+ log_str=log_str,
1088
+ )
1089
+
1090
+ stats[3], log_str = _summarize(
1091
+ mode,
1092
+ 1,
1093
+ iouThr=thres[2],
1094
+ maxDets=params.maxDets[2],
1095
+ log_str=log_str,
1096
+ )
1097
+
1098
+ stats[4], log_str = _summarize(
1099
+ mode,
1100
+ 1,
1101
+ areaRng=params.areaRngLbl[1],
1102
+ maxDets=params.maxDets[2],
1103
+ log_str=log_str,
1104
+ )
1105
+
1106
+ stats[5], log_str = _summarize(
1107
+ mode,
1108
+ 1,
1109
+ areaRng=params.areaRngLbl[2],
1110
+ maxDets=params.maxDets[2],
1111
+ log_str=log_str,
1112
+ )
1113
+
1114
+ stats[6], log_str = _summarize(
1115
+ mode,
1116
+ 1,
1117
+ areaRng=params.areaRngLbl[3],
1118
+ maxDets=params.maxDets[2],
1119
+ log_str=log_str,
1120
+ )
1121
+
1122
+ stats[7], log_str = _summarize(
1123
+ mode, 0, maxDets=params.maxDets[0], log_str=log_str
1124
+ )
1125
+
1126
+ stats[8], log_str = _summarize(
1127
+ mode, 0, maxDets=params.maxDets[1], log_str=log_str
1128
+ )
1129
+
1130
+ stats[9], log_str = _summarize(
1131
+ mode, 0, maxDets=params.maxDets[2], log_str=log_str
1132
+ )
1133
+
1134
+ stats[10], log_str = _summarize(
1135
+ mode,
1136
+ 0,
1137
+ areaRng=params.areaRngLbl[1],
1138
+ maxDets=params.maxDets[2],
1139
+ log_str=log_str,
1140
+ )
1141
+
1142
+ stats[11], log_str = _summarize(
1143
+ mode,
1144
+ 0,
1145
+ areaRng=params.areaRngLbl[2],
1146
+ maxDets=params.maxDets[2],
1147
+ log_str=log_str,
1148
+ )
1149
+
1150
+ stats[12], log_str = _summarize(
1151
+ mode,
1152
+ 0,
1153
+ areaRng=params.areaRngLbl[3],
1154
+ maxDets=params.maxDets[2],
1155
+ log_str=log_str,
1156
+ )
1157
+
1158
+ return stats, log_str
1159
+
1160
+ if not self.eval:
1161
+ raise Exception("Please run accumulate() first")
1162
+
1163
+ stats, log_str = _summarizeDets(self.mode)
1164
+ self.stats = stats
1165
+
1166
+ return log_str
1167
+
1168
+
1169
+ class Detect3DParams:
1170
+ """Params for the 3d detection evaluation API."""
1171
+
1172
+ def __init__(
1173
+ self,
1174
+ mode: str = "2D",
1175
+ iouType: str = "bbox",
1176
+ proximity_thresh: float = 0.3,
1177
+ ) -> None:
1178
+ """Create an instance of Detect3DParams.
1179
+
1180
+ Args:
1181
+ mode: (str) defines whether to evaluate 2D or 3D performance.
1182
+ iouType: (str) defines the type of IoU to be used for evaluation.
1183
+ proximity_thresh (float): It defines the neighborhood when
1184
+ evaluating on non-exhaustively annotated datasets.
1185
+ """
1186
+ assert iouType in {"bbox", "dist"}, f"Invalid iouType {iouType}."
1187
+ self.iouType = iouType
1188
+
1189
+ if mode == "2D":
1190
+ self.setDet2DParams()
1191
+ elif mode == "3D":
1192
+ self.setDet3DParams()
1193
+ else:
1194
+ raise Exception(f"{mode} mode is not supported")
1195
+ self.mode = mode
1196
+ self.proximity_thresh = proximity_thresh
1197
+
1198
+ def setDet2DParams(self) -> None:
1199
+ """Set parameters for 2D detection evaluation."""
1200
+ self.imgIds = []
1201
+ self.catIds = []
1202
+
1203
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
1204
+ self.iouThrs = np.linspace(
1205
+ 0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True
1206
+ )
1207
+
1208
+ self.recThrs = np.linspace(
1209
+ 0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True
1210
+ )
1211
+
1212
+ self.maxDets = [1, 10, 100]
1213
+ self.areaRng = [
1214
+ [0**2, 1e5**2],
1215
+ [0**2, 32**2],
1216
+ [32**2, 96**2],
1217
+ [96**2, 1e5**2],
1218
+ ]
1219
+
1220
+ self.areaRngLbl = ["all", "small", "medium", "large"]
1221
+ self.useCats = 1
1222
+
1223
+ def setDet3DParams(self) -> None:
1224
+ """Set parameters for 3D detection evaluation."""
1225
+ self.imgIds = []
1226
+ self.catIds = []
1227
+
1228
+ # np.arange causes trouble. The data point on arange is slightly
1229
+ # larger than the true value
1230
+ if self.iouType == "bbox":
1231
+ self.iouThrs = np.linspace(
1232
+ 0.05,
1233
+ 0.5,
1234
+ int(np.round((0.5 - 0.05) / 0.05)) + 1,
1235
+ endpoint=True,
1236
+ )
1237
+ else:
1238
+ self.iouThrs = np.linspace(
1239
+ 0.5, 1.0, int(np.round((1.00 - 0.5) / 0.05)) + 1, endpoint=True
1240
+ )
1241
+
1242
+ self.recThrs = np.linspace(
1243
+ 0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True
1244
+ )
1245
+
1246
+ self.maxDets = [1, 10, 100]
1247
+ self.areaRng = [[0, 1e5], [0, 10], [10, 35], [35, 1e5]]
1248
+ self.areaRngLbl = ["all", "near", "medium", "far"]
1249
+ self.useCats = 1
opendet3d/eval/omni3d.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Omni3D 3D detection evaluation."""
2
+
3
+ import contextlib
4
+ import copy
5
+ import io
6
+ import itertools
7
+ import os
8
+ from collections.abc import Sequence
9
+
10
+ import numpy as np
11
+ from terminaltables import AsciiTable
12
+ from vis4d.common.logging import rank_zero_info
13
+ from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber
14
+ from vis4d.eval.base import Evaluator
15
+
16
+ from opendet3d.data.datasets.omni3d.omni3d_classes import omni3d_class_map
17
+ from opendet3d.data.datasets.omni3d.util import get_dataset_det_map
18
+
19
+ from .detect3d import Detect3Deval, Detect3DEvaluator
20
+
21
+ omni3d_in = {
22
+ "stationery",
23
+ "sink",
24
+ "table",
25
+ "floor mat",
26
+ "bottle",
27
+ "bookcase",
28
+ "bin",
29
+ "blinds",
30
+ "pillow",
31
+ "bicycle",
32
+ "refrigerator",
33
+ "night stand",
34
+ "chair",
35
+ "sofa",
36
+ "books",
37
+ "oven",
38
+ "towel",
39
+ "cabinet",
40
+ "window",
41
+ "curtain",
42
+ "bathtub",
43
+ "laptop",
44
+ "desk",
45
+ "television",
46
+ "clothes",
47
+ "stove",
48
+ "cup",
49
+ "shelves",
50
+ "box",
51
+ "shoes",
52
+ "mirror",
53
+ "door",
54
+ "picture",
55
+ "lamp",
56
+ "machine",
57
+ "counter",
58
+ "bed",
59
+ "toilet",
60
+ }
61
+
62
+ omni3d_out = {
63
+ "cyclist",
64
+ "pedestrian",
65
+ "trailer",
66
+ "bus",
67
+ "motorcycle",
68
+ "car",
69
+ "barrier",
70
+ "truck",
71
+ "van",
72
+ "traffic cone",
73
+ "bicycle",
74
+ }
75
+
76
+
77
+ class Omni3DEvaluator(Evaluator):
78
+ """Omni3D 3D detection evaluator."""
79
+
80
+ def __init__(
81
+ self,
82
+ data_root: str = "data/omni3d",
83
+ omni3d50: bool = True,
84
+ datasets: Sequence[str] = (
85
+ "KITTI_test",
86
+ "nuScenes_test",
87
+ "SUNRGBD_test",
88
+ "Hypersim_test",
89
+ "ARKitScenes_test",
90
+ "Objectron_test",
91
+ ),
92
+ per_class_eval: bool = True,
93
+ ) -> None:
94
+ """Initialize the evaluator."""
95
+ super().__init__()
96
+ self.id_to_name = {v: k for k, v in omni3d_class_map.items()}
97
+ self.dataset_names = datasets
98
+ self.per_class_eval = per_class_eval
99
+
100
+ # Each dataset evaluator is stored here
101
+ self.evaluators: dict[str, Detect3DEvaluator] = {}
102
+
103
+ # These store the evaluations for each category and area,
104
+ # concatenated from ALL evaluated datasets. Doing so avoids
105
+ # the need to re-compute them when accumulating results.
106
+ self.evals_per_cat_area2D = {}
107
+ self.evals_per_cat_area3D = {}
108
+
109
+ self.overall_imgIds = set()
110
+ self.overall_catIds = set()
111
+
112
+ for dataset_name in self.dataset_names:
113
+ annotation = os.path.join(
114
+ data_root, "annotations", f"{dataset_name}.json"
115
+ )
116
+
117
+ det_map = get_dataset_det_map(
118
+ dataset_name=dataset_name, omni3d50=omni3d50
119
+ )
120
+
121
+ # create an individual dataset evaluator
122
+ self.evaluators[dataset_name] = Detect3DEvaluator(
123
+ det_map,
124
+ cat_map=omni3d_class_map,
125
+ annotation=annotation,
126
+ eval_prox=(
127
+ "Objectron" in dataset_name or "SUNRGBD" in dataset_name
128
+ ),
129
+ )
130
+
131
+ self.overall_imgIds.update(
132
+ set(self.evaluators[dataset_name]._coco_gt.getImgIds())
133
+ )
134
+ self.overall_catIds.update(
135
+ set(self.evaluators[dataset_name]._coco_gt.getCatIds())
136
+ )
137
+
138
+ def __repr__(self) -> str:
139
+ """Returns the string representation of the object."""
140
+ datasets_str = ", ".join(self.dataset_names)
141
+ return f"Omni3DEvaluator ({datasets_str})"
142
+
143
+ @property
144
+ def metrics(self) -> list[str]:
145
+ """Supported metrics.
146
+
147
+ Returns:
148
+ list[str]: Metrics to evaluate.
149
+ """
150
+ return ["2D", "3D"]
151
+
152
+ def reset(self) -> None:
153
+ """Reset the saved predictions to start new round of evaluation."""
154
+ for dataset_name in self.dataset_names:
155
+ self.evaluators[dataset_name].reset()
156
+ self.evals_per_cat_area2D.clear()
157
+ self.evals_per_cat_area3D.clear()
158
+
159
+ def gather(self, gather_func: GenericFunc) -> None:
160
+ """Accumulate predictions across processes."""
161
+ for dataset_name in self.dataset_names:
162
+ self.evaluators[dataset_name].gather(gather_func)
163
+
164
+ def process_batch(
165
+ self,
166
+ coco_image_id: list[int],
167
+ dataset_names: list[str],
168
+ pred_boxes: list[NDArrayNumber],
169
+ pred_scores: list[NDArrayNumber],
170
+ pred_classes: list[NDArrayNumber],
171
+ pred_boxes3d: list[NDArrayNumber] | None = None,
172
+ ) -> None:
173
+ """Process sample and convert detections to coco format."""
174
+ for i, dataset_name in enumerate(dataset_names):
175
+ self.evaluators[dataset_name].process_batch(
176
+ [coco_image_id[i]],
177
+ [pred_boxes[i]],
178
+ [pred_scores[i]],
179
+ [pred_classes[i]],
180
+ pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None,
181
+ )
182
+
183
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
184
+ """Evaluate predictions and return the results."""
185
+ assert metric in self.metrics, f"Unsupported metric: {metric}"
186
+
187
+ log_dict = {}
188
+
189
+ for dataset_name in self.dataset_names:
190
+ rank_zero_info(f"Evaluating {dataset_name}...")
191
+ per_dataset_log_dict, dataset_log_str = self.evaluators[
192
+ dataset_name
193
+ ].evaluate(metric)
194
+
195
+ log_dict[f"AP_{dataset_name}"] = per_dataset_log_dict["AP"]
196
+
197
+ rank_zero_info(dataset_log_str + "\n")
198
+
199
+ # store the partially accumulated evaluations per category per area
200
+ if metric == "2D":
201
+ for key, item in self.evaluators[
202
+ dataset_name
203
+ ].bbox_2D_evals_per_cat_area.items():
204
+ if not key in self.evals_per_cat_area2D:
205
+ self.evals_per_cat_area2D[key] = []
206
+ self.evals_per_cat_area2D[key] += item
207
+ else:
208
+ for key, item in self.evaluators[
209
+ dataset_name
210
+ ].bbox_3D_evals_per_cat_area.items():
211
+ if not key in self.evals_per_cat_area3D:
212
+ self.evals_per_cat_area3D[key] = []
213
+ self.evals_per_cat_area3D[key] += item
214
+
215
+ results_per_category_dict = {}
216
+ results_per_category = []
217
+
218
+ rank_zero_info(f"Evaluating Omni3D for {metric} Detection...")
219
+
220
+ evaluator = Detect3Deval(mode=metric)
221
+ evaluator.params.catIds = list(self.overall_catIds)
222
+ evaluator.params.imgIds = list(self.overall_imgIds)
223
+ evaluator.evalImgs = True
224
+
225
+ if metric == "2D":
226
+ evaluator.evals_per_cat_area = self.evals_per_cat_area2D
227
+ metrics = ["AP", "AP50", "AP75", "AP95", "APs", "APm", "APl"]
228
+ else:
229
+ evaluator.evals_per_cat_area = self.evals_per_cat_area3D
230
+ metrics = ["AP", "AP15", "AP25", "AP50", "APn", "APm", "APf"]
231
+
232
+ evaluator._paramsEval = copy.deepcopy(evaluator.params)
233
+
234
+ with contextlib.redirect_stdout(io.StringIO()):
235
+ evaluator.accumulate()
236
+ log_str = "\n" + evaluator.summarize()
237
+
238
+ log_dict.update(dict(zip(metrics, evaluator.stats)))
239
+
240
+ if self.per_class_eval:
241
+ precisions = evaluator.eval["precision"]
242
+ for idx, cat_id in enumerate(self.overall_catIds):
243
+ cat_name = self.id_to_name[cat_id]
244
+ precision = precisions[:, :, idx, 0, -1]
245
+ precision = precision[precision > -1]
246
+ if precision.size:
247
+ ap = float(np.mean(precision).item())
248
+ else:
249
+ ap = float("nan")
250
+
251
+ results_per_category_dict[cat_name] = ap
252
+ results_per_category.append((f"{cat_name}", f"{ap:0.3f}"))
253
+
254
+ num_columns = min(6, len(results_per_category) * 2)
255
+ results_flatten = list(itertools.chain(*results_per_category))
256
+ headers = ["category", "AP"] * (num_columns // 2)
257
+ results_2d = itertools.zip_longest(
258
+ *[results_flatten[i::num_columns] for i in range(num_columns)]
259
+ )
260
+ table_data = [headers] + list(results_2d)
261
+ table = AsciiTable(table_data)
262
+ log_str = f"\n{table.table}\n{log_str}"
263
+
264
+ # Omni3D Outdoor performance
265
+ ap_out_lst = []
266
+ for cat in omni3d_out:
267
+ ap_out_lst.append(results_per_category_dict.get(cat, 0.0))
268
+
269
+ log_dict["Omni3D_Out"] = np.mean(ap_out_lst).item()
270
+
271
+ # Omni3D Indoor performance
272
+ ap_in_lst = []
273
+ for cat in omni3d_in:
274
+ ap_in_lst.append(results_per_category_dict.get(cat, 0.0))
275
+
276
+ log_dict["Omni3D_In"] = np.mean(ap_in_lst).item()
277
+
278
+ return log_dict, log_str
279
+
280
+ def save(self, metric: str, output_dir: str) -> None:
281
+ """Save the results to json files."""
282
+ for dataset_name in self.dataset_names:
283
+ self.evaluators[dataset_name].save(
284
+ metric, output_dir, prefix=dataset_name
285
+ )
opendet3d/eval/open.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multi-data 3D detection evaluation."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from vis4d.common.logging import rank_zero_info
6
+ from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber
7
+ from vis4d.eval.base import Evaluator
8
+
9
+ from .detect3d import Detect3DEvaluator
10
+ from .omni3d import Omni3DEvaluator
11
+
12
+
13
+ class OpenDetect3DEvaluator(Evaluator):
14
+ """Multi-data 3D detection evaluator."""
15
+
16
+ def __init__(
17
+ self,
18
+ datasets: Sequence[str],
19
+ evaluators: Sequence[Detect3DEvaluator],
20
+ omni3d_evaluator: Omni3DEvaluator | None = None,
21
+ ) -> None:
22
+ """Initialize the evaluator."""
23
+ super().__init__()
24
+ self.dataset_names = datasets
25
+ self.evaluators = {
26
+ name: evaluator for name, evaluator in zip(datasets, evaluators)
27
+ }
28
+
29
+ self.omni3d_evaluator = omni3d_evaluator
30
+
31
+ def __repr__(self) -> str:
32
+ """Returns the string representation of the object."""
33
+ datasets_str = ", ".join(self.dataset_names)
34
+ return f"Open 3D Object Detection Evaluator ({datasets_str})"
35
+
36
+ @property
37
+ def metrics(self) -> list[str]:
38
+ """Supported metrics.
39
+
40
+ Returns:
41
+ list[str]: Metrics to evaluate.
42
+ """
43
+ return ["2D", "3D"]
44
+
45
+ def reset(self) -> None:
46
+ """Reset the saved predictions to start new round of evaluation."""
47
+ for dataset_name in self.dataset_names:
48
+ self.evaluators[dataset_name].reset()
49
+
50
+ if self.omni3d_evaluator is not None:
51
+ self.omni3d_evaluator.reset()
52
+
53
+ def gather(self, gather_func: GenericFunc) -> None:
54
+ """Accumulate predictions across processes."""
55
+ for dataset_name in self.dataset_names:
56
+ self.evaluators[dataset_name].gather(gather_func)
57
+
58
+ if self.omni3d_evaluator is not None:
59
+ self.omni3d_evaluator.gather(gather_func)
60
+
61
+ def process_batch(
62
+ self,
63
+ coco_image_id: list[int],
64
+ dataset_names: list[str],
65
+ pred_boxes: list[NDArrayNumber],
66
+ pred_scores: list[NDArrayNumber],
67
+ pred_classes: list[NDArrayNumber],
68
+ pred_boxes3d: list[NDArrayNumber] | None = None,
69
+ ) -> None:
70
+ """Process sample and convert detections to coco format."""
71
+ for i, dataset_name in enumerate(dataset_names):
72
+ if (
73
+ self.omni3d_evaluator is not None
74
+ and dataset_name in self.omni3d_evaluator.dataset_names
75
+ ):
76
+ self.omni3d_evaluator.process_batch(
77
+ [coco_image_id[i]],
78
+ [dataset_name],
79
+ [pred_boxes[i]],
80
+ [pred_scores[i]],
81
+ [pred_classes[i]],
82
+ pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None,
83
+ )
84
+ else:
85
+ self.evaluators[dataset_name].process_batch(
86
+ [coco_image_id[i]],
87
+ [pred_boxes[i]],
88
+ [pred_scores[i]],
89
+ [pred_classes[i]],
90
+ pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None,
91
+ )
92
+
93
+ def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
94
+ """Evaluate predictions and return the results."""
95
+ assert metric in self.metrics, f"Unsupported metric: {metric}"
96
+
97
+ log_dict = {}
98
+ log_str = ""
99
+
100
+ if self.omni3d_evaluator is not None:
101
+ log_dict_omni3d, omni3d_log_str = self.omni3d_evaluator.evaluate(
102
+ metric
103
+ )
104
+
105
+ log_dict.update(log_dict_omni3d)
106
+ log_str += omni3d_log_str
107
+
108
+ for dataset_name in self.dataset_names:
109
+ rank_zero_info(f"Evaluating {dataset_name}...")
110
+ per_dataset_log_dict, dataset_log_str = self.evaluators[
111
+ dataset_name
112
+ ].evaluate(metric)
113
+
114
+ if "ODS" in per_dataset_log_dict:
115
+ score = "ODS"
116
+ else:
117
+ score = "AP"
118
+
119
+ log_dict[f"{score}_{dataset_name}"] = per_dataset_log_dict[score]
120
+
121
+ if self.evaluators[dataset_name].base_classes is not None:
122
+ log_dict[f"{score}_Base_{dataset_name}"] = (
123
+ per_dataset_log_dict[f"{score}_Base"]
124
+ )
125
+ log_dict[f"{score}_Novel_{dataset_name}"] = (
126
+ per_dataset_log_dict[f"{score}_Novel"]
127
+ )
128
+
129
+ log_str += f"\nCheck {dataset_name} results in log dict."
130
+
131
+ rank_zero_info(dataset_log_str + "\n")
132
+
133
+ return log_dict, log_str
134
+
135
+ def save(self, metric: str, output_dir: str) -> None:
136
+ """Save the results to json files."""
137
+ for dataset_name in self.dataset_names:
138
+ self.evaluators[dataset_name].save(
139
+ metric, output_dir, prefix=dataset_name
140
+ )
opendet3d/model/__init__.py ADDED
File without changes
opendet3d/model/detect/__init__.py ADDED
File without changes
opendet3d/model/detect/grounding_dino.py ADDED
@@ -0,0 +1,1050 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grounding DINO model.
2
+
3
+ modified from mmdetection.
4
+ """
5
+
6
+ from collections.abc import Sequence
7
+ from typing import NamedTuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import Tensor, nn
12
+ from transformers import BatchEncoding
13
+ from vis4d.common.ckpt import load_model_checkpoint
14
+ from vis4d.common.logging import rank_zero_warn
15
+ from vis4d.op.base import BaseModel
16
+ from vis4d.op.layer.positional_encoding import SinePositionalEncoding
17
+
18
+ from opendet3d.model.language.mm_bert import BertModel
19
+ from opendet3d.op.detect.deformable_detr import get_valid_ratio
20
+ from opendet3d.op.detect.dino import CdnQueryGenerator
21
+ from opendet3d.op.detect.grounding_dino import (
22
+ GroundingDINOHead,
23
+ GroundingDinoTransformerDecoder,
24
+ GroundingDinoTransformerEncoder,
25
+ RoI2Det,
26
+ )
27
+ from opendet3d.op.fpp.channel_mapper import ChannelMapper
28
+ from opendet3d.op.language.grounding import (
29
+ chunks,
30
+ clean_label_name,
31
+ create_positive_map,
32
+ create_positive_map_label_to_token,
33
+ run_ner,
34
+ )
35
+
36
+ REV_KEYS = [
37
+ (r"\.conv.weight", ".weight"),
38
+ (r"\.conv.bias", ".bias"),
39
+ (r"\.gn", ".norm"),
40
+ ]
41
+
42
+
43
+ class GroundingDINOOut(NamedTuple):
44
+ """Output of the Grounding DINO model."""
45
+
46
+ all_layers_cls_scores: list[Tensor]
47
+ all_layers_bbox_preds: list[Tensor]
48
+ enc_outputs_class: Tensor
49
+ enc_outputs_coord: Tensor
50
+ text_token_mask: Tensor
51
+ dn_meta: dict[str, Tensor]
52
+ positive_maps: list[Tensor]
53
+
54
+
55
+ class DetOut(NamedTuple):
56
+ """Output of the Grounding DINO model."""
57
+
58
+ boxes: list[Tensor]
59
+ scores: list[Tensor]
60
+ class_ids: list[Tensor]
61
+ categories: list[list[str]] | None = None
62
+
63
+
64
+ class GroundingDINO(nn.Module):
65
+ """Grounding DINO."""
66
+
67
+ def __init__(
68
+ self,
69
+ basemodel: BaseModel,
70
+ neck: ChannelMapper,
71
+ texts: list[str] | None = None,
72
+ custom_entities: bool = True,
73
+ chunked_size: int = -1,
74
+ num_queries: int = 900,
75
+ num_feature_levels: int = 4,
76
+ use_checkpoint: bool = False,
77
+ bbox_head: GroundingDINOHead | None = None,
78
+ language_model: BertModel | None = None,
79
+ roi2det: RoI2Det | None = None,
80
+ weights: str | None = None,
81
+ ) -> None:
82
+ """Create the Grounding DINO model."""
83
+ super().__init__()
84
+ self.texts = texts
85
+ self.custom_entities = custom_entities
86
+ self.chunked_size = chunked_size
87
+
88
+ self.num_queries = num_queries
89
+
90
+ self.backbone = basemodel
91
+ self.neck = neck
92
+
93
+ # Encoder
94
+ self.encoder = GroundingDinoTransformerEncoder(
95
+ num_levels=num_feature_levels, use_checkpoint=use_checkpoint
96
+ )
97
+
98
+ self.embed_dims = self.encoder.embed_dims
99
+ self.positional_encoding = SinePositionalEncoding(
100
+ num_feats=128, normalize=True, offset=0.0, temperature=20
101
+ )
102
+
103
+ num_feats = self.positional_encoding.num_feats
104
+ assert num_feats * 2 == self.embed_dims, (
105
+ f"embed_dims should be exactly 2 times of num_feats. "
106
+ f"Found {self.embed_dims} and {num_feats}."
107
+ )
108
+
109
+ self.level_embed = nn.Parameter(
110
+ torch.Tensor(num_feature_levels, self.embed_dims)
111
+ )
112
+
113
+ self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
114
+ self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
115
+
116
+ # Decoder
117
+ self.decoder = GroundingDinoTransformerDecoder(
118
+ num_levels=num_feature_levels
119
+ )
120
+
121
+ self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
122
+
123
+ # Grounding DINO head
124
+ self.bbox_head = bbox_head or GroundingDINOHead(
125
+ num_classes=256, num_decoder_layer=self.decoder.num_layers
126
+ )
127
+
128
+ self.roi2det = roi2det or RoI2Det()
129
+
130
+ self.dn_query_generator = CdnQueryGenerator(
131
+ num_classes=self.bbox_head.num_classes,
132
+ embed_dims=self.embed_dims,
133
+ num_matching_queries=self.num_queries,
134
+ label_noise_scale=0.5,
135
+ box_noise_scale=1.0, # 0.4 for DN-DETR
136
+ dynamic=True,
137
+ num_groups=None,
138
+ num_dn_queries=100,
139
+ )
140
+
141
+ # Language model configuration
142
+ self._special_tokens = ". "
143
+
144
+ # text modules
145
+ self.language_model = language_model or BertModel(
146
+ name="bert-base-uncased",
147
+ max_tokens=256,
148
+ pad_to_max=False,
149
+ use_sub_sentence_represent=True,
150
+ special_tokens_list=["[CLS]", "[SEP]", ".", "?"],
151
+ add_pooling_layer=False,
152
+ use_checkpoint=use_checkpoint,
153
+ )
154
+
155
+ self.text_feat_map = nn.Linear(
156
+ self.language_model.language_backbone.body.language_dim,
157
+ self.embed_dims,
158
+ bias=True,
159
+ )
160
+
161
+ self._init_weights()
162
+
163
+ if weights is not None:
164
+ load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
165
+
166
+ def _init_weights(self) -> None:
167
+ """Initialize weights."""
168
+ # DINO
169
+ for coder in self.encoder, self.decoder:
170
+ for p in coder.parameters():
171
+ if p.dim() > 1:
172
+ nn.init.xavier_uniform_(p)
173
+
174
+ nn.init.xavier_uniform_(self.memory_trans_fc.weight)
175
+ nn.init.xavier_uniform_(self.query_embedding.weight)
176
+ nn.init.normal_(self.level_embed)
177
+
178
+ # G-DINO
179
+ nn.init.constant_(self.text_feat_map.bias.data, 0)
180
+ nn.init.xavier_uniform_(self.text_feat_map.weight.data)
181
+
182
+ def get_captions_and_tokens_positive(
183
+ self,
184
+ text_prompt: list[str],
185
+ text_prompt_mapping: dict[str, dict[str, str]] | None = None,
186
+ ) -> tuple[str, list[list[int]]]:
187
+ """Enhance the text prompts with the text mapping."""
188
+ captions = ""
189
+ tokens_positive = []
190
+ for word in text_prompt:
191
+ if text_prompt_mapping is not None and word in text_prompt_mapping:
192
+ enhanced_text_dict = text_prompt_mapping[word]
193
+ if "prefix" in enhanced_text_dict:
194
+ captions += enhanced_text_dict["prefix"]
195
+
196
+ start_i = len(captions)
197
+ if "name" in enhanced_text_dict:
198
+ captions += enhanced_text_dict["name"]
199
+ else:
200
+ captions += word
201
+ end_i = len(captions)
202
+
203
+ tokens_positive.append([[start_i, end_i]])
204
+
205
+ if "suffix" in enhanced_text_dict:
206
+ captions += enhanced_text_dict["suffix"]
207
+ else:
208
+ tokens_positive.append(
209
+ [[len(captions), len(captions) + len(word)]]
210
+ )
211
+ captions += word
212
+ captions += self._special_tokens
213
+ return captions, tokens_positive
214
+
215
+ def get_tokens_and_prompts(
216
+ self,
217
+ text_prompt: str | list[str],
218
+ text_prompt_mapping: dict[str, dict[str, str]] | None = None,
219
+ ) -> tuple[BatchEncoding, str, list[list[int]], list[str]]:
220
+ """Get the tokens positive and prompts for the caption."""
221
+ if isinstance(text_prompt, list):
222
+ captions, tokens_positive = self.get_captions_and_tokens_positive(
223
+ text_prompt, text_prompt_mapping
224
+ )
225
+
226
+ tokenized = self.language_model.tokenizer(
227
+ [captions], padding="longest", return_tensors="pt"
228
+ )
229
+ entities = text_prompt
230
+ else:
231
+ if not text_prompt.endswith("."):
232
+ captions = text_prompt + self._special_tokens
233
+ else:
234
+ captions = text_prompt
235
+
236
+ tokenized = self.language_model.tokenizer(
237
+ [captions], padding="longest", return_tensors="pt"
238
+ )
239
+ tokens_positive, entities = run_ner(captions)
240
+
241
+ return tokenized, captions, tokens_positive, entities
242
+
243
+ def get_positive_map(
244
+ self, tokenized: BatchEncoding, tokens_positive: list[list[int]]
245
+ ) -> tuple[dict, Tensor]:
246
+ """Get the positive map and label to token."""
247
+ positive_map = create_positive_map(
248
+ tokenized,
249
+ tokens_positive,
250
+ max_num_entities=self.bbox_head.cls_branches[
251
+ self.decoder.num_layers
252
+ ].max_text_len,
253
+ )
254
+ positive_map_label_to_token = create_positive_map_label_to_token(
255
+ positive_map, plus=1
256
+ )
257
+ return positive_map_label_to_token, positive_map
258
+
259
+ def get_tokens_positive_and_prompts(
260
+ self,
261
+ text_prompt: str | list[str],
262
+ custom_entities: bool = False,
263
+ tokens_positive: list[list[int, int]] | int | None = None,
264
+ text_prompt_mapping: dict[str, dict[str, str]] | None = None,
265
+ ) -> tuple[dict, str, Tensor, list]:
266
+ """Get the tokens positive and prompts for the caption.
267
+
268
+ Args:
269
+ original_caption (str): The original caption, e.g. 'bench . car .'
270
+ custom_entities (bool, optional): Whether to use custom entities.
271
+ If ``True``, the ``original_caption`` should be a list of
272
+ strings, each of which is a word. Defaults to False.
273
+
274
+ Returns:
275
+ Tuple[dict, str, dict, str]: The dict is a mapping from each entity
276
+ id, which is numbered from 1, to its positive token id.
277
+ The str represents the prompts.
278
+ """
279
+ if tokens_positive is not None:
280
+ assert isinstance(
281
+ text_prompt, str
282
+ ), "Text prompt should be a string with given positive tokens."
283
+
284
+ if not text_prompt.endswith("."):
285
+ captions = text_prompt + self._special_tokens
286
+ else:
287
+ captions = text_prompt
288
+
289
+ if tokens_positive == -1:
290
+ return None, captions, None, captions
291
+ else:
292
+ assert isinstance(
293
+ tokens_positive, list
294
+ ), "Positive tokens should be a list of list[int] if not -1."
295
+ tokenized = self.language_model.tokenizer(
296
+ [captions], padding="longest", return_tensors="pt"
297
+ )
298
+ positive_map_label_to_token, positive_map = (
299
+ self.get_positive_map(tokenized, tokens_positive)
300
+ )
301
+
302
+ entities = []
303
+ for token_positive in tokens_positive:
304
+ instance_entities = []
305
+ for t in token_positive:
306
+ instance_entities.append(captions[t[0] : t[1]])
307
+ entities.append(" / ".join(instance_entities))
308
+
309
+ return (
310
+ positive_map_label_to_token,
311
+ captions,
312
+ positive_map,
313
+ entities,
314
+ )
315
+
316
+ if custom_entities:
317
+ if isinstance(text_prompt, str):
318
+ text_prompt = text_prompt.strip(self._special_tokens)
319
+ text_prompt = text_prompt.split(self._special_tokens)
320
+ text_prompt = list(filter(lambda x: len(x) > 0, text_prompt))
321
+ text_prompt = [clean_label_name(i) for i in text_prompt]
322
+
323
+ if self.chunked_size > 0:
324
+ assert not self.training, "Chunked size is only for testing."
325
+ (
326
+ positive_map_label_to_token,
327
+ captions,
328
+ positive_map,
329
+ entities,
330
+ ) = self.get_tokens_positive_and_prompts_chunked(
331
+ text_prompt, text_prompt_mapping
332
+ )
333
+ else:
334
+ tokenized, captions, tokens_positive, entities = (
335
+ self.get_tokens_and_prompts(text_prompt, text_prompt_mapping)
336
+ )
337
+ positive_map_label_to_token, positive_map = self.get_positive_map(
338
+ tokenized, tokens_positive
339
+ )
340
+
341
+ return positive_map_label_to_token, captions, positive_map, entities
342
+
343
+ def get_tokens_positive_and_prompts_chunked(
344
+ self,
345
+ text_prompt: list[str],
346
+ text_prompt_mapping: dict[str, dict[str, str]] | None = None,
347
+ ):
348
+ """Get the tokens positive and prompts for the caption."""
349
+ text_prompt_chunked = chunks(text_prompt, self.chunked_size)
350
+ ids_chunked = chunks(
351
+ list(range(1, len(text_prompt) + 1)), self.chunked_size
352
+ )
353
+
354
+ positive_map_label_to_token_chunked = []
355
+ captions_chunked = []
356
+ positive_map_chunked = []
357
+ entities_chunked = []
358
+ for i in range(len(ids_chunked)):
359
+ captions, tokens_positive = self.get_captions_and_tokens_positive(
360
+ text_prompt_chunked[i], text_prompt_mapping
361
+ )
362
+
363
+ tokenized = self.language_model.tokenizer(
364
+ [captions], padding="longest", return_tensors="pt"
365
+ )
366
+ if tokenized.input_ids.shape[1] > self.language_model.max_tokens:
367
+ rank_zero_warn(
368
+ "Caption is too long will result in poor performance. "
369
+ "Please reduce the chunked size."
370
+ )
371
+
372
+ positive_map_label_to_token, positive_map = self.get_positive_map(
373
+ tokenized, tokens_positive
374
+ )
375
+
376
+ captions_chunked.append(captions)
377
+ positive_map_label_to_token_chunked.append(
378
+ positive_map_label_to_token
379
+ )
380
+ positive_map_chunked.append(positive_map)
381
+ entities_chunked.append(text_prompt_chunked[i])
382
+
383
+ return (
384
+ positive_map_label_to_token_chunked,
385
+ captions_chunked,
386
+ positive_map_chunked,
387
+ entities_chunked,
388
+ )
389
+
390
+ # TODO: Move this to deformable DETR
391
+ def pre_transformer(
392
+ self,
393
+ feats: list[Tensor],
394
+ input_hw: list[tuple[int, int]],
395
+ batch_input_shape: tuple[int, int],
396
+ padding: list[list[int]] | None = None,
397
+ ) -> tuple[Tensor, Tensor, Tensor | None, Tensor, Tensor, Tensor]:
398
+ """Process image features before transformer."""
399
+ batch_size = feats[0].size(0)
400
+
401
+ # construct binary masks for the transformer.
402
+ batch_input_img_h, batch_input_img_w = batch_input_shape
403
+ same_shape_flag = all(
404
+ [
405
+ s[0] == batch_input_img_h and s[1] == batch_input_img_w
406
+ for s in input_hw
407
+ ]
408
+ )
409
+
410
+ if same_shape_flag:
411
+ mlvl_masks = []
412
+ mlvl_pos_embeds = []
413
+ for feat in feats:
414
+ mlvl_masks.append(None)
415
+ mlvl_pos_embeds.append(
416
+ self.positional_encoding(None, inputs=feat)
417
+ )
418
+ else:
419
+ check_center = not (padding is None)
420
+ masks = feats[0].new_ones(
421
+ (batch_size, batch_input_img_h, batch_input_img_w)
422
+ )
423
+ for img_id in range(batch_size):
424
+ img_h, img_w = input_hw[img_id]
425
+
426
+ if padding is None:
427
+ masks[img_id, :img_h, :img_w] = 0
428
+ else:
429
+ pad_left, pad_right, pad_top, pad_bottom = padding[img_id]
430
+ masks[
431
+ img_id,
432
+ pad_top : batch_input_img_h - pad_bottom,
433
+ pad_left : batch_input_img_w - pad_right,
434
+ ] = 0
435
+
436
+ # NOTE following the official DETR repo, non-zero
437
+ # values representing ignored positions, while
438
+ # zero values means valid positions.
439
+ mlvl_masks = []
440
+ mlvl_pos_embeds = []
441
+ for feat in feats:
442
+ mlvl_masks.append(
443
+ F.interpolate(masks[None], size=feat.shape[-2:])
444
+ .to(torch.bool)
445
+ .squeeze(0)
446
+ )
447
+ mlvl_pos_embeds.append(
448
+ self.positional_encoding(mlvl_masks[-1])
449
+ )
450
+
451
+ feat_flatten = []
452
+ lvl_pos_embed_flatten = []
453
+ mask_flatten = []
454
+ spatial_shapes = []
455
+ for lvl, (feat, mask, pos_embed) in enumerate(
456
+ zip(feats, mlvl_masks, mlvl_pos_embeds)
457
+ ):
458
+ batch_size, c, _, _ = feat.shape
459
+ spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device)
460
+ # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c]
461
+ feat = feat.view(batch_size, c, -1).permute(0, 2, 1)
462
+ pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1)
463
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
464
+ # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl]
465
+ if mask is not None:
466
+ mask = mask.flatten(1)
467
+
468
+ feat_flatten.append(feat)
469
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
470
+ mask_flatten.append(mask)
471
+ spatial_shapes.append(spatial_shape)
472
+
473
+ # (bs, num_feat_points, dim)
474
+ feat_flatten = torch.cat(feat_flatten, 1)
475
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
476
+ # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl)
477
+ if mask_flatten[0] is not None:
478
+ mask_flatten = torch.cat(mask_flatten, 1)
479
+ else:
480
+ mask_flatten = None
481
+
482
+ # (num_level, 2)
483
+ spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
484
+ level_start_index = torch.cat(
485
+ (
486
+ spatial_shapes.new_zeros((1,)), # (num_level)
487
+ spatial_shapes.prod(1).cumsum(0)[:-1],
488
+ )
489
+ )
490
+ if mlvl_masks[0] is not None:
491
+ valid_ratios = torch.stack( # (bs, num_level, 2)
492
+ [get_valid_ratio(m, check_center) for m in mlvl_masks],
493
+ 1,
494
+ )
495
+ else:
496
+ valid_ratios = feats[0].new_ones(batch_size, len(feats), 2)
497
+
498
+ return (
499
+ feat_flatten,
500
+ lvl_pos_embed_flatten,
501
+ mask_flatten,
502
+ spatial_shapes,
503
+ level_start_index,
504
+ valid_ratios,
505
+ )
506
+
507
+ def forward_transformer(
508
+ self,
509
+ feat_flatten: Tensor,
510
+ lvl_pos_embed_flatten: Tensor,
511
+ memory_mask: Tensor | None,
512
+ spatial_shapes: Tensor,
513
+ level_start_index: Tensor,
514
+ valid_ratios: Tensor,
515
+ text_dict: dict[str, Tensor],
516
+ boxes: Tensor | None = None,
517
+ class_ids: Tensor | None = None,
518
+ input_hw: list[tuple[int, int]] | None = None,
519
+ ) -> tuple[Tensor, Tensor, Tensor, list[Tensor]]:
520
+ """Forward function for the transformer."""
521
+ text_token_mask = text_dict["text_token_mask"]
522
+
523
+ memory, memory_text = self.encoder(
524
+ query=feat_flatten,
525
+ query_pos=lvl_pos_embed_flatten,
526
+ key_padding_mask=memory_mask,
527
+ spatial_shapes=spatial_shapes,
528
+ level_start_index=level_start_index,
529
+ valid_ratios=valid_ratios,
530
+ memory_text=text_dict["embedded"],
531
+ text_attention_mask=~text_token_mask,
532
+ position_ids=text_dict["position_ids"],
533
+ text_self_attention_masks=text_dict["masks"],
534
+ )
535
+
536
+ bs = memory.shape[0]
537
+
538
+ output_memory, output_proposals = self.gen_encoder_output_proposals(
539
+ memory, memory_mask, spatial_shapes
540
+ )
541
+
542
+ enc_outputs_class = self.bbox_head.cls_branches[
543
+ self.decoder.num_layers
544
+ ](output_memory, memory_text, text_token_mask)
545
+ cls_out_features = self.bbox_head.cls_branches[
546
+ self.decoder.num_layers
547
+ ].max_text_len
548
+ enc_outputs_coord_unact = (
549
+ self.bbox_head.reg_branches[self.decoder.num_layers](output_memory)
550
+ + output_proposals
551
+ )
552
+
553
+ # NOTE The DINO selects top-k proposals according to scores of
554
+ # multi-class classification, while DeformDETR, where the input
555
+ # is `enc_outputs_class[..., 0]` selects according to scores of
556
+ # binary classification.
557
+ topk_indices = torch.topk(
558
+ enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1
559
+ )[1]
560
+
561
+ topk_score = torch.gather(
562
+ enc_outputs_class,
563
+ 1,
564
+ topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features),
565
+ )
566
+ topk_coords_unact = torch.gather(
567
+ enc_outputs_coord_unact,
568
+ 1,
569
+ topk_indices.unsqueeze(-1).repeat(1, 1, 4),
570
+ )
571
+ topk_coords = topk_coords_unact.sigmoid()
572
+ topk_coords_unact = topk_coords_unact.detach()
573
+
574
+ query = self.query_embedding.weight[:, None, :]
575
+ query = query.repeat(1, bs, 1).transpose(0, 1)
576
+
577
+ if self.training:
578
+ dn_label_query, dn_bbox_query, dn_mask, dn_meta = (
579
+ self.dn_query_generator(boxes, class_ids, input_hw)
580
+ )
581
+ query = torch.cat([dn_label_query, query], dim=1)
582
+ reference_points = torch.cat(
583
+ [dn_bbox_query, topk_coords_unact], dim=1
584
+ )
585
+ else:
586
+ reference_points = topk_coords_unact
587
+ dn_mask, dn_meta = None, None
588
+
589
+ reference_points = reference_points.sigmoid()
590
+
591
+ # NOTE DINO calculates encoder losses on scores and coordinates
592
+ # of selected top-k encoder queries, while DeformDETR is of all
593
+ # encoder queries.
594
+ if self.training:
595
+ enc_outputs_class = topk_score
596
+ enc_outputs_coord = topk_coords
597
+
598
+ hidden_states, references = self.decoder(
599
+ query=query,
600
+ value=memory,
601
+ key_padding_mask=memory_mask,
602
+ self_attn_mask=dn_mask,
603
+ reference_points=reference_points,
604
+ spatial_shapes=spatial_shapes,
605
+ level_start_index=level_start_index,
606
+ valid_ratios=valid_ratios,
607
+ reg_branches=self.bbox_head.reg_branches,
608
+ memory_text=memory_text,
609
+ text_attention_mask=~text_token_mask,
610
+ )
611
+
612
+ if len(query) == self.num_queries:
613
+ # NOTE: This is to make sure label_embeding can be involved to
614
+ # produce loss even if there is no denoising query (no ground truth
615
+ # target in this GPU), otherwise, this will raise runtime error in
616
+ # distributed training.
617
+ hidden_states[0] += (
618
+ self.dn_query_generator.label_embedding.weight[0, 0] * 0.0
619
+ )
620
+
621
+ if self.training:
622
+ return (
623
+ memory_text,
624
+ text_token_mask,
625
+ hidden_states,
626
+ list(references),
627
+ enc_outputs_class,
628
+ enc_outputs_coord,
629
+ dn_meta,
630
+ )
631
+
632
+ return memory_text, text_token_mask, hidden_states, list(references)
633
+
634
+ # TODO: Move this to deformable DETR
635
+ def gen_encoder_output_proposals(
636
+ self,
637
+ memory: Tensor,
638
+ memory_mask: Tensor | None,
639
+ spatial_shapes: Tensor,
640
+ ) -> tuple[Tensor, Tensor]:
641
+ """Generate proposals from encoded memory. The function will only be
642
+ used when `as_two_stage` is `True`.
643
+
644
+ Args:
645
+ memory (Tensor): The output embeddings of the Transformer encoder,
646
+ has shape (bs, num_feat_points, dim).
647
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
648
+ has shape (bs, num_feat_points).
649
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
650
+ has shape (num_levels, 2), last dimension represents (h, w).
651
+
652
+ Returns:
653
+ tuple: A tuple of transformed memory and proposals.
654
+
655
+ - output_memory (Tensor): The transformed memory for obtaining
656
+ top-k proposals, has shape (bs, num_feat_points, dim).
657
+ - output_proposals (Tensor): The inverse-normalized proposal, has
658
+ shape (batch_size, num_keys, 4) with the last dimension arranged
659
+ as (cx, cy, w, h).
660
+ """
661
+
662
+ bs = memory.size(0)
663
+ proposals = []
664
+ _cur = 0 # start index in the sequence of the current level
665
+ for lvl, HW in enumerate(spatial_shapes):
666
+ H, W = HW
667
+
668
+ if memory_mask is not None:
669
+ mask_flatten_ = memory_mask[:, _cur : (_cur + H * W)].view(
670
+ bs, H, W, 1
671
+ )
672
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(
673
+ -1
674
+ )
675
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(
676
+ -1
677
+ )
678
+ scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
679
+ else:
680
+ if not isinstance(HW, Tensor):
681
+ HW = memory.new_tensor(HW)
682
+ scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(1, 1, 1, 2)
683
+
684
+ grid_y, grid_x = torch.meshgrid(
685
+ torch.linspace(
686
+ 0, H - 1, H, dtype=torch.float32, device=memory.device
687
+ ),
688
+ torch.linspace(
689
+ 0, W - 1, W, dtype=torch.float32, device=memory.device
690
+ ),
691
+ indexing="ij",
692
+ )
693
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
694
+ grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
695
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
696
+ proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
697
+ proposals.append(proposal)
698
+ _cur += H * W
699
+
700
+ output_proposals = torch.cat(proposals, 1)
701
+
702
+ # do not use `all` to make it exportable to onnx
703
+ output_proposals_valid = (
704
+ (output_proposals > 0.01) & (output_proposals < 0.99)
705
+ ).sum(-1, keepdim=True) == output_proposals.shape[-1]
706
+
707
+ # inverse_sigmoid
708
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
709
+ if memory_mask is not None:
710
+ output_proposals = output_proposals.masked_fill(
711
+ memory_mask.unsqueeze(-1), float("inf")
712
+ )
713
+
714
+ output_proposals = output_proposals.masked_fill(
715
+ ~output_proposals_valid, float("inf")
716
+ )
717
+
718
+ if memory_mask is not None:
719
+ output_memory = memory.masked_fill(
720
+ memory_mask.unsqueeze(-1), float(0)
721
+ )
722
+ else:
723
+ output_memory = memory
724
+
725
+ # [bs, sum(hw), 2]
726
+ output_memory = output_memory.masked_fill(
727
+ ~output_proposals_valid, float(0)
728
+ )
729
+ output_memory = self.memory_trans_fc(output_memory)
730
+ output_memory = self.memory_trans_norm(output_memory)
731
+
732
+ return output_memory, output_proposals
733
+
734
+ def _forward_train(
735
+ self,
736
+ images: Tensor,
737
+ input_texts: list[list[str]] | None,
738
+ boxes2d: Tensor,
739
+ boxes2d_classes: Tensor,
740
+ input_hw: list[tuple[int, int]],
741
+ input_tokens_positive: list[list[int, int]] | None = None,
742
+ ) -> GroundingDINOOut:
743
+ """Forward train."""
744
+ batch_size = images.shape[0]
745
+
746
+ # if "tokens_positive" in batch_data_samples[0]:
747
+ if input_tokens_positive is not None:
748
+ positive_maps = []
749
+ for tokens_positive_dict, text_prompt, gt_label in zip(
750
+ input_tokens_positive, input_texts, boxes2d_classes
751
+ ):
752
+ tokenized = self.language_model.tokenizer(
753
+ [text_prompt], padding="longest", return_tensors="pt"
754
+ )
755
+ new_tokens_positive = [
756
+ tokens_positive_dict[label.item()] for label in gt_label
757
+ ]
758
+ _, positive_map = self.get_positive_map(
759
+ tokenized, new_tokens_positive
760
+ )
761
+ positive_maps.append(positive_map)
762
+ new_text_prompts = input_texts
763
+ else:
764
+ new_text_prompts = []
765
+ positive_maps = []
766
+
767
+ # All the text prompts are the same, so there is no need to
768
+ # calculate them multiple times.
769
+ if (
770
+ input_texts is None
771
+ or len(set(["".join(t) for t in input_texts])) == 1
772
+ ):
773
+ if input_texts is None:
774
+ text_prompt = self.texts
775
+ else:
776
+ text_prompt = input_texts[0]
777
+
778
+ tokenized, caption_string, tokens_positive, _ = (
779
+ self.get_tokens_and_prompts(text_prompt)
780
+ )
781
+ new_text_prompts = [caption_string] * batch_size
782
+ for gt_label in boxes2d_classes:
783
+ new_tokens_positive = [
784
+ tokens_positive[label] for label in gt_label
785
+ ]
786
+ _, positive_map = self.get_positive_map(
787
+ tokenized, new_tokens_positive
788
+ )
789
+ positive_maps.append(positive_map)
790
+ else:
791
+ for text_prompt, gt_label in zip(input_texts, boxes2d_classes):
792
+ tokenized, caption_string, tokens_positive, _ = (
793
+ self.get_tokens_and_prompts(text_prompt)
794
+ )
795
+
796
+ new_text_prompts.append(caption_string)
797
+
798
+ new_tokens_positive = [
799
+ tokens_positive[label] for label in gt_label
800
+ ]
801
+ _, positive_map = self.get_positive_map(
802
+ tokenized, new_tokens_positive
803
+ )
804
+
805
+ positive_maps.append(positive_map)
806
+
807
+ for i in range(batch_size):
808
+ positive_maps[i] = (
809
+ positive_maps[i].to(images.device).bool().float()
810
+ )
811
+
812
+ text_dict = self.language_model(new_text_prompts)
813
+
814
+ if self.text_feat_map is not None:
815
+ text_dict["embedded"] = self.text_feat_map(text_dict["embedded"])
816
+
817
+ text_token_masks = []
818
+ for i in range(batch_size):
819
+ text_token_masks.append(
820
+ text_dict["text_token_mask"][i]
821
+ .unsqueeze(0)
822
+ .repeat(len(positive_map), 1)
823
+ )
824
+
825
+ visual_feats = self.backbone(images)[2:]
826
+ visual_feats = self.neck(visual_feats)
827
+
828
+ batch_input_img_h, batch_input_img_w = images.shape[-2:]
829
+ batch_input_shape = (batch_input_img_h, batch_input_img_w)
830
+
831
+ (
832
+ feat_flatten,
833
+ lvl_pos_embed_flatten,
834
+ memory_mask,
835
+ spatial_shapes,
836
+ level_start_index,
837
+ valid_ratios,
838
+ ) = self.pre_transformer(visual_feats, input_hw, batch_input_shape)
839
+
840
+ (
841
+ memory_text,
842
+ text_token_mask,
843
+ hidden_states,
844
+ references,
845
+ enc_outputs_class,
846
+ enc_outputs_coord,
847
+ dn_meta,
848
+ ) = self.forward_transformer(
849
+ feat_flatten,
850
+ lvl_pos_embed_flatten,
851
+ memory_mask,
852
+ spatial_shapes,
853
+ level_start_index,
854
+ valid_ratios,
855
+ text_dict,
856
+ boxes2d,
857
+ boxes2d_classes,
858
+ input_hw,
859
+ )
860
+
861
+ all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
862
+ hidden_states, references, memory_text, text_token_mask
863
+ )
864
+
865
+ return GroundingDINOOut(
866
+ all_layers_cls_scores,
867
+ all_layers_bbox_preds,
868
+ enc_outputs_class,
869
+ enc_outputs_coord,
870
+ text_token_mask,
871
+ dn_meta,
872
+ positive_maps,
873
+ )
874
+
875
+ def _forward_test(
876
+ self,
877
+ images: Tensor,
878
+ input_texts: list[str] | None,
879
+ text_prompt_mapping: list[dict[str, dict[str, str]]] | None,
880
+ input_hw: list[tuple[int, int]],
881
+ original_hw: list[tuple[int, int]],
882
+ ) -> DetOut:
883
+ """Forward."""
884
+ batch_size = images.shape[0]
885
+
886
+ token_positive_maps = []
887
+ text_prompts = []
888
+ entities = []
889
+ for i in range(batch_size):
890
+ if self.texts is not None:
891
+ text_prompt = self.texts
892
+ else:
893
+ text_prompt = input_texts[i]
894
+
895
+ if text_prompt_mapping is not None:
896
+ prompt_mapping = text_prompt_mapping[i]
897
+ else:
898
+ prompt_mapping = None
899
+
900
+ token_positive_map, captions, _, _entities = (
901
+ self.get_tokens_positive_and_prompts(
902
+ text_prompt,
903
+ self.custom_entities,
904
+ text_prompt_mapping=prompt_mapping,
905
+ )
906
+ )
907
+
908
+ token_positive_maps.append(token_positive_map)
909
+ text_prompts.append(captions)
910
+ entities.append(_entities)
911
+
912
+ # image feature extraction
913
+ batch_input_img_h, batch_input_img_w = images.shape[-2:]
914
+
915
+ visual_feats = self.backbone(images)[2:]
916
+ visual_feats = self.neck(visual_feats)
917
+
918
+ if isinstance(text_prompts[0], list):
919
+ # TODO: Support chunked text prompts in the future.
920
+ pass
921
+ # assert batch_size == 1, "Batch size should be 1 for chunked text."
922
+ # count = 0
923
+ # results_list = []
924
+
925
+ # entities = [[item for lst in entities[0] for item in lst]]
926
+
927
+ # for i, captions in enumerate(text_prompts[0]):
928
+ # token_positive_map = token_positive_maps[0][i]
929
+
930
+ # text_dict = self.language_model(captions)
931
+
932
+ # # text feature map layer
933
+ # if self.text_feat_map is not None:
934
+ # text_dict["embedded"] = self.text_feat_map(
935
+ # text_dict["embedded"]
936
+ # )
937
+
938
+ # head_inputs_dict = self.forward_transformer(
939
+ # copy.deepcopy(visual_feats), text_dict, input_hw
940
+ # )
941
+ # pred_instances = self.bbox_head.predict(
942
+ # **head_inputs_dict,
943
+ # batch_token_positive_maps=token_positive_map,
944
+ # )[0]
945
+
946
+ # if len(pred_instances) > 0:
947
+ # pred_instances.labels += count
948
+ # count += len(token_positive_maps_once)
949
+ # results_list.append(pred_instances)
950
+ # results_list = [results_list[0].cat(results_list)]
951
+ else:
952
+ # extract text feats
953
+ text_dict = self.language_model(list(text_prompts))
954
+
955
+ # text feature map layer
956
+ if self.text_feat_map is not None:
957
+ text_dict["embedded"] = self.text_feat_map(
958
+ text_dict["embedded"]
959
+ )
960
+
961
+ batch_input_shape = (batch_input_img_h, batch_input_img_w)
962
+
963
+ (
964
+ feat_flatten,
965
+ lvl_pos_embed_flatten,
966
+ memory_mask,
967
+ spatial_shapes,
968
+ level_start_index,
969
+ valid_ratios,
970
+ ) = self.pre_transformer(visual_feats, input_hw, batch_input_shape)
971
+
972
+ (
973
+ memory_text,
974
+ text_token_mask,
975
+ hidden_states,
976
+ references,
977
+ ) = self.forward_transformer(
978
+ feat_flatten,
979
+ lvl_pos_embed_flatten,
980
+ memory_mask,
981
+ spatial_shapes,
982
+ level_start_index,
983
+ valid_ratios,
984
+ text_dict,
985
+ )
986
+
987
+ all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
988
+ hidden_states, references, memory_text, text_token_mask
989
+ )
990
+
991
+ cls_scores = all_layers_cls_scores[-1]
992
+ bbox_preds = all_layers_bbox_preds[-1]
993
+
994
+ boxes = []
995
+ scores = []
996
+ class_ids = []
997
+ categories = []
998
+ for i, bbox_pred in enumerate(bbox_preds):
999
+ cls_score = cls_scores[i]
1000
+ det_bboxes, det_scores, det_labels = self.roi2det(
1001
+ cls_score,
1002
+ bbox_pred,
1003
+ token_positive_maps[i],
1004
+ input_hw[i],
1005
+ original_hw[i],
1006
+ )
1007
+ boxes.append(det_bboxes)
1008
+ scores.append(det_scores)
1009
+ class_ids.append(det_labels)
1010
+
1011
+ # Get the categories text
1012
+ cur_categories = []
1013
+ for label in det_labels:
1014
+ cur_categories.append(entities[i][label])
1015
+
1016
+ categories.append(cur_categories)
1017
+
1018
+ return DetOut(boxes, scores, class_ids, categories=categories)
1019
+
1020
+ def forward(
1021
+ self,
1022
+ images: Tensor,
1023
+ input_hw: list[tuple[int, int]],
1024
+ boxes2d: Tensor | None = None,
1025
+ boxes2d_classes: Tensor | None = None,
1026
+ original_hw: list[tuple[int, int]] | None = None,
1027
+ input_texts: Sequence[str] | str | None = None,
1028
+ input_tokens_positive: list[dict[int, list[int, int]]] | None = None,
1029
+ text_prompt_mapping: dict[str, dict[str, str]] | None = None,
1030
+ ) -> GroundingDINOOut | DetOut:
1031
+ """Forward function."""
1032
+ if self.training:
1033
+ assert boxes2d is not None and boxes2d_classes is not None
1034
+ return self._forward_train(
1035
+ images,
1036
+ input_texts,
1037
+ boxes2d,
1038
+ boxes2d_classes,
1039
+ input_hw,
1040
+ input_tokens_positive=input_tokens_positive,
1041
+ )
1042
+
1043
+ assert original_hw is not None
1044
+ return self._forward_test(
1045
+ images,
1046
+ input_texts,
1047
+ text_prompt_mapping,
1048
+ input_hw,
1049
+ original_hw,
1050
+ )
opendet3d/model/detect3d/__init__.py ADDED
File without changes
opendet3d/model/detect3d/grounding_dino_3d.py ADDED
@@ -0,0 +1,812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """3D-MOOD."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ from collections.abc import Sequence
7
+ from typing import NamedTuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import Tensor, nn
12
+ from vis4d.op.base import BaseModel
13
+ from vis4d.op.fpp.fpn import FPN
14
+
15
+ from opendet3d.model.detect.grounding_dino import GroundingDINO
16
+ from opendet3d.model.language.mm_bert import BertModel
17
+ from opendet3d.op.detect3d.grounding_dino_3d import (
18
+ GroundingDINO3DHead,
19
+ RoI2Det3D,
20
+ )
21
+ from opendet3d.op.detect.grounding_dino import GroundingDINOHead, RoI2Det
22
+ from opendet3d.op.fpp.channel_mapper import ChannelMapper
23
+
24
+
25
+ class Det3DOut(NamedTuple):
26
+ """Output of the detection model.
27
+
28
+ boxes (list[Tensor]): 2D bounding boxes of shape [N, 4] in xyxy format.
29
+ boxes3d (list[Tensor]): 3D bounding boxes of shape [N, 10].
30
+ scores (list[Tensor]): confidence scores of shape [N,].
31
+ class_ids (list[Tensor]): class ids of shape [N,].
32
+ """
33
+
34
+ boxes: list[Tensor]
35
+ boxes3d: list[Tensor]
36
+ scores: list[Tensor]
37
+ class_ids: list[Tensor]
38
+ depth_maps: list[Tensor] | None
39
+ categories: list[list[str]] | None = None
40
+
41
+
42
+ class GroundingDINO3DOut(NamedTuple):
43
+ """Output of the Grounding DINO model."""
44
+
45
+ all_layers_cls_scores: list[Tensor]
46
+ all_layers_bbox_preds: list[Tensor]
47
+ all_layers_bbox_3d_preds: list[Tensor]
48
+ enc_outputs_class: Tensor
49
+ enc_outputs_coord: Tensor
50
+ enc_outputs_3d: Tensor
51
+ text_token_mask: Tensor
52
+ dn_meta: dict[str, Tensor]
53
+ positive_maps: list[Tensor]
54
+ depth_maps: Tensor | None
55
+
56
+
57
+ class GroundingDINO3D(GroundingDINO):
58
+ """Grounding DINO."""
59
+
60
+ def __init__(
61
+ self,
62
+ basemodel: BaseModel,
63
+ neck: ChannelMapper,
64
+ texts: list[str] | None = None,
65
+ custom_entities: bool = True,
66
+ chunked_size: int = -1,
67
+ num_queries: int = 900,
68
+ num_feature_levels: int = 4,
69
+ use_checkpoint: bool = False,
70
+ bbox_head: GroundingDINOHead | None = None,
71
+ language_model: BertModel | None = None,
72
+ roi2det: RoI2Det | None = None,
73
+ bbox3d_head: GroundingDINO3DHead | None = None,
74
+ roi2det3d: RoI2Det3D | None = None,
75
+ depth_head: nn.Module | None = None,
76
+ fpn: FPN | None = None,
77
+ freeze_detector: bool = False,
78
+ weights: str | None = None,
79
+ cat_mapping: dict[str, int] | None = None,
80
+ ) -> None:
81
+ """Create the Grounding DINO model."""
82
+ super().__init__(
83
+ basemodel=basemodel,
84
+ neck=neck,
85
+ texts=texts,
86
+ custom_entities=custom_entities,
87
+ chunked_size=chunked_size,
88
+ num_queries=num_queries,
89
+ num_feature_levels=num_feature_levels,
90
+ use_checkpoint=use_checkpoint,
91
+ bbox_head=bbox_head,
92
+ roi2det=roi2det,
93
+ language_model=language_model,
94
+ weights=weights,
95
+ )
96
+
97
+ self.bbox3d_head = bbox3d_head or GroundingDINO3DHead()
98
+ self.roi2det3d = roi2det3d or RoI2Det3D()
99
+
100
+ # Depth Head
101
+ self.fpn = fpn
102
+ self.depth_head = depth_head
103
+
104
+ if freeze_detector:
105
+ self._freeze_detector()
106
+
107
+ self.cat_mapping = cat_mapping
108
+
109
+ def _freeze_detector(self):
110
+ """Freeze the detector."""
111
+ for model in [
112
+ self.backbone,
113
+ self.neck,
114
+ self.encoder,
115
+ self.positional_encoding,
116
+ self.memory_trans_fc,
117
+ self.memory_trans_norm,
118
+ self.decoder,
119
+ self.bbox_head,
120
+ self.dn_query_generator,
121
+ self.language_model,
122
+ self.text_feat_map,
123
+ ]:
124
+ model.eval()
125
+ for param in model.parameters():
126
+ param.requires_grad = False
127
+
128
+ self.level_embed.requires_grad = False
129
+ self.query_embedding.requires_grad = False
130
+
131
+ def forward_transformer(
132
+ self,
133
+ feat_flatten: Tensor,
134
+ lvl_pos_embed_flatten: Tensor,
135
+ memory_mask: Tensor | None,
136
+ spatial_shapes: Tensor,
137
+ level_start_index: Tensor,
138
+ valid_ratios: Tensor,
139
+ text_dict: dict[str, Tensor],
140
+ boxes: Tensor | None = None,
141
+ class_ids: Tensor | None = None,
142
+ input_hw: list[tuple[int, int]] | None = None,
143
+ ray_embeddings: Tensor | None = None,
144
+ depth_latents: Tensor | None = None,
145
+ ) -> tuple[Tensor, Tensor, Tensor, list[Tensor]]:
146
+ """Forward function for the transformer."""
147
+ text_token_mask = text_dict["text_token_mask"]
148
+
149
+ memory, memory_text = self.encoder(
150
+ query=feat_flatten,
151
+ query_pos=lvl_pos_embed_flatten,
152
+ key_padding_mask=memory_mask,
153
+ spatial_shapes=spatial_shapes,
154
+ level_start_index=level_start_index,
155
+ valid_ratios=valid_ratios,
156
+ memory_text=text_dict["embedded"],
157
+ text_attention_mask=~text_token_mask,
158
+ position_ids=text_dict["position_ids"],
159
+ text_self_attention_masks=text_dict["masks"],
160
+ )
161
+
162
+ bs = memory.shape[0]
163
+
164
+ output_memory, output_proposals = self.gen_encoder_output_proposals(
165
+ memory, memory_mask, spatial_shapes
166
+ )
167
+
168
+ enc_outputs_class = self.bbox_head.cls_branches[
169
+ self.decoder.num_layers
170
+ ](output_memory, memory_text, text_token_mask)
171
+ cls_out_features = self.bbox_head.cls_branches[
172
+ self.decoder.num_layers
173
+ ].max_text_len
174
+ enc_outputs_coord_unact = (
175
+ self.bbox_head.reg_branches[self.decoder.num_layers](output_memory)
176
+ + output_proposals
177
+ )
178
+
179
+ # NOTE The DINO selects top-k proposals according to scores of
180
+ # multi-class classification, while DeformDETR, where the input
181
+ # is `enc_outputs_class[..., 0]` selects according to scores of
182
+ # binary classification.
183
+ topk_indices = torch.topk(
184
+ enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1
185
+ )[1]
186
+
187
+ topk_score = torch.gather(
188
+ enc_outputs_class,
189
+ 1,
190
+ topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features),
191
+ )
192
+ topk_coords_unact = torch.gather(
193
+ enc_outputs_coord_unact,
194
+ 1,
195
+ topk_indices.unsqueeze(-1).repeat(1, 1, 4),
196
+ )
197
+ topk_coords = topk_coords_unact.sigmoid()
198
+ topk_coords_unact = topk_coords_unact.detach()
199
+
200
+ # Top-k 3D proposals
201
+ topk_output_memory = torch.gather(
202
+ output_memory, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 256)
203
+ )
204
+
205
+ topk_output_3d = self.bbox3d_head.single_forward(
206
+ self.decoder.num_layers,
207
+ topk_output_memory,
208
+ ray_embeddings,
209
+ depth_latents,
210
+ )
211
+
212
+ query = self.query_embedding.weight[:, None, :]
213
+ query = query.repeat(1, bs, 1).transpose(0, 1)
214
+
215
+ if self.training:
216
+ dn_label_query, dn_bbox_query, dn_mask, dn_meta = (
217
+ self.dn_query_generator(boxes, class_ids, input_hw)
218
+ )
219
+ query = torch.cat([dn_label_query, query], dim=1)
220
+ reference_points = torch.cat(
221
+ [dn_bbox_query, topk_coords_unact], dim=1
222
+ )
223
+ else:
224
+ reference_points = topk_coords_unact
225
+ dn_mask, dn_meta = None, None
226
+
227
+ reference_points = reference_points.sigmoid()
228
+
229
+ # NOTE DINO calculates encoder losses on scores and coordinates
230
+ # of selected top-k encoder queries, while DeformDETR is of all
231
+ # encoder queries.
232
+ if self.training:
233
+ enc_outputs_class = topk_score
234
+ enc_outputs_coord = topk_coords
235
+
236
+ hidden_states, references = self.decoder(
237
+ query=query,
238
+ value=memory,
239
+ key_padding_mask=memory_mask,
240
+ self_attn_mask=dn_mask,
241
+ reference_points=reference_points,
242
+ spatial_shapes=spatial_shapes,
243
+ level_start_index=level_start_index,
244
+ valid_ratios=valid_ratios,
245
+ reg_branches=self.bbox_head.reg_branches,
246
+ memory_text=memory_text,
247
+ text_attention_mask=~text_token_mask,
248
+ )
249
+
250
+ if len(query) == self.num_queries:
251
+ # NOTE: This is to make sure label_embeding can be involved to
252
+ # produce loss even if there is no denoising query (no ground truth
253
+ # target in this GPU), otherwise, this will raise runtime error in
254
+ # distributed training.
255
+ hidden_states[0] += (
256
+ self.dn_query_generator.label_embedding.weight[0, 0] * 0.0
257
+ )
258
+
259
+ if self.training:
260
+ return (
261
+ memory_text,
262
+ text_token_mask,
263
+ hidden_states,
264
+ list(references),
265
+ enc_outputs_class,
266
+ enc_outputs_coord,
267
+ topk_output_3d,
268
+ dn_meta,
269
+ )
270
+
271
+ return (
272
+ memory_text,
273
+ text_token_mask,
274
+ hidden_states,
275
+ list(references),
276
+ )
277
+
278
+ def _extract_image_features(
279
+ self, images: Tensor
280
+ ) -> tuple[list[Tensor], list[Tensor] | None]:
281
+ """Extract image features."""
282
+ visual_feats = self.backbone(images)[2:]
283
+
284
+ if self.fpn is not None:
285
+ depth_feats = self.fpn(visual_feats)
286
+
287
+ if len(visual_feats) > len(self.neck.in_channels):
288
+ start_index = len(visual_feats) - len(self.neck.in_channels)
289
+ # NOTE: This is to make sure the number of input channels is the
290
+ # same as the number of input channels of the neck.
291
+ visual_feats = visual_feats[start_index:]
292
+
293
+ visual_feats = self.neck(visual_feats)
294
+
295
+ if self.fpn is None:
296
+ depth_feats = visual_feats
297
+
298
+ return visual_feats, depth_feats
299
+
300
+ def _forward_train(
301
+ self,
302
+ images: Tensor,
303
+ input_texts: list[list[str]] | None,
304
+ boxes2d: list[Tensor],
305
+ boxes2d_classes: list[Tensor],
306
+ input_hw: list[tuple[int, int]],
307
+ intrinsics: Tensor,
308
+ input_tokens_positive: list[list[int, int]] | list[None] | None = None,
309
+ padding: list[list[int]] | None = None,
310
+ ) -> GroundingDINO3DOut:
311
+ """Forward function for training."""
312
+ batch_size = images.shape[0]
313
+
314
+ new_text_prompts = []
315
+ positive_maps = []
316
+ if input_tokens_positive is not None:
317
+ for tokens_positive_dict, text_prompt, gt_label in zip(
318
+ input_tokens_positive, input_texts, boxes2d_classes
319
+ ):
320
+ if tokens_positive_dict is not None:
321
+ tokenized = self.language_model.tokenizer(
322
+ [text_prompt], padding="longest", return_tensors="pt"
323
+ )
324
+
325
+ new_text_prompts.append(text_prompt)
326
+
327
+ new_tokens_positive = [
328
+ tokens_positive_dict[label.item()]
329
+ for label in gt_label
330
+ ]
331
+ else:
332
+ tokenized, caption_string, tokens_positive, _ = (
333
+ self.get_tokens_and_prompts(text_prompt)
334
+ )
335
+
336
+ new_text_prompts.append(caption_string)
337
+
338
+ new_tokens_positive = [
339
+ tokens_positive[label] for label in gt_label
340
+ ]
341
+
342
+ _, positive_map = self.get_positive_map(
343
+ tokenized, new_tokens_positive
344
+ )
345
+ positive_maps.append(positive_map)
346
+ else:
347
+ # All the text prompts are the same or using the self.texts,
348
+ # so there is no need to calculate them multiple times.
349
+ if (
350
+ input_texts is None
351
+ or len(set(["".join(t) for t in input_texts])) == 1
352
+ ):
353
+ if input_texts is None:
354
+ assert self.texts is not None, "Texts should be provided."
355
+ text_prompt = self.texts
356
+ else:
357
+ text_prompt = input_texts[0]
358
+
359
+ tokenized, caption_string, tokens_positive, _ = (
360
+ self.get_tokens_and_prompts(text_prompt)
361
+ )
362
+ new_text_prompts = [caption_string] * batch_size
363
+ for gt_label in boxes2d_classes:
364
+ new_tokens_positive = [
365
+ tokens_positive[label] for label in gt_label
366
+ ]
367
+ _, positive_map = self.get_positive_map(
368
+ tokenized, new_tokens_positive
369
+ )
370
+ positive_maps.append(positive_map)
371
+ else:
372
+ for text_prompt, gt_label in zip(input_texts, boxes2d_classes):
373
+ tokenized, caption_string, tokens_positive, _ = (
374
+ self.get_tokens_and_prompts(text_prompt)
375
+ )
376
+
377
+ new_text_prompts.append(caption_string)
378
+
379
+ new_tokens_positive = [
380
+ tokens_positive[label] for label in gt_label
381
+ ]
382
+ _, positive_map = self.get_positive_map(
383
+ tokenized, new_tokens_positive
384
+ )
385
+
386
+ positive_maps.append(positive_map)
387
+
388
+ for i in range(batch_size):
389
+ positive_maps[i] = (
390
+ positive_maps[i].to(images.device).bool().float()
391
+ )
392
+
393
+ text_dict = self.language_model(new_text_prompts)
394
+
395
+ if self.text_feat_map is not None:
396
+ text_dict["embedded"] = self.text_feat_map(text_dict["embedded"])
397
+
398
+ text_token_masks = []
399
+ for i in range(batch_size):
400
+ text_token_masks.append(
401
+ text_dict["text_token_mask"][i]
402
+ .unsqueeze(0)
403
+ .repeat(len(positive_map), 1)
404
+ )
405
+
406
+ visual_feats, depth_feats = self._extract_image_features(images)
407
+
408
+ batch_input_img_h, batch_input_img_w = images.shape[-2:]
409
+ batch_input_shape = (batch_input_img_h, batch_input_img_w)
410
+
411
+ (
412
+ feat_flatten,
413
+ lvl_pos_embed_flatten,
414
+ memory_mask,
415
+ spatial_shapes,
416
+ level_start_index,
417
+ valid_ratios,
418
+ ) = self.pre_transformer(
419
+ visual_feats, input_hw, batch_input_shape, padding=padding
420
+ )
421
+
422
+ ray_embeddings = self.bbox3d_head.get_camera_embeddings(
423
+ intrinsics, batch_input_shape
424
+ )
425
+
426
+ # Depth Head
427
+ depth_preds, depth_latents = self.depth_head(
428
+ depth_feats, intrinsics, batch_input_shape
429
+ )
430
+
431
+ (
432
+ memory_text,
433
+ text_token_mask,
434
+ hidden_states,
435
+ references,
436
+ enc_outputs_class,
437
+ enc_outputs_coord,
438
+ enc_outputs_3d,
439
+ dn_meta,
440
+ ) = self.forward_transformer(
441
+ feat_flatten,
442
+ lvl_pos_embed_flatten,
443
+ memory_mask,
444
+ spatial_shapes,
445
+ level_start_index,
446
+ valid_ratios,
447
+ text_dict,
448
+ boxes2d,
449
+ boxes2d_classes,
450
+ input_hw,
451
+ ray_embeddings,
452
+ depth_latents,
453
+ )
454
+
455
+ all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
456
+ hidden_states, references, memory_text, text_token_mask
457
+ )
458
+
459
+ # Not using denoising for 3D
460
+ hidden_states_3d = hidden_states[
461
+ :, :, dn_meta["num_denoising_queries"] :, :
462
+ ]
463
+
464
+ all_layers_outputs_3d = self.bbox3d_head(
465
+ hidden_states_3d, ray_embeddings, depth_latents
466
+ )
467
+
468
+ return GroundingDINO3DOut(
469
+ all_layers_cls_scores,
470
+ all_layers_bbox_preds,
471
+ all_layers_outputs_3d,
472
+ enc_outputs_class,
473
+ enc_outputs_coord,
474
+ enc_outputs_3d,
475
+ text_token_mask,
476
+ dn_meta,
477
+ positive_maps,
478
+ depth_maps=depth_preds,
479
+ )
480
+
481
+ def _forward_test(
482
+ self,
483
+ images: Tensor,
484
+ input_texts: list[str] | None,
485
+ text_prompt_mapping: list[dict[str, dict[str, str]]] | None,
486
+ input_hw: list[tuple[int, int]],
487
+ original_hw: list[tuple[int, int]],
488
+ intrinsics: list[Tensor] | None,
489
+ padding: list[list[int]] | None,
490
+ ) -> Det3DOut:
491
+ """Forward."""
492
+ batch_size = images.shape[0]
493
+
494
+ token_positive_maps = []
495
+ text_prompts = []
496
+ entities = []
497
+ for i in range(batch_size):
498
+ if self.texts is not None:
499
+ text_prompt = self.texts
500
+ else:
501
+ text_prompt = input_texts[i]
502
+
503
+ if text_prompt_mapping is not None:
504
+ prompt_mapping = text_prompt_mapping[i]
505
+ else:
506
+ prompt_mapping = None
507
+
508
+ token_positive_map, captions, _, _entities = (
509
+ self.get_tokens_positive_and_prompts(
510
+ text_prompt,
511
+ self.custom_entities,
512
+ text_prompt_mapping=prompt_mapping,
513
+ )
514
+ )
515
+
516
+ token_positive_maps.append(token_positive_map)
517
+ text_prompts.append(captions)
518
+ entities.append(_entities)
519
+
520
+ # image feature extraction
521
+ batch_input_img_h, batch_input_img_w = images.shape[-2:]
522
+
523
+ visual_feats, depth_feats = self._extract_image_features(images)
524
+
525
+ batch_input_shape = (batch_input_img_h, batch_input_img_w)
526
+
527
+ if isinstance(text_prompts[0], list):
528
+ assert batch_size == 1, "Batch size should be 1 for chunked text."
529
+ assert (
530
+ self.cat_mapping is not None
531
+ ), "Category mapping should be provided."
532
+
533
+ boxes = []
534
+ boxes3d = []
535
+ scores = []
536
+ class_ids = []
537
+ categories = []
538
+ for i, captions in enumerate(text_prompts[0]):
539
+ token_positive_map = token_positive_maps[0][i]
540
+ cur_entities = entities[0][i]
541
+
542
+ text_dict = self.language_model([captions])
543
+
544
+ # text feature map layer
545
+ if self.text_feat_map is not None:
546
+ text_dict["embedded"] = self.text_feat_map(
547
+ text_dict["embedded"]
548
+ )
549
+
550
+ (
551
+ feat_flatten,
552
+ lvl_pos_embed_flatten,
553
+ memory_mask,
554
+ spatial_shapes,
555
+ level_start_index,
556
+ valid_ratios,
557
+ ) = self.pre_transformer(
558
+ copy.deepcopy(visual_feats), input_hw, batch_input_shape
559
+ )
560
+
561
+ ray_embeddings = self.bbox3d_head.get_camera_embeddings(
562
+ intrinsics, batch_input_shape
563
+ )
564
+
565
+ # Depth Head
566
+ depth_preds, depth_latents = self.depth_head(
567
+ depth_feats, intrinsics, batch_input_shape
568
+ )
569
+
570
+ depth_maps = []
571
+ for i, depth_pred in enumerate(depth_preds):
572
+ if padding is not None:
573
+ pad_left, pad_right, pad_top, pad_bottom = padding[i]
574
+
575
+ depth_pred = depth_pred[
576
+ pad_top : batch_input_img_h - pad_bottom,
577
+ pad_left : batch_input_img_w - pad_right,
578
+ ]
579
+
580
+ depth_maps.append(
581
+ F.interpolate(
582
+ depth_pred.unsqueeze(0).unsqueeze(0),
583
+ size=original_hw[i],
584
+ mode="bilinear",
585
+ align_corners=False,
586
+ antialias=True,
587
+ )
588
+ .squeeze(0)
589
+ .squeeze(0)
590
+ )
591
+
592
+ (
593
+ memory_text,
594
+ text_token_mask,
595
+ hidden_states,
596
+ references,
597
+ ) = self.forward_transformer(
598
+ feat_flatten,
599
+ lvl_pos_embed_flatten,
600
+ memory_mask,
601
+ spatial_shapes,
602
+ level_start_index,
603
+ valid_ratios,
604
+ text_dict,
605
+ ray_embeddings=ray_embeddings,
606
+ depth_latents=depth_latents,
607
+ )
608
+
609
+ all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
610
+ hidden_states, references, memory_text, text_token_mask
611
+ )
612
+
613
+ all_layers_outputs_3d = self.bbox3d_head(
614
+ hidden_states, ray_embeddings, depth_latents=depth_latents
615
+ )
616
+
617
+ cls_scores = all_layers_cls_scores[-1]
618
+ bbox_preds = all_layers_bbox_preds[-1]
619
+ boxes3d_preds = all_layers_outputs_3d[-1]
620
+
621
+ cls_score = cls_scores[0]
622
+ det_bboxes, det_scores, det_labels, det_bboxes3d = (
623
+ self.roi2det3d(
624
+ cls_score,
625
+ bbox_preds[0],
626
+ token_positive_map,
627
+ input_hw[0],
628
+ original_hw[0],
629
+ boxes3d_preds[0],
630
+ intrinsics[0],
631
+ padding[0] if padding is not None else None,
632
+ )
633
+ )
634
+
635
+ boxes.append(det_bboxes)
636
+ scores.append(det_scores)
637
+ boxes3d.append(det_bboxes3d)
638
+
639
+ # Get the categories text and class ids
640
+ cur_class_ids = []
641
+ cur_categories = []
642
+ for label in det_labels:
643
+ cur_class_ids.append(self.cat_mapping[cur_entities[label]])
644
+ cur_categories.append(cur_entities[label])
645
+
646
+ class_ids.append(cur_class_ids)
647
+ categories.append(cur_categories)
648
+
649
+ boxes = [torch.cat([b for b in boxes])]
650
+ boxes3d = [torch.cat([b for b in boxes3d])]
651
+ scores = [torch.cat([s for s in scores])]
652
+ class_ids = [sum(class_ids, [])]
653
+ categories = [sum(categories, [])]
654
+ else:
655
+ # extract text feats
656
+ text_dict = self.language_model(list(text_prompts))
657
+
658
+ # text feature map layer
659
+ if self.text_feat_map is not None:
660
+ text_dict["embedded"] = self.text_feat_map(
661
+ text_dict["embedded"]
662
+ )
663
+
664
+ (
665
+ feat_flatten,
666
+ lvl_pos_embed_flatten,
667
+ memory_mask,
668
+ spatial_shapes,
669
+ level_start_index,
670
+ valid_ratios,
671
+ ) = self.pre_transformer(visual_feats, input_hw, batch_input_shape)
672
+
673
+ ray_embeddings = self.bbox3d_head.get_camera_embeddings(
674
+ intrinsics, batch_input_shape
675
+ )
676
+
677
+ # Depth Head
678
+ depth_preds, depth_latents = self.depth_head(
679
+ depth_feats, intrinsics, batch_input_shape
680
+ )
681
+
682
+ depth_maps = []
683
+ for i, depth_pred in enumerate(depth_preds):
684
+ if padding is not None:
685
+ pad_left, pad_right, pad_top, pad_bottom = padding[i]
686
+
687
+ depth_pred = depth_pred[
688
+ pad_top : batch_input_img_h - pad_bottom,
689
+ pad_left : batch_input_img_w - pad_right,
690
+ ]
691
+
692
+ depth_maps.append(
693
+ F.interpolate(
694
+ depth_pred.unsqueeze(0).unsqueeze(0),
695
+ size=original_hw[i],
696
+ mode="bilinear",
697
+ align_corners=False,
698
+ antialias=True,
699
+ )
700
+ .squeeze(0)
701
+ .squeeze(0)
702
+ )
703
+
704
+ (
705
+ memory_text,
706
+ text_token_mask,
707
+ hidden_states,
708
+ references,
709
+ ) = self.forward_transformer(
710
+ feat_flatten,
711
+ lvl_pos_embed_flatten,
712
+ memory_mask,
713
+ spatial_shapes,
714
+ level_start_index,
715
+ valid_ratios,
716
+ text_dict,
717
+ ray_embeddings=ray_embeddings,
718
+ depth_latents=depth_latents,
719
+ )
720
+
721
+ all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
722
+ hidden_states, references, memory_text, text_token_mask
723
+ )
724
+
725
+ all_layers_outputs_3d = self.bbox3d_head(
726
+ hidden_states, ray_embeddings, depth_latents=depth_latents
727
+ )
728
+
729
+ cls_scores = all_layers_cls_scores[-1]
730
+ bbox_preds = all_layers_bbox_preds[-1]
731
+ boxes3d_preds = all_layers_outputs_3d[-1]
732
+
733
+ boxes = []
734
+ boxes3d = []
735
+ scores = []
736
+ class_ids = []
737
+ categories = []
738
+ for i, bbox_pred in enumerate(bbox_preds):
739
+ cls_score = cls_scores[i]
740
+ det_bboxes, det_scores, det_labels, det_bboxes3d = (
741
+ self.roi2det3d(
742
+ cls_score,
743
+ bbox_pred,
744
+ token_positive_maps[i],
745
+ input_hw[i],
746
+ original_hw[i],
747
+ boxes3d_preds[i],
748
+ intrinsics[i],
749
+ padding[i] if padding is not None else None,
750
+ )
751
+ )
752
+
753
+ boxes.append(det_bboxes)
754
+ scores.append(det_scores)
755
+ class_ids.append(det_labels)
756
+ boxes3d.append(det_bboxes3d)
757
+
758
+ # Get the categories text
759
+ cur_categories = []
760
+ for label in det_labels:
761
+ cur_categories.append(entities[i][label])
762
+
763
+ categories.append(cur_categories)
764
+
765
+ return Det3DOut(
766
+ boxes,
767
+ boxes3d,
768
+ scores,
769
+ class_ids,
770
+ depth_maps=depth_maps,
771
+ categories=categories,
772
+ )
773
+
774
+ def forward(
775
+ self,
776
+ images: Tensor,
777
+ input_hw: list[tuple[int, int]],
778
+ intrinsics: Tensor,
779
+ boxes2d: Tensor | None = None,
780
+ boxes2d_classes: Tensor | None = None,
781
+ original_hw: list[tuple[int, int]] | None = None,
782
+ input_texts: Sequence[str] | str | None = None,
783
+ input_tokens_positive: list[dict[int, list[int, int]]] | None = None,
784
+ text_prompt_mapping: dict[str, dict[str, str]] | None = None,
785
+ padding: list[list[int]] | None = None,
786
+ **kwargs,
787
+ ) -> GroundingDINO3DOut | Det3DOut:
788
+ """Forward function."""
789
+ if self.training:
790
+ assert boxes2d is not None and boxes2d_classes is not None
791
+ return self._forward_train(
792
+ images,
793
+ input_texts,
794
+ boxes2d,
795
+ boxes2d_classes,
796
+ input_hw,
797
+ intrinsics,
798
+ input_tokens_positive=input_tokens_positive,
799
+ padding=padding,
800
+ **kwargs,
801
+ )
802
+
803
+ assert original_hw is not None
804
+ return self._forward_test(
805
+ images,
806
+ input_texts,
807
+ text_prompt_mapping,
808
+ input_hw,
809
+ original_hw,
810
+ intrinsics,
811
+ padding,
812
+ )
opendet3d/model/language/__init__.py ADDED
File without changes
opendet3d/model/language/mm_bert.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BERT model from mmdetection."""
2
+
3
+ import logging
4
+ import os
5
+ from collections import OrderedDict
6
+ from collections.abc import Sequence
7
+
8
+ import torch
9
+ from torch import nn
10
+ from transformers import AutoTokenizer, BertConfig
11
+ from transformers import BertModel as HFBertModel
12
+
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+
16
+ def generate_masks_with_special_tokens_and_transfer_map(
17
+ tokenized, special_tokens_list
18
+ ):
19
+ """Generate attention mask between each pair of special tokens.
20
+
21
+ Only token pairs in between two special tokens are attended to
22
+ and thus the attention mask for these pairs is positive.
23
+
24
+ Args:
25
+ input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
26
+ special_tokens_mask (list): special tokens mask.
27
+
28
+ Returns:
29
+ Tuple(Tensor, Tensor):
30
+ - attention_mask is the attention mask between each tokens.
31
+ Only token pairs in between two special tokens are positive.
32
+ Shape: [bs, num_token, num_token].
33
+ - position_ids is the position id of tokens within each valid sentence.
34
+ The id starts from 0 whenenver a special token is encountered.
35
+ Shape: [bs, num_token]
36
+ """
37
+ input_ids = tokenized["input_ids"]
38
+ bs, num_token = input_ids.shape
39
+ # special_tokens_mask:
40
+ # bs, num_token. 1 for special tokens. 0 for normal tokens
41
+ special_tokens_mask = torch.zeros(
42
+ (bs, num_token), device=input_ids.device
43
+ ).bool()
44
+
45
+ for special_token in special_tokens_list:
46
+ special_tokens_mask |= input_ids == special_token
47
+
48
+ # idxs: each row is a list of indices of special tokens
49
+ idxs = torch.nonzero(special_tokens_mask)
50
+
51
+ # generate attention mask and positional ids
52
+ attention_mask = (
53
+ torch.eye(num_token, device=input_ids.device)
54
+ .bool()
55
+ .unsqueeze(0)
56
+ .repeat(bs, 1, 1)
57
+ )
58
+ position_ids = torch.zeros((bs, num_token), device=input_ids.device)
59
+ previous_col = 0
60
+ for i in range(idxs.shape[0]):
61
+ row, col = idxs[i]
62
+ if (col == 0) or (col == num_token - 1):
63
+ attention_mask[row, col, col] = True
64
+ position_ids[row, col] = 0
65
+ else:
66
+ attention_mask[
67
+ row, previous_col + 1 : col + 1, previous_col + 1 : col + 1
68
+ ] = True
69
+ position_ids[row, previous_col + 1 : col + 1] = torch.arange(
70
+ 0, col - previous_col, device=input_ids.device
71
+ )
72
+ previous_col = col
73
+
74
+ return attention_mask, position_ids.to(torch.long)
75
+
76
+
77
+ class BertModel(nn.Module):
78
+ """BERT model for language embedding only encoder.
79
+
80
+ Args:
81
+ name (str, optional): name of the pretrained BERT model from
82
+ HuggingFace. Defaults to bert-base-uncased.
83
+ max_tokens (int, optional): maximum number of tokens to be
84
+ used for BERT. Defaults to 256.
85
+ pad_to_max (bool, optional): whether to pad the tokens to max_tokens.
86
+ Defaults to True.
87
+ use_sub_sentence_represent (bool, optional): whether to use sub
88
+ sentence represent introduced in `Grounding DINO
89
+ <https://arxiv.org/abs/2303.05499>`. Defaults to False.
90
+ special_tokens_list (list, optional): special tokens used to split
91
+ subsentence. It cannot be None when `use_sub_sentence_represent`
92
+ is True. Defaults to None.
93
+ add_pooling_layer (bool, optional): whether to adding pooling
94
+ layer in bert encoder. Defaults to False.
95
+ num_layers_of_embedded (int, optional): number of layers of
96
+ the embedded model. Defaults to 1.
97
+ use_checkpoint (bool, optional): whether to use gradient checkpointing.
98
+ Defaults to False.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ name: str = "bert-base-uncased",
104
+ max_tokens: int = 256,
105
+ pad_to_max: bool = True,
106
+ use_sub_sentence_represent: bool = False,
107
+ special_tokens_list: list = None,
108
+ add_pooling_layer: bool = False,
109
+ num_layers_of_embedded: int = 1,
110
+ use_checkpoint: bool = False,
111
+ ) -> None:
112
+ """Create the BERT model."""
113
+ super().__init__()
114
+ self.max_tokens = max_tokens
115
+ self.pad_to_max = pad_to_max
116
+
117
+ self.tokenizer = AutoTokenizer.from_pretrained(name)
118
+ self.language_backbone = nn.Sequential(
119
+ OrderedDict(
120
+ [
121
+ (
122
+ "body",
123
+ BertEncoder(
124
+ name,
125
+ add_pooling_layer=add_pooling_layer,
126
+ num_layers_of_embedded=num_layers_of_embedded,
127
+ use_checkpoint=use_checkpoint,
128
+ ),
129
+ )
130
+ ]
131
+ )
132
+ )
133
+
134
+ self.use_sub_sentence_represent = use_sub_sentence_represent
135
+ if self.use_sub_sentence_represent:
136
+ assert (
137
+ special_tokens_list is not None
138
+ ), "special_tokens should not be None \
139
+ if use_sub_sentence_represent is True"
140
+
141
+ self.special_tokens = self.tokenizer.convert_tokens_to_ids(
142
+ special_tokens_list
143
+ )
144
+
145
+ def forward(self, captions: Sequence[str]) -> dict:
146
+ """Forward function."""
147
+ device = next(self.language_backbone.parameters()).device
148
+ tokenized = self.tokenizer.batch_encode_plus(
149
+ captions,
150
+ max_length=self.max_tokens,
151
+ padding="max_length" if self.pad_to_max else "longest",
152
+ return_special_tokens_mask=True,
153
+ return_tensors="pt",
154
+ truncation=True,
155
+ ).to(device)
156
+ input_ids = tokenized.input_ids
157
+ if self.use_sub_sentence_represent:
158
+ attention_mask, position_ids = (
159
+ generate_masks_with_special_tokens_and_transfer_map(
160
+ tokenized, self.special_tokens
161
+ )
162
+ )
163
+ token_type_ids = tokenized["token_type_ids"]
164
+
165
+ else:
166
+ attention_mask = tokenized.attention_mask
167
+ position_ids = None
168
+ token_type_ids = None
169
+
170
+ tokenizer_input = {
171
+ "input_ids": input_ids,
172
+ "attention_mask": attention_mask,
173
+ "position_ids": position_ids,
174
+ "token_type_ids": token_type_ids,
175
+ }
176
+ language_dict_features = self.language_backbone(tokenizer_input)
177
+ if self.use_sub_sentence_represent:
178
+ language_dict_features["position_ids"] = position_ids
179
+ language_dict_features["text_token_mask"] = (
180
+ tokenized.attention_mask.bool()
181
+ )
182
+ return language_dict_features
183
+
184
+
185
+ class BertEncoder(nn.Module):
186
+ """BERT encoder for language embedding.
187
+
188
+ Args:
189
+ name (str): name of the pretrained BERT model from HuggingFace.
190
+ Defaults to bert-base-uncased.
191
+ add_pooling_layer (bool): whether to add a pooling layer.
192
+ num_layers_of_embedded (int): number of layers of the embedded model.
193
+ Defaults to 1.
194
+ use_checkpoint (bool): whether to use gradient checkpointing.
195
+ Defaults to False.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ name: str,
201
+ add_pooling_layer: bool = False,
202
+ num_layers_of_embedded: int = 1,
203
+ use_checkpoint: bool = False,
204
+ ):
205
+ super().__init__()
206
+ config = BertConfig.from_pretrained(name)
207
+ config.gradient_checkpointing = use_checkpoint
208
+
209
+ loggers = [
210
+ logging.getLogger(name) for name in logging.root.manager.loggerDict
211
+ ]
212
+ for logger in loggers:
213
+ if "transformers" in logger.name.lower():
214
+ logger.setLevel(logging.ERROR)
215
+
216
+ # only encoder
217
+ self.model = HFBertModel.from_pretrained(
218
+ name,
219
+ add_pooling_layer=add_pooling_layer,
220
+ config=config,
221
+ attn_implementation="eager",
222
+ )
223
+
224
+ self.language_dim = config.hidden_size
225
+ self.num_layers_of_embedded = num_layers_of_embedded
226
+
227
+ def forward(self, x) -> dict:
228
+ mask = x["attention_mask"]
229
+
230
+ outputs = self.model(
231
+ input_ids=x["input_ids"],
232
+ attention_mask=mask,
233
+ position_ids=x["position_ids"],
234
+ token_type_ids=x["token_type_ids"],
235
+ output_hidden_states=True,
236
+ )
237
+
238
+ # outputs has 13 layers, 1 input layer and 12 hidden layers
239
+ encoded_layers = outputs.hidden_states[1:]
240
+ features = torch.stack(
241
+ encoded_layers[-self.num_layers_of_embedded :], 1
242
+ ).mean(1)
243
+ # language embedding has shape [len(phrase), seq_len, language_dim]
244
+ features = features / self.num_layers_of_embedded
245
+ if mask.dim() == 2:
246
+ embedded = features * mask.unsqueeze(-1).float()
247
+ else:
248
+ embedded = features
249
+
250
+ results = {
251
+ "embedded": embedded,
252
+ "masks": mask,
253
+ "hidden": encoded_layers[-1],
254
+ }
255
+ return results
opendet3d/op/__init__.py ADDED
File without changes
opendet3d/op/base/__init__.py ADDED
File without changes
opendet3d/op/base/swin.py ADDED
@@ -0,0 +1,870 @@