Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
9b33fca
1
Parent(s):
41b3aa4
feat: Try to build everything locally.
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +81 -0
- README.md +1 -0
- opendet3d/__init__.py +1 -0
- opendet3d/common/__init__.py +0 -0
- opendet3d/common/parallel.py +76 -0
- opendet3d/data/__init__.py +0 -0
- opendet3d/data/datasets/__init__.py +0 -0
- opendet3d/data/datasets/argoverse.py +94 -0
- opendet3d/data/datasets/coco3d.py +518 -0
- opendet3d/data/datasets/odvg.py +280 -0
- opendet3d/data/datasets/omni3d/__init__.py +1 -0
- opendet3d/data/datasets/omni3d/arkitscenes.py +81 -0
- opendet3d/data/datasets/omni3d/hypersim.py +190 -0
- opendet3d/data/datasets/omni3d/kitti_object.py +105 -0
- opendet3d/data/datasets/omni3d/nuscenes.py +62 -0
- opendet3d/data/datasets/omni3d/objectron.py +56 -0
- opendet3d/data/datasets/omni3d/omni3d_classes.py +156 -0
- opendet3d/data/datasets/omni3d/sunrgbd.py +278 -0
- opendet3d/data/datasets/omni3d/util.py +74 -0
- opendet3d/data/datasets/scannet.py +449 -0
- opendet3d/data/transforms/__init__.py +0 -0
- opendet3d/data/transforms/crop.py +43 -0
- opendet3d/data/transforms/language.py +267 -0
- opendet3d/data/transforms/pad.py +176 -0
- opendet3d/data/transforms/resize.py +121 -0
- opendet3d/eval/__init__.py +0 -0
- opendet3d/eval/detect3d.py +1249 -0
- opendet3d/eval/omni3d.py +285 -0
- opendet3d/eval/open.py +140 -0
- opendet3d/model/__init__.py +0 -0
- opendet3d/model/detect/__init__.py +0 -0
- opendet3d/model/detect/grounding_dino.py +1050 -0
- opendet3d/model/detect3d/__init__.py +0 -0
- opendet3d/model/detect3d/grounding_dino_3d.py +812 -0
- opendet3d/model/language/__init__.py +0 -0
- opendet3d/model/language/mm_bert.py +255 -0
- opendet3d/op/__init__.py +0 -0
- opendet3d/op/base/__init__.py +0 -0
- opendet3d/op/base/swin.py +870 -0
- opendet3d/op/box/__init__.py +0 -0
- opendet3d/op/box/box2d.py +272 -0
- opendet3d/op/box/box3d.py +79 -0
- opendet3d/op/box/iou_box3d.py +174 -0
- opendet3d/op/box/matchers/__init__.py +0 -0
- opendet3d/op/box/matchers/hungarian.py +117 -0
- opendet3d/op/detect/__init__.py +0 -0
- opendet3d/op/detect/deformable_detr.py +463 -0
- opendet3d/op/detect/detr.py +358 -0
- opendet3d/op/detect/dino.py +667 -0
- opendet3d/op/detect/grounding_dino/__init__.py +17 -0
.gitignore
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# no IntelliJ files
|
| 2 |
+
.idea
|
| 3 |
+
|
| 4 |
+
# don't upload macOS folder info
|
| 5 |
+
*.DS_Store
|
| 6 |
+
|
| 7 |
+
# don't upload node_modules from npm test
|
| 8 |
+
node_modules/*
|
| 9 |
+
flow-typed/*
|
| 10 |
+
|
| 11 |
+
# potential files generated by golang
|
| 12 |
+
bin/
|
| 13 |
+
|
| 14 |
+
# don't upload webpack bundle file
|
| 15 |
+
app/dist/
|
| 16 |
+
|
| 17 |
+
# potential integration testing data directory
|
| 18 |
+
# test_data/
|
| 19 |
+
/data
|
| 20 |
+
|
| 21 |
+
#python
|
| 22 |
+
*.pyc
|
| 23 |
+
__pycache__/
|
| 24 |
+
|
| 25 |
+
# pytype
|
| 26 |
+
.pytype
|
| 27 |
+
|
| 28 |
+
# vscode sftp settings
|
| 29 |
+
.vscode/sftp.json
|
| 30 |
+
|
| 31 |
+
# vscode launch settings
|
| 32 |
+
.vscode/launch.json
|
| 33 |
+
|
| 34 |
+
# redis
|
| 35 |
+
*.rdb
|
| 36 |
+
|
| 37 |
+
# mypy
|
| 38 |
+
.mypy_cache
|
| 39 |
+
|
| 40 |
+
# jest coverage cache
|
| 41 |
+
coverage/
|
| 42 |
+
|
| 43 |
+
# downloaded repos and models
|
| 44 |
+
scalabel/bot/experimental/*
|
| 45 |
+
|
| 46 |
+
# python virtual environment
|
| 47 |
+
env/
|
| 48 |
+
|
| 49 |
+
# vscode workspace configuration
|
| 50 |
+
*.code-workspace
|
| 51 |
+
|
| 52 |
+
# sphinx build folder
|
| 53 |
+
_build/
|
| 54 |
+
|
| 55 |
+
# media files are not in this repo
|
| 56 |
+
doc/media
|
| 57 |
+
|
| 58 |
+
# ignore rope db cache
|
| 59 |
+
.vscode/.ropeproject
|
| 60 |
+
|
| 61 |
+
# python build
|
| 62 |
+
build/
|
| 63 |
+
dist/
|
| 64 |
+
|
| 65 |
+
# coverage
|
| 66 |
+
.coverage*
|
| 67 |
+
|
| 68 |
+
# package default workspace
|
| 69 |
+
vis4d-workspace
|
| 70 |
+
|
| 71 |
+
*.tmp
|
| 72 |
+
|
| 73 |
+
# local test logs and scripts
|
| 74 |
+
log/
|
| 75 |
+
/*.sh
|
| 76 |
+
docs/source/api/*
|
| 77 |
+
docs/source/tutorials/.ipynb_checkpoints/*
|
| 78 |
+
wandb/
|
| 79 |
+
|
| 80 |
+
# No lightning logs
|
| 81 |
+
lightning_logs/
|
README.md
CHANGED
|
@@ -6,6 +6,7 @@ colorTo: yellow
|
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.44.0
|
| 8 |
app_file: app.py
|
|
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
|
|
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.44.0
|
| 8 |
app_file: app.py
|
| 9 |
+
app_build_command: cd vis4d_cuda_ops && pip install -v .
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
opendet3d/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""3D-MOOD."""
|
opendet3d/common/__init__.py
ADDED
|
File without changes
|
opendet3d/common/parallel.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for cpu parallelization."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from multiprocessing import Process, Queue
|
| 7 |
+
|
| 8 |
+
# Disabling unused import becase we need Tuple in typing
|
| 9 |
+
from typing import Callable, Iterable, List, Optional, Tuple, TypeVar
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
Inputs = TypeVar("Inputs")
|
| 14 |
+
Return = TypeVar("Return")
|
| 15 |
+
|
| 16 |
+
cpu_num = os.cpu_count()
|
| 17 |
+
NPROC: int = min(4, cpu_num if cpu_num else 1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run(
|
| 21 |
+
func: Callable[[Inputs], Return],
|
| 22 |
+
q_in: "Queue[Tuple[int, Optional[Tuple[Inputs]]]]",
|
| 23 |
+
q_out: "Queue[Tuple[int, Return]]",
|
| 24 |
+
) -> None:
|
| 25 |
+
"""Run function on the inputs from the queue."""
|
| 26 |
+
while True:
|
| 27 |
+
i, x = q_in.get()
|
| 28 |
+
if i < 0 or x is None:
|
| 29 |
+
break
|
| 30 |
+
q_out.put((i, func(x[0])))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def pmap(
|
| 34 |
+
func: Callable[[Inputs], Return],
|
| 35 |
+
inputs: Iterable[Inputs],
|
| 36 |
+
max_len: int,
|
| 37 |
+
nprocs: int = NPROC,
|
| 38 |
+
) -> List[Return]:
|
| 39 |
+
"""Parrell mapping func to arguments.
|
| 40 |
+
|
| 41 |
+
Different from the python pool map, this function will not hang if any of
|
| 42 |
+
the processes throws an exception.
|
| 43 |
+
"""
|
| 44 |
+
q_in: "Queue[Tuple[int, Optional[Tuple[Inputs]]]]" = Queue(1)
|
| 45 |
+
q_out: "Queue[Tuple[int, Return]]" = Queue()
|
| 46 |
+
|
| 47 |
+
proc = [
|
| 48 |
+
Process(target=run, args=(func, q_in, q_out)) for _ in range(nprocs)
|
| 49 |
+
]
|
| 50 |
+
for p in proc:
|
| 51 |
+
p.daemon = True
|
| 52 |
+
p.start()
|
| 53 |
+
|
| 54 |
+
count = 0
|
| 55 |
+
with tqdm(total=max_len) as pbar:
|
| 56 |
+
for i, x in enumerate(inputs):
|
| 57 |
+
q_in.put((i, (x,)))
|
| 58 |
+
count += 1
|
| 59 |
+
|
| 60 |
+
if count % nprocs == 0:
|
| 61 |
+
pbar.update()
|
| 62 |
+
|
| 63 |
+
pbar.refresh()
|
| 64 |
+
|
| 65 |
+
pbar.update()
|
| 66 |
+
pbar.refresh()
|
| 67 |
+
|
| 68 |
+
for _ in range(nprocs):
|
| 69 |
+
q_in.put((-1, None))
|
| 70 |
+
|
| 71 |
+
res = [q_out.get() for _ in range(count)]
|
| 72 |
+
|
| 73 |
+
for p in proc:
|
| 74 |
+
p.join()
|
| 75 |
+
|
| 76 |
+
return [x for _, x in sorted(res)]
|
opendet3d/data/__init__.py
ADDED
|
File without changes
|
opendet3d/data/datasets/__init__.py
ADDED
|
File without changes
|
opendet3d/data/datasets/argoverse.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Argoverse V2 Sensor dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 6 |
+
|
| 7 |
+
from .coco3d import COCO3DDataset
|
| 8 |
+
|
| 9 |
+
TRAIN_SAMPLE_RATE = 10
|
| 10 |
+
VAL_SAMPLE_RATE = 5
|
| 11 |
+
ACC_FRAMES = 5
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
av2_class_map = {
|
| 15 |
+
"regular vehicle": 0,
|
| 16 |
+
"pedestrian": 1,
|
| 17 |
+
"bicyclist": 2,
|
| 18 |
+
"motorcyclist": 3,
|
| 19 |
+
"wheeled rider": 4,
|
| 20 |
+
"bollard": 5,
|
| 21 |
+
"construction cone": 6,
|
| 22 |
+
"sign": 7,
|
| 23 |
+
"construction barrel": 8,
|
| 24 |
+
"stop sign": 9,
|
| 25 |
+
"mobile pedestrian crossing sign": 10,
|
| 26 |
+
"large vehicle": 11,
|
| 27 |
+
"bus": 12,
|
| 28 |
+
"box truck": 13,
|
| 29 |
+
"truck": 14,
|
| 30 |
+
"vehicular trailer": 15,
|
| 31 |
+
"truck cab": 16,
|
| 32 |
+
"school bus": 17,
|
| 33 |
+
"articulated bus": 18,
|
| 34 |
+
"message board trailer": 19,
|
| 35 |
+
"bicycle": 20,
|
| 36 |
+
"motorcycle": 21,
|
| 37 |
+
"wheeled device": 22,
|
| 38 |
+
"wheelchair": 23,
|
| 39 |
+
"stroller": 24,
|
| 40 |
+
"dog": 25,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
av2_det_map = {
|
| 44 |
+
"regular vehicle": 0,
|
| 45 |
+
"pedestrian": 1,
|
| 46 |
+
"bicyclist": 2,
|
| 47 |
+
"motorcyclist": 3,
|
| 48 |
+
"wheeled rider": 4,
|
| 49 |
+
"bollard": 5,
|
| 50 |
+
"construction cone": 6,
|
| 51 |
+
"sign": 7,
|
| 52 |
+
"construction barrel": 8,
|
| 53 |
+
"stop sign": 9,
|
| 54 |
+
"mobile pedestrian crossing sign": 10,
|
| 55 |
+
"large vehicle": 11,
|
| 56 |
+
"bus": 12,
|
| 57 |
+
"box truck": 13,
|
| 58 |
+
"truck": 14,
|
| 59 |
+
"vehicular trailer": 15,
|
| 60 |
+
"truck cab": 16,
|
| 61 |
+
"school bus": 17,
|
| 62 |
+
"articulated bus": 18,
|
| 63 |
+
"bicycle": 19,
|
| 64 |
+
"motorcycle": 20,
|
| 65 |
+
"wheeled device": 21,
|
| 66 |
+
"stroller": 22,
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class AV2SensorDataset(COCO3DDataset):
|
| 71 |
+
"""Argoverse V2 Sensor dataset."""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
class_map: dict[str, int] = av2_class_map,
|
| 76 |
+
max_depth: float = 80.0,
|
| 77 |
+
depth_scale: float = 256.0,
|
| 78 |
+
**kwargs: ArgsType,
|
| 79 |
+
) -> None:
|
| 80 |
+
"""Creates an instance of the class."""
|
| 81 |
+
super().__init__(
|
| 82 |
+
class_map=class_map,
|
| 83 |
+
max_depth=max_depth,
|
| 84 |
+
depth_scale=depth_scale,
|
| 85 |
+
**kwargs,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def get_depth_filenames(self, img: DictStrAny) -> str | None:
|
| 89 |
+
"""Get the depth filenames."""
|
| 90 |
+
return (
|
| 91 |
+
img["file_path"]
|
| 92 |
+
.replace("images", "depth")
|
| 93 |
+
.replace(".jpg", "_depth.png")
|
| 94 |
+
)
|
opendet3d/data/datasets/coco3d.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""COCO 3D API."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import contextlib
|
| 6 |
+
import io
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from collections.abc import Sequence
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from pycocotools.coco import COCO
|
| 15 |
+
from pyquaternion import Quaternion
|
| 16 |
+
from scipy.spatial.transform import Rotation as R
|
| 17 |
+
from vis4d.common.logging import rank_zero_info, rank_zero_warn
|
| 18 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 19 |
+
from vis4d.data.const import AxisMode
|
| 20 |
+
from vis4d.data.const import CommonKeys as K
|
| 21 |
+
from vis4d.data.datasets.base import Dataset
|
| 22 |
+
from vis4d.data.datasets.util import (
|
| 23 |
+
CacheMappingMixin,
|
| 24 |
+
im_decode,
|
| 25 |
+
print_class_histogram,
|
| 26 |
+
)
|
| 27 |
+
from vis4d.data.typing import DictData
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class COCO3DDataset(CacheMappingMixin, Dataset):
|
| 31 |
+
"""3D Object Detection Dataset using coco annotation files."""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
data_root: str,
|
| 36 |
+
dataset_name: str,
|
| 37 |
+
class_map: dict[str, int],
|
| 38 |
+
det_map: dict[str, int],
|
| 39 |
+
keys_to_load: Sequence[str] = (K.images, K.boxes2d, K.boxes3d),
|
| 40 |
+
with_depth: bool = False,
|
| 41 |
+
max_depth: float = 80.0,
|
| 42 |
+
depth_scale: float = 256.0,
|
| 43 |
+
remove_empty: bool = False,
|
| 44 |
+
data_prefix: str | None = None,
|
| 45 |
+
text_prompt_mapping: dict[str, dict[str, str]] | None = None,
|
| 46 |
+
cache_as_binary: bool = False,
|
| 47 |
+
cached_file_path: str | None = None,
|
| 48 |
+
**kwargs: ArgsType,
|
| 49 |
+
) -> None:
|
| 50 |
+
"""Creates an instance of the class."""
|
| 51 |
+
super().__init__(**kwargs)
|
| 52 |
+
self.data_root = data_root
|
| 53 |
+
self.dataset_name = dataset_name
|
| 54 |
+
self.annotation_file = f"{dataset_name}.json"
|
| 55 |
+
|
| 56 |
+
self.keys_to_load = list(keys_to_load)
|
| 57 |
+
self.remove_empty = remove_empty
|
| 58 |
+
|
| 59 |
+
self.class_map = class_map # Class mapping in the annotation file
|
| 60 |
+
self.det_map = det_map # Class mapping for detection
|
| 61 |
+
self.categories = sorted(self.det_map, key=self.det_map.get)
|
| 62 |
+
|
| 63 |
+
self.data_prefix = data_prefix
|
| 64 |
+
self.text_prompt_mapping = text_prompt_mapping
|
| 65 |
+
|
| 66 |
+
# Metric Depth
|
| 67 |
+
if with_depth and not K.depth_maps in keys_to_load:
|
| 68 |
+
self.keys_to_load.append(K.depth_maps)
|
| 69 |
+
|
| 70 |
+
self.max_depth = max_depth
|
| 71 |
+
self.depth_scale = depth_scale
|
| 72 |
+
|
| 73 |
+
# Load annotations
|
| 74 |
+
self.samples, _ = self._load_mapping(
|
| 75 |
+
self._generate_data_mapping,
|
| 76 |
+
self._filter_data,
|
| 77 |
+
cache_as_binary=cache_as_binary,
|
| 78 |
+
cached_file_path=cached_file_path,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def __repr__(self) -> str:
|
| 82 |
+
"""Concise representation of the dataset."""
|
| 83 |
+
return self.dataset_name
|
| 84 |
+
|
| 85 |
+
def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
|
| 86 |
+
"""Remove empty samples."""
|
| 87 |
+
samples = []
|
| 88 |
+
|
| 89 |
+
frequencies = {cat: 0 for cat in sorted(self.det_map)}
|
| 90 |
+
|
| 91 |
+
empty_samples = 0
|
| 92 |
+
no_depth_samples = 0
|
| 93 |
+
for sample in data:
|
| 94 |
+
if self.remove_empty and len(sample["anns"]) == 0:
|
| 95 |
+
empty_samples += 1
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
if (
|
| 99 |
+
K.depth_maps in self.keys_to_load
|
| 100 |
+
and "depth_filename" not in sample
|
| 101 |
+
):
|
| 102 |
+
empty_samples += 1
|
| 103 |
+
no_depth_samples += 1
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
for ann in sample["anns"]:
|
| 107 |
+
frequencies[ann["category_name"]] += 1
|
| 108 |
+
|
| 109 |
+
samples.append(sample)
|
| 110 |
+
|
| 111 |
+
rank_zero_info(
|
| 112 |
+
f"Propocessing {self.dataset_name} with {len(samples)} samples."
|
| 113 |
+
)
|
| 114 |
+
rank_zero_info(f"No depth samples: {no_depth_samples}")
|
| 115 |
+
rank_zero_info(f"Filtered {empty_samples} empty samples")
|
| 116 |
+
print_class_histogram(frequencies)
|
| 117 |
+
|
| 118 |
+
return samples
|
| 119 |
+
|
| 120 |
+
def _get_cat_id(
|
| 121 |
+
self, img: DictStrAny, ann: DictStrAny, cat_name: str
|
| 122 |
+
) -> None:
|
| 123 |
+
"""Get the category id from the category name."""
|
| 124 |
+
ann["category_id"] = self.det_map[cat_name]
|
| 125 |
+
|
| 126 |
+
def _generate_data_mapping(self) -> list[DictStrAny]:
|
| 127 |
+
"""Generates the data mapping."""
|
| 128 |
+
# Load annotations
|
| 129 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 130 |
+
coco_api = COCO3D(
|
| 131 |
+
os.path.join(
|
| 132 |
+
self.data_root, "annotations", self.annotation_file
|
| 133 |
+
),
|
| 134 |
+
self.categories,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
cats_map = {v: k for k, v in self.class_map.items()}
|
| 138 |
+
|
| 139 |
+
img_ids = sorted(coco_api.getImgIds())
|
| 140 |
+
imgs = coco_api.loadImgs(img_ids)
|
| 141 |
+
|
| 142 |
+
samples = []
|
| 143 |
+
for img_id, img in zip(img_ids, imgs):
|
| 144 |
+
# Fix file path for Omni3D
|
| 145 |
+
if self.data_prefix is not None:
|
| 146 |
+
img["file_path"] = os.path.join(
|
| 147 |
+
self.data_prefix, img["file_path"]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
valid_anns = []
|
| 151 |
+
anns = coco_api.imgToAnns[img_id]
|
| 152 |
+
|
| 153 |
+
boxes = []
|
| 154 |
+
boxes3d = np.empty((0, 10), dtype=np.float32)[1:]
|
| 155 |
+
class_ids = np.empty((0,), dtype=np.int64)[1:]
|
| 156 |
+
for ann in anns:
|
| 157 |
+
cat_name = cats_map[ann["category_id"]]
|
| 158 |
+
assert cat_name == ann["category_name"]
|
| 159 |
+
|
| 160 |
+
if cat_name in {"dontcare", "ignore", "void"}:
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
if ann["ignore"]:
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
self._get_cat_id(img, ann, cat_name)
|
| 167 |
+
|
| 168 |
+
# Box 2D
|
| 169 |
+
x1, y1, width, height = ann["bbox"]
|
| 170 |
+
x2, y2 = x1 + width, y1 + height
|
| 171 |
+
boxes.append((x1, y1, x2, y2))
|
| 172 |
+
|
| 173 |
+
# Class
|
| 174 |
+
class_ids = np.concatenate(
|
| 175 |
+
[
|
| 176 |
+
class_ids,
|
| 177 |
+
np.array([ann["category_id"]], dtype=np.int64),
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Box 3D
|
| 182 |
+
center = ann["center_cam"]
|
| 183 |
+
width, height, length = ann["dimensions"]
|
| 184 |
+
|
| 185 |
+
# Check if the rotation matrix is valid
|
| 186 |
+
try:
|
| 187 |
+
x, y, z, w = R.from_matrix(
|
| 188 |
+
np.array(ann["R_cam"])
|
| 189 |
+
).as_quat()
|
| 190 |
+
except Exception as e:
|
| 191 |
+
rank_zero_warn(
|
| 192 |
+
f"Error processing rotation matrix for annotation {ann['id']}: {e}"
|
| 193 |
+
)
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
orientation = Quaternion([w, x, y, z])
|
| 197 |
+
|
| 198 |
+
boxes3d = np.concatenate(
|
| 199 |
+
[
|
| 200 |
+
boxes3d,
|
| 201 |
+
np.array(
|
| 202 |
+
[
|
| 203 |
+
[
|
| 204 |
+
*center,
|
| 205 |
+
width,
|
| 206 |
+
length,
|
| 207 |
+
height,
|
| 208 |
+
*orientation.elements,
|
| 209 |
+
]
|
| 210 |
+
],
|
| 211 |
+
dtype=np.float32,
|
| 212 |
+
),
|
| 213 |
+
]
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
valid_anns.append(ann)
|
| 217 |
+
|
| 218 |
+
boxes2d = (
|
| 219 |
+
np.empty((0, 4), dtype=np.float32)
|
| 220 |
+
if not boxes
|
| 221 |
+
else np.array(boxes, dtype=np.float32)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
depth_filename = self.get_depth_filenames(img)
|
| 225 |
+
|
| 226 |
+
sample = {
|
| 227 |
+
"img_id": img_id,
|
| 228 |
+
"img": img,
|
| 229 |
+
"anns": valid_anns,
|
| 230 |
+
"boxes2d": boxes2d,
|
| 231 |
+
"boxes3d": boxes3d,
|
| 232 |
+
"class_ids": class_ids,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
if depth_filename is not None and self.data_backend.exists(
|
| 236 |
+
depth_filename
|
| 237 |
+
):
|
| 238 |
+
sample["depth_filename"] = depth_filename
|
| 239 |
+
|
| 240 |
+
samples.append(sample)
|
| 241 |
+
|
| 242 |
+
return samples
|
| 243 |
+
|
| 244 |
+
def get_depth_filenames(self, img: DictStrAny) -> str | None:
|
| 245 |
+
"""Get the depth filenames.
|
| 246 |
+
|
| 247 |
+
Since not every data has depth.
|
| 248 |
+
"""
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
def get_cat_ids(self, idx: int) -> list[int]:
|
| 252 |
+
"""Return the samples."""
|
| 253 |
+
return self.samples[idx]["class_ids"].tolist()
|
| 254 |
+
|
| 255 |
+
def __len__(self) -> int:
|
| 256 |
+
"""Total number of samples of data."""
|
| 257 |
+
return len(self.samples)
|
| 258 |
+
|
| 259 |
+
def get_depth_map(self, sample: DictStrAny) -> np.ndarray:
|
| 260 |
+
"""Get the depth map."""
|
| 261 |
+
depth_bytes = self.data_backend.get(sample["depth_filename"])
|
| 262 |
+
depth_array = im_decode(depth_bytes)
|
| 263 |
+
|
| 264 |
+
depth = np.ascontiguousarray(depth_array, dtype=np.float32)
|
| 265 |
+
|
| 266 |
+
depth = depth / self.depth_scale
|
| 267 |
+
|
| 268 |
+
return depth
|
| 269 |
+
|
| 270 |
+
def __getitem__(self, idx: int) -> DictData:
|
| 271 |
+
"""Get single sample.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
idx (int): Index of sample.
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
DictData: sample at index in Vis4D input format.
|
| 278 |
+
"""
|
| 279 |
+
sample = self.samples[idx]
|
| 280 |
+
data_dict: DictData = {}
|
| 281 |
+
|
| 282 |
+
# Get image info
|
| 283 |
+
data_dict[K.sample_names] = sample["img_id"]
|
| 284 |
+
|
| 285 |
+
data_dict["dataset_name"] = self.dataset_name
|
| 286 |
+
data_dict[K.boxes2d_names] = self.categories
|
| 287 |
+
data_dict["text_prompt_mapping"] = self.text_prompt_mapping
|
| 288 |
+
|
| 289 |
+
if K.images in self.keys_to_load:
|
| 290 |
+
im_bytes = self.data_backend.get(sample["img"]["file_path"])
|
| 291 |
+
image = np.ascontiguousarray(
|
| 292 |
+
im_decode(im_bytes, mode=self.image_channel_mode),
|
| 293 |
+
dtype=np.float32,
|
| 294 |
+
)[None]
|
| 295 |
+
|
| 296 |
+
data_dict[K.images] = image
|
| 297 |
+
data_dict[K.input_hw] = (image.shape[1], image.shape[2])
|
| 298 |
+
|
| 299 |
+
data_dict[K.original_images] = image
|
| 300 |
+
data_dict[K.original_hw] = (image.shape[1], image.shape[2])
|
| 301 |
+
|
| 302 |
+
# Get camera info
|
| 303 |
+
intrinsics = np.array(sample["img"]["K"], dtype=np.float32)
|
| 304 |
+
data_dict[K.intrinsics] = intrinsics
|
| 305 |
+
data_dict["original_intrinsics"] = intrinsics
|
| 306 |
+
|
| 307 |
+
data_dict[K.boxes2d] = sample["boxes2d"]
|
| 308 |
+
data_dict[K.boxes2d_classes] = sample["class_ids"]
|
| 309 |
+
data_dict[K.boxes3d] = sample["boxes3d"]
|
| 310 |
+
data_dict[K.boxes3d_classes] = sample["class_ids"]
|
| 311 |
+
data_dict[K.axis_mode] = AxisMode.OPENCV
|
| 312 |
+
|
| 313 |
+
if K.depth_maps in self.keys_to_load:
|
| 314 |
+
depth = self.get_depth_map(sample)
|
| 315 |
+
|
| 316 |
+
depth[depth > self.max_depth] = 0
|
| 317 |
+
|
| 318 |
+
data_dict[K.depth_maps] = depth
|
| 319 |
+
|
| 320 |
+
data_dict["tokens_positive"] = None
|
| 321 |
+
|
| 322 |
+
self.data_backend.close()
|
| 323 |
+
|
| 324 |
+
return data_dict
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class COCO3D(COCO):
|
| 328 |
+
"""COCO API with 3D annotations."""
|
| 329 |
+
|
| 330 |
+
def __init__(
|
| 331 |
+
self,
|
| 332 |
+
annotation_files: Sequence[str] | str,
|
| 333 |
+
category_names: Sequence[str] | None = None,
|
| 334 |
+
ignore_names: Sequence[str] = ("dontcare", "ignore", "void"),
|
| 335 |
+
truncation_thres: float = 0.33333333,
|
| 336 |
+
visibility_thres: float = 0.33333333,
|
| 337 |
+
min_height_thres: float = 0.0625,
|
| 338 |
+
max_height_thres: float = 1.50,
|
| 339 |
+
modal_2D_boxes: bool = False,
|
| 340 |
+
trunc_2D_boxes: bool = True,
|
| 341 |
+
max_depth: int = 1e8,
|
| 342 |
+
) -> None:
|
| 343 |
+
"""Creates an instance of the class."""
|
| 344 |
+
self.dataset, self.anns, self.cats, self.imgs = {}, {}, {}, {}
|
| 345 |
+
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
|
| 346 |
+
|
| 347 |
+
self.truncation_thres = truncation_thres
|
| 348 |
+
self.visibility_thres = visibility_thres
|
| 349 |
+
self.min_height_thres = min_height_thres
|
| 350 |
+
self.max_height_thres = max_height_thres
|
| 351 |
+
self.max_depth = max_depth
|
| 352 |
+
|
| 353 |
+
if isinstance(annotation_files, str):
|
| 354 |
+
annotation_files = [annotation_files]
|
| 355 |
+
|
| 356 |
+
cats_ids_master = []
|
| 357 |
+
cats_master = []
|
| 358 |
+
|
| 359 |
+
for annotation_file in annotation_files:
|
| 360 |
+
_, tail = os.path.split(annotation_file)
|
| 361 |
+
name, _ = os.path.splitext(tail)
|
| 362 |
+
|
| 363 |
+
print(f"loading {name} annotations into memory...")
|
| 364 |
+
tic = time.time()
|
| 365 |
+
|
| 366 |
+
with open(annotation_file, "r") as f:
|
| 367 |
+
dataset = json.load(f)
|
| 368 |
+
|
| 369 |
+
assert (
|
| 370 |
+
type(dataset) == dict
|
| 371 |
+
), f"annotation file format {type(dataset)} not supported"
|
| 372 |
+
print(f"Done (t={time.time() - tic:.2f}s)")
|
| 373 |
+
|
| 374 |
+
if type(dataset["info"]) == list:
|
| 375 |
+
dataset["info"] = dataset["info"][0]
|
| 376 |
+
|
| 377 |
+
dataset["info"]["known_category_ids"] = [
|
| 378 |
+
cat["id"] for cat in dataset["categories"]
|
| 379 |
+
]
|
| 380 |
+
|
| 381 |
+
# first dataset
|
| 382 |
+
if len(self.dataset) == 0:
|
| 383 |
+
self.dataset = dataset
|
| 384 |
+
# concatenate datasets
|
| 385 |
+
else:
|
| 386 |
+
if type(self.dataset["info"]) == dict:
|
| 387 |
+
self.dataset["info"] = [self.dataset["info"]]
|
| 388 |
+
|
| 389 |
+
self.dataset["info"] += [dataset["info"]]
|
| 390 |
+
self.dataset["annotations"] += dataset["annotations"]
|
| 391 |
+
self.dataset["images"] += dataset["images"]
|
| 392 |
+
|
| 393 |
+
# sort through categories
|
| 394 |
+
for cat in dataset["categories"]:
|
| 395 |
+
if not cat["id"] in cats_ids_master:
|
| 396 |
+
cats_ids_master.append(cat["id"])
|
| 397 |
+
cats_master.append(cat)
|
| 398 |
+
|
| 399 |
+
# category names are provided to us
|
| 400 |
+
if category_names is not None:
|
| 401 |
+
self.dataset["categories"] = [
|
| 402 |
+
cats_master[i]
|
| 403 |
+
for i in np.argsort(cats_ids_master)
|
| 404 |
+
if cats_master[i]["name"] in category_names
|
| 405 |
+
]
|
| 406 |
+
# no categories are provided, so assume use ALL available.
|
| 407 |
+
else:
|
| 408 |
+
self.dataset["categories"] = [
|
| 409 |
+
cats_master[i] for i in np.argsort(cats_ids_master)
|
| 410 |
+
]
|
| 411 |
+
|
| 412 |
+
category_names = [
|
| 413 |
+
cat["name"] for cat in self.dataset["categories"]
|
| 414 |
+
]
|
| 415 |
+
|
| 416 |
+
# determine which categories we may actually use for filtering.
|
| 417 |
+
trainable_cats = set(ignore_names) | set(category_names)
|
| 418 |
+
|
| 419 |
+
valid_anns = []
|
| 420 |
+
im_height_map = {}
|
| 421 |
+
|
| 422 |
+
for im_obj in self.dataset["images"]:
|
| 423 |
+
im_height_map[im_obj["id"]] = im_obj["height"]
|
| 424 |
+
|
| 425 |
+
# Filter out annotations
|
| 426 |
+
for anno_idx, anno in enumerate(self.dataset["annotations"]):
|
| 427 |
+
|
| 428 |
+
im_height = im_height_map[anno["image_id"]]
|
| 429 |
+
|
| 430 |
+
# tightly annotated 2D boxes are not always available.
|
| 431 |
+
if (
|
| 432 |
+
modal_2D_boxes
|
| 433 |
+
and "bbox2D_tight" in anno
|
| 434 |
+
and anno["bbox2D_tight"][0] != -1
|
| 435 |
+
):
|
| 436 |
+
bbox2D = anno["bbox2D_tight"]
|
| 437 |
+
elif (
|
| 438 |
+
trunc_2D_boxes
|
| 439 |
+
and "bbox2D_trunc" in anno
|
| 440 |
+
and not np.all([val == -1 for val in anno["bbox2D_trunc"]])
|
| 441 |
+
):
|
| 442 |
+
bbox2D = anno["bbox2D_trunc"]
|
| 443 |
+
elif anno["bbox2D_proj"][0] != -1:
|
| 444 |
+
bbox2D = anno["bbox2D_proj"]
|
| 445 |
+
elif anno["bbox2D_tight"][0] != -1:
|
| 446 |
+
bbox2D = anno["bbox2D_tight"]
|
| 447 |
+
else:
|
| 448 |
+
continue
|
| 449 |
+
|
| 450 |
+
# convert to xywh
|
| 451 |
+
bbox2D[2] = bbox2D[2] - bbox2D[0]
|
| 452 |
+
bbox2D[3] = bbox2D[3] - bbox2D[1]
|
| 453 |
+
|
| 454 |
+
ignore = self.is_ignore(anno, bbox2D, ignore_names, im_height)
|
| 455 |
+
|
| 456 |
+
width = bbox2D[2]
|
| 457 |
+
height = bbox2D[3]
|
| 458 |
+
|
| 459 |
+
self.dataset["annotations"][anno_idx]["area"] = width * height
|
| 460 |
+
self.dataset["annotations"][anno_idx]["iscrowd"] = False
|
| 461 |
+
self.dataset["annotations"][anno_idx]["ignore"] = ignore
|
| 462 |
+
self.dataset["annotations"][anno_idx]["ignore2D"] = ignore
|
| 463 |
+
self.dataset["annotations"][anno_idx]["ignore3D"] = ignore
|
| 464 |
+
|
| 465 |
+
self.dataset["annotations"][anno_idx]["bbox"] = bbox2D
|
| 466 |
+
self.dataset["annotations"][anno_idx]["bbox3D"] = anno[
|
| 467 |
+
"bbox3D_cam"
|
| 468 |
+
]
|
| 469 |
+
self.dataset["annotations"][anno_idx]["depth"] = anno[
|
| 470 |
+
"center_cam"
|
| 471 |
+
][2]
|
| 472 |
+
|
| 473 |
+
category_name = anno["category_name"]
|
| 474 |
+
|
| 475 |
+
if category_name in trainable_cats:
|
| 476 |
+
valid_anns.append(self.dataset["annotations"][anno_idx])
|
| 477 |
+
|
| 478 |
+
self.dataset["annotations"] = valid_anns
|
| 479 |
+
|
| 480 |
+
self.createIndex()
|
| 481 |
+
|
| 482 |
+
def is_ignore(
|
| 483 |
+
self,
|
| 484 |
+
anno,
|
| 485 |
+
bbox2D: list[float, float, float, float],
|
| 486 |
+
ignore_names: Sequence[str] | None,
|
| 487 |
+
image_height: int,
|
| 488 |
+
) -> bool:
|
| 489 |
+
ignore = anno["behind_camera"]
|
| 490 |
+
ignore |= not bool(anno["valid3D"])
|
| 491 |
+
|
| 492 |
+
if ignore:
|
| 493 |
+
return ignore
|
| 494 |
+
|
| 495 |
+
ignore |= anno["dimensions"][0] <= 0
|
| 496 |
+
ignore |= anno["dimensions"][1] <= 0
|
| 497 |
+
ignore |= anno["dimensions"][2] <= 0
|
| 498 |
+
ignore |= anno["center_cam"][2] > self.max_depth
|
| 499 |
+
ignore |= anno["lidar_pts"] == 0
|
| 500 |
+
ignore |= anno["segmentation_pts"] == 0
|
| 501 |
+
ignore |= anno["depth_error"] > 0.5
|
| 502 |
+
|
| 503 |
+
ignore |= bbox2D[3] <= self.min_height_thres * image_height
|
| 504 |
+
ignore |= bbox2D[3] >= self.max_height_thres * image_height
|
| 505 |
+
|
| 506 |
+
ignore |= (
|
| 507 |
+
anno["truncation"] >= 0
|
| 508 |
+
and anno["truncation"] >= self.truncation_thres
|
| 509 |
+
)
|
| 510 |
+
ignore |= (
|
| 511 |
+
anno["visibility"] >= 0
|
| 512 |
+
and anno["visibility"] <= self.visibility_thres
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if ignore_names is not None:
|
| 516 |
+
ignore |= anno["category_name"] in ignore_names
|
| 517 |
+
|
| 518 |
+
return ignore
|
opendet3d/data/datasets/odvg.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Object detection and visual grounding dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os.path as osp
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from vis4d.common.logging import rank_zero_info
|
| 11 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 12 |
+
from vis4d.data.const import CommonKeys as K
|
| 13 |
+
from vis4d.data.datasets.base import Dataset
|
| 14 |
+
from vis4d.data.datasets.util import (
|
| 15 |
+
CacheMappingMixin,
|
| 16 |
+
im_decode,
|
| 17 |
+
print_class_histogram,
|
| 18 |
+
)
|
| 19 |
+
from vis4d.data.typing import DictData
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ODVGDataset(CacheMappingMixin, Dataset):
|
| 23 |
+
"""Object detection and visual grounding dataset."""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
data_root: str,
|
| 28 |
+
ann_file: str,
|
| 29 |
+
label_map_file: str | None = None,
|
| 30 |
+
dataset_type: str = "VG",
|
| 31 |
+
dataset_prefix: str | None = None,
|
| 32 |
+
remove_empty: bool = False,
|
| 33 |
+
cache_as_binary: bool = False,
|
| 34 |
+
cached_file_path: str | None = None,
|
| 35 |
+
**kwargs: ArgsType,
|
| 36 |
+
) -> None:
|
| 37 |
+
"""Create an object detection and visual grounding dataset."""
|
| 38 |
+
super().__init__(**kwargs)
|
| 39 |
+
|
| 40 |
+
self.data_root = data_root
|
| 41 |
+
self.ann_file = ann_file
|
| 42 |
+
self.dataset_type = dataset_type
|
| 43 |
+
self.dataset_prefix = dataset_prefix
|
| 44 |
+
self.remove_empty = remove_empty
|
| 45 |
+
|
| 46 |
+
if label_map_file is not None:
|
| 47 |
+
label_map_file = osp.join(self.data_root, label_map_file)
|
| 48 |
+
|
| 49 |
+
with open(label_map_file, "r") as file:
|
| 50 |
+
# dict[class_id (str): class_name (str)]
|
| 51 |
+
self.label_map = json.load(file)
|
| 52 |
+
|
| 53 |
+
self.dataset_type = "OD"
|
| 54 |
+
|
| 55 |
+
self.det_map = {v: int(k) for k, v in self.label_map.items()}
|
| 56 |
+
self.categories = sorted(self.det_map, key=self.det_map.get)
|
| 57 |
+
else:
|
| 58 |
+
self.label_map = None
|
| 59 |
+
self.dataset_type = "VG"
|
| 60 |
+
|
| 61 |
+
# Load annotations
|
| 62 |
+
self.samples, _ = self._load_mapping(
|
| 63 |
+
self._generate_data_mapping,
|
| 64 |
+
self._filter_data,
|
| 65 |
+
cache_as_binary=cache_as_binary,
|
| 66 |
+
cached_file_path=cached_file_path,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def __repr__(self) -> str:
|
| 70 |
+
"""Concise representation of the dataset."""
|
| 71 |
+
return f"ODVGDataset({self.ann_file})"
|
| 72 |
+
|
| 73 |
+
def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]:
|
| 74 |
+
"""Remove empty samples."""
|
| 75 |
+
samples = []
|
| 76 |
+
|
| 77 |
+
if self.dataset_type == "OD":
|
| 78 |
+
frequencies = {cat: 0 for _, cat in self.label_map.items()}
|
| 79 |
+
|
| 80 |
+
empty_samples = 0
|
| 81 |
+
for sample in data:
|
| 82 |
+
if self.remove_empty and len(sample["anns"]) == 0:
|
| 83 |
+
empty_samples += 1
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
if self.dataset_type == "OD":
|
| 87 |
+
for ann in sample["anns"]:
|
| 88 |
+
frequencies[ann["category"]] += 1
|
| 89 |
+
|
| 90 |
+
samples.append(sample)
|
| 91 |
+
|
| 92 |
+
rank_zero_info(f"Propocessing {self} with {len(samples)} samples.")
|
| 93 |
+
rank_zero_info(f"Filtered {empty_samples} empty samples")
|
| 94 |
+
|
| 95 |
+
if self.dataset_type == "OD":
|
| 96 |
+
frequencies = dict(sorted(frequencies.items()))
|
| 97 |
+
|
| 98 |
+
print_class_histogram(frequencies)
|
| 99 |
+
|
| 100 |
+
return samples
|
| 101 |
+
|
| 102 |
+
def _generate_data_mapping(self) -> list[DictStrAny]:
|
| 103 |
+
"""Generates the data mapping."""
|
| 104 |
+
with open(osp.join(self.data_root, self.ann_file), "r") as f:
|
| 105 |
+
data_list = [json.loads(line) for line in f]
|
| 106 |
+
|
| 107 |
+
if self.with_camera:
|
| 108 |
+
with open(osp.join(self.data_root, "cam_info.json"), "r") as f:
|
| 109 |
+
cameras = json.load(f)
|
| 110 |
+
|
| 111 |
+
samples = []
|
| 112 |
+
for data in tqdm(data_list):
|
| 113 |
+
data_info = {}
|
| 114 |
+
|
| 115 |
+
if self.dataset_prefix is not None:
|
| 116 |
+
img_path = osp.join(
|
| 117 |
+
self.data_root, self.dataset_prefix, data["filename"]
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
img_path = osp.join(self.data_root, data["filename"])
|
| 121 |
+
|
| 122 |
+
data_info["img_path"] = img_path
|
| 123 |
+
|
| 124 |
+
# Pseudo K
|
| 125 |
+
if self.with_camera:
|
| 126 |
+
data_info["K"] = cameras[img_path][0]
|
| 127 |
+
|
| 128 |
+
# Pseudo Depth Path
|
| 129 |
+
if self.dataset_prefix is not None:
|
| 130 |
+
depth_path = osp.join(
|
| 131 |
+
self.data_root,
|
| 132 |
+
f"{self.dataset_prefix}_depth",
|
| 133 |
+
data["filename"].replace(".jpg", "_depth.png"),
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
depth_path = osp.join(
|
| 137 |
+
self.data_root,
|
| 138 |
+
data["filename"].replace(".jpg", "_depth.png"),
|
| 139 |
+
)
|
| 140 |
+
data_info["depth_path"] = depth_path
|
| 141 |
+
|
| 142 |
+
data_info["height"] = data["height"]
|
| 143 |
+
data_info["width"] = data["width"]
|
| 144 |
+
|
| 145 |
+
valid_anns = []
|
| 146 |
+
boxes = []
|
| 147 |
+
class_ids = np.empty((0,), dtype=np.int64)[1:]
|
| 148 |
+
if self.dataset_type == "OD":
|
| 149 |
+
instances = data.get("detection", {}).get("instances", [])
|
| 150 |
+
|
| 151 |
+
for ann in instances:
|
| 152 |
+
bbox = ann["bbox"]
|
| 153 |
+
|
| 154 |
+
# Box 2D
|
| 155 |
+
x1, y1, x2, y2 = bbox
|
| 156 |
+
inter_w = max(0, min(x2, data["width"]) - max(x1, 0))
|
| 157 |
+
inter_h = max(0, min(y2, data["height"]) - max(y1, 0))
|
| 158 |
+
|
| 159 |
+
if inter_w * inter_h == 0:
|
| 160 |
+
continue
|
| 161 |
+
if (x2 - x1) < 1 or (y2 - y1) < 1:
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
boxes.append(bbox)
|
| 165 |
+
|
| 166 |
+
# Class
|
| 167 |
+
class_ids = np.concatenate(
|
| 168 |
+
[class_ids, np.array([ann["label"]], dtype=np.int64)]
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
valid_anns.append(ann)
|
| 172 |
+
else:
|
| 173 |
+
anno = data["grounding"]
|
| 174 |
+
|
| 175 |
+
caption = anno["caption"].lower().strip()
|
| 176 |
+
if not caption.endswith("."):
|
| 177 |
+
caption = caption + ". "
|
| 178 |
+
|
| 179 |
+
data_info["caption"] = caption
|
| 180 |
+
|
| 181 |
+
regions = anno["regions"]
|
| 182 |
+
phrases = []
|
| 183 |
+
positive_positions = []
|
| 184 |
+
for i, region in enumerate(regions):
|
| 185 |
+
bboxes = region["bbox"]
|
| 186 |
+
|
| 187 |
+
if not isinstance(bboxes[0], list):
|
| 188 |
+
bboxes = [bboxes]
|
| 189 |
+
|
| 190 |
+
for bbox in bboxes:
|
| 191 |
+
x1, y1, x2, y2 = bbox
|
| 192 |
+
inter_w = max(0, min(x2, data["width"]) - max(x1, 0))
|
| 193 |
+
inter_h = max(0, min(y2, data["height"]) - max(y1, 0))
|
| 194 |
+
|
| 195 |
+
if inter_w * inter_h == 0:
|
| 196 |
+
continue
|
| 197 |
+
if (x2 - x1) < 1 or (y2 - y1) < 1:
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
boxes.append(bbox)
|
| 201 |
+
phrases.append(region["phrase"])
|
| 202 |
+
positive_positions.append(region["tokens_positive"])
|
| 203 |
+
valid_anns.append(region)
|
| 204 |
+
|
| 205 |
+
class_ids = np.concatenate(
|
| 206 |
+
[class_ids, np.array([i], dtype=np.int64)]
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
data_info["phrases"] = phrases
|
| 210 |
+
data_info["positive_positions"] = positive_positions
|
| 211 |
+
|
| 212 |
+
boxes2d = (
|
| 213 |
+
np.empty((0, 4), dtype=np.float32)
|
| 214 |
+
if not boxes
|
| 215 |
+
else np.array(boxes, dtype=np.float32)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
data_info["boxes2d"] = boxes2d
|
| 219 |
+
data_info["class_ids"] = class_ids
|
| 220 |
+
data_info["anns"] = valid_anns
|
| 221 |
+
|
| 222 |
+
samples.append(data_info)
|
| 223 |
+
|
| 224 |
+
del data_list
|
| 225 |
+
return samples
|
| 226 |
+
|
| 227 |
+
def get_cat_ids(self, idx: int) -> list[int]:
|
| 228 |
+
"""Return the samples."""
|
| 229 |
+
return self.samples[idx]["class_ids"].tolist()
|
| 230 |
+
|
| 231 |
+
def __len__(self) -> int:
|
| 232 |
+
"""Total number of samples of data."""
|
| 233 |
+
return len(self.samples)
|
| 234 |
+
|
| 235 |
+
def __getitem__(self, idx: int) -> DictData:
|
| 236 |
+
"""Get single sample.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
idx (int): Index of sample.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
DictData: sample at index in Vis4D input format.
|
| 243 |
+
"""
|
| 244 |
+
sample = self.samples[idx]
|
| 245 |
+
data_dict: DictData = {}
|
| 246 |
+
|
| 247 |
+
# Get image info
|
| 248 |
+
sample_name = sample["img_path"].split("/")[-1]
|
| 249 |
+
data_dict[K.sample_names] = sample_name
|
| 250 |
+
|
| 251 |
+
im_bytes = self.data_backend.get(sample["img_path"])
|
| 252 |
+
image = np.ascontiguousarray(
|
| 253 |
+
im_decode(im_bytes, mode=self.image_channel_mode),
|
| 254 |
+
dtype=np.float32,
|
| 255 |
+
)[None]
|
| 256 |
+
|
| 257 |
+
data_dict[K.images] = image
|
| 258 |
+
data_dict[K.input_hw] = (image.shape[1], image.shape[2])
|
| 259 |
+
|
| 260 |
+
data_dict[K.original_images] = image
|
| 261 |
+
data_dict[K.original_hw] = (image.shape[1], image.shape[2])
|
| 262 |
+
|
| 263 |
+
data_dict[K.boxes2d] = sample["boxes2d"]
|
| 264 |
+
data_dict[K.boxes2d_classes] = sample["class_ids"]
|
| 265 |
+
|
| 266 |
+
if self.dataset_type == "OD":
|
| 267 |
+
data_dict[K.boxes2d_names] = self.categories
|
| 268 |
+
data_dict["phrases"] = None
|
| 269 |
+
data_dict["positive_positions"] = None
|
| 270 |
+
else:
|
| 271 |
+
data_dict[K.boxes2d_names] = sample["caption"]
|
| 272 |
+
data_dict["phrases"] = sample["phrases"]
|
| 273 |
+
data_dict["positive_positions"] = sample["positive_positions"]
|
| 274 |
+
|
| 275 |
+
data_dict["dataset_type"] = self.dataset_type
|
| 276 |
+
data_dict["label_map"] = self.label_map
|
| 277 |
+
|
| 278 |
+
self.data_backend.close()
|
| 279 |
+
|
| 280 |
+
return data_dict
|
opendet3d/data/datasets/omni3d/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Omni3D Dataset."""
|
opendet3d/data/datasets/omni3d/arkitscenes.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ARKitScenes from Omni3D."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 8 |
+
|
| 9 |
+
from opendet3d.data.datasets.coco3d import COCO3DDataset
|
| 10 |
+
|
| 11 |
+
from .omni3d_classes import omni3d_class_map
|
| 12 |
+
|
| 13 |
+
arkitscenes_det_map = {
|
| 14 |
+
"bathtub": 0,
|
| 15 |
+
"bed": 1,
|
| 16 |
+
"cabinet": 2,
|
| 17 |
+
"chair": 3,
|
| 18 |
+
"fireplace": 4,
|
| 19 |
+
"machine": 5,
|
| 20 |
+
"oven": 6,
|
| 21 |
+
"refrigerator": 7,
|
| 22 |
+
"shelves": 8,
|
| 23 |
+
"sink": 9,
|
| 24 |
+
"sofa": 10,
|
| 25 |
+
"stove": 11,
|
| 26 |
+
"table": 12,
|
| 27 |
+
"television": 13,
|
| 28 |
+
"toilet": 14,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
omni3d_arkitscenes_det_map = {
|
| 32 |
+
"table": 0,
|
| 33 |
+
"bed": 1,
|
| 34 |
+
"sofa": 2,
|
| 35 |
+
"television": 3,
|
| 36 |
+
"refrigerator": 4,
|
| 37 |
+
"chair": 5,
|
| 38 |
+
"oven": 6,
|
| 39 |
+
"machine": 7,
|
| 40 |
+
"stove": 8,
|
| 41 |
+
"shelves": 9,
|
| 42 |
+
"sink": 10,
|
| 43 |
+
"cabinet": 11,
|
| 44 |
+
"bathtub": 12,
|
| 45 |
+
"toilet": 13,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ARKitScenes(COCO3DDataset):
|
| 50 |
+
"""ARKitScenes Dataset."""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
class_map: dict[str, int] = omni3d_class_map,
|
| 55 |
+
max_depth: float = 10.0,
|
| 56 |
+
depth_scale: float = 1000.0,
|
| 57 |
+
**kwargs: ArgsType,
|
| 58 |
+
) -> None:
|
| 59 |
+
"""Creates an instance of the class."""
|
| 60 |
+
super().__init__(
|
| 61 |
+
class_map=class_map,
|
| 62 |
+
max_depth=max_depth,
|
| 63 |
+
depth_scale=depth_scale,
|
| 64 |
+
**kwargs,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def get_depth_filenames(self, img: DictStrAny) -> str | None:
|
| 68 |
+
"""Get the depth filenames.
|
| 69 |
+
|
| 70 |
+
Since not every data has depth.
|
| 71 |
+
"""
|
| 72 |
+
_, _, split, video_id, image_name = img["file_path"].split("/")
|
| 73 |
+
|
| 74 |
+
depth_filename = os.path.join(
|
| 75 |
+
"data/ARKitScenes_depth",
|
| 76 |
+
split,
|
| 77 |
+
video_id,
|
| 78 |
+
image_name.replace("jpg", "png"),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return depth_filename
|
opendet3d/data/datasets/omni3d/hypersim.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hypersim from Omni3D."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 8 |
+
|
| 9 |
+
from opendet3d.data.datasets.coco3d import COCO3DDataset
|
| 10 |
+
|
| 11 |
+
from .omni3d_classes import omni3d_class_map
|
| 12 |
+
|
| 13 |
+
hypersim_train_det_map = {
|
| 14 |
+
"bathtub": 0,
|
| 15 |
+
"bed": 1,
|
| 16 |
+
"blinds": 2,
|
| 17 |
+
"bookcase": 3,
|
| 18 |
+
"books": 4,
|
| 19 |
+
"box": 5,
|
| 20 |
+
"cabinet": 6,
|
| 21 |
+
"chair": 7,
|
| 22 |
+
"clothes": 8,
|
| 23 |
+
"counter": 9,
|
| 24 |
+
"curtain": 10,
|
| 25 |
+
"desk": 11,
|
| 26 |
+
"door": 12,
|
| 27 |
+
"dresser": 13,
|
| 28 |
+
"floor mat": 14,
|
| 29 |
+
"lamp": 15,
|
| 30 |
+
"mirror": 16,
|
| 31 |
+
"night stand": 17,
|
| 32 |
+
"person": 18,
|
| 33 |
+
"picture": 19,
|
| 34 |
+
"pillow": 20,
|
| 35 |
+
"refrigerator": 21,
|
| 36 |
+
"shelves": 22,
|
| 37 |
+
"sink": 23,
|
| 38 |
+
"sofa": 24,
|
| 39 |
+
"stationery": 25,
|
| 40 |
+
"table": 26,
|
| 41 |
+
"television": 27,
|
| 42 |
+
"toilet": 28,
|
| 43 |
+
"towel": 29,
|
| 44 |
+
"window": 30,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
hypersim_val_det_map = {
|
| 48 |
+
"bathtub": 0,
|
| 49 |
+
"bed": 1,
|
| 50 |
+
"blinds": 2,
|
| 51 |
+
"bookcase": 3,
|
| 52 |
+
"books": 4,
|
| 53 |
+
"box": 5,
|
| 54 |
+
"cabinet": 6,
|
| 55 |
+
"chair": 7,
|
| 56 |
+
"clothes": 8,
|
| 57 |
+
"counter": 9,
|
| 58 |
+
"curtain": 10,
|
| 59 |
+
"desk": 11,
|
| 60 |
+
"door": 12,
|
| 61 |
+
"dresser": 13,
|
| 62 |
+
"floor mat": 14,
|
| 63 |
+
"lamp": 15,
|
| 64 |
+
"mirror": 16,
|
| 65 |
+
"night stand": 17,
|
| 66 |
+
"picture": 18,
|
| 67 |
+
"pillow": 19,
|
| 68 |
+
"refrigerator": 20,
|
| 69 |
+
"shelves": 21,
|
| 70 |
+
"sink": 22,
|
| 71 |
+
"sofa": 23,
|
| 72 |
+
"stationery": 24,
|
| 73 |
+
"table": 25,
|
| 74 |
+
"television": 26,
|
| 75 |
+
"toilet": 27,
|
| 76 |
+
"towel": 28,
|
| 77 |
+
"window": 29,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
hypersim_test_det_map = {
|
| 81 |
+
"bathtub": 0,
|
| 82 |
+
"bed": 1,
|
| 83 |
+
"blinds": 2,
|
| 84 |
+
"board": 3,
|
| 85 |
+
"bookcase": 4,
|
| 86 |
+
"books": 5,
|
| 87 |
+
"box": 6,
|
| 88 |
+
"cabinet": 7,
|
| 89 |
+
"chair": 8,
|
| 90 |
+
"clothes": 9,
|
| 91 |
+
"counter": 10,
|
| 92 |
+
"curtain": 11,
|
| 93 |
+
"desk": 12,
|
| 94 |
+
"door": 13,
|
| 95 |
+
"floor mat": 14,
|
| 96 |
+
"lamp": 15,
|
| 97 |
+
"mirror": 16,
|
| 98 |
+
"night stand": 17,
|
| 99 |
+
"picture": 18,
|
| 100 |
+
"pillow": 19,
|
| 101 |
+
"refrigerator": 20,
|
| 102 |
+
"shelves": 21,
|
| 103 |
+
"sink": 22,
|
| 104 |
+
"sofa": 23,
|
| 105 |
+
"stationery": 24,
|
| 106 |
+
"table": 25,
|
| 107 |
+
"television": 26,
|
| 108 |
+
"towel": 27,
|
| 109 |
+
"window": 28,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
omni3d_hypersim_det_map = {
|
| 114 |
+
"books": 0,
|
| 115 |
+
"chair": 1,
|
| 116 |
+
"towel": 2,
|
| 117 |
+
"blinds": 3,
|
| 118 |
+
"window": 4,
|
| 119 |
+
"lamp": 5,
|
| 120 |
+
"shelves": 6,
|
| 121 |
+
"mirror": 7,
|
| 122 |
+
"sink": 8,
|
| 123 |
+
"cabinet": 9,
|
| 124 |
+
"bathtub": 10,
|
| 125 |
+
"door": 11,
|
| 126 |
+
"desk": 12,
|
| 127 |
+
"box": 13,
|
| 128 |
+
"bookcase": 14,
|
| 129 |
+
"picture": 15,
|
| 130 |
+
"table": 16,
|
| 131 |
+
"counter": 17,
|
| 132 |
+
"bed": 18,
|
| 133 |
+
"night stand": 19,
|
| 134 |
+
"pillow": 20,
|
| 135 |
+
"sofa": 21,
|
| 136 |
+
"television": 22,
|
| 137 |
+
"floor mat": 23,
|
| 138 |
+
"curtain": 24,
|
| 139 |
+
"clothes": 25,
|
| 140 |
+
"stationery": 26,
|
| 141 |
+
"refrigerator": 27,
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_hypersim_det_map(split: str) -> dict[str, int]:
|
| 146 |
+
"""Get Hypersim detection map."""
|
| 147 |
+
assert split in {"train", "val", "test"}, f"Invalid split: {split}"
|
| 148 |
+
|
| 149 |
+
if split == "train":
|
| 150 |
+
return hypersim_train_det_map
|
| 151 |
+
elif split == "val":
|
| 152 |
+
return hypersim_val_det_map
|
| 153 |
+
elif split == "test":
|
| 154 |
+
return hypersim_test_det_map
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class Hypersim(COCO3DDataset):
|
| 158 |
+
"""Hypersim Dataset."""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
class_map: dict[str, int] = omni3d_class_map,
|
| 163 |
+
max_depth: float = 50.0,
|
| 164 |
+
depth_scale: float = 1000.0,
|
| 165 |
+
**kwargs: ArgsType,
|
| 166 |
+
) -> None:
|
| 167 |
+
"""Creates an instance of the class."""
|
| 168 |
+
super().__init__(
|
| 169 |
+
class_map=class_map,
|
| 170 |
+
max_depth=max_depth,
|
| 171 |
+
depth_scale=depth_scale,
|
| 172 |
+
**kwargs,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def get_depth_filenames(self, img: DictStrAny) -> str | None:
|
| 176 |
+
"""Get the depth filenames.
|
| 177 |
+
|
| 178 |
+
Since not every data has depth.
|
| 179 |
+
"""
|
| 180 |
+
_, _, scene, _, img_dir, img_name = img["file_path"].split("/")
|
| 181 |
+
|
| 182 |
+
depth_filename = os.path.join(
|
| 183 |
+
"data/hypersim_depth",
|
| 184 |
+
scene,
|
| 185 |
+
"images",
|
| 186 |
+
img_dir,
|
| 187 |
+
img_name.replace("jpg", "png"),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
return depth_filename
|
opendet3d/data/datasets/omni3d/kitti_object.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KITTI Object from Omni3D.
|
| 2 |
+
|
| 3 |
+
KITTI Object Labels:
|
| 4 |
+
Categories, -, -, alpha, x1, y1, x2, y2, h, w, l, x, botom_y, z, ry
|
| 5 |
+
|
| 6 |
+
KITTI Object Categories:
|
| 7 |
+
{
|
| 8 |
+
"Pedestrian": "pedestrian",
|
| 9 |
+
"Cyclist": "cyclist",
|
| 10 |
+
"Car": "car",
|
| 11 |
+
"Van": "car",
|
| 12 |
+
"Truck": "truck",
|
| 13 |
+
"Tram": "tram",
|
| 14 |
+
"Person": "pedestrian",
|
| 15 |
+
"Person_sitting": "pedestrian",
|
| 16 |
+
"Misc": "misc",
|
| 17 |
+
"DontCare": "dontcare",
|
| 18 |
+
}
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 26 |
+
|
| 27 |
+
from opendet3d.data.datasets.coco3d import COCO3DDataset
|
| 28 |
+
|
| 29 |
+
from .omni3d_classes import omni3d_class_map
|
| 30 |
+
|
| 31 |
+
kitti_train_det_map = kitti_test_det_map = {
|
| 32 |
+
"car": 0,
|
| 33 |
+
"cyclist": 1,
|
| 34 |
+
"pedestrian": 2,
|
| 35 |
+
"person": 3,
|
| 36 |
+
"tram": 4,
|
| 37 |
+
"truck": 5,
|
| 38 |
+
"van": 6,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
kitti_val_det_map = {
|
| 42 |
+
"car": 0,
|
| 43 |
+
"cyclist": 1,
|
| 44 |
+
"pedestrian": 2,
|
| 45 |
+
"tram": 3,
|
| 46 |
+
"truck": 4,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
# KITTI-Omni3D Mapping
|
| 50 |
+
omni3d_kitti_det_map = {
|
| 51 |
+
"pedestrian": 0,
|
| 52 |
+
"car": 1,
|
| 53 |
+
"cyclist": 2,
|
| 54 |
+
"van": 3,
|
| 55 |
+
"truck": 4,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_kitti_det_map(split: str) -> dict[str, int]:
|
| 60 |
+
"""Get the KITTI detection map."""
|
| 61 |
+
assert split in {"train", "val", "test"}, f"Invalid split: {split}"
|
| 62 |
+
|
| 63 |
+
if split == "val":
|
| 64 |
+
return kitti_val_det_map
|
| 65 |
+
|
| 66 |
+
# Train and Test are the same
|
| 67 |
+
return kitti_train_det_map
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class KITTIObject(COCO3DDataset):
|
| 71 |
+
"""KITTI Object Dataset."""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
class_map: dict[str, int] = omni3d_class_map,
|
| 76 |
+
max_depth: float = 80.0,
|
| 77 |
+
depth_scale: float = 256.0,
|
| 78 |
+
depth_data_root: str = "data/KITTI_object_depth",
|
| 79 |
+
**kwargs: ArgsType,
|
| 80 |
+
) -> None:
|
| 81 |
+
"""Creates an instance of the class."""
|
| 82 |
+
self.depth_data_root = depth_data_root
|
| 83 |
+
|
| 84 |
+
super().__init__(
|
| 85 |
+
class_map=class_map,
|
| 86 |
+
max_depth=max_depth,
|
| 87 |
+
depth_scale=depth_scale,
|
| 88 |
+
**kwargs,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def get_depth_filenames(self, img: DictStrAny) -> str | None:
|
| 92 |
+
"""Get the depth filenames.
|
| 93 |
+
|
| 94 |
+
Since not every data has depth.
|
| 95 |
+
"""
|
| 96 |
+
_, _, split, image_id, img_filename = img["file_path"].split("/")
|
| 97 |
+
|
| 98 |
+
depth_filename = os.path.join(
|
| 99 |
+
self.depth_data_root,
|
| 100 |
+
split,
|
| 101 |
+
image_id,
|
| 102 |
+
img_filename.replace(".jpg", ".png"),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return depth_filename
|
opendet3d/data/datasets/omni3d/nuscenes.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""nuScenes from Omni3D."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 6 |
+
|
| 7 |
+
from opendet3d.data.datasets.coco3d import COCO3DDataset
|
| 8 |
+
|
| 9 |
+
from .omni3d_classes import omni3d_class_map
|
| 10 |
+
|
| 11 |
+
nusc_det_map = {
|
| 12 |
+
"bicycle": 0,
|
| 13 |
+
"motorcycle": 1,
|
| 14 |
+
"pedestrian": 2,
|
| 15 |
+
"bus": 3,
|
| 16 |
+
"car": 4,
|
| 17 |
+
"trailer": 5,
|
| 18 |
+
"truck": 6,
|
| 19 |
+
"traffic cone": 7,
|
| 20 |
+
"barrier": 8,
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class nuScenes(COCO3DDataset):
|
| 25 |
+
"""nuScenes dataset."""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
class_map: dict[str, int] = omni3d_class_map,
|
| 30 |
+
max_depth: float = 80.0,
|
| 31 |
+
depth_scale: float = 256.0,
|
| 32 |
+
**kwargs: ArgsType,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""Creates an instance of the class."""
|
| 35 |
+
super().__init__(
|
| 36 |
+
class_map=class_map,
|
| 37 |
+
max_depth=max_depth,
|
| 38 |
+
depth_scale=depth_scale,
|
| 39 |
+
**kwargs,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def get_depth_filenames(self, img: DictStrAny) -> str | None:
|
| 43 |
+
"""Get the depth filenames.
|
| 44 |
+
|
| 45 |
+
Since not every data has depth.
|
| 46 |
+
"""
|
| 47 |
+
img["file_path"] = img["file_path"].replace("nuScenes", "nuscenes")
|
| 48 |
+
|
| 49 |
+
depth_filename = (
|
| 50 |
+
img["file_path"]
|
| 51 |
+
.replace("nuscenes", "nuscenes_depth")
|
| 52 |
+
.replace("jpg", "png")
|
| 53 |
+
)
|
| 54 |
+
return depth_filename
|
| 55 |
+
|
| 56 |
+
def get_cat_ids(self, idx: int) -> list[int]:
|
| 57 |
+
"""Return the samples."""
|
| 58 |
+
return self.samples[idx]["class_ids"].tolist()
|
| 59 |
+
|
| 60 |
+
def __len__(self) -> int:
|
| 61 |
+
"""Total number of samples of data."""
|
| 62 |
+
return len(self.samples)
|
opendet3d/data/datasets/omni3d/objectron.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Objectron from Omni3D."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 8 |
+
|
| 9 |
+
from opendet3d.data.datasets.coco3d import COCO3DDataset
|
| 10 |
+
|
| 11 |
+
from .omni3d_classes import omni3d_class_map
|
| 12 |
+
|
| 13 |
+
objectron_det_map = {
|
| 14 |
+
"bicycle": 0,
|
| 15 |
+
"books": 1,
|
| 16 |
+
"bottle": 2,
|
| 17 |
+
"camera": 3,
|
| 18 |
+
"cereal box": 4,
|
| 19 |
+
"chair": 5,
|
| 20 |
+
"cup": 6,
|
| 21 |
+
"laptop": 7,
|
| 22 |
+
"shoes": 8,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Objectron(COCO3DDataset):
|
| 27 |
+
"""Objectron dataset."""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
class_map: dict[str, int] = omni3d_class_map,
|
| 32 |
+
max_depth: float = 12.0,
|
| 33 |
+
depth_scale: float = 1000.0,
|
| 34 |
+
**kwargs: ArgsType,
|
| 35 |
+
) -> None:
|
| 36 |
+
"""Creates an instance of the class."""
|
| 37 |
+
super().__init__(
|
| 38 |
+
class_map=class_map,
|
| 39 |
+
max_depth=max_depth,
|
| 40 |
+
depth_scale=depth_scale,
|
| 41 |
+
**kwargs,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def get_depth_filenames(self, img: DictStrAny) -> str | None:
|
| 45 |
+
"""Get the depth filenames.
|
| 46 |
+
|
| 47 |
+
Since not every data has depth.
|
| 48 |
+
"""
|
| 49 |
+
_, _, split, img_name = img["file_path"].split("/")
|
| 50 |
+
|
| 51 |
+
depth_filename = os.path.join(
|
| 52 |
+
"data/objectron_depth",
|
| 53 |
+
split,
|
| 54 |
+
img_name.replace(".jpg", "_depth.png"),
|
| 55 |
+
)
|
| 56 |
+
return depth_filename
|
opendet3d/data/datasets/omni3d/omni3d_classes.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Omni3D classes."""
|
| 2 |
+
|
| 3 |
+
omni3d_class_map = {
|
| 4 |
+
"pedestrian": 0,
|
| 5 |
+
"car": 1,
|
| 6 |
+
"dontcare": 2,
|
| 7 |
+
"cyclist": 3,
|
| 8 |
+
"van": 4,
|
| 9 |
+
"truck": 5,
|
| 10 |
+
"tram": 6,
|
| 11 |
+
"person": 7,
|
| 12 |
+
"traffic cone": 8,
|
| 13 |
+
"barrier": 9,
|
| 14 |
+
"motorcycle": 10,
|
| 15 |
+
"bicycle": 11,
|
| 16 |
+
"bus": 12,
|
| 17 |
+
"trailer": 13,
|
| 18 |
+
"books": 14,
|
| 19 |
+
"bottle": 15,
|
| 20 |
+
"camera": 16,
|
| 21 |
+
"cereal box": 17,
|
| 22 |
+
"chair": 18,
|
| 23 |
+
"cup": 19,
|
| 24 |
+
"laptop": 20,
|
| 25 |
+
"shoes": 21,
|
| 26 |
+
"towel": 22,
|
| 27 |
+
"blinds": 23,
|
| 28 |
+
"window": 24,
|
| 29 |
+
"lamp": 25,
|
| 30 |
+
"shelves": 26,
|
| 31 |
+
"mirror": 27,
|
| 32 |
+
"sink": 28,
|
| 33 |
+
"cabinet": 29,
|
| 34 |
+
"bathtub": 30,
|
| 35 |
+
"door": 31,
|
| 36 |
+
"toilet": 32,
|
| 37 |
+
"desk": 33,
|
| 38 |
+
"box": 34,
|
| 39 |
+
"bookcase": 35,
|
| 40 |
+
"picture": 36,
|
| 41 |
+
"table": 37,
|
| 42 |
+
"counter": 38,
|
| 43 |
+
"bed": 39,
|
| 44 |
+
"night stand": 40,
|
| 45 |
+
"dresser": 41,
|
| 46 |
+
"pillow": 42,
|
| 47 |
+
"sofa": 43,
|
| 48 |
+
"television": 44,
|
| 49 |
+
"floor mat": 45,
|
| 50 |
+
"curtain": 46,
|
| 51 |
+
"clothes": 47,
|
| 52 |
+
"stationery": 48,
|
| 53 |
+
"refrigerator": 49,
|
| 54 |
+
"board": 50,
|
| 55 |
+
"kitchen pan": 51,
|
| 56 |
+
"bin": 52,
|
| 57 |
+
"stove": 53,
|
| 58 |
+
"microwave": 54,
|
| 59 |
+
"plates": 55,
|
| 60 |
+
"bowl": 56,
|
| 61 |
+
"oven": 57,
|
| 62 |
+
"vase": 58,
|
| 63 |
+
"faucet": 59,
|
| 64 |
+
"tissues": 60,
|
| 65 |
+
"machine": 61,
|
| 66 |
+
"printer": 62,
|
| 67 |
+
"monitor": 63,
|
| 68 |
+
"podium": 64,
|
| 69 |
+
"cart": 65,
|
| 70 |
+
"projector": 66,
|
| 71 |
+
"electronics": 67,
|
| 72 |
+
"computer": 68,
|
| 73 |
+
"air conditioner": 69,
|
| 74 |
+
"drawers": 70,
|
| 75 |
+
"coffee maker": 71,
|
| 76 |
+
"toaster": 72,
|
| 77 |
+
"potted plant": 73,
|
| 78 |
+
"painting": 74,
|
| 79 |
+
"bag": 75,
|
| 80 |
+
"tray": 76,
|
| 81 |
+
"keyboard": 77,
|
| 82 |
+
"blanket": 78,
|
| 83 |
+
"rack": 79,
|
| 84 |
+
"phone": 80,
|
| 85 |
+
"mouse": 81,
|
| 86 |
+
"fire extinguisher": 82,
|
| 87 |
+
"toys": 83,
|
| 88 |
+
"ladder": 84,
|
| 89 |
+
"fan": 85,
|
| 90 |
+
"glass": 86,
|
| 91 |
+
"clock": 87,
|
| 92 |
+
"toilet paper": 88,
|
| 93 |
+
"closet": 89,
|
| 94 |
+
"fume hood": 90,
|
| 95 |
+
"utensils": 91,
|
| 96 |
+
"soundsystem": 92,
|
| 97 |
+
"fire place": 93,
|
| 98 |
+
"shower curtain": 94,
|
| 99 |
+
"remote": 95,
|
| 100 |
+
"pen": 96,
|
| 101 |
+
"fireplace": 97,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# Used for Cube R-CNN and Omni3D benchmark
|
| 105 |
+
omni3d_det_map = {
|
| 106 |
+
"pedestrian": 0,
|
| 107 |
+
"car": 1,
|
| 108 |
+
"cyclist": 2,
|
| 109 |
+
"van": 3,
|
| 110 |
+
"truck": 4,
|
| 111 |
+
"traffic cone": 5,
|
| 112 |
+
"barrier": 6,
|
| 113 |
+
"motorcycle": 7,
|
| 114 |
+
"bicycle": 8,
|
| 115 |
+
"bus": 9,
|
| 116 |
+
"trailer": 10,
|
| 117 |
+
"books": 11,
|
| 118 |
+
"bottle": 12,
|
| 119 |
+
"camera": 13,
|
| 120 |
+
"cereal box": 14,
|
| 121 |
+
"chair": 15,
|
| 122 |
+
"cup": 16,
|
| 123 |
+
"laptop": 17,
|
| 124 |
+
"shoes": 18,
|
| 125 |
+
"towel": 19,
|
| 126 |
+
"blinds": 20,
|
| 127 |
+
"window": 21,
|
| 128 |
+
"lamp": 22,
|
| 129 |
+
"shelves": 23,
|
| 130 |
+
"mirror": 24,
|
| 131 |
+
"sink": 25,
|
| 132 |
+
"cabinet": 26,
|
| 133 |
+
"bathtub": 27,
|
| 134 |
+
"door": 28,
|
| 135 |
+
"toilet": 29,
|
| 136 |
+
"desk": 30,
|
| 137 |
+
"box": 31,
|
| 138 |
+
"bookcase": 32,
|
| 139 |
+
"picture": 33,
|
| 140 |
+
"table": 34,
|
| 141 |
+
"counter": 35,
|
| 142 |
+
"bed": 36,
|
| 143 |
+
"night stand": 37,
|
| 144 |
+
"pillow": 38,
|
| 145 |
+
"sofa": 39,
|
| 146 |
+
"television": 40,
|
| 147 |
+
"floor mat": 41,
|
| 148 |
+
"curtain": 42,
|
| 149 |
+
"clothes": 43,
|
| 150 |
+
"stationery": 44,
|
| 151 |
+
"refrigerator": 45,
|
| 152 |
+
"bin": 46,
|
| 153 |
+
"stove": 47,
|
| 154 |
+
"oven": 48,
|
| 155 |
+
"machine": 49,
|
| 156 |
+
}
|
opendet3d/data/datasets/omni3d/sunrgbd.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SUN RGB-D from Omni3D."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 9 |
+
from vis4d.data.datasets.util import im_decode
|
| 10 |
+
|
| 11 |
+
from opendet3d.data.datasets.coco3d import COCO3DDataset
|
| 12 |
+
|
| 13 |
+
from .omni3d_classes import omni3d_class_map
|
| 14 |
+
|
| 15 |
+
# Train and Test are sharing the classes
|
| 16 |
+
sun_rgbd_train_det_map = sun_rgbd_test_det_map = {
|
| 17 |
+
"air conditioner": 0,
|
| 18 |
+
"bag": 1,
|
| 19 |
+
"bathtub": 2,
|
| 20 |
+
"bed": 3,
|
| 21 |
+
"bicycle": 4,
|
| 22 |
+
"bin": 5,
|
| 23 |
+
"blanket": 6,
|
| 24 |
+
"blinds": 7,
|
| 25 |
+
"board": 8,
|
| 26 |
+
"bookcase": 9,
|
| 27 |
+
"books": 10,
|
| 28 |
+
"bottle": 11,
|
| 29 |
+
"bowl": 12,
|
| 30 |
+
"box": 13,
|
| 31 |
+
"cabinet": 14,
|
| 32 |
+
"cart": 15,
|
| 33 |
+
"chair": 16,
|
| 34 |
+
"clock": 17,
|
| 35 |
+
"closet": 18,
|
| 36 |
+
"clothes": 19,
|
| 37 |
+
"coffee maker": 20,
|
| 38 |
+
"computer": 21,
|
| 39 |
+
"counter": 22,
|
| 40 |
+
"cup": 23,
|
| 41 |
+
"curtain": 24,
|
| 42 |
+
"desk": 25,
|
| 43 |
+
"door": 26,
|
| 44 |
+
"drawers": 27,
|
| 45 |
+
"dresser": 28,
|
| 46 |
+
"electronics": 29,
|
| 47 |
+
"fan": 30,
|
| 48 |
+
"faucet": 31,
|
| 49 |
+
"fire extinguisher": 32,
|
| 50 |
+
"fire place": 33,
|
| 51 |
+
"floor mat": 34,
|
| 52 |
+
"fume hood": 35,
|
| 53 |
+
"glass": 36,
|
| 54 |
+
"keyboard": 37,
|
| 55 |
+
"kitchen pan": 38,
|
| 56 |
+
"ladder": 39,
|
| 57 |
+
"lamp": 40,
|
| 58 |
+
"laptop": 41,
|
| 59 |
+
"machine": 42,
|
| 60 |
+
"microwave": 43,
|
| 61 |
+
"mirror": 44,
|
| 62 |
+
"monitor": 45,
|
| 63 |
+
"mouse": 46,
|
| 64 |
+
"night stand": 47,
|
| 65 |
+
"oven": 48,
|
| 66 |
+
"painting": 49,
|
| 67 |
+
"pen": 50,
|
| 68 |
+
"person": 51,
|
| 69 |
+
"phone": 52,
|
| 70 |
+
"picture": 53,
|
| 71 |
+
"pillow": 54,
|
| 72 |
+
"plates": 55,
|
| 73 |
+
"podium": 56,
|
| 74 |
+
"potted plant": 57,
|
| 75 |
+
"printer": 58,
|
| 76 |
+
"projector": 59,
|
| 77 |
+
"rack": 60,
|
| 78 |
+
"refrigerator": 61,
|
| 79 |
+
"remote": 62,
|
| 80 |
+
"shelves": 63,
|
| 81 |
+
"shoes": 64,
|
| 82 |
+
"shower curtain": 65,
|
| 83 |
+
"sink": 66,
|
| 84 |
+
"sofa": 67,
|
| 85 |
+
"soundsystem": 68,
|
| 86 |
+
"stationery": 69,
|
| 87 |
+
"stove": 70,
|
| 88 |
+
"table": 71,
|
| 89 |
+
"television": 72,
|
| 90 |
+
"tissues": 73,
|
| 91 |
+
"toaster": 74,
|
| 92 |
+
"toilet": 75,
|
| 93 |
+
"toilet paper": 76,
|
| 94 |
+
"towel": 77,
|
| 95 |
+
"toys": 78,
|
| 96 |
+
"tray": 79,
|
| 97 |
+
"utensils": 80,
|
| 98 |
+
"vase": 81,
|
| 99 |
+
"window": 82,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
sun_rgbd_val_det_map = {
|
| 103 |
+
"air conditioner": 0,
|
| 104 |
+
"bag": 1,
|
| 105 |
+
"bathtub": 2,
|
| 106 |
+
"bed": 3,
|
| 107 |
+
"bin": 4,
|
| 108 |
+
"blanket": 5,
|
| 109 |
+
"blinds": 6,
|
| 110 |
+
"board": 7,
|
| 111 |
+
"bookcase": 8,
|
| 112 |
+
"books": 9,
|
| 113 |
+
"bottle": 10,
|
| 114 |
+
"bowl": 11,
|
| 115 |
+
"box": 12,
|
| 116 |
+
"cabinet": 13,
|
| 117 |
+
"cart": 14,
|
| 118 |
+
"chair": 15,
|
| 119 |
+
"closet": 16,
|
| 120 |
+
"clothes": 17,
|
| 121 |
+
"coffee maker": 18,
|
| 122 |
+
"computer": 19,
|
| 123 |
+
"counter": 20,
|
| 124 |
+
"cup": 21,
|
| 125 |
+
"curtain": 22,
|
| 126 |
+
"desk": 23,
|
| 127 |
+
"door": 24,
|
| 128 |
+
"drawers": 25,
|
| 129 |
+
"dresser": 26,
|
| 130 |
+
"electronics": 27,
|
| 131 |
+
"fan": 28,
|
| 132 |
+
"faucet": 29,
|
| 133 |
+
"fire extinguisher": 30,
|
| 134 |
+
"fire place": 31,
|
| 135 |
+
"fume hood": 32,
|
| 136 |
+
"keyboard": 33,
|
| 137 |
+
"kitchen pan": 34,
|
| 138 |
+
"lamp": 35,
|
| 139 |
+
"laptop": 36,
|
| 140 |
+
"machine": 37,
|
| 141 |
+
"microwave": 38,
|
| 142 |
+
"mirror": 39,
|
| 143 |
+
"monitor": 40,
|
| 144 |
+
"night stand": 41,
|
| 145 |
+
"oven": 42,
|
| 146 |
+
"painting": 43,
|
| 147 |
+
"pen": 44,
|
| 148 |
+
"person": 45,
|
| 149 |
+
"phone": 46,
|
| 150 |
+
"picture": 47,
|
| 151 |
+
"pillow": 48,
|
| 152 |
+
"plates": 49,
|
| 153 |
+
"potted plant": 50,
|
| 154 |
+
"printer": 51,
|
| 155 |
+
"projector": 52,
|
| 156 |
+
"rack": 53,
|
| 157 |
+
"refrigerator": 54,
|
| 158 |
+
"shelves": 55,
|
| 159 |
+
"sink": 56,
|
| 160 |
+
"sofa": 57,
|
| 161 |
+
"soundsystem": 58,
|
| 162 |
+
"stationery": 59,
|
| 163 |
+
"stove": 60,
|
| 164 |
+
"table": 61,
|
| 165 |
+
"television": 62,
|
| 166 |
+
"tissues": 63,
|
| 167 |
+
"toaster": 64,
|
| 168 |
+
"toilet": 65,
|
| 169 |
+
"towel": 66,
|
| 170 |
+
"toys": 67,
|
| 171 |
+
"tray": 68,
|
| 172 |
+
"utensils": 69,
|
| 173 |
+
"vase": 70,
|
| 174 |
+
"window": 71,
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
omni3d_sun_rgbd_det_map = {
|
| 178 |
+
"bicycle": 0,
|
| 179 |
+
"books": 1,
|
| 180 |
+
"bottle": 2,
|
| 181 |
+
"chair": 3,
|
| 182 |
+
"cup": 4,
|
| 183 |
+
"laptop": 5,
|
| 184 |
+
"shoes": 6,
|
| 185 |
+
"towel": 7,
|
| 186 |
+
"blinds": 8,
|
| 187 |
+
"window": 9,
|
| 188 |
+
"lamp": 10,
|
| 189 |
+
"shelves": 11,
|
| 190 |
+
"mirror": 12,
|
| 191 |
+
"sink": 13,
|
| 192 |
+
"cabinet": 14,
|
| 193 |
+
"bathtub": 15,
|
| 194 |
+
"door": 16,
|
| 195 |
+
"toilet": 17,
|
| 196 |
+
"desk": 18,
|
| 197 |
+
"box": 19,
|
| 198 |
+
"bookcase": 20,
|
| 199 |
+
"picture": 21,
|
| 200 |
+
"table": 22,
|
| 201 |
+
"counter": 23,
|
| 202 |
+
"bed": 24,
|
| 203 |
+
"night stand": 25,
|
| 204 |
+
"pillow": 26,
|
| 205 |
+
"sofa": 27,
|
| 206 |
+
"television": 28,
|
| 207 |
+
"floor mat": 29,
|
| 208 |
+
"curtain": 30,
|
| 209 |
+
"clothes": 31,
|
| 210 |
+
"stationery": 32,
|
| 211 |
+
"refrigerator": 33,
|
| 212 |
+
"bin": 34,
|
| 213 |
+
"stove": 35,
|
| 214 |
+
"oven": 36,
|
| 215 |
+
"machine": 37,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_sunrgbd_det_map(split: str) -> dict[str, int]:
|
| 220 |
+
"""Get the SUN RGB-D detection map."""
|
| 221 |
+
assert split in {"train", "val", "test"}, f"Invalid split: {split}"
|
| 222 |
+
|
| 223 |
+
if split == "train":
|
| 224 |
+
return sun_rgbd_train_det_map
|
| 225 |
+
elif split == "val":
|
| 226 |
+
return sun_rgbd_val_det_map
|
| 227 |
+
else:
|
| 228 |
+
return sun_rgbd_test_det_map
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class SUNRGBD(COCO3DDataset):
|
| 232 |
+
"""SUN RGB-D Dataset."""
|
| 233 |
+
|
| 234 |
+
def __init__(
|
| 235 |
+
self,
|
| 236 |
+
class_map: dict[str, int] = omni3d_class_map,
|
| 237 |
+
max_depth: float = 8.0,
|
| 238 |
+
depth_scale: float = 1000.0,
|
| 239 |
+
**kwargs: ArgsType,
|
| 240 |
+
) -> None:
|
| 241 |
+
"""Initialize SUN RGB-D dataset."""
|
| 242 |
+
super().__init__(
|
| 243 |
+
class_map=class_map,
|
| 244 |
+
max_depth=max_depth,
|
| 245 |
+
depth_scale=depth_scale,
|
| 246 |
+
**kwargs,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
def get_depth_filenames(self, img: DictStrAny) -> str | None:
|
| 250 |
+
"""Get the depth filenames.
|
| 251 |
+
|
| 252 |
+
Since not every data has depth.
|
| 253 |
+
"""
|
| 254 |
+
img["file_path"] = img["file_path"].replace("//", "/")
|
| 255 |
+
|
| 256 |
+
data_dir = img["file_path"].split("/image")[0]
|
| 257 |
+
|
| 258 |
+
depth_files = self.data_backend.listdir(
|
| 259 |
+
os.path.join(data_dir, "depth")
|
| 260 |
+
)
|
| 261 |
+
assert len(depth_files) == 1
|
| 262 |
+
|
| 263 |
+
depth_filename = os.path.join(data_dir, "depth", depth_files[0])
|
| 264 |
+
|
| 265 |
+
return depth_filename
|
| 266 |
+
|
| 267 |
+
def get_depth_map(self, sample: DictStrAny) -> np.ndarray:
|
| 268 |
+
"""Get the depth map."""
|
| 269 |
+
depth_bytes = self.data_backend.get(sample["depth_filename"])
|
| 270 |
+
depth_array = im_decode(depth_bytes)
|
| 271 |
+
|
| 272 |
+
depth_array = depth_array >> 3 | depth_array << (16 - 3)
|
| 273 |
+
|
| 274 |
+
depth = np.ascontiguousarray(depth_array, dtype=np.float32)
|
| 275 |
+
|
| 276 |
+
depth = depth / self.depth_scale
|
| 277 |
+
|
| 278 |
+
return depth
|
opendet3d/data/datasets/omni3d/util.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Omni3D data util."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from .arkitscenes import arkitscenes_det_map, omni3d_arkitscenes_det_map
|
| 6 |
+
from .hypersim import get_hypersim_det_map, omni3d_hypersim_det_map
|
| 7 |
+
from .kitti_object import get_kitti_det_map, omni3d_kitti_det_map
|
| 8 |
+
from .nuscenes import nusc_det_map
|
| 9 |
+
from .objectron import objectron_det_map
|
| 10 |
+
from .sunrgbd import get_sunrgbd_det_map, omni3d_sun_rgbd_det_map
|
| 11 |
+
|
| 12 |
+
DATASET_ID_MAP = {
|
| 13 |
+
0: "KITTI_train",
|
| 14 |
+
1: "KITTI_val",
|
| 15 |
+
2: "KITTI_test",
|
| 16 |
+
3: "nuScenes_train",
|
| 17 |
+
4: "nuScenes_val",
|
| 18 |
+
5: "nuScenes_test",
|
| 19 |
+
6: "Objectron_train",
|
| 20 |
+
7: "Objectron_val",
|
| 21 |
+
8: "Objectron_test",
|
| 22 |
+
9: "Hypersim_train",
|
| 23 |
+
10: "Hypersim_val",
|
| 24 |
+
11: "Hypersim_test",
|
| 25 |
+
12: "SUNRGBD_train",
|
| 26 |
+
13: "SUNRGBD_val",
|
| 27 |
+
14: "SUNRGBD_test",
|
| 28 |
+
15: "ARKitScenes_train",
|
| 29 |
+
16: "ARKitScenes_val",
|
| 30 |
+
17: "ARKitScenes_test",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_dataset_det_map(
|
| 35 |
+
dataset_name: str,
|
| 36 |
+
omni3d50: bool = True,
|
| 37 |
+
) -> tuple[str, dict[str, int]]:
|
| 38 |
+
"""Get the detection map."""
|
| 39 |
+
if "train" in dataset_name:
|
| 40 |
+
split = "train"
|
| 41 |
+
elif "val" in dataset_name:
|
| 42 |
+
split = "val"
|
| 43 |
+
elif "test" in dataset_name:
|
| 44 |
+
split = "test"
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f"Unknown dataset_name: {dataset_name}")
|
| 47 |
+
|
| 48 |
+
if "nuScenes" in dataset_name:
|
| 49 |
+
det_map = nusc_det_map
|
| 50 |
+
elif "KITTI" in dataset_name:
|
| 51 |
+
if omni3d50:
|
| 52 |
+
det_map = omni3d_kitti_det_map
|
| 53 |
+
else:
|
| 54 |
+
det_map = get_kitti_det_map(split)
|
| 55 |
+
elif "Objectron" in dataset_name:
|
| 56 |
+
det_map = objectron_det_map
|
| 57 |
+
elif "SUNRGBD" in dataset_name:
|
| 58 |
+
if omni3d50:
|
| 59 |
+
det_map = omni3d_sun_rgbd_det_map
|
| 60 |
+
else:
|
| 61 |
+
det_map = get_sunrgbd_det_map(split)
|
| 62 |
+
elif "Hypersim" in dataset_name:
|
| 63 |
+
if omni3d50:
|
| 64 |
+
det_map = omni3d_hypersim_det_map
|
| 65 |
+
else:
|
| 66 |
+
det_map = get_hypersim_det_map(split)
|
| 67 |
+
elif "ARKitScenes" in dataset_name:
|
| 68 |
+
det_map = (
|
| 69 |
+
omni3d_arkitscenes_det_map if omni3d50 else arkitscenes_det_map
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"Unknown dataset_name: {dataset_name}")
|
| 73 |
+
|
| 74 |
+
return det_map
|
opendet3d/data/datasets/scannet.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ScanNet dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from vis4d.common.typing import ArgsType, DictStrAny
|
| 6 |
+
|
| 7 |
+
from .coco3d import COCO3DDataset
|
| 8 |
+
|
| 9 |
+
scannet_class_map = {
|
| 10 |
+
"cabinet": 3,
|
| 11 |
+
"bed": 4,
|
| 12 |
+
"chair": 5,
|
| 13 |
+
"sofa": 6,
|
| 14 |
+
"table": 7,
|
| 15 |
+
"door": 8,
|
| 16 |
+
"window": 9,
|
| 17 |
+
"bookshelf": 10,
|
| 18 |
+
"picture": 11,
|
| 19 |
+
"counter": 12,
|
| 20 |
+
"desk": 14,
|
| 21 |
+
"curtain": 16,
|
| 22 |
+
"refrigerator": 24,
|
| 23 |
+
"shower curtain": 28,
|
| 24 |
+
"toilet": 33,
|
| 25 |
+
"sink": 34,
|
| 26 |
+
"bathtub": 36,
|
| 27 |
+
"other furniture": 39,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
scannet_det_map = {
|
| 31 |
+
"cabinet": 0,
|
| 32 |
+
"bed": 1,
|
| 33 |
+
"chair": 2,
|
| 34 |
+
"sofa": 3,
|
| 35 |
+
"table": 4,
|
| 36 |
+
"door": 5,
|
| 37 |
+
"window": 6,
|
| 38 |
+
"bookshelf": 7,
|
| 39 |
+
"picture": 8,
|
| 40 |
+
"counter": 9,
|
| 41 |
+
"desk": 10,
|
| 42 |
+
"curtain": 11,
|
| 43 |
+
"refrigerator": 12,
|
| 44 |
+
"shower curtain": 13,
|
| 45 |
+
"toilet": 14,
|
| 46 |
+
"sink": 15,
|
| 47 |
+
"bathtub": 16,
|
| 48 |
+
"other furniture": 17,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
scannet200_class_map = {
|
| 52 |
+
"chair": 2,
|
| 53 |
+
"book": 22,
|
| 54 |
+
"door": 5,
|
| 55 |
+
"object": 1163,
|
| 56 |
+
"window": 16,
|
| 57 |
+
"table": 4,
|
| 58 |
+
"trash can": 56,
|
| 59 |
+
"pillow": 13,
|
| 60 |
+
"picture": 15,
|
| 61 |
+
"box": 26,
|
| 62 |
+
"doorframe": 161,
|
| 63 |
+
"monitor": 19,
|
| 64 |
+
"cabinet": 7,
|
| 65 |
+
"desk": 9,
|
| 66 |
+
"shelf": 8,
|
| 67 |
+
"office chair": 10,
|
| 68 |
+
"towel": 31,
|
| 69 |
+
"couch": 6,
|
| 70 |
+
"sink": 14,
|
| 71 |
+
"backpack": 48,
|
| 72 |
+
"lamp": 28,
|
| 73 |
+
"bed": 11,
|
| 74 |
+
"bookshelf": 18,
|
| 75 |
+
"mirror": 71,
|
| 76 |
+
"curtain": 21,
|
| 77 |
+
"plant": 40,
|
| 78 |
+
"whiteboard": 52,
|
| 79 |
+
"radiator": 96,
|
| 80 |
+
"kitchen cabinet": 29,
|
| 81 |
+
"toilet paper": 49,
|
| 82 |
+
"armchair": 23,
|
| 83 |
+
"shoe": 63,
|
| 84 |
+
"coffee table": 24,
|
| 85 |
+
"toilet": 17,
|
| 86 |
+
"bag": 47,
|
| 87 |
+
"clothes": 32,
|
| 88 |
+
"keyboard": 46,
|
| 89 |
+
"bottle": 65,
|
| 90 |
+
"recycling bin": 97,
|
| 91 |
+
"nightstand": 34,
|
| 92 |
+
"stool": 38,
|
| 93 |
+
"tv": 33,
|
| 94 |
+
"file cabinet": 75,
|
| 95 |
+
"dresser": 36,
|
| 96 |
+
"computer tower": 64,
|
| 97 |
+
"telephone": 101,
|
| 98 |
+
"cup": 130,
|
| 99 |
+
"refrigerator": 27,
|
| 100 |
+
"end table": 44,
|
| 101 |
+
"jacket": 131,
|
| 102 |
+
"shower curtain": 55,
|
| 103 |
+
"bathtub": 42,
|
| 104 |
+
"microwave": 59,
|
| 105 |
+
"kitchen counter": 159,
|
| 106 |
+
"sofa chair": 74,
|
| 107 |
+
"paper towel dispenser": 82,
|
| 108 |
+
"bathroom vanity": 1164,
|
| 109 |
+
"suitcase": 93,
|
| 110 |
+
"laptop": 77,
|
| 111 |
+
"ottoman": 67,
|
| 112 |
+
"shower wall": 128,
|
| 113 |
+
"printer": 50,
|
| 114 |
+
"counter": 35,
|
| 115 |
+
"board": 69,
|
| 116 |
+
"soap dispenser": 100,
|
| 117 |
+
"stove": 62,
|
| 118 |
+
"light": 105,
|
| 119 |
+
"closet wall": 1165,
|
| 120 |
+
"mini fridge": 165,
|
| 121 |
+
"fan": 76,
|
| 122 |
+
"tissue box": 230,
|
| 123 |
+
"blanket": 54,
|
| 124 |
+
"bathroom stall": 125,
|
| 125 |
+
"copier": 72,
|
| 126 |
+
"bench": 68,
|
| 127 |
+
"bar": 145,
|
| 128 |
+
"soap dish": 157,
|
| 129 |
+
"laundry hamper": 1166,
|
| 130 |
+
"storage bin": 132,
|
| 131 |
+
"bathroom stall door": 1167,
|
| 132 |
+
"light switch": 232,
|
| 133 |
+
"coffee maker": 134,
|
| 134 |
+
"tv stand": 51,
|
| 135 |
+
"decoration": 250,
|
| 136 |
+
"ceiling light": 1168,
|
| 137 |
+
"range hood": 342,
|
| 138 |
+
"blackboard": 89,
|
| 139 |
+
"clock": 103,
|
| 140 |
+
"wardrobe": 99,
|
| 141 |
+
"rail": 95,
|
| 142 |
+
"bulletin board": 154,
|
| 143 |
+
"mat": 140,
|
| 144 |
+
"trash bin": 1169,
|
| 145 |
+
"ledge": 193,
|
| 146 |
+
"seat": 116,
|
| 147 |
+
"mouse": 202,
|
| 148 |
+
"basket": 73,
|
| 149 |
+
"shower": 78,
|
| 150 |
+
"dumbbell": 1170,
|
| 151 |
+
"paper": 79,
|
| 152 |
+
"person": 80,
|
| 153 |
+
"windowsill": 141,
|
| 154 |
+
"closet": 57,
|
| 155 |
+
"bucket": 102,
|
| 156 |
+
"sign": 261,
|
| 157 |
+
"speaker": 118,
|
| 158 |
+
"dishwasher": 136,
|
| 159 |
+
"container": 98,
|
| 160 |
+
"stair rail": 1171,
|
| 161 |
+
"shower curtain rod": 170,
|
| 162 |
+
"tube": 1172,
|
| 163 |
+
"bathroom cabinet": 1173,
|
| 164 |
+
"storage container": 221,
|
| 165 |
+
"paper bag": 570,
|
| 166 |
+
"paper towel roll": 138,
|
| 167 |
+
"ball": 168,
|
| 168 |
+
"closet door": 276,
|
| 169 |
+
"laundry basket": 106,
|
| 170 |
+
"cart": 214,
|
| 171 |
+
"dish rack": 323,
|
| 172 |
+
"stairs": 58,
|
| 173 |
+
"blinds": 86,
|
| 174 |
+
"purse": 399,
|
| 175 |
+
"bicycle": 121,
|
| 176 |
+
"tray": 185,
|
| 177 |
+
"plunger": 300,
|
| 178 |
+
"paper cutter": 180,
|
| 179 |
+
"toilet paper dispenser": 163,
|
| 180 |
+
"bin": 66,
|
| 181 |
+
"toilet seat cover dispenser": 208,
|
| 182 |
+
"guitar": 112,
|
| 183 |
+
"mailbox": 540,
|
| 184 |
+
"handicap bar": 395,
|
| 185 |
+
"fire extinguisher": 166,
|
| 186 |
+
"ladder": 122,
|
| 187 |
+
"column": 120,
|
| 188 |
+
"pipe": 107,
|
| 189 |
+
"vacuum cleaner": 283,
|
| 190 |
+
"plate": 88,
|
| 191 |
+
"piano": 90,
|
| 192 |
+
"water cooler": 177,
|
| 193 |
+
"cd case": 1174,
|
| 194 |
+
"bowl": 562,
|
| 195 |
+
"closet rod": 1175,
|
| 196 |
+
"bathroom counter": 1156,
|
| 197 |
+
"oven": 84,
|
| 198 |
+
"stand": 104,
|
| 199 |
+
"scale": 229,
|
| 200 |
+
"washing machine": 70,
|
| 201 |
+
"broom": 325,
|
| 202 |
+
"hat": 169,
|
| 203 |
+
"guitar case": 331,
|
| 204 |
+
"rack": 87,
|
| 205 |
+
"water pitcher": 488,
|
| 206 |
+
"laundry detergent": 776,
|
| 207 |
+
"hair dryer": 370,
|
| 208 |
+
"pillar": 191,
|
| 209 |
+
"divider": 748,
|
| 210 |
+
"power outlet": 242,
|
| 211 |
+
"dining table": 45,
|
| 212 |
+
"shower floor": 417,
|
| 213 |
+
"shower door": 188,
|
| 214 |
+
"coffee kettle": 1176,
|
| 215 |
+
"structure": 1178,
|
| 216 |
+
"clothes dryer": 110,
|
| 217 |
+
"toaster": 148,
|
| 218 |
+
"ironing board": 155,
|
| 219 |
+
"alarm clock": 572,
|
| 220 |
+
"shower head": 1179,
|
| 221 |
+
"water bottle": 392,
|
| 222 |
+
"keyboard piano": 1180,
|
| 223 |
+
"projector screen": 609,
|
| 224 |
+
"case of water bottles": 1181,
|
| 225 |
+
"toaster oven": 195,
|
| 226 |
+
"music stand": 581,
|
| 227 |
+
"coat rack": 1182,
|
| 228 |
+
"storage organizer": 1183,
|
| 229 |
+
"machine": 139,
|
| 230 |
+
"folded chair": 1184,
|
| 231 |
+
"fire alarm": 1185,
|
| 232 |
+
"fireplace": 156,
|
| 233 |
+
"vent": 408,
|
| 234 |
+
"furniture": 213,
|
| 235 |
+
"power strip": 1186,
|
| 236 |
+
"calendar": 1187,
|
| 237 |
+
"poster": 1188,
|
| 238 |
+
"toilet paper holder": 115,
|
| 239 |
+
"potted plant": 1189,
|
| 240 |
+
"stuffed animal": 304,
|
| 241 |
+
"luggage": 1190,
|
| 242 |
+
"headphones": 312,
|
| 243 |
+
"crate": 233,
|
| 244 |
+
"candle": 286,
|
| 245 |
+
"projector": 264,
|
| 246 |
+
"mattress": 1191,
|
| 247 |
+
"dustpan": 356,
|
| 248 |
+
"cushion": 39,
|
| 249 |
+
"stick": 1163,
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
scannet200_det_map = {
|
| 253 |
+
"chair": 0,
|
| 254 |
+
"table": 1,
|
| 255 |
+
"door": 2,
|
| 256 |
+
"couch": 3,
|
| 257 |
+
"cabinet": 4,
|
| 258 |
+
"shelf": 5,
|
| 259 |
+
"desk": 6,
|
| 260 |
+
"office chair": 7,
|
| 261 |
+
"bed": 8,
|
| 262 |
+
"pillow": 9,
|
| 263 |
+
"sink": 10,
|
| 264 |
+
"picture": 11,
|
| 265 |
+
"window": 12,
|
| 266 |
+
"toilet": 13,
|
| 267 |
+
"bookshelf": 14,
|
| 268 |
+
"monitor": 15,
|
| 269 |
+
"curtain": 16,
|
| 270 |
+
"book": 17,
|
| 271 |
+
"armchair": 18,
|
| 272 |
+
"coffee table": 19,
|
| 273 |
+
"box": 20,
|
| 274 |
+
"refrigerator": 21,
|
| 275 |
+
"lamp": 22,
|
| 276 |
+
"kitchen cabinet": 23,
|
| 277 |
+
"towel": 24,
|
| 278 |
+
"clothes": 25,
|
| 279 |
+
"tv": 26,
|
| 280 |
+
"nightstand": 27,
|
| 281 |
+
"counter": 28,
|
| 282 |
+
"dresser": 29,
|
| 283 |
+
"stool": 30,
|
| 284 |
+
"plant": 31,
|
| 285 |
+
"bathtub": 32,
|
| 286 |
+
"end table": 33,
|
| 287 |
+
"dining table": 34,
|
| 288 |
+
"keyboard": 35,
|
| 289 |
+
"bag": 36,
|
| 290 |
+
"backpack": 37,
|
| 291 |
+
"toilet paper": 38,
|
| 292 |
+
"printer": 39,
|
| 293 |
+
"tv stand": 40,
|
| 294 |
+
"whiteboard": 41,
|
| 295 |
+
"blanket": 42,
|
| 296 |
+
"shower curtain": 43,
|
| 297 |
+
"trash can": 44,
|
| 298 |
+
"closet": 45,
|
| 299 |
+
"stairs": 46,
|
| 300 |
+
"microwave": 47,
|
| 301 |
+
"stove": 48,
|
| 302 |
+
"shoe": 49,
|
| 303 |
+
"computer tower": 50,
|
| 304 |
+
"bottle": 51,
|
| 305 |
+
"bin": 52,
|
| 306 |
+
"ottoman": 53,
|
| 307 |
+
"bench": 54,
|
| 308 |
+
"board": 55,
|
| 309 |
+
"washing machine": 56,
|
| 310 |
+
"mirror": 57,
|
| 311 |
+
"copier": 58,
|
| 312 |
+
"basket": 59,
|
| 313 |
+
"sofa chair": 60,
|
| 314 |
+
"file cabinet": 61,
|
| 315 |
+
"fan": 62,
|
| 316 |
+
"laptop": 63,
|
| 317 |
+
"shower": 64,
|
| 318 |
+
"paper": 65,
|
| 319 |
+
"person": 66,
|
| 320 |
+
"paper towel dispenser": 67,
|
| 321 |
+
"oven": 68,
|
| 322 |
+
"blinds": 69,
|
| 323 |
+
"rack": 70,
|
| 324 |
+
"plate": 71,
|
| 325 |
+
"blackboard": 72,
|
| 326 |
+
"piano": 73,
|
| 327 |
+
"suitcase": 74,
|
| 328 |
+
"rail": 75,
|
| 329 |
+
"radiator": 76,
|
| 330 |
+
"recycling bin": 77,
|
| 331 |
+
"container": 78,
|
| 332 |
+
"wardrobe": 79,
|
| 333 |
+
"soap dispenser": 80,
|
| 334 |
+
"telephone": 81,
|
| 335 |
+
"bucket": 82,
|
| 336 |
+
"clock": 83,
|
| 337 |
+
"stand": 84,
|
| 338 |
+
"light": 85,
|
| 339 |
+
"laundry basket": 86,
|
| 340 |
+
"pipe": 87,
|
| 341 |
+
"clothes dryer": 88,
|
| 342 |
+
"guitar": 89,
|
| 343 |
+
"toilet paper holder": 90,
|
| 344 |
+
"seat": 91,
|
| 345 |
+
"speaker": 92,
|
| 346 |
+
"column": 93,
|
| 347 |
+
"ladder": 94,
|
| 348 |
+
"cup": 95,
|
| 349 |
+
"jacket": 96,
|
| 350 |
+
"storage bin": 97,
|
| 351 |
+
"coffee maker": 98,
|
| 352 |
+
"dishwasher": 99,
|
| 353 |
+
"paper towel roll": 100,
|
| 354 |
+
"machine": 101,
|
| 355 |
+
"mat": 102,
|
| 356 |
+
"windowsill": 103,
|
| 357 |
+
"bar": 104,
|
| 358 |
+
"bulletin board": 105,
|
| 359 |
+
"ironing board": 106,
|
| 360 |
+
"fireplace": 107,
|
| 361 |
+
"soap dish": 108,
|
| 362 |
+
"kitchen counter": 109,
|
| 363 |
+
"doorframe": 110,
|
| 364 |
+
"toilet paper dispenser": 111,
|
| 365 |
+
"mini fridge": 112,
|
| 366 |
+
"fire extinguisher": 113,
|
| 367 |
+
"ball": 114,
|
| 368 |
+
"hat": 115,
|
| 369 |
+
"shower curtain rod": 116,
|
| 370 |
+
"water cooler": 117,
|
| 371 |
+
"paper cutter": 118,
|
| 372 |
+
"tray": 119,
|
| 373 |
+
"pillar": 120,
|
| 374 |
+
"ledge": 121,
|
| 375 |
+
"toaster oven": 122,
|
| 376 |
+
"mouse": 123,
|
| 377 |
+
"toilet seat cover dispenser": 124,
|
| 378 |
+
"cart": 125,
|
| 379 |
+
"scale": 126,
|
| 380 |
+
"tissue box": 127,
|
| 381 |
+
"light switch": 128,
|
| 382 |
+
"crate": 129,
|
| 383 |
+
"power outlet": 130,
|
| 384 |
+
"decoration": 131,
|
| 385 |
+
"sign": 132,
|
| 386 |
+
"projector": 133,
|
| 387 |
+
"closet door": 134,
|
| 388 |
+
"vacuum cleaner": 135,
|
| 389 |
+
"headphones": 136,
|
| 390 |
+
"dish rack": 137,
|
| 391 |
+
"broom": 138,
|
| 392 |
+
"range hood": 139,
|
| 393 |
+
"hair dryer": 140,
|
| 394 |
+
"water bottle": 141,
|
| 395 |
+
"vent": 142,
|
| 396 |
+
"mailbox": 143,
|
| 397 |
+
"bowl": 144,
|
| 398 |
+
"paper bag": 145,
|
| 399 |
+
"projector screen": 146,
|
| 400 |
+
"divider": 147,
|
| 401 |
+
"laundry detergent": 148,
|
| 402 |
+
"bathroom counter": 149,
|
| 403 |
+
"stick": 150,
|
| 404 |
+
"bathroom vanity": 151,
|
| 405 |
+
"closet wall": 152,
|
| 406 |
+
"laundry hamper": 153,
|
| 407 |
+
"bathroom stall door": 154,
|
| 408 |
+
"ceiling light": 155,
|
| 409 |
+
"trash bin": 156,
|
| 410 |
+
"dumbbell": 157,
|
| 411 |
+
"stair rail": 158,
|
| 412 |
+
"tube": 159,
|
| 413 |
+
"bathroom cabinet": 160,
|
| 414 |
+
"coffee kettle": 161,
|
| 415 |
+
"shower head": 162,
|
| 416 |
+
"case of water bottles": 163,
|
| 417 |
+
"power strip": 164,
|
| 418 |
+
"calendar": 165,
|
| 419 |
+
"poster": 166,
|
| 420 |
+
"mattress": 167,
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class ScanNetDataset(COCO3DDataset):
|
| 425 |
+
"""ScanNetV2 dataset."""
|
| 426 |
+
|
| 427 |
+
def __init__(
|
| 428 |
+
self,
|
| 429 |
+
class_map: dict[str, int] = scannet_class_map,
|
| 430 |
+
max_depth: float = 12.0,
|
| 431 |
+
depth_scale: float = 1000.0,
|
| 432 |
+
**kwargs: ArgsType,
|
| 433 |
+
) -> None:
|
| 434 |
+
"""Creates an instance of the class."""
|
| 435 |
+
super().__init__(
|
| 436 |
+
class_map=class_map,
|
| 437 |
+
max_depth=max_depth,
|
| 438 |
+
depth_scale=depth_scale,
|
| 439 |
+
**kwargs,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
def get_depth_filenames(self, img: DictStrAny) -> str | None:
|
| 443 |
+
"""Get the depth filenames.
|
| 444 |
+
|
| 445 |
+
Since not every data has depth.
|
| 446 |
+
"""
|
| 447 |
+
return (
|
| 448 |
+
img["file_path"].replace("image", "depth").replace(".jpg", ".png")
|
| 449 |
+
)
|
opendet3d/data/transforms/__init__.py
ADDED
|
File without changes
|
opendet3d/data/transforms/crop.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Crop transforms."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from vis4d.common.typing import (
|
| 6 |
+
NDArrayBool,
|
| 7 |
+
NDArrayF32,
|
| 8 |
+
NDArrayI64,
|
| 9 |
+
)
|
| 10 |
+
from vis4d.data.const import CommonKeys as K
|
| 11 |
+
from vis4d.data.transforms.base import Transform
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@Transform(
|
| 15 |
+
in_keys=[
|
| 16 |
+
K.boxes3d,
|
| 17 |
+
K.boxes3d_classes,
|
| 18 |
+
K.boxes3d_track_ids,
|
| 19 |
+
"transforms.crop.keep_mask",
|
| 20 |
+
],
|
| 21 |
+
out_keys=[K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids],
|
| 22 |
+
)
|
| 23 |
+
class CropBoxes3D:
|
| 24 |
+
"""Crop 3D bounding boxes."""
|
| 25 |
+
|
| 26 |
+
def __call__(
|
| 27 |
+
self,
|
| 28 |
+
boxes_list: list[NDArrayF32],
|
| 29 |
+
classes_list: list[NDArrayI64],
|
| 30 |
+
track_ids_list: list[NDArrayI64] | None,
|
| 31 |
+
keep_mask_list: list[NDArrayBool],
|
| 32 |
+
) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]:
|
| 33 |
+
"""Crop 3D bounding boxes."""
|
| 34 |
+
for i, (boxes, classes, keep_mask) in enumerate(
|
| 35 |
+
zip(boxes_list, classes_list, keep_mask_list)
|
| 36 |
+
):
|
| 37 |
+
boxes_list[i] = boxes[keep_mask]
|
| 38 |
+
classes_list[i] = classes[keep_mask]
|
| 39 |
+
|
| 40 |
+
if track_ids_list is not None:
|
| 41 |
+
track_ids_list[i] = track_ids_list[i][keep_mask]
|
| 42 |
+
|
| 43 |
+
return boxes_list, classes_list, track_ids_list
|
opendet3d/data/transforms/language.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Language related transforms."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
from vis4d.common.logging import rank_zero_warn
|
| 11 |
+
from vis4d.common.typing import NDArrayF32, NDArrayI64
|
| 12 |
+
from vis4d.data.const import CommonKeys as K
|
| 13 |
+
from vis4d.data.transforms.base import Transform
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def clean_name(name: str) -> str:
|
| 17 |
+
"""Clean the name."""
|
| 18 |
+
name = re.sub(r"\(.*\)", "", name)
|
| 19 |
+
name = re.sub(r"_", " ", name)
|
| 20 |
+
name = re.sub(r" ", " ", name)
|
| 21 |
+
name = name.lower()
|
| 22 |
+
return name
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def generate_senetence_given_labels(
|
| 26 |
+
positive_label_list: list[int],
|
| 27 |
+
negative_label_list: list[str],
|
| 28 |
+
label_map: dict[str, str],
|
| 29 |
+
) -> tuple[dict[int, list[list[int]]], str, dict[int, int]]:
|
| 30 |
+
"""Generate a sentence given positive and negative labels."""
|
| 31 |
+
label_to_positions = {}
|
| 32 |
+
|
| 33 |
+
label_list = negative_label_list + positive_label_list
|
| 34 |
+
|
| 35 |
+
random.shuffle(label_list)
|
| 36 |
+
|
| 37 |
+
pheso_caption = ""
|
| 38 |
+
|
| 39 |
+
label_remap_dict = {}
|
| 40 |
+
for index, label in enumerate(label_list):
|
| 41 |
+
start_index = len(pheso_caption)
|
| 42 |
+
|
| 43 |
+
pheso_caption += clean_name(label_map[str(label)])
|
| 44 |
+
|
| 45 |
+
end_index = len(pheso_caption)
|
| 46 |
+
|
| 47 |
+
if label in positive_label_list:
|
| 48 |
+
label_to_positions[index] = [[start_index, end_index]]
|
| 49 |
+
label_remap_dict[int(label)] = index
|
| 50 |
+
|
| 51 |
+
pheso_caption += ". "
|
| 52 |
+
|
| 53 |
+
return label_to_positions, pheso_caption, label_remap_dict
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@Transform(
|
| 57 |
+
[
|
| 58 |
+
"dataset_type",
|
| 59 |
+
K.boxes2d,
|
| 60 |
+
K.boxes2d_classes,
|
| 61 |
+
K.boxes2d_names,
|
| 62 |
+
"label_map",
|
| 63 |
+
"positive_positions",
|
| 64 |
+
],
|
| 65 |
+
[K.boxes2d, K.boxes2d_classes, K.boxes2d_names, "tokens_positive"],
|
| 66 |
+
)
|
| 67 |
+
class RandomSamplingNegPos:
|
| 68 |
+
"""Randomly sample negative and positive labels for object detection."""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
tokenizer_name: str = "bert-base-uncased",
|
| 73 |
+
num_sample_negative: int = 85,
|
| 74 |
+
max_tokens: int = 256,
|
| 75 |
+
full_sampling_prob: float = 0.5,
|
| 76 |
+
) -> None:
|
| 77 |
+
"""Creates an instance of RandomSamplingNegPos."""
|
| 78 |
+
if AutoTokenizer is None:
|
| 79 |
+
raise RuntimeError(
|
| 80 |
+
"transformers is not installed, please install it by: "
|
| 81 |
+
"pip install transformers."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 85 |
+
self.num_sample_negative = num_sample_negative
|
| 86 |
+
self.full_sampling_prob = full_sampling_prob
|
| 87 |
+
self.max_tokens = max_tokens
|
| 88 |
+
|
| 89 |
+
def __call__(
|
| 90 |
+
self,
|
| 91 |
+
dataset_type_list: list[str],
|
| 92 |
+
boxes_list: list[NDArrayF32],
|
| 93 |
+
class_ids_list: list[NDArrayI64],
|
| 94 |
+
texts_list: list[str] | None = None,
|
| 95 |
+
label_map_list: dict | None = None,
|
| 96 |
+
positive_positions_list: list[dict] | None = None,
|
| 97 |
+
) -> tuple[
|
| 98 |
+
list[NDArrayF32],
|
| 99 |
+
list[NDArrayI64],
|
| 100 |
+
list[str],
|
| 101 |
+
list[dict[int, list[list[int]]]],
|
| 102 |
+
]:
|
| 103 |
+
"""Randomly sample negative and positive labels."""
|
| 104 |
+
new_texts_list = []
|
| 105 |
+
tokens_positive_list = []
|
| 106 |
+
for i, (boxes, class_ids) in enumerate(
|
| 107 |
+
zip(boxes_list, class_ids_list)
|
| 108 |
+
):
|
| 109 |
+
if dataset_type_list[i] == "OD":
|
| 110 |
+
assert (
|
| 111 |
+
label_map_list[i] is not None
|
| 112 |
+
), "label_map should not be None"
|
| 113 |
+
boxes_list[i], class_ids_list[i], text, tokens_positive = (
|
| 114 |
+
self.od_aug(boxes, class_ids, label_map_list[i])
|
| 115 |
+
)
|
| 116 |
+
new_texts_list.append(text)
|
| 117 |
+
tokens_positive_list.append(tokens_positive)
|
| 118 |
+
else:
|
| 119 |
+
assert (
|
| 120 |
+
positive_positions_list[i] is not None
|
| 121 |
+
), "positive_positions should not be None"
|
| 122 |
+
tokens_positive = self.vg_aug(
|
| 123 |
+
class_ids, positive_positions_list[i]
|
| 124 |
+
)
|
| 125 |
+
new_texts_list.append(texts_list[i])
|
| 126 |
+
tokens_positive_list.append(tokens_positive)
|
| 127 |
+
|
| 128 |
+
return boxes_list, class_ids_list, new_texts_list, tokens_positive_list
|
| 129 |
+
|
| 130 |
+
def vg_aug(self, class_ids: NDArrayI64, positive_positions):
|
| 131 |
+
"""Visual Genome data augmentation."""
|
| 132 |
+
positive_label_list = np.unique(class_ids).tolist()
|
| 133 |
+
|
| 134 |
+
label_to_positions = {}
|
| 135 |
+
for label in positive_label_list:
|
| 136 |
+
label_to_positions[label] = positive_positions[label]
|
| 137 |
+
|
| 138 |
+
return label_to_positions
|
| 139 |
+
|
| 140 |
+
def od_aug(
|
| 141 |
+
self,
|
| 142 |
+
boxes: NDArrayF32,
|
| 143 |
+
class_ids: NDArrayI64,
|
| 144 |
+
label_map: dict,
|
| 145 |
+
) -> tuple[NDArrayF32, NDArrayI64, str, dict[int, list[list[int]]]]:
|
| 146 |
+
"""Object detection data augmentation."""
|
| 147 |
+
original_box_num = len(class_ids)
|
| 148 |
+
|
| 149 |
+
# If the category name is in the format of 'a/b' (in object365),
|
| 150 |
+
# we randomly select one of them.
|
| 151 |
+
for key, value in label_map.items():
|
| 152 |
+
if "/" in value:
|
| 153 |
+
label_map[key] = random.choice(value.split("/")).strip()
|
| 154 |
+
|
| 155 |
+
keep_box_index, class_ids, positive_caption_length = (
|
| 156 |
+
self.check_for_positive_overflow(class_ids, label_map)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
boxes = boxes[keep_box_index]
|
| 160 |
+
|
| 161 |
+
if len(boxes) < original_box_num:
|
| 162 |
+
rank_zero_warn(
|
| 163 |
+
f"Remove {original_box_num - len(boxes)} boxes due to "
|
| 164 |
+
"positive caption overflow."
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
valid_negative_indexes = list(label_map.keys())
|
| 168 |
+
|
| 169 |
+
positive_label_list = np.unique(class_ids).tolist()
|
| 170 |
+
|
| 171 |
+
full_negative = self.num_sample_negative
|
| 172 |
+
if full_negative > len(valid_negative_indexes):
|
| 173 |
+
full_negative = len(valid_negative_indexes)
|
| 174 |
+
|
| 175 |
+
outer_prob = random.random()
|
| 176 |
+
|
| 177 |
+
if outer_prob < self.full_sampling_prob:
|
| 178 |
+
# c. probability_full: add both all positive and all negatives
|
| 179 |
+
num_negatives = full_negative
|
| 180 |
+
else:
|
| 181 |
+
if random.random() < 1.0:
|
| 182 |
+
num_negatives = np.random.choice(max(1, full_negative)) + 1
|
| 183 |
+
else:
|
| 184 |
+
num_negatives = full_negative
|
| 185 |
+
|
| 186 |
+
# Keep some negatives
|
| 187 |
+
negative_label_list = set()
|
| 188 |
+
if num_negatives != -1:
|
| 189 |
+
if num_negatives > len(valid_negative_indexes):
|
| 190 |
+
num_negatives = len(valid_negative_indexes)
|
| 191 |
+
|
| 192 |
+
for i in np.random.choice(
|
| 193 |
+
valid_negative_indexes, size=num_negatives, replace=False
|
| 194 |
+
):
|
| 195 |
+
if int(i) not in positive_label_list:
|
| 196 |
+
negative_label_list.add(i)
|
| 197 |
+
|
| 198 |
+
random.shuffle(positive_label_list)
|
| 199 |
+
|
| 200 |
+
negative_label_list = list(negative_label_list)
|
| 201 |
+
random.shuffle(negative_label_list)
|
| 202 |
+
|
| 203 |
+
negative_max_length = self.max_tokens - positive_caption_length
|
| 204 |
+
screened_negative_label_list = []
|
| 205 |
+
|
| 206 |
+
for negative_label in negative_label_list:
|
| 207 |
+
label_text = clean_name(label_map[str(negative_label)]) + ". "
|
| 208 |
+
|
| 209 |
+
tokenized = self.tokenizer.tokenize(label_text)
|
| 210 |
+
|
| 211 |
+
negative_max_length -= len(tokenized)
|
| 212 |
+
|
| 213 |
+
if negative_max_length > 0:
|
| 214 |
+
screened_negative_label_list.append(negative_label)
|
| 215 |
+
else:
|
| 216 |
+
break
|
| 217 |
+
|
| 218 |
+
negative_label_list = screened_negative_label_list
|
| 219 |
+
label_to_positions, pheso_caption, label_remap_dict = (
|
| 220 |
+
generate_senetence_given_labels(
|
| 221 |
+
positive_label_list, negative_label_list, label_map
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# label remap
|
| 226 |
+
if len(class_ids) > 0:
|
| 227 |
+
class_ids = np.vectorize(lambda x: label_remap_dict[x])(class_ids)
|
| 228 |
+
|
| 229 |
+
return boxes, class_ids, pheso_caption, label_to_positions
|
| 230 |
+
|
| 231 |
+
def check_for_positive_overflow(
|
| 232 |
+
self, class_ids: NDArrayI64, label_map: dict[str, str]
|
| 233 |
+
) -> tuple[list[int], NDArrayI64, int]:
|
| 234 |
+
"""Check if having too many positive labels."""
|
| 235 |
+
# generate a caption by appending the positive labels
|
| 236 |
+
positive_label_list = np.unique(class_ids).tolist()
|
| 237 |
+
|
| 238 |
+
# random shuffule so we can sample different annotations
|
| 239 |
+
# at different epochs
|
| 240 |
+
random.shuffle(positive_label_list)
|
| 241 |
+
|
| 242 |
+
kept_lables = []
|
| 243 |
+
length = 0
|
| 244 |
+
for _, label in enumerate(positive_label_list):
|
| 245 |
+
label_text = clean_name(label_map[str(label)]) + ". "
|
| 246 |
+
|
| 247 |
+
tokenized = self.tokenizer.tokenize(label_text)
|
| 248 |
+
|
| 249 |
+
length += len(tokenized)
|
| 250 |
+
|
| 251 |
+
if length > self.max_tokens:
|
| 252 |
+
break
|
| 253 |
+
else:
|
| 254 |
+
kept_lables.append(label)
|
| 255 |
+
|
| 256 |
+
keep_box_index = []
|
| 257 |
+
keep_gt_labels = []
|
| 258 |
+
for i, class_id in enumerate(class_ids):
|
| 259 |
+
if class_id in kept_lables:
|
| 260 |
+
keep_box_index.append(i)
|
| 261 |
+
keep_gt_labels.append(class_id)
|
| 262 |
+
|
| 263 |
+
return (
|
| 264 |
+
keep_box_index,
|
| 265 |
+
np.array(keep_gt_labels, dtype=np.int64),
|
| 266 |
+
length,
|
| 267 |
+
)
|
opendet3d/data/transforms/pad.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pad transformation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TypedDict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from vis4d.common.typing import NDArrayF32
|
| 10 |
+
from vis4d.data.const import CommonKeys as K
|
| 11 |
+
from vis4d.data.transforms.base import Transform
|
| 12 |
+
from vis4d.data.transforms.pad import _get_max_shape
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PadParam(TypedDict):
|
| 16 |
+
"""Parameters for Reshape."""
|
| 17 |
+
|
| 18 |
+
pad_top: int
|
| 19 |
+
pad_bottom: int
|
| 20 |
+
pad_left: int
|
| 21 |
+
pad_right: int
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@Transform(
|
| 25 |
+
[K.images, K.input_hw],
|
| 26 |
+
[K.images, "transforms.pad", K.input_hw, "padding"],
|
| 27 |
+
)
|
| 28 |
+
class CenterPadImages:
|
| 29 |
+
"""Pad batch of images at the bottom right."""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
stride: int = 32,
|
| 34 |
+
mode: str = "constant",
|
| 35 |
+
value: float = 0.0,
|
| 36 |
+
update_input_hw: bool = False,
|
| 37 |
+
shape: tuple[int, int] | None = None,
|
| 38 |
+
pad2square: bool = False,
|
| 39 |
+
) -> None:
|
| 40 |
+
"""Creates an instance of PadImage.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
stride (int, optional): Chooses padding size so that the input will
|
| 44 |
+
be divisible by stride. Defaults to 32.
|
| 45 |
+
mode (str, optional): Padding mode. One of constant, reflect,
|
| 46 |
+
replicate or circular. Defaults to "constant".
|
| 47 |
+
value (float, optional): Value for constant padding.
|
| 48 |
+
Defaults to 0.0.
|
| 49 |
+
shape (tuple[int, int], optional): Shape of the padded image
|
| 50 |
+
(H, W). Defaults to None.
|
| 51 |
+
pad2square (bool, optional): Pad to square. Defaults to False.
|
| 52 |
+
"""
|
| 53 |
+
self.stride = stride
|
| 54 |
+
self.mode = mode
|
| 55 |
+
self.value = value
|
| 56 |
+
self.update_input_hw = update_input_hw
|
| 57 |
+
self.shape = shape
|
| 58 |
+
self.pad2square = pad2square
|
| 59 |
+
|
| 60 |
+
def __call__(
|
| 61 |
+
self, images: list[NDArrayF32], input_hw: list[tuple[int, int]]
|
| 62 |
+
) -> tuple[list[NDArrayF32], list[PadParam], list[tuple[int, int]]]:
|
| 63 |
+
"""Pad images to consistent size."""
|
| 64 |
+
heights = [im.shape[1] for im in images]
|
| 65 |
+
widths = [im.shape[2] for im in images]
|
| 66 |
+
|
| 67 |
+
max_hw = _get_max_shape(
|
| 68 |
+
heights, widths, self.stride, self.shape, self.pad2square
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# generate params for torch pad
|
| 72 |
+
pad_params = []
|
| 73 |
+
target_input_hw = []
|
| 74 |
+
paddings = []
|
| 75 |
+
for i, (image, h, w) in enumerate(zip(images, heights, widths)):
|
| 76 |
+
pad_top, pad_bottom = (max_hw[0] - h) // 2, max_hw[0] - h - (
|
| 77 |
+
max_hw[0] - h
|
| 78 |
+
) // 2
|
| 79 |
+
|
| 80 |
+
pad_left, pad_right = (max_hw[1] - w) // 2, max_hw[1] - w - (
|
| 81 |
+
max_hw[1] - w
|
| 82 |
+
) // 2
|
| 83 |
+
|
| 84 |
+
image_ = torch.from_numpy(image).permute(0, 3, 1, 2)
|
| 85 |
+
image_ = F.pad(
|
| 86 |
+
image_,
|
| 87 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
| 88 |
+
self.mode,
|
| 89 |
+
self.value,
|
| 90 |
+
)
|
| 91 |
+
images[i] = image_.permute(0, 2, 3, 1).numpy()
|
| 92 |
+
|
| 93 |
+
pad_params.append(
|
| 94 |
+
PadParam(
|
| 95 |
+
pad_top=pad_top,
|
| 96 |
+
pad_bottom=pad_bottom,
|
| 97 |
+
pad_left=pad_left,
|
| 98 |
+
pad_right=pad_right,
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
paddings.append([pad_left, pad_right, pad_top, pad_bottom])
|
| 103 |
+
|
| 104 |
+
target_input_hw.append(max_hw)
|
| 105 |
+
|
| 106 |
+
if self.update_input_hw:
|
| 107 |
+
input_hw = target_input_hw
|
| 108 |
+
|
| 109 |
+
return images, pad_params, input_hw, paddings
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@Transform([K.intrinsics, "transforms.pad"], K.intrinsics)
|
| 113 |
+
class CenterPadIntrinsics:
|
| 114 |
+
"""Resize Intrinsics."""
|
| 115 |
+
|
| 116 |
+
def __call__(
|
| 117 |
+
self, intrinsics: list[NDArrayF32], pad_params: list[PadParam]
|
| 118 |
+
) -> list[NDArrayF32]:
|
| 119 |
+
"""Scale camera intrinsics when resizing."""
|
| 120 |
+
for i, intrinsic in enumerate(intrinsics):
|
| 121 |
+
intrinsic[0, 2] += pad_params[i]["pad_left"]
|
| 122 |
+
intrinsic[1, 2] += pad_params[i]["pad_top"]
|
| 123 |
+
|
| 124 |
+
intrinsics[i] = intrinsic
|
| 125 |
+
return intrinsics
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@Transform([K.boxes2d, "transforms.pad"], K.boxes2d)
|
| 129 |
+
class CenterPadBoxes2D:
|
| 130 |
+
"""Pad batch of depth maps at the bottom right."""
|
| 131 |
+
|
| 132 |
+
def __call__(
|
| 133 |
+
self, boxes_list: list[NDArrayF32], pad_params: list[PadParam]
|
| 134 |
+
) -> list[NDArrayF32]:
|
| 135 |
+
"""Scale camera intrinsics when resizing."""
|
| 136 |
+
for i, boxes in enumerate(boxes_list):
|
| 137 |
+
boxes[:, 0] += pad_params[i]["pad_left"]
|
| 138 |
+
boxes[:, 1] += pad_params[i]["pad_top"]
|
| 139 |
+
boxes[:, 2] += pad_params[i]["pad_left"]
|
| 140 |
+
boxes[:, 3] += pad_params[i]["pad_top"]
|
| 141 |
+
|
| 142 |
+
boxes_list[i] = boxes
|
| 143 |
+
|
| 144 |
+
return boxes_list
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@Transform([K.depth_maps, "transforms.pad"], K.depth_maps)
|
| 148 |
+
class CenterPadDepthMaps:
|
| 149 |
+
"""Pad batch of depth maps at the bottom right."""
|
| 150 |
+
|
| 151 |
+
def __init__(self, mode: str = "constant", value: int = 0) -> None:
|
| 152 |
+
"""Creates an instance."""
|
| 153 |
+
self.mode = mode
|
| 154 |
+
self.value = value
|
| 155 |
+
|
| 156 |
+
def __call__(
|
| 157 |
+
self, depth_maps: list[NDArrayF32], pad_params: list[PadParam]
|
| 158 |
+
) -> list[NDArrayF32]:
|
| 159 |
+
"""Pad images to consistent size."""
|
| 160 |
+
|
| 161 |
+
# generate params for torch pad
|
| 162 |
+
for i, (depth, pad_param_dict) in enumerate(
|
| 163 |
+
zip(depth_maps, pad_params)
|
| 164 |
+
):
|
| 165 |
+
pad_param = (
|
| 166 |
+
pad_param_dict["pad_left"],
|
| 167 |
+
pad_param_dict["pad_right"],
|
| 168 |
+
pad_param_dict["pad_top"],
|
| 169 |
+
pad_param_dict["pad_bottom"],
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
depth_ = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0)
|
| 173 |
+
depth_ = F.pad(depth_, pad_param, self.mode, self.value)
|
| 174 |
+
depth_maps[i] = depth_.squeeze(0).squeeze(0).numpy()
|
| 175 |
+
|
| 176 |
+
return depth_maps
|
opendet3d/data/transforms/resize.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Resize transformation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from vis4d.common.typing import NDArrayF32, NDArrayI64
|
| 10 |
+
from vis4d.data.const import CommonKeys as K
|
| 11 |
+
from vis4d.data.transforms.base import Transform
|
| 12 |
+
from vis4d.data.transforms.resize import ResizeParam, resize_tensor
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@Transform(K.images, ["transforms.resize", K.input_hw])
|
| 16 |
+
class GenResizeParameters:
|
| 17 |
+
"""Generate the parameters for a resize operation."""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self, shape: tuple[int, int], scales: tuple[float, float] | float = 1.0
|
| 21 |
+
) -> None:
|
| 22 |
+
"""Create a new instance of the class."""
|
| 23 |
+
self.shape = shape
|
| 24 |
+
self.scales = scales
|
| 25 |
+
|
| 26 |
+
def __call__(
|
| 27 |
+
self, images: list[NDArrayF32]
|
| 28 |
+
) -> tuple[list[ResizeParam], list[tuple[int, int]]]:
|
| 29 |
+
"""Compute the parameters and put them in the data dict."""
|
| 30 |
+
if isinstance(self.scales, float):
|
| 31 |
+
random_scale = self.scales
|
| 32 |
+
else:
|
| 33 |
+
random_scale = np.random.uniform(self.scales[0], self.scales[1])
|
| 34 |
+
|
| 35 |
+
shape = (
|
| 36 |
+
math.ceil(self.shape[0] * random_scale - 0.5),
|
| 37 |
+
math.ceil(self.shape[1] * random_scale - 0.5),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
output_ratio = shape[1] / shape[0]
|
| 41 |
+
|
| 42 |
+
image = images[0]
|
| 43 |
+
|
| 44 |
+
input_h, input_w = (image.shape[1], image.shape[2])
|
| 45 |
+
input_ratio = input_w / input_h
|
| 46 |
+
|
| 47 |
+
if output_ratio > input_ratio:
|
| 48 |
+
scale = shape[0] / input_h
|
| 49 |
+
else:
|
| 50 |
+
scale = shape[1] / input_w
|
| 51 |
+
|
| 52 |
+
target_shape = (
|
| 53 |
+
math.ceil(input_h * scale - 0.5),
|
| 54 |
+
math.ceil(input_w * scale - 0.5),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
scale_factor = (target_shape[0] / input_h, target_shape[1] / input_w)
|
| 58 |
+
|
| 59 |
+
resize_params = [
|
| 60 |
+
ResizeParam(target_shape=target_shape, scale_factor=scale_factor)
|
| 61 |
+
] * len(images)
|
| 62 |
+
target_shapes = [target_shape] * len(images)
|
| 63 |
+
|
| 64 |
+
return resize_params, target_shapes
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@Transform(
|
| 68 |
+
[K.panoptic_masks, "transforms.resize.target_shape"], K.panoptic_masks
|
| 69 |
+
)
|
| 70 |
+
class ResizePanopticMasks:
|
| 71 |
+
"""Resize panoptic segmentation masks."""
|
| 72 |
+
|
| 73 |
+
def __call__(
|
| 74 |
+
self,
|
| 75 |
+
masks_list: list[NDArrayI64],
|
| 76 |
+
target_shape_list: list[tuple[int, int]],
|
| 77 |
+
) -> list[NDArrayI64]:
|
| 78 |
+
"""Resize masks."""
|
| 79 |
+
for i, (masks, target_shape) in enumerate(
|
| 80 |
+
zip(masks_list, target_shape_list)
|
| 81 |
+
):
|
| 82 |
+
masks_ = torch.from_numpy(masks)
|
| 83 |
+
masks_ = (
|
| 84 |
+
resize_tensor(
|
| 85 |
+
masks_.float().unsqueeze(0).unsqueeze(0),
|
| 86 |
+
target_shape,
|
| 87 |
+
interpolation="nearest",
|
| 88 |
+
)
|
| 89 |
+
.type(masks_.dtype)
|
| 90 |
+
.squeeze(0)
|
| 91 |
+
.squeeze(0)
|
| 92 |
+
)
|
| 93 |
+
masks_list[i] = masks_.numpy()
|
| 94 |
+
return masks_list
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@Transform([K.boxes3d, "transforms.resize.scale_factor"], K.boxes3d)
|
| 98 |
+
class ResizeBoxes3D:
|
| 99 |
+
"""Resize list of 2D bounding boxes."""
|
| 100 |
+
|
| 101 |
+
def __call__(
|
| 102 |
+
self,
|
| 103 |
+
boxes_list: list[NDArrayF32],
|
| 104 |
+
scale_factors: list[tuple[float, float]],
|
| 105 |
+
) -> list[NDArrayF32]:
|
| 106 |
+
"""Resize 2D bounding boxes.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
boxes_list: (list[NDArrayF32]): The bounding boxes to be resized.
|
| 110 |
+
scale_factors (list[tuple[float, float]]): scaling factors.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
list[NDArrayF32]: Resized bounding boxes according to parameters in
|
| 114 |
+
resize.
|
| 115 |
+
"""
|
| 116 |
+
for i, (boxes, scale_factor) in enumerate(
|
| 117 |
+
zip(boxes_list, scale_factors)
|
| 118 |
+
):
|
| 119 |
+
boxes[:, 2] /= scale_factor[0]
|
| 120 |
+
boxes_list[i] = boxes
|
| 121 |
+
return boxes_list
|
opendet3d/eval/__init__.py
ADDED
|
File without changes
|
opendet3d/eval/detect3d.py
ADDED
|
@@ -0,0 +1,1249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""3D Multiple Object Detection Evaluator."""
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import copy
|
| 5 |
+
import datetime
|
| 6 |
+
import io
|
| 7 |
+
import itertools
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pycocotools.mask as maskUtils
|
| 15 |
+
import torch
|
| 16 |
+
from pycocotools.cocoeval import COCOeval
|
| 17 |
+
from scipy.spatial.distance import cdist
|
| 18 |
+
from terminaltables import AsciiTable
|
| 19 |
+
from vis4d.common.array import array_to_numpy
|
| 20 |
+
from vis4d.common.distributed import all_gather_object_cpu
|
| 21 |
+
from vis4d.common.typing import (
|
| 22 |
+
ArrayLike,
|
| 23 |
+
DictStrAny,
|
| 24 |
+
GenericFunc,
|
| 25 |
+
MetricLogs,
|
| 26 |
+
NDArrayF32,
|
| 27 |
+
NDArrayI64,
|
| 28 |
+
)
|
| 29 |
+
from vis4d.data.const import AxisMode
|
| 30 |
+
from vis4d.eval.base import Evaluator
|
| 31 |
+
from vis4d.eval.coco.detect import xyxy_to_xywh
|
| 32 |
+
from vis4d.op.box.box3d import boxes3d_to_corners
|
| 33 |
+
from vis4d.op.geometry.rotation import quaternion_to_matrix
|
| 34 |
+
|
| 35 |
+
from opendet3d.data.datasets.coco3d import COCO3D
|
| 36 |
+
from opendet3d.op.box.box3d import box3d_overlap
|
| 37 |
+
from opendet3d.op.geometric.rotation import so3_relative_angle
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Detect3DEvaluator(Evaluator):
|
| 41 |
+
"""3D object detection evaluation with COCO format."""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
det_map: dict[str, int],
|
| 46 |
+
cat_map: dict[str, int],
|
| 47 |
+
annotation: str,
|
| 48 |
+
id2name: dict[int, str] | None = None,
|
| 49 |
+
per_class_eval: bool = True,
|
| 50 |
+
eval_prox: bool = False,
|
| 51 |
+
iou_type: str = "bbox",
|
| 52 |
+
num_columns: int = 6,
|
| 53 |
+
base_classes: list[str] | None = None,
|
| 54 |
+
) -> None:
|
| 55 |
+
"""Create an instance of the class."""
|
| 56 |
+
if id2name is None:
|
| 57 |
+
self.id2name = {v: k for k, v in det_map.items()}
|
| 58 |
+
else:
|
| 59 |
+
self.id2name = id2name
|
| 60 |
+
|
| 61 |
+
self.annotation = annotation
|
| 62 |
+
self.per_class_eval = per_class_eval
|
| 63 |
+
self.eval_prox = eval_prox
|
| 64 |
+
self.iou_type = iou_type
|
| 65 |
+
self.num_columns = num_columns
|
| 66 |
+
self.base_classes = base_classes
|
| 67 |
+
|
| 68 |
+
self.tp_errors = ["ATE", "AOE", "ASE"]
|
| 69 |
+
|
| 70 |
+
category_names = sorted(det_map, key=det_map.get)
|
| 71 |
+
|
| 72 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 73 |
+
self._coco_gt = COCO3D([annotation], category_names)
|
| 74 |
+
|
| 75 |
+
self.cat_map = cat_map
|
| 76 |
+
|
| 77 |
+
self.bbox_2D_evals_per_cat_area: DictStrAny = {}
|
| 78 |
+
self.bbox_3D_evals_per_cat_area: DictStrAny = {}
|
| 79 |
+
self._predictions: list[DictStrAny] = []
|
| 80 |
+
|
| 81 |
+
def __repr__(self) -> str:
|
| 82 |
+
"""Returns the string representation of the object."""
|
| 83 |
+
return f"3D Object Detection Evaluator with {self.annotation}"
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def metrics(self) -> list[str]:
|
| 87 |
+
"""Supported metrics.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
list[str]: Metrics to evaluate.
|
| 91 |
+
"""
|
| 92 |
+
return ["2D", "3D"]
|
| 93 |
+
|
| 94 |
+
def gather(self, gather_func: GenericFunc) -> None:
|
| 95 |
+
"""Accumulate predictions across processes."""
|
| 96 |
+
all_preds = all_gather_object_cpu(
|
| 97 |
+
self._predictions, use_system_tmp=False
|
| 98 |
+
)
|
| 99 |
+
if all_preds is not None:
|
| 100 |
+
self._predictions = list(itertools.chain(*all_preds))
|
| 101 |
+
|
| 102 |
+
def reset(self) -> None:
|
| 103 |
+
"""Reset the saved predictions to start new round of evaluation."""
|
| 104 |
+
self._predictions.clear()
|
| 105 |
+
self.bbox_2D_evals_per_cat_area.clear()
|
| 106 |
+
self.bbox_3D_evals_per_cat_area.clear()
|
| 107 |
+
|
| 108 |
+
def process_batch(
|
| 109 |
+
self,
|
| 110 |
+
coco_image_id: list[int],
|
| 111 |
+
pred_boxes: list[ArrayLike],
|
| 112 |
+
pred_scores: list[ArrayLike],
|
| 113 |
+
pred_classes: list[ArrayLike],
|
| 114 |
+
pred_boxes3d: list[ArrayLike] | None = None,
|
| 115 |
+
) -> None:
|
| 116 |
+
"""Process sample and convert detections to coco format."""
|
| 117 |
+
for i, image_id in enumerate(coco_image_id):
|
| 118 |
+
boxes = array_to_numpy(
|
| 119 |
+
pred_boxes[i].to(torch.float32), n_dims=None, dtype=np.float32
|
| 120 |
+
)
|
| 121 |
+
scores = array_to_numpy(
|
| 122 |
+
pred_scores[i].to(torch.float32), n_dims=None, dtype=np.float32
|
| 123 |
+
)
|
| 124 |
+
classes = array_to_numpy(
|
| 125 |
+
pred_classes[i], n_dims=None, dtype=np.int64
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if pred_boxes3d is not None:
|
| 129 |
+
boxes3d = array_to_numpy(
|
| 130 |
+
pred_boxes3d[i].to(torch.float32),
|
| 131 |
+
n_dims=None,
|
| 132 |
+
dtype=np.float32,
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
boxes3d = None
|
| 136 |
+
|
| 137 |
+
self._predictions_to_coco(
|
| 138 |
+
image_id, boxes, boxes3d, scores, classes
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def _predictions_to_coco(
|
| 142 |
+
self,
|
| 143 |
+
img_id: int,
|
| 144 |
+
boxes: NDArrayF32,
|
| 145 |
+
boxes3d: NDArrayF32 | None,
|
| 146 |
+
scores: NDArrayF32,
|
| 147 |
+
classes: NDArrayI64,
|
| 148 |
+
) -> None:
|
| 149 |
+
"""Convert predictions to COCO format."""
|
| 150 |
+
boxes_xyxy = copy.deepcopy(boxes)
|
| 151 |
+
boxes_xywh = xyxy_to_xywh(boxes_xyxy)
|
| 152 |
+
|
| 153 |
+
if boxes3d is not None:
|
| 154 |
+
# FIXME: Make axismode configurable
|
| 155 |
+
corners_3d = boxes3d_to_corners(
|
| 156 |
+
torch.from_numpy(boxes3d), AxisMode.OPENCV
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
for i, (box, box_score, box_class) in enumerate(
|
| 160 |
+
zip(boxes_xywh, scores, classes)
|
| 161 |
+
):
|
| 162 |
+
xywh = box.tolist()
|
| 163 |
+
|
| 164 |
+
result = {
|
| 165 |
+
"image_id": img_id,
|
| 166 |
+
"bbox": xywh,
|
| 167 |
+
"category_id": self.cat_map[self.id2name[box_class.item()]],
|
| 168 |
+
"score": box_score.item(),
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
# mapping to Omni3D format
|
| 172 |
+
if boxes3d is not None:
|
| 173 |
+
result["center_cam"] = boxes3d[i][:3].tolist()
|
| 174 |
+
|
| 175 |
+
# wlh to whl
|
| 176 |
+
result["dimensions"] = boxes3d[i][[3, 5, 4]].tolist()
|
| 177 |
+
|
| 178 |
+
result["R_cam"] = (
|
| 179 |
+
quaternion_to_matrix(torch.from_numpy(boxes3d[i][6:10]))
|
| 180 |
+
.numpy()
|
| 181 |
+
.tolist()
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
corners = corners_3d[i].numpy().tolist()
|
| 185 |
+
|
| 186 |
+
result["bbox3D"] = [
|
| 187 |
+
corners[6],
|
| 188 |
+
corners[4],
|
| 189 |
+
corners[0],
|
| 190 |
+
corners[2],
|
| 191 |
+
corners[7],
|
| 192 |
+
corners[5],
|
| 193 |
+
corners[1],
|
| 194 |
+
corners[3],
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
result["depth"] = boxes3d[i][2].item()
|
| 198 |
+
|
| 199 |
+
self._predictions.append(result)
|
| 200 |
+
|
| 201 |
+
def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
|
| 202 |
+
"""Evaluate predictions."""
|
| 203 |
+
if metric == "2D":
|
| 204 |
+
metrics = ["AP", "AP50", "AP75", "AP95", "APs", "APm", "APl"]
|
| 205 |
+
else:
|
| 206 |
+
if self.iou_type == "bbox":
|
| 207 |
+
metrics = ["AP", "AP15", "AP25", "AP50", "APn", "APm", "APf"]
|
| 208 |
+
main_metric = "AP"
|
| 209 |
+
else:
|
| 210 |
+
metrics = ["AP", "ATE", "ASE", "AOE", "ODS"]
|
| 211 |
+
main_metric = "ODS"
|
| 212 |
+
|
| 213 |
+
if self.base_classes is not None:
|
| 214 |
+
metrics += [f"{main_metric}_Base", f"{main_metric}_Novel"]
|
| 215 |
+
|
| 216 |
+
if len(self._predictions) == 0:
|
| 217 |
+
return {m: 0.0 for m in metrics}, "No predictions to evaluate."
|
| 218 |
+
|
| 219 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 220 |
+
coco_dt = self._coco_gt.loadRes(self._predictions)
|
| 221 |
+
|
| 222 |
+
assert coco_dt is not None
|
| 223 |
+
evaluator = Detect3Deval(
|
| 224 |
+
self._coco_gt,
|
| 225 |
+
coco_dt,
|
| 226 |
+
mode=metric,
|
| 227 |
+
eval_prox=self.eval_prox,
|
| 228 |
+
iou_type=self.iou_type,
|
| 229 |
+
)
|
| 230 |
+
evaluator.evaluate()
|
| 231 |
+
evaluator.accumulate()
|
| 232 |
+
|
| 233 |
+
if self.iou_type == "bbox":
|
| 234 |
+
log_str = "\n" + evaluator.summarize()
|
| 235 |
+
|
| 236 |
+
# precision: (iou, recall, cls, area range, max dets)
|
| 237 |
+
precisions = evaluator.eval["precision"]
|
| 238 |
+
assert len(self._coco_gt.getCatIds()) == precisions.shape[2]
|
| 239 |
+
|
| 240 |
+
if metric == "2D":
|
| 241 |
+
self.bbox_2D_evals_per_cat_area = evaluator.evals_per_cat_area
|
| 242 |
+
|
| 243 |
+
score_dict = dict(zip(metrics, evaluator.stats))
|
| 244 |
+
else:
|
| 245 |
+
if self.iou_type == "bbox":
|
| 246 |
+
self.bbox_3D_evals_per_cat_area = evaluator.evals_per_cat_area
|
| 247 |
+
|
| 248 |
+
score_dict = dict(zip(metrics, evaluator.stats))
|
| 249 |
+
else:
|
| 250 |
+
trans_tp_errors = evaluator.eval["trans_tp_errors"]
|
| 251 |
+
rot_tp_errors = evaluator.eval["rot_tp_errors"]
|
| 252 |
+
scale_tp_errors = evaluator.eval["scale_tp_errors"]
|
| 253 |
+
|
| 254 |
+
precision = precisions[:, :, :, 0, -1]
|
| 255 |
+
precision = precision[precision > -1]
|
| 256 |
+
if precision.size:
|
| 257 |
+
mAP = np.mean(precision).item()
|
| 258 |
+
else:
|
| 259 |
+
mAP = float("nan")
|
| 260 |
+
|
| 261 |
+
trans_tp = trans_tp_errors[:, :, :, 0, -1]
|
| 262 |
+
trans_tp = trans_tp[trans_tp > -1]
|
| 263 |
+
|
| 264 |
+
rot_tp = rot_tp_errors[:, :, :, 0, -1]
|
| 265 |
+
rot_tp = rot_tp[rot_tp > -1]
|
| 266 |
+
|
| 267 |
+
scale_tp = scale_tp_errors[:, :, :, 0, -1]
|
| 268 |
+
scale_tp = scale_tp[scale_tp > -1]
|
| 269 |
+
|
| 270 |
+
if trans_tp.size:
|
| 271 |
+
mATE = np.mean(trans_tp).item()
|
| 272 |
+
mAOE = np.mean(rot_tp).item()
|
| 273 |
+
mASE = np.mean(scale_tp).item()
|
| 274 |
+
|
| 275 |
+
mODS = (
|
| 276 |
+
np.sum(mAP * 3 + (1 - mATE) + (1 - mAOE) + (1 - mASE))
|
| 277 |
+
/ 6
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
else:
|
| 281 |
+
mATE = float("nan")
|
| 282 |
+
mAOE = float("nan")
|
| 283 |
+
mASE = float("nan")
|
| 284 |
+
mODS = float("nan")
|
| 285 |
+
|
| 286 |
+
score_dict = {
|
| 287 |
+
"AP": mAP,
|
| 288 |
+
"ATE": mATE,
|
| 289 |
+
"ASE": mASE,
|
| 290 |
+
"AOE": mAOE,
|
| 291 |
+
"ODS": mODS,
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
log_str = "\nHigh-level metrics:"
|
| 295 |
+
for k, v in score_dict.items():
|
| 296 |
+
log_str += f"\n{k}: {v:.4f}"
|
| 297 |
+
|
| 298 |
+
if self.per_class_eval:
|
| 299 |
+
results_per_category = []
|
| 300 |
+
score_base_list = []
|
| 301 |
+
score_novel_list = []
|
| 302 |
+
|
| 303 |
+
for idx, cat_id in enumerate(self._coco_gt.getCatIds()):
|
| 304 |
+
# area range index 0: all area ranges
|
| 305 |
+
# max dets index -1: typically 100 per image
|
| 306 |
+
nm = self._coco_gt.loadCats(cat_id)[0]
|
| 307 |
+
precision = precisions[:, :, idx, 0, -1]
|
| 308 |
+
precision = precision[precision > -1]
|
| 309 |
+
if precision.size:
|
| 310 |
+
ap = np.mean(precision).item()
|
| 311 |
+
else:
|
| 312 |
+
ap = float("nan")
|
| 313 |
+
|
| 314 |
+
if self.iou_type == "dist":
|
| 315 |
+
trans_tp = trans_tp_errors[:, :, idx, 0, -1]
|
| 316 |
+
trans_tp = trans_tp[trans_tp > -1]
|
| 317 |
+
|
| 318 |
+
rot_tp = rot_tp_errors[:, :, idx, 0, -1]
|
| 319 |
+
rot_tp = rot_tp[rot_tp > -1]
|
| 320 |
+
|
| 321 |
+
scale_tp = scale_tp_errors[:, :, idx, 0, -1]
|
| 322 |
+
scale_tp = scale_tp[scale_tp > -1]
|
| 323 |
+
|
| 324 |
+
if trans_tp.size:
|
| 325 |
+
ate = np.mean(trans_tp).item()
|
| 326 |
+
aoe = np.mean(rot_tp).item()
|
| 327 |
+
ase = np.mean(scale_tp).item()
|
| 328 |
+
|
| 329 |
+
ods = (
|
| 330 |
+
np.sum(ap * 3 + (1 - ate) + (1 - aoe) + (1 - ase))
|
| 331 |
+
/ 6
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
else:
|
| 335 |
+
ate = float("nan")
|
| 336 |
+
aoe = float("nan")
|
| 337 |
+
ase = float("nan")
|
| 338 |
+
ods = float("nan")
|
| 339 |
+
|
| 340 |
+
results_per_category.append(
|
| 341 |
+
(
|
| 342 |
+
f'{nm["name"]}',
|
| 343 |
+
f"{ap:0.3f}",
|
| 344 |
+
f"{ate:0.3f}",
|
| 345 |
+
f"{ase:0.3f}",
|
| 346 |
+
f"{aoe:0.3f}",
|
| 347 |
+
f"{ods:0.3f}",
|
| 348 |
+
)
|
| 349 |
+
)
|
| 350 |
+
else:
|
| 351 |
+
results_per_category.append(
|
| 352 |
+
(f'{nm["name"]}', f"{ap:0.3f}")
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if self.base_classes is not None:
|
| 356 |
+
if self.iou_type == "dist":
|
| 357 |
+
score = ods
|
| 358 |
+
else:
|
| 359 |
+
score = ap
|
| 360 |
+
|
| 361 |
+
if nm["name"] in self.base_classes:
|
| 362 |
+
score_base_list.append(score)
|
| 363 |
+
else:
|
| 364 |
+
score_novel_list.append(score)
|
| 365 |
+
|
| 366 |
+
results_flatten = list(itertools.chain(*results_per_category))
|
| 367 |
+
|
| 368 |
+
if self.iou_type == "dist":
|
| 369 |
+
num_columns = 6
|
| 370 |
+
headers = ["category", "AP", "ATE", "ASE", "AOE", "ODS"]
|
| 371 |
+
else:
|
| 372 |
+
num_columns = min(
|
| 373 |
+
self.num_columns, len(results_per_category) * 2
|
| 374 |
+
)
|
| 375 |
+
headers = ["category", "AP"] * (num_columns // 2)
|
| 376 |
+
results = itertools.zip_longest(
|
| 377 |
+
*[results_flatten[i::num_columns] for i in range(num_columns)]
|
| 378 |
+
)
|
| 379 |
+
table_data = [headers] + list(results)
|
| 380 |
+
table = AsciiTable(table_data)
|
| 381 |
+
log_str = f"\n{table.table}\n{log_str}"
|
| 382 |
+
|
| 383 |
+
if self.base_classes is not None:
|
| 384 |
+
score_dict[f"{main_metric}_Base"] = np.mean(score_base_list).item()
|
| 385 |
+
score_dict[f"{main_metric}_Novel"] = np.mean(
|
| 386 |
+
score_novel_list
|
| 387 |
+
).item()
|
| 388 |
+
|
| 389 |
+
return score_dict, log_str
|
| 390 |
+
|
| 391 |
+
def save(
|
| 392 |
+
self, metric: str, output_dir: str, prefix: str | None = None
|
| 393 |
+
) -> None:
|
| 394 |
+
"""Save the results to json files."""
|
| 395 |
+
assert metric in self.metrics
|
| 396 |
+
|
| 397 |
+
if prefix is not None:
|
| 398 |
+
result_folder = os.path.join(output_dir, prefix)
|
| 399 |
+
os.makedirs(result_folder, exist_ok=True)
|
| 400 |
+
else:
|
| 401 |
+
result_folder = output_dir
|
| 402 |
+
|
| 403 |
+
result_file = os.path.join(
|
| 404 |
+
result_folder, f"detect_{metric}_results.json"
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
with open(result_file, mode="w", encoding="utf-8") as f:
|
| 408 |
+
json.dump(self._predictions, f)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class Detect3Deval(COCOeval):
|
| 412 |
+
"""COCOeval Wrapper for 2D and 3D box evaluation.
|
| 413 |
+
|
| 414 |
+
Now it support bbox IoU matching only.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
def __init__(
|
| 418 |
+
self,
|
| 419 |
+
cocoGt=None,
|
| 420 |
+
cocoDt=None,
|
| 421 |
+
mode: str = "2D",
|
| 422 |
+
iou_type: str = "bbox",
|
| 423 |
+
eval_prox: bool = False,
|
| 424 |
+
):
|
| 425 |
+
"""Initialize Detect3Deval using coco APIs for Gt and Dt.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
cocoGt: COCO object with ground truth annotations
|
| 429 |
+
cocoDt: COCO object with detection results
|
| 430 |
+
mode: (str) defines whether to evaluate 2D or 3D performance.
|
| 431 |
+
One of {"2D", "3D"}
|
| 432 |
+
eval_prox: (bool) if True, performs "Proximity Evaluation", i.e.
|
| 433 |
+
evaluates detections in the proximity of the ground truth2D
|
| 434 |
+
boxes. This is used for datasets which are not exhaustively
|
| 435 |
+
annotated.
|
| 436 |
+
"""
|
| 437 |
+
if mode not in {"2D", "3D"}:
|
| 438 |
+
raise Exception(f"{mode} mode is not supported")
|
| 439 |
+
self.mode = mode
|
| 440 |
+
self.iou_type = iou_type
|
| 441 |
+
self.eval_prox = eval_prox
|
| 442 |
+
|
| 443 |
+
self.cocoGt = cocoGt # ground truth COCO API
|
| 444 |
+
self.cocoDt = cocoDt # detections COCO API
|
| 445 |
+
|
| 446 |
+
# per-image per-category evaluation results [KxAxI] elements
|
| 447 |
+
self.evalImgs = defaultdict(list)
|
| 448 |
+
|
| 449 |
+
self.eval = {} # accumulated evaluation results
|
| 450 |
+
self._gts = defaultdict(list) # gt for evaluation
|
| 451 |
+
self._dts = defaultdict(list) # dt for evaluation
|
| 452 |
+
self.params = Detect3DParams(mode=mode, iouType=iou_type) # parameters
|
| 453 |
+
self._paramsEval = {} # parameters for evaluation
|
| 454 |
+
self.stats = [] # result summarization
|
| 455 |
+
self.ious = {} # ious between all gts and dts
|
| 456 |
+
|
| 457 |
+
if cocoGt is not None:
|
| 458 |
+
self.params.imgIds = sorted(cocoGt.getImgIds())
|
| 459 |
+
self.params.catIds = sorted(cocoGt.getCatIds())
|
| 460 |
+
|
| 461 |
+
self.evals_per_cat_area = None
|
| 462 |
+
|
| 463 |
+
def _prepare(self) -> None:
|
| 464 |
+
"""Prepare ._gts and ._dts for evaluation based on params."""
|
| 465 |
+
p = self.params
|
| 466 |
+
|
| 467 |
+
if p.useCats:
|
| 468 |
+
gts = self.cocoGt.loadAnns(
|
| 469 |
+
self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
|
| 470 |
+
)
|
| 471 |
+
dts = self.cocoDt.loadAnns(
|
| 472 |
+
self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds)
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
else:
|
| 476 |
+
gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
|
| 477 |
+
dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
|
| 478 |
+
|
| 479 |
+
# set ignore flag
|
| 480 |
+
ignore_flag = "ignore2D" if self.mode == "2D" else "ignore3D"
|
| 481 |
+
for gt in gts:
|
| 482 |
+
gt[ignore_flag] = gt[ignore_flag] if ignore_flag in gt else 0
|
| 483 |
+
|
| 484 |
+
self._gts = defaultdict(list) # gt for evaluation
|
| 485 |
+
self._dts = defaultdict(list) # dt for evaluation
|
| 486 |
+
|
| 487 |
+
for gt in gts:
|
| 488 |
+
self._gts[gt["image_id"], gt["category_id"]].append(gt)
|
| 489 |
+
|
| 490 |
+
for dt in dts:
|
| 491 |
+
self._dts[dt["image_id"], dt["category_id"]].append(dt)
|
| 492 |
+
|
| 493 |
+
self.evalImgs = defaultdict(
|
| 494 |
+
list
|
| 495 |
+
) # per-image per-category evaluation results
|
| 496 |
+
self.eval = {} # accumulated evaluation results
|
| 497 |
+
|
| 498 |
+
def accumulate(self, p=None) -> None:
|
| 499 |
+
"""Accumulate per image evaluation and store the result in self.eval.
|
| 500 |
+
|
| 501 |
+
Args:
|
| 502 |
+
p: input params for evaluation
|
| 503 |
+
"""
|
| 504 |
+
print("Accumulating evaluation results...")
|
| 505 |
+
assert self.evalImgs, "Please run evaluate() first"
|
| 506 |
+
|
| 507 |
+
tic = time.time()
|
| 508 |
+
|
| 509 |
+
# allows input customized parameters
|
| 510 |
+
if p is None:
|
| 511 |
+
p = self.params
|
| 512 |
+
|
| 513 |
+
p.catIds = p.catIds if p.useCats == 1 else [-1]
|
| 514 |
+
|
| 515 |
+
T = len(p.iouThrs)
|
| 516 |
+
R = len(p.recThrs)
|
| 517 |
+
K = len(p.catIds) if p.useCats else 1
|
| 518 |
+
A = len(p.areaRng)
|
| 519 |
+
M = len(p.maxDets)
|
| 520 |
+
|
| 521 |
+
precision = -np.ones(
|
| 522 |
+
(T, R, K, A, M)
|
| 523 |
+
) # -1 for the precision of absent categories
|
| 524 |
+
trans_tp_errors = -np.ones((T, R, K, A, M))
|
| 525 |
+
rot_tp_errors = -np.ones((T, R, K, A, M))
|
| 526 |
+
scale_tp_errors = -np.ones((T, R, K, A, M))
|
| 527 |
+
recall = -np.ones((T, K, A, M))
|
| 528 |
+
scores = -np.ones((T, R, K, A, M))
|
| 529 |
+
|
| 530 |
+
# create dictionary for future indexing
|
| 531 |
+
_pe = self._paramsEval
|
| 532 |
+
|
| 533 |
+
catIds = _pe.catIds if _pe.useCats else [-1]
|
| 534 |
+
setK = set(catIds)
|
| 535 |
+
setA = set(map(tuple, _pe.areaRng))
|
| 536 |
+
setM = set(_pe.maxDets)
|
| 537 |
+
setI = set(_pe.imgIds)
|
| 538 |
+
|
| 539 |
+
# get inds to evaluate
|
| 540 |
+
catid_list = [k for n, k in enumerate(p.catIds) if k in setK]
|
| 541 |
+
k_list = [n for n, k in enumerate(p.catIds) if k in setK]
|
| 542 |
+
m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
|
| 543 |
+
a_list = [
|
| 544 |
+
n
|
| 545 |
+
for n, a in enumerate(map(lambda x: tuple(x), p.areaRng))
|
| 546 |
+
if a in setA
|
| 547 |
+
]
|
| 548 |
+
i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
|
| 549 |
+
|
| 550 |
+
I0 = len(_pe.imgIds)
|
| 551 |
+
A0 = len(_pe.areaRng)
|
| 552 |
+
|
| 553 |
+
has_precomputed_evals = not (self.evals_per_cat_area is None)
|
| 554 |
+
|
| 555 |
+
if has_precomputed_evals:
|
| 556 |
+
evals_per_cat_area = self.evals_per_cat_area
|
| 557 |
+
else:
|
| 558 |
+
evals_per_cat_area = {}
|
| 559 |
+
|
| 560 |
+
# retrieve E at each category, area range, and max number of detections
|
| 561 |
+
for k, (k0, catId) in enumerate(zip(k_list, catid_list)):
|
| 562 |
+
Nk = k0 * A0 * I0
|
| 563 |
+
for a, a0 in enumerate(a_list):
|
| 564 |
+
Na = a0 * I0
|
| 565 |
+
|
| 566 |
+
if has_precomputed_evals:
|
| 567 |
+
E = evals_per_cat_area[(catId, a)]
|
| 568 |
+
|
| 569 |
+
else:
|
| 570 |
+
E = [self.evalImgs[Nk + Na + i] for i in i_list]
|
| 571 |
+
E = [e for e in E if not e is None]
|
| 572 |
+
evals_per_cat_area[(catId, a)] = E
|
| 573 |
+
|
| 574 |
+
if len(E) == 0:
|
| 575 |
+
continue
|
| 576 |
+
|
| 577 |
+
for m, maxDet in enumerate(m_list):
|
| 578 |
+
|
| 579 |
+
dtScores = np.concatenate(
|
| 580 |
+
[e["dtScores"][0:maxDet] for e in E]
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
# different sorting method generates slightly different results.
|
| 584 |
+
# mergesort is used to be consistent as Matlab implementation.
|
| 585 |
+
inds = np.argsort(-dtScores, kind="mergesort")
|
| 586 |
+
dtScoresSorted = dtScores[inds]
|
| 587 |
+
|
| 588 |
+
dtm = np.concatenate(
|
| 589 |
+
[e["dtMatches"][:, 0:maxDet] for e in E], axis=1
|
| 590 |
+
)[:, inds]
|
| 591 |
+
dtIg = np.concatenate(
|
| 592 |
+
[e["dtIgnore"][:, 0:maxDet] for e in E], axis=1
|
| 593 |
+
)[:, inds]
|
| 594 |
+
gtIg = np.concatenate([e["gtIgnore"] for e in E])
|
| 595 |
+
npig = np.count_nonzero(gtIg == 0)
|
| 596 |
+
|
| 597 |
+
if npig == 0:
|
| 598 |
+
continue
|
| 599 |
+
|
| 600 |
+
tps = np.logical_and(dtm, np.logical_not(dtIg))
|
| 601 |
+
fps = np.logical_and(
|
| 602 |
+
np.logical_not(dtm), np.logical_not(dtIg)
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float64)
|
| 606 |
+
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float64)
|
| 607 |
+
|
| 608 |
+
# Compute TP error
|
| 609 |
+
if self.iou_type == "dist":
|
| 610 |
+
tems = np.concatenate(
|
| 611 |
+
[e["dtTranslationError"][:, 0:maxDet] for e in E],
|
| 612 |
+
axis=1,
|
| 613 |
+
)[:, inds]
|
| 614 |
+
|
| 615 |
+
oems = np.concatenate(
|
| 616 |
+
[e["dtOrientationError"][:, 0:maxDet] for e in E],
|
| 617 |
+
axis=1,
|
| 618 |
+
)[:, inds]
|
| 619 |
+
|
| 620 |
+
sems = np.concatenate(
|
| 621 |
+
[e["dtScaleError"][:, 0:maxDet] for e in E], axis=1
|
| 622 |
+
)[:, inds]
|
| 623 |
+
|
| 624 |
+
for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
|
| 625 |
+
tp = np.array(tp)
|
| 626 |
+
fp = np.array(fp)
|
| 627 |
+
nd = len(tp)
|
| 628 |
+
rc = tp / npig
|
| 629 |
+
pr = tp / (fp + tp + np.spacing(1))
|
| 630 |
+
|
| 631 |
+
q = np.zeros((R,))
|
| 632 |
+
ss = np.zeros((R,))
|
| 633 |
+
tran_tp_error = np.ones((R,))
|
| 634 |
+
rot_tp_error = np.ones((R,))
|
| 635 |
+
scale_tp_error = np.ones((R,))
|
| 636 |
+
|
| 637 |
+
if nd:
|
| 638 |
+
recall[t, k, a, m] = rc[-1]
|
| 639 |
+
|
| 640 |
+
else:
|
| 641 |
+
recall[t, k, a, m] = 0
|
| 642 |
+
|
| 643 |
+
# numpy is slow without cython optimization for accessing elements
|
| 644 |
+
# use python array gets significant speed improvement
|
| 645 |
+
pr = pr.tolist()
|
| 646 |
+
q = q.tolist()
|
| 647 |
+
tran_tp_error = tran_tp_error.tolist()
|
| 648 |
+
rot_tp_error = rot_tp_error.tolist()
|
| 649 |
+
scale_tp_error = scale_tp_error.tolist()
|
| 650 |
+
|
| 651 |
+
for i in range(nd - 1, 0, -1):
|
| 652 |
+
if pr[i] > pr[i - 1]:
|
| 653 |
+
pr[i - 1] = pr[i]
|
| 654 |
+
|
| 655 |
+
inds = np.searchsorted(rc, p.recThrs, side="left")
|
| 656 |
+
|
| 657 |
+
try:
|
| 658 |
+
for ri, pi in enumerate(inds):
|
| 659 |
+
q[ri] = pr[pi]
|
| 660 |
+
ss[ri] = dtScoresSorted[pi]
|
| 661 |
+
if self.iou_type == "dist":
|
| 662 |
+
tran_tp_error[ri] = tems[t][pi]
|
| 663 |
+
rot_tp_error[ri] = oems[t][pi]
|
| 664 |
+
scale_tp_error[ri] = sems[t][pi]
|
| 665 |
+
except:
|
| 666 |
+
pass
|
| 667 |
+
|
| 668 |
+
precision[t, :, k, a, m] = np.array(q)
|
| 669 |
+
scores[t, :, k, a, m] = np.array(ss)
|
| 670 |
+
|
| 671 |
+
if self.iou_type == "dist":
|
| 672 |
+
trans_tp_errors[t, :, k, a, m] = np.array(
|
| 673 |
+
tran_tp_error
|
| 674 |
+
)
|
| 675 |
+
rot_tp_errors[t, :, k, a, m] = np.array(
|
| 676 |
+
rot_tp_error
|
| 677 |
+
)
|
| 678 |
+
scale_tp_errors[t, :, k, a, m] = np.array(
|
| 679 |
+
scale_tp_error
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
self.evals_per_cat_area = evals_per_cat_area
|
| 683 |
+
|
| 684 |
+
self.eval = {
|
| 685 |
+
"params": p,
|
| 686 |
+
"counts": [T, R, K, A, M],
|
| 687 |
+
"date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 688 |
+
"precision": precision,
|
| 689 |
+
"recall": recall,
|
| 690 |
+
"scores": scores,
|
| 691 |
+
"trans_tp_errors": trans_tp_errors,
|
| 692 |
+
"rot_tp_errors": rot_tp_errors,
|
| 693 |
+
"scale_tp_errors": scale_tp_errors,
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
toc = time.time()
|
| 697 |
+
print("DONE (t={:0.2f}s).".format(toc - tic))
|
| 698 |
+
|
| 699 |
+
def evaluate(self) -> None:
|
| 700 |
+
"""Run per image evaluation on given images.
|
| 701 |
+
|
| 702 |
+
It will store results (a list of dict) in self.evalImgs
|
| 703 |
+
"""
|
| 704 |
+
print("Running per image evaluation...")
|
| 705 |
+
|
| 706 |
+
p = self.params
|
| 707 |
+
print(f"Evaluate annotation type *{p.iouType}*")
|
| 708 |
+
|
| 709 |
+
tic = time.time()
|
| 710 |
+
|
| 711 |
+
p.imgIds = list(np.unique(p.imgIds))
|
| 712 |
+
if p.useCats:
|
| 713 |
+
p.catIds = list(np.unique(p.catIds))
|
| 714 |
+
|
| 715 |
+
p.maxDets = sorted(p.maxDets)
|
| 716 |
+
self.params = p
|
| 717 |
+
|
| 718 |
+
self._prepare()
|
| 719 |
+
|
| 720 |
+
catIds = p.catIds if p.useCats else [-1]
|
| 721 |
+
|
| 722 |
+
# loop through images, area range, max detection number
|
| 723 |
+
self.ious = {
|
| 724 |
+
(imgId, catId): self.computeIoU(imgId, catId)
|
| 725 |
+
for imgId in p.imgIds
|
| 726 |
+
for catId in catIds
|
| 727 |
+
}
|
| 728 |
+
|
| 729 |
+
maxDet = p.maxDets[-1]
|
| 730 |
+
|
| 731 |
+
self.evalImgs = [
|
| 732 |
+
self.evaluateImg(imgId, catId, areaRng, maxDet)
|
| 733 |
+
for catId in catIds
|
| 734 |
+
for areaRng in p.areaRng
|
| 735 |
+
for imgId in p.imgIds
|
| 736 |
+
]
|
| 737 |
+
|
| 738 |
+
self._paramsEval = copy.deepcopy(self.params)
|
| 739 |
+
|
| 740 |
+
toc = time.time()
|
| 741 |
+
print("DONE (t={:0.2f}s).".format(toc - tic))
|
| 742 |
+
|
| 743 |
+
def computeIoU(self, imgId, catId) -> tuple[NDArrayF32, NDArrayF32]:
|
| 744 |
+
"""Computes the IoUs by sorting based on score"""
|
| 745 |
+
p = self.params
|
| 746 |
+
|
| 747 |
+
if p.useCats:
|
| 748 |
+
gt = self._gts[imgId, catId]
|
| 749 |
+
dt = self._dts[imgId, catId]
|
| 750 |
+
else:
|
| 751 |
+
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 752 |
+
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 753 |
+
|
| 754 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 755 |
+
return []
|
| 756 |
+
|
| 757 |
+
inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
|
| 758 |
+
dt = [dt[i] for i in inds]
|
| 759 |
+
if len(dt) > p.maxDets[-1]:
|
| 760 |
+
dt = dt[0 : p.maxDets[-1]]
|
| 761 |
+
|
| 762 |
+
if self.mode == "2D":
|
| 763 |
+
g = [g["bbox"] for g in gt]
|
| 764 |
+
d = [d["bbox"] for d in dt]
|
| 765 |
+
elif self.mode == "3D":
|
| 766 |
+
g = [g["bbox3D"] for g in gt]
|
| 767 |
+
d = [d["bbox3D"] for d in dt]
|
| 768 |
+
|
| 769 |
+
# compute iou between each dt and gt region
|
| 770 |
+
# iscrowd is required in builtin maskUtils so we
|
| 771 |
+
# use a dummy buffer for it
|
| 772 |
+
iscrowd = [0 for _ in gt]
|
| 773 |
+
if self.mode == "2D":
|
| 774 |
+
ious = maskUtils.iou(d, g, iscrowd)
|
| 775 |
+
elif len(d) > 0 and len(g) > 0:
|
| 776 |
+
if p.iouType == "bbox":
|
| 777 |
+
dd = torch.tensor(d, dtype=torch.float32)
|
| 778 |
+
gg = torch.tensor(g, dtype=torch.float32)
|
| 779 |
+
|
| 780 |
+
ious = box3d_overlap(dd, gg).cpu().numpy()
|
| 781 |
+
else:
|
| 782 |
+
ious = np.zeros((len(d), len(g)))
|
| 783 |
+
|
| 784 |
+
dd = [d["center_cam"] for d in dt]
|
| 785 |
+
gg = [g["center_cam"] for g in gt]
|
| 786 |
+
|
| 787 |
+
ious = cdist(dd, gg, metric="euclidean")
|
| 788 |
+
else:
|
| 789 |
+
ious = []
|
| 790 |
+
|
| 791 |
+
in_prox = None
|
| 792 |
+
|
| 793 |
+
if self.eval_prox:
|
| 794 |
+
g = [g["bbox"] for g in gt]
|
| 795 |
+
d = [d["bbox"] for d in dt]
|
| 796 |
+
iscrowd = [0 for o in gt]
|
| 797 |
+
ious2d = maskUtils.iou(d, g, iscrowd)
|
| 798 |
+
|
| 799 |
+
if type(ious2d) == list:
|
| 800 |
+
in_prox = []
|
| 801 |
+
|
| 802 |
+
else:
|
| 803 |
+
in_prox = ious2d > p.proximity_thresh
|
| 804 |
+
|
| 805 |
+
return ious, in_prox
|
| 806 |
+
|
| 807 |
+
def evaluateImg(self, imgId, catId, aRng, maxDet):
|
| 808 |
+
"""
|
| 809 |
+
Perform evaluation for single category and image
|
| 810 |
+
Returns:
|
| 811 |
+
dict (single image results)
|
| 812 |
+
"""
|
| 813 |
+
|
| 814 |
+
p = self.params
|
| 815 |
+
if p.useCats:
|
| 816 |
+
gt = self._gts[imgId, catId]
|
| 817 |
+
dt = self._dts[imgId, catId]
|
| 818 |
+
|
| 819 |
+
else:
|
| 820 |
+
gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
|
| 821 |
+
dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
|
| 822 |
+
|
| 823 |
+
if len(gt) == 0 and len(dt) == 0:
|
| 824 |
+
return None
|
| 825 |
+
|
| 826 |
+
flag_range = "area" if self.mode == "2D" else "depth"
|
| 827 |
+
flag_ignore = "ignore2D" if self.mode == "2D" else "ignore3D"
|
| 828 |
+
|
| 829 |
+
for g in gt:
|
| 830 |
+
if g[flag_ignore] or (
|
| 831 |
+
g[flag_range] < aRng[0] or g[flag_range] > aRng[1]
|
| 832 |
+
):
|
| 833 |
+
g["_ignore"] = 1
|
| 834 |
+
else:
|
| 835 |
+
g["_ignore"] = 0
|
| 836 |
+
|
| 837 |
+
# sort dt highest score first, sort gt ignore last
|
| 838 |
+
gtind = np.argsort([g["_ignore"] for g in gt], kind="mergesort")
|
| 839 |
+
gt = [gt[i] for i in gtind]
|
| 840 |
+
dtind = np.argsort([-d["score"] for d in dt], kind="mergesort")
|
| 841 |
+
dt = [dt[i] for i in dtind[0:maxDet]]
|
| 842 |
+
|
| 843 |
+
# load computed ious
|
| 844 |
+
ious = (
|
| 845 |
+
self.ious[imgId, catId][0][:, gtind]
|
| 846 |
+
if len(self.ious[imgId, catId][0]) > 0
|
| 847 |
+
else self.ious[imgId, catId][0]
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
if self.eval_prox:
|
| 851 |
+
in_prox = (
|
| 852 |
+
self.ious[imgId, catId][1][:, gtind]
|
| 853 |
+
if len(self.ious[imgId, catId][1]) > 0
|
| 854 |
+
else self.ious[imgId, catId][1]
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
T = len(p.iouThrs)
|
| 858 |
+
G = len(gt)
|
| 859 |
+
D = len(dt)
|
| 860 |
+
gtm = np.zeros((T, G))
|
| 861 |
+
dtm = np.zeros((T, D))
|
| 862 |
+
tem = np.ones((T, D)) # Translation Error
|
| 863 |
+
sem = np.ones((T, D)) # Scale Error
|
| 864 |
+
oem = np.ones((T, D)) # Oritentation Error
|
| 865 |
+
gtIg = np.array([g["_ignore"] for g in gt])
|
| 866 |
+
dtIg = np.zeros((T, D))
|
| 867 |
+
|
| 868 |
+
dist_thres = 1
|
| 869 |
+
if not len(ious) == 0:
|
| 870 |
+
for tind, t in enumerate(p.iouThrs):
|
| 871 |
+
for dind, d in enumerate(dt):
|
| 872 |
+
|
| 873 |
+
# information about best match so far (m=-1 -> unmatched)
|
| 874 |
+
iou = min([t, 1 - 1e-10])
|
| 875 |
+
m = -1
|
| 876 |
+
|
| 877 |
+
for gind, g in enumerate(gt):
|
| 878 |
+
# in case of proximity evaluation, if not in proximity continue
|
| 879 |
+
if self.eval_prox and not in_prox[dind, gind]:
|
| 880 |
+
continue
|
| 881 |
+
|
| 882 |
+
# if this gt already matched, continue
|
| 883 |
+
if gtm[tind, gind] > 0:
|
| 884 |
+
continue
|
| 885 |
+
|
| 886 |
+
# if dt matched to reg gt, and on ignore gt, stop
|
| 887 |
+
if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1:
|
| 888 |
+
break
|
| 889 |
+
|
| 890 |
+
# continue to next gt unless better match made
|
| 891 |
+
if p.iouType == "bbox" and ious[dind, gind] < iou:
|
| 892 |
+
continue
|
| 893 |
+
|
| 894 |
+
if p.iouType == "dist":
|
| 895 |
+
# Compute Object Radius
|
| 896 |
+
gt_obj_radius = (
|
| 897 |
+
np.linalg.norm(np.array(g["dimensions"])) / 2
|
| 898 |
+
)
|
| 899 |
+
if ious[dind, gind] > gt_obj_radius * iou:
|
| 900 |
+
continue
|
| 901 |
+
else:
|
| 902 |
+
dist_thres = gt_obj_radius * iou
|
| 903 |
+
|
| 904 |
+
# if match successful and best so far, store appropriately
|
| 905 |
+
iou = ious[dind, gind]
|
| 906 |
+
m = gind
|
| 907 |
+
|
| 908 |
+
# if match made store id of match for both dt and gt
|
| 909 |
+
if m == -1:
|
| 910 |
+
continue
|
| 911 |
+
|
| 912 |
+
dtIg[tind, dind] = gtIg[m]
|
| 913 |
+
dtm[tind, dind] = gt[m]["id"]
|
| 914 |
+
gtm[tind, m] = d["id"]
|
| 915 |
+
|
| 916 |
+
if p.iouType == "dist":
|
| 917 |
+
# Translation Error
|
| 918 |
+
tem[tind, dind] = np.linalg.norm(
|
| 919 |
+
np.array(d["center_cam"])
|
| 920 |
+
- np.array(gt[m]["center_cam"])
|
| 921 |
+
) / (dist_thres)
|
| 922 |
+
|
| 923 |
+
# Orientation Error
|
| 924 |
+
oem[tind, dind] = (
|
| 925 |
+
so3_relative_angle(
|
| 926 |
+
torch.tensor(d["R_cam"])[None],
|
| 927 |
+
torch.tensor(gt[m]["R_cam"])[None],
|
| 928 |
+
cos_bound=1e-2,
|
| 929 |
+
eps=1e-2,
|
| 930 |
+
).item()
|
| 931 |
+
/ np.pi
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
# Scale Error
|
| 935 |
+
min_whl = np.minimum(
|
| 936 |
+
d["dimensions"], gt[m]["dimensions"]
|
| 937 |
+
)
|
| 938 |
+
volume_annotation = np.prod(gt[m]["dimensions"])
|
| 939 |
+
volume_result = np.prod(d["dimensions"])
|
| 940 |
+
|
| 941 |
+
intersection = np.prod(min_whl)
|
| 942 |
+
union = (
|
| 943 |
+
volume_annotation + volume_result - intersection
|
| 944 |
+
)
|
| 945 |
+
scale_iou = intersection / union
|
| 946 |
+
|
| 947 |
+
sem[tind, dind] = 1 - scale_iou
|
| 948 |
+
|
| 949 |
+
# set unmatched detections outside of area range to ignore
|
| 950 |
+
a = np.array(
|
| 951 |
+
[d[flag_range] < aRng[0] or d[flag_range] > aRng[1] for d in dt]
|
| 952 |
+
).reshape((1, len(dt)))
|
| 953 |
+
|
| 954 |
+
dtIg = np.logical_or(
|
| 955 |
+
dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0))
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
# in case of proximity evaluation, ignore detections which are far from gt regions
|
| 959 |
+
if self.eval_prox and len(in_prox) > 0:
|
| 960 |
+
dt_far = in_prox.any(1) == 0
|
| 961 |
+
dtIg = np.logical_or(
|
| 962 |
+
dtIg, np.repeat(dt_far.reshape((1, len(dt))), T, 0)
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# store results for given image and category
|
| 966 |
+
return {
|
| 967 |
+
"image_id": imgId,
|
| 968 |
+
"category_id": catId,
|
| 969 |
+
"aRng": aRng,
|
| 970 |
+
"maxDet": maxDet,
|
| 971 |
+
"dtIds": [d["id"] for d in dt],
|
| 972 |
+
"gtIds": [g["id"] for g in gt],
|
| 973 |
+
"dtMatches": dtm,
|
| 974 |
+
"gtMatches": gtm,
|
| 975 |
+
"dtScores": [d["score"] for d in dt],
|
| 976 |
+
"gtIgnore": gtIg,
|
| 977 |
+
"dtIgnore": dtIg,
|
| 978 |
+
"dtTranslationError": tem,
|
| 979 |
+
"dtScaleError": sem,
|
| 980 |
+
"dtOrientationError": oem,
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
def summarize(self):
|
| 984 |
+
"""
|
| 985 |
+
Compute and display summary metrics for evaluation results.
|
| 986 |
+
Note this functin can *only* be applied on the default parameter setting
|
| 987 |
+
"""
|
| 988 |
+
|
| 989 |
+
def _summarize(
|
| 990 |
+
mode, ap=1, iouThr=None, areaRng="all", maxDets=100, log_str=""
|
| 991 |
+
):
|
| 992 |
+
p = self.params
|
| 993 |
+
eval = self.eval
|
| 994 |
+
|
| 995 |
+
if mode == "2D":
|
| 996 |
+
if self.iou_type == "bbox":
|
| 997 |
+
iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
|
| 998 |
+
else:
|
| 999 |
+
iStr = " {:<18} {} @[ Dist={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
|
| 1000 |
+
|
| 1001 |
+
elif mode == "3D":
|
| 1002 |
+
if self.iou_type == "bbox":
|
| 1003 |
+
iStr = " {:<18} {} @[ IoU={:<9} | depth={:>6s} | maxDets={:>3d} ] = {:0.3f}"
|
| 1004 |
+
else:
|
| 1005 |
+
iStr = " {:<18} {} @[ Dist={:<9} | depth={:>6s} | maxDets={:>3d} ] = {:0.3f}"
|
| 1006 |
+
|
| 1007 |
+
titleStr = "Average Precision" if ap == 1 else "Average Recall"
|
| 1008 |
+
typeStr = "(AP)" if ap == 1 else "(AR)"
|
| 1009 |
+
|
| 1010 |
+
iouStr = (
|
| 1011 |
+
"{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
|
| 1012 |
+
if iouThr is None
|
| 1013 |
+
else "{:0.2f}".format(iouThr)
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
aind = [
|
| 1017 |
+
i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng
|
| 1018 |
+
]
|
| 1019 |
+
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
|
| 1020 |
+
|
| 1021 |
+
if ap == 1:
|
| 1022 |
+
|
| 1023 |
+
# dimension of precision: [TxRxKxAxM]
|
| 1024 |
+
s = eval["precision"]
|
| 1025 |
+
|
| 1026 |
+
# IoU
|
| 1027 |
+
if iouThr is not None:
|
| 1028 |
+
t = np.where(np.isclose(iouThr, p.iouThrs.astype(float)))[
|
| 1029 |
+
0
|
| 1030 |
+
]
|
| 1031 |
+
s = s[t]
|
| 1032 |
+
|
| 1033 |
+
s = s[:, :, :, aind, mind]
|
| 1034 |
+
|
| 1035 |
+
else:
|
| 1036 |
+
# dimension of recall: [TxKxAxM]
|
| 1037 |
+
s = eval["recall"]
|
| 1038 |
+
if iouThr is not None:
|
| 1039 |
+
t = np.where(iouThr == p.iouThrs)[0]
|
| 1040 |
+
s = s[t]
|
| 1041 |
+
s = s[:, :, aind, mind]
|
| 1042 |
+
|
| 1043 |
+
if len(s[s > -1]) == 0:
|
| 1044 |
+
mean_s = -1
|
| 1045 |
+
|
| 1046 |
+
else:
|
| 1047 |
+
mean_s = np.mean(s[s > -1])
|
| 1048 |
+
|
| 1049 |
+
if log_str != "":
|
| 1050 |
+
log_str += "\n"
|
| 1051 |
+
|
| 1052 |
+
log_str += "mode={} ".format(mode) + iStr.format(
|
| 1053 |
+
titleStr, typeStr, iouStr, areaRng, maxDets, mean_s
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
return mean_s, log_str
|
| 1057 |
+
|
| 1058 |
+
def _summarizeDets(mode):
|
| 1059 |
+
|
| 1060 |
+
params = self.params
|
| 1061 |
+
|
| 1062 |
+
# Define the thresholds to be printed
|
| 1063 |
+
if mode == "2D":
|
| 1064 |
+
thres = [0.5, 0.75, 0.95]
|
| 1065 |
+
else:
|
| 1066 |
+
if self.iou_type == "bbox":
|
| 1067 |
+
thres = [0.15, 0.25, 0.50]
|
| 1068 |
+
else:
|
| 1069 |
+
thres = [0.5, 0.75, 1.0]
|
| 1070 |
+
|
| 1071 |
+
stats = np.zeros((13,))
|
| 1072 |
+
stats[0], log_str = _summarize(mode, 1)
|
| 1073 |
+
|
| 1074 |
+
stats[1], log_str = _summarize(
|
| 1075 |
+
mode,
|
| 1076 |
+
1,
|
| 1077 |
+
iouThr=thres[0],
|
| 1078 |
+
maxDets=params.maxDets[2],
|
| 1079 |
+
log_str=log_str,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
stats[2], log_str = _summarize(
|
| 1083 |
+
mode,
|
| 1084 |
+
1,
|
| 1085 |
+
iouThr=thres[1],
|
| 1086 |
+
maxDets=params.maxDets[2],
|
| 1087 |
+
log_str=log_str,
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
stats[3], log_str = _summarize(
|
| 1091 |
+
mode,
|
| 1092 |
+
1,
|
| 1093 |
+
iouThr=thres[2],
|
| 1094 |
+
maxDets=params.maxDets[2],
|
| 1095 |
+
log_str=log_str,
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
stats[4], log_str = _summarize(
|
| 1099 |
+
mode,
|
| 1100 |
+
1,
|
| 1101 |
+
areaRng=params.areaRngLbl[1],
|
| 1102 |
+
maxDets=params.maxDets[2],
|
| 1103 |
+
log_str=log_str,
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
stats[5], log_str = _summarize(
|
| 1107 |
+
mode,
|
| 1108 |
+
1,
|
| 1109 |
+
areaRng=params.areaRngLbl[2],
|
| 1110 |
+
maxDets=params.maxDets[2],
|
| 1111 |
+
log_str=log_str,
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
stats[6], log_str = _summarize(
|
| 1115 |
+
mode,
|
| 1116 |
+
1,
|
| 1117 |
+
areaRng=params.areaRngLbl[3],
|
| 1118 |
+
maxDets=params.maxDets[2],
|
| 1119 |
+
log_str=log_str,
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
stats[7], log_str = _summarize(
|
| 1123 |
+
mode, 0, maxDets=params.maxDets[0], log_str=log_str
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
stats[8], log_str = _summarize(
|
| 1127 |
+
mode, 0, maxDets=params.maxDets[1], log_str=log_str
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
stats[9], log_str = _summarize(
|
| 1131 |
+
mode, 0, maxDets=params.maxDets[2], log_str=log_str
|
| 1132 |
+
)
|
| 1133 |
+
|
| 1134 |
+
stats[10], log_str = _summarize(
|
| 1135 |
+
mode,
|
| 1136 |
+
0,
|
| 1137 |
+
areaRng=params.areaRngLbl[1],
|
| 1138 |
+
maxDets=params.maxDets[2],
|
| 1139 |
+
log_str=log_str,
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
stats[11], log_str = _summarize(
|
| 1143 |
+
mode,
|
| 1144 |
+
0,
|
| 1145 |
+
areaRng=params.areaRngLbl[2],
|
| 1146 |
+
maxDets=params.maxDets[2],
|
| 1147 |
+
log_str=log_str,
|
| 1148 |
+
)
|
| 1149 |
+
|
| 1150 |
+
stats[12], log_str = _summarize(
|
| 1151 |
+
mode,
|
| 1152 |
+
0,
|
| 1153 |
+
areaRng=params.areaRngLbl[3],
|
| 1154 |
+
maxDets=params.maxDets[2],
|
| 1155 |
+
log_str=log_str,
|
| 1156 |
+
)
|
| 1157 |
+
|
| 1158 |
+
return stats, log_str
|
| 1159 |
+
|
| 1160 |
+
if not self.eval:
|
| 1161 |
+
raise Exception("Please run accumulate() first")
|
| 1162 |
+
|
| 1163 |
+
stats, log_str = _summarizeDets(self.mode)
|
| 1164 |
+
self.stats = stats
|
| 1165 |
+
|
| 1166 |
+
return log_str
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
class Detect3DParams:
|
| 1170 |
+
"""Params for the 3d detection evaluation API."""
|
| 1171 |
+
|
| 1172 |
+
def __init__(
|
| 1173 |
+
self,
|
| 1174 |
+
mode: str = "2D",
|
| 1175 |
+
iouType: str = "bbox",
|
| 1176 |
+
proximity_thresh: float = 0.3,
|
| 1177 |
+
) -> None:
|
| 1178 |
+
"""Create an instance of Detect3DParams.
|
| 1179 |
+
|
| 1180 |
+
Args:
|
| 1181 |
+
mode: (str) defines whether to evaluate 2D or 3D performance.
|
| 1182 |
+
iouType: (str) defines the type of IoU to be used for evaluation.
|
| 1183 |
+
proximity_thresh (float): It defines the neighborhood when
|
| 1184 |
+
evaluating on non-exhaustively annotated datasets.
|
| 1185 |
+
"""
|
| 1186 |
+
assert iouType in {"bbox", "dist"}, f"Invalid iouType {iouType}."
|
| 1187 |
+
self.iouType = iouType
|
| 1188 |
+
|
| 1189 |
+
if mode == "2D":
|
| 1190 |
+
self.setDet2DParams()
|
| 1191 |
+
elif mode == "3D":
|
| 1192 |
+
self.setDet3DParams()
|
| 1193 |
+
else:
|
| 1194 |
+
raise Exception(f"{mode} mode is not supported")
|
| 1195 |
+
self.mode = mode
|
| 1196 |
+
self.proximity_thresh = proximity_thresh
|
| 1197 |
+
|
| 1198 |
+
def setDet2DParams(self) -> None:
|
| 1199 |
+
"""Set parameters for 2D detection evaluation."""
|
| 1200 |
+
self.imgIds = []
|
| 1201 |
+
self.catIds = []
|
| 1202 |
+
|
| 1203 |
+
# np.arange causes trouble. the data point on arange is slightly larger than the true value
|
| 1204 |
+
self.iouThrs = np.linspace(
|
| 1205 |
+
0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
self.recThrs = np.linspace(
|
| 1209 |
+
0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True
|
| 1210 |
+
)
|
| 1211 |
+
|
| 1212 |
+
self.maxDets = [1, 10, 100]
|
| 1213 |
+
self.areaRng = [
|
| 1214 |
+
[0**2, 1e5**2],
|
| 1215 |
+
[0**2, 32**2],
|
| 1216 |
+
[32**2, 96**2],
|
| 1217 |
+
[96**2, 1e5**2],
|
| 1218 |
+
]
|
| 1219 |
+
|
| 1220 |
+
self.areaRngLbl = ["all", "small", "medium", "large"]
|
| 1221 |
+
self.useCats = 1
|
| 1222 |
+
|
| 1223 |
+
def setDet3DParams(self) -> None:
|
| 1224 |
+
"""Set parameters for 3D detection evaluation."""
|
| 1225 |
+
self.imgIds = []
|
| 1226 |
+
self.catIds = []
|
| 1227 |
+
|
| 1228 |
+
# np.arange causes trouble. The data point on arange is slightly
|
| 1229 |
+
# larger than the true value
|
| 1230 |
+
if self.iouType == "bbox":
|
| 1231 |
+
self.iouThrs = np.linspace(
|
| 1232 |
+
0.05,
|
| 1233 |
+
0.5,
|
| 1234 |
+
int(np.round((0.5 - 0.05) / 0.05)) + 1,
|
| 1235 |
+
endpoint=True,
|
| 1236 |
+
)
|
| 1237 |
+
else:
|
| 1238 |
+
self.iouThrs = np.linspace(
|
| 1239 |
+
0.5, 1.0, int(np.round((1.00 - 0.5) / 0.05)) + 1, endpoint=True
|
| 1240 |
+
)
|
| 1241 |
+
|
| 1242 |
+
self.recThrs = np.linspace(
|
| 1243 |
+
0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
self.maxDets = [1, 10, 100]
|
| 1247 |
+
self.areaRng = [[0, 1e5], [0, 10], [10, 35], [35, 1e5]]
|
| 1248 |
+
self.areaRngLbl = ["all", "near", "medium", "far"]
|
| 1249 |
+
self.useCats = 1
|
opendet3d/eval/omni3d.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Omni3D 3D detection evaluation."""
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import copy
|
| 5 |
+
import io
|
| 6 |
+
import itertools
|
| 7 |
+
import os
|
| 8 |
+
from collections.abc import Sequence
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from terminaltables import AsciiTable
|
| 12 |
+
from vis4d.common.logging import rank_zero_info
|
| 13 |
+
from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber
|
| 14 |
+
from vis4d.eval.base import Evaluator
|
| 15 |
+
|
| 16 |
+
from opendet3d.data.datasets.omni3d.omni3d_classes import omni3d_class_map
|
| 17 |
+
from opendet3d.data.datasets.omni3d.util import get_dataset_det_map
|
| 18 |
+
|
| 19 |
+
from .detect3d import Detect3Deval, Detect3DEvaluator
|
| 20 |
+
|
| 21 |
+
omni3d_in = {
|
| 22 |
+
"stationery",
|
| 23 |
+
"sink",
|
| 24 |
+
"table",
|
| 25 |
+
"floor mat",
|
| 26 |
+
"bottle",
|
| 27 |
+
"bookcase",
|
| 28 |
+
"bin",
|
| 29 |
+
"blinds",
|
| 30 |
+
"pillow",
|
| 31 |
+
"bicycle",
|
| 32 |
+
"refrigerator",
|
| 33 |
+
"night stand",
|
| 34 |
+
"chair",
|
| 35 |
+
"sofa",
|
| 36 |
+
"books",
|
| 37 |
+
"oven",
|
| 38 |
+
"towel",
|
| 39 |
+
"cabinet",
|
| 40 |
+
"window",
|
| 41 |
+
"curtain",
|
| 42 |
+
"bathtub",
|
| 43 |
+
"laptop",
|
| 44 |
+
"desk",
|
| 45 |
+
"television",
|
| 46 |
+
"clothes",
|
| 47 |
+
"stove",
|
| 48 |
+
"cup",
|
| 49 |
+
"shelves",
|
| 50 |
+
"box",
|
| 51 |
+
"shoes",
|
| 52 |
+
"mirror",
|
| 53 |
+
"door",
|
| 54 |
+
"picture",
|
| 55 |
+
"lamp",
|
| 56 |
+
"machine",
|
| 57 |
+
"counter",
|
| 58 |
+
"bed",
|
| 59 |
+
"toilet",
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
omni3d_out = {
|
| 63 |
+
"cyclist",
|
| 64 |
+
"pedestrian",
|
| 65 |
+
"trailer",
|
| 66 |
+
"bus",
|
| 67 |
+
"motorcycle",
|
| 68 |
+
"car",
|
| 69 |
+
"barrier",
|
| 70 |
+
"truck",
|
| 71 |
+
"van",
|
| 72 |
+
"traffic cone",
|
| 73 |
+
"bicycle",
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Omni3DEvaluator(Evaluator):
|
| 78 |
+
"""Omni3D 3D detection evaluator."""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
data_root: str = "data/omni3d",
|
| 83 |
+
omni3d50: bool = True,
|
| 84 |
+
datasets: Sequence[str] = (
|
| 85 |
+
"KITTI_test",
|
| 86 |
+
"nuScenes_test",
|
| 87 |
+
"SUNRGBD_test",
|
| 88 |
+
"Hypersim_test",
|
| 89 |
+
"ARKitScenes_test",
|
| 90 |
+
"Objectron_test",
|
| 91 |
+
),
|
| 92 |
+
per_class_eval: bool = True,
|
| 93 |
+
) -> None:
|
| 94 |
+
"""Initialize the evaluator."""
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.id_to_name = {v: k for k, v in omni3d_class_map.items()}
|
| 97 |
+
self.dataset_names = datasets
|
| 98 |
+
self.per_class_eval = per_class_eval
|
| 99 |
+
|
| 100 |
+
# Each dataset evaluator is stored here
|
| 101 |
+
self.evaluators: dict[str, Detect3DEvaluator] = {}
|
| 102 |
+
|
| 103 |
+
# These store the evaluations for each category and area,
|
| 104 |
+
# concatenated from ALL evaluated datasets. Doing so avoids
|
| 105 |
+
# the need to re-compute them when accumulating results.
|
| 106 |
+
self.evals_per_cat_area2D = {}
|
| 107 |
+
self.evals_per_cat_area3D = {}
|
| 108 |
+
|
| 109 |
+
self.overall_imgIds = set()
|
| 110 |
+
self.overall_catIds = set()
|
| 111 |
+
|
| 112 |
+
for dataset_name in self.dataset_names:
|
| 113 |
+
annotation = os.path.join(
|
| 114 |
+
data_root, "annotations", f"{dataset_name}.json"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
det_map = get_dataset_det_map(
|
| 118 |
+
dataset_name=dataset_name, omni3d50=omni3d50
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# create an individual dataset evaluator
|
| 122 |
+
self.evaluators[dataset_name] = Detect3DEvaluator(
|
| 123 |
+
det_map,
|
| 124 |
+
cat_map=omni3d_class_map,
|
| 125 |
+
annotation=annotation,
|
| 126 |
+
eval_prox=(
|
| 127 |
+
"Objectron" in dataset_name or "SUNRGBD" in dataset_name
|
| 128 |
+
),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.overall_imgIds.update(
|
| 132 |
+
set(self.evaluators[dataset_name]._coco_gt.getImgIds())
|
| 133 |
+
)
|
| 134 |
+
self.overall_catIds.update(
|
| 135 |
+
set(self.evaluators[dataset_name]._coco_gt.getCatIds())
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def __repr__(self) -> str:
|
| 139 |
+
"""Returns the string representation of the object."""
|
| 140 |
+
datasets_str = ", ".join(self.dataset_names)
|
| 141 |
+
return f"Omni3DEvaluator ({datasets_str})"
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def metrics(self) -> list[str]:
|
| 145 |
+
"""Supported metrics.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
list[str]: Metrics to evaluate.
|
| 149 |
+
"""
|
| 150 |
+
return ["2D", "3D"]
|
| 151 |
+
|
| 152 |
+
def reset(self) -> None:
|
| 153 |
+
"""Reset the saved predictions to start new round of evaluation."""
|
| 154 |
+
for dataset_name in self.dataset_names:
|
| 155 |
+
self.evaluators[dataset_name].reset()
|
| 156 |
+
self.evals_per_cat_area2D.clear()
|
| 157 |
+
self.evals_per_cat_area3D.clear()
|
| 158 |
+
|
| 159 |
+
def gather(self, gather_func: GenericFunc) -> None:
|
| 160 |
+
"""Accumulate predictions across processes."""
|
| 161 |
+
for dataset_name in self.dataset_names:
|
| 162 |
+
self.evaluators[dataset_name].gather(gather_func)
|
| 163 |
+
|
| 164 |
+
def process_batch(
|
| 165 |
+
self,
|
| 166 |
+
coco_image_id: list[int],
|
| 167 |
+
dataset_names: list[str],
|
| 168 |
+
pred_boxes: list[NDArrayNumber],
|
| 169 |
+
pred_scores: list[NDArrayNumber],
|
| 170 |
+
pred_classes: list[NDArrayNumber],
|
| 171 |
+
pred_boxes3d: list[NDArrayNumber] | None = None,
|
| 172 |
+
) -> None:
|
| 173 |
+
"""Process sample and convert detections to coco format."""
|
| 174 |
+
for i, dataset_name in enumerate(dataset_names):
|
| 175 |
+
self.evaluators[dataset_name].process_batch(
|
| 176 |
+
[coco_image_id[i]],
|
| 177 |
+
[pred_boxes[i]],
|
| 178 |
+
[pred_scores[i]],
|
| 179 |
+
[pred_classes[i]],
|
| 180 |
+
pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
|
| 184 |
+
"""Evaluate predictions and return the results."""
|
| 185 |
+
assert metric in self.metrics, f"Unsupported metric: {metric}"
|
| 186 |
+
|
| 187 |
+
log_dict = {}
|
| 188 |
+
|
| 189 |
+
for dataset_name in self.dataset_names:
|
| 190 |
+
rank_zero_info(f"Evaluating {dataset_name}...")
|
| 191 |
+
per_dataset_log_dict, dataset_log_str = self.evaluators[
|
| 192 |
+
dataset_name
|
| 193 |
+
].evaluate(metric)
|
| 194 |
+
|
| 195 |
+
log_dict[f"AP_{dataset_name}"] = per_dataset_log_dict["AP"]
|
| 196 |
+
|
| 197 |
+
rank_zero_info(dataset_log_str + "\n")
|
| 198 |
+
|
| 199 |
+
# store the partially accumulated evaluations per category per area
|
| 200 |
+
if metric == "2D":
|
| 201 |
+
for key, item in self.evaluators[
|
| 202 |
+
dataset_name
|
| 203 |
+
].bbox_2D_evals_per_cat_area.items():
|
| 204 |
+
if not key in self.evals_per_cat_area2D:
|
| 205 |
+
self.evals_per_cat_area2D[key] = []
|
| 206 |
+
self.evals_per_cat_area2D[key] += item
|
| 207 |
+
else:
|
| 208 |
+
for key, item in self.evaluators[
|
| 209 |
+
dataset_name
|
| 210 |
+
].bbox_3D_evals_per_cat_area.items():
|
| 211 |
+
if not key in self.evals_per_cat_area3D:
|
| 212 |
+
self.evals_per_cat_area3D[key] = []
|
| 213 |
+
self.evals_per_cat_area3D[key] += item
|
| 214 |
+
|
| 215 |
+
results_per_category_dict = {}
|
| 216 |
+
results_per_category = []
|
| 217 |
+
|
| 218 |
+
rank_zero_info(f"Evaluating Omni3D for {metric} Detection...")
|
| 219 |
+
|
| 220 |
+
evaluator = Detect3Deval(mode=metric)
|
| 221 |
+
evaluator.params.catIds = list(self.overall_catIds)
|
| 222 |
+
evaluator.params.imgIds = list(self.overall_imgIds)
|
| 223 |
+
evaluator.evalImgs = True
|
| 224 |
+
|
| 225 |
+
if metric == "2D":
|
| 226 |
+
evaluator.evals_per_cat_area = self.evals_per_cat_area2D
|
| 227 |
+
metrics = ["AP", "AP50", "AP75", "AP95", "APs", "APm", "APl"]
|
| 228 |
+
else:
|
| 229 |
+
evaluator.evals_per_cat_area = self.evals_per_cat_area3D
|
| 230 |
+
metrics = ["AP", "AP15", "AP25", "AP50", "APn", "APm", "APf"]
|
| 231 |
+
|
| 232 |
+
evaluator._paramsEval = copy.deepcopy(evaluator.params)
|
| 233 |
+
|
| 234 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
| 235 |
+
evaluator.accumulate()
|
| 236 |
+
log_str = "\n" + evaluator.summarize()
|
| 237 |
+
|
| 238 |
+
log_dict.update(dict(zip(metrics, evaluator.stats)))
|
| 239 |
+
|
| 240 |
+
if self.per_class_eval:
|
| 241 |
+
precisions = evaluator.eval["precision"]
|
| 242 |
+
for idx, cat_id in enumerate(self.overall_catIds):
|
| 243 |
+
cat_name = self.id_to_name[cat_id]
|
| 244 |
+
precision = precisions[:, :, idx, 0, -1]
|
| 245 |
+
precision = precision[precision > -1]
|
| 246 |
+
if precision.size:
|
| 247 |
+
ap = float(np.mean(precision).item())
|
| 248 |
+
else:
|
| 249 |
+
ap = float("nan")
|
| 250 |
+
|
| 251 |
+
results_per_category_dict[cat_name] = ap
|
| 252 |
+
results_per_category.append((f"{cat_name}", f"{ap:0.3f}"))
|
| 253 |
+
|
| 254 |
+
num_columns = min(6, len(results_per_category) * 2)
|
| 255 |
+
results_flatten = list(itertools.chain(*results_per_category))
|
| 256 |
+
headers = ["category", "AP"] * (num_columns // 2)
|
| 257 |
+
results_2d = itertools.zip_longest(
|
| 258 |
+
*[results_flatten[i::num_columns] for i in range(num_columns)]
|
| 259 |
+
)
|
| 260 |
+
table_data = [headers] + list(results_2d)
|
| 261 |
+
table = AsciiTable(table_data)
|
| 262 |
+
log_str = f"\n{table.table}\n{log_str}"
|
| 263 |
+
|
| 264 |
+
# Omni3D Outdoor performance
|
| 265 |
+
ap_out_lst = []
|
| 266 |
+
for cat in omni3d_out:
|
| 267 |
+
ap_out_lst.append(results_per_category_dict.get(cat, 0.0))
|
| 268 |
+
|
| 269 |
+
log_dict["Omni3D_Out"] = np.mean(ap_out_lst).item()
|
| 270 |
+
|
| 271 |
+
# Omni3D Indoor performance
|
| 272 |
+
ap_in_lst = []
|
| 273 |
+
for cat in omni3d_in:
|
| 274 |
+
ap_in_lst.append(results_per_category_dict.get(cat, 0.0))
|
| 275 |
+
|
| 276 |
+
log_dict["Omni3D_In"] = np.mean(ap_in_lst).item()
|
| 277 |
+
|
| 278 |
+
return log_dict, log_str
|
| 279 |
+
|
| 280 |
+
def save(self, metric: str, output_dir: str) -> None:
|
| 281 |
+
"""Save the results to json files."""
|
| 282 |
+
for dataset_name in self.dataset_names:
|
| 283 |
+
self.evaluators[dataset_name].save(
|
| 284 |
+
metric, output_dir, prefix=dataset_name
|
| 285 |
+
)
|
opendet3d/eval/open.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-data 3D detection evaluation."""
|
| 2 |
+
|
| 3 |
+
from collections.abc import Sequence
|
| 4 |
+
|
| 5 |
+
from vis4d.common.logging import rank_zero_info
|
| 6 |
+
from vis4d.common.typing import GenericFunc, MetricLogs, NDArrayNumber
|
| 7 |
+
from vis4d.eval.base import Evaluator
|
| 8 |
+
|
| 9 |
+
from .detect3d import Detect3DEvaluator
|
| 10 |
+
from .omni3d import Omni3DEvaluator
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class OpenDetect3DEvaluator(Evaluator):
|
| 14 |
+
"""Multi-data 3D detection evaluator."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
datasets: Sequence[str],
|
| 19 |
+
evaluators: Sequence[Detect3DEvaluator],
|
| 20 |
+
omni3d_evaluator: Omni3DEvaluator | None = None,
|
| 21 |
+
) -> None:
|
| 22 |
+
"""Initialize the evaluator."""
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.dataset_names = datasets
|
| 25 |
+
self.evaluators = {
|
| 26 |
+
name: evaluator for name, evaluator in zip(datasets, evaluators)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
self.omni3d_evaluator = omni3d_evaluator
|
| 30 |
+
|
| 31 |
+
def __repr__(self) -> str:
|
| 32 |
+
"""Returns the string representation of the object."""
|
| 33 |
+
datasets_str = ", ".join(self.dataset_names)
|
| 34 |
+
return f"Open 3D Object Detection Evaluator ({datasets_str})"
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def metrics(self) -> list[str]:
|
| 38 |
+
"""Supported metrics.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
list[str]: Metrics to evaluate.
|
| 42 |
+
"""
|
| 43 |
+
return ["2D", "3D"]
|
| 44 |
+
|
| 45 |
+
def reset(self) -> None:
|
| 46 |
+
"""Reset the saved predictions to start new round of evaluation."""
|
| 47 |
+
for dataset_name in self.dataset_names:
|
| 48 |
+
self.evaluators[dataset_name].reset()
|
| 49 |
+
|
| 50 |
+
if self.omni3d_evaluator is not None:
|
| 51 |
+
self.omni3d_evaluator.reset()
|
| 52 |
+
|
| 53 |
+
def gather(self, gather_func: GenericFunc) -> None:
|
| 54 |
+
"""Accumulate predictions across processes."""
|
| 55 |
+
for dataset_name in self.dataset_names:
|
| 56 |
+
self.evaluators[dataset_name].gather(gather_func)
|
| 57 |
+
|
| 58 |
+
if self.omni3d_evaluator is not None:
|
| 59 |
+
self.omni3d_evaluator.gather(gather_func)
|
| 60 |
+
|
| 61 |
+
def process_batch(
|
| 62 |
+
self,
|
| 63 |
+
coco_image_id: list[int],
|
| 64 |
+
dataset_names: list[str],
|
| 65 |
+
pred_boxes: list[NDArrayNumber],
|
| 66 |
+
pred_scores: list[NDArrayNumber],
|
| 67 |
+
pred_classes: list[NDArrayNumber],
|
| 68 |
+
pred_boxes3d: list[NDArrayNumber] | None = None,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""Process sample and convert detections to coco format."""
|
| 71 |
+
for i, dataset_name in enumerate(dataset_names):
|
| 72 |
+
if (
|
| 73 |
+
self.omni3d_evaluator is not None
|
| 74 |
+
and dataset_name in self.omni3d_evaluator.dataset_names
|
| 75 |
+
):
|
| 76 |
+
self.omni3d_evaluator.process_batch(
|
| 77 |
+
[coco_image_id[i]],
|
| 78 |
+
[dataset_name],
|
| 79 |
+
[pred_boxes[i]],
|
| 80 |
+
[pred_scores[i]],
|
| 81 |
+
[pred_classes[i]],
|
| 82 |
+
pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None,
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
self.evaluators[dataset_name].process_batch(
|
| 86 |
+
[coco_image_id[i]],
|
| 87 |
+
[pred_boxes[i]],
|
| 88 |
+
[pred_scores[i]],
|
| 89 |
+
[pred_classes[i]],
|
| 90 |
+
pred_boxes3d=[pred_boxes3d[i]] if pred_boxes3d else None,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def evaluate(self, metric: str) -> tuple[MetricLogs, str]:
|
| 94 |
+
"""Evaluate predictions and return the results."""
|
| 95 |
+
assert metric in self.metrics, f"Unsupported metric: {metric}"
|
| 96 |
+
|
| 97 |
+
log_dict = {}
|
| 98 |
+
log_str = ""
|
| 99 |
+
|
| 100 |
+
if self.omni3d_evaluator is not None:
|
| 101 |
+
log_dict_omni3d, omni3d_log_str = self.omni3d_evaluator.evaluate(
|
| 102 |
+
metric
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
log_dict.update(log_dict_omni3d)
|
| 106 |
+
log_str += omni3d_log_str
|
| 107 |
+
|
| 108 |
+
for dataset_name in self.dataset_names:
|
| 109 |
+
rank_zero_info(f"Evaluating {dataset_name}...")
|
| 110 |
+
per_dataset_log_dict, dataset_log_str = self.evaluators[
|
| 111 |
+
dataset_name
|
| 112 |
+
].evaluate(metric)
|
| 113 |
+
|
| 114 |
+
if "ODS" in per_dataset_log_dict:
|
| 115 |
+
score = "ODS"
|
| 116 |
+
else:
|
| 117 |
+
score = "AP"
|
| 118 |
+
|
| 119 |
+
log_dict[f"{score}_{dataset_name}"] = per_dataset_log_dict[score]
|
| 120 |
+
|
| 121 |
+
if self.evaluators[dataset_name].base_classes is not None:
|
| 122 |
+
log_dict[f"{score}_Base_{dataset_name}"] = (
|
| 123 |
+
per_dataset_log_dict[f"{score}_Base"]
|
| 124 |
+
)
|
| 125 |
+
log_dict[f"{score}_Novel_{dataset_name}"] = (
|
| 126 |
+
per_dataset_log_dict[f"{score}_Novel"]
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
log_str += f"\nCheck {dataset_name} results in log dict."
|
| 130 |
+
|
| 131 |
+
rank_zero_info(dataset_log_str + "\n")
|
| 132 |
+
|
| 133 |
+
return log_dict, log_str
|
| 134 |
+
|
| 135 |
+
def save(self, metric: str, output_dir: str) -> None:
|
| 136 |
+
"""Save the results to json files."""
|
| 137 |
+
for dataset_name in self.dataset_names:
|
| 138 |
+
self.evaluators[dataset_name].save(
|
| 139 |
+
metric, output_dir, prefix=dataset_name
|
| 140 |
+
)
|
opendet3d/model/__init__.py
ADDED
|
File without changes
|
opendet3d/model/detect/__init__.py
ADDED
|
File without changes
|
opendet3d/model/detect/grounding_dino.py
ADDED
|
@@ -0,0 +1,1050 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Grounding DINO model.
|
| 2 |
+
|
| 3 |
+
modified from mmdetection.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from collections.abc import Sequence
|
| 7 |
+
from typing import NamedTuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
from transformers import BatchEncoding
|
| 13 |
+
from vis4d.common.ckpt import load_model_checkpoint
|
| 14 |
+
from vis4d.common.logging import rank_zero_warn
|
| 15 |
+
from vis4d.op.base import BaseModel
|
| 16 |
+
from vis4d.op.layer.positional_encoding import SinePositionalEncoding
|
| 17 |
+
|
| 18 |
+
from opendet3d.model.language.mm_bert import BertModel
|
| 19 |
+
from opendet3d.op.detect.deformable_detr import get_valid_ratio
|
| 20 |
+
from opendet3d.op.detect.dino import CdnQueryGenerator
|
| 21 |
+
from opendet3d.op.detect.grounding_dino import (
|
| 22 |
+
GroundingDINOHead,
|
| 23 |
+
GroundingDinoTransformerDecoder,
|
| 24 |
+
GroundingDinoTransformerEncoder,
|
| 25 |
+
RoI2Det,
|
| 26 |
+
)
|
| 27 |
+
from opendet3d.op.fpp.channel_mapper import ChannelMapper
|
| 28 |
+
from opendet3d.op.language.grounding import (
|
| 29 |
+
chunks,
|
| 30 |
+
clean_label_name,
|
| 31 |
+
create_positive_map,
|
| 32 |
+
create_positive_map_label_to_token,
|
| 33 |
+
run_ner,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
REV_KEYS = [
|
| 37 |
+
(r"\.conv.weight", ".weight"),
|
| 38 |
+
(r"\.conv.bias", ".bias"),
|
| 39 |
+
(r"\.gn", ".norm"),
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class GroundingDINOOut(NamedTuple):
|
| 44 |
+
"""Output of the Grounding DINO model."""
|
| 45 |
+
|
| 46 |
+
all_layers_cls_scores: list[Tensor]
|
| 47 |
+
all_layers_bbox_preds: list[Tensor]
|
| 48 |
+
enc_outputs_class: Tensor
|
| 49 |
+
enc_outputs_coord: Tensor
|
| 50 |
+
text_token_mask: Tensor
|
| 51 |
+
dn_meta: dict[str, Tensor]
|
| 52 |
+
positive_maps: list[Tensor]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class DetOut(NamedTuple):
|
| 56 |
+
"""Output of the Grounding DINO model."""
|
| 57 |
+
|
| 58 |
+
boxes: list[Tensor]
|
| 59 |
+
scores: list[Tensor]
|
| 60 |
+
class_ids: list[Tensor]
|
| 61 |
+
categories: list[list[str]] | None = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class GroundingDINO(nn.Module):
|
| 65 |
+
"""Grounding DINO."""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
basemodel: BaseModel,
|
| 70 |
+
neck: ChannelMapper,
|
| 71 |
+
texts: list[str] | None = None,
|
| 72 |
+
custom_entities: bool = True,
|
| 73 |
+
chunked_size: int = -1,
|
| 74 |
+
num_queries: int = 900,
|
| 75 |
+
num_feature_levels: int = 4,
|
| 76 |
+
use_checkpoint: bool = False,
|
| 77 |
+
bbox_head: GroundingDINOHead | None = None,
|
| 78 |
+
language_model: BertModel | None = None,
|
| 79 |
+
roi2det: RoI2Det | None = None,
|
| 80 |
+
weights: str | None = None,
|
| 81 |
+
) -> None:
|
| 82 |
+
"""Create the Grounding DINO model."""
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.texts = texts
|
| 85 |
+
self.custom_entities = custom_entities
|
| 86 |
+
self.chunked_size = chunked_size
|
| 87 |
+
|
| 88 |
+
self.num_queries = num_queries
|
| 89 |
+
|
| 90 |
+
self.backbone = basemodel
|
| 91 |
+
self.neck = neck
|
| 92 |
+
|
| 93 |
+
# Encoder
|
| 94 |
+
self.encoder = GroundingDinoTransformerEncoder(
|
| 95 |
+
num_levels=num_feature_levels, use_checkpoint=use_checkpoint
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.embed_dims = self.encoder.embed_dims
|
| 99 |
+
self.positional_encoding = SinePositionalEncoding(
|
| 100 |
+
num_feats=128, normalize=True, offset=0.0, temperature=20
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
num_feats = self.positional_encoding.num_feats
|
| 104 |
+
assert num_feats * 2 == self.embed_dims, (
|
| 105 |
+
f"embed_dims should be exactly 2 times of num_feats. "
|
| 106 |
+
f"Found {self.embed_dims} and {num_feats}."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.level_embed = nn.Parameter(
|
| 110 |
+
torch.Tensor(num_feature_levels, self.embed_dims)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
|
| 114 |
+
self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
|
| 115 |
+
|
| 116 |
+
# Decoder
|
| 117 |
+
self.decoder = GroundingDinoTransformerDecoder(
|
| 118 |
+
num_levels=num_feature_levels
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
|
| 122 |
+
|
| 123 |
+
# Grounding DINO head
|
| 124 |
+
self.bbox_head = bbox_head or GroundingDINOHead(
|
| 125 |
+
num_classes=256, num_decoder_layer=self.decoder.num_layers
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.roi2det = roi2det or RoI2Det()
|
| 129 |
+
|
| 130 |
+
self.dn_query_generator = CdnQueryGenerator(
|
| 131 |
+
num_classes=self.bbox_head.num_classes,
|
| 132 |
+
embed_dims=self.embed_dims,
|
| 133 |
+
num_matching_queries=self.num_queries,
|
| 134 |
+
label_noise_scale=0.5,
|
| 135 |
+
box_noise_scale=1.0, # 0.4 for DN-DETR
|
| 136 |
+
dynamic=True,
|
| 137 |
+
num_groups=None,
|
| 138 |
+
num_dn_queries=100,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Language model configuration
|
| 142 |
+
self._special_tokens = ". "
|
| 143 |
+
|
| 144 |
+
# text modules
|
| 145 |
+
self.language_model = language_model or BertModel(
|
| 146 |
+
name="bert-base-uncased",
|
| 147 |
+
max_tokens=256,
|
| 148 |
+
pad_to_max=False,
|
| 149 |
+
use_sub_sentence_represent=True,
|
| 150 |
+
special_tokens_list=["[CLS]", "[SEP]", ".", "?"],
|
| 151 |
+
add_pooling_layer=False,
|
| 152 |
+
use_checkpoint=use_checkpoint,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.text_feat_map = nn.Linear(
|
| 156 |
+
self.language_model.language_backbone.body.language_dim,
|
| 157 |
+
self.embed_dims,
|
| 158 |
+
bias=True,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self._init_weights()
|
| 162 |
+
|
| 163 |
+
if weights is not None:
|
| 164 |
+
load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
|
| 165 |
+
|
| 166 |
+
def _init_weights(self) -> None:
|
| 167 |
+
"""Initialize weights."""
|
| 168 |
+
# DINO
|
| 169 |
+
for coder in self.encoder, self.decoder:
|
| 170 |
+
for p in coder.parameters():
|
| 171 |
+
if p.dim() > 1:
|
| 172 |
+
nn.init.xavier_uniform_(p)
|
| 173 |
+
|
| 174 |
+
nn.init.xavier_uniform_(self.memory_trans_fc.weight)
|
| 175 |
+
nn.init.xavier_uniform_(self.query_embedding.weight)
|
| 176 |
+
nn.init.normal_(self.level_embed)
|
| 177 |
+
|
| 178 |
+
# G-DINO
|
| 179 |
+
nn.init.constant_(self.text_feat_map.bias.data, 0)
|
| 180 |
+
nn.init.xavier_uniform_(self.text_feat_map.weight.data)
|
| 181 |
+
|
| 182 |
+
def get_captions_and_tokens_positive(
|
| 183 |
+
self,
|
| 184 |
+
text_prompt: list[str],
|
| 185 |
+
text_prompt_mapping: dict[str, dict[str, str]] | None = None,
|
| 186 |
+
) -> tuple[str, list[list[int]]]:
|
| 187 |
+
"""Enhance the text prompts with the text mapping."""
|
| 188 |
+
captions = ""
|
| 189 |
+
tokens_positive = []
|
| 190 |
+
for word in text_prompt:
|
| 191 |
+
if text_prompt_mapping is not None and word in text_prompt_mapping:
|
| 192 |
+
enhanced_text_dict = text_prompt_mapping[word]
|
| 193 |
+
if "prefix" in enhanced_text_dict:
|
| 194 |
+
captions += enhanced_text_dict["prefix"]
|
| 195 |
+
|
| 196 |
+
start_i = len(captions)
|
| 197 |
+
if "name" in enhanced_text_dict:
|
| 198 |
+
captions += enhanced_text_dict["name"]
|
| 199 |
+
else:
|
| 200 |
+
captions += word
|
| 201 |
+
end_i = len(captions)
|
| 202 |
+
|
| 203 |
+
tokens_positive.append([[start_i, end_i]])
|
| 204 |
+
|
| 205 |
+
if "suffix" in enhanced_text_dict:
|
| 206 |
+
captions += enhanced_text_dict["suffix"]
|
| 207 |
+
else:
|
| 208 |
+
tokens_positive.append(
|
| 209 |
+
[[len(captions), len(captions) + len(word)]]
|
| 210 |
+
)
|
| 211 |
+
captions += word
|
| 212 |
+
captions += self._special_tokens
|
| 213 |
+
return captions, tokens_positive
|
| 214 |
+
|
| 215 |
+
def get_tokens_and_prompts(
|
| 216 |
+
self,
|
| 217 |
+
text_prompt: str | list[str],
|
| 218 |
+
text_prompt_mapping: dict[str, dict[str, str]] | None = None,
|
| 219 |
+
) -> tuple[BatchEncoding, str, list[list[int]], list[str]]:
|
| 220 |
+
"""Get the tokens positive and prompts for the caption."""
|
| 221 |
+
if isinstance(text_prompt, list):
|
| 222 |
+
captions, tokens_positive = self.get_captions_and_tokens_positive(
|
| 223 |
+
text_prompt, text_prompt_mapping
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
tokenized = self.language_model.tokenizer(
|
| 227 |
+
[captions], padding="longest", return_tensors="pt"
|
| 228 |
+
)
|
| 229 |
+
entities = text_prompt
|
| 230 |
+
else:
|
| 231 |
+
if not text_prompt.endswith("."):
|
| 232 |
+
captions = text_prompt + self._special_tokens
|
| 233 |
+
else:
|
| 234 |
+
captions = text_prompt
|
| 235 |
+
|
| 236 |
+
tokenized = self.language_model.tokenizer(
|
| 237 |
+
[captions], padding="longest", return_tensors="pt"
|
| 238 |
+
)
|
| 239 |
+
tokens_positive, entities = run_ner(captions)
|
| 240 |
+
|
| 241 |
+
return tokenized, captions, tokens_positive, entities
|
| 242 |
+
|
| 243 |
+
def get_positive_map(
|
| 244 |
+
self, tokenized: BatchEncoding, tokens_positive: list[list[int]]
|
| 245 |
+
) -> tuple[dict, Tensor]:
|
| 246 |
+
"""Get the positive map and label to token."""
|
| 247 |
+
positive_map = create_positive_map(
|
| 248 |
+
tokenized,
|
| 249 |
+
tokens_positive,
|
| 250 |
+
max_num_entities=self.bbox_head.cls_branches[
|
| 251 |
+
self.decoder.num_layers
|
| 252 |
+
].max_text_len,
|
| 253 |
+
)
|
| 254 |
+
positive_map_label_to_token = create_positive_map_label_to_token(
|
| 255 |
+
positive_map, plus=1
|
| 256 |
+
)
|
| 257 |
+
return positive_map_label_to_token, positive_map
|
| 258 |
+
|
| 259 |
+
def get_tokens_positive_and_prompts(
|
| 260 |
+
self,
|
| 261 |
+
text_prompt: str | list[str],
|
| 262 |
+
custom_entities: bool = False,
|
| 263 |
+
tokens_positive: list[list[int, int]] | int | None = None,
|
| 264 |
+
text_prompt_mapping: dict[str, dict[str, str]] | None = None,
|
| 265 |
+
) -> tuple[dict, str, Tensor, list]:
|
| 266 |
+
"""Get the tokens positive and prompts for the caption.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
original_caption (str): The original caption, e.g. 'bench . car .'
|
| 270 |
+
custom_entities (bool, optional): Whether to use custom entities.
|
| 271 |
+
If ``True``, the ``original_caption`` should be a list of
|
| 272 |
+
strings, each of which is a word. Defaults to False.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Tuple[dict, str, dict, str]: The dict is a mapping from each entity
|
| 276 |
+
id, which is numbered from 1, to its positive token id.
|
| 277 |
+
The str represents the prompts.
|
| 278 |
+
"""
|
| 279 |
+
if tokens_positive is not None:
|
| 280 |
+
assert isinstance(
|
| 281 |
+
text_prompt, str
|
| 282 |
+
), "Text prompt should be a string with given positive tokens."
|
| 283 |
+
|
| 284 |
+
if not text_prompt.endswith("."):
|
| 285 |
+
captions = text_prompt + self._special_tokens
|
| 286 |
+
else:
|
| 287 |
+
captions = text_prompt
|
| 288 |
+
|
| 289 |
+
if tokens_positive == -1:
|
| 290 |
+
return None, captions, None, captions
|
| 291 |
+
else:
|
| 292 |
+
assert isinstance(
|
| 293 |
+
tokens_positive, list
|
| 294 |
+
), "Positive tokens should be a list of list[int] if not -1."
|
| 295 |
+
tokenized = self.language_model.tokenizer(
|
| 296 |
+
[captions], padding="longest", return_tensors="pt"
|
| 297 |
+
)
|
| 298 |
+
positive_map_label_to_token, positive_map = (
|
| 299 |
+
self.get_positive_map(tokenized, tokens_positive)
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
entities = []
|
| 303 |
+
for token_positive in tokens_positive:
|
| 304 |
+
instance_entities = []
|
| 305 |
+
for t in token_positive:
|
| 306 |
+
instance_entities.append(captions[t[0] : t[1]])
|
| 307 |
+
entities.append(" / ".join(instance_entities))
|
| 308 |
+
|
| 309 |
+
return (
|
| 310 |
+
positive_map_label_to_token,
|
| 311 |
+
captions,
|
| 312 |
+
positive_map,
|
| 313 |
+
entities,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
if custom_entities:
|
| 317 |
+
if isinstance(text_prompt, str):
|
| 318 |
+
text_prompt = text_prompt.strip(self._special_tokens)
|
| 319 |
+
text_prompt = text_prompt.split(self._special_tokens)
|
| 320 |
+
text_prompt = list(filter(lambda x: len(x) > 0, text_prompt))
|
| 321 |
+
text_prompt = [clean_label_name(i) for i in text_prompt]
|
| 322 |
+
|
| 323 |
+
if self.chunked_size > 0:
|
| 324 |
+
assert not self.training, "Chunked size is only for testing."
|
| 325 |
+
(
|
| 326 |
+
positive_map_label_to_token,
|
| 327 |
+
captions,
|
| 328 |
+
positive_map,
|
| 329 |
+
entities,
|
| 330 |
+
) = self.get_tokens_positive_and_prompts_chunked(
|
| 331 |
+
text_prompt, text_prompt_mapping
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
tokenized, captions, tokens_positive, entities = (
|
| 335 |
+
self.get_tokens_and_prompts(text_prompt, text_prompt_mapping)
|
| 336 |
+
)
|
| 337 |
+
positive_map_label_to_token, positive_map = self.get_positive_map(
|
| 338 |
+
tokenized, tokens_positive
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
return positive_map_label_to_token, captions, positive_map, entities
|
| 342 |
+
|
| 343 |
+
def get_tokens_positive_and_prompts_chunked(
|
| 344 |
+
self,
|
| 345 |
+
text_prompt: list[str],
|
| 346 |
+
text_prompt_mapping: dict[str, dict[str, str]] | None = None,
|
| 347 |
+
):
|
| 348 |
+
"""Get the tokens positive and prompts for the caption."""
|
| 349 |
+
text_prompt_chunked = chunks(text_prompt, self.chunked_size)
|
| 350 |
+
ids_chunked = chunks(
|
| 351 |
+
list(range(1, len(text_prompt) + 1)), self.chunked_size
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
positive_map_label_to_token_chunked = []
|
| 355 |
+
captions_chunked = []
|
| 356 |
+
positive_map_chunked = []
|
| 357 |
+
entities_chunked = []
|
| 358 |
+
for i in range(len(ids_chunked)):
|
| 359 |
+
captions, tokens_positive = self.get_captions_and_tokens_positive(
|
| 360 |
+
text_prompt_chunked[i], text_prompt_mapping
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
tokenized = self.language_model.tokenizer(
|
| 364 |
+
[captions], padding="longest", return_tensors="pt"
|
| 365 |
+
)
|
| 366 |
+
if tokenized.input_ids.shape[1] > self.language_model.max_tokens:
|
| 367 |
+
rank_zero_warn(
|
| 368 |
+
"Caption is too long will result in poor performance. "
|
| 369 |
+
"Please reduce the chunked size."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
positive_map_label_to_token, positive_map = self.get_positive_map(
|
| 373 |
+
tokenized, tokens_positive
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
captions_chunked.append(captions)
|
| 377 |
+
positive_map_label_to_token_chunked.append(
|
| 378 |
+
positive_map_label_to_token
|
| 379 |
+
)
|
| 380 |
+
positive_map_chunked.append(positive_map)
|
| 381 |
+
entities_chunked.append(text_prompt_chunked[i])
|
| 382 |
+
|
| 383 |
+
return (
|
| 384 |
+
positive_map_label_to_token_chunked,
|
| 385 |
+
captions_chunked,
|
| 386 |
+
positive_map_chunked,
|
| 387 |
+
entities_chunked,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# TODO: Move this to deformable DETR
|
| 391 |
+
def pre_transformer(
|
| 392 |
+
self,
|
| 393 |
+
feats: list[Tensor],
|
| 394 |
+
input_hw: list[tuple[int, int]],
|
| 395 |
+
batch_input_shape: tuple[int, int],
|
| 396 |
+
padding: list[list[int]] | None = None,
|
| 397 |
+
) -> tuple[Tensor, Tensor, Tensor | None, Tensor, Tensor, Tensor]:
|
| 398 |
+
"""Process image features before transformer."""
|
| 399 |
+
batch_size = feats[0].size(0)
|
| 400 |
+
|
| 401 |
+
# construct binary masks for the transformer.
|
| 402 |
+
batch_input_img_h, batch_input_img_w = batch_input_shape
|
| 403 |
+
same_shape_flag = all(
|
| 404 |
+
[
|
| 405 |
+
s[0] == batch_input_img_h and s[1] == batch_input_img_w
|
| 406 |
+
for s in input_hw
|
| 407 |
+
]
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if same_shape_flag:
|
| 411 |
+
mlvl_masks = []
|
| 412 |
+
mlvl_pos_embeds = []
|
| 413 |
+
for feat in feats:
|
| 414 |
+
mlvl_masks.append(None)
|
| 415 |
+
mlvl_pos_embeds.append(
|
| 416 |
+
self.positional_encoding(None, inputs=feat)
|
| 417 |
+
)
|
| 418 |
+
else:
|
| 419 |
+
check_center = not (padding is None)
|
| 420 |
+
masks = feats[0].new_ones(
|
| 421 |
+
(batch_size, batch_input_img_h, batch_input_img_w)
|
| 422 |
+
)
|
| 423 |
+
for img_id in range(batch_size):
|
| 424 |
+
img_h, img_w = input_hw[img_id]
|
| 425 |
+
|
| 426 |
+
if padding is None:
|
| 427 |
+
masks[img_id, :img_h, :img_w] = 0
|
| 428 |
+
else:
|
| 429 |
+
pad_left, pad_right, pad_top, pad_bottom = padding[img_id]
|
| 430 |
+
masks[
|
| 431 |
+
img_id,
|
| 432 |
+
pad_top : batch_input_img_h - pad_bottom,
|
| 433 |
+
pad_left : batch_input_img_w - pad_right,
|
| 434 |
+
] = 0
|
| 435 |
+
|
| 436 |
+
# NOTE following the official DETR repo, non-zero
|
| 437 |
+
# values representing ignored positions, while
|
| 438 |
+
# zero values means valid positions.
|
| 439 |
+
mlvl_masks = []
|
| 440 |
+
mlvl_pos_embeds = []
|
| 441 |
+
for feat in feats:
|
| 442 |
+
mlvl_masks.append(
|
| 443 |
+
F.interpolate(masks[None], size=feat.shape[-2:])
|
| 444 |
+
.to(torch.bool)
|
| 445 |
+
.squeeze(0)
|
| 446 |
+
)
|
| 447 |
+
mlvl_pos_embeds.append(
|
| 448 |
+
self.positional_encoding(mlvl_masks[-1])
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
feat_flatten = []
|
| 452 |
+
lvl_pos_embed_flatten = []
|
| 453 |
+
mask_flatten = []
|
| 454 |
+
spatial_shapes = []
|
| 455 |
+
for lvl, (feat, mask, pos_embed) in enumerate(
|
| 456 |
+
zip(feats, mlvl_masks, mlvl_pos_embeds)
|
| 457 |
+
):
|
| 458 |
+
batch_size, c, _, _ = feat.shape
|
| 459 |
+
spatial_shape = torch._shape_as_tensor(feat)[2:].to(feat.device)
|
| 460 |
+
# [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c]
|
| 461 |
+
feat = feat.view(batch_size, c, -1).permute(0, 2, 1)
|
| 462 |
+
pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1)
|
| 463 |
+
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
| 464 |
+
# [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl]
|
| 465 |
+
if mask is not None:
|
| 466 |
+
mask = mask.flatten(1)
|
| 467 |
+
|
| 468 |
+
feat_flatten.append(feat)
|
| 469 |
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
| 470 |
+
mask_flatten.append(mask)
|
| 471 |
+
spatial_shapes.append(spatial_shape)
|
| 472 |
+
|
| 473 |
+
# (bs, num_feat_points, dim)
|
| 474 |
+
feat_flatten = torch.cat(feat_flatten, 1)
|
| 475 |
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
| 476 |
+
# (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl)
|
| 477 |
+
if mask_flatten[0] is not None:
|
| 478 |
+
mask_flatten = torch.cat(mask_flatten, 1)
|
| 479 |
+
else:
|
| 480 |
+
mask_flatten = None
|
| 481 |
+
|
| 482 |
+
# (num_level, 2)
|
| 483 |
+
spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
|
| 484 |
+
level_start_index = torch.cat(
|
| 485 |
+
(
|
| 486 |
+
spatial_shapes.new_zeros((1,)), # (num_level)
|
| 487 |
+
spatial_shapes.prod(1).cumsum(0)[:-1],
|
| 488 |
+
)
|
| 489 |
+
)
|
| 490 |
+
if mlvl_masks[0] is not None:
|
| 491 |
+
valid_ratios = torch.stack( # (bs, num_level, 2)
|
| 492 |
+
[get_valid_ratio(m, check_center) for m in mlvl_masks],
|
| 493 |
+
1,
|
| 494 |
+
)
|
| 495 |
+
else:
|
| 496 |
+
valid_ratios = feats[0].new_ones(batch_size, len(feats), 2)
|
| 497 |
+
|
| 498 |
+
return (
|
| 499 |
+
feat_flatten,
|
| 500 |
+
lvl_pos_embed_flatten,
|
| 501 |
+
mask_flatten,
|
| 502 |
+
spatial_shapes,
|
| 503 |
+
level_start_index,
|
| 504 |
+
valid_ratios,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
def forward_transformer(
|
| 508 |
+
self,
|
| 509 |
+
feat_flatten: Tensor,
|
| 510 |
+
lvl_pos_embed_flatten: Tensor,
|
| 511 |
+
memory_mask: Tensor | None,
|
| 512 |
+
spatial_shapes: Tensor,
|
| 513 |
+
level_start_index: Tensor,
|
| 514 |
+
valid_ratios: Tensor,
|
| 515 |
+
text_dict: dict[str, Tensor],
|
| 516 |
+
boxes: Tensor | None = None,
|
| 517 |
+
class_ids: Tensor | None = None,
|
| 518 |
+
input_hw: list[tuple[int, int]] | None = None,
|
| 519 |
+
) -> tuple[Tensor, Tensor, Tensor, list[Tensor]]:
|
| 520 |
+
"""Forward function for the transformer."""
|
| 521 |
+
text_token_mask = text_dict["text_token_mask"]
|
| 522 |
+
|
| 523 |
+
memory, memory_text = self.encoder(
|
| 524 |
+
query=feat_flatten,
|
| 525 |
+
query_pos=lvl_pos_embed_flatten,
|
| 526 |
+
key_padding_mask=memory_mask,
|
| 527 |
+
spatial_shapes=spatial_shapes,
|
| 528 |
+
level_start_index=level_start_index,
|
| 529 |
+
valid_ratios=valid_ratios,
|
| 530 |
+
memory_text=text_dict["embedded"],
|
| 531 |
+
text_attention_mask=~text_token_mask,
|
| 532 |
+
position_ids=text_dict["position_ids"],
|
| 533 |
+
text_self_attention_masks=text_dict["masks"],
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
bs = memory.shape[0]
|
| 537 |
+
|
| 538 |
+
output_memory, output_proposals = self.gen_encoder_output_proposals(
|
| 539 |
+
memory, memory_mask, spatial_shapes
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
enc_outputs_class = self.bbox_head.cls_branches[
|
| 543 |
+
self.decoder.num_layers
|
| 544 |
+
](output_memory, memory_text, text_token_mask)
|
| 545 |
+
cls_out_features = self.bbox_head.cls_branches[
|
| 546 |
+
self.decoder.num_layers
|
| 547 |
+
].max_text_len
|
| 548 |
+
enc_outputs_coord_unact = (
|
| 549 |
+
self.bbox_head.reg_branches[self.decoder.num_layers](output_memory)
|
| 550 |
+
+ output_proposals
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# NOTE The DINO selects top-k proposals according to scores of
|
| 554 |
+
# multi-class classification, while DeformDETR, where the input
|
| 555 |
+
# is `enc_outputs_class[..., 0]` selects according to scores of
|
| 556 |
+
# binary classification.
|
| 557 |
+
topk_indices = torch.topk(
|
| 558 |
+
enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1
|
| 559 |
+
)[1]
|
| 560 |
+
|
| 561 |
+
topk_score = torch.gather(
|
| 562 |
+
enc_outputs_class,
|
| 563 |
+
1,
|
| 564 |
+
topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features),
|
| 565 |
+
)
|
| 566 |
+
topk_coords_unact = torch.gather(
|
| 567 |
+
enc_outputs_coord_unact,
|
| 568 |
+
1,
|
| 569 |
+
topk_indices.unsqueeze(-1).repeat(1, 1, 4),
|
| 570 |
+
)
|
| 571 |
+
topk_coords = topk_coords_unact.sigmoid()
|
| 572 |
+
topk_coords_unact = topk_coords_unact.detach()
|
| 573 |
+
|
| 574 |
+
query = self.query_embedding.weight[:, None, :]
|
| 575 |
+
query = query.repeat(1, bs, 1).transpose(0, 1)
|
| 576 |
+
|
| 577 |
+
if self.training:
|
| 578 |
+
dn_label_query, dn_bbox_query, dn_mask, dn_meta = (
|
| 579 |
+
self.dn_query_generator(boxes, class_ids, input_hw)
|
| 580 |
+
)
|
| 581 |
+
query = torch.cat([dn_label_query, query], dim=1)
|
| 582 |
+
reference_points = torch.cat(
|
| 583 |
+
[dn_bbox_query, topk_coords_unact], dim=1
|
| 584 |
+
)
|
| 585 |
+
else:
|
| 586 |
+
reference_points = topk_coords_unact
|
| 587 |
+
dn_mask, dn_meta = None, None
|
| 588 |
+
|
| 589 |
+
reference_points = reference_points.sigmoid()
|
| 590 |
+
|
| 591 |
+
# NOTE DINO calculates encoder losses on scores and coordinates
|
| 592 |
+
# of selected top-k encoder queries, while DeformDETR is of all
|
| 593 |
+
# encoder queries.
|
| 594 |
+
if self.training:
|
| 595 |
+
enc_outputs_class = topk_score
|
| 596 |
+
enc_outputs_coord = topk_coords
|
| 597 |
+
|
| 598 |
+
hidden_states, references = self.decoder(
|
| 599 |
+
query=query,
|
| 600 |
+
value=memory,
|
| 601 |
+
key_padding_mask=memory_mask,
|
| 602 |
+
self_attn_mask=dn_mask,
|
| 603 |
+
reference_points=reference_points,
|
| 604 |
+
spatial_shapes=spatial_shapes,
|
| 605 |
+
level_start_index=level_start_index,
|
| 606 |
+
valid_ratios=valid_ratios,
|
| 607 |
+
reg_branches=self.bbox_head.reg_branches,
|
| 608 |
+
memory_text=memory_text,
|
| 609 |
+
text_attention_mask=~text_token_mask,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
if len(query) == self.num_queries:
|
| 613 |
+
# NOTE: This is to make sure label_embeding can be involved to
|
| 614 |
+
# produce loss even if there is no denoising query (no ground truth
|
| 615 |
+
# target in this GPU), otherwise, this will raise runtime error in
|
| 616 |
+
# distributed training.
|
| 617 |
+
hidden_states[0] += (
|
| 618 |
+
self.dn_query_generator.label_embedding.weight[0, 0] * 0.0
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
if self.training:
|
| 622 |
+
return (
|
| 623 |
+
memory_text,
|
| 624 |
+
text_token_mask,
|
| 625 |
+
hidden_states,
|
| 626 |
+
list(references),
|
| 627 |
+
enc_outputs_class,
|
| 628 |
+
enc_outputs_coord,
|
| 629 |
+
dn_meta,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
return memory_text, text_token_mask, hidden_states, list(references)
|
| 633 |
+
|
| 634 |
+
# TODO: Move this to deformable DETR
|
| 635 |
+
def gen_encoder_output_proposals(
|
| 636 |
+
self,
|
| 637 |
+
memory: Tensor,
|
| 638 |
+
memory_mask: Tensor | None,
|
| 639 |
+
spatial_shapes: Tensor,
|
| 640 |
+
) -> tuple[Tensor, Tensor]:
|
| 641 |
+
"""Generate proposals from encoded memory. The function will only be
|
| 642 |
+
used when `as_two_stage` is `True`.
|
| 643 |
+
|
| 644 |
+
Args:
|
| 645 |
+
memory (Tensor): The output embeddings of the Transformer encoder,
|
| 646 |
+
has shape (bs, num_feat_points, dim).
|
| 647 |
+
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
|
| 648 |
+
has shape (bs, num_feat_points).
|
| 649 |
+
spatial_shapes (Tensor): Spatial shapes of features in all levels,
|
| 650 |
+
has shape (num_levels, 2), last dimension represents (h, w).
|
| 651 |
+
|
| 652 |
+
Returns:
|
| 653 |
+
tuple: A tuple of transformed memory and proposals.
|
| 654 |
+
|
| 655 |
+
- output_memory (Tensor): The transformed memory for obtaining
|
| 656 |
+
top-k proposals, has shape (bs, num_feat_points, dim).
|
| 657 |
+
- output_proposals (Tensor): The inverse-normalized proposal, has
|
| 658 |
+
shape (batch_size, num_keys, 4) with the last dimension arranged
|
| 659 |
+
as (cx, cy, w, h).
|
| 660 |
+
"""
|
| 661 |
+
|
| 662 |
+
bs = memory.size(0)
|
| 663 |
+
proposals = []
|
| 664 |
+
_cur = 0 # start index in the sequence of the current level
|
| 665 |
+
for lvl, HW in enumerate(spatial_shapes):
|
| 666 |
+
H, W = HW
|
| 667 |
+
|
| 668 |
+
if memory_mask is not None:
|
| 669 |
+
mask_flatten_ = memory_mask[:, _cur : (_cur + H * W)].view(
|
| 670 |
+
bs, H, W, 1
|
| 671 |
+
)
|
| 672 |
+
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(
|
| 673 |
+
-1
|
| 674 |
+
)
|
| 675 |
+
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(
|
| 676 |
+
-1
|
| 677 |
+
)
|
| 678 |
+
scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
|
| 679 |
+
else:
|
| 680 |
+
if not isinstance(HW, Tensor):
|
| 681 |
+
HW = memory.new_tensor(HW)
|
| 682 |
+
scale = HW.unsqueeze(0).flip(dims=[0, 1]).view(1, 1, 1, 2)
|
| 683 |
+
|
| 684 |
+
grid_y, grid_x = torch.meshgrid(
|
| 685 |
+
torch.linspace(
|
| 686 |
+
0, H - 1, H, dtype=torch.float32, device=memory.device
|
| 687 |
+
),
|
| 688 |
+
torch.linspace(
|
| 689 |
+
0, W - 1, W, dtype=torch.float32, device=memory.device
|
| 690 |
+
),
|
| 691 |
+
indexing="ij",
|
| 692 |
+
)
|
| 693 |
+
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
| 694 |
+
grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
|
| 695 |
+
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
|
| 696 |
+
proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
|
| 697 |
+
proposals.append(proposal)
|
| 698 |
+
_cur += H * W
|
| 699 |
+
|
| 700 |
+
output_proposals = torch.cat(proposals, 1)
|
| 701 |
+
|
| 702 |
+
# do not use `all` to make it exportable to onnx
|
| 703 |
+
output_proposals_valid = (
|
| 704 |
+
(output_proposals > 0.01) & (output_proposals < 0.99)
|
| 705 |
+
).sum(-1, keepdim=True) == output_proposals.shape[-1]
|
| 706 |
+
|
| 707 |
+
# inverse_sigmoid
|
| 708 |
+
output_proposals = torch.log(output_proposals / (1 - output_proposals))
|
| 709 |
+
if memory_mask is not None:
|
| 710 |
+
output_proposals = output_proposals.masked_fill(
|
| 711 |
+
memory_mask.unsqueeze(-1), float("inf")
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
output_proposals = output_proposals.masked_fill(
|
| 715 |
+
~output_proposals_valid, float("inf")
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
if memory_mask is not None:
|
| 719 |
+
output_memory = memory.masked_fill(
|
| 720 |
+
memory_mask.unsqueeze(-1), float(0)
|
| 721 |
+
)
|
| 722 |
+
else:
|
| 723 |
+
output_memory = memory
|
| 724 |
+
|
| 725 |
+
# [bs, sum(hw), 2]
|
| 726 |
+
output_memory = output_memory.masked_fill(
|
| 727 |
+
~output_proposals_valid, float(0)
|
| 728 |
+
)
|
| 729 |
+
output_memory = self.memory_trans_fc(output_memory)
|
| 730 |
+
output_memory = self.memory_trans_norm(output_memory)
|
| 731 |
+
|
| 732 |
+
return output_memory, output_proposals
|
| 733 |
+
|
| 734 |
+
def _forward_train(
|
| 735 |
+
self,
|
| 736 |
+
images: Tensor,
|
| 737 |
+
input_texts: list[list[str]] | None,
|
| 738 |
+
boxes2d: Tensor,
|
| 739 |
+
boxes2d_classes: Tensor,
|
| 740 |
+
input_hw: list[tuple[int, int]],
|
| 741 |
+
input_tokens_positive: list[list[int, int]] | None = None,
|
| 742 |
+
) -> GroundingDINOOut:
|
| 743 |
+
"""Forward train."""
|
| 744 |
+
batch_size = images.shape[0]
|
| 745 |
+
|
| 746 |
+
# if "tokens_positive" in batch_data_samples[0]:
|
| 747 |
+
if input_tokens_positive is not None:
|
| 748 |
+
positive_maps = []
|
| 749 |
+
for tokens_positive_dict, text_prompt, gt_label in zip(
|
| 750 |
+
input_tokens_positive, input_texts, boxes2d_classes
|
| 751 |
+
):
|
| 752 |
+
tokenized = self.language_model.tokenizer(
|
| 753 |
+
[text_prompt], padding="longest", return_tensors="pt"
|
| 754 |
+
)
|
| 755 |
+
new_tokens_positive = [
|
| 756 |
+
tokens_positive_dict[label.item()] for label in gt_label
|
| 757 |
+
]
|
| 758 |
+
_, positive_map = self.get_positive_map(
|
| 759 |
+
tokenized, new_tokens_positive
|
| 760 |
+
)
|
| 761 |
+
positive_maps.append(positive_map)
|
| 762 |
+
new_text_prompts = input_texts
|
| 763 |
+
else:
|
| 764 |
+
new_text_prompts = []
|
| 765 |
+
positive_maps = []
|
| 766 |
+
|
| 767 |
+
# All the text prompts are the same, so there is no need to
|
| 768 |
+
# calculate them multiple times.
|
| 769 |
+
if (
|
| 770 |
+
input_texts is None
|
| 771 |
+
or len(set(["".join(t) for t in input_texts])) == 1
|
| 772 |
+
):
|
| 773 |
+
if input_texts is None:
|
| 774 |
+
text_prompt = self.texts
|
| 775 |
+
else:
|
| 776 |
+
text_prompt = input_texts[0]
|
| 777 |
+
|
| 778 |
+
tokenized, caption_string, tokens_positive, _ = (
|
| 779 |
+
self.get_tokens_and_prompts(text_prompt)
|
| 780 |
+
)
|
| 781 |
+
new_text_prompts = [caption_string] * batch_size
|
| 782 |
+
for gt_label in boxes2d_classes:
|
| 783 |
+
new_tokens_positive = [
|
| 784 |
+
tokens_positive[label] for label in gt_label
|
| 785 |
+
]
|
| 786 |
+
_, positive_map = self.get_positive_map(
|
| 787 |
+
tokenized, new_tokens_positive
|
| 788 |
+
)
|
| 789 |
+
positive_maps.append(positive_map)
|
| 790 |
+
else:
|
| 791 |
+
for text_prompt, gt_label in zip(input_texts, boxes2d_classes):
|
| 792 |
+
tokenized, caption_string, tokens_positive, _ = (
|
| 793 |
+
self.get_tokens_and_prompts(text_prompt)
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
new_text_prompts.append(caption_string)
|
| 797 |
+
|
| 798 |
+
new_tokens_positive = [
|
| 799 |
+
tokens_positive[label] for label in gt_label
|
| 800 |
+
]
|
| 801 |
+
_, positive_map = self.get_positive_map(
|
| 802 |
+
tokenized, new_tokens_positive
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
positive_maps.append(positive_map)
|
| 806 |
+
|
| 807 |
+
for i in range(batch_size):
|
| 808 |
+
positive_maps[i] = (
|
| 809 |
+
positive_maps[i].to(images.device).bool().float()
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
text_dict = self.language_model(new_text_prompts)
|
| 813 |
+
|
| 814 |
+
if self.text_feat_map is not None:
|
| 815 |
+
text_dict["embedded"] = self.text_feat_map(text_dict["embedded"])
|
| 816 |
+
|
| 817 |
+
text_token_masks = []
|
| 818 |
+
for i in range(batch_size):
|
| 819 |
+
text_token_masks.append(
|
| 820 |
+
text_dict["text_token_mask"][i]
|
| 821 |
+
.unsqueeze(0)
|
| 822 |
+
.repeat(len(positive_map), 1)
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
visual_feats = self.backbone(images)[2:]
|
| 826 |
+
visual_feats = self.neck(visual_feats)
|
| 827 |
+
|
| 828 |
+
batch_input_img_h, batch_input_img_w = images.shape[-2:]
|
| 829 |
+
batch_input_shape = (batch_input_img_h, batch_input_img_w)
|
| 830 |
+
|
| 831 |
+
(
|
| 832 |
+
feat_flatten,
|
| 833 |
+
lvl_pos_embed_flatten,
|
| 834 |
+
memory_mask,
|
| 835 |
+
spatial_shapes,
|
| 836 |
+
level_start_index,
|
| 837 |
+
valid_ratios,
|
| 838 |
+
) = self.pre_transformer(visual_feats, input_hw, batch_input_shape)
|
| 839 |
+
|
| 840 |
+
(
|
| 841 |
+
memory_text,
|
| 842 |
+
text_token_mask,
|
| 843 |
+
hidden_states,
|
| 844 |
+
references,
|
| 845 |
+
enc_outputs_class,
|
| 846 |
+
enc_outputs_coord,
|
| 847 |
+
dn_meta,
|
| 848 |
+
) = self.forward_transformer(
|
| 849 |
+
feat_flatten,
|
| 850 |
+
lvl_pos_embed_flatten,
|
| 851 |
+
memory_mask,
|
| 852 |
+
spatial_shapes,
|
| 853 |
+
level_start_index,
|
| 854 |
+
valid_ratios,
|
| 855 |
+
text_dict,
|
| 856 |
+
boxes2d,
|
| 857 |
+
boxes2d_classes,
|
| 858 |
+
input_hw,
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
|
| 862 |
+
hidden_states, references, memory_text, text_token_mask
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
return GroundingDINOOut(
|
| 866 |
+
all_layers_cls_scores,
|
| 867 |
+
all_layers_bbox_preds,
|
| 868 |
+
enc_outputs_class,
|
| 869 |
+
enc_outputs_coord,
|
| 870 |
+
text_token_mask,
|
| 871 |
+
dn_meta,
|
| 872 |
+
positive_maps,
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
def _forward_test(
|
| 876 |
+
self,
|
| 877 |
+
images: Tensor,
|
| 878 |
+
input_texts: list[str] | None,
|
| 879 |
+
text_prompt_mapping: list[dict[str, dict[str, str]]] | None,
|
| 880 |
+
input_hw: list[tuple[int, int]],
|
| 881 |
+
original_hw: list[tuple[int, int]],
|
| 882 |
+
) -> DetOut:
|
| 883 |
+
"""Forward."""
|
| 884 |
+
batch_size = images.shape[0]
|
| 885 |
+
|
| 886 |
+
token_positive_maps = []
|
| 887 |
+
text_prompts = []
|
| 888 |
+
entities = []
|
| 889 |
+
for i in range(batch_size):
|
| 890 |
+
if self.texts is not None:
|
| 891 |
+
text_prompt = self.texts
|
| 892 |
+
else:
|
| 893 |
+
text_prompt = input_texts[i]
|
| 894 |
+
|
| 895 |
+
if text_prompt_mapping is not None:
|
| 896 |
+
prompt_mapping = text_prompt_mapping[i]
|
| 897 |
+
else:
|
| 898 |
+
prompt_mapping = None
|
| 899 |
+
|
| 900 |
+
token_positive_map, captions, _, _entities = (
|
| 901 |
+
self.get_tokens_positive_and_prompts(
|
| 902 |
+
text_prompt,
|
| 903 |
+
self.custom_entities,
|
| 904 |
+
text_prompt_mapping=prompt_mapping,
|
| 905 |
+
)
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
token_positive_maps.append(token_positive_map)
|
| 909 |
+
text_prompts.append(captions)
|
| 910 |
+
entities.append(_entities)
|
| 911 |
+
|
| 912 |
+
# image feature extraction
|
| 913 |
+
batch_input_img_h, batch_input_img_w = images.shape[-2:]
|
| 914 |
+
|
| 915 |
+
visual_feats = self.backbone(images)[2:]
|
| 916 |
+
visual_feats = self.neck(visual_feats)
|
| 917 |
+
|
| 918 |
+
if isinstance(text_prompts[0], list):
|
| 919 |
+
# TODO: Support chunked text prompts in the future.
|
| 920 |
+
pass
|
| 921 |
+
# assert batch_size == 1, "Batch size should be 1 for chunked text."
|
| 922 |
+
# count = 0
|
| 923 |
+
# results_list = []
|
| 924 |
+
|
| 925 |
+
# entities = [[item for lst in entities[0] for item in lst]]
|
| 926 |
+
|
| 927 |
+
# for i, captions in enumerate(text_prompts[0]):
|
| 928 |
+
# token_positive_map = token_positive_maps[0][i]
|
| 929 |
+
|
| 930 |
+
# text_dict = self.language_model(captions)
|
| 931 |
+
|
| 932 |
+
# # text feature map layer
|
| 933 |
+
# if self.text_feat_map is not None:
|
| 934 |
+
# text_dict["embedded"] = self.text_feat_map(
|
| 935 |
+
# text_dict["embedded"]
|
| 936 |
+
# )
|
| 937 |
+
|
| 938 |
+
# head_inputs_dict = self.forward_transformer(
|
| 939 |
+
# copy.deepcopy(visual_feats), text_dict, input_hw
|
| 940 |
+
# )
|
| 941 |
+
# pred_instances = self.bbox_head.predict(
|
| 942 |
+
# **head_inputs_dict,
|
| 943 |
+
# batch_token_positive_maps=token_positive_map,
|
| 944 |
+
# )[0]
|
| 945 |
+
|
| 946 |
+
# if len(pred_instances) > 0:
|
| 947 |
+
# pred_instances.labels += count
|
| 948 |
+
# count += len(token_positive_maps_once)
|
| 949 |
+
# results_list.append(pred_instances)
|
| 950 |
+
# results_list = [results_list[0].cat(results_list)]
|
| 951 |
+
else:
|
| 952 |
+
# extract text feats
|
| 953 |
+
text_dict = self.language_model(list(text_prompts))
|
| 954 |
+
|
| 955 |
+
# text feature map layer
|
| 956 |
+
if self.text_feat_map is not None:
|
| 957 |
+
text_dict["embedded"] = self.text_feat_map(
|
| 958 |
+
text_dict["embedded"]
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
batch_input_shape = (batch_input_img_h, batch_input_img_w)
|
| 962 |
+
|
| 963 |
+
(
|
| 964 |
+
feat_flatten,
|
| 965 |
+
lvl_pos_embed_flatten,
|
| 966 |
+
memory_mask,
|
| 967 |
+
spatial_shapes,
|
| 968 |
+
level_start_index,
|
| 969 |
+
valid_ratios,
|
| 970 |
+
) = self.pre_transformer(visual_feats, input_hw, batch_input_shape)
|
| 971 |
+
|
| 972 |
+
(
|
| 973 |
+
memory_text,
|
| 974 |
+
text_token_mask,
|
| 975 |
+
hidden_states,
|
| 976 |
+
references,
|
| 977 |
+
) = self.forward_transformer(
|
| 978 |
+
feat_flatten,
|
| 979 |
+
lvl_pos_embed_flatten,
|
| 980 |
+
memory_mask,
|
| 981 |
+
spatial_shapes,
|
| 982 |
+
level_start_index,
|
| 983 |
+
valid_ratios,
|
| 984 |
+
text_dict,
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
|
| 988 |
+
hidden_states, references, memory_text, text_token_mask
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
cls_scores = all_layers_cls_scores[-1]
|
| 992 |
+
bbox_preds = all_layers_bbox_preds[-1]
|
| 993 |
+
|
| 994 |
+
boxes = []
|
| 995 |
+
scores = []
|
| 996 |
+
class_ids = []
|
| 997 |
+
categories = []
|
| 998 |
+
for i, bbox_pred in enumerate(bbox_preds):
|
| 999 |
+
cls_score = cls_scores[i]
|
| 1000 |
+
det_bboxes, det_scores, det_labels = self.roi2det(
|
| 1001 |
+
cls_score,
|
| 1002 |
+
bbox_pred,
|
| 1003 |
+
token_positive_maps[i],
|
| 1004 |
+
input_hw[i],
|
| 1005 |
+
original_hw[i],
|
| 1006 |
+
)
|
| 1007 |
+
boxes.append(det_bboxes)
|
| 1008 |
+
scores.append(det_scores)
|
| 1009 |
+
class_ids.append(det_labels)
|
| 1010 |
+
|
| 1011 |
+
# Get the categories text
|
| 1012 |
+
cur_categories = []
|
| 1013 |
+
for label in det_labels:
|
| 1014 |
+
cur_categories.append(entities[i][label])
|
| 1015 |
+
|
| 1016 |
+
categories.append(cur_categories)
|
| 1017 |
+
|
| 1018 |
+
return DetOut(boxes, scores, class_ids, categories=categories)
|
| 1019 |
+
|
| 1020 |
+
def forward(
|
| 1021 |
+
self,
|
| 1022 |
+
images: Tensor,
|
| 1023 |
+
input_hw: list[tuple[int, int]],
|
| 1024 |
+
boxes2d: Tensor | None = None,
|
| 1025 |
+
boxes2d_classes: Tensor | None = None,
|
| 1026 |
+
original_hw: list[tuple[int, int]] | None = None,
|
| 1027 |
+
input_texts: Sequence[str] | str | None = None,
|
| 1028 |
+
input_tokens_positive: list[dict[int, list[int, int]]] | None = None,
|
| 1029 |
+
text_prompt_mapping: dict[str, dict[str, str]] | None = None,
|
| 1030 |
+
) -> GroundingDINOOut | DetOut:
|
| 1031 |
+
"""Forward function."""
|
| 1032 |
+
if self.training:
|
| 1033 |
+
assert boxes2d is not None and boxes2d_classes is not None
|
| 1034 |
+
return self._forward_train(
|
| 1035 |
+
images,
|
| 1036 |
+
input_texts,
|
| 1037 |
+
boxes2d,
|
| 1038 |
+
boxes2d_classes,
|
| 1039 |
+
input_hw,
|
| 1040 |
+
input_tokens_positive=input_tokens_positive,
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
assert original_hw is not None
|
| 1044 |
+
return self._forward_test(
|
| 1045 |
+
images,
|
| 1046 |
+
input_texts,
|
| 1047 |
+
text_prompt_mapping,
|
| 1048 |
+
input_hw,
|
| 1049 |
+
original_hw,
|
| 1050 |
+
)
|
opendet3d/model/detect3d/__init__.py
ADDED
|
File without changes
|
opendet3d/model/detect3d/grounding_dino_3d.py
ADDED
|
@@ -0,0 +1,812 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""3D-MOOD."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
from collections.abc import Sequence
|
| 7 |
+
from typing import NamedTuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
from vis4d.op.base import BaseModel
|
| 13 |
+
from vis4d.op.fpp.fpn import FPN
|
| 14 |
+
|
| 15 |
+
from opendet3d.model.detect.grounding_dino import GroundingDINO
|
| 16 |
+
from opendet3d.model.language.mm_bert import BertModel
|
| 17 |
+
from opendet3d.op.detect3d.grounding_dino_3d import (
|
| 18 |
+
GroundingDINO3DHead,
|
| 19 |
+
RoI2Det3D,
|
| 20 |
+
)
|
| 21 |
+
from opendet3d.op.detect.grounding_dino import GroundingDINOHead, RoI2Det
|
| 22 |
+
from opendet3d.op.fpp.channel_mapper import ChannelMapper
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Det3DOut(NamedTuple):
|
| 26 |
+
"""Output of the detection model.
|
| 27 |
+
|
| 28 |
+
boxes (list[Tensor]): 2D bounding boxes of shape [N, 4] in xyxy format.
|
| 29 |
+
boxes3d (list[Tensor]): 3D bounding boxes of shape [N, 10].
|
| 30 |
+
scores (list[Tensor]): confidence scores of shape [N,].
|
| 31 |
+
class_ids (list[Tensor]): class ids of shape [N,].
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
boxes: list[Tensor]
|
| 35 |
+
boxes3d: list[Tensor]
|
| 36 |
+
scores: list[Tensor]
|
| 37 |
+
class_ids: list[Tensor]
|
| 38 |
+
depth_maps: list[Tensor] | None
|
| 39 |
+
categories: list[list[str]] | None = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GroundingDINO3DOut(NamedTuple):
|
| 43 |
+
"""Output of the Grounding DINO model."""
|
| 44 |
+
|
| 45 |
+
all_layers_cls_scores: list[Tensor]
|
| 46 |
+
all_layers_bbox_preds: list[Tensor]
|
| 47 |
+
all_layers_bbox_3d_preds: list[Tensor]
|
| 48 |
+
enc_outputs_class: Tensor
|
| 49 |
+
enc_outputs_coord: Tensor
|
| 50 |
+
enc_outputs_3d: Tensor
|
| 51 |
+
text_token_mask: Tensor
|
| 52 |
+
dn_meta: dict[str, Tensor]
|
| 53 |
+
positive_maps: list[Tensor]
|
| 54 |
+
depth_maps: Tensor | None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class GroundingDINO3D(GroundingDINO):
|
| 58 |
+
"""Grounding DINO."""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
basemodel: BaseModel,
|
| 63 |
+
neck: ChannelMapper,
|
| 64 |
+
texts: list[str] | None = None,
|
| 65 |
+
custom_entities: bool = True,
|
| 66 |
+
chunked_size: int = -1,
|
| 67 |
+
num_queries: int = 900,
|
| 68 |
+
num_feature_levels: int = 4,
|
| 69 |
+
use_checkpoint: bool = False,
|
| 70 |
+
bbox_head: GroundingDINOHead | None = None,
|
| 71 |
+
language_model: BertModel | None = None,
|
| 72 |
+
roi2det: RoI2Det | None = None,
|
| 73 |
+
bbox3d_head: GroundingDINO3DHead | None = None,
|
| 74 |
+
roi2det3d: RoI2Det3D | None = None,
|
| 75 |
+
depth_head: nn.Module | None = None,
|
| 76 |
+
fpn: FPN | None = None,
|
| 77 |
+
freeze_detector: bool = False,
|
| 78 |
+
weights: str | None = None,
|
| 79 |
+
cat_mapping: dict[str, int] | None = None,
|
| 80 |
+
) -> None:
|
| 81 |
+
"""Create the Grounding DINO model."""
|
| 82 |
+
super().__init__(
|
| 83 |
+
basemodel=basemodel,
|
| 84 |
+
neck=neck,
|
| 85 |
+
texts=texts,
|
| 86 |
+
custom_entities=custom_entities,
|
| 87 |
+
chunked_size=chunked_size,
|
| 88 |
+
num_queries=num_queries,
|
| 89 |
+
num_feature_levels=num_feature_levels,
|
| 90 |
+
use_checkpoint=use_checkpoint,
|
| 91 |
+
bbox_head=bbox_head,
|
| 92 |
+
roi2det=roi2det,
|
| 93 |
+
language_model=language_model,
|
| 94 |
+
weights=weights,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.bbox3d_head = bbox3d_head or GroundingDINO3DHead()
|
| 98 |
+
self.roi2det3d = roi2det3d or RoI2Det3D()
|
| 99 |
+
|
| 100 |
+
# Depth Head
|
| 101 |
+
self.fpn = fpn
|
| 102 |
+
self.depth_head = depth_head
|
| 103 |
+
|
| 104 |
+
if freeze_detector:
|
| 105 |
+
self._freeze_detector()
|
| 106 |
+
|
| 107 |
+
self.cat_mapping = cat_mapping
|
| 108 |
+
|
| 109 |
+
def _freeze_detector(self):
|
| 110 |
+
"""Freeze the detector."""
|
| 111 |
+
for model in [
|
| 112 |
+
self.backbone,
|
| 113 |
+
self.neck,
|
| 114 |
+
self.encoder,
|
| 115 |
+
self.positional_encoding,
|
| 116 |
+
self.memory_trans_fc,
|
| 117 |
+
self.memory_trans_norm,
|
| 118 |
+
self.decoder,
|
| 119 |
+
self.bbox_head,
|
| 120 |
+
self.dn_query_generator,
|
| 121 |
+
self.language_model,
|
| 122 |
+
self.text_feat_map,
|
| 123 |
+
]:
|
| 124 |
+
model.eval()
|
| 125 |
+
for param in model.parameters():
|
| 126 |
+
param.requires_grad = False
|
| 127 |
+
|
| 128 |
+
self.level_embed.requires_grad = False
|
| 129 |
+
self.query_embedding.requires_grad = False
|
| 130 |
+
|
| 131 |
+
def forward_transformer(
|
| 132 |
+
self,
|
| 133 |
+
feat_flatten: Tensor,
|
| 134 |
+
lvl_pos_embed_flatten: Tensor,
|
| 135 |
+
memory_mask: Tensor | None,
|
| 136 |
+
spatial_shapes: Tensor,
|
| 137 |
+
level_start_index: Tensor,
|
| 138 |
+
valid_ratios: Tensor,
|
| 139 |
+
text_dict: dict[str, Tensor],
|
| 140 |
+
boxes: Tensor | None = None,
|
| 141 |
+
class_ids: Tensor | None = None,
|
| 142 |
+
input_hw: list[tuple[int, int]] | None = None,
|
| 143 |
+
ray_embeddings: Tensor | None = None,
|
| 144 |
+
depth_latents: Tensor | None = None,
|
| 145 |
+
) -> tuple[Tensor, Tensor, Tensor, list[Tensor]]:
|
| 146 |
+
"""Forward function for the transformer."""
|
| 147 |
+
text_token_mask = text_dict["text_token_mask"]
|
| 148 |
+
|
| 149 |
+
memory, memory_text = self.encoder(
|
| 150 |
+
query=feat_flatten,
|
| 151 |
+
query_pos=lvl_pos_embed_flatten,
|
| 152 |
+
key_padding_mask=memory_mask,
|
| 153 |
+
spatial_shapes=spatial_shapes,
|
| 154 |
+
level_start_index=level_start_index,
|
| 155 |
+
valid_ratios=valid_ratios,
|
| 156 |
+
memory_text=text_dict["embedded"],
|
| 157 |
+
text_attention_mask=~text_token_mask,
|
| 158 |
+
position_ids=text_dict["position_ids"],
|
| 159 |
+
text_self_attention_masks=text_dict["masks"],
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
bs = memory.shape[0]
|
| 163 |
+
|
| 164 |
+
output_memory, output_proposals = self.gen_encoder_output_proposals(
|
| 165 |
+
memory, memory_mask, spatial_shapes
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
enc_outputs_class = self.bbox_head.cls_branches[
|
| 169 |
+
self.decoder.num_layers
|
| 170 |
+
](output_memory, memory_text, text_token_mask)
|
| 171 |
+
cls_out_features = self.bbox_head.cls_branches[
|
| 172 |
+
self.decoder.num_layers
|
| 173 |
+
].max_text_len
|
| 174 |
+
enc_outputs_coord_unact = (
|
| 175 |
+
self.bbox_head.reg_branches[self.decoder.num_layers](output_memory)
|
| 176 |
+
+ output_proposals
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# NOTE The DINO selects top-k proposals according to scores of
|
| 180 |
+
# multi-class classification, while DeformDETR, where the input
|
| 181 |
+
# is `enc_outputs_class[..., 0]` selects according to scores of
|
| 182 |
+
# binary classification.
|
| 183 |
+
topk_indices = torch.topk(
|
| 184 |
+
enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1
|
| 185 |
+
)[1]
|
| 186 |
+
|
| 187 |
+
topk_score = torch.gather(
|
| 188 |
+
enc_outputs_class,
|
| 189 |
+
1,
|
| 190 |
+
topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features),
|
| 191 |
+
)
|
| 192 |
+
topk_coords_unact = torch.gather(
|
| 193 |
+
enc_outputs_coord_unact,
|
| 194 |
+
1,
|
| 195 |
+
topk_indices.unsqueeze(-1).repeat(1, 1, 4),
|
| 196 |
+
)
|
| 197 |
+
topk_coords = topk_coords_unact.sigmoid()
|
| 198 |
+
topk_coords_unact = topk_coords_unact.detach()
|
| 199 |
+
|
| 200 |
+
# Top-k 3D proposals
|
| 201 |
+
topk_output_memory = torch.gather(
|
| 202 |
+
output_memory, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 256)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
topk_output_3d = self.bbox3d_head.single_forward(
|
| 206 |
+
self.decoder.num_layers,
|
| 207 |
+
topk_output_memory,
|
| 208 |
+
ray_embeddings,
|
| 209 |
+
depth_latents,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
query = self.query_embedding.weight[:, None, :]
|
| 213 |
+
query = query.repeat(1, bs, 1).transpose(0, 1)
|
| 214 |
+
|
| 215 |
+
if self.training:
|
| 216 |
+
dn_label_query, dn_bbox_query, dn_mask, dn_meta = (
|
| 217 |
+
self.dn_query_generator(boxes, class_ids, input_hw)
|
| 218 |
+
)
|
| 219 |
+
query = torch.cat([dn_label_query, query], dim=1)
|
| 220 |
+
reference_points = torch.cat(
|
| 221 |
+
[dn_bbox_query, topk_coords_unact], dim=1
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
reference_points = topk_coords_unact
|
| 225 |
+
dn_mask, dn_meta = None, None
|
| 226 |
+
|
| 227 |
+
reference_points = reference_points.sigmoid()
|
| 228 |
+
|
| 229 |
+
# NOTE DINO calculates encoder losses on scores and coordinates
|
| 230 |
+
# of selected top-k encoder queries, while DeformDETR is of all
|
| 231 |
+
# encoder queries.
|
| 232 |
+
if self.training:
|
| 233 |
+
enc_outputs_class = topk_score
|
| 234 |
+
enc_outputs_coord = topk_coords
|
| 235 |
+
|
| 236 |
+
hidden_states, references = self.decoder(
|
| 237 |
+
query=query,
|
| 238 |
+
value=memory,
|
| 239 |
+
key_padding_mask=memory_mask,
|
| 240 |
+
self_attn_mask=dn_mask,
|
| 241 |
+
reference_points=reference_points,
|
| 242 |
+
spatial_shapes=spatial_shapes,
|
| 243 |
+
level_start_index=level_start_index,
|
| 244 |
+
valid_ratios=valid_ratios,
|
| 245 |
+
reg_branches=self.bbox_head.reg_branches,
|
| 246 |
+
memory_text=memory_text,
|
| 247 |
+
text_attention_mask=~text_token_mask,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if len(query) == self.num_queries:
|
| 251 |
+
# NOTE: This is to make sure label_embeding can be involved to
|
| 252 |
+
# produce loss even if there is no denoising query (no ground truth
|
| 253 |
+
# target in this GPU), otherwise, this will raise runtime error in
|
| 254 |
+
# distributed training.
|
| 255 |
+
hidden_states[0] += (
|
| 256 |
+
self.dn_query_generator.label_embedding.weight[0, 0] * 0.0
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if self.training:
|
| 260 |
+
return (
|
| 261 |
+
memory_text,
|
| 262 |
+
text_token_mask,
|
| 263 |
+
hidden_states,
|
| 264 |
+
list(references),
|
| 265 |
+
enc_outputs_class,
|
| 266 |
+
enc_outputs_coord,
|
| 267 |
+
topk_output_3d,
|
| 268 |
+
dn_meta,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
return (
|
| 272 |
+
memory_text,
|
| 273 |
+
text_token_mask,
|
| 274 |
+
hidden_states,
|
| 275 |
+
list(references),
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def _extract_image_features(
|
| 279 |
+
self, images: Tensor
|
| 280 |
+
) -> tuple[list[Tensor], list[Tensor] | None]:
|
| 281 |
+
"""Extract image features."""
|
| 282 |
+
visual_feats = self.backbone(images)[2:]
|
| 283 |
+
|
| 284 |
+
if self.fpn is not None:
|
| 285 |
+
depth_feats = self.fpn(visual_feats)
|
| 286 |
+
|
| 287 |
+
if len(visual_feats) > len(self.neck.in_channels):
|
| 288 |
+
start_index = len(visual_feats) - len(self.neck.in_channels)
|
| 289 |
+
# NOTE: This is to make sure the number of input channels is the
|
| 290 |
+
# same as the number of input channels of the neck.
|
| 291 |
+
visual_feats = visual_feats[start_index:]
|
| 292 |
+
|
| 293 |
+
visual_feats = self.neck(visual_feats)
|
| 294 |
+
|
| 295 |
+
if self.fpn is None:
|
| 296 |
+
depth_feats = visual_feats
|
| 297 |
+
|
| 298 |
+
return visual_feats, depth_feats
|
| 299 |
+
|
| 300 |
+
def _forward_train(
|
| 301 |
+
self,
|
| 302 |
+
images: Tensor,
|
| 303 |
+
input_texts: list[list[str]] | None,
|
| 304 |
+
boxes2d: list[Tensor],
|
| 305 |
+
boxes2d_classes: list[Tensor],
|
| 306 |
+
input_hw: list[tuple[int, int]],
|
| 307 |
+
intrinsics: Tensor,
|
| 308 |
+
input_tokens_positive: list[list[int, int]] | list[None] | None = None,
|
| 309 |
+
padding: list[list[int]] | None = None,
|
| 310 |
+
) -> GroundingDINO3DOut:
|
| 311 |
+
"""Forward function for training."""
|
| 312 |
+
batch_size = images.shape[0]
|
| 313 |
+
|
| 314 |
+
new_text_prompts = []
|
| 315 |
+
positive_maps = []
|
| 316 |
+
if input_tokens_positive is not None:
|
| 317 |
+
for tokens_positive_dict, text_prompt, gt_label in zip(
|
| 318 |
+
input_tokens_positive, input_texts, boxes2d_classes
|
| 319 |
+
):
|
| 320 |
+
if tokens_positive_dict is not None:
|
| 321 |
+
tokenized = self.language_model.tokenizer(
|
| 322 |
+
[text_prompt], padding="longest", return_tensors="pt"
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
new_text_prompts.append(text_prompt)
|
| 326 |
+
|
| 327 |
+
new_tokens_positive = [
|
| 328 |
+
tokens_positive_dict[label.item()]
|
| 329 |
+
for label in gt_label
|
| 330 |
+
]
|
| 331 |
+
else:
|
| 332 |
+
tokenized, caption_string, tokens_positive, _ = (
|
| 333 |
+
self.get_tokens_and_prompts(text_prompt)
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
new_text_prompts.append(caption_string)
|
| 337 |
+
|
| 338 |
+
new_tokens_positive = [
|
| 339 |
+
tokens_positive[label] for label in gt_label
|
| 340 |
+
]
|
| 341 |
+
|
| 342 |
+
_, positive_map = self.get_positive_map(
|
| 343 |
+
tokenized, new_tokens_positive
|
| 344 |
+
)
|
| 345 |
+
positive_maps.append(positive_map)
|
| 346 |
+
else:
|
| 347 |
+
# All the text prompts are the same or using the self.texts,
|
| 348 |
+
# so there is no need to calculate them multiple times.
|
| 349 |
+
if (
|
| 350 |
+
input_texts is None
|
| 351 |
+
or len(set(["".join(t) for t in input_texts])) == 1
|
| 352 |
+
):
|
| 353 |
+
if input_texts is None:
|
| 354 |
+
assert self.texts is not None, "Texts should be provided."
|
| 355 |
+
text_prompt = self.texts
|
| 356 |
+
else:
|
| 357 |
+
text_prompt = input_texts[0]
|
| 358 |
+
|
| 359 |
+
tokenized, caption_string, tokens_positive, _ = (
|
| 360 |
+
self.get_tokens_and_prompts(text_prompt)
|
| 361 |
+
)
|
| 362 |
+
new_text_prompts = [caption_string] * batch_size
|
| 363 |
+
for gt_label in boxes2d_classes:
|
| 364 |
+
new_tokens_positive = [
|
| 365 |
+
tokens_positive[label] for label in gt_label
|
| 366 |
+
]
|
| 367 |
+
_, positive_map = self.get_positive_map(
|
| 368 |
+
tokenized, new_tokens_positive
|
| 369 |
+
)
|
| 370 |
+
positive_maps.append(positive_map)
|
| 371 |
+
else:
|
| 372 |
+
for text_prompt, gt_label in zip(input_texts, boxes2d_classes):
|
| 373 |
+
tokenized, caption_string, tokens_positive, _ = (
|
| 374 |
+
self.get_tokens_and_prompts(text_prompt)
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
new_text_prompts.append(caption_string)
|
| 378 |
+
|
| 379 |
+
new_tokens_positive = [
|
| 380 |
+
tokens_positive[label] for label in gt_label
|
| 381 |
+
]
|
| 382 |
+
_, positive_map = self.get_positive_map(
|
| 383 |
+
tokenized, new_tokens_positive
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
positive_maps.append(positive_map)
|
| 387 |
+
|
| 388 |
+
for i in range(batch_size):
|
| 389 |
+
positive_maps[i] = (
|
| 390 |
+
positive_maps[i].to(images.device).bool().float()
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
text_dict = self.language_model(new_text_prompts)
|
| 394 |
+
|
| 395 |
+
if self.text_feat_map is not None:
|
| 396 |
+
text_dict["embedded"] = self.text_feat_map(text_dict["embedded"])
|
| 397 |
+
|
| 398 |
+
text_token_masks = []
|
| 399 |
+
for i in range(batch_size):
|
| 400 |
+
text_token_masks.append(
|
| 401 |
+
text_dict["text_token_mask"][i]
|
| 402 |
+
.unsqueeze(0)
|
| 403 |
+
.repeat(len(positive_map), 1)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
visual_feats, depth_feats = self._extract_image_features(images)
|
| 407 |
+
|
| 408 |
+
batch_input_img_h, batch_input_img_w = images.shape[-2:]
|
| 409 |
+
batch_input_shape = (batch_input_img_h, batch_input_img_w)
|
| 410 |
+
|
| 411 |
+
(
|
| 412 |
+
feat_flatten,
|
| 413 |
+
lvl_pos_embed_flatten,
|
| 414 |
+
memory_mask,
|
| 415 |
+
spatial_shapes,
|
| 416 |
+
level_start_index,
|
| 417 |
+
valid_ratios,
|
| 418 |
+
) = self.pre_transformer(
|
| 419 |
+
visual_feats, input_hw, batch_input_shape, padding=padding
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
ray_embeddings = self.bbox3d_head.get_camera_embeddings(
|
| 423 |
+
intrinsics, batch_input_shape
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Depth Head
|
| 427 |
+
depth_preds, depth_latents = self.depth_head(
|
| 428 |
+
depth_feats, intrinsics, batch_input_shape
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
(
|
| 432 |
+
memory_text,
|
| 433 |
+
text_token_mask,
|
| 434 |
+
hidden_states,
|
| 435 |
+
references,
|
| 436 |
+
enc_outputs_class,
|
| 437 |
+
enc_outputs_coord,
|
| 438 |
+
enc_outputs_3d,
|
| 439 |
+
dn_meta,
|
| 440 |
+
) = self.forward_transformer(
|
| 441 |
+
feat_flatten,
|
| 442 |
+
lvl_pos_embed_flatten,
|
| 443 |
+
memory_mask,
|
| 444 |
+
spatial_shapes,
|
| 445 |
+
level_start_index,
|
| 446 |
+
valid_ratios,
|
| 447 |
+
text_dict,
|
| 448 |
+
boxes2d,
|
| 449 |
+
boxes2d_classes,
|
| 450 |
+
input_hw,
|
| 451 |
+
ray_embeddings,
|
| 452 |
+
depth_latents,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
|
| 456 |
+
hidden_states, references, memory_text, text_token_mask
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# Not using denoising for 3D
|
| 460 |
+
hidden_states_3d = hidden_states[
|
| 461 |
+
:, :, dn_meta["num_denoising_queries"] :, :
|
| 462 |
+
]
|
| 463 |
+
|
| 464 |
+
all_layers_outputs_3d = self.bbox3d_head(
|
| 465 |
+
hidden_states_3d, ray_embeddings, depth_latents
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
return GroundingDINO3DOut(
|
| 469 |
+
all_layers_cls_scores,
|
| 470 |
+
all_layers_bbox_preds,
|
| 471 |
+
all_layers_outputs_3d,
|
| 472 |
+
enc_outputs_class,
|
| 473 |
+
enc_outputs_coord,
|
| 474 |
+
enc_outputs_3d,
|
| 475 |
+
text_token_mask,
|
| 476 |
+
dn_meta,
|
| 477 |
+
positive_maps,
|
| 478 |
+
depth_maps=depth_preds,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
def _forward_test(
|
| 482 |
+
self,
|
| 483 |
+
images: Tensor,
|
| 484 |
+
input_texts: list[str] | None,
|
| 485 |
+
text_prompt_mapping: list[dict[str, dict[str, str]]] | None,
|
| 486 |
+
input_hw: list[tuple[int, int]],
|
| 487 |
+
original_hw: list[tuple[int, int]],
|
| 488 |
+
intrinsics: list[Tensor] | None,
|
| 489 |
+
padding: list[list[int]] | None,
|
| 490 |
+
) -> Det3DOut:
|
| 491 |
+
"""Forward."""
|
| 492 |
+
batch_size = images.shape[0]
|
| 493 |
+
|
| 494 |
+
token_positive_maps = []
|
| 495 |
+
text_prompts = []
|
| 496 |
+
entities = []
|
| 497 |
+
for i in range(batch_size):
|
| 498 |
+
if self.texts is not None:
|
| 499 |
+
text_prompt = self.texts
|
| 500 |
+
else:
|
| 501 |
+
text_prompt = input_texts[i]
|
| 502 |
+
|
| 503 |
+
if text_prompt_mapping is not None:
|
| 504 |
+
prompt_mapping = text_prompt_mapping[i]
|
| 505 |
+
else:
|
| 506 |
+
prompt_mapping = None
|
| 507 |
+
|
| 508 |
+
token_positive_map, captions, _, _entities = (
|
| 509 |
+
self.get_tokens_positive_and_prompts(
|
| 510 |
+
text_prompt,
|
| 511 |
+
self.custom_entities,
|
| 512 |
+
text_prompt_mapping=prompt_mapping,
|
| 513 |
+
)
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
token_positive_maps.append(token_positive_map)
|
| 517 |
+
text_prompts.append(captions)
|
| 518 |
+
entities.append(_entities)
|
| 519 |
+
|
| 520 |
+
# image feature extraction
|
| 521 |
+
batch_input_img_h, batch_input_img_w = images.shape[-2:]
|
| 522 |
+
|
| 523 |
+
visual_feats, depth_feats = self._extract_image_features(images)
|
| 524 |
+
|
| 525 |
+
batch_input_shape = (batch_input_img_h, batch_input_img_w)
|
| 526 |
+
|
| 527 |
+
if isinstance(text_prompts[0], list):
|
| 528 |
+
assert batch_size == 1, "Batch size should be 1 for chunked text."
|
| 529 |
+
assert (
|
| 530 |
+
self.cat_mapping is not None
|
| 531 |
+
), "Category mapping should be provided."
|
| 532 |
+
|
| 533 |
+
boxes = []
|
| 534 |
+
boxes3d = []
|
| 535 |
+
scores = []
|
| 536 |
+
class_ids = []
|
| 537 |
+
categories = []
|
| 538 |
+
for i, captions in enumerate(text_prompts[0]):
|
| 539 |
+
token_positive_map = token_positive_maps[0][i]
|
| 540 |
+
cur_entities = entities[0][i]
|
| 541 |
+
|
| 542 |
+
text_dict = self.language_model([captions])
|
| 543 |
+
|
| 544 |
+
# text feature map layer
|
| 545 |
+
if self.text_feat_map is not None:
|
| 546 |
+
text_dict["embedded"] = self.text_feat_map(
|
| 547 |
+
text_dict["embedded"]
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
(
|
| 551 |
+
feat_flatten,
|
| 552 |
+
lvl_pos_embed_flatten,
|
| 553 |
+
memory_mask,
|
| 554 |
+
spatial_shapes,
|
| 555 |
+
level_start_index,
|
| 556 |
+
valid_ratios,
|
| 557 |
+
) = self.pre_transformer(
|
| 558 |
+
copy.deepcopy(visual_feats), input_hw, batch_input_shape
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
ray_embeddings = self.bbox3d_head.get_camera_embeddings(
|
| 562 |
+
intrinsics, batch_input_shape
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# Depth Head
|
| 566 |
+
depth_preds, depth_latents = self.depth_head(
|
| 567 |
+
depth_feats, intrinsics, batch_input_shape
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
depth_maps = []
|
| 571 |
+
for i, depth_pred in enumerate(depth_preds):
|
| 572 |
+
if padding is not None:
|
| 573 |
+
pad_left, pad_right, pad_top, pad_bottom = padding[i]
|
| 574 |
+
|
| 575 |
+
depth_pred = depth_pred[
|
| 576 |
+
pad_top : batch_input_img_h - pad_bottom,
|
| 577 |
+
pad_left : batch_input_img_w - pad_right,
|
| 578 |
+
]
|
| 579 |
+
|
| 580 |
+
depth_maps.append(
|
| 581 |
+
F.interpolate(
|
| 582 |
+
depth_pred.unsqueeze(0).unsqueeze(0),
|
| 583 |
+
size=original_hw[i],
|
| 584 |
+
mode="bilinear",
|
| 585 |
+
align_corners=False,
|
| 586 |
+
antialias=True,
|
| 587 |
+
)
|
| 588 |
+
.squeeze(0)
|
| 589 |
+
.squeeze(0)
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
(
|
| 593 |
+
memory_text,
|
| 594 |
+
text_token_mask,
|
| 595 |
+
hidden_states,
|
| 596 |
+
references,
|
| 597 |
+
) = self.forward_transformer(
|
| 598 |
+
feat_flatten,
|
| 599 |
+
lvl_pos_embed_flatten,
|
| 600 |
+
memory_mask,
|
| 601 |
+
spatial_shapes,
|
| 602 |
+
level_start_index,
|
| 603 |
+
valid_ratios,
|
| 604 |
+
text_dict,
|
| 605 |
+
ray_embeddings=ray_embeddings,
|
| 606 |
+
depth_latents=depth_latents,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
|
| 610 |
+
hidden_states, references, memory_text, text_token_mask
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
all_layers_outputs_3d = self.bbox3d_head(
|
| 614 |
+
hidden_states, ray_embeddings, depth_latents=depth_latents
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
cls_scores = all_layers_cls_scores[-1]
|
| 618 |
+
bbox_preds = all_layers_bbox_preds[-1]
|
| 619 |
+
boxes3d_preds = all_layers_outputs_3d[-1]
|
| 620 |
+
|
| 621 |
+
cls_score = cls_scores[0]
|
| 622 |
+
det_bboxes, det_scores, det_labels, det_bboxes3d = (
|
| 623 |
+
self.roi2det3d(
|
| 624 |
+
cls_score,
|
| 625 |
+
bbox_preds[0],
|
| 626 |
+
token_positive_map,
|
| 627 |
+
input_hw[0],
|
| 628 |
+
original_hw[0],
|
| 629 |
+
boxes3d_preds[0],
|
| 630 |
+
intrinsics[0],
|
| 631 |
+
padding[0] if padding is not None else None,
|
| 632 |
+
)
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
boxes.append(det_bboxes)
|
| 636 |
+
scores.append(det_scores)
|
| 637 |
+
boxes3d.append(det_bboxes3d)
|
| 638 |
+
|
| 639 |
+
# Get the categories text and class ids
|
| 640 |
+
cur_class_ids = []
|
| 641 |
+
cur_categories = []
|
| 642 |
+
for label in det_labels:
|
| 643 |
+
cur_class_ids.append(self.cat_mapping[cur_entities[label]])
|
| 644 |
+
cur_categories.append(cur_entities[label])
|
| 645 |
+
|
| 646 |
+
class_ids.append(cur_class_ids)
|
| 647 |
+
categories.append(cur_categories)
|
| 648 |
+
|
| 649 |
+
boxes = [torch.cat([b for b in boxes])]
|
| 650 |
+
boxes3d = [torch.cat([b for b in boxes3d])]
|
| 651 |
+
scores = [torch.cat([s for s in scores])]
|
| 652 |
+
class_ids = [sum(class_ids, [])]
|
| 653 |
+
categories = [sum(categories, [])]
|
| 654 |
+
else:
|
| 655 |
+
# extract text feats
|
| 656 |
+
text_dict = self.language_model(list(text_prompts))
|
| 657 |
+
|
| 658 |
+
# text feature map layer
|
| 659 |
+
if self.text_feat_map is not None:
|
| 660 |
+
text_dict["embedded"] = self.text_feat_map(
|
| 661 |
+
text_dict["embedded"]
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
(
|
| 665 |
+
feat_flatten,
|
| 666 |
+
lvl_pos_embed_flatten,
|
| 667 |
+
memory_mask,
|
| 668 |
+
spatial_shapes,
|
| 669 |
+
level_start_index,
|
| 670 |
+
valid_ratios,
|
| 671 |
+
) = self.pre_transformer(visual_feats, input_hw, batch_input_shape)
|
| 672 |
+
|
| 673 |
+
ray_embeddings = self.bbox3d_head.get_camera_embeddings(
|
| 674 |
+
intrinsics, batch_input_shape
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Depth Head
|
| 678 |
+
depth_preds, depth_latents = self.depth_head(
|
| 679 |
+
depth_feats, intrinsics, batch_input_shape
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
depth_maps = []
|
| 683 |
+
for i, depth_pred in enumerate(depth_preds):
|
| 684 |
+
if padding is not None:
|
| 685 |
+
pad_left, pad_right, pad_top, pad_bottom = padding[i]
|
| 686 |
+
|
| 687 |
+
depth_pred = depth_pred[
|
| 688 |
+
pad_top : batch_input_img_h - pad_bottom,
|
| 689 |
+
pad_left : batch_input_img_w - pad_right,
|
| 690 |
+
]
|
| 691 |
+
|
| 692 |
+
depth_maps.append(
|
| 693 |
+
F.interpolate(
|
| 694 |
+
depth_pred.unsqueeze(0).unsqueeze(0),
|
| 695 |
+
size=original_hw[i],
|
| 696 |
+
mode="bilinear",
|
| 697 |
+
align_corners=False,
|
| 698 |
+
antialias=True,
|
| 699 |
+
)
|
| 700 |
+
.squeeze(0)
|
| 701 |
+
.squeeze(0)
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
(
|
| 705 |
+
memory_text,
|
| 706 |
+
text_token_mask,
|
| 707 |
+
hidden_states,
|
| 708 |
+
references,
|
| 709 |
+
) = self.forward_transformer(
|
| 710 |
+
feat_flatten,
|
| 711 |
+
lvl_pos_embed_flatten,
|
| 712 |
+
memory_mask,
|
| 713 |
+
spatial_shapes,
|
| 714 |
+
level_start_index,
|
| 715 |
+
valid_ratios,
|
| 716 |
+
text_dict,
|
| 717 |
+
ray_embeddings=ray_embeddings,
|
| 718 |
+
depth_latents=depth_latents,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
all_layers_cls_scores, all_layers_bbox_preds = self.bbox_head(
|
| 722 |
+
hidden_states, references, memory_text, text_token_mask
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
all_layers_outputs_3d = self.bbox3d_head(
|
| 726 |
+
hidden_states, ray_embeddings, depth_latents=depth_latents
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
cls_scores = all_layers_cls_scores[-1]
|
| 730 |
+
bbox_preds = all_layers_bbox_preds[-1]
|
| 731 |
+
boxes3d_preds = all_layers_outputs_3d[-1]
|
| 732 |
+
|
| 733 |
+
boxes = []
|
| 734 |
+
boxes3d = []
|
| 735 |
+
scores = []
|
| 736 |
+
class_ids = []
|
| 737 |
+
categories = []
|
| 738 |
+
for i, bbox_pred in enumerate(bbox_preds):
|
| 739 |
+
cls_score = cls_scores[i]
|
| 740 |
+
det_bboxes, det_scores, det_labels, det_bboxes3d = (
|
| 741 |
+
self.roi2det3d(
|
| 742 |
+
cls_score,
|
| 743 |
+
bbox_pred,
|
| 744 |
+
token_positive_maps[i],
|
| 745 |
+
input_hw[i],
|
| 746 |
+
original_hw[i],
|
| 747 |
+
boxes3d_preds[i],
|
| 748 |
+
intrinsics[i],
|
| 749 |
+
padding[i] if padding is not None else None,
|
| 750 |
+
)
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
boxes.append(det_bboxes)
|
| 754 |
+
scores.append(det_scores)
|
| 755 |
+
class_ids.append(det_labels)
|
| 756 |
+
boxes3d.append(det_bboxes3d)
|
| 757 |
+
|
| 758 |
+
# Get the categories text
|
| 759 |
+
cur_categories = []
|
| 760 |
+
for label in det_labels:
|
| 761 |
+
cur_categories.append(entities[i][label])
|
| 762 |
+
|
| 763 |
+
categories.append(cur_categories)
|
| 764 |
+
|
| 765 |
+
return Det3DOut(
|
| 766 |
+
boxes,
|
| 767 |
+
boxes3d,
|
| 768 |
+
scores,
|
| 769 |
+
class_ids,
|
| 770 |
+
depth_maps=depth_maps,
|
| 771 |
+
categories=categories,
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
def forward(
|
| 775 |
+
self,
|
| 776 |
+
images: Tensor,
|
| 777 |
+
input_hw: list[tuple[int, int]],
|
| 778 |
+
intrinsics: Tensor,
|
| 779 |
+
boxes2d: Tensor | None = None,
|
| 780 |
+
boxes2d_classes: Tensor | None = None,
|
| 781 |
+
original_hw: list[tuple[int, int]] | None = None,
|
| 782 |
+
input_texts: Sequence[str] | str | None = None,
|
| 783 |
+
input_tokens_positive: list[dict[int, list[int, int]]] | None = None,
|
| 784 |
+
text_prompt_mapping: dict[str, dict[str, str]] | None = None,
|
| 785 |
+
padding: list[list[int]] | None = None,
|
| 786 |
+
**kwargs,
|
| 787 |
+
) -> GroundingDINO3DOut | Det3DOut:
|
| 788 |
+
"""Forward function."""
|
| 789 |
+
if self.training:
|
| 790 |
+
assert boxes2d is not None and boxes2d_classes is not None
|
| 791 |
+
return self._forward_train(
|
| 792 |
+
images,
|
| 793 |
+
input_texts,
|
| 794 |
+
boxes2d,
|
| 795 |
+
boxes2d_classes,
|
| 796 |
+
input_hw,
|
| 797 |
+
intrinsics,
|
| 798 |
+
input_tokens_positive=input_tokens_positive,
|
| 799 |
+
padding=padding,
|
| 800 |
+
**kwargs,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
assert original_hw is not None
|
| 804 |
+
return self._forward_test(
|
| 805 |
+
images,
|
| 806 |
+
input_texts,
|
| 807 |
+
text_prompt_mapping,
|
| 808 |
+
input_hw,
|
| 809 |
+
original_hw,
|
| 810 |
+
intrinsics,
|
| 811 |
+
padding,
|
| 812 |
+
)
|
opendet3d/model/language/__init__.py
ADDED
|
File without changes
|
opendet3d/model/language/mm_bert.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BERT model from mmdetection."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
from collections.abc import Sequence
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from transformers import AutoTokenizer, BertConfig
|
| 11 |
+
from transformers import BertModel as HFBertModel
|
| 12 |
+
|
| 13 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def generate_masks_with_special_tokens_and_transfer_map(
|
| 17 |
+
tokenized, special_tokens_list
|
| 18 |
+
):
|
| 19 |
+
"""Generate attention mask between each pair of special tokens.
|
| 20 |
+
|
| 21 |
+
Only token pairs in between two special tokens are attended to
|
| 22 |
+
and thus the attention mask for these pairs is positive.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
|
| 26 |
+
special_tokens_mask (list): special tokens mask.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Tuple(Tensor, Tensor):
|
| 30 |
+
- attention_mask is the attention mask between each tokens.
|
| 31 |
+
Only token pairs in between two special tokens are positive.
|
| 32 |
+
Shape: [bs, num_token, num_token].
|
| 33 |
+
- position_ids is the position id of tokens within each valid sentence.
|
| 34 |
+
The id starts from 0 whenenver a special token is encountered.
|
| 35 |
+
Shape: [bs, num_token]
|
| 36 |
+
"""
|
| 37 |
+
input_ids = tokenized["input_ids"]
|
| 38 |
+
bs, num_token = input_ids.shape
|
| 39 |
+
# special_tokens_mask:
|
| 40 |
+
# bs, num_token. 1 for special tokens. 0 for normal tokens
|
| 41 |
+
special_tokens_mask = torch.zeros(
|
| 42 |
+
(bs, num_token), device=input_ids.device
|
| 43 |
+
).bool()
|
| 44 |
+
|
| 45 |
+
for special_token in special_tokens_list:
|
| 46 |
+
special_tokens_mask |= input_ids == special_token
|
| 47 |
+
|
| 48 |
+
# idxs: each row is a list of indices of special tokens
|
| 49 |
+
idxs = torch.nonzero(special_tokens_mask)
|
| 50 |
+
|
| 51 |
+
# generate attention mask and positional ids
|
| 52 |
+
attention_mask = (
|
| 53 |
+
torch.eye(num_token, device=input_ids.device)
|
| 54 |
+
.bool()
|
| 55 |
+
.unsqueeze(0)
|
| 56 |
+
.repeat(bs, 1, 1)
|
| 57 |
+
)
|
| 58 |
+
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
|
| 59 |
+
previous_col = 0
|
| 60 |
+
for i in range(idxs.shape[0]):
|
| 61 |
+
row, col = idxs[i]
|
| 62 |
+
if (col == 0) or (col == num_token - 1):
|
| 63 |
+
attention_mask[row, col, col] = True
|
| 64 |
+
position_ids[row, col] = 0
|
| 65 |
+
else:
|
| 66 |
+
attention_mask[
|
| 67 |
+
row, previous_col + 1 : col + 1, previous_col + 1 : col + 1
|
| 68 |
+
] = True
|
| 69 |
+
position_ids[row, previous_col + 1 : col + 1] = torch.arange(
|
| 70 |
+
0, col - previous_col, device=input_ids.device
|
| 71 |
+
)
|
| 72 |
+
previous_col = col
|
| 73 |
+
|
| 74 |
+
return attention_mask, position_ids.to(torch.long)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class BertModel(nn.Module):
|
| 78 |
+
"""BERT model for language embedding only encoder.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
name (str, optional): name of the pretrained BERT model from
|
| 82 |
+
HuggingFace. Defaults to bert-base-uncased.
|
| 83 |
+
max_tokens (int, optional): maximum number of tokens to be
|
| 84 |
+
used for BERT. Defaults to 256.
|
| 85 |
+
pad_to_max (bool, optional): whether to pad the tokens to max_tokens.
|
| 86 |
+
Defaults to True.
|
| 87 |
+
use_sub_sentence_represent (bool, optional): whether to use sub
|
| 88 |
+
sentence represent introduced in `Grounding DINO
|
| 89 |
+
<https://arxiv.org/abs/2303.05499>`. Defaults to False.
|
| 90 |
+
special_tokens_list (list, optional): special tokens used to split
|
| 91 |
+
subsentence. It cannot be None when `use_sub_sentence_represent`
|
| 92 |
+
is True. Defaults to None.
|
| 93 |
+
add_pooling_layer (bool, optional): whether to adding pooling
|
| 94 |
+
layer in bert encoder. Defaults to False.
|
| 95 |
+
num_layers_of_embedded (int, optional): number of layers of
|
| 96 |
+
the embedded model. Defaults to 1.
|
| 97 |
+
use_checkpoint (bool, optional): whether to use gradient checkpointing.
|
| 98 |
+
Defaults to False.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
name: str = "bert-base-uncased",
|
| 104 |
+
max_tokens: int = 256,
|
| 105 |
+
pad_to_max: bool = True,
|
| 106 |
+
use_sub_sentence_represent: bool = False,
|
| 107 |
+
special_tokens_list: list = None,
|
| 108 |
+
add_pooling_layer: bool = False,
|
| 109 |
+
num_layers_of_embedded: int = 1,
|
| 110 |
+
use_checkpoint: bool = False,
|
| 111 |
+
) -> None:
|
| 112 |
+
"""Create the BERT model."""
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.max_tokens = max_tokens
|
| 115 |
+
self.pad_to_max = pad_to_max
|
| 116 |
+
|
| 117 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name)
|
| 118 |
+
self.language_backbone = nn.Sequential(
|
| 119 |
+
OrderedDict(
|
| 120 |
+
[
|
| 121 |
+
(
|
| 122 |
+
"body",
|
| 123 |
+
BertEncoder(
|
| 124 |
+
name,
|
| 125 |
+
add_pooling_layer=add_pooling_layer,
|
| 126 |
+
num_layers_of_embedded=num_layers_of_embedded,
|
| 127 |
+
use_checkpoint=use_checkpoint,
|
| 128 |
+
),
|
| 129 |
+
)
|
| 130 |
+
]
|
| 131 |
+
)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self.use_sub_sentence_represent = use_sub_sentence_represent
|
| 135 |
+
if self.use_sub_sentence_represent:
|
| 136 |
+
assert (
|
| 137 |
+
special_tokens_list is not None
|
| 138 |
+
), "special_tokens should not be None \
|
| 139 |
+
if use_sub_sentence_represent is True"
|
| 140 |
+
|
| 141 |
+
self.special_tokens = self.tokenizer.convert_tokens_to_ids(
|
| 142 |
+
special_tokens_list
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def forward(self, captions: Sequence[str]) -> dict:
|
| 146 |
+
"""Forward function."""
|
| 147 |
+
device = next(self.language_backbone.parameters()).device
|
| 148 |
+
tokenized = self.tokenizer.batch_encode_plus(
|
| 149 |
+
captions,
|
| 150 |
+
max_length=self.max_tokens,
|
| 151 |
+
padding="max_length" if self.pad_to_max else "longest",
|
| 152 |
+
return_special_tokens_mask=True,
|
| 153 |
+
return_tensors="pt",
|
| 154 |
+
truncation=True,
|
| 155 |
+
).to(device)
|
| 156 |
+
input_ids = tokenized.input_ids
|
| 157 |
+
if self.use_sub_sentence_represent:
|
| 158 |
+
attention_mask, position_ids = (
|
| 159 |
+
generate_masks_with_special_tokens_and_transfer_map(
|
| 160 |
+
tokenized, self.special_tokens
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
token_type_ids = tokenized["token_type_ids"]
|
| 164 |
+
|
| 165 |
+
else:
|
| 166 |
+
attention_mask = tokenized.attention_mask
|
| 167 |
+
position_ids = None
|
| 168 |
+
token_type_ids = None
|
| 169 |
+
|
| 170 |
+
tokenizer_input = {
|
| 171 |
+
"input_ids": input_ids,
|
| 172 |
+
"attention_mask": attention_mask,
|
| 173 |
+
"position_ids": position_ids,
|
| 174 |
+
"token_type_ids": token_type_ids,
|
| 175 |
+
}
|
| 176 |
+
language_dict_features = self.language_backbone(tokenizer_input)
|
| 177 |
+
if self.use_sub_sentence_represent:
|
| 178 |
+
language_dict_features["position_ids"] = position_ids
|
| 179 |
+
language_dict_features["text_token_mask"] = (
|
| 180 |
+
tokenized.attention_mask.bool()
|
| 181 |
+
)
|
| 182 |
+
return language_dict_features
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class BertEncoder(nn.Module):
|
| 186 |
+
"""BERT encoder for language embedding.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
name (str): name of the pretrained BERT model from HuggingFace.
|
| 190 |
+
Defaults to bert-base-uncased.
|
| 191 |
+
add_pooling_layer (bool): whether to add a pooling layer.
|
| 192 |
+
num_layers_of_embedded (int): number of layers of the embedded model.
|
| 193 |
+
Defaults to 1.
|
| 194 |
+
use_checkpoint (bool): whether to use gradient checkpointing.
|
| 195 |
+
Defaults to False.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
name: str,
|
| 201 |
+
add_pooling_layer: bool = False,
|
| 202 |
+
num_layers_of_embedded: int = 1,
|
| 203 |
+
use_checkpoint: bool = False,
|
| 204 |
+
):
|
| 205 |
+
super().__init__()
|
| 206 |
+
config = BertConfig.from_pretrained(name)
|
| 207 |
+
config.gradient_checkpointing = use_checkpoint
|
| 208 |
+
|
| 209 |
+
loggers = [
|
| 210 |
+
logging.getLogger(name) for name in logging.root.manager.loggerDict
|
| 211 |
+
]
|
| 212 |
+
for logger in loggers:
|
| 213 |
+
if "transformers" in logger.name.lower():
|
| 214 |
+
logger.setLevel(logging.ERROR)
|
| 215 |
+
|
| 216 |
+
# only encoder
|
| 217 |
+
self.model = HFBertModel.from_pretrained(
|
| 218 |
+
name,
|
| 219 |
+
add_pooling_layer=add_pooling_layer,
|
| 220 |
+
config=config,
|
| 221 |
+
attn_implementation="eager",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.language_dim = config.hidden_size
|
| 225 |
+
self.num_layers_of_embedded = num_layers_of_embedded
|
| 226 |
+
|
| 227 |
+
def forward(self, x) -> dict:
|
| 228 |
+
mask = x["attention_mask"]
|
| 229 |
+
|
| 230 |
+
outputs = self.model(
|
| 231 |
+
input_ids=x["input_ids"],
|
| 232 |
+
attention_mask=mask,
|
| 233 |
+
position_ids=x["position_ids"],
|
| 234 |
+
token_type_ids=x["token_type_ids"],
|
| 235 |
+
output_hidden_states=True,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# outputs has 13 layers, 1 input layer and 12 hidden layers
|
| 239 |
+
encoded_layers = outputs.hidden_states[1:]
|
| 240 |
+
features = torch.stack(
|
| 241 |
+
encoded_layers[-self.num_layers_of_embedded :], 1
|
| 242 |
+
).mean(1)
|
| 243 |
+
# language embedding has shape [len(phrase), seq_len, language_dim]
|
| 244 |
+
features = features / self.num_layers_of_embedded
|
| 245 |
+
if mask.dim() == 2:
|
| 246 |
+
embedded = features * mask.unsqueeze(-1).float()
|
| 247 |
+
else:
|
| 248 |
+
embedded = features
|
| 249 |
+
|
| 250 |
+
results = {
|
| 251 |
+
"embedded": embedded,
|
| 252 |
+
"masks": mask,
|
| 253 |
+
"hidden": encoded_layers[-1],
|
| 254 |
+
}
|
| 255 |
+
return results
|
opendet3d/op/__init__.py
ADDED
|
File without changes
|
opendet3d/op/base/__init__.py
ADDED
|
File without changes
|
opendet3d/op/base/swin.py
ADDED
|
@@ -0,0 +1,870 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|