Spaces:
Running
Running
Upload 169 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +57 -0
- app.py +19 -0
- mapanything/__init__.py +0 -0
- mapanything/__pycache__/__init__.cpython-312.pyc +0 -0
- mapanything/datasets/__init__.py +177 -0
- mapanything/datasets/base/__init__.py +0 -0
- mapanything/datasets/base/base_dataset.py +697 -0
- mapanything/datasets/base/batched_sampler.py +431 -0
- mapanything/datasets/base/easy_dataset.py +478 -0
- mapanything/datasets/utils/__init__.py +0 -0
- mapanything/datasets/utils/data_splits.py +1734 -0
- mapanything/datasets/wai/__init__.py +0 -0
- mapanything/datasets/wai/ase.py +294 -0
- mapanything/datasets/wai/blendedmvs.py +313 -0
- mapanything/datasets/wai/dl3dv.py +356 -0
- mapanything/datasets/wai/dynamicreplica.py +297 -0
- mapanything/datasets/wai/eth3d.py +277 -0
- mapanything/datasets/wai/megadepth.py +314 -0
- mapanything/datasets/wai/mpsd.py +311 -0
- mapanything/datasets/wai/mvs_synth.py +308 -0
- mapanything/datasets/wai/paralleldomain4d.py +309 -0
- mapanything/datasets/wai/sailvos3d.py +308 -0
- mapanything/datasets/wai/scannetpp.py +307 -0
- mapanything/datasets/wai/spring.py +316 -0
- mapanything/datasets/wai/tav2_wb.py +328 -0
- mapanything/datasets/wai/unrealstereo4k.py +309 -0
- mapanything/models/__init__.py +190 -0
- mapanything/models/__pycache__/__init__.cpython-312.pyc +0 -0
- mapanything/models/external/README.md +5 -0
- mapanything/models/external/__init__.py +0 -0
- mapanything/models/external/anycalib/__init__.py +100 -0
- mapanything/models/external/dinov2/__init__.py +6 -0
- mapanything/models/external/dinov2/hub/__init__.py +4 -0
- mapanything/models/external/dinov2/hub/backbones.py +183 -0
- mapanything/models/external/dinov2/hub/utils.py +42 -0
- mapanything/models/external/dinov2/layers/__init__.py +14 -0
- mapanything/models/external/dinov2/layers/attention.py +90 -0
- mapanything/models/external/dinov2/layers/block.py +290 -0
- mapanything/models/external/dinov2/layers/dino_head.py +67 -0
- mapanything/models/external/dinov2/layers/drop_path.py +36 -0
- mapanything/models/external/dinov2/layers/layer_scale.py +26 -0
- mapanything/models/external/dinov2/layers/mlp.py +40 -0
- mapanything/models/external/dinov2/layers/patch_embed.py +100 -0
- mapanything/models/external/dinov2/layers/swiglu_ffn.py +71 -0
- mapanything/models/external/dinov2/models/__init__.py +44 -0
- mapanything/models/external/dinov2/models/vision_transformer.py +448 -0
- mapanything/models/external/dinov2/utils/__init__.py +4 -0
- mapanything/models/external/dinov2/utils/cluster.py +102 -0
- mapanything/models/external/dinov2/utils/config.py +74 -0
- mapanything/models/external/dinov2/utils/dtype.py +38 -0
.gitignore
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
MANIFEST
|
| 23 |
+
|
| 24 |
+
# Virtual Environment
|
| 25 |
+
venv/
|
| 26 |
+
ENV/
|
| 27 |
+
env/
|
| 28 |
+
.venv
|
| 29 |
+
|
| 30 |
+
# IDE
|
| 31 |
+
.vscode/
|
| 32 |
+
.idea/
|
| 33 |
+
*.swp
|
| 34 |
+
*.swo
|
| 35 |
+
*~
|
| 36 |
+
.DS_Store
|
| 37 |
+
|
| 38 |
+
# HuggingFace Space 临时文件
|
| 39 |
+
input_images_*/
|
| 40 |
+
*.glb
|
| 41 |
+
*.npz
|
| 42 |
+
flagged/
|
| 43 |
+
|
| 44 |
+
# 本地模型缓存(已改用 HuggingFace)
|
| 45 |
+
models/
|
| 46 |
+
|
| 47 |
+
# 日志
|
| 48 |
+
*.log
|
| 49 |
+
logs/
|
| 50 |
+
|
| 51 |
+
# 测试文件
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
.coverage
|
| 54 |
+
htmlcov/
|
| 55 |
+
|
| 56 |
+
# 系统文件
|
| 57 |
+
Thumbs.db
|
app.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
HuggingFace Space 入口文件
|
| 5 |
+
直接导入并运行 gradio_app_v8
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# 添加 scripts 目录到 Python 路径
|
| 12 |
+
scripts_dir = Path(__file__).parent / "scripts"
|
| 13 |
+
sys.path.insert(0, str(scripts_dir))
|
| 14 |
+
|
| 15 |
+
# 导入并运行主应用
|
| 16 |
+
if __name__ == "__main__":
|
| 17 |
+
# 导入 gradio_app_v8(会自动启动 demo)
|
| 18 |
+
import gradio_app_v8
|
| 19 |
+
|
mapanything/__init__.py
ADDED
|
File without changes
|
mapanything/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
mapanything/datasets/__init__.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
MapAnything Datasets
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from mapanything.datasets.wai.ase import ASEWAI # noqa
|
| 13 |
+
from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI # noqa
|
| 14 |
+
from mapanything.datasets.wai.dl3dv import DL3DVWAI # noqa
|
| 15 |
+
from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa
|
| 16 |
+
from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa
|
| 17 |
+
from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa
|
| 18 |
+
from mapanything.datasets.wai.mpsd import MPSDWAI # noqa
|
| 19 |
+
from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa
|
| 20 |
+
from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noqa
|
| 21 |
+
from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI # noqa
|
| 22 |
+
from mapanything.datasets.wai.scannetpp import ScanNetPPWAI # noqa
|
| 23 |
+
from mapanything.datasets.wai.spring import SpringWAI # noqa
|
| 24 |
+
from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI # noqa
|
| 25 |
+
from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI # noqa
|
| 26 |
+
from mapanything.utils.train_tools import get_rank, get_world_size
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_test_data_loader(
|
| 30 |
+
dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True
|
| 31 |
+
):
|
| 32 |
+
"Get simple PyTorch dataloader corresponding to the testing dataset"
|
| 33 |
+
# PyTorch dataset
|
| 34 |
+
if isinstance(dataset, str):
|
| 35 |
+
dataset = eval(dataset)
|
| 36 |
+
|
| 37 |
+
world_size = get_world_size()
|
| 38 |
+
rank = get_rank()
|
| 39 |
+
|
| 40 |
+
if torch.distributed.is_initialized():
|
| 41 |
+
sampler = torch.utils.data.DistributedSampler(
|
| 42 |
+
dataset,
|
| 43 |
+
num_replicas=world_size,
|
| 44 |
+
rank=rank,
|
| 45 |
+
shuffle=shuffle,
|
| 46 |
+
drop_last=drop_last,
|
| 47 |
+
)
|
| 48 |
+
elif shuffle:
|
| 49 |
+
sampler = torch.utils.data.RandomSampler(dataset)
|
| 50 |
+
else:
|
| 51 |
+
sampler = torch.utils.data.SequentialSampler(dataset)
|
| 52 |
+
|
| 53 |
+
data_loader = torch.utils.data.DataLoader(
|
| 54 |
+
dataset,
|
| 55 |
+
sampler=sampler,
|
| 56 |
+
batch_size=batch_size,
|
| 57 |
+
num_workers=num_workers,
|
| 58 |
+
pin_memory=pin_mem,
|
| 59 |
+
drop_last=drop_last,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return data_loader
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_test_many_ar_data_loader(
|
| 66 |
+
dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True
|
| 67 |
+
):
|
| 68 |
+
"Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios"
|
| 69 |
+
# PyTorch dataset
|
| 70 |
+
if isinstance(dataset, str):
|
| 71 |
+
dataset = eval(dataset)
|
| 72 |
+
|
| 73 |
+
world_size = get_world_size()
|
| 74 |
+
rank = get_rank()
|
| 75 |
+
|
| 76 |
+
# Get BatchedMultiFeatureRandomSampler
|
| 77 |
+
sampler = dataset.make_sampler(
|
| 78 |
+
batch_size,
|
| 79 |
+
shuffle=True,
|
| 80 |
+
world_size=world_size,
|
| 81 |
+
rank=rank,
|
| 82 |
+
drop_last=drop_last,
|
| 83 |
+
use_dynamic_sampler=False,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Init the data laoder
|
| 87 |
+
data_loader = torch.utils.data.DataLoader(
|
| 88 |
+
dataset,
|
| 89 |
+
sampler=sampler,
|
| 90 |
+
batch_size=batch_size,
|
| 91 |
+
num_workers=num_workers,
|
| 92 |
+
pin_memory=pin_mem,
|
| 93 |
+
drop_last=drop_last,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return data_loader
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DynamicBatchDatasetWrapper:
|
| 100 |
+
"""
|
| 101 |
+
Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output.
|
| 102 |
+
|
| 103 |
+
The dynamic sampler returns batches (lists of tuples) instead of individual samples.
|
| 104 |
+
This wrapper ensures that the underlying dataset's __getitem__ method gets called
|
| 105 |
+
with individual tuples as expected.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, dataset):
|
| 109 |
+
self.dataset = dataset
|
| 110 |
+
|
| 111 |
+
def __getitem__(self, batch_indices):
|
| 112 |
+
"""
|
| 113 |
+
Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
List of samples from the underlying dataset
|
| 120 |
+
"""
|
| 121 |
+
if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0:
|
| 122 |
+
# If it's a batch (list of tuples), process each item
|
| 123 |
+
if isinstance(batch_indices[0], (list, tuple)):
|
| 124 |
+
return [self.dataset[idx] for idx in batch_indices]
|
| 125 |
+
else:
|
| 126 |
+
# Single tuple, call dataset directly
|
| 127 |
+
return self.dataset[batch_indices]
|
| 128 |
+
else:
|
| 129 |
+
# Fallback for single index
|
| 130 |
+
return self.dataset[batch_indices]
|
| 131 |
+
|
| 132 |
+
def __len__(self):
|
| 133 |
+
return len(self.dataset)
|
| 134 |
+
|
| 135 |
+
def __getattr__(self, name):
|
| 136 |
+
# Delegate all other attributes to the wrapped dataset
|
| 137 |
+
return getattr(self.dataset, name)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_train_data_loader(
|
| 141 |
+
dataset,
|
| 142 |
+
max_num_of_imgs_per_gpu,
|
| 143 |
+
num_workers=8,
|
| 144 |
+
shuffle=True,
|
| 145 |
+
drop_last=True,
|
| 146 |
+
pin_mem=True,
|
| 147 |
+
):
|
| 148 |
+
"Dynamic PyTorch dataloader corresponding to the training dataset"
|
| 149 |
+
# PyTorch dataset
|
| 150 |
+
if isinstance(dataset, str):
|
| 151 |
+
dataset = eval(dataset)
|
| 152 |
+
|
| 153 |
+
world_size = get_world_size()
|
| 154 |
+
rank = get_rank()
|
| 155 |
+
|
| 156 |
+
# Get DynamicBatchedMultiFeatureRandomSampler
|
| 157 |
+
batch_sampler = dataset.make_sampler(
|
| 158 |
+
shuffle=shuffle,
|
| 159 |
+
world_size=world_size,
|
| 160 |
+
rank=rank,
|
| 161 |
+
drop_last=drop_last,
|
| 162 |
+
max_num_of_images_per_gpu=max_num_of_imgs_per_gpu,
|
| 163 |
+
use_dynamic_sampler=True,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Wrap the dataset to handle batch format from dynamic sampler
|
| 167 |
+
wrapped_dataset = DynamicBatchDatasetWrapper(dataset)
|
| 168 |
+
|
| 169 |
+
# Init the dynamic data loader
|
| 170 |
+
data_loader = torch.utils.data.DataLoader(
|
| 171 |
+
wrapped_dataset,
|
| 172 |
+
batch_sampler=batch_sampler,
|
| 173 |
+
num_workers=num_workers,
|
| 174 |
+
pin_memory=pin_mem,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
return data_loader
|
mapanything/datasets/base/__init__.py
ADDED
|
File without changes
|
mapanything/datasets/base/base_dataset.py
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Base class for MapAnything datasets.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import List, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import PIL
|
| 14 |
+
import torch
|
| 15 |
+
import torchvision.transforms as tvf
|
| 16 |
+
from scipy.spatial.transform import Rotation
|
| 17 |
+
|
| 18 |
+
from mapanything.datasets.base.easy_dataset import EasyDataset
|
| 19 |
+
from mapanything.utils.cropping import (
|
| 20 |
+
bbox_from_intrinsics_in_out,
|
| 21 |
+
camera_matrix_of_crop,
|
| 22 |
+
crop_image_and_other_optional_info,
|
| 23 |
+
rescale_image_and_other_optional_info,
|
| 24 |
+
)
|
| 25 |
+
from mapanything.utils.geometry import (
|
| 26 |
+
depthmap_to_camera_coordinates,
|
| 27 |
+
get_absolute_pointmaps_and_rays_info,
|
| 28 |
+
)
|
| 29 |
+
from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BaseDataset(EasyDataset):
|
| 33 |
+
"""
|
| 34 |
+
Define all basic options.
|
| 35 |
+
|
| 36 |
+
Usage:
|
| 37 |
+
class MyDataset(BaseDataset):
|
| 38 |
+
def _get_views(self, idx):
|
| 39 |
+
views = []
|
| 40 |
+
views.append(dict(img=, ...))
|
| 41 |
+
return views
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
num_views: int,
|
| 47 |
+
variable_num_views: bool = False,
|
| 48 |
+
split: str = None,
|
| 49 |
+
covisibility_thres: float = None,
|
| 50 |
+
resolution: Union[int, Tuple[int, int], List[Tuple[int, int]]] = None,
|
| 51 |
+
principal_point_centered: bool = False,
|
| 52 |
+
transform: str = None,
|
| 53 |
+
data_norm_type: str = None,
|
| 54 |
+
aug_crop: int = 0,
|
| 55 |
+
seed: int = None,
|
| 56 |
+
max_num_retries: int = 5,
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
PyTorch dataset for multi-view images sampled from scenes, where the images form a single connected component.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
num_views (int): Number of views.
|
| 63 |
+
variable_num_views (bool): If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2.
|
| 64 |
+
On by default for N-view train dataloader (hydra config).
|
| 65 |
+
split (str): 'train', 'val', 'test', etc.
|
| 66 |
+
covisibility_thres (float): Covisibility (%) threshold to determine if another image is a neighbor or not
|
| 67 |
+
resolution (int or tuple or list of tuples): Resolution of the images
|
| 68 |
+
principal_point_centered (bool): If True, the principal point is centered in the image.
|
| 69 |
+
transform (str): Transform to apply to the images. Options:
|
| 70 |
+
- 'colorjitter+grayscale+gaublur':
|
| 71 |
+
tvf.Compose([
|
| 72 |
+
tvf.RandomApply([tvf.ColorJittter(0.3, 0.4, 0.2, 0.1)], p=0.75),
|
| 73 |
+
tvf.RandomGrayscale(p=0.05),
|
| 74 |
+
tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05),
|
| 75 |
+
]) after ImgNorm
|
| 76 |
+
- 'colorjitter': tvf.ColorJittter(0.5, 0.5, 0.5, 0.1) after ImgNorm
|
| 77 |
+
- 'imgnorm': ImgNorm only
|
| 78 |
+
data_norm_type (str): Image normalization type.
|
| 79 |
+
For options, see UniCeption image normalization dict.
|
| 80 |
+
aug_crop (int): Augment crop. If int greater than 0, indicates the number of pixels to increase in target resolution.
|
| 81 |
+
seed (int): Seed for the random number generator.
|
| 82 |
+
max_num_retries (int): Maximum number of retries for loading a different sample from the dataset, if provided idx fails.
|
| 83 |
+
"""
|
| 84 |
+
self.num_views = num_views
|
| 85 |
+
self.variable_num_views = variable_num_views
|
| 86 |
+
self.num_views_min = 2
|
| 87 |
+
self.split = split
|
| 88 |
+
self.covisibility_thres = covisibility_thres
|
| 89 |
+
self._set_resolutions(resolution)
|
| 90 |
+
self.principal_point_centered = principal_point_centered
|
| 91 |
+
|
| 92 |
+
# Update the number of views if necessary and make it a list if variable_num_views is True
|
| 93 |
+
if self.variable_num_views and self.num_views > self.num_views_min:
|
| 94 |
+
self.num_views = list(range(self.num_views_min, self.num_views + 1))
|
| 95 |
+
|
| 96 |
+
# Initialize the image normalization type
|
| 97 |
+
if data_norm_type in IMAGE_NORMALIZATION_DICT.keys():
|
| 98 |
+
self.data_norm_type = data_norm_type
|
| 99 |
+
image_norm = IMAGE_NORMALIZATION_DICT[data_norm_type]
|
| 100 |
+
ImgNorm = tvf.Compose(
|
| 101 |
+
[
|
| 102 |
+
tvf.ToTensor(),
|
| 103 |
+
tvf.Normalize(mean=image_norm.mean, std=image_norm.std),
|
| 104 |
+
]
|
| 105 |
+
)
|
| 106 |
+
elif data_norm_type == "identity":
|
| 107 |
+
self.data_norm_type = data_norm_type
|
| 108 |
+
ImgNorm = tvf.Compose([tvf.ToTensor()])
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(
|
| 111 |
+
f"Unknown data_norm_type: {data_norm_type}. Available options: identity or {list(IMAGE_NORMALIZATION_DICT.keys())}"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Initialize torchvision transforms
|
| 115 |
+
if transform == "imgnorm":
|
| 116 |
+
self.transform = ImgNorm
|
| 117 |
+
elif transform == "colorjitter":
|
| 118 |
+
self.transform = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
|
| 119 |
+
elif transform == "colorjitter+grayscale+gaublur":
|
| 120 |
+
self.transform = tvf.Compose(
|
| 121 |
+
[
|
| 122 |
+
tvf.RandomApply([tvf.ColorJitter(0.3, 0.4, 0.2, 0.1)], p=0.75),
|
| 123 |
+
tvf.RandomGrayscale(p=0.05),
|
| 124 |
+
tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05),
|
| 125 |
+
ImgNorm,
|
| 126 |
+
]
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
raise ValueError(
|
| 130 |
+
'Unknown transform. Available options: "imgnorm", "colorjitter", "colorjitter+grayscale+gaublur"'
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Initialize the augmentation parameters
|
| 134 |
+
self.aug_crop = aug_crop
|
| 135 |
+
|
| 136 |
+
# Initialize the seed for the random number generator
|
| 137 |
+
self.seed = seed
|
| 138 |
+
self._seed_offset = 0
|
| 139 |
+
|
| 140 |
+
# Initialize the maximum number of retries for loading a different sample from the dataset, if the first idx fails
|
| 141 |
+
self.max_num_retries = max_num_retries
|
| 142 |
+
|
| 143 |
+
# Initialize the dataset type flags
|
| 144 |
+
self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this
|
| 145 |
+
self.is_synthetic = False # by default a dataset is not synthetic, subclasses can overwrite this
|
| 146 |
+
|
| 147 |
+
def _load_data(self):
|
| 148 |
+
self.scenes = []
|
| 149 |
+
self.num_of_scenes = len(self.scenes)
|
| 150 |
+
|
| 151 |
+
def __len__(self):
|
| 152 |
+
"Length of the dataset is determined by the number of scenes in the dataset split"
|
| 153 |
+
return self.num_of_scenes
|
| 154 |
+
|
| 155 |
+
def get_stats(self):
|
| 156 |
+
"Get the number of scenes in the dataset split"
|
| 157 |
+
return f"{self.num_of_scenes} scenes"
|
| 158 |
+
|
| 159 |
+
def __repr__(self):
|
| 160 |
+
resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
|
| 161 |
+
return (
|
| 162 |
+
f"""{type(self).__name__}({self.get_stats()},
|
| 163 |
+
{self.num_views=}
|
| 164 |
+
{self.split=},
|
| 165 |
+
{self.seed=},
|
| 166 |
+
resolutions={resolutions_str},
|
| 167 |
+
{self.transform=})""".replace("self.", "")
|
| 168 |
+
.replace("\n", "")
|
| 169 |
+
.replace(" ", "")
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def _get_views(self, idx, num_views_to_sample, resolution):
|
| 173 |
+
raise NotImplementedError()
|
| 174 |
+
|
| 175 |
+
def _set_seed_offset(self, idx):
|
| 176 |
+
"""
|
| 177 |
+
Set the seed offset. This is directly added to self.seed when setting the random seed.
|
| 178 |
+
"""
|
| 179 |
+
self._seed_offset = idx
|
| 180 |
+
|
| 181 |
+
def _set_resolutions(self, resolutions):
|
| 182 |
+
assert resolutions is not None, "undefined resolution"
|
| 183 |
+
|
| 184 |
+
if isinstance(resolutions, int):
|
| 185 |
+
resolutions = [resolutions]
|
| 186 |
+
elif isinstance(resolutions, tuple):
|
| 187 |
+
resolutions = [resolutions]
|
| 188 |
+
elif isinstance(resolutions, list):
|
| 189 |
+
assert all(isinstance(res, tuple) for res in resolutions), (
|
| 190 |
+
f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints"
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(
|
| 194 |
+
f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self._resolutions = []
|
| 198 |
+
for resolution in resolutions:
|
| 199 |
+
if isinstance(resolution, int):
|
| 200 |
+
width = height = resolution
|
| 201 |
+
else:
|
| 202 |
+
width, height = resolution
|
| 203 |
+
assert isinstance(width, int), (
|
| 204 |
+
f"Bad type for {width=} {type(width)=}, should be int"
|
| 205 |
+
)
|
| 206 |
+
assert isinstance(height, int), (
|
| 207 |
+
f"Bad type for {height=} {type(height)=}, should be int"
|
| 208 |
+
)
|
| 209 |
+
self._resolutions.append((width, height))
|
| 210 |
+
|
| 211 |
+
def _crop_resize_if_necessary(
|
| 212 |
+
self,
|
| 213 |
+
image,
|
| 214 |
+
resolution,
|
| 215 |
+
depthmap,
|
| 216 |
+
intrinsics,
|
| 217 |
+
additional_quantities=None,
|
| 218 |
+
):
|
| 219 |
+
"""
|
| 220 |
+
Process an image by downsampling and cropping as needed to match the target resolution.
|
| 221 |
+
|
| 222 |
+
This method performs the following operations:
|
| 223 |
+
1. Converts the image to PIL.Image if necessary
|
| 224 |
+
2. Crops the image centered on the principal point if requested
|
| 225 |
+
3. Downsamples the image using high-quality Lanczos filtering
|
| 226 |
+
4. Performs final cropping to match the target resolution
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
image (numpy.ndarray or PIL.Image.Image): Input image to be processed
|
| 230 |
+
resolution (tuple): Target resolution as (width, height)
|
| 231 |
+
depthmap (numpy.ndarray): Depth map corresponding to the image
|
| 232 |
+
intrinsics (numpy.ndarray): Camera intrinsics matrix (3x3)
|
| 233 |
+
additional_quantities (dict, optional): Additional image-related data to be processed
|
| 234 |
+
alongside the main image with nearest interpolation. Defaults to None.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
tuple: Processed image, depthmap, and updated intrinsics matrix.
|
| 238 |
+
If additional_quantities is provided, it returns those as well.
|
| 239 |
+
"""
|
| 240 |
+
if not isinstance(image, PIL.Image.Image):
|
| 241 |
+
image = PIL.Image.fromarray(image)
|
| 242 |
+
|
| 243 |
+
# Cropping centered on the principal point if necessary
|
| 244 |
+
if self.principal_point_centered:
|
| 245 |
+
W, H = image.size
|
| 246 |
+
cx, cy = intrinsics[:2, 2].round().astype(int)
|
| 247 |
+
if cx < 0 or cx >= W or cy < 0 or cy >= H:
|
| 248 |
+
# Skip centered cropping if principal point is outside image bounds
|
| 249 |
+
pass
|
| 250 |
+
else:
|
| 251 |
+
min_margin_x = min(cx, W - cx)
|
| 252 |
+
min_margin_y = min(cy, H - cy)
|
| 253 |
+
left, top = cx - min_margin_x, cy - min_margin_y
|
| 254 |
+
right, bottom = cx + min_margin_x, cy + min_margin_y
|
| 255 |
+
crop_bbox = (left, top, right, bottom)
|
| 256 |
+
# Only perform the centered crop if the crop_bbox is larger than the target resolution
|
| 257 |
+
crop_width = right - left
|
| 258 |
+
crop_height = bottom - top
|
| 259 |
+
if crop_width > resolution[0] and crop_height > resolution[1]:
|
| 260 |
+
image, depthmap, intrinsics, additional_quantities = (
|
| 261 |
+
crop_image_and_other_optional_info(
|
| 262 |
+
image=image,
|
| 263 |
+
crop_bbox=crop_bbox,
|
| 264 |
+
depthmap=depthmap,
|
| 265 |
+
camera_intrinsics=intrinsics,
|
| 266 |
+
additional_quantities=additional_quantities,
|
| 267 |
+
)
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Get the target resolution for re-scaling
|
| 271 |
+
target_rescale_resolution = np.array(resolution)
|
| 272 |
+
if self.aug_crop > 1:
|
| 273 |
+
target_rescale_resolution += self._rng.integers(0, self.aug_crop)
|
| 274 |
+
|
| 275 |
+
# High-quality Lanczos down-scaling if necessary
|
| 276 |
+
image, depthmap, intrinsics, additional_quantities = (
|
| 277 |
+
rescale_image_and_other_optional_info(
|
| 278 |
+
image=image,
|
| 279 |
+
output_resolution=target_rescale_resolution,
|
| 280 |
+
depthmap=depthmap,
|
| 281 |
+
camera_intrinsics=intrinsics,
|
| 282 |
+
additional_quantities_to_be_resized_with_nearest=additional_quantities,
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Actual cropping (if necessary)
|
| 287 |
+
new_intrinsics = camera_matrix_of_crop(
|
| 288 |
+
input_camera_matrix=intrinsics,
|
| 289 |
+
input_resolution=image.size,
|
| 290 |
+
output_resolution=resolution,
|
| 291 |
+
offset_factor=0.5,
|
| 292 |
+
)
|
| 293 |
+
crop_bbox = bbox_from_intrinsics_in_out(
|
| 294 |
+
input_camera_matrix=intrinsics,
|
| 295 |
+
output_camera_matrix=new_intrinsics,
|
| 296 |
+
output_resolution=resolution,
|
| 297 |
+
)
|
| 298 |
+
image, depthmap, new_intrinsics, additional_quantities = (
|
| 299 |
+
crop_image_and_other_optional_info(
|
| 300 |
+
image=image,
|
| 301 |
+
crop_bbox=crop_bbox,
|
| 302 |
+
depthmap=depthmap,
|
| 303 |
+
camera_intrinsics=intrinsics,
|
| 304 |
+
additional_quantities=additional_quantities,
|
| 305 |
+
)
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Return the output
|
| 309 |
+
if additional_quantities is not None:
|
| 310 |
+
return image, depthmap, new_intrinsics, additional_quantities
|
| 311 |
+
else:
|
| 312 |
+
return image, depthmap, new_intrinsics
|
| 313 |
+
|
| 314 |
+
def _random_walk_sampling(
|
| 315 |
+
self,
|
| 316 |
+
scene_pairwise_covisibility,
|
| 317 |
+
num_of_samples,
|
| 318 |
+
max_retries=4,
|
| 319 |
+
use_bidirectional_covis=True,
|
| 320 |
+
):
|
| 321 |
+
"""
|
| 322 |
+
Randomly samples S indices from an N x N covisibility matrix by forming adjacency edges such that the resulting subgraph (given by the indices) is connected.
|
| 323 |
+
If the current node has no new unvisited neighbors, backtracking occurs.
|
| 324 |
+
Retries with different starting indices if the desired number of samples is not reached, excluding previously visited components.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
scene_pairwise_covisibility : np.ndarray (mmap)
|
| 328 |
+
N x N covisibility matrix for the scene, where N is the number of views in the scene.
|
| 329 |
+
num_of_samples : int
|
| 330 |
+
The desired number of nodes to sample (num_of_samples < N).
|
| 331 |
+
max_retries : int
|
| 332 |
+
The maximum number of retries with different starting indices.
|
| 333 |
+
use_bidirectional_covis : bool
|
| 334 |
+
Whether to compute bidirectional covisibility by averaging row and column values.
|
| 335 |
+
If False, uses only row access (faster for large memory-mapped arrays).
|
| 336 |
+
Defaults to True.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
np.ndarray
|
| 340 |
+
An array of sampled indices forming a connected subgraph.
|
| 341 |
+
"""
|
| 342 |
+
excluded_nodes = set()
|
| 343 |
+
best_walk = [] # To keep track of the best walk found
|
| 344 |
+
for _ in range(max_retries):
|
| 345 |
+
visited = set()
|
| 346 |
+
walk = [] # List to store the random walk sampling order
|
| 347 |
+
stack = [] # Stack for backtracking
|
| 348 |
+
|
| 349 |
+
# Choose a random starting index that is not in the excluded set
|
| 350 |
+
all_nodes = set(range(len(scene_pairwise_covisibility)))
|
| 351 |
+
available_nodes = list(all_nodes - excluded_nodes)
|
| 352 |
+
if not available_nodes:
|
| 353 |
+
break # No more nodes to try
|
| 354 |
+
start = self._rng.choice(available_nodes)
|
| 355 |
+
walk.append(start)
|
| 356 |
+
visited.add(start)
|
| 357 |
+
stack.append(start)
|
| 358 |
+
|
| 359 |
+
# Continue until we have sampled S indices or all expandable nodes are exhausted
|
| 360 |
+
while len(walk) < num_of_samples and stack:
|
| 361 |
+
current = stack[-1]
|
| 362 |
+
# Get the pairwise covisibility for the current node
|
| 363 |
+
if use_bidirectional_covis:
|
| 364 |
+
# Use bidirectional covisibility (slower for large memory-mapped arrays)
|
| 365 |
+
pairwise_covisibility = (
|
| 366 |
+
scene_pairwise_covisibility[current, :]
|
| 367 |
+
+ scene_pairwise_covisibility[:, current].T
|
| 368 |
+
) / 2
|
| 369 |
+
else:
|
| 370 |
+
# Use only row access (faster for large memory-mapped arrays)
|
| 371 |
+
pairwise_covisibility = scene_pairwise_covisibility[current, :]
|
| 372 |
+
# Normalize the covisibility using self covisibility
|
| 373 |
+
pairwise_covisibility = pairwise_covisibility / (
|
| 374 |
+
pairwise_covisibility[current] + 1e-8
|
| 375 |
+
)
|
| 376 |
+
# Assign overlap score of zero to self-pairs
|
| 377 |
+
pairwise_covisibility[current] = 0
|
| 378 |
+
# Threshold the covisibility to get adjacency list for the current node
|
| 379 |
+
adjacency_list_for_current = (
|
| 380 |
+
pairwise_covisibility > self.covisibility_thres
|
| 381 |
+
).astype(int)
|
| 382 |
+
adjacency_list_for_current = np.flatnonzero(adjacency_list_for_current)
|
| 383 |
+
# Get all unvisited neighbors
|
| 384 |
+
candidates = [
|
| 385 |
+
idx for idx in adjacency_list_for_current if idx not in visited
|
| 386 |
+
] # Remove visited nodes
|
| 387 |
+
if candidates:
|
| 388 |
+
# Randomly select one of the unvisited overlapping neighbors
|
| 389 |
+
next_node = self._rng.choice(candidates)
|
| 390 |
+
walk.append(next_node)
|
| 391 |
+
visited.add(next_node)
|
| 392 |
+
stack.append(next_node)
|
| 393 |
+
else:
|
| 394 |
+
# If no unvisited neighbor is available, backtrack
|
| 395 |
+
stack.pop()
|
| 396 |
+
|
| 397 |
+
# Update the best walk if the current walk is larger
|
| 398 |
+
if len(walk) > len(best_walk):
|
| 399 |
+
best_walk = walk
|
| 400 |
+
|
| 401 |
+
# If we have enough samples, return the result
|
| 402 |
+
if len(walk) >= num_of_samples:
|
| 403 |
+
return np.array(walk)
|
| 404 |
+
|
| 405 |
+
# Add all visited nodes to the excluded set
|
| 406 |
+
excluded_nodes.update(visited)
|
| 407 |
+
|
| 408 |
+
# If all retries are exhausted and we still don't have enough samples, return the best walk found
|
| 409 |
+
return np.array(best_walk)
|
| 410 |
+
|
| 411 |
+
def _sample_view_indices(
|
| 412 |
+
self,
|
| 413 |
+
num_views_to_sample,
|
| 414 |
+
num_views_in_scene,
|
| 415 |
+
scene_pairwise_covisibility,
|
| 416 |
+
use_bidirectional_covis=True,
|
| 417 |
+
):
|
| 418 |
+
"""
|
| 419 |
+
Sample view indices from a scene based on the adjacency list and the number of views to sample.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
num_views_to_sample (int): Number of views to sample.
|
| 423 |
+
num_views_in_scene (int): Total number of views available in the scene.
|
| 424 |
+
scene_pairwise_covisibility (np.ndarray): N x N covisibility matrix for the scene, where N is the number of views in the scene.
|
| 425 |
+
use_bidirectional_covis (bool): Whether to compute bidirectional covisibility by averaging row and column values.
|
| 426 |
+
If False, uses only row access (faster for large memory-mapped arrays).
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
numpy.ndarray: Array of sampled view indices.
|
| 430 |
+
"""
|
| 431 |
+
if num_views_to_sample == num_views_in_scene:
|
| 432 |
+
# Select all views in the scene
|
| 433 |
+
view_indices = self._rng.permutation(num_views_in_scene)
|
| 434 |
+
elif num_views_to_sample > num_views_in_scene:
|
| 435 |
+
# Select all views in the scene and repeat them to get the desired number of views
|
| 436 |
+
view_indices = self._rng.choice(
|
| 437 |
+
num_views_in_scene, size=num_views_to_sample, replace=True
|
| 438 |
+
)
|
| 439 |
+
else:
|
| 440 |
+
# Select a subset of single component connected views in the scene using random walk sampling
|
| 441 |
+
view_indices = self._random_walk_sampling(
|
| 442 |
+
scene_pairwise_covisibility,
|
| 443 |
+
num_views_to_sample,
|
| 444 |
+
use_bidirectional_covis=use_bidirectional_covis,
|
| 445 |
+
)
|
| 446 |
+
# If the required num of views can't be obtained even with 4 retries, repeat existing indices to get the desired number of views
|
| 447 |
+
if len(view_indices) < num_views_to_sample:
|
| 448 |
+
view_indices = self._rng.choice(
|
| 449 |
+
view_indices, size=num_views_to_sample, replace=True
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
return view_indices
|
| 453 |
+
|
| 454 |
+
def _getitem_fn(self, idx):
|
| 455 |
+
if isinstance(idx, tuple):
|
| 456 |
+
# The idx is a tuple if specifying the aspect-ratio or/and the number of views
|
| 457 |
+
if isinstance(self.num_views, int):
|
| 458 |
+
idx, ar_idx = idx
|
| 459 |
+
else:
|
| 460 |
+
idx, ar_idx, num_views_to_sample_idx = idx
|
| 461 |
+
else:
|
| 462 |
+
assert len(self._resolutions) == 1
|
| 463 |
+
assert isinstance(self.num_views, int)
|
| 464 |
+
ar_idx = 0
|
| 465 |
+
|
| 466 |
+
# Setup the rng
|
| 467 |
+
if self.seed: # reseed for each _getitem_fn
|
| 468 |
+
# Leads to deterministic sampling where repeating self.seed and self._seed_offset yields the same multi-view set again
|
| 469 |
+
# Scenes will be repeated if size of dataset is artificially increased using "N @" or "N *"
|
| 470 |
+
# When scenes are repeated, self._seed_offset is increased to ensure new multi-view sets
|
| 471 |
+
# This is useful for evaluation if the number of dataset scenes is < N, yet we want unique multi-view sets each iter
|
| 472 |
+
self._rng = np.random.default_rng(seed=self.seed + self._seed_offset + idx)
|
| 473 |
+
elif not hasattr(self, "_rng"):
|
| 474 |
+
seed = torch.initial_seed() # this is different for each dataloader process
|
| 475 |
+
self._rng = np.random.default_rng(seed=seed)
|
| 476 |
+
|
| 477 |
+
# Get the views for the given index and check that the number of views is correct
|
| 478 |
+
resolution = self._resolutions[ar_idx]
|
| 479 |
+
if isinstance(self.num_views, int):
|
| 480 |
+
num_views_to_sample = self.num_views
|
| 481 |
+
else:
|
| 482 |
+
num_views_to_sample = self.num_views[num_views_to_sample_idx]
|
| 483 |
+
views = self._get_views(idx, num_views_to_sample, resolution)
|
| 484 |
+
if isinstance(self.num_views, int):
|
| 485 |
+
assert len(views) == self.num_views
|
| 486 |
+
else:
|
| 487 |
+
assert len(views) in self.num_views
|
| 488 |
+
|
| 489 |
+
for v, view in enumerate(views):
|
| 490 |
+
# Store the index and other metadata
|
| 491 |
+
view["idx"] = (idx, ar_idx, v)
|
| 492 |
+
view["is_metric_scale"] = self.is_metric_scale
|
| 493 |
+
view["is_synthetic"] = self.is_synthetic
|
| 494 |
+
|
| 495 |
+
# Check the depth, intrinsics, and pose data (also other data if present)
|
| 496 |
+
assert "camera_intrinsics" in view
|
| 497 |
+
assert "camera_pose" in view
|
| 498 |
+
assert np.isfinite(view["camera_pose"]).all(), (
|
| 499 |
+
f"NaN or infinite values in camera pose for view {view_name(view)}"
|
| 500 |
+
)
|
| 501 |
+
assert np.isfinite(view["depthmap"]).all(), (
|
| 502 |
+
f"NaN or infinite values in depthmap for view {view_name(view)}"
|
| 503 |
+
)
|
| 504 |
+
assert "valid_mask" not in view
|
| 505 |
+
assert "pts3d" not in view, (
|
| 506 |
+
f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
| 507 |
+
)
|
| 508 |
+
if "prior_depth_z" in view:
|
| 509 |
+
assert np.isfinite(view["prior_depth_z"]).all(), (
|
| 510 |
+
f"NaN or infinite values in prior_depth_z for view {view_name(view)}"
|
| 511 |
+
)
|
| 512 |
+
if "non_ambiguous_mask" in view:
|
| 513 |
+
assert np.isfinite(view["non_ambiguous_mask"]).all(), (
|
| 514 |
+
f"NaN or infinite values in non_ambiguous_mask for view {view_name(view)}"
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Encode the image
|
| 518 |
+
width, height = view["img"].size
|
| 519 |
+
view["true_shape"] = np.int32((height, width))
|
| 520 |
+
view["img"] = self.transform(view["img"])
|
| 521 |
+
view["data_norm_type"] = self.data_norm_type
|
| 522 |
+
|
| 523 |
+
# Compute the pointmaps, raymap and depth along ray
|
| 524 |
+
(
|
| 525 |
+
pts3d,
|
| 526 |
+
valid_mask,
|
| 527 |
+
ray_origins_world,
|
| 528 |
+
ray_directions_world,
|
| 529 |
+
depth_along_ray,
|
| 530 |
+
ray_directions_cam,
|
| 531 |
+
pts3d_cam,
|
| 532 |
+
) = get_absolute_pointmaps_and_rays_info(**view)
|
| 533 |
+
view["pts3d"] = pts3d
|
| 534 |
+
view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
| 535 |
+
view["depth_along_ray"] = depth_along_ray
|
| 536 |
+
view["ray_directions_cam"] = ray_directions_cam
|
| 537 |
+
view["pts3d_cam"] = pts3d_cam
|
| 538 |
+
|
| 539 |
+
# Compute the prior depth along ray if present
|
| 540 |
+
if "prior_depth_z" in view:
|
| 541 |
+
prior_pts3d, _ = depthmap_to_camera_coordinates(
|
| 542 |
+
view["prior_depth_z"], view["camera_intrinsics"]
|
| 543 |
+
)
|
| 544 |
+
view["prior_depth_along_ray"] = np.linalg.norm(prior_pts3d, axis=-1)
|
| 545 |
+
view["prior_depth_along_ray"] = view["prior_depth_along_ray"][..., None]
|
| 546 |
+
del view["prior_depth_z"]
|
| 547 |
+
|
| 548 |
+
# Convert ambiguous mask dtype to match valid mask dtype
|
| 549 |
+
if "non_ambiguous_mask" in view:
|
| 550 |
+
view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype(
|
| 551 |
+
view["valid_mask"].dtype
|
| 552 |
+
)
|
| 553 |
+
else:
|
| 554 |
+
ambiguous_mask = view["depthmap"] < 0
|
| 555 |
+
view["non_ambiguous_mask"] = ~ambiguous_mask
|
| 556 |
+
view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype(
|
| 557 |
+
view["valid_mask"].dtype
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# Check all datatypes
|
| 561 |
+
for key, val in view.items():
|
| 562 |
+
res, err_msg = is_good_type(val)
|
| 563 |
+
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
| 564 |
+
|
| 565 |
+
# Check shapes
|
| 566 |
+
assert view["depthmap"].shape == view["img"].shape[1:]
|
| 567 |
+
assert view["depthmap"].shape == view["pts3d"].shape[:2]
|
| 568 |
+
assert view["depthmap"].shape == view["valid_mask"].shape
|
| 569 |
+
assert view["depthmap"].shape == view["depth_along_ray"].shape[:2]
|
| 570 |
+
assert view["depthmap"].shape == view["ray_directions_cam"].shape[:2]
|
| 571 |
+
assert view["depthmap"].shape == view["pts3d_cam"].shape[:2]
|
| 572 |
+
if "prior_depth_along_ray" in view:
|
| 573 |
+
assert view["depthmap"].shape == view["prior_depth_along_ray"].shape[:2]
|
| 574 |
+
if "non_ambiguous_mask" in view:
|
| 575 |
+
assert view["depthmap"].shape == view["non_ambiguous_mask"].shape
|
| 576 |
+
|
| 577 |
+
# Expand the last dimension of the depthmap
|
| 578 |
+
view["depthmap"] = view["depthmap"][..., None]
|
| 579 |
+
|
| 580 |
+
# Append RNG state to the views, this allows to check whether the RNG is in the same state each time
|
| 581 |
+
view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
|
| 582 |
+
|
| 583 |
+
# Compute and store the quaternions and translation for the camera poses
|
| 584 |
+
# Notation is (x, y, z, w) for quaternions
|
| 585 |
+
# This also ensures that the camera poses have a positive determinant (right-handed coordinate system)
|
| 586 |
+
view["camera_pose_quats"] = (
|
| 587 |
+
Rotation.from_matrix(view["camera_pose"][:3, :3])
|
| 588 |
+
.as_quat()
|
| 589 |
+
.astype(view["camera_pose"].dtype)
|
| 590 |
+
)
|
| 591 |
+
view["camera_pose_trans"] = view["camera_pose"][:3, 3].astype(
|
| 592 |
+
view["camera_pose"].dtype
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# Check the pointmaps, rays, depth along ray, and camera pose quaternions and translation to ensure they are finite
|
| 596 |
+
assert np.isfinite(view["pts3d"]).all(), (
|
| 597 |
+
f"NaN in pts3d for view {view_name(view)}"
|
| 598 |
+
)
|
| 599 |
+
assert np.isfinite(view["valid_mask"]).all(), (
|
| 600 |
+
f"NaN in valid_mask for view {view_name(view)}"
|
| 601 |
+
)
|
| 602 |
+
assert np.isfinite(view["depth_along_ray"]).all(), (
|
| 603 |
+
f"NaN in depth_along_ray for view {view_name(view)}"
|
| 604 |
+
)
|
| 605 |
+
assert np.isfinite(view["ray_directions_cam"]).all(), (
|
| 606 |
+
f"NaN in ray_directions_cam for view {view_name(view)}"
|
| 607 |
+
)
|
| 608 |
+
assert np.isfinite(view["pts3d_cam"]).all(), (
|
| 609 |
+
f"NaN in pts3d_cam for view {view_name(view)}"
|
| 610 |
+
)
|
| 611 |
+
assert np.isfinite(view["camera_pose_quats"]).all(), (
|
| 612 |
+
f"NaN in camera_pose_quats for view {view_name(view)}"
|
| 613 |
+
)
|
| 614 |
+
assert np.isfinite(view["camera_pose_trans"]).all(), (
|
| 615 |
+
f"NaN in camera_pose_trans for view {view_name(view)}"
|
| 616 |
+
)
|
| 617 |
+
if "prior_depth_along_ray" in view:
|
| 618 |
+
assert np.isfinite(view["prior_depth_along_ray"]).all(), (
|
| 619 |
+
f"NaN in prior_depth_along_ray for view {view_name(view)}"
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
return views
|
| 623 |
+
|
| 624 |
+
def __getitem__(self, idx):
|
| 625 |
+
if self.max_num_retries == 0:
|
| 626 |
+
return self._getitem_fn(idx)
|
| 627 |
+
|
| 628 |
+
num_retries = 0
|
| 629 |
+
while num_retries <= self.max_num_retries:
|
| 630 |
+
try:
|
| 631 |
+
return self._getitem_fn(idx)
|
| 632 |
+
except Exception as e:
|
| 633 |
+
scene_idx = idx[0] if isinstance(idx, tuple) else idx
|
| 634 |
+
print(
|
| 635 |
+
f"Error in {type(self).__name__}.__getitem__ for scene_idx={scene_idx}: {e}"
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
if num_retries >= self.max_num_retries:
|
| 639 |
+
print(
|
| 640 |
+
f"Max retries ({self.max_num_retries}) reached, raising the exception"
|
| 641 |
+
)
|
| 642 |
+
raise e
|
| 643 |
+
|
| 644 |
+
# Retry with a different scene index
|
| 645 |
+
num_retries += 1
|
| 646 |
+
if isinstance(idx, tuple):
|
| 647 |
+
# The scene index is the first element of the tuple
|
| 648 |
+
idx_list = list(idx)
|
| 649 |
+
idx_list[0] = np.random.randint(0, len(self))
|
| 650 |
+
idx = tuple(idx_list)
|
| 651 |
+
else:
|
| 652 |
+
# The scene index is idx
|
| 653 |
+
idx = np.random.randint(0, len(self))
|
| 654 |
+
scene_idx = idx[0] if isinstance(idx, tuple) else idx
|
| 655 |
+
print(
|
| 656 |
+
f"Retrying with scene_idx={scene_idx} ({num_retries} of {self.max_num_retries})"
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def is_good_type(v):
|
| 661 |
+
"""
|
| 662 |
+
Check if a value has an acceptable data type for processing in the dataset.
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
v: The value to check.
|
| 666 |
+
|
| 667 |
+
Returns:
|
| 668 |
+
tuple: A tuple containing:
|
| 669 |
+
- bool: True if the type is acceptable, False otherwise.
|
| 670 |
+
- str or None: Error message if the type is not acceptable, None otherwise.
|
| 671 |
+
"""
|
| 672 |
+
if isinstance(v, (str, int, tuple)):
|
| 673 |
+
return True, None
|
| 674 |
+
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
|
| 675 |
+
return False, f"bad {v.dtype=}"
|
| 676 |
+
return True, None
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def view_name(view, batch_index=None):
|
| 680 |
+
"""
|
| 681 |
+
Generate a string identifier for a view based on its dataset, label, and instance.
|
| 682 |
+
|
| 683 |
+
Args:
|
| 684 |
+
view (dict): Dictionary containing view information with 'dataset', 'label', and 'instance' keys.
|
| 685 |
+
batch_index (int, optional): Index to select from batched data. Defaults to None.
|
| 686 |
+
|
| 687 |
+
Returns:
|
| 688 |
+
str: A formatted string in the form "dataset/label/instance".
|
| 689 |
+
"""
|
| 690 |
+
|
| 691 |
+
def sel(x):
|
| 692 |
+
return x[batch_index] if batch_index not in (None, slice(None)) else x
|
| 693 |
+
|
| 694 |
+
db = sel(view["dataset"])
|
| 695 |
+
label = sel(view["label"])
|
| 696 |
+
instance = sel(view["instance"])
|
| 697 |
+
return f"{db}/{label}/{instance}"
|
mapanything/datasets/base/batched_sampler.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Utilities for random sampling under a single or multiple constraints
|
| 8 |
+
|
| 9 |
+
References: DUSt3R
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def round_by(total, multiple, up=False):
|
| 17 |
+
"""
|
| 18 |
+
Round a number to the nearest multiple of another number.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
total (int): The number to round
|
| 22 |
+
multiple (int): The multiple to round to
|
| 23 |
+
up (bool, optional): Whether to round up. Defaults to False.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
int: The rounded number
|
| 27 |
+
"""
|
| 28 |
+
if up:
|
| 29 |
+
total = total + multiple - 1
|
| 30 |
+
return (total // multiple) * multiple
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class BatchedRandomSampler:
|
| 34 |
+
"""
|
| 35 |
+
Random sampling under a constraint: each sample in the batch has the same feature,
|
| 36 |
+
which is chosen randomly from a known pool of 'features' for each batch.
|
| 37 |
+
|
| 38 |
+
For instance, the 'feature' could be the image aspect-ratio.
|
| 39 |
+
|
| 40 |
+
The index returned is a tuple (sample_idx, feat_idx).
|
| 41 |
+
This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
dataset: Dataset to sample from
|
| 50 |
+
batch_size: Number of samples per batch
|
| 51 |
+
pool_size: Integer representing the size of feature pool
|
| 52 |
+
world_size: Number of distributed processes
|
| 53 |
+
rank: Rank of the current process
|
| 54 |
+
drop_last: Whether to drop the last incomplete batch
|
| 55 |
+
"""
|
| 56 |
+
self.batch_size = batch_size
|
| 57 |
+
self.pool_size = pool_size
|
| 58 |
+
|
| 59 |
+
self.len_dataset = N = len(dataset)
|
| 60 |
+
self.total_size = round_by(N, batch_size * world_size) if drop_last else N
|
| 61 |
+
assert world_size == 1 or drop_last, (
|
| 62 |
+
"must drop the last batch in distributed mode"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Distributed sampler
|
| 66 |
+
self.world_size = world_size
|
| 67 |
+
self.rank = rank
|
| 68 |
+
self.epoch = None
|
| 69 |
+
|
| 70 |
+
def __len__(self):
|
| 71 |
+
"""
|
| 72 |
+
Get the length of the sampler.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
int: The number of samples in the sampler for the current process
|
| 76 |
+
"""
|
| 77 |
+
return self.total_size // self.world_size
|
| 78 |
+
|
| 79 |
+
def set_epoch(self, epoch):
|
| 80 |
+
"""
|
| 81 |
+
Set the epoch for this sampler.
|
| 82 |
+
|
| 83 |
+
This should be called before each epoch to ensure proper shuffling of the data.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
epoch (int): The current epoch number
|
| 87 |
+
"""
|
| 88 |
+
self.epoch = epoch
|
| 89 |
+
|
| 90 |
+
def __iter__(self):
|
| 91 |
+
"""
|
| 92 |
+
Iterator over the indices.
|
| 93 |
+
|
| 94 |
+
This method generates random indices for each batch, ensuring that all samples
|
| 95 |
+
within a batch have the same feature index for the given feature pool.
|
| 96 |
+
|
| 97 |
+
Yields:
|
| 98 |
+
tuple: A tuple containing (sample_idx, feat_idx)
|
| 99 |
+
"""
|
| 100 |
+
# Prepare RNG
|
| 101 |
+
if self.epoch is None:
|
| 102 |
+
assert self.world_size == 1 and self.rank == 0, (
|
| 103 |
+
"use set_epoch() if distributed mode is used"
|
| 104 |
+
)
|
| 105 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 106 |
+
else:
|
| 107 |
+
seed = self.epoch + 777
|
| 108 |
+
rng = np.random.default_rng(seed=seed)
|
| 109 |
+
|
| 110 |
+
# Random indices (will restart from 0 if not drop_last)
|
| 111 |
+
sample_idxs = np.arange(self.total_size)
|
| 112 |
+
rng.shuffle(sample_idxs)
|
| 113 |
+
|
| 114 |
+
# Random feat_idxs (same across each batch)
|
| 115 |
+
n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
|
| 116 |
+
feat_idxs = rng.integers(self.pool_size, size=n_batches)
|
| 117 |
+
feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
|
| 118 |
+
feat_idxs = feat_idxs.ravel()[: self.total_size]
|
| 119 |
+
|
| 120 |
+
# Put them together
|
| 121 |
+
idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
|
| 122 |
+
|
| 123 |
+
# Distributed sampler: we select a subset of batches
|
| 124 |
+
# Make sure the slice for each node is aligned with batch_size
|
| 125 |
+
size_per_proc = self.batch_size * (
|
| 126 |
+
(self.total_size + self.world_size * self.batch_size - 1)
|
| 127 |
+
// (self.world_size * self.batch_size)
|
| 128 |
+
)
|
| 129 |
+
idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
|
| 130 |
+
|
| 131 |
+
yield from (tuple(idx) for idx in idxs)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class BatchedMultiFeatureRandomSampler:
|
| 135 |
+
"""
|
| 136 |
+
Random sampling under multiple constraints: each sample in the batch has the same features,
|
| 137 |
+
which are chosen randomly from known pools of 'features' for each batch.
|
| 138 |
+
|
| 139 |
+
For instance, the 'features' could be the image aspect-ratio and scene type.
|
| 140 |
+
|
| 141 |
+
The index returned is a tuple (sample_idx, feat_idx_1, feat_idx_2, ...).
|
| 142 |
+
This sampler ensures that each series of `batch_size` indices has the same feature indices.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self, dataset, batch_size, pool_sizes, world_size=1, rank=0, drop_last=True
|
| 147 |
+
):
|
| 148 |
+
"""
|
| 149 |
+
Args:
|
| 150 |
+
dataset: Dataset to sample from
|
| 151 |
+
batch_size: Number of samples per batch
|
| 152 |
+
pool_sizes: List of integers representing the size of each feature pool
|
| 153 |
+
world_size: Number of distributed processes
|
| 154 |
+
rank: Rank of the current process
|
| 155 |
+
drop_last: Whether to drop the last incomplete batch
|
| 156 |
+
"""
|
| 157 |
+
self.batch_size = batch_size
|
| 158 |
+
self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes]
|
| 159 |
+
|
| 160 |
+
self.len_dataset = N = len(dataset)
|
| 161 |
+
self.total_size = round_by(N, batch_size * world_size) if drop_last else N
|
| 162 |
+
assert world_size == 1 or drop_last, (
|
| 163 |
+
"must drop the last batch in distributed mode"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Distributed sampler
|
| 167 |
+
self.world_size = world_size
|
| 168 |
+
self.rank = rank
|
| 169 |
+
self.epoch = None
|
| 170 |
+
|
| 171 |
+
def __len__(self):
|
| 172 |
+
"""
|
| 173 |
+
Get the length of the sampler.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
int: The number of samples in the sampler for the current process
|
| 177 |
+
"""
|
| 178 |
+
return self.total_size // self.world_size
|
| 179 |
+
|
| 180 |
+
def set_epoch(self, epoch):
|
| 181 |
+
"""
|
| 182 |
+
Set the epoch for this sampler.
|
| 183 |
+
|
| 184 |
+
This should be called before each epoch to ensure proper shuffling of the data.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
epoch (int): The current epoch number
|
| 188 |
+
"""
|
| 189 |
+
self.epoch = epoch
|
| 190 |
+
|
| 191 |
+
def __iter__(self):
|
| 192 |
+
"""
|
| 193 |
+
Iterator over the indices.
|
| 194 |
+
|
| 195 |
+
This method generates random indices for each batch, ensuring that all samples
|
| 196 |
+
within a batch have the same feature indices for multiple features.
|
| 197 |
+
|
| 198 |
+
Yields:
|
| 199 |
+
tuple: A tuple containing (sample_idx, feat_idx_1, feat_idx_2, ...)
|
| 200 |
+
"""
|
| 201 |
+
# Prepare RNG
|
| 202 |
+
if self.epoch is None:
|
| 203 |
+
assert self.world_size == 1 and self.rank == 0, (
|
| 204 |
+
"use set_epoch() if distributed mode is used"
|
| 205 |
+
)
|
| 206 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 207 |
+
else:
|
| 208 |
+
seed = self.epoch + 777
|
| 209 |
+
rng = np.random.default_rng(seed=seed)
|
| 210 |
+
|
| 211 |
+
# Random indices (will restart from 0 if not drop_last)
|
| 212 |
+
sample_idxs = np.arange(self.total_size)
|
| 213 |
+
rng.shuffle(sample_idxs)
|
| 214 |
+
|
| 215 |
+
# Random feat_idxs (same across each batch)
|
| 216 |
+
n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
|
| 217 |
+
|
| 218 |
+
# Generate feature indices for each feature pool
|
| 219 |
+
all_feat_idxs = []
|
| 220 |
+
for pool_size in self.pool_sizes:
|
| 221 |
+
feat_idxs = rng.integers(pool_size, size=n_batches)
|
| 222 |
+
feat_idxs = np.broadcast_to(
|
| 223 |
+
feat_idxs[:, None], (n_batches, self.batch_size)
|
| 224 |
+
)
|
| 225 |
+
feat_idxs = feat_idxs.ravel()[: self.total_size]
|
| 226 |
+
all_feat_idxs.append(feat_idxs)
|
| 227 |
+
|
| 228 |
+
# Put them together
|
| 229 |
+
idxs = np.column_stack(
|
| 230 |
+
[sample_idxs] + all_feat_idxs
|
| 231 |
+
) # shape = (total_size, 1 + len(pool_sizes))
|
| 232 |
+
|
| 233 |
+
# Distributed sampler: we select a subset of batches
|
| 234 |
+
# Make sure the slice for each node is aligned with batch_size
|
| 235 |
+
size_per_proc = self.batch_size * (
|
| 236 |
+
(self.total_size + self.world_size * self.batch_size - 1)
|
| 237 |
+
// (self.world_size * self.batch_size)
|
| 238 |
+
)
|
| 239 |
+
idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
|
| 240 |
+
|
| 241 |
+
yield from (tuple(idx) for idx in idxs)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class DynamicBatchedMultiFeatureRandomSampler:
|
| 245 |
+
"""
|
| 246 |
+
Random sampling under multiple constraints with dynamic batch size:
|
| 247 |
+
each sample in the batch has the same features, which are chosen randomly
|
| 248 |
+
from known pools of 'features' for each batch.
|
| 249 |
+
|
| 250 |
+
The batch size is dynamically determined based on a specified feature index,
|
| 251 |
+
using a direct mapping from feature values to batch sizes.
|
| 252 |
+
|
| 253 |
+
For instance, if one of the features is the number of images in a multi-view set,
|
| 254 |
+
you can specify different batch sizes for different numbers of images to optimize
|
| 255 |
+
GPU memory usage. This is achieved by using the feature_to_batch_size_map parameter
|
| 256 |
+
to directly specify what batch size to use for each feature value.
|
| 257 |
+
|
| 258 |
+
The returned index is a list of tuples [(sample_idx, feat_idx_1, feat_idx_2, ...), ...].
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
dataset,
|
| 264 |
+
pool_sizes,
|
| 265 |
+
scaling_feature_idx=0,
|
| 266 |
+
feature_to_batch_size_map=None,
|
| 267 |
+
world_size=1,
|
| 268 |
+
rank=0,
|
| 269 |
+
drop_last=True,
|
| 270 |
+
):
|
| 271 |
+
"""
|
| 272 |
+
Args:
|
| 273 |
+
dataset: Dataset to sample from
|
| 274 |
+
pool_sizes: List of integers representing the size of each feature pool
|
| 275 |
+
scaling_feature_idx: Index of the feature to use for determining batch size (0-based index into pool_sizes)
|
| 276 |
+
feature_to_batch_size_map: Optional function or dict that maps feature values directly to batch sizes.
|
| 277 |
+
For example, if the feature represents number of views, this maps number of views
|
| 278 |
+
to appropriate batch size that can fit in GPU memory.
|
| 279 |
+
If None, uses a default batch size of 1 for all feature values.
|
| 280 |
+
world_size: Number of distributed processes
|
| 281 |
+
rank: Rank of the current process
|
| 282 |
+
drop_last: Whether to drop the last incomplete batch
|
| 283 |
+
"""
|
| 284 |
+
self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes]
|
| 285 |
+
self.scaling_feature_idx = scaling_feature_idx
|
| 286 |
+
|
| 287 |
+
# Ensure scaling_feature_idx is valid
|
| 288 |
+
if scaling_feature_idx < 0 or scaling_feature_idx >= len(self.pool_sizes):
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"scaling_feature_idx must be between 0 and {len(self.pool_sizes) - 1}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Set up mapping from feature values to batch sizes
|
| 294 |
+
self.feature_to_batch_size_map = feature_to_batch_size_map
|
| 295 |
+
if self.feature_to_batch_size_map is None:
|
| 296 |
+
# Default: batch size of 1 for all feature values
|
| 297 |
+
self.feature_to_batch_size_map = {
|
| 298 |
+
i: 1 for i in range(self.pool_sizes[scaling_feature_idx])
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
self.len_dataset = N = len(dataset)
|
| 302 |
+
|
| 303 |
+
# We don't know the exact batch size yet, so we use a large number for total_size
|
| 304 |
+
# This will be adjusted during iteration
|
| 305 |
+
self.total_size = N
|
| 306 |
+
|
| 307 |
+
# Distributed sampler
|
| 308 |
+
self.world_size = world_size
|
| 309 |
+
self.rank = rank
|
| 310 |
+
self.epoch = None
|
| 311 |
+
self.drop_last = drop_last
|
| 312 |
+
|
| 313 |
+
def __len__(self):
|
| 314 |
+
"""
|
| 315 |
+
Get the approximate length of the sampler.
|
| 316 |
+
|
| 317 |
+
Since batch size varies, this is an estimate based on the largest batch size
|
| 318 |
+
in the mapping, which provides a lower bound on the number of batches.
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
int: The estimated minimum number of samples in the sampler for the current process
|
| 322 |
+
"""
|
| 323 |
+
# Find the largest batch size in the mapping
|
| 324 |
+
if callable(self.feature_to_batch_size_map):
|
| 325 |
+
# If it's a function, sample some values to find the maximum
|
| 326 |
+
batch_sizes = [
|
| 327 |
+
self.feature_to_batch_size_map(i)
|
| 328 |
+
for i in range(self.pool_sizes[self.scaling_feature_idx])
|
| 329 |
+
]
|
| 330 |
+
max_batch_size = max(batch_sizes)
|
| 331 |
+
else:
|
| 332 |
+
# If it's a dict or similar, find the maximum directly
|
| 333 |
+
max_batch_size = max(self.feature_to_batch_size_map.values())
|
| 334 |
+
|
| 335 |
+
# Ensure minimum batch size of 1
|
| 336 |
+
max_batch_size = max(1, max_batch_size)
|
| 337 |
+
|
| 338 |
+
# Estimate total batches using the largest batch size
|
| 339 |
+
# This gives a lower bound on the number of batches
|
| 340 |
+
total_batches = self.total_size // max_batch_size
|
| 341 |
+
if not self.drop_last and self.total_size % max_batch_size > 0:
|
| 342 |
+
total_batches += 1
|
| 343 |
+
|
| 344 |
+
# Distribute among processes
|
| 345 |
+
return total_batches // self.world_size
|
| 346 |
+
|
| 347 |
+
def set_epoch(self, epoch):
|
| 348 |
+
"""
|
| 349 |
+
Set the epoch for this sampler.
|
| 350 |
+
|
| 351 |
+
This should be called before each epoch to ensure proper shuffling of the data.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
epoch (int): The current epoch number
|
| 355 |
+
"""
|
| 356 |
+
self.epoch = epoch
|
| 357 |
+
|
| 358 |
+
def __iter__(self):
|
| 359 |
+
"""
|
| 360 |
+
Iterator over the indices with dynamic batch sizes.
|
| 361 |
+
|
| 362 |
+
This method generates random indices for each batch, ensuring that all samples
|
| 363 |
+
within a batch have the same feature indices for multiple features.
|
| 364 |
+
The batch size is determined directly from the feature_to_batch_size_map.
|
| 365 |
+
|
| 366 |
+
The iterator enforces the length returned by __len__() by stopping after
|
| 367 |
+
exactly that many batches have been yielded for this process.
|
| 368 |
+
|
| 369 |
+
Yields:
|
| 370 |
+
list of tuples: A batch of tuples, each containing (sample_idx, feat_idx_1, feat_idx_2, ...)
|
| 371 |
+
"""
|
| 372 |
+
# Prepare RNG
|
| 373 |
+
if self.epoch is None:
|
| 374 |
+
assert self.world_size == 1 and self.rank == 0, (
|
| 375 |
+
"use set_epoch() if distributed mode is used"
|
| 376 |
+
)
|
| 377 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 378 |
+
else:
|
| 379 |
+
seed = self.epoch + 777
|
| 380 |
+
rng = np.random.default_rng(seed=seed)
|
| 381 |
+
|
| 382 |
+
# Random indices for the entire dataset
|
| 383 |
+
sample_idxs = np.arange(self.total_size)
|
| 384 |
+
rng.shuffle(sample_idxs)
|
| 385 |
+
|
| 386 |
+
# Get the target number of batches for this process (enforce strict length)
|
| 387 |
+
target_batches_for_process = len(self)
|
| 388 |
+
batches_yielded_for_process = 0
|
| 389 |
+
|
| 390 |
+
# Process indices in batches with dynamic sizing
|
| 391 |
+
idx = 0
|
| 392 |
+
batch_idx = 0 # Track batch index for even distribution
|
| 393 |
+
while idx < len(sample_idxs) and (
|
| 394 |
+
batches_yielded_for_process < target_batches_for_process
|
| 395 |
+
):
|
| 396 |
+
# Randomly select feature indices for this batch
|
| 397 |
+
feat_idxs = [rng.integers(pool_size) for pool_size in self.pool_sizes]
|
| 398 |
+
|
| 399 |
+
# Get the scaling feature value
|
| 400 |
+
scaling_feat = feat_idxs[self.scaling_feature_idx]
|
| 401 |
+
|
| 402 |
+
# Get the batch size directly from the mapping
|
| 403 |
+
if callable(self.feature_to_batch_size_map):
|
| 404 |
+
batch_size = self.feature_to_batch_size_map(scaling_feat)
|
| 405 |
+
else:
|
| 406 |
+
batch_size = self.feature_to_batch_size_map.get(scaling_feat, 1)
|
| 407 |
+
|
| 408 |
+
# Ensure minimum batch size of 1
|
| 409 |
+
batch_size = max(1, batch_size)
|
| 410 |
+
|
| 411 |
+
# Ensure we don't go beyond available samples
|
| 412 |
+
remaining = len(sample_idxs) - idx
|
| 413 |
+
if remaining < batch_size:
|
| 414 |
+
if self.drop_last:
|
| 415 |
+
break
|
| 416 |
+
batch_size = remaining
|
| 417 |
+
|
| 418 |
+
# Create batch with consistent feature indices
|
| 419 |
+
batch = []
|
| 420 |
+
for i in range(batch_size):
|
| 421 |
+
if idx + i < len(sample_idxs):
|
| 422 |
+
sample_idx = sample_idxs[idx + i]
|
| 423 |
+
batch.append(tuple([sample_idx] + feat_idxs))
|
| 424 |
+
|
| 425 |
+
# Distribute batches among processes in round-robin fashion
|
| 426 |
+
if len(batch) > 0 and (batch_idx % self.world_size == self.rank):
|
| 427 |
+
yield batch
|
| 428 |
+
batches_yielded_for_process += 1
|
| 429 |
+
|
| 430 |
+
batch_idx += 1 # Increment batch index
|
| 431 |
+
idx += batch_size
|
mapanything/datasets/base/easy_dataset.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Base dataset class that enables easy resizing and combining
|
| 8 |
+
|
| 9 |
+
References: DUSt3R
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from mapanything.datasets.base.batched_sampler import (
|
| 15 |
+
BatchedMultiFeatureRandomSampler,
|
| 16 |
+
DynamicBatchedMultiFeatureRandomSampler,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class EasyDataset:
|
| 21 |
+
"""
|
| 22 |
+
Dataset that can be easily resized and combined.
|
| 23 |
+
|
| 24 |
+
Examples:
|
| 25 |
+
---------
|
| 26 |
+
2 * dataset ==> Duplicate each element 2x
|
| 27 |
+
|
| 28 |
+
10 @ dataset ==> Set the size to 10 (random sampling, duplicates if necessary)
|
| 29 |
+
|
| 30 |
+
Dataset1 + Dataset2 ==> Concatenate datasets
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __add__(self, other):
|
| 34 |
+
"""
|
| 35 |
+
Concatenate this dataset with another dataset.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
other (EasyDataset): Another dataset to concatenate with this one
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
CatDataset: A new dataset that is the concatenation of this dataset and the other
|
| 42 |
+
"""
|
| 43 |
+
return CatDataset([self, other])
|
| 44 |
+
|
| 45 |
+
def __rmul__(self, factor):
|
| 46 |
+
"""
|
| 47 |
+
Multiply the dataset by a factor, duplicating each element.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
factor (int): Number of times to duplicate each element
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
MulDataset: A new dataset with each element duplicated 'factor' times
|
| 54 |
+
"""
|
| 55 |
+
return MulDataset(factor, self)
|
| 56 |
+
|
| 57 |
+
def __rmatmul__(self, factor):
|
| 58 |
+
"""
|
| 59 |
+
Resize the dataset to a specific size using random sampling.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
factor (int): The new size of the dataset
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
ResizedDataset: A new dataset with the specified size
|
| 66 |
+
"""
|
| 67 |
+
return ResizedDataset(factor, self)
|
| 68 |
+
|
| 69 |
+
def set_epoch(self, epoch):
|
| 70 |
+
"""
|
| 71 |
+
Set the current epoch for all constituent datasets.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
epoch (int): The current epoch number
|
| 75 |
+
"""
|
| 76 |
+
pass # nothing to do by default
|
| 77 |
+
|
| 78 |
+
def make_sampler(
|
| 79 |
+
self,
|
| 80 |
+
batch_size=None,
|
| 81 |
+
shuffle=True,
|
| 82 |
+
world_size=1,
|
| 83 |
+
rank=0,
|
| 84 |
+
drop_last=True,
|
| 85 |
+
max_num_of_images_per_gpu=None,
|
| 86 |
+
use_dynamic_sampler=True,
|
| 87 |
+
):
|
| 88 |
+
"""
|
| 89 |
+
Create a sampler for this dataset.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
batch_size (int, optional): Number of samples per batch (used for non-dynamic sampler). Defaults to None.
|
| 93 |
+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
|
| 94 |
+
world_size (int, optional): Number of distributed processes. Defaults to 1.
|
| 95 |
+
rank (int, optional): Rank of the current process. Defaults to 0.
|
| 96 |
+
drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
|
| 97 |
+
max_num_of_images_per_gpu (int, optional): Maximum number of images per GPU for dynamic batching. Defaults to None.
|
| 98 |
+
use_dynamic_sampler (bool, optional): Whether to use the dynamic sampler. Defaults to True.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
DynamicBatchedMultiFeatureRandomSampler or BatchedMultiFeatureRandomSampler: A sampler for this dataset
|
| 102 |
+
|
| 103 |
+
Raises:
|
| 104 |
+
NotImplementedError: If shuffle is False
|
| 105 |
+
ValueError: If num_views has an invalid type or required parameters are missing
|
| 106 |
+
"""
|
| 107 |
+
if not (shuffle):
|
| 108 |
+
raise NotImplementedError() # cannot deal yet
|
| 109 |
+
|
| 110 |
+
if isinstance(self.num_views, int):
|
| 111 |
+
num_of_aspect_ratios = len(self._resolutions)
|
| 112 |
+
feature_pool_sizes = [num_of_aspect_ratios]
|
| 113 |
+
scaling_feature_idx = 0 # Use aspect ratio as scaling feature
|
| 114 |
+
elif isinstance(self.num_views, list):
|
| 115 |
+
num_of_aspect_ratios = len(self._resolutions)
|
| 116 |
+
num_of_num_views = len(self.num_views)
|
| 117 |
+
feature_pool_sizes = [num_of_aspect_ratios, num_of_num_views]
|
| 118 |
+
scaling_feature_idx = 1 # Use num_views as scaling feature
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError(
|
| 121 |
+
f"Bad type for {self.num_views=}, should be int or list of ints"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if use_dynamic_sampler:
|
| 125 |
+
if max_num_of_images_per_gpu is None:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
"max_num_of_images_per_gpu must be provided when using dynamic sampler"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Create feature-to-batch-size mapping
|
| 131 |
+
if isinstance(self.num_views, list):
|
| 132 |
+
# Map num_views_idx to batch size: max(1, max_num_of_images_per_gpu // (num_views_idx + dataset.num_views_min))
|
| 133 |
+
feature_to_batch_size_map = {}
|
| 134 |
+
for num_views_idx, num_views in enumerate(self.num_views):
|
| 135 |
+
batch_size_for_multi_view_sets = max(
|
| 136 |
+
1, max_num_of_images_per_gpu // num_views
|
| 137 |
+
)
|
| 138 |
+
feature_to_batch_size_map[num_views_idx] = (
|
| 139 |
+
batch_size_for_multi_view_sets
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
# For fixed num_views, use a simple mapping
|
| 143 |
+
feature_to_batch_size_map = {
|
| 144 |
+
0: max(1, max_num_of_images_per_gpu // self.num_views)
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
return DynamicBatchedMultiFeatureRandomSampler(
|
| 148 |
+
self,
|
| 149 |
+
pool_sizes=feature_pool_sizes,
|
| 150 |
+
scaling_feature_idx=scaling_feature_idx,
|
| 151 |
+
feature_to_batch_size_map=feature_to_batch_size_map,
|
| 152 |
+
world_size=world_size,
|
| 153 |
+
rank=rank,
|
| 154 |
+
drop_last=drop_last,
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
if batch_size is None:
|
| 158 |
+
raise ValueError(
|
| 159 |
+
"batch_size must be provided when not using dynamic sampler"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
return BatchedMultiFeatureRandomSampler(
|
| 163 |
+
self,
|
| 164 |
+
batch_size,
|
| 165 |
+
feature_pool_sizes,
|
| 166 |
+
world_size=world_size,
|
| 167 |
+
rank=rank,
|
| 168 |
+
drop_last=drop_last,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class MulDataset(EasyDataset):
|
| 173 |
+
"""Artificially augmenting the size of a dataset."""
|
| 174 |
+
|
| 175 |
+
multiplicator: int
|
| 176 |
+
|
| 177 |
+
def __init__(self, multiplicator, dataset):
|
| 178 |
+
"""
|
| 179 |
+
Initialize a dataset that artificially augments the size of another dataset.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
multiplicator (int): Factor by which to multiply the dataset size
|
| 183 |
+
dataset (EasyDataset): The dataset to augment
|
| 184 |
+
"""
|
| 185 |
+
assert isinstance(multiplicator, int) and multiplicator > 0
|
| 186 |
+
self.multiplicator = multiplicator
|
| 187 |
+
self.dataset = dataset
|
| 188 |
+
|
| 189 |
+
def __len__(self):
|
| 190 |
+
"""
|
| 191 |
+
Get the length of the dataset.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
int: The number of samples in the dataset
|
| 195 |
+
"""
|
| 196 |
+
return self.multiplicator * len(self.dataset)
|
| 197 |
+
|
| 198 |
+
def __repr__(self):
|
| 199 |
+
"""
|
| 200 |
+
Get a string representation of the dataset.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
str: String representation showing the multiplication factor and the original dataset
|
| 204 |
+
"""
|
| 205 |
+
return f"{self.multiplicator}*{repr(self.dataset)}"
|
| 206 |
+
|
| 207 |
+
def __getitem__(self, idx):
|
| 208 |
+
"""
|
| 209 |
+
Get an item from the dataset.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
idx: Index or tuple of indices to retrieve
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
The item at the specified index from the original dataset
|
| 216 |
+
"""
|
| 217 |
+
if isinstance(idx, tuple):
|
| 218 |
+
other = idx[1:]
|
| 219 |
+
idx = idx[0]
|
| 220 |
+
new_idx = (idx // self.multiplicator, *other)
|
| 221 |
+
return self.dataset[new_idx]
|
| 222 |
+
else:
|
| 223 |
+
return self.dataset[idx // self.multiplicator]
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def _resolutions(self):
|
| 227 |
+
"""
|
| 228 |
+
Get the resolutions of the dataset.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
The resolutions from the original dataset
|
| 232 |
+
"""
|
| 233 |
+
return self.dataset._resolutions
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
def num_views(self):
|
| 237 |
+
"""
|
| 238 |
+
Get the number of views used for the dataset.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
int or list: The number of views parameter from the original dataset
|
| 242 |
+
"""
|
| 243 |
+
return self.dataset.num_views
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class ResizedDataset(EasyDataset):
|
| 247 |
+
"""Artificially changing the size of a dataset."""
|
| 248 |
+
|
| 249 |
+
new_size: int
|
| 250 |
+
|
| 251 |
+
def __init__(self, new_size, dataset):
|
| 252 |
+
"""
|
| 253 |
+
Initialize a dataset with an artificially changed size.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
new_size (int): The new size of the dataset
|
| 257 |
+
dataset (EasyDataset): The original dataset
|
| 258 |
+
"""
|
| 259 |
+
assert isinstance(new_size, int) and new_size > 0
|
| 260 |
+
self.new_size = new_size
|
| 261 |
+
self.dataset = dataset
|
| 262 |
+
|
| 263 |
+
def __len__(self):
|
| 264 |
+
"""
|
| 265 |
+
Get the length of the dataset.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
int: The new size of the dataset
|
| 269 |
+
"""
|
| 270 |
+
return self.new_size
|
| 271 |
+
|
| 272 |
+
def __repr__(self):
|
| 273 |
+
"""
|
| 274 |
+
Get a string representation of the dataset.
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
str: String representation showing the new size and the original dataset
|
| 278 |
+
"""
|
| 279 |
+
size_str = str(self.new_size)
|
| 280 |
+
for i in range((len(size_str) - 1) // 3):
|
| 281 |
+
sep = -4 * i - 3
|
| 282 |
+
size_str = size_str[:sep] + "_" + size_str[sep:]
|
| 283 |
+
return f"{size_str} @ {repr(self.dataset)}"
|
| 284 |
+
|
| 285 |
+
def set_epoch(self, epoch):
|
| 286 |
+
"""
|
| 287 |
+
Set the current epoch and generate a new random mapping of indices.
|
| 288 |
+
|
| 289 |
+
This method must be called before using __getitem__.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
epoch (int): The current epoch number
|
| 293 |
+
"""
|
| 294 |
+
# This random shuffle only depends on the epoch
|
| 295 |
+
rng = np.random.default_rng(seed=epoch + 777)
|
| 296 |
+
|
| 297 |
+
# Shuffle all indices
|
| 298 |
+
perm = rng.permutation(len(self.dataset))
|
| 299 |
+
|
| 300 |
+
# Calculate how many repetitions we need
|
| 301 |
+
num_repetitions = 1 + (len(self) - 1) // len(self.dataset)
|
| 302 |
+
|
| 303 |
+
# Rotary extension until target size is met
|
| 304 |
+
shuffled_idxs = np.concatenate([perm] * num_repetitions)
|
| 305 |
+
self._idxs_mapping = shuffled_idxs[: self.new_size]
|
| 306 |
+
|
| 307 |
+
# Generate the seed offset for each repetition
|
| 308 |
+
# This is needed to ensure we see unique samples when we repeat a scene
|
| 309 |
+
seed_offset_per_repetition = [
|
| 310 |
+
np.full(len(self.dataset), i) for i in range(num_repetitions)
|
| 311 |
+
]
|
| 312 |
+
seed_offset_idxs = np.concatenate(seed_offset_per_repetition)
|
| 313 |
+
self._idxs_seed_offset = seed_offset_idxs[: self.new_size]
|
| 314 |
+
|
| 315 |
+
assert len(self._idxs_mapping) == self.new_size
|
| 316 |
+
assert len(self._idxs_seed_offset) == self.new_size
|
| 317 |
+
|
| 318 |
+
def __getitem__(self, idx):
|
| 319 |
+
"""
|
| 320 |
+
Get an item from the dataset.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
idx: Index or tuple of indices to retrieve
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
The item at the mapped index from the original dataset
|
| 327 |
+
|
| 328 |
+
Raises:
|
| 329 |
+
AssertionError: If set_epoch has not been called
|
| 330 |
+
"""
|
| 331 |
+
assert hasattr(self, "_idxs_mapping"), (
|
| 332 |
+
"You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()"
|
| 333 |
+
)
|
| 334 |
+
if isinstance(idx, tuple):
|
| 335 |
+
other = idx[1:]
|
| 336 |
+
idx = idx[0]
|
| 337 |
+
self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
|
| 338 |
+
new_idx = (self._idxs_mapping[idx], *other)
|
| 339 |
+
return self.dataset[new_idx]
|
| 340 |
+
else:
|
| 341 |
+
self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
|
| 342 |
+
return self.dataset[self._idxs_mapping[idx]]
|
| 343 |
+
|
| 344 |
+
@property
|
| 345 |
+
def _resolutions(self):
|
| 346 |
+
"""
|
| 347 |
+
Get the resolutions of the dataset.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
The resolutions from the original dataset
|
| 351 |
+
"""
|
| 352 |
+
return self.dataset._resolutions
|
| 353 |
+
|
| 354 |
+
@property
|
| 355 |
+
def num_views(self):
|
| 356 |
+
"""
|
| 357 |
+
Get the number of views used for the dataset.
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
int or list: The number of views parameter from the original dataset
|
| 361 |
+
"""
|
| 362 |
+
return self.dataset.num_views
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class CatDataset(EasyDataset):
|
| 366 |
+
"""Concatenation of several datasets"""
|
| 367 |
+
|
| 368 |
+
def __init__(self, datasets):
|
| 369 |
+
"""
|
| 370 |
+
Initialize a dataset that is a concatenation of several datasets.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
datasets (list): List of EasyDataset instances to concatenate
|
| 374 |
+
"""
|
| 375 |
+
for dataset in datasets:
|
| 376 |
+
assert isinstance(dataset, EasyDataset)
|
| 377 |
+
self.datasets = datasets
|
| 378 |
+
self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
|
| 379 |
+
|
| 380 |
+
def __len__(self):
|
| 381 |
+
"""
|
| 382 |
+
Get the length of the concatenated dataset.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
int: Total number of samples across all datasets
|
| 386 |
+
"""
|
| 387 |
+
return self._cum_sizes[-1]
|
| 388 |
+
|
| 389 |
+
def __repr__(self):
|
| 390 |
+
"""
|
| 391 |
+
Get a string representation of the concatenated dataset.
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
str: String representation showing all concatenated datasets joined by '+'
|
| 395 |
+
"""
|
| 396 |
+
# Remove uselessly long transform
|
| 397 |
+
return " + ".join(
|
| 398 |
+
repr(dataset).replace(
|
| 399 |
+
",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))",
|
| 400 |
+
"",
|
| 401 |
+
)
|
| 402 |
+
for dataset in self.datasets
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
def set_epoch(self, epoch):
|
| 406 |
+
"""
|
| 407 |
+
Set the current epoch for all constituent datasets.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
epoch (int): The current epoch number
|
| 411 |
+
"""
|
| 412 |
+
for dataset in self.datasets:
|
| 413 |
+
dataset.set_epoch(epoch)
|
| 414 |
+
|
| 415 |
+
def __getitem__(self, idx):
|
| 416 |
+
"""
|
| 417 |
+
Get an item from the concatenated dataset.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
idx: Index or tuple of indices to retrieve
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
The item at the specified index from the appropriate constituent dataset
|
| 424 |
+
|
| 425 |
+
Raises:
|
| 426 |
+
IndexError: If the index is out of range
|
| 427 |
+
"""
|
| 428 |
+
other = None
|
| 429 |
+
if isinstance(idx, tuple):
|
| 430 |
+
other = idx[1:]
|
| 431 |
+
idx = idx[0]
|
| 432 |
+
|
| 433 |
+
if not (0 <= idx < len(self)):
|
| 434 |
+
raise IndexError()
|
| 435 |
+
|
| 436 |
+
db_idx = np.searchsorted(self._cum_sizes, idx, "right")
|
| 437 |
+
dataset = self.datasets[db_idx]
|
| 438 |
+
new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
|
| 439 |
+
|
| 440 |
+
if other is not None:
|
| 441 |
+
new_idx = (new_idx, *other)
|
| 442 |
+
return dataset[new_idx]
|
| 443 |
+
|
| 444 |
+
@property
|
| 445 |
+
def _resolutions(self):
|
| 446 |
+
"""
|
| 447 |
+
Get the resolutions of the dataset.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
The resolutions from the first dataset (all datasets must have the same resolutions)
|
| 451 |
+
|
| 452 |
+
Raises:
|
| 453 |
+
AssertionError: If datasets have different resolutions
|
| 454 |
+
"""
|
| 455 |
+
resolutions = self.datasets[0]._resolutions
|
| 456 |
+
for dataset in self.datasets[1:]:
|
| 457 |
+
assert tuple(dataset._resolutions) == tuple(resolutions), (
|
| 458 |
+
"All datasets must have the same resolutions"
|
| 459 |
+
)
|
| 460 |
+
return resolutions
|
| 461 |
+
|
| 462 |
+
@property
|
| 463 |
+
def num_views(self):
|
| 464 |
+
"""
|
| 465 |
+
Get the number of views used for the dataset.
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
int or list: The number of views parameter from the first dataset
|
| 469 |
+
|
| 470 |
+
Raises:
|
| 471 |
+
AssertionError: If datasets have different num_views
|
| 472 |
+
"""
|
| 473 |
+
num_views = self.datasets[0].num_views
|
| 474 |
+
for dataset in self.datasets[1:]:
|
| 475 |
+
assert dataset.num_views == num_views, (
|
| 476 |
+
"All datasets must have the same num_views and variable_num_views parameters"
|
| 477 |
+
)
|
| 478 |
+
return num_views
|
mapanything/datasets/utils/__init__.py
ADDED
|
File without changes
|
mapanything/datasets/utils/data_splits.py
ADDED
|
@@ -0,0 +1,1734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Modules containing dataset split information
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BlendedMVSSplits:
|
| 12 |
+
"""
|
| 13 |
+
This class contains the information about the BlendedMVS dataset splits.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
"""
|
| 18 |
+
The splits are generated using the following logic:
|
| 19 |
+
# Get all seqls and seqhs using self.blendedmvs_info.all_sequences
|
| 20 |
+
all_sequences = self.blendedmvs_info.all_sequences
|
| 21 |
+
all_seqls = [int(seq[8:], 16) for seq in all_sequences]
|
| 22 |
+
all_seqhs = [int(seq[:8], 16) for seq in all_sequences]
|
| 23 |
+
# Split the seqls (& corresponding seqhs) using the DUSt3R train/val split logic
|
| 24 |
+
if split is None:
|
| 25 |
+
selection = slice(None)
|
| 26 |
+
elif split in ["train", "overfit"]:
|
| 27 |
+
# select 90% of all scenes
|
| 28 |
+
selection = [(seql % 10) > 0 for seql in all_seqls]
|
| 29 |
+
elif split == "val":
|
| 30 |
+
# select 10% of all scenes
|
| 31 |
+
selection = [(seql % 10) == 0 for seql in all_seqls]
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unknown split {split}, must be None, train, val or overfit")
|
| 34 |
+
# Filter sequences based on the selection
|
| 35 |
+
selected_seqls = [seql for seql, sel in zip(all_seqls, selection) if sel]
|
| 36 |
+
selected_seqhs = [seqh for seqh, sel in zip(all_seqhs, selection) if sel]
|
| 37 |
+
# Put them back into sequence names f"{seqh:08x}{seql:016x}"
|
| 38 |
+
sequence_names = [f"{seqh:08x}{seql:016x}" for seqh, seql in zip(selected_seqhs, selected_seqls)]
|
| 39 |
+
# Remove invalid sequence names which don't exist in self.blendedmvs_info.sequences
|
| 40 |
+
valid_sequences = set(self.blendedmvs_info.sequences)
|
| 41 |
+
valid_sequence_names = [name for name in sequence_names if name in valid_sequences]
|
| 42 |
+
"""
|
| 43 |
+
# All the 502 sequences in the dataset (totals to 115k images)
|
| 44 |
+
self.all_scenes = [
|
| 45 |
+
"000000000000000000000000",
|
| 46 |
+
"00000000000000000000000a",
|
| 47 |
+
"00000000000000000000000b",
|
| 48 |
+
"00000000000000000000000c",
|
| 49 |
+
"00000000000000000000000d",
|
| 50 |
+
"00000000000000000000000e",
|
| 51 |
+
"00000000000000000000000f",
|
| 52 |
+
"000000000000000000000001",
|
| 53 |
+
"00000000000000000000001a",
|
| 54 |
+
"00000000000000000000001b",
|
| 55 |
+
"00000000000000000000001d",
|
| 56 |
+
"000000000000000000000002",
|
| 57 |
+
"000000000000000000000003",
|
| 58 |
+
"000000000000000000000004",
|
| 59 |
+
"000000000000000000000005",
|
| 60 |
+
"5a2a95f032a1c655cfe3de62",
|
| 61 |
+
"5a2af22b32a1c655cfe46013",
|
| 62 |
+
"5a2ba6de32a1c655cfe51b79",
|
| 63 |
+
"5a3b9731e24cd76dad1a5f1b",
|
| 64 |
+
"5a3ca9cb270f0e3f14d0eddb",
|
| 65 |
+
"5a3cb4e4270f0e3f14d12f43",
|
| 66 |
+
"5a03e732454a8a7ec672776c",
|
| 67 |
+
"5a3f4aba5889373fbbc5d3b5",
|
| 68 |
+
"5a4a38dad38c8a075495b5d2",
|
| 69 |
+
"5a5a1e48d62c7a12d5d00e47",
|
| 70 |
+
"5a6b1c418d100c2f8fdc4411",
|
| 71 |
+
"5a6feeb54a7fbc3f874f9db7",
|
| 72 |
+
"5a7cb1d6fe5c0d6fb53e64fb",
|
| 73 |
+
"5a7d3db14989e929563eb153",
|
| 74 |
+
"5a8aa0fab18050187cbe060e",
|
| 75 |
+
"5a9e5df65baeef72b4a021cd",
|
| 76 |
+
"5a48ba95c7dab83a7d7b44ed",
|
| 77 |
+
"5a48c4e9c7dab83a7d7b5cc7",
|
| 78 |
+
"5a48d4b2c7dab83a7d7b9851",
|
| 79 |
+
"5a69c47d0d5d0a7f3b2e9752",
|
| 80 |
+
"5a77b46b318efe6c6736e68a",
|
| 81 |
+
"5a355c271b63f53d5970f362",
|
| 82 |
+
"5a489fb1c7dab83a7d7b1070",
|
| 83 |
+
"5a533e8034d7582116e34209",
|
| 84 |
+
"5a562fc7425d0f5186314725",
|
| 85 |
+
"5a572fd9fc597b0478a81d14",
|
| 86 |
+
"5a588a8193ac3d233f77fbca",
|
| 87 |
+
"5a618c72784780334bc1972d",
|
| 88 |
+
"5a752d42acc41e2423f17674",
|
| 89 |
+
"5a969eea91dfc339a9a3ad2c",
|
| 90 |
+
"5a8315f624b8e938486e0bd8",
|
| 91 |
+
"5a57542f333d180827dfc132",
|
| 92 |
+
"5a0271884e62597cdee0d0eb",
|
| 93 |
+
"5a6400933d809f1d8200af15",
|
| 94 |
+
"5a6464143d809f1d8208c43c",
|
| 95 |
+
"5a563183425d0f5186314855",
|
| 96 |
+
"5aa0f9d7a9efce63548c69a1",
|
| 97 |
+
"5aa0f478a9efce63548c1cb4",
|
| 98 |
+
"5aa7db90bfdd572271e95246",
|
| 99 |
+
"5aa235f64a17b335eeaf9609",
|
| 100 |
+
"5aa515e613d42d091d29d300",
|
| 101 |
+
"5aa1196ea9efce63548ed649",
|
| 102 |
+
"5aaadd4cbc13235570d178a7",
|
| 103 |
+
"5ab6af12ac4291329b1072ab",
|
| 104 |
+
"5ab7e00aac4291329b15864d",
|
| 105 |
+
"5ab8b8e029f5351f7f2ccf59",
|
| 106 |
+
"5ab74bf2ac4291329b11e879",
|
| 107 |
+
"5ab85f1dac4291329b17cb50",
|
| 108 |
+
"5ab8713ba3799a1d138bd69a",
|
| 109 |
+
"5abc2506b53b042ead637d86",
|
| 110 |
+
"5acc7459a7853c4b5ebbef59",
|
| 111 |
+
"5acf8ca0f3d8a750097e4b15",
|
| 112 |
+
"5adc6bd52430a05ecb2ffb85",
|
| 113 |
+
"5ae2e9c5fe405c5076abc6b2",
|
| 114 |
+
"5af02e904c8216544b4ab5a2",
|
| 115 |
+
"5af28cea59bc705737003253",
|
| 116 |
+
"5af545d0559359053d25dcf5",
|
| 117 |
+
"5afacb69ab00705d0cefdd5b",
|
| 118 |
+
"5b2c67b5e0878c381608b8d8",
|
| 119 |
+
"5b3b2b9e8d46a939f933fdc0",
|
| 120 |
+
"5b3b353d8d46a939f93524b9",
|
| 121 |
+
"5b6e716d67b396324c2d77cb",
|
| 122 |
+
"5b6eff8b67b396324c5b2672",
|
| 123 |
+
"5b7a3890fc8fcf6781e2593a",
|
| 124 |
+
"5b21e18c58e2823a67a10dd8",
|
| 125 |
+
"5b60fa0c764f146feef84df0",
|
| 126 |
+
"5b69cc0cb44b61786eb959bf",
|
| 127 |
+
"5b78e57afc8fcf6781d0c3ba",
|
| 128 |
+
"5b192eb2170cf166458ff886",
|
| 129 |
+
"5b558a928bbfb62204e77ba2",
|
| 130 |
+
"5b864d850d072a699b32f4ae",
|
| 131 |
+
"5b908d3dc6ab78485f3d24a9",
|
| 132 |
+
"5b950c71608de421b1e7318f",
|
| 133 |
+
"5b4933abf2b5f44e95de482a",
|
| 134 |
+
"5b08286b2775267d5b0634ba",
|
| 135 |
+
"5b37189a35304b6f75e7583e",
|
| 136 |
+
"5b271079e0878c3816dacca4",
|
| 137 |
+
"5b22269758e2823a67a3bd03",
|
| 138 |
+
"5b62647143840965efc0dbde",
|
| 139 |
+
"5ba19a8a360c7c30c1c169df",
|
| 140 |
+
"5ba75d79d76ffa2c86cf2f05",
|
| 141 |
+
"5bb7a08aea1cfa39f1a947ab",
|
| 142 |
+
"5bb8a49aea1cfa39f1aa7f75",
|
| 143 |
+
"5bbb6eb2ea1cfa39f1af7e0c",
|
| 144 |
+
"5bc5f0e896b66a2cd8f9bd36",
|
| 145 |
+
"5bccd6beca24970bce448134",
|
| 146 |
+
"5bce7ac9ca24970bce4934b6",
|
| 147 |
+
"5bcf979a6d5f586b95c258cd",
|
| 148 |
+
"5bd43b4ba6b28b1ee86b92dd",
|
| 149 |
+
"5be3a5fb8cfdd56947f6b67c",
|
| 150 |
+
"5be3ae47f44e235bdbbc9771",
|
| 151 |
+
"5be4ab93870d330ff2dce134",
|
| 152 |
+
"5be47bf9b18881428d8fbc1d",
|
| 153 |
+
"5be883a4f98cee15019d5b83",
|
| 154 |
+
"5bea87f4abd34c35e1860ab5",
|
| 155 |
+
"5beb6e66abd34c35e18e66b9",
|
| 156 |
+
"5bf3a82cd439231948877aed",
|
| 157 |
+
"5bf7d63575c26f32dbf7413b",
|
| 158 |
+
"5bf17c0fd439231948355385",
|
| 159 |
+
"5bf26cbbd43923194854b270",
|
| 160 |
+
"5bf03590d4392319481971dc",
|
| 161 |
+
"5bf18642c50e6f7f8bdbd492",
|
| 162 |
+
"5bf21799d43923194842c001",
|
| 163 |
+
"5bfc9d5aec61ca1dd69132a2",
|
| 164 |
+
"5bfd0f32ec61ca1dd69dc77b",
|
| 165 |
+
"5bfe5ae0fe0ea555e6a969ca",
|
| 166 |
+
"5bff3c5cfe0ea555e6bcbf3a",
|
| 167 |
+
"5c0d13b795da9479e12e2ee9",
|
| 168 |
+
"5c1af2e2bee9a723c963d019",
|
| 169 |
+
"5c1b1500bee9a723c96c3e78",
|
| 170 |
+
"5c1dbf200843bc542d8ef8c4",
|
| 171 |
+
"5c1f33f1d33e1f2e4aa6dda4",
|
| 172 |
+
"5c2b3ed5e611832e8aed46bf",
|
| 173 |
+
"5c20ca3a0843bc542d94e3e2",
|
| 174 |
+
"5c062d84a96e33018ff6f0a6",
|
| 175 |
+
"5c189f2326173c3a09ed7ef3",
|
| 176 |
+
"5c1892f726173c3a09ea9aeb",
|
| 177 |
+
"5c34300a73a8df509add216d",
|
| 178 |
+
"5c34529873a8df509ae57b58",
|
| 179 |
+
"000000000000000000000006",
|
| 180 |
+
"000000000000000000000007",
|
| 181 |
+
"000000000000000000000008",
|
| 182 |
+
"000000000000000000000009",
|
| 183 |
+
"000000000000000000000010",
|
| 184 |
+
"000000000000000000000011",
|
| 185 |
+
"000000000000000000000012",
|
| 186 |
+
"000000000000000000000015",
|
| 187 |
+
"000000000000000000000016",
|
| 188 |
+
"000000000000000000000017",
|
| 189 |
+
"000000000000000000000018",
|
| 190 |
+
"000000000000000000000019",
|
| 191 |
+
"56d73ba74bd29b8c35abade2",
|
| 192 |
+
"56f34064e296120e10484dc4",
|
| 193 |
+
"57a4a7bb6b9272286e26dc18",
|
| 194 |
+
"57f8d9bbe73f6760f10e916a",
|
| 195 |
+
"58a0a2f33d0b4542479a11b1",
|
| 196 |
+
"58a0dd1a3d0b4542479a28f3",
|
| 197 |
+
"58a1a7914a4d262a170b1101",
|
| 198 |
+
"58a1bc804a4d262a170b2f01",
|
| 199 |
+
"58a1d9d14a4d262a170b58fe",
|
| 200 |
+
"58a01dea38486e3c98475871",
|
| 201 |
+
"58a1f5d74a4d262a170b65fc",
|
| 202 |
+
"58a2a09e156b87103d3d668c",
|
| 203 |
+
"58a2d9c3156b87103d3da90f",
|
| 204 |
+
"58a3ccb0156b87103d3e4332",
|
| 205 |
+
"58a3f2f8156b87103d3e5838",
|
| 206 |
+
"58a3f6c0156b87103d3e5971",
|
| 207 |
+
"58a3fc95156b87103d3e5d9b",
|
| 208 |
+
"58a07ce53d0b45424799fdde",
|
| 209 |
+
"58a07f233d0b45424799ffe7",
|
| 210 |
+
"58a44df2156b87103d3ee239",
|
| 211 |
+
"58a164f73d0b4542479a7a8e",
|
| 212 |
+
"58a0365e38486e3c984783eb",
|
| 213 |
+
"58a439cf156b87103d3ec885",
|
| 214 |
+
"58a464aa156b87103d3eec04",
|
| 215 |
+
"58a4452f156b87103d3ed55b",
|
| 216 |
+
"58a160983d0b4542479a7347",
|
| 217 |
+
"58a186444a4d262a170ae3ae",
|
| 218 |
+
"58a285424a4d262a170baf3e",
|
| 219 |
+
"58a41819156b87103d3e92a5",
|
| 220 |
+
"58a44463156b87103d3ed45e",
|
| 221 |
+
"58a47552156b87103d3f00a4",
|
| 222 |
+
"58c4bb4f4a69c55606122be4",
|
| 223 |
+
"58c6451e4a69c556061894f1",
|
| 224 |
+
"58ca7014affdfd07c70a95ce",
|
| 225 |
+
"58cf4771d0f5fb221defe6da",
|
| 226 |
+
"58d36897f387231e6c929903",
|
| 227 |
+
"58eaf1513353456af3a1682a",
|
| 228 |
+
"58f7f7299f5b5647873cb110",
|
| 229 |
+
"58f73e7c9f5b56478738929f",
|
| 230 |
+
"59a8f851597729752c31e7e0",
|
| 231 |
+
"59a452bf9b460239aa5d1c72",
|
| 232 |
+
"59a9619a825418241fb88191",
|
| 233 |
+
"59acd2f4b891807f439c8992",
|
| 234 |
+
"59bf97fe7e7b31545da34439",
|
| 235 |
+
"59c1c3e2fd6e3d4ead9f1013",
|
| 236 |
+
"59d2657f82ca7774b1ec081d",
|
| 237 |
+
"59da1fb88a126011d0394ae9",
|
| 238 |
+
"59e75a2ca9e91f2c5526005d",
|
| 239 |
+
"59e864b2a9e91f2c5529325f",
|
| 240 |
+
"59ecfd02e225f6492d20fcc9",
|
| 241 |
+
"59f37f74b45be2233001ba18",
|
| 242 |
+
"59f70ab1e5c5d366af29bf3e",
|
| 243 |
+
"59f87d0bfa6280566fb38c9a",
|
| 244 |
+
"59f363a8b45be22330016cad",
|
| 245 |
+
"564a27b26d07883f460d8ab0",
|
| 246 |
+
"565fb1dead14d4154dae2b94",
|
| 247 |
+
"567a0fb0a825d2fb79ac9a20",
|
| 248 |
+
"569b92eb826bcba945ca002b",
|
| 249 |
+
"576fefa017ce5a16397e87fd",
|
| 250 |
+
"584a7333fe3cb463906c9fe6",
|
| 251 |
+
"584aa8e9fe3cb463906cc7d0",
|
| 252 |
+
"584ad76bfe3cb463906ce6dc",
|
| 253 |
+
"584af003fe3cb463906d0e9b",
|
| 254 |
+
"584b9a747072670e72bfc49d",
|
| 255 |
+
"584b671f7072670e72bfaaf8",
|
| 256 |
+
"584b81747072670e72bfbbfd",
|
| 257 |
+
"584ba35f7072670e72bfca4d",
|
| 258 |
+
"584ba5977072670e72bfcc2d",
|
| 259 |
+
"584bc53c7072670e72bfe85f",
|
| 260 |
+
"584bc3997072670e72bfe58d",
|
| 261 |
+
"584bc4407072670e72bfe665",
|
| 262 |
+
"584bd5587072670e72bffe39",
|
| 263 |
+
"584bdadf7072670e72c0005c",
|
| 264 |
+
"584be5ed7072670e72c007b3",
|
| 265 |
+
"584c9ad27072670e72c060c5",
|
| 266 |
+
"584c9cc67072670e72c063a1",
|
| 267 |
+
"584c58b77072670e72c03990",
|
| 268 |
+
"584cea557072670e72c07fb4",
|
| 269 |
+
"584d19d47072670e72c0c6c0",
|
| 270 |
+
"584dfe467072670e72c1665a",
|
| 271 |
+
"584e875c7072670e72c1ec94",
|
| 272 |
+
"584e05667072670e72c17167",
|
| 273 |
+
"584f94e87072670e72c2d3f7",
|
| 274 |
+
"584fdffd7072670e72c32dc7",
|
| 275 |
+
"584fe07f7072670e72c32e59",
|
| 276 |
+
"585a2a71b338a62ad50138dc",
|
| 277 |
+
"585a206ab338a62ad501298f",
|
| 278 |
+
"585a217cb338a62ad5012b38",
|
| 279 |
+
"585b34afb338a62ad501e836",
|
| 280 |
+
"585bb25fc49c8507c3ce7812",
|
| 281 |
+
"585bbe55c49c8507c3ce81cd",
|
| 282 |
+
"585d6c8a2a57cc11d4920a1e",
|
| 283 |
+
"585e54c72a57cc11d492f71a",
|
| 284 |
+
"585e34302a57cc11d492be30",
|
| 285 |
+
"585ee0632a57cc11d4933608",
|
| 286 |
+
"585f9661712e2761468dabca",
|
| 287 |
+
"585ffe9a712e2761468df643",
|
| 288 |
+
"586a37ec9d1b5e34c28184fc",
|
| 289 |
+
"586a515a9d1b5e34c281b431",
|
| 290 |
+
"586a94939d1b5e34c2823b5d",
|
| 291 |
+
"586abc689d1b5e34c2826360",
|
| 292 |
+
"586b0e219d1b5e34c2828862",
|
| 293 |
+
"586b3db89d1b5e34c282cd52",
|
| 294 |
+
"586b4c459d1b5e34c282e66d",
|
| 295 |
+
"586b7d7d9d1b5e34c283359e",
|
| 296 |
+
"586b8f149d1b5e34c283497c",
|
| 297 |
+
"586b8f629d1b5e34c28349d6",
|
| 298 |
+
"586c4c4d9d1b5e34c28391a1",
|
| 299 |
+
"586c5b5b9d1b5e34c2839a5b",
|
| 300 |
+
"586c9fdf9d1b5e34c283b657",
|
| 301 |
+
"586c48329d1b5e34c2838e80",
|
| 302 |
+
"586caab99d1b5e34c283c213",
|
| 303 |
+
"586cd0779d1b5e34c28403a7",
|
| 304 |
+
"586d6d249d1b5e34c284b80e",
|
| 305 |
+
"586d8a029d1b5e34c284c948",
|
| 306 |
+
"586d55af9d1b5e34c284a999",
|
| 307 |
+
"586d07869d1b5e34c2842e5b",
|
| 308 |
+
"586d27489d1b5e34c28453af",
|
| 309 |
+
"586df9849d1b5e34c28506de",
|
| 310 |
+
"586e279c9d1b5e34c2852180",
|
| 311 |
+
"587bc5ec2366dd5d06e262c1",
|
| 312 |
+
"587c1abf2366dd5d06e28901",
|
| 313 |
+
"587c03f12366dd5d06e27722",
|
| 314 |
+
"587c19da2366dd5d06e2877b",
|
| 315 |
+
"587c31b92366dd5d06e2a9dc",
|
| 316 |
+
"587c87d02366dd5d06e2f989",
|
| 317 |
+
"587c97a52366dd5d06e30a96",
|
| 318 |
+
"587c45192366dd5d06e2c0eb",
|
| 319 |
+
"587cec702366dd5d06e37862",
|
| 320 |
+
"587cef0a2366dd5d06e379e3",
|
| 321 |
+
"587db5872366dd5d06e3e0af",
|
| 322 |
+
"587e2b1d2366dd5d06e41af0",
|
| 323 |
+
"587e2ea62366dd5d06e41f2e",
|
| 324 |
+
"587e5cb52366dd5d06e4486e",
|
| 325 |
+
"587eb1822366dd5d06e45f29",
|
| 326 |
+
"587f365d2366dd5d06e4906e",
|
| 327 |
+
"588a9c5fec4d5a1c088ec350",
|
| 328 |
+
"588a34cfec4d5a1c088ea8d1",
|
| 329 |
+
"588ab5bdec4d5a1c088ed60f",
|
| 330 |
+
"588aff9d90414422fbe7885a",
|
| 331 |
+
"588b20d290414422fbe79f40",
|
| 332 |
+
"588c08d590414422fbe8200b",
|
| 333 |
+
"588c203d90414422fbe8319e",
|
| 334 |
+
"588c989a90414422fbe86d96",
|
| 335 |
+
"588ca09d90414422fbe871a1",
|
| 336 |
+
"588cce2190414422fbe88520",
|
| 337 |
+
"588cd5ef90414422fbe8875c",
|
| 338 |
+
"588cf0ad90414422fbe8a20f",
|
| 339 |
+
"588e0d8c90414422fbe8f8b2",
|
| 340 |
+
"588e01c490414422fbe8ee2a",
|
| 341 |
+
"588e35e690414422fbe90a53",
|
| 342 |
+
"588f017e90414422fbe9b74b",
|
| 343 |
+
"588f095190414422fbe9c1ee",
|
| 344 |
+
"589aca717dc3d323d55671c4",
|
| 345 |
+
"589af2c97dc3d323d55691e8",
|
| 346 |
+
"589b49ea7dc3d323d556d9b4",
|
| 347 |
+
"589b04287dc3d323d556a185",
|
| 348 |
+
"589bf6a57dc3d323d55743ab",
|
| 349 |
+
"589c3c497dc3d323d5578468",
|
| 350 |
+
"589c3c577dc3d323d5578480",
|
| 351 |
+
"589c300f7dc3d323d5577926",
|
| 352 |
+
"589c24527dc3d323d5577126",
|
| 353 |
+
"589c35457dc3d323d5577d8d",
|
| 354 |
+
"589ca6a6b896147a1b73aff7",
|
| 355 |
+
"589d1e1fb896147a1b73ee5b",
|
| 356 |
+
"589d5c58b896147a1b742256",
|
| 357 |
+
"589d95538fa2cf375df3317b",
|
| 358 |
+
"589df0ffb504a864ad63521a",
|
| 359 |
+
"589ea316b504a864ad639a2b",
|
| 360 |
+
"589ec97cb504a864ad63adc3",
|
| 361 |
+
"589f214338486e3c9846f123",
|
| 362 |
+
"589fdfe738486e3c984736cf",
|
| 363 |
+
"590c2d70336bb52a190be886",
|
| 364 |
+
"590f91851225725be9e25d4e",
|
| 365 |
+
"591a467a6109e14d4f09b776",
|
| 366 |
+
"591cf3033162411cf9047f37",
|
| 367 |
+
"591ea44850991c70dc99a207",
|
| 368 |
+
"599aa591d5b41f366fed0d58",
|
| 369 |
+
"5643df56138263b51db1b5f3",
|
| 370 |
+
"5644bdac138263b51db9f669",
|
| 371 |
+
"5692a4c2adafac1f14201821",
|
| 372 |
+
"5850d4f97072670e72c425d6",
|
| 373 |
+
"5854c405804be105852330fe",
|
| 374 |
+
"5855a4fc804be1058523bd75",
|
| 375 |
+
"5856ac15804be105852419d8",
|
| 376 |
+
"5856ae8b804be10585241bae",
|
| 377 |
+
"5856b460804be10585242059",
|
| 378 |
+
"5857aa5ab338a62ad5ff4dbe",
|
| 379 |
+
"5857acf8b338a62ad5ff5107",
|
| 380 |
+
"5858db6cb338a62ad500103b",
|
| 381 |
+
"5858dbcab338a62ad5001081",
|
| 382 |
+
"5859d84fb338a62ad500e5cf",
|
| 383 |
+
"5861d8ea712e2761468f3cb3",
|
| 384 |
+
"5863edf8712e27614690cce0",
|
| 385 |
+
"5864a935712e2761469111b4",
|
| 386 |
+
"5864b076712e27614691197e",
|
| 387 |
+
"5864da88712e276146913d8b",
|
| 388 |
+
"5865f4a8712e27614691e39b",
|
| 389 |
+
"5867a434833dfe3f7b88edaf",
|
| 390 |
+
"5868cd15833dfe3f7b89bfa3",
|
| 391 |
+
"5880b3692366dd5d06e5d534",
|
| 392 |
+
"5880e3422366dd5d06e5ff8e",
|
| 393 |
+
"5880f0ef2366dd5d06e6166e",
|
| 394 |
+
"5881d2bfb6844814c136a119",
|
| 395 |
+
"5881f11d8ce2c2754d0714c3",
|
| 396 |
+
"5881fee18ce2c2754d0723f8",
|
| 397 |
+
"5882cda2b116682b4adebd25",
|
| 398 |
+
"5882d58fb116682b4adec7db",
|
| 399 |
+
"5884c256932ba84fbed70bf5",
|
| 400 |
+
"5884cc13932ba84fbed71ec4",
|
| 401 |
+
"5885bc5296fa095e0671a7f0",
|
| 402 |
+
"5886d14cb791366d617a362c",
|
| 403 |
+
"5888becfc02346100f4b0b21",
|
| 404 |
+
"5888e408c02346100f4b1a29",
|
| 405 |
+
"5889da66ec4d5a1c088e5187",
|
| 406 |
+
"5889e344ec4d5a1c088e59be",
|
| 407 |
+
"5889e754ec4d5a1c088e60ba",
|
| 408 |
+
"5890c16b90414422fbeb0262",
|
| 409 |
+
"5891d8ae9a8c0314c5cd30ab",
|
| 410 |
+
"5891d0479a8c0314c5cd2abd",
|
| 411 |
+
"5891ecf19a8c0314c5cd490a",
|
| 412 |
+
"5892c0cd9a8c0314c5cdc977",
|
| 413 |
+
"5894ab309a8c0314c5cee57d",
|
| 414 |
+
"5895a6a89a8c0314c5cfca7c",
|
| 415 |
+
"5895b8c29a8c0314c5cfd051",
|
| 416 |
+
"5895d38f9a8c0314c5cfe50c",
|
| 417 |
+
"5895f2329a8c0314c5d00117",
|
| 418 |
+
"5896bb989a8c0314c5d086b6",
|
| 419 |
+
"5896ebf39a8c0314c5d0a8c4",
|
| 420 |
+
"5898b1bac9dccc22987b7f74",
|
| 421 |
+
"5898b6ffc9dccc22987b8a03",
|
| 422 |
+
"5898b31cc9dccc22987b82ec",
|
| 423 |
+
"5898bbaac9dccc22987b8eba",
|
| 424 |
+
"5899cfa6b76d7a3780a4cb64",
|
| 425 |
+
"5899e5dcb76d7a3780a4ecc1",
|
| 426 |
+
"5947b62af1b45630bd0c2a02",
|
| 427 |
+
"57102be2877e1421026358af",
|
| 428 |
+
"57153d4031bb9900425bde85",
|
| 429 |
+
"57177cd7fb8d93461afc4527",
|
| 430 |
+
"58497cdf97b73e0b090c4273",
|
| 431 |
+
"58500b007072670e72c35588",
|
| 432 |
+
"58510bf97072670e72c46ddf",
|
| 433 |
+
"58522bd56789802282f2ecb3",
|
| 434 |
+
"58524a2e0e7012308944bcf3",
|
| 435 |
+
"58524a080e7012308944bcbf",
|
| 436 |
+
"58524c1d0e7012308944bfda",
|
| 437 |
+
"58524f170e7012308944c200",
|
| 438 |
+
"58529a4e0e70123089454c6f",
|
| 439 |
+
"58551bdf804be1058523556d",
|
| 440 |
+
"58568c9a804be10585240b03",
|
| 441 |
+
"58574b35804be105852455fd",
|
| 442 |
+
"58577c60b338a62ad5ff1564",
|
| 443 |
+
"58592d69b338a62ad5007a74",
|
| 444 |
+
"58598db2b338a62ad500bc38",
|
| 445 |
+
"58625f42712e2761468fb44c",
|
| 446 |
+
"58651bcc712e2761469166dc",
|
| 447 |
+
"58660e79712e27614691fe3d",
|
| 448 |
+
"58669aad712e27614692834c",
|
| 449 |
+
"58669c02712e27614692851a",
|
| 450 |
+
"58676c36833dfe3f7b88b7f2",
|
| 451 |
+
"58678b2d833dfe3f7b88e244",
|
| 452 |
+
"58790c82ce911104a3467c88",
|
| 453 |
+
"58800b0b2366dd5d06e5312d",
|
| 454 |
+
"58805eac2366dd5d06e56460",
|
| 455 |
+
"58806e422366dd5d06e57bb6",
|
| 456 |
+
"58831d060db9bf59bf8ab98b",
|
| 457 |
+
"58851ebb932ba84fbed7abad",
|
| 458 |
+
"58871dc3b791366d617a55ff",
|
| 459 |
+
"58873cabb791366d617a65a7",
|
| 460 |
+
"58873d44b791366d617a65dd",
|
| 461 |
+
"58888b3dc02346100f4af665",
|
| 462 |
+
"58897f62c02346100f4b8ee6",
|
| 463 |
+
"58933bac9a8c0314c5ce3508",
|
| 464 |
+
"58938e6d9a8c0314c5ce726f",
|
| 465 |
+
"58951cb49a8c0314c5cf4d5e",
|
| 466 |
+
"58970fd09a8c0314c5d0e383",
|
| 467 |
+
"58977ef09a8c0314c5d17b26",
|
| 468 |
+
"59056e6760bb961de55f3501",
|
| 469 |
+
"59071f2e5a6dbd3af4130f98",
|
| 470 |
+
"59102c811225725be9e64149",
|
| 471 |
+
"59338e76772c3e6384afbb15",
|
| 472 |
+
"59350ca084b7f26bf5ce6eb8",
|
| 473 |
+
"59397e493a87372f2c9e882b",
|
| 474 |
+
"59521e0b9096412211c2aa9d",
|
| 475 |
+
"59817e4a1bd4b175e7038d19",
|
| 476 |
+
"567884f58d2828b95e3c8eba",
|
| 477 |
+
"585559d9804be10585238ddf",
|
| 478 |
+
"585834cdb338a62ad5ffab4d",
|
| 479 |
+
"586082d8712e2761468e2877",
|
| 480 |
+
"586133c2712e2761468ecfe3",
|
| 481 |
+
"586281d2712e2761468fcaa2",
|
| 482 |
+
"586316e5712e276146903c4d",
|
| 483 |
+
"586326ad712e276146904571",
|
| 484 |
+
"586375c9712e276146907429",
|
| 485 |
+
"586389c9712e276146908da6",
|
| 486 |
+
"586496fa712e2761469108e7",
|
| 487 |
+
"586669c6712e27614692597a",
|
| 488 |
+
"586913a49d1b5e34c2808b02",
|
| 489 |
+
"586922da9d1b5e34c2809ff3",
|
| 490 |
+
"588185d8dfb7a15588a114a3",
|
| 491 |
+
"588305ed0db9bf59bf8a8c80",
|
| 492 |
+
"588315c60db9bf59bf8aa928",
|
| 493 |
+
"588332ee0db9bf59bf8ae9c3",
|
| 494 |
+
"588457b8932ba84fbed69942",
|
| 495 |
+
"588519d5932ba84fbed7a04a",
|
| 496 |
+
"588824d1b791366d617adeef",
|
| 497 |
+
"588857f6c02346100f4ac09f",
|
| 498 |
+
"589145ef90414422fbeb2e08",
|
| 499 |
+
"589433fa9a8c0314c5ce9656",
|
| 500 |
+
"589765d39a8c0314c5d16b12",
|
| 501 |
+
"5851165f7072670e72c4860d",
|
| 502 |
+
"5859341ab338a62ad500848d",
|
| 503 |
+
"5862388b712e2761468f84aa",
|
| 504 |
+
"5863915b712e276146909135",
|
| 505 |
+
"5866445b712e27614692383e",
|
| 506 |
+
"5866500d712e2761469240fd",
|
| 507 |
+
"5867785a833dfe3f7b88c764",
|
| 508 |
+
"5867969c833dfe3f7b88e8bc",
|
| 509 |
+
"5868040c833dfe3f7b8934f7",
|
| 510 |
+
"5880675a2366dd5d06e570ca",
|
| 511 |
+
"5882372c8ce2c2754d076af0",
|
| 512 |
+
"5883535e932ba84fbed5ad07",
|
| 513 |
+
"5888358cb791366d617af69d",
|
| 514 |
+
"5890330d90414422fbeaa0cb",
|
| 515 |
+
"5897076e9a8c0314c5d0d31b",
|
| 516 |
+
"5940564ec2d9527ab869f7e2",
|
| 517 |
+
"5947719bf1b45630bd096665",
|
| 518 |
+
"5948194ff1b45630bd0f47e3",
|
| 519 |
+
"5950206a41b158666ac50506",
|
| 520 |
+
"5983012d1bd4b175e70c985a",
|
| 521 |
+
"58586810b338a62ad5ffc20c",
|
| 522 |
+
"58592046b338a62ad5006b33",
|
| 523 |
+
"58592854b338a62ad500750a",
|
| 524 |
+
"58596531b338a62ad500aace",
|
| 525 |
+
"58818685dfb7a15588a11626",
|
| 526 |
+
"58829563f42b1d3ee3ec835f",
|
| 527 |
+
"58894345c02346100f4b51ca",
|
| 528 |
+
"585289980e7012308945276a",
|
| 529 |
+
"585369770e7012308945c709",
|
| 530 |
+
"585373640e7012308945cab9",
|
| 531 |
+
"588230658ce2c2754d076728",
|
| 532 |
+
"589388059a8c0314c5ce718b",
|
| 533 |
+
"595979485ec6a95e86a58c8d",
|
| 534 |
+
"5841206219d291325678ca90",
|
| 535 |
+
"58563650804be1058523da55",
|
| 536 |
+
"58564084804be1058523e116",
|
| 537 |
+
"58636467712e27614690661f",
|
| 538 |
+
"58647495712e27614690f36d",
|
| 539 |
+
"58654563712e276146918643",
|
| 540 |
+
"58664251712e276146923738",
|
| 541 |
+
"588084032366dd5d06e59e82",
|
| 542 |
+
"588159582366dd5d06e66877",
|
| 543 |
+
"5890279190414422fbea9734",
|
| 544 |
+
"5890523090414422fbeab3f0",
|
| 545 |
+
"5890641690414422fbeabbe7",
|
| 546 |
+
"585203546789802282f2aaf5",
|
| 547 |
+
]
|
| 548 |
+
|
| 549 |
+
# Final sequences to be used after filtering (some of the sequences have incorrect/low quality depth)
|
| 550 |
+
# Generally water bodies like lakes have incorrect depth
|
| 551 |
+
# Filtered out sequences:
|
| 552 |
+
# "5692a4c2adafac1f14201821" # Incorrect Depth
|
| 553 |
+
# "5864a935712e2761469111b4" # Noisy Depth and artifacts near horizon
|
| 554 |
+
# "59f87d0bfa6280566fb38c9a" # Object-centric, noise with background and sometimes in front of object
|
| 555 |
+
# "58a44463156b87103d3ed45e" # Very noisy depth in background
|
| 556 |
+
# "5c2b3ed5e611832e8aed46bf" # Depth occluded by artifacts
|
| 557 |
+
# "5bf03590d4392319481971dc" # Depth occluded by artifacts
|
| 558 |
+
# "00000000000000000000001a" # Largely incomplete depth
|
| 559 |
+
# "00000000000000000000000c" # Imprecise depth for buildings
|
| 560 |
+
# "000000000000000000000000" # Incorrect depth for planar terrain
|
| 561 |
+
self.scenes = [
|
| 562 |
+
"00000000000000000000000a",
|
| 563 |
+
"00000000000000000000000b",
|
| 564 |
+
"00000000000000000000000d",
|
| 565 |
+
"00000000000000000000000e",
|
| 566 |
+
"00000000000000000000000f",
|
| 567 |
+
"000000000000000000000001",
|
| 568 |
+
"00000000000000000000001b",
|
| 569 |
+
"00000000000000000000001d",
|
| 570 |
+
"000000000000000000000002",
|
| 571 |
+
"000000000000000000000003",
|
| 572 |
+
"000000000000000000000004",
|
| 573 |
+
"000000000000000000000005",
|
| 574 |
+
"5a2a95f032a1c655cfe3de62",
|
| 575 |
+
"5a2af22b32a1c655cfe46013",
|
| 576 |
+
"5a2ba6de32a1c655cfe51b79",
|
| 577 |
+
"5a3b9731e24cd76dad1a5f1b",
|
| 578 |
+
"5a3ca9cb270f0e3f14d0eddb",
|
| 579 |
+
"5a3cb4e4270f0e3f14d12f43",
|
| 580 |
+
"5a03e732454a8a7ec672776c",
|
| 581 |
+
"5a3f4aba5889373fbbc5d3b5",
|
| 582 |
+
"5a4a38dad38c8a075495b5d2",
|
| 583 |
+
"5a5a1e48d62c7a12d5d00e47",
|
| 584 |
+
"5a6b1c418d100c2f8fdc4411",
|
| 585 |
+
"5a6feeb54a7fbc3f874f9db7",
|
| 586 |
+
"5a7cb1d6fe5c0d6fb53e64fb",
|
| 587 |
+
"5a7d3db14989e929563eb153",
|
| 588 |
+
"5a8aa0fab18050187cbe060e",
|
| 589 |
+
"5a9e5df65baeef72b4a021cd",
|
| 590 |
+
"5a48ba95c7dab83a7d7b44ed",
|
| 591 |
+
"5a48c4e9c7dab83a7d7b5cc7",
|
| 592 |
+
"5a48d4b2c7dab83a7d7b9851",
|
| 593 |
+
"5a69c47d0d5d0a7f3b2e9752",
|
| 594 |
+
"5a77b46b318efe6c6736e68a",
|
| 595 |
+
"5a355c271b63f53d5970f362",
|
| 596 |
+
"5a489fb1c7dab83a7d7b1070",
|
| 597 |
+
"5a533e8034d7582116e34209",
|
| 598 |
+
"5a562fc7425d0f5186314725",
|
| 599 |
+
"5a572fd9fc597b0478a81d14",
|
| 600 |
+
"5a588a8193ac3d233f77fbca",
|
| 601 |
+
"5a618c72784780334bc1972d",
|
| 602 |
+
"5a752d42acc41e2423f17674",
|
| 603 |
+
"5a969eea91dfc339a9a3ad2c",
|
| 604 |
+
"5a8315f624b8e938486e0bd8",
|
| 605 |
+
"5a57542f333d180827dfc132",
|
| 606 |
+
"5a0271884e62597cdee0d0eb",
|
| 607 |
+
"5a6400933d809f1d8200af15",
|
| 608 |
+
"5a6464143d809f1d8208c43c",
|
| 609 |
+
"5a563183425d0f5186314855",
|
| 610 |
+
"5aa0f9d7a9efce63548c69a1",
|
| 611 |
+
"5aa0f478a9efce63548c1cb4",
|
| 612 |
+
"5aa7db90bfdd572271e95246",
|
| 613 |
+
"5aa235f64a17b335eeaf9609",
|
| 614 |
+
"5aa515e613d42d091d29d300",
|
| 615 |
+
"5aa1196ea9efce63548ed649",
|
| 616 |
+
"5aaadd4cbc13235570d178a7",
|
| 617 |
+
"5ab6af12ac4291329b1072ab",
|
| 618 |
+
"5ab7e00aac4291329b15864d",
|
| 619 |
+
"5ab8b8e029f5351f7f2ccf59",
|
| 620 |
+
"5ab74bf2ac4291329b11e879",
|
| 621 |
+
"5ab85f1dac4291329b17cb50",
|
| 622 |
+
"5ab8713ba3799a1d138bd69a",
|
| 623 |
+
"5abc2506b53b042ead637d86",
|
| 624 |
+
"5acc7459a7853c4b5ebbef59",
|
| 625 |
+
"5acf8ca0f3d8a750097e4b15",
|
| 626 |
+
"5adc6bd52430a05ecb2ffb85",
|
| 627 |
+
"5ae2e9c5fe405c5076abc6b2",
|
| 628 |
+
"5af02e904c8216544b4ab5a2",
|
| 629 |
+
"5af28cea59bc705737003253",
|
| 630 |
+
"5af545d0559359053d25dcf5",
|
| 631 |
+
"5afacb69ab00705d0cefdd5b",
|
| 632 |
+
"5b2c67b5e0878c381608b8d8",
|
| 633 |
+
"5b3b2b9e8d46a939f933fdc0",
|
| 634 |
+
"5b3b353d8d46a939f93524b9",
|
| 635 |
+
"5b6e716d67b396324c2d77cb",
|
| 636 |
+
"5b6eff8b67b396324c5b2672",
|
| 637 |
+
"5b7a3890fc8fcf6781e2593a",
|
| 638 |
+
"5b21e18c58e2823a67a10dd8",
|
| 639 |
+
"5b60fa0c764f146feef84df0",
|
| 640 |
+
"5b69cc0cb44b61786eb959bf",
|
| 641 |
+
"5b78e57afc8fcf6781d0c3ba",
|
| 642 |
+
"5b192eb2170cf166458ff886",
|
| 643 |
+
"5b558a928bbfb62204e77ba2",
|
| 644 |
+
"5b864d850d072a699b32f4ae",
|
| 645 |
+
"5b908d3dc6ab78485f3d24a9",
|
| 646 |
+
"5b950c71608de421b1e7318f",
|
| 647 |
+
"5b4933abf2b5f44e95de482a",
|
| 648 |
+
"5b08286b2775267d5b0634ba",
|
| 649 |
+
"5b37189a35304b6f75e7583e",
|
| 650 |
+
"5b271079e0878c3816dacca4",
|
| 651 |
+
"5b22269758e2823a67a3bd03",
|
| 652 |
+
"5b62647143840965efc0dbde",
|
| 653 |
+
"5ba19a8a360c7c30c1c169df",
|
| 654 |
+
"5ba75d79d76ffa2c86cf2f05",
|
| 655 |
+
"5bb7a08aea1cfa39f1a947ab",
|
| 656 |
+
"5bb8a49aea1cfa39f1aa7f75",
|
| 657 |
+
"5bbb6eb2ea1cfa39f1af7e0c",
|
| 658 |
+
"5bc5f0e896b66a2cd8f9bd36",
|
| 659 |
+
"5bccd6beca24970bce448134",
|
| 660 |
+
"5bce7ac9ca24970bce4934b6",
|
| 661 |
+
"5bcf979a6d5f586b95c258cd",
|
| 662 |
+
"5bd43b4ba6b28b1ee86b92dd",
|
| 663 |
+
"5be3a5fb8cfdd56947f6b67c",
|
| 664 |
+
"5be3ae47f44e235bdbbc9771",
|
| 665 |
+
"5be4ab93870d330ff2dce134",
|
| 666 |
+
"5be47bf9b18881428d8fbc1d",
|
| 667 |
+
"5be883a4f98cee15019d5b83",
|
| 668 |
+
"5bea87f4abd34c35e1860ab5",
|
| 669 |
+
"5beb6e66abd34c35e18e66b9",
|
| 670 |
+
"5bf3a82cd439231948877aed",
|
| 671 |
+
"5bf7d63575c26f32dbf7413b",
|
| 672 |
+
"5bf17c0fd439231948355385",
|
| 673 |
+
"5bf26cbbd43923194854b270",
|
| 674 |
+
"5bf18642c50e6f7f8bdbd492",
|
| 675 |
+
"5bf21799d43923194842c001",
|
| 676 |
+
"5bfc9d5aec61ca1dd69132a2",
|
| 677 |
+
"5bfd0f32ec61ca1dd69dc77b",
|
| 678 |
+
"5bfe5ae0fe0ea555e6a969ca",
|
| 679 |
+
"5bff3c5cfe0ea555e6bcbf3a",
|
| 680 |
+
"5c0d13b795da9479e12e2ee9",
|
| 681 |
+
"5c1af2e2bee9a723c963d019",
|
| 682 |
+
"5c1b1500bee9a723c96c3e78",
|
| 683 |
+
"5c1dbf200843bc542d8ef8c4",
|
| 684 |
+
"5c1f33f1d33e1f2e4aa6dda4",
|
| 685 |
+
"5c20ca3a0843bc542d94e3e2",
|
| 686 |
+
"5c062d84a96e33018ff6f0a6",
|
| 687 |
+
"5c189f2326173c3a09ed7ef3",
|
| 688 |
+
"5c1892f726173c3a09ea9aeb",
|
| 689 |
+
"5c34300a73a8df509add216d",
|
| 690 |
+
"5c34529873a8df509ae57b58",
|
| 691 |
+
"000000000000000000000006",
|
| 692 |
+
"000000000000000000000007",
|
| 693 |
+
"000000000000000000000008",
|
| 694 |
+
"000000000000000000000009",
|
| 695 |
+
"000000000000000000000010",
|
| 696 |
+
"000000000000000000000011",
|
| 697 |
+
"000000000000000000000012",
|
| 698 |
+
"000000000000000000000015",
|
| 699 |
+
"000000000000000000000016",
|
| 700 |
+
"000000000000000000000017",
|
| 701 |
+
"000000000000000000000018",
|
| 702 |
+
"000000000000000000000019",
|
| 703 |
+
"56d73ba74bd29b8c35abade2",
|
| 704 |
+
"56f34064e296120e10484dc4",
|
| 705 |
+
"57a4a7bb6b9272286e26dc18",
|
| 706 |
+
"57f8d9bbe73f6760f10e916a",
|
| 707 |
+
"58a0a2f33d0b4542479a11b1",
|
| 708 |
+
"58a0dd1a3d0b4542479a28f3",
|
| 709 |
+
"58a1a7914a4d262a170b1101",
|
| 710 |
+
"58a1bc804a4d262a170b2f01",
|
| 711 |
+
"58a1d9d14a4d262a170b58fe",
|
| 712 |
+
"58a01dea38486e3c98475871",
|
| 713 |
+
"58a1f5d74a4d262a170b65fc",
|
| 714 |
+
"58a2a09e156b87103d3d668c",
|
| 715 |
+
"58a2d9c3156b87103d3da90f",
|
| 716 |
+
"58a3ccb0156b87103d3e4332",
|
| 717 |
+
"58a3f2f8156b87103d3e5838",
|
| 718 |
+
"58a3f6c0156b87103d3e5971",
|
| 719 |
+
"58a3fc95156b87103d3e5d9b",
|
| 720 |
+
"58a07ce53d0b45424799fdde",
|
| 721 |
+
"58a07f233d0b45424799ffe7",
|
| 722 |
+
"58a44df2156b87103d3ee239",
|
| 723 |
+
"58a164f73d0b4542479a7a8e",
|
| 724 |
+
"58a0365e38486e3c984783eb",
|
| 725 |
+
"58a439cf156b87103d3ec885",
|
| 726 |
+
"58a464aa156b87103d3eec04",
|
| 727 |
+
"58a4452f156b87103d3ed55b",
|
| 728 |
+
"58a160983d0b4542479a7347",
|
| 729 |
+
"58a186444a4d262a170ae3ae",
|
| 730 |
+
"58a285424a4d262a170baf3e",
|
| 731 |
+
"58a41819156b87103d3e92a5",
|
| 732 |
+
"58a47552156b87103d3f00a4",
|
| 733 |
+
"58c4bb4f4a69c55606122be4",
|
| 734 |
+
"58c6451e4a69c556061894f1",
|
| 735 |
+
"58ca7014affdfd07c70a95ce",
|
| 736 |
+
"58cf4771d0f5fb221defe6da",
|
| 737 |
+
"58d36897f387231e6c929903",
|
| 738 |
+
"58eaf1513353456af3a1682a",
|
| 739 |
+
"58f7f7299f5b5647873cb110",
|
| 740 |
+
"58f73e7c9f5b56478738929f",
|
| 741 |
+
"59a8f851597729752c31e7e0",
|
| 742 |
+
"59a452bf9b460239aa5d1c72",
|
| 743 |
+
"59a9619a825418241fb88191",
|
| 744 |
+
"59acd2f4b891807f439c8992",
|
| 745 |
+
"59bf97fe7e7b31545da34439",
|
| 746 |
+
"59c1c3e2fd6e3d4ead9f1013",
|
| 747 |
+
"59d2657f82ca7774b1ec081d",
|
| 748 |
+
"59da1fb88a126011d0394ae9",
|
| 749 |
+
"59e75a2ca9e91f2c5526005d",
|
| 750 |
+
"59e864b2a9e91f2c5529325f",
|
| 751 |
+
"59ecfd02e225f6492d20fcc9",
|
| 752 |
+
"59f37f74b45be2233001ba18",
|
| 753 |
+
"59f70ab1e5c5d366af29bf3e",
|
| 754 |
+
"59f363a8b45be22330016cad",
|
| 755 |
+
"564a27b26d07883f460d8ab0",
|
| 756 |
+
"565fb1dead14d4154dae2b94",
|
| 757 |
+
"567a0fb0a825d2fb79ac9a20",
|
| 758 |
+
"569b92eb826bcba945ca002b",
|
| 759 |
+
"576fefa017ce5a16397e87fd",
|
| 760 |
+
"584a7333fe3cb463906c9fe6",
|
| 761 |
+
"584aa8e9fe3cb463906cc7d0",
|
| 762 |
+
"584ad76bfe3cb463906ce6dc",
|
| 763 |
+
"584af003fe3cb463906d0e9b",
|
| 764 |
+
"584b9a747072670e72bfc49d",
|
| 765 |
+
"584b671f7072670e72bfaaf8",
|
| 766 |
+
"584b81747072670e72bfbbfd",
|
| 767 |
+
"584ba35f7072670e72bfca4d",
|
| 768 |
+
"584ba5977072670e72bfcc2d",
|
| 769 |
+
"584bc53c7072670e72bfe85f",
|
| 770 |
+
"584bc3997072670e72bfe58d",
|
| 771 |
+
"584bc4407072670e72bfe665",
|
| 772 |
+
"584bd5587072670e72bffe39",
|
| 773 |
+
"584bdadf7072670e72c0005c",
|
| 774 |
+
"584be5ed7072670e72c007b3",
|
| 775 |
+
"584c9ad27072670e72c060c5",
|
| 776 |
+
"584c9cc67072670e72c063a1",
|
| 777 |
+
"584c58b77072670e72c03990",
|
| 778 |
+
"584cea557072670e72c07fb4",
|
| 779 |
+
"584d19d47072670e72c0c6c0",
|
| 780 |
+
"584dfe467072670e72c1665a",
|
| 781 |
+
"584e875c7072670e72c1ec94",
|
| 782 |
+
"584e05667072670e72c17167",
|
| 783 |
+
"584f94e87072670e72c2d3f7",
|
| 784 |
+
"584fdffd7072670e72c32dc7",
|
| 785 |
+
"584fe07f7072670e72c32e59",
|
| 786 |
+
"585a2a71b338a62ad50138dc",
|
| 787 |
+
"585a206ab338a62ad501298f",
|
| 788 |
+
"585a217cb338a62ad5012b38",
|
| 789 |
+
"585b34afb338a62ad501e836",
|
| 790 |
+
"585bb25fc49c8507c3ce7812",
|
| 791 |
+
"585bbe55c49c8507c3ce81cd",
|
| 792 |
+
"585d6c8a2a57cc11d4920a1e",
|
| 793 |
+
"585e54c72a57cc11d492f71a",
|
| 794 |
+
"585e34302a57cc11d492be30",
|
| 795 |
+
"585ee0632a57cc11d4933608",
|
| 796 |
+
"585f9661712e2761468dabca",
|
| 797 |
+
"585ffe9a712e2761468df643",
|
| 798 |
+
"586a37ec9d1b5e34c28184fc",
|
| 799 |
+
"586a515a9d1b5e34c281b431",
|
| 800 |
+
"586a94939d1b5e34c2823b5d",
|
| 801 |
+
"586abc689d1b5e34c2826360",
|
| 802 |
+
"586b0e219d1b5e34c2828862",
|
| 803 |
+
"586b3db89d1b5e34c282cd52",
|
| 804 |
+
"586b4c459d1b5e34c282e66d",
|
| 805 |
+
"586b7d7d9d1b5e34c283359e",
|
| 806 |
+
"586b8f149d1b5e34c283497c",
|
| 807 |
+
"586b8f629d1b5e34c28349d6",
|
| 808 |
+
"586c4c4d9d1b5e34c28391a1",
|
| 809 |
+
"586c5b5b9d1b5e34c2839a5b",
|
| 810 |
+
"586c9fdf9d1b5e34c283b657",
|
| 811 |
+
"586c48329d1b5e34c2838e80",
|
| 812 |
+
"586caab99d1b5e34c283c213",
|
| 813 |
+
"586cd0779d1b5e34c28403a7",
|
| 814 |
+
"586d6d249d1b5e34c284b80e",
|
| 815 |
+
"586d8a029d1b5e34c284c948",
|
| 816 |
+
"586d55af9d1b5e34c284a999",
|
| 817 |
+
"586d07869d1b5e34c2842e5b",
|
| 818 |
+
"586d27489d1b5e34c28453af",
|
| 819 |
+
"586df9849d1b5e34c28506de",
|
| 820 |
+
"586e279c9d1b5e34c2852180",
|
| 821 |
+
"587bc5ec2366dd5d06e262c1",
|
| 822 |
+
"587c1abf2366dd5d06e28901",
|
| 823 |
+
"587c03f12366dd5d06e27722",
|
| 824 |
+
"587c19da2366dd5d06e2877b",
|
| 825 |
+
"587c31b92366dd5d06e2a9dc",
|
| 826 |
+
"587c87d02366dd5d06e2f989",
|
| 827 |
+
"587c97a52366dd5d06e30a96",
|
| 828 |
+
"587c45192366dd5d06e2c0eb",
|
| 829 |
+
"587cec702366dd5d06e37862",
|
| 830 |
+
"587cef0a2366dd5d06e379e3",
|
| 831 |
+
"587db5872366dd5d06e3e0af",
|
| 832 |
+
"587e2b1d2366dd5d06e41af0",
|
| 833 |
+
"587e2ea62366dd5d06e41f2e",
|
| 834 |
+
"587e5cb52366dd5d06e4486e",
|
| 835 |
+
"587eb1822366dd5d06e45f29",
|
| 836 |
+
"587f365d2366dd5d06e4906e",
|
| 837 |
+
"588a9c5fec4d5a1c088ec350",
|
| 838 |
+
"588a34cfec4d5a1c088ea8d1",
|
| 839 |
+
"588ab5bdec4d5a1c088ed60f",
|
| 840 |
+
"588aff9d90414422fbe7885a",
|
| 841 |
+
"588b20d290414422fbe79f40",
|
| 842 |
+
"588c08d590414422fbe8200b",
|
| 843 |
+
"588c203d90414422fbe8319e",
|
| 844 |
+
"588c989a90414422fbe86d96",
|
| 845 |
+
"588ca09d90414422fbe871a1",
|
| 846 |
+
"588cce2190414422fbe88520",
|
| 847 |
+
"588cd5ef90414422fbe8875c",
|
| 848 |
+
"588cf0ad90414422fbe8a20f",
|
| 849 |
+
"588e0d8c90414422fbe8f8b2",
|
| 850 |
+
"588e01c490414422fbe8ee2a",
|
| 851 |
+
"588e35e690414422fbe90a53",
|
| 852 |
+
"588f017e90414422fbe9b74b",
|
| 853 |
+
"588f095190414422fbe9c1ee",
|
| 854 |
+
"589aca717dc3d323d55671c4",
|
| 855 |
+
"589af2c97dc3d323d55691e8",
|
| 856 |
+
"589b49ea7dc3d323d556d9b4",
|
| 857 |
+
"589b04287dc3d323d556a185",
|
| 858 |
+
"589bf6a57dc3d323d55743ab",
|
| 859 |
+
"589c3c497dc3d323d5578468",
|
| 860 |
+
"589c3c577dc3d323d5578480",
|
| 861 |
+
"589c300f7dc3d323d5577926",
|
| 862 |
+
"589c24527dc3d323d5577126",
|
| 863 |
+
"589c35457dc3d323d5577d8d",
|
| 864 |
+
"589ca6a6b896147a1b73aff7",
|
| 865 |
+
"589d1e1fb896147a1b73ee5b",
|
| 866 |
+
"589d5c58b896147a1b742256",
|
| 867 |
+
"589d95538fa2cf375df3317b",
|
| 868 |
+
"589df0ffb504a864ad63521a",
|
| 869 |
+
"589ea316b504a864ad639a2b",
|
| 870 |
+
"589ec97cb504a864ad63adc3",
|
| 871 |
+
"589f214338486e3c9846f123",
|
| 872 |
+
"589fdfe738486e3c984736cf",
|
| 873 |
+
"590c2d70336bb52a190be886",
|
| 874 |
+
"590f91851225725be9e25d4e",
|
| 875 |
+
"591a467a6109e14d4f09b776",
|
| 876 |
+
"591cf3033162411cf9047f37",
|
| 877 |
+
"591ea44850991c70dc99a207",
|
| 878 |
+
"599aa591d5b41f366fed0d58",
|
| 879 |
+
"5643df56138263b51db1b5f3",
|
| 880 |
+
"5644bdac138263b51db9f669",
|
| 881 |
+
"5850d4f97072670e72c425d6",
|
| 882 |
+
"5854c405804be105852330fe",
|
| 883 |
+
"5855a4fc804be1058523bd75",
|
| 884 |
+
"5856ac15804be105852419d8",
|
| 885 |
+
"5856ae8b804be10585241bae",
|
| 886 |
+
"5856b460804be10585242059",
|
| 887 |
+
"5857aa5ab338a62ad5ff4dbe",
|
| 888 |
+
"5857acf8b338a62ad5ff5107",
|
| 889 |
+
"5858db6cb338a62ad500103b",
|
| 890 |
+
"5858dbcab338a62ad5001081",
|
| 891 |
+
"5859d84fb338a62ad500e5cf",
|
| 892 |
+
"5861d8ea712e2761468f3cb3",
|
| 893 |
+
"5863edf8712e27614690cce0",
|
| 894 |
+
"5864b076712e27614691197e",
|
| 895 |
+
"5864da88712e276146913d8b",
|
| 896 |
+
"5865f4a8712e27614691e39b",
|
| 897 |
+
"5867a434833dfe3f7b88edaf",
|
| 898 |
+
"5868cd15833dfe3f7b89bfa3",
|
| 899 |
+
"5880b3692366dd5d06e5d534",
|
| 900 |
+
"5880e3422366dd5d06e5ff8e",
|
| 901 |
+
"5880f0ef2366dd5d06e6166e",
|
| 902 |
+
"5881d2bfb6844814c136a119",
|
| 903 |
+
"5881f11d8ce2c2754d0714c3",
|
| 904 |
+
"5881fee18ce2c2754d0723f8",
|
| 905 |
+
"5882cda2b116682b4adebd25",
|
| 906 |
+
"5882d58fb116682b4adec7db",
|
| 907 |
+
"5884c256932ba84fbed70bf5",
|
| 908 |
+
"5884cc13932ba84fbed71ec4",
|
| 909 |
+
"5885bc5296fa095e0671a7f0",
|
| 910 |
+
"5886d14cb791366d617a362c",
|
| 911 |
+
"5888becfc02346100f4b0b21",
|
| 912 |
+
"5888e408c02346100f4b1a29",
|
| 913 |
+
"5889da66ec4d5a1c088e5187",
|
| 914 |
+
"5889e344ec4d5a1c088e59be",
|
| 915 |
+
"5889e754ec4d5a1c088e60ba",
|
| 916 |
+
"5890c16b90414422fbeb0262",
|
| 917 |
+
"5891d8ae9a8c0314c5cd30ab",
|
| 918 |
+
"5891d0479a8c0314c5cd2abd",
|
| 919 |
+
"5891ecf19a8c0314c5cd490a",
|
| 920 |
+
"5892c0cd9a8c0314c5cdc977",
|
| 921 |
+
"5894ab309a8c0314c5cee57d",
|
| 922 |
+
"5895a6a89a8c0314c5cfca7c",
|
| 923 |
+
"5895b8c29a8c0314c5cfd051",
|
| 924 |
+
"5895d38f9a8c0314c5cfe50c",
|
| 925 |
+
"5895f2329a8c0314c5d00117",
|
| 926 |
+
"5896bb989a8c0314c5d086b6",
|
| 927 |
+
"5896ebf39a8c0314c5d0a8c4",
|
| 928 |
+
"5898b1bac9dccc22987b7f74",
|
| 929 |
+
"5898b6ffc9dccc22987b8a03",
|
| 930 |
+
"5898b31cc9dccc22987b82ec",
|
| 931 |
+
"5898bbaac9dccc22987b8eba",
|
| 932 |
+
"5899cfa6b76d7a3780a4cb64",
|
| 933 |
+
"5899e5dcb76d7a3780a4ecc1",
|
| 934 |
+
"5947b62af1b45630bd0c2a02",
|
| 935 |
+
"57102be2877e1421026358af",
|
| 936 |
+
"57153d4031bb9900425bde85",
|
| 937 |
+
"57177cd7fb8d93461afc4527",
|
| 938 |
+
"58497cdf97b73e0b090c4273",
|
| 939 |
+
"58500b007072670e72c35588",
|
| 940 |
+
"58510bf97072670e72c46ddf",
|
| 941 |
+
"58522bd56789802282f2ecb3",
|
| 942 |
+
"58524a2e0e7012308944bcf3",
|
| 943 |
+
"58524a080e7012308944bcbf",
|
| 944 |
+
"58524c1d0e7012308944bfda",
|
| 945 |
+
"58524f170e7012308944c200",
|
| 946 |
+
"58529a4e0e70123089454c6f",
|
| 947 |
+
"58551bdf804be1058523556d",
|
| 948 |
+
"58568c9a804be10585240b03",
|
| 949 |
+
"58574b35804be105852455fd",
|
| 950 |
+
"58577c60b338a62ad5ff1564",
|
| 951 |
+
"58592d69b338a62ad5007a74",
|
| 952 |
+
"58598db2b338a62ad500bc38",
|
| 953 |
+
"58625f42712e2761468fb44c",
|
| 954 |
+
"58651bcc712e2761469166dc",
|
| 955 |
+
"58660e79712e27614691fe3d",
|
| 956 |
+
"58669aad712e27614692834c",
|
| 957 |
+
"58669c02712e27614692851a",
|
| 958 |
+
"58676c36833dfe3f7b88b7f2",
|
| 959 |
+
"58678b2d833dfe3f7b88e244",
|
| 960 |
+
"58790c82ce911104a3467c88",
|
| 961 |
+
"58800b0b2366dd5d06e5312d",
|
| 962 |
+
"58805eac2366dd5d06e56460",
|
| 963 |
+
"58806e422366dd5d06e57bb6",
|
| 964 |
+
"58831d060db9bf59bf8ab98b",
|
| 965 |
+
"58851ebb932ba84fbed7abad",
|
| 966 |
+
"58871dc3b791366d617a55ff",
|
| 967 |
+
"58873cabb791366d617a65a7",
|
| 968 |
+
"58873d44b791366d617a65dd",
|
| 969 |
+
"58888b3dc02346100f4af665",
|
| 970 |
+
"58897f62c02346100f4b8ee6",
|
| 971 |
+
"58933bac9a8c0314c5ce3508",
|
| 972 |
+
"58938e6d9a8c0314c5ce726f",
|
| 973 |
+
"58951cb49a8c0314c5cf4d5e",
|
| 974 |
+
"58970fd09a8c0314c5d0e383",
|
| 975 |
+
"58977ef09a8c0314c5d17b26",
|
| 976 |
+
"59056e6760bb961de55f3501",
|
| 977 |
+
"59071f2e5a6dbd3af4130f98",
|
| 978 |
+
"59102c811225725be9e64149",
|
| 979 |
+
"59338e76772c3e6384afbb15",
|
| 980 |
+
"59350ca084b7f26bf5ce6eb8",
|
| 981 |
+
"59397e493a87372f2c9e882b",
|
| 982 |
+
"59521e0b9096412211c2aa9d",
|
| 983 |
+
"59817e4a1bd4b175e7038d19",
|
| 984 |
+
"567884f58d2828b95e3c8eba",
|
| 985 |
+
"585559d9804be10585238ddf",
|
| 986 |
+
"585834cdb338a62ad5ffab4d",
|
| 987 |
+
"586082d8712e2761468e2877",
|
| 988 |
+
"586133c2712e2761468ecfe3",
|
| 989 |
+
"586281d2712e2761468fcaa2",
|
| 990 |
+
"586316e5712e276146903c4d",
|
| 991 |
+
"586326ad712e276146904571",
|
| 992 |
+
"586375c9712e276146907429",
|
| 993 |
+
"586389c9712e276146908da6",
|
| 994 |
+
"586496fa712e2761469108e7",
|
| 995 |
+
"586669c6712e27614692597a",
|
| 996 |
+
"586913a49d1b5e34c2808b02",
|
| 997 |
+
"586922da9d1b5e34c2809ff3",
|
| 998 |
+
"588185d8dfb7a15588a114a3",
|
| 999 |
+
"588305ed0db9bf59bf8a8c80",
|
| 1000 |
+
"588315c60db9bf59bf8aa928",
|
| 1001 |
+
"588332ee0db9bf59bf8ae9c3",
|
| 1002 |
+
"588457b8932ba84fbed69942",
|
| 1003 |
+
"588519d5932ba84fbed7a04a",
|
| 1004 |
+
"588824d1b791366d617adeef",
|
| 1005 |
+
"588857f6c02346100f4ac09f",
|
| 1006 |
+
"589145ef90414422fbeb2e08",
|
| 1007 |
+
"589433fa9a8c0314c5ce9656",
|
| 1008 |
+
"589765d39a8c0314c5d16b12",
|
| 1009 |
+
"5851165f7072670e72c4860d",
|
| 1010 |
+
"5859341ab338a62ad500848d",
|
| 1011 |
+
"5862388b712e2761468f84aa",
|
| 1012 |
+
"5863915b712e276146909135",
|
| 1013 |
+
"5866445b712e27614692383e",
|
| 1014 |
+
"5866500d712e2761469240fd",
|
| 1015 |
+
"5867785a833dfe3f7b88c764",
|
| 1016 |
+
"5867969c833dfe3f7b88e8bc",
|
| 1017 |
+
"5868040c833dfe3f7b8934f7",
|
| 1018 |
+
"5880675a2366dd5d06e570ca",
|
| 1019 |
+
"5882372c8ce2c2754d076af0",
|
| 1020 |
+
"5883535e932ba84fbed5ad07",
|
| 1021 |
+
"5888358cb791366d617af69d",
|
| 1022 |
+
"5890330d90414422fbeaa0cb",
|
| 1023 |
+
"5897076e9a8c0314c5d0d31b",
|
| 1024 |
+
"5940564ec2d9527ab869f7e2",
|
| 1025 |
+
"5947719bf1b45630bd096665",
|
| 1026 |
+
"5948194ff1b45630bd0f47e3",
|
| 1027 |
+
"5950206a41b158666ac50506",
|
| 1028 |
+
"5983012d1bd4b175e70c985a",
|
| 1029 |
+
"58586810b338a62ad5ffc20c",
|
| 1030 |
+
"58592046b338a62ad5006b33",
|
| 1031 |
+
"58592854b338a62ad500750a",
|
| 1032 |
+
"58596531b338a62ad500aace",
|
| 1033 |
+
"58818685dfb7a15588a11626",
|
| 1034 |
+
"58829563f42b1d3ee3ec835f",
|
| 1035 |
+
"58894345c02346100f4b51ca",
|
| 1036 |
+
"585289980e7012308945276a",
|
| 1037 |
+
"585369770e7012308945c709",
|
| 1038 |
+
"585373640e7012308945cab9",
|
| 1039 |
+
"588230658ce2c2754d076728",
|
| 1040 |
+
"589388059a8c0314c5ce718b",
|
| 1041 |
+
"595979485ec6a95e86a58c8d",
|
| 1042 |
+
"5841206219d291325678ca90",
|
| 1043 |
+
"58563650804be1058523da55",
|
| 1044 |
+
"58564084804be1058523e116",
|
| 1045 |
+
"58636467712e27614690661f",
|
| 1046 |
+
"58647495712e27614690f36d",
|
| 1047 |
+
"58654563712e276146918643",
|
| 1048 |
+
"58664251712e276146923738",
|
| 1049 |
+
"588084032366dd5d06e59e82",
|
| 1050 |
+
"588159582366dd5d06e66877",
|
| 1051 |
+
"5890279190414422fbea9734",
|
| 1052 |
+
"5890523090414422fbeab3f0",
|
| 1053 |
+
"5890641690414422fbeabbe7",
|
| 1054 |
+
"585203546789802282f2aaf5",
|
| 1055 |
+
]
|
| 1056 |
+
|
| 1057 |
+
# Train set sequences after filtering
|
| 1058 |
+
self.train_split_scenes = [
|
| 1059 |
+
"00000000000000000000000b",
|
| 1060 |
+
"00000000000000000000000d",
|
| 1061 |
+
"00000000000000000000000e",
|
| 1062 |
+
"00000000000000000000000f",
|
| 1063 |
+
"000000000000000000000001",
|
| 1064 |
+
"00000000000000000000001b",
|
| 1065 |
+
"00000000000000000000001d",
|
| 1066 |
+
"000000000000000000000002",
|
| 1067 |
+
"000000000000000000000003",
|
| 1068 |
+
"000000000000000000000004",
|
| 1069 |
+
"000000000000000000000005",
|
| 1070 |
+
"5a2a95f032a1c655cfe3de62",
|
| 1071 |
+
"5a2af22b32a1c655cfe46013",
|
| 1072 |
+
"5a2ba6de32a1c655cfe51b79",
|
| 1073 |
+
"5a3b9731e24cd76dad1a5f1b",
|
| 1074 |
+
"5a3ca9cb270f0e3f14d0eddb",
|
| 1075 |
+
"5a3cb4e4270f0e3f14d12f43",
|
| 1076 |
+
"5a03e732454a8a7ec672776c",
|
| 1077 |
+
"5a3f4aba5889373fbbc5d3b5",
|
| 1078 |
+
"5a5a1e48d62c7a12d5d00e47",
|
| 1079 |
+
"5a6b1c418d100c2f8fdc4411",
|
| 1080 |
+
"5a6feeb54a7fbc3f874f9db7",
|
| 1081 |
+
"5a7cb1d6fe5c0d6fb53e64fb",
|
| 1082 |
+
"5a7d3db14989e929563eb153",
|
| 1083 |
+
"5a8aa0fab18050187cbe060e",
|
| 1084 |
+
"5a9e5df65baeef72b4a021cd",
|
| 1085 |
+
"5a48ba95c7dab83a7d7b44ed",
|
| 1086 |
+
"5a48c4e9c7dab83a7d7b5cc7",
|
| 1087 |
+
"5a48d4b2c7dab83a7d7b9851",
|
| 1088 |
+
"5a69c47d0d5d0a7f3b2e9752",
|
| 1089 |
+
"5a77b46b318efe6c6736e68a",
|
| 1090 |
+
"5a355c271b63f53d5970f362",
|
| 1091 |
+
"5a533e8034d7582116e34209",
|
| 1092 |
+
"5a562fc7425d0f5186314725",
|
| 1093 |
+
"5a618c72784780334bc1972d",
|
| 1094 |
+
"5a752d42acc41e2423f17674",
|
| 1095 |
+
"5a969eea91dfc339a9a3ad2c",
|
| 1096 |
+
"5a8315f624b8e938486e0bd8",
|
| 1097 |
+
"5a57542f333d180827dfc132",
|
| 1098 |
+
"5a0271884e62597cdee0d0eb",
|
| 1099 |
+
"5a6400933d809f1d8200af15",
|
| 1100 |
+
"5a6464143d809f1d8208c43c",
|
| 1101 |
+
"5a563183425d0f5186314855",
|
| 1102 |
+
"5aa0f9d7a9efce63548c69a1",
|
| 1103 |
+
"5aa7db90bfdd572271e95246",
|
| 1104 |
+
"5aa235f64a17b335eeaf9609",
|
| 1105 |
+
"5aa515e613d42d091d29d300",
|
| 1106 |
+
"5aa1196ea9efce63548ed649",
|
| 1107 |
+
"5aaadd4cbc13235570d178a7",
|
| 1108 |
+
"5ab6af12ac4291329b1072ab",
|
| 1109 |
+
"5ab7e00aac4291329b15864d",
|
| 1110 |
+
"5ab8b8e029f5351f7f2ccf59",
|
| 1111 |
+
"5ab74bf2ac4291329b11e879",
|
| 1112 |
+
"5ab85f1dac4291329b17cb50",
|
| 1113 |
+
"5ab8713ba3799a1d138bd69a",
|
| 1114 |
+
"5abc2506b53b042ead637d86",
|
| 1115 |
+
"5acc7459a7853c4b5ebbef59",
|
| 1116 |
+
"5acf8ca0f3d8a750097e4b15",
|
| 1117 |
+
"5adc6bd52430a05ecb2ffb85",
|
| 1118 |
+
"5af02e904c8216544b4ab5a2",
|
| 1119 |
+
"5af28cea59bc705737003253",
|
| 1120 |
+
"5af545d0559359053d25dcf5",
|
| 1121 |
+
"5afacb69ab00705d0cefdd5b",
|
| 1122 |
+
"5b3b2b9e8d46a939f933fdc0",
|
| 1123 |
+
"5b3b353d8d46a939f93524b9",
|
| 1124 |
+
"5b6e716d67b396324c2d77cb",
|
| 1125 |
+
"5b6eff8b67b396324c5b2672",
|
| 1126 |
+
"5b7a3890fc8fcf6781e2593a",
|
| 1127 |
+
"5b60fa0c764f146feef84df0",
|
| 1128 |
+
"5b69cc0cb44b61786eb959bf",
|
| 1129 |
+
"5b78e57afc8fcf6781d0c3ba",
|
| 1130 |
+
"5b192eb2170cf166458ff886",
|
| 1131 |
+
"5b558a928bbfb62204e77ba2",
|
| 1132 |
+
"5b908d3dc6ab78485f3d24a9",
|
| 1133 |
+
"5b950c71608de421b1e7318f",
|
| 1134 |
+
"5b08286b2775267d5b0634ba",
|
| 1135 |
+
"5b271079e0878c3816dacca4",
|
| 1136 |
+
"5b22269758e2823a67a3bd03",
|
| 1137 |
+
"5b62647143840965efc0dbde",
|
| 1138 |
+
"5ba19a8a360c7c30c1c169df",
|
| 1139 |
+
"5ba75d79d76ffa2c86cf2f05",
|
| 1140 |
+
"5bb7a08aea1cfa39f1a947ab",
|
| 1141 |
+
"5bb8a49aea1cfa39f1aa7f75",
|
| 1142 |
+
"5bbb6eb2ea1cfa39f1af7e0c",
|
| 1143 |
+
"5bce7ac9ca24970bce4934b6",
|
| 1144 |
+
"5bcf979a6d5f586b95c258cd",
|
| 1145 |
+
"5bd43b4ba6b28b1ee86b92dd",
|
| 1146 |
+
"5be3a5fb8cfdd56947f6b67c",
|
| 1147 |
+
"5be3ae47f44e235bdbbc9771",
|
| 1148 |
+
"5be4ab93870d330ff2dce134",
|
| 1149 |
+
"5be47bf9b18881428d8fbc1d",
|
| 1150 |
+
"5be883a4f98cee15019d5b83",
|
| 1151 |
+
"5bea87f4abd34c35e1860ab5",
|
| 1152 |
+
"5beb6e66abd34c35e18e66b9",
|
| 1153 |
+
"5bf3a82cd439231948877aed",
|
| 1154 |
+
"5bf7d63575c26f32dbf7413b",
|
| 1155 |
+
"5bf17c0fd439231948355385",
|
| 1156 |
+
"5bf21799d43923194842c001",
|
| 1157 |
+
"5bfd0f32ec61ca1dd69dc77b",
|
| 1158 |
+
"5bfe5ae0fe0ea555e6a969ca",
|
| 1159 |
+
"5c0d13b795da9479e12e2ee9",
|
| 1160 |
+
"5c1af2e2bee9a723c963d019",
|
| 1161 |
+
"5c1b1500bee9a723c96c3e78",
|
| 1162 |
+
"5c1dbf200843bc542d8ef8c4",
|
| 1163 |
+
"5c20ca3a0843bc542d94e3e2",
|
| 1164 |
+
"5c062d84a96e33018ff6f0a6",
|
| 1165 |
+
"5c189f2326173c3a09ed7ef3",
|
| 1166 |
+
"5c1892f726173c3a09ea9aeb",
|
| 1167 |
+
"5c34300a73a8df509add216d",
|
| 1168 |
+
"000000000000000000000006",
|
| 1169 |
+
"000000000000000000000007",
|
| 1170 |
+
"000000000000000000000008",
|
| 1171 |
+
"000000000000000000000009",
|
| 1172 |
+
"000000000000000000000010",
|
| 1173 |
+
"000000000000000000000011",
|
| 1174 |
+
"000000000000000000000012",
|
| 1175 |
+
"000000000000000000000015",
|
| 1176 |
+
"000000000000000000000016",
|
| 1177 |
+
"000000000000000000000017",
|
| 1178 |
+
"000000000000000000000018",
|
| 1179 |
+
"000000000000000000000019",
|
| 1180 |
+
"56d73ba74bd29b8c35abade2",
|
| 1181 |
+
"56f34064e296120e10484dc4",
|
| 1182 |
+
"57a4a7bb6b9272286e26dc18",
|
| 1183 |
+
"57f8d9bbe73f6760f10e916a",
|
| 1184 |
+
"58a0a2f33d0b4542479a11b1",
|
| 1185 |
+
"58a0dd1a3d0b4542479a28f3",
|
| 1186 |
+
"58a1a7914a4d262a170b1101",
|
| 1187 |
+
"58a1bc804a4d262a170b2f01",
|
| 1188 |
+
"58a1d9d14a4d262a170b58fe",
|
| 1189 |
+
"58a01dea38486e3c98475871",
|
| 1190 |
+
"58a1f5d74a4d262a170b65fc",
|
| 1191 |
+
"58a2a09e156b87103d3d668c",
|
| 1192 |
+
"58a2d9c3156b87103d3da90f",
|
| 1193 |
+
"58a3ccb0156b87103d3e4332",
|
| 1194 |
+
"58a3f2f8156b87103d3e5838",
|
| 1195 |
+
"58a3f6c0156b87103d3e5971",
|
| 1196 |
+
"58a3fc95156b87103d3e5d9b",
|
| 1197 |
+
"58a07ce53d0b45424799fdde",
|
| 1198 |
+
"58a07f233d0b45424799ffe7",
|
| 1199 |
+
"58a44df2156b87103d3ee239",
|
| 1200 |
+
"58a164f73d0b4542479a7a8e",
|
| 1201 |
+
"58a0365e38486e3c984783eb",
|
| 1202 |
+
"58a439cf156b87103d3ec885",
|
| 1203 |
+
"58a464aa156b87103d3eec04",
|
| 1204 |
+
"58a4452f156b87103d3ed55b",
|
| 1205 |
+
"58a160983d0b4542479a7347",
|
| 1206 |
+
"58a285424a4d262a170baf3e",
|
| 1207 |
+
"58a41819156b87103d3e92a5",
|
| 1208 |
+
"58a47552156b87103d3f00a4",
|
| 1209 |
+
"58c4bb4f4a69c55606122be4",
|
| 1210 |
+
"58c6451e4a69c556061894f1",
|
| 1211 |
+
"58ca7014affdfd07c70a95ce",
|
| 1212 |
+
"58cf4771d0f5fb221defe6da",
|
| 1213 |
+
"58d36897f387231e6c929903",
|
| 1214 |
+
"58eaf1513353456af3a1682a",
|
| 1215 |
+
"58f73e7c9f5b56478738929f",
|
| 1216 |
+
"59a8f851597729752c31e7e0",
|
| 1217 |
+
"59a452bf9b460239aa5d1c72",
|
| 1218 |
+
"59a9619a825418241fb88191",
|
| 1219 |
+
"59bf97fe7e7b31545da34439",
|
| 1220 |
+
"59c1c3e2fd6e3d4ead9f1013",
|
| 1221 |
+
"59d2657f82ca7774b1ec081d",
|
| 1222 |
+
"59da1fb88a126011d0394ae9",
|
| 1223 |
+
"59e75a2ca9e91f2c5526005d",
|
| 1224 |
+
"59e864b2a9e91f2c5529325f",
|
| 1225 |
+
"59ecfd02e225f6492d20fcc9",
|
| 1226 |
+
"59f37f74b45be2233001ba18",
|
| 1227 |
+
"59f70ab1e5c5d366af29bf3e",
|
| 1228 |
+
"59f363a8b45be22330016cad",
|
| 1229 |
+
"564a27b26d07883f460d8ab0",
|
| 1230 |
+
"565fb1dead14d4154dae2b94",
|
| 1231 |
+
"569b92eb826bcba945ca002b",
|
| 1232 |
+
"576fefa017ce5a16397e87fd",
|
| 1233 |
+
"584a7333fe3cb463906c9fe6",
|
| 1234 |
+
"584aa8e9fe3cb463906cc7d0",
|
| 1235 |
+
"584af003fe3cb463906d0e9b",
|
| 1236 |
+
"584b9a747072670e72bfc49d",
|
| 1237 |
+
"584b671f7072670e72bfaaf8",
|
| 1238 |
+
"584b81747072670e72bfbbfd",
|
| 1239 |
+
"584ba35f7072670e72bfca4d",
|
| 1240 |
+
"584ba5977072670e72bfcc2d",
|
| 1241 |
+
"584bc53c7072670e72bfe85f",
|
| 1242 |
+
"584bc3997072670e72bfe58d",
|
| 1243 |
+
"584bc4407072670e72bfe665",
|
| 1244 |
+
"584bd5587072670e72bffe39",
|
| 1245 |
+
"584bdadf7072670e72c0005c",
|
| 1246 |
+
"584be5ed7072670e72c007b3",
|
| 1247 |
+
"584c9ad27072670e72c060c5",
|
| 1248 |
+
"584c9cc67072670e72c063a1",
|
| 1249 |
+
"584cea557072670e72c07fb4",
|
| 1250 |
+
"584d19d47072670e72c0c6c0",
|
| 1251 |
+
"584dfe467072670e72c1665a",
|
| 1252 |
+
"584e875c7072670e72c1ec94",
|
| 1253 |
+
"584e05667072670e72c17167",
|
| 1254 |
+
"584f94e87072670e72c2d3f7",
|
| 1255 |
+
"584fdffd7072670e72c32dc7",
|
| 1256 |
+
"584fe07f7072670e72c32e59",
|
| 1257 |
+
"585a2a71b338a62ad50138dc",
|
| 1258 |
+
"585a206ab338a62ad501298f",
|
| 1259 |
+
"585a217cb338a62ad5012b38",
|
| 1260 |
+
"585b34afb338a62ad501e836",
|
| 1261 |
+
"585bb25fc49c8507c3ce7812",
|
| 1262 |
+
"585bbe55c49c8507c3ce81cd",
|
| 1263 |
+
"585d6c8a2a57cc11d4920a1e",
|
| 1264 |
+
"585e54c72a57cc11d492f71a",
|
| 1265 |
+
"585e34302a57cc11d492be30",
|
| 1266 |
+
"585ee0632a57cc11d4933608",
|
| 1267 |
+
"585f9661712e2761468dabca",
|
| 1268 |
+
"585ffe9a712e2761468df643",
|
| 1269 |
+
"586a37ec9d1b5e34c28184fc",
|
| 1270 |
+
"586a515a9d1b5e34c281b431",
|
| 1271 |
+
"586a94939d1b5e34c2823b5d",
|
| 1272 |
+
"586abc689d1b5e34c2826360",
|
| 1273 |
+
"586b0e219d1b5e34c2828862",
|
| 1274 |
+
"586b3db89d1b5e34c282cd52",
|
| 1275 |
+
"586b4c459d1b5e34c282e66d",
|
| 1276 |
+
"586b7d7d9d1b5e34c283359e",
|
| 1277 |
+
"586b8f149d1b5e34c283497c",
|
| 1278 |
+
"586b8f629d1b5e34c28349d6",
|
| 1279 |
+
"586c4c4d9d1b5e34c28391a1",
|
| 1280 |
+
"586c5b5b9d1b5e34c2839a5b",
|
| 1281 |
+
"586c9fdf9d1b5e34c283b657",
|
| 1282 |
+
"586caab99d1b5e34c283c213",
|
| 1283 |
+
"586cd0779d1b5e34c28403a7",
|
| 1284 |
+
"586d6d249d1b5e34c284b80e",
|
| 1285 |
+
"586d8a029d1b5e34c284c948",
|
| 1286 |
+
"586d55af9d1b5e34c284a999",
|
| 1287 |
+
"586d07869d1b5e34c2842e5b",
|
| 1288 |
+
"586d27489d1b5e34c28453af",
|
| 1289 |
+
"586e279c9d1b5e34c2852180",
|
| 1290 |
+
"587bc5ec2366dd5d06e262c1",
|
| 1291 |
+
"587c1abf2366dd5d06e28901",
|
| 1292 |
+
"587c03f12366dd5d06e27722",
|
| 1293 |
+
"587c19da2366dd5d06e2877b",
|
| 1294 |
+
"587c31b92366dd5d06e2a9dc",
|
| 1295 |
+
"587c87d02366dd5d06e2f989",
|
| 1296 |
+
"587c97a52366dd5d06e30a96",
|
| 1297 |
+
"587c45192366dd5d06e2c0eb",
|
| 1298 |
+
"587cec702366dd5d06e37862",
|
| 1299 |
+
"587cef0a2366dd5d06e379e3",
|
| 1300 |
+
"587db5872366dd5d06e3e0af",
|
| 1301 |
+
"587e2b1d2366dd5d06e41af0",
|
| 1302 |
+
"587e2ea62366dd5d06e41f2e",
|
| 1303 |
+
"587e5cb52366dd5d06e4486e",
|
| 1304 |
+
"587eb1822366dd5d06e45f29",
|
| 1305 |
+
"587f365d2366dd5d06e4906e",
|
| 1306 |
+
"588a9c5fec4d5a1c088ec350",
|
| 1307 |
+
"588a34cfec4d5a1c088ea8d1",
|
| 1308 |
+
"588ab5bdec4d5a1c088ed60f",
|
| 1309 |
+
"588aff9d90414422fbe7885a",
|
| 1310 |
+
"588b20d290414422fbe79f40",
|
| 1311 |
+
"588c08d590414422fbe8200b",
|
| 1312 |
+
"588c203d90414422fbe8319e",
|
| 1313 |
+
"588c989a90414422fbe86d96",
|
| 1314 |
+
"588ca09d90414422fbe871a1",
|
| 1315 |
+
"588cce2190414422fbe88520",
|
| 1316 |
+
"588cd5ef90414422fbe8875c",
|
| 1317 |
+
"588cf0ad90414422fbe8a20f",
|
| 1318 |
+
"588e01c490414422fbe8ee2a",
|
| 1319 |
+
"588e35e690414422fbe90a53",
|
| 1320 |
+
"588f017e90414422fbe9b74b",
|
| 1321 |
+
"588f095190414422fbe9c1ee",
|
| 1322 |
+
"589aca717dc3d323d55671c4",
|
| 1323 |
+
"589af2c97dc3d323d55691e8",
|
| 1324 |
+
"589b49ea7dc3d323d556d9b4",
|
| 1325 |
+
"589b04287dc3d323d556a185",
|
| 1326 |
+
"589bf6a57dc3d323d55743ab",
|
| 1327 |
+
"589c3c497dc3d323d5578468",
|
| 1328 |
+
"589c3c577dc3d323d5578480",
|
| 1329 |
+
"589c24527dc3d323d5577126",
|
| 1330 |
+
"589c35457dc3d323d5577d8d",
|
| 1331 |
+
"589ca6a6b896147a1b73aff7",
|
| 1332 |
+
"589d1e1fb896147a1b73ee5b",
|
| 1333 |
+
"589d5c58b896147a1b742256",
|
| 1334 |
+
"589d95538fa2cf375df3317b",
|
| 1335 |
+
"589df0ffb504a864ad63521a",
|
| 1336 |
+
"589ea316b504a864ad639a2b",
|
| 1337 |
+
"589ec97cb504a864ad63adc3",
|
| 1338 |
+
"589f214338486e3c9846f123",
|
| 1339 |
+
"589fdfe738486e3c984736cf",
|
| 1340 |
+
"590c2d70336bb52a190be886",
|
| 1341 |
+
"591a467a6109e14d4f09b776",
|
| 1342 |
+
"591cf3033162411cf9047f37",
|
| 1343 |
+
"591ea44850991c70dc99a207",
|
| 1344 |
+
"599aa591d5b41f366fed0d58",
|
| 1345 |
+
"5643df56138263b51db1b5f3",
|
| 1346 |
+
"5644bdac138263b51db9f669",
|
| 1347 |
+
"5850d4f97072670e72c425d6",
|
| 1348 |
+
"5854c405804be105852330fe",
|
| 1349 |
+
"5855a4fc804be1058523bd75",
|
| 1350 |
+
"5856ac15804be105852419d8",
|
| 1351 |
+
"5856ae8b804be10585241bae",
|
| 1352 |
+
"5856b460804be10585242059",
|
| 1353 |
+
"5857aa5ab338a62ad5ff4dbe",
|
| 1354 |
+
"5857acf8b338a62ad5ff5107",
|
| 1355 |
+
"5858db6cb338a62ad500103b",
|
| 1356 |
+
"5858dbcab338a62ad5001081",
|
| 1357 |
+
"5859d84fb338a62ad500e5cf",
|
| 1358 |
+
"5861d8ea712e2761468f3cb3",
|
| 1359 |
+
"5863edf8712e27614690cce0",
|
| 1360 |
+
"5864b076712e27614691197e",
|
| 1361 |
+
"5864da88712e276146913d8b",
|
| 1362 |
+
"5865f4a8712e27614691e39b",
|
| 1363 |
+
"5867a434833dfe3f7b88edaf",
|
| 1364 |
+
"5868cd15833dfe3f7b89bfa3",
|
| 1365 |
+
"5880b3692366dd5d06e5d534",
|
| 1366 |
+
"5880e3422366dd5d06e5ff8e",
|
| 1367 |
+
"5880f0ef2366dd5d06e6166e",
|
| 1368 |
+
"5881d2bfb6844814c136a119",
|
| 1369 |
+
"5881f11d8ce2c2754d0714c3",
|
| 1370 |
+
"5881fee18ce2c2754d0723f8",
|
| 1371 |
+
"5882cda2b116682b4adebd25",
|
| 1372 |
+
"5882d58fb116682b4adec7db",
|
| 1373 |
+
"5884c256932ba84fbed70bf5",
|
| 1374 |
+
"5884cc13932ba84fbed71ec4",
|
| 1375 |
+
"5885bc5296fa095e0671a7f0",
|
| 1376 |
+
"5886d14cb791366d617a362c",
|
| 1377 |
+
"5888becfc02346100f4b0b21",
|
| 1378 |
+
"5888e408c02346100f4b1a29",
|
| 1379 |
+
"5889da66ec4d5a1c088e5187",
|
| 1380 |
+
"5889e754ec4d5a1c088e60ba",
|
| 1381 |
+
"5890c16b90414422fbeb0262",
|
| 1382 |
+
"5891d8ae9a8c0314c5cd30ab",
|
| 1383 |
+
"5891d0479a8c0314c5cd2abd",
|
| 1384 |
+
"5891ecf19a8c0314c5cd490a",
|
| 1385 |
+
"5892c0cd9a8c0314c5cdc977",
|
| 1386 |
+
"5894ab309a8c0314c5cee57d",
|
| 1387 |
+
"5895a6a89a8c0314c5cfca7c",
|
| 1388 |
+
"5895b8c29a8c0314c5cfd051",
|
| 1389 |
+
"5895d38f9a8c0314c5cfe50c",
|
| 1390 |
+
"5895f2329a8c0314c5d00117",
|
| 1391 |
+
"5896bb989a8c0314c5d086b6",
|
| 1392 |
+
"5896ebf39a8c0314c5d0a8c4",
|
| 1393 |
+
"5898b1bac9dccc22987b7f74",
|
| 1394 |
+
"5898b6ffc9dccc22987b8a03",
|
| 1395 |
+
"5898bbaac9dccc22987b8eba",
|
| 1396 |
+
"5899cfa6b76d7a3780a4cb64",
|
| 1397 |
+
"5899e5dcb76d7a3780a4ecc1",
|
| 1398 |
+
"57102be2877e1421026358af",
|
| 1399 |
+
"57153d4031bb9900425bde85",
|
| 1400 |
+
"57177cd7fb8d93461afc4527",
|
| 1401 |
+
"58497cdf97b73e0b090c4273",
|
| 1402 |
+
"58500b007072670e72c35588",
|
| 1403 |
+
"58510bf97072670e72c46ddf",
|
| 1404 |
+
"58522bd56789802282f2ecb3",
|
| 1405 |
+
"58524a2e0e7012308944bcf3",
|
| 1406 |
+
"58524a080e7012308944bcbf",
|
| 1407 |
+
"58524c1d0e7012308944bfda",
|
| 1408 |
+
"58524f170e7012308944c200",
|
| 1409 |
+
"58529a4e0e70123089454c6f",
|
| 1410 |
+
"58551bdf804be1058523556d",
|
| 1411 |
+
"58568c9a804be10585240b03",
|
| 1412 |
+
"58574b35804be105852455fd",
|
| 1413 |
+
"58577c60b338a62ad5ff1564",
|
| 1414 |
+
"58592d69b338a62ad5007a74",
|
| 1415 |
+
"58625f42712e2761468fb44c",
|
| 1416 |
+
"58651bcc712e2761469166dc",
|
| 1417 |
+
"58660e79712e27614691fe3d",
|
| 1418 |
+
"58669aad712e27614692834c",
|
| 1419 |
+
"58676c36833dfe3f7b88b7f2",
|
| 1420 |
+
"58678b2d833dfe3f7b88e244",
|
| 1421 |
+
"58800b0b2366dd5d06e5312d",
|
| 1422 |
+
"58805eac2366dd5d06e56460",
|
| 1423 |
+
"58806e422366dd5d06e57bb6",
|
| 1424 |
+
"58831d060db9bf59bf8ab98b",
|
| 1425 |
+
"58851ebb932ba84fbed7abad",
|
| 1426 |
+
"58871dc3b791366d617a55ff",
|
| 1427 |
+
"58873cabb791366d617a65a7",
|
| 1428 |
+
"58873d44b791366d617a65dd",
|
| 1429 |
+
"58888b3dc02346100f4af665",
|
| 1430 |
+
"58933bac9a8c0314c5ce3508",
|
| 1431 |
+
"58938e6d9a8c0314c5ce726f",
|
| 1432 |
+
"58951cb49a8c0314c5cf4d5e",
|
| 1433 |
+
"58970fd09a8c0314c5d0e383",
|
| 1434 |
+
"58977ef09a8c0314c5d17b26",
|
| 1435 |
+
"59056e6760bb961de55f3501",
|
| 1436 |
+
"59071f2e5a6dbd3af4130f98",
|
| 1437 |
+
"59102c811225725be9e64149",
|
| 1438 |
+
"59338e76772c3e6384afbb15",
|
| 1439 |
+
"59350ca084b7f26bf5ce6eb8",
|
| 1440 |
+
"59397e493a87372f2c9e882b",
|
| 1441 |
+
"59521e0b9096412211c2aa9d",
|
| 1442 |
+
"59817e4a1bd4b175e7038d19",
|
| 1443 |
+
"567884f58d2828b95e3c8eba",
|
| 1444 |
+
"585559d9804be10585238ddf",
|
| 1445 |
+
"585834cdb338a62ad5ffab4d",
|
| 1446 |
+
"586082d8712e2761468e2877",
|
| 1447 |
+
"586133c2712e2761468ecfe3",
|
| 1448 |
+
"586281d2712e2761468fcaa2",
|
| 1449 |
+
"586316e5712e276146903c4d",
|
| 1450 |
+
"586326ad712e276146904571",
|
| 1451 |
+
"586375c9712e276146907429",
|
| 1452 |
+
"586389c9712e276146908da6",
|
| 1453 |
+
"586496fa712e2761469108e7",
|
| 1454 |
+
"586669c6712e27614692597a",
|
| 1455 |
+
"586913a49d1b5e34c2808b02",
|
| 1456 |
+
"586922da9d1b5e34c2809ff3",
|
| 1457 |
+
"588185d8dfb7a15588a114a3",
|
| 1458 |
+
"588315c60db9bf59bf8aa928",
|
| 1459 |
+
"588332ee0db9bf59bf8ae9c3",
|
| 1460 |
+
"588519d5932ba84fbed7a04a",
|
| 1461 |
+
"588824d1b791366d617adeef",
|
| 1462 |
+
"588857f6c02346100f4ac09f",
|
| 1463 |
+
"589145ef90414422fbeb2e08",
|
| 1464 |
+
"589433fa9a8c0314c5ce9656",
|
| 1465 |
+
"589765d39a8c0314c5d16b12",
|
| 1466 |
+
"5851165f7072670e72c4860d",
|
| 1467 |
+
"5859341ab338a62ad500848d",
|
| 1468 |
+
"5863915b712e276146909135",
|
| 1469 |
+
"5866445b712e27614692383e",
|
| 1470 |
+
"5866500d712e2761469240fd",
|
| 1471 |
+
"5867785a833dfe3f7b88c764",
|
| 1472 |
+
"5867969c833dfe3f7b88e8bc",
|
| 1473 |
+
"5868040c833dfe3f7b8934f7",
|
| 1474 |
+
"5882372c8ce2c2754d076af0",
|
| 1475 |
+
"5883535e932ba84fbed5ad07",
|
| 1476 |
+
"5888358cb791366d617af69d",
|
| 1477 |
+
"5890330d90414422fbeaa0cb",
|
| 1478 |
+
"5897076e9a8c0314c5d0d31b",
|
| 1479 |
+
"5940564ec2d9527ab869f7e2",
|
| 1480 |
+
"5947719bf1b45630bd096665",
|
| 1481 |
+
"5948194ff1b45630bd0f47e3",
|
| 1482 |
+
"5950206a41b158666ac50506",
|
| 1483 |
+
"5983012d1bd4b175e70c985a",
|
| 1484 |
+
"58586810b338a62ad5ffc20c",
|
| 1485 |
+
"58592046b338a62ad5006b33",
|
| 1486 |
+
"58592854b338a62ad500750a",
|
| 1487 |
+
"58596531b338a62ad500aace",
|
| 1488 |
+
"58818685dfb7a15588a11626",
|
| 1489 |
+
"58829563f42b1d3ee3ec835f",
|
| 1490 |
+
"58894345c02346100f4b51ca",
|
| 1491 |
+
"585289980e7012308945276a",
|
| 1492 |
+
"585369770e7012308945c709",
|
| 1493 |
+
"585373640e7012308945cab9",
|
| 1494 |
+
"588230658ce2c2754d076728",
|
| 1495 |
+
"589388059a8c0314c5ce718b",
|
| 1496 |
+
"595979485ec6a95e86a58c8d",
|
| 1497 |
+
"5841206219d291325678ca90",
|
| 1498 |
+
"58563650804be1058523da55",
|
| 1499 |
+
"58564084804be1058523e116",
|
| 1500 |
+
"58636467712e27614690661f",
|
| 1501 |
+
"58647495712e27614690f36d",
|
| 1502 |
+
"58654563712e276146918643",
|
| 1503 |
+
"58664251712e276146923738",
|
| 1504 |
+
"588084032366dd5d06e59e82",
|
| 1505 |
+
"588159582366dd5d06e66877",
|
| 1506 |
+
"5890279190414422fbea9734",
|
| 1507 |
+
"5890641690414422fbeabbe7",
|
| 1508 |
+
"585203546789802282f2aaf5",
|
| 1509 |
+
]
|
| 1510 |
+
|
| 1511 |
+
# Validation set sequences after filtering
|
| 1512 |
+
self.val_split_scenes = [
|
| 1513 |
+
"00000000000000000000000a",
|
| 1514 |
+
"5a4a38dad38c8a075495b5d2",
|
| 1515 |
+
"5a489fb1c7dab83a7d7b1070",
|
| 1516 |
+
"5a572fd9fc597b0478a81d14",
|
| 1517 |
+
"5a588a8193ac3d233f77fbca",
|
| 1518 |
+
"5aa0f478a9efce63548c1cb4",
|
| 1519 |
+
"5ae2e9c5fe405c5076abc6b2",
|
| 1520 |
+
"5b2c67b5e0878c381608b8d8",
|
| 1521 |
+
"5b21e18c58e2823a67a10dd8",
|
| 1522 |
+
"5b864d850d072a699b32f4ae",
|
| 1523 |
+
"5b4933abf2b5f44e95de482a",
|
| 1524 |
+
"5b37189a35304b6f75e7583e",
|
| 1525 |
+
"5bc5f0e896b66a2cd8f9bd36",
|
| 1526 |
+
"5bccd6beca24970bce448134",
|
| 1527 |
+
"5bf26cbbd43923194854b270",
|
| 1528 |
+
"5bf18642c50e6f7f8bdbd492",
|
| 1529 |
+
"5bfc9d5aec61ca1dd69132a2",
|
| 1530 |
+
"5bff3c5cfe0ea555e6bcbf3a",
|
| 1531 |
+
"5c1f33f1d33e1f2e4aa6dda4",
|
| 1532 |
+
"5c34529873a8df509ae57b58",
|
| 1533 |
+
"58a186444a4d262a170ae3ae",
|
| 1534 |
+
"58f7f7299f5b5647873cb110",
|
| 1535 |
+
"59acd2f4b891807f439c8992",
|
| 1536 |
+
"567a0fb0a825d2fb79ac9a20",
|
| 1537 |
+
"584ad76bfe3cb463906ce6dc",
|
| 1538 |
+
"584c58b77072670e72c03990",
|
| 1539 |
+
"586c48329d1b5e34c2838e80",
|
| 1540 |
+
"586df9849d1b5e34c28506de",
|
| 1541 |
+
"588e0d8c90414422fbe8f8b2",
|
| 1542 |
+
"589c300f7dc3d323d5577926",
|
| 1543 |
+
"590f91851225725be9e25d4e",
|
| 1544 |
+
"5889e344ec4d5a1c088e59be",
|
| 1545 |
+
"5898b31cc9dccc22987b82ec",
|
| 1546 |
+
"5947b62af1b45630bd0c2a02",
|
| 1547 |
+
"58598db2b338a62ad500bc38",
|
| 1548 |
+
"58669c02712e27614692851a",
|
| 1549 |
+
"58790c82ce911104a3467c88",
|
| 1550 |
+
"58897f62c02346100f4b8ee6",
|
| 1551 |
+
"588305ed0db9bf59bf8a8c80",
|
| 1552 |
+
"588457b8932ba84fbed69942",
|
| 1553 |
+
"5862388b712e2761468f84aa",
|
| 1554 |
+
"5880675a2366dd5d06e570ca",
|
| 1555 |
+
"5890523090414422fbeab3f0",
|
| 1556 |
+
]
|
| 1557 |
+
|
| 1558 |
+
|
| 1559 |
+
class TartanAirV2Splits:
|
| 1560 |
+
"""
|
| 1561 |
+
This class contains the information about the splits of the TartanAir V2 dataset.
|
| 1562 |
+
"""
|
| 1563 |
+
|
| 1564 |
+
def __init__(self):
|
| 1565 |
+
"""
|
| 1566 |
+
Splits of environments with unique geometry selected based on TartanVO & UFM splits.
|
| 1567 |
+
"""
|
| 1568 |
+
# Apart from the below 2 splits, all other TAv2 scenes are in the train split
|
| 1569 |
+
# Val split
|
| 1570 |
+
self.val_split_scenes = ["EndofTheWorld", "HongKong", "WesternDesertTown"]
|
| 1571 |
+
|
| 1572 |
+
# Test split
|
| 1573 |
+
self.test_split_scenes = [
|
| 1574 |
+
"DesertGasStation",
|
| 1575 |
+
"OldScandinavia",
|
| 1576 |
+
"PolarSciFi",
|
| 1577 |
+
"Sewerage",
|
| 1578 |
+
"Supermarket",
|
| 1579 |
+
]
|
| 1580 |
+
|
| 1581 |
+
|
| 1582 |
+
class MegaDepthSplits:
|
| 1583 |
+
"""
|
| 1584 |
+
This class contains the information about the splits of the MegaDepth dataset.
|
| 1585 |
+
"""
|
| 1586 |
+
|
| 1587 |
+
def __init__(self):
|
| 1588 |
+
"""
|
| 1589 |
+
Validation split is based on scenes used in DUSt3R.
|
| 1590 |
+
"""
|
| 1591 |
+
self.val_split_scenes = ["0015_0", "0015_1", "0022_0"]
|
| 1592 |
+
|
| 1593 |
+
|
| 1594 |
+
class SpringSplits:
|
| 1595 |
+
"""
|
| 1596 |
+
This class contains the information about the splits of the Spring dataset.
|
| 1597 |
+
"""
|
| 1598 |
+
|
| 1599 |
+
def __init__(self):
|
| 1600 |
+
self.val_split_scenes = ["0013", "0023", "0037"]
|
| 1601 |
+
|
| 1602 |
+
|
| 1603 |
+
class MPSDSplits:
|
| 1604 |
+
"""
|
| 1605 |
+
This class contains the information about the splits of the MPSD dataset.
|
| 1606 |
+
"""
|
| 1607 |
+
|
| 1608 |
+
def __init__(self):
|
| 1609 |
+
"""
|
| 1610 |
+
Train & Validation split numpy files containing folder names are generated during preprocessing of MPSD dataset.
|
| 1611 |
+
Load the numpy files to get the list of scenes in the train & validation split.
|
| 1612 |
+
A 95% (Train) & 5% (Validation) split is used.
|
| 1613 |
+
"""
|
| 1614 |
+
self.train_split_scenes = "load_numpy_file_with_train_scenes"
|
| 1615 |
+
self.val_split_scenes = "load_numpy_file_with_val_scenes"
|
| 1616 |
+
|
| 1617 |
+
|
| 1618 |
+
class ScanNetPPSplits:
|
| 1619 |
+
"""
|
| 1620 |
+
This class contains the information about the splits of the ScanNetPP dataset.
|
| 1621 |
+
"""
|
| 1622 |
+
|
| 1623 |
+
def __init__(self):
|
| 1624 |
+
"""
|
| 1625 |
+
Validation & Test split only contains scenes from ScanNet++V2 to prevent data leak with other methods such as DUSt3R during benchmarking.
|
| 1626 |
+
|
| 1627 |
+
Following logic was used to generate the splits:
|
| 1628 |
+
# Select 80%, 10%, 10% of the scenes for train, val, test respectively from ScanNet++ V2 (~300 scene subset; excluding V1 scenes)
|
| 1629 |
+
snpp_v2_test_scenes = np.random.choice(
|
| 1630 |
+
snpp_v2_processed_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False
|
| 1631 |
+
)
|
| 1632 |
+
remaining_scenes = [scene for scene in snpp_v2_processed_scenes if scene not in snpp_v2_test_scenes]
|
| 1633 |
+
snpp_v2_val_scenes = np.random.choice(
|
| 1634 |
+
remaining_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False
|
| 1635 |
+
)
|
| 1636 |
+
snpp_v2_train_scenes = [
|
| 1637 |
+
scene for scene in remaining_scenes if scene not in snpp_v2_val_scenes and scene not in snpp_v2_test_scenes
|
| 1638 |
+
]
|
| 1639 |
+
"""
|
| 1640 |
+
# Validation Scenes
|
| 1641 |
+
self.val_split_scenes = [
|
| 1642 |
+
"1c7a683c92",
|
| 1643 |
+
"2a1b555966",
|
| 1644 |
+
"3a43c7b8d2",
|
| 1645 |
+
"4aef651da7",
|
| 1646 |
+
"06bc6d1b24",
|
| 1647 |
+
"7f22d5ef1b",
|
| 1648 |
+
"7f77abce34",
|
| 1649 |
+
"8ea517a2fc",
|
| 1650 |
+
"29c7afafed",
|
| 1651 |
+
"41eb967018",
|
| 1652 |
+
"77b40ce601",
|
| 1653 |
+
"086f09d6e3",
|
| 1654 |
+
"307e3262f1",
|
| 1655 |
+
"639f2c4d5a",
|
| 1656 |
+
"894dbd41f1",
|
| 1657 |
+
"898a7dfd0c",
|
| 1658 |
+
"2779f8f9e2",
|
| 1659 |
+
"151178afd7",
|
| 1660 |
+
"182932a4f3",
|
| 1661 |
+
"635852d56e",
|
| 1662 |
+
"9906136b57",
|
| 1663 |
+
"af112b8903",
|
| 1664 |
+
"b0f057c684",
|
| 1665 |
+
"b37177e6c8",
|
| 1666 |
+
"b119249da7",
|
| 1667 |
+
"be8367fcbe",
|
| 1668 |
+
"c8fc01c453",
|
| 1669 |
+
"e1fb8626c8",
|
| 1670 |
+
"e2caaaf5b5",
|
| 1671 |
+
"fe3fc057a1",
|
| 1672 |
+
]
|
| 1673 |
+
|
| 1674 |
+
# Test Scenes
|
| 1675 |
+
self.test_split_scenes = [
|
| 1676 |
+
"0e900bcc5c",
|
| 1677 |
+
"0eba3981c9",
|
| 1678 |
+
"1cbb105c6a",
|
| 1679 |
+
"3c8d535d49",
|
| 1680 |
+
"5d902f1593",
|
| 1681 |
+
"6bd39ac392",
|
| 1682 |
+
"6c14d5fd01",
|
| 1683 |
+
"7c31a42404",
|
| 1684 |
+
"9bfbc75700",
|
| 1685 |
+
"13b4efaf62",
|
| 1686 |
+
"062e5a23a6",
|
| 1687 |
+
"95b9971d01",
|
| 1688 |
+
"246fe09e98",
|
| 1689 |
+
"637a27d04b",
|
| 1690 |
+
"725b8f0cba",
|
| 1691 |
+
"413085a827",
|
| 1692 |
+
"696317583f",
|
| 1693 |
+
"a4c043ac48",
|
| 1694 |
+
"a9e4791c7e",
|
| 1695 |
+
"b0b004c40f",
|
| 1696 |
+
"c3bc5e82c5",
|
| 1697 |
+
"c31ebd4b22",
|
| 1698 |
+
"cba701332a",
|
| 1699 |
+
"cc5ea8026c",
|
| 1700 |
+
"cec8312f4e",
|
| 1701 |
+
"e3b3b0d0c7",
|
| 1702 |
+
"e667e09fe6",
|
| 1703 |
+
"eaa6c90310",
|
| 1704 |
+
"f9397af4cb",
|
| 1705 |
+
"fb893ffaf3",
|
| 1706 |
+
]
|
| 1707 |
+
|
| 1708 |
+
|
| 1709 |
+
class DL3DV10KSplits:
|
| 1710 |
+
"""
|
| 1711 |
+
This class contains the information about the splits of the DL3DV-10K dataset.
|
| 1712 |
+
We use the official benchmark split as the val split.
|
| 1713 |
+
"""
|
| 1714 |
+
|
| 1715 |
+
def __init__(self):
|
| 1716 |
+
"""
|
| 1717 |
+
Validation split is based on DL3DV-Benchmark.
|
| 1718 |
+
"""
|
| 1719 |
+
self.val_split_scenes = [
|
| 1720 |
+
"load https://huggingface.co/datasets/DL3DV/DL3DV-Benchmark/raw/main/benchmark-meta.csv \
|
| 1721 |
+
& https://raw.githubusercontent.com/DL3DV-10K/Dataset/main/cache/DL3DV-valid.csv"
|
| 1722 |
+
]
|
| 1723 |
+
|
| 1724 |
+
|
| 1725 |
+
class ETH3DSplits:
|
| 1726 |
+
"""
|
| 1727 |
+
This class contains the information about the splits of the ETH3D dataset.
|
| 1728 |
+
"""
|
| 1729 |
+
|
| 1730 |
+
def __init__(self):
|
| 1731 |
+
"""
|
| 1732 |
+
All scenes are in the test split.
|
| 1733 |
+
"""
|
| 1734 |
+
self.test_split_scenes = "all"
|
mapanything/datasets/wai/__init__.py
ADDED
|
File without changes
|
mapanything/datasets/wai/ase.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
ASE Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 15 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ASEWAI(BaseDataset):
|
| 19 |
+
"""
|
| 20 |
+
ASE dataset containing large diversity of synthetic indoor scenes.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*args,
|
| 26 |
+
ROOT,
|
| 27 |
+
dataset_metadata_dir,
|
| 28 |
+
split,
|
| 29 |
+
overfit_num_sets=None,
|
| 30 |
+
sample_specific_scene: bool = False,
|
| 31 |
+
specific_scene_name: str = None,
|
| 32 |
+
**kwargs,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the dataset attributes.
|
| 36 |
+
Args:
|
| 37 |
+
ROOT: Root directory of the dataset.
|
| 38 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 39 |
+
split: Dataset split (train, val, test).
|
| 40 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 41 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 42 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 43 |
+
"""
|
| 44 |
+
# Initialize the dataset attributes
|
| 45 |
+
super().__init__(*args, **kwargs)
|
| 46 |
+
self.ROOT = ROOT
|
| 47 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 48 |
+
self.split = split
|
| 49 |
+
self.overfit_num_sets = overfit_num_sets
|
| 50 |
+
self.sample_specific_scene = sample_specific_scene
|
| 51 |
+
self.specific_scene_name = specific_scene_name
|
| 52 |
+
self._load_data()
|
| 53 |
+
|
| 54 |
+
# Define the dataset type flags
|
| 55 |
+
self.is_metric_scale = True
|
| 56 |
+
self.is_synthetic = True
|
| 57 |
+
|
| 58 |
+
def _load_data(self):
|
| 59 |
+
"Load the precomputed dataset metadata"
|
| 60 |
+
# Load the dataset metadata corresponding to the split
|
| 61 |
+
split_metadata_path = os.path.join(
|
| 62 |
+
self.dataset_metadata_dir,
|
| 63 |
+
self.split,
|
| 64 |
+
f"ase_scene_list_{self.split}.npy",
|
| 65 |
+
)
|
| 66 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 67 |
+
|
| 68 |
+
# Get the list of all scenes
|
| 69 |
+
if not self.sample_specific_scene:
|
| 70 |
+
self.scenes = list(split_scene_list)
|
| 71 |
+
else:
|
| 72 |
+
self.scenes = [self.specific_scene_name]
|
| 73 |
+
self.num_of_scenes = len(self.scenes)
|
| 74 |
+
|
| 75 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 76 |
+
# Get the scene name of the sampled index
|
| 77 |
+
scene_index = sampled_idx
|
| 78 |
+
scene_name = self.scenes[scene_index]
|
| 79 |
+
|
| 80 |
+
# Get the metadata corresponding to the scene
|
| 81 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 82 |
+
scene_meta = load_data(
|
| 83 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 84 |
+
)
|
| 85 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 86 |
+
num_views_in_scene = len(scene_file_names)
|
| 87 |
+
|
| 88 |
+
# Load the scene pairwise covisibility mmap
|
| 89 |
+
covisibility_version_key = "v0"
|
| 90 |
+
covisibility_map_dir = os.path.join(
|
| 91 |
+
scene_root, "covisibility", covisibility_version_key
|
| 92 |
+
)
|
| 93 |
+
# Assumes only npy file in directory is covisibility map
|
| 94 |
+
covisibility_map_name = next(
|
| 95 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 96 |
+
)
|
| 97 |
+
covisibility_map_path = os.path.join(
|
| 98 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 99 |
+
)
|
| 100 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 101 |
+
|
| 102 |
+
# Get the indices of the N views in the scene
|
| 103 |
+
view_indices = self._sample_view_indices(
|
| 104 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Get the views corresponding to the selected view indices
|
| 108 |
+
views = []
|
| 109 |
+
for view_index in view_indices:
|
| 110 |
+
# Load the data corresponding to the view
|
| 111 |
+
view_file_name = scene_file_names[view_index]
|
| 112 |
+
view_data = load_frame(
|
| 113 |
+
scene_root,
|
| 114 |
+
view_file_name,
|
| 115 |
+
modalities=["image", "depth"],
|
| 116 |
+
scene_meta=scene_meta,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Convert necessary data to numpy
|
| 120 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 121 |
+
image = (image * 255).astype(np.uint8)
|
| 122 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 123 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 124 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 125 |
+
|
| 126 |
+
# Ensure that the depthmap has all valid values
|
| 127 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 128 |
+
|
| 129 |
+
# Resize the data to match the desired resolution
|
| 130 |
+
image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 131 |
+
image=image,
|
| 132 |
+
resolution=resolution,
|
| 133 |
+
depthmap=depthmap,
|
| 134 |
+
intrinsics=intrinsics,
|
| 135 |
+
additional_quantities=None,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Append the view dictionary to the list of views
|
| 139 |
+
views.append(
|
| 140 |
+
dict(
|
| 141 |
+
img=image,
|
| 142 |
+
depthmap=depthmap,
|
| 143 |
+
camera_pose=c2w_pose, # cam2world
|
| 144 |
+
camera_intrinsics=intrinsics,
|
| 145 |
+
dataset="ASE",
|
| 146 |
+
label=scene_name,
|
| 147 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return views
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_parser():
|
| 155 |
+
import argparse
|
| 156 |
+
|
| 157 |
+
parser = argparse.ArgumentParser()
|
| 158 |
+
parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/ase", type=str)
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"-dmd",
|
| 161 |
+
"--dataset_metadata_dir",
|
| 162 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 163 |
+
type=str,
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"-nv",
|
| 167 |
+
"--num_of_views",
|
| 168 |
+
default=2,
|
| 169 |
+
type=int,
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument("--viz", action="store_true")
|
| 172 |
+
|
| 173 |
+
return parser
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
import rerun as rr
|
| 178 |
+
from tqdm import tqdm
|
| 179 |
+
|
| 180 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 181 |
+
from mapanything.utils.image import rgb
|
| 182 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 183 |
+
|
| 184 |
+
parser = get_parser()
|
| 185 |
+
script_add_rerun_args(
|
| 186 |
+
parser
|
| 187 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 188 |
+
args = parser.parse_args()
|
| 189 |
+
|
| 190 |
+
dataset = ASEWAI(
|
| 191 |
+
num_views=args.num_of_views,
|
| 192 |
+
split="train",
|
| 193 |
+
covisibility_thres=0.25,
|
| 194 |
+
ROOT=args.root_dir,
|
| 195 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 196 |
+
resolution=(518, 518),
|
| 197 |
+
aug_crop=16,
|
| 198 |
+
transform="colorjitter+grayscale+gaublur",
|
| 199 |
+
data_norm_type="dinov2",
|
| 200 |
+
)
|
| 201 |
+
# dataset = ASEWAI(
|
| 202 |
+
# num_views=args.num_of_views,
|
| 203 |
+
# split="val",
|
| 204 |
+
# covisibility_thres=0.25,
|
| 205 |
+
# ROOT=args.root_dir,
|
| 206 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 207 |
+
# resolution=(518, 518),
|
| 208 |
+
# seed=777,
|
| 209 |
+
# transform="imgnorm",
|
| 210 |
+
# data_norm_type="dinov2",
|
| 211 |
+
# )
|
| 212 |
+
print(dataset.get_stats())
|
| 213 |
+
|
| 214 |
+
if args.viz:
|
| 215 |
+
rr.script_setup(args, "ASE_Dataloader")
|
| 216 |
+
rr.set_time("stable_time", sequence=0)
|
| 217 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 218 |
+
|
| 219 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 220 |
+
|
| 221 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 222 |
+
views = dataset[idx]
|
| 223 |
+
assert len(views) == args.num_of_views
|
| 224 |
+
sample_name = f"{idx}"
|
| 225 |
+
for view_idx in range(args.num_of_views):
|
| 226 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 227 |
+
print(sample_name)
|
| 228 |
+
for view_idx in range(args.num_of_views):
|
| 229 |
+
image = rgb(
|
| 230 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 231 |
+
)
|
| 232 |
+
depthmap = views[view_idx]["depthmap"]
|
| 233 |
+
pose = views[view_idx]["camera_pose"]
|
| 234 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 235 |
+
pts3d = views[view_idx]["pts3d"]
|
| 236 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 237 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 238 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 239 |
+
else:
|
| 240 |
+
non_ambiguous_mask = None
|
| 241 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 242 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 243 |
+
else:
|
| 244 |
+
prior_depth_along_ray = None
|
| 245 |
+
if args.viz:
|
| 246 |
+
rr.set_time("stable_time", sequence=num)
|
| 247 |
+
base_name = f"world/view_{view_idx}"
|
| 248 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 249 |
+
# Log camera info and loaded data
|
| 250 |
+
height, width = image.shape[0], image.shape[1]
|
| 251 |
+
rr.log(
|
| 252 |
+
base_name,
|
| 253 |
+
rr.Transform3D(
|
| 254 |
+
translation=pose[:3, 3],
|
| 255 |
+
mat3x3=pose[:3, :3],
|
| 256 |
+
),
|
| 257 |
+
)
|
| 258 |
+
rr.log(
|
| 259 |
+
f"{base_name}/pinhole",
|
| 260 |
+
rr.Pinhole(
|
| 261 |
+
image_from_camera=intrinsics,
|
| 262 |
+
height=height,
|
| 263 |
+
width=width,
|
| 264 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 265 |
+
),
|
| 266 |
+
)
|
| 267 |
+
rr.log(
|
| 268 |
+
f"{base_name}/pinhole/rgb",
|
| 269 |
+
rr.Image(image),
|
| 270 |
+
)
|
| 271 |
+
rr.log(
|
| 272 |
+
f"{base_name}/pinhole/depth",
|
| 273 |
+
rr.DepthImage(depthmap),
|
| 274 |
+
)
|
| 275 |
+
if prior_depth_along_ray is not None:
|
| 276 |
+
rr.log(
|
| 277 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 278 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 279 |
+
)
|
| 280 |
+
if non_ambiguous_mask is not None:
|
| 281 |
+
rr.log(
|
| 282 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 283 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 284 |
+
)
|
| 285 |
+
# Log points in 3D
|
| 286 |
+
filtered_pts = pts3d[valid_mask]
|
| 287 |
+
filtered_pts_col = image[valid_mask]
|
| 288 |
+
rr.log(
|
| 289 |
+
pts_name,
|
| 290 |
+
rr.Points3D(
|
| 291 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 292 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 293 |
+
),
|
| 294 |
+
)
|
mapanything/datasets/wai/blendedmvs.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
BlendedMVS Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 16 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BlendedMVSWAI(BaseDataset):
|
| 20 |
+
"""
|
| 21 |
+
BlendedMVS dataset containing object-centric and birds-eye-view scenes.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
*args,
|
| 27 |
+
ROOT,
|
| 28 |
+
dataset_metadata_dir,
|
| 29 |
+
split,
|
| 30 |
+
overfit_num_sets=None,
|
| 31 |
+
sample_specific_scene: bool = False,
|
| 32 |
+
specific_scene_name: str = None,
|
| 33 |
+
**kwargs,
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the dataset attributes.
|
| 37 |
+
Args:
|
| 38 |
+
ROOT: Root directory of the dataset.
|
| 39 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 40 |
+
split: Dataset split (train, val, test).
|
| 41 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 42 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 43 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 44 |
+
"""
|
| 45 |
+
# Initialize the dataset attributes
|
| 46 |
+
super().__init__(*args, **kwargs)
|
| 47 |
+
self.ROOT = ROOT
|
| 48 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 49 |
+
self.split = split
|
| 50 |
+
self.overfit_num_sets = overfit_num_sets
|
| 51 |
+
self.sample_specific_scene = sample_specific_scene
|
| 52 |
+
self.specific_scene_name = specific_scene_name
|
| 53 |
+
self._load_data()
|
| 54 |
+
|
| 55 |
+
# Define the dataset type flags
|
| 56 |
+
self.is_metric_scale = False
|
| 57 |
+
self.is_synthetic = False
|
| 58 |
+
|
| 59 |
+
def _load_data(self):
|
| 60 |
+
"Load the precomputed dataset metadata"
|
| 61 |
+
# Load the dataset metadata corresponding to the split
|
| 62 |
+
split_metadata_path = os.path.join(
|
| 63 |
+
self.dataset_metadata_dir,
|
| 64 |
+
self.split,
|
| 65 |
+
f"blendedmvs_scene_list_{self.split}.npy",
|
| 66 |
+
)
|
| 67 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 68 |
+
|
| 69 |
+
# Get the list of all scenes
|
| 70 |
+
if not self.sample_specific_scene:
|
| 71 |
+
self.scenes = list(split_scene_list)
|
| 72 |
+
else:
|
| 73 |
+
self.scenes = [self.specific_scene_name]
|
| 74 |
+
self.num_of_scenes = len(self.scenes)
|
| 75 |
+
|
| 76 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 77 |
+
# Get the scene name of the sampled index
|
| 78 |
+
scene_index = sampled_idx
|
| 79 |
+
scene_name = self.scenes[scene_index]
|
| 80 |
+
|
| 81 |
+
# Get the metadata corresponding to the scene
|
| 82 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 83 |
+
scene_meta = load_data(
|
| 84 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 85 |
+
)
|
| 86 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 87 |
+
num_views_in_scene = len(scene_file_names)
|
| 88 |
+
|
| 89 |
+
# Load the scene pairwise covisibility mmap
|
| 90 |
+
covisibility_version_key = "v0"
|
| 91 |
+
covisibility_map_dir = os.path.join(
|
| 92 |
+
scene_root, "covisibility", covisibility_version_key
|
| 93 |
+
)
|
| 94 |
+
# Assumes only npy file in directory is covisibility map
|
| 95 |
+
covisibility_map_name = next(
|
| 96 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 97 |
+
)
|
| 98 |
+
covisibility_map_path = os.path.join(
|
| 99 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 100 |
+
)
|
| 101 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 102 |
+
|
| 103 |
+
# Get the indices of the N views in the scene
|
| 104 |
+
view_indices = self._sample_view_indices(
|
| 105 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Get the views corresponding to the selected view indices
|
| 109 |
+
views = []
|
| 110 |
+
for view_index in view_indices:
|
| 111 |
+
# Load the data corresponding to the view
|
| 112 |
+
view_file_name = scene_file_names[view_index]
|
| 113 |
+
view_data = load_frame(
|
| 114 |
+
scene_root,
|
| 115 |
+
view_file_name,
|
| 116 |
+
modalities=["image", "depth", "pred_mask/moge2"],
|
| 117 |
+
scene_meta=scene_meta,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Convert necessary data to numpy
|
| 121 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 122 |
+
image = (image * 255).astype(np.uint8)
|
| 123 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 124 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 125 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 126 |
+
|
| 127 |
+
# Ensure that the depthmap has all valid values
|
| 128 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 129 |
+
|
| 130 |
+
# Get the non_ambiguous_mask and ensure it matches image resolution
|
| 131 |
+
non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
|
| 132 |
+
non_ambiguous_mask = cv2.resize(
|
| 133 |
+
non_ambiguous_mask,
|
| 134 |
+
(image.shape[1], image.shape[0]),
|
| 135 |
+
interpolation=cv2.INTER_NEAREST,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Mask out the GT depth using the non_ambiguous_mask
|
| 139 |
+
depthmap = np.where(non_ambiguous_mask, depthmap, 0)
|
| 140 |
+
|
| 141 |
+
# Resize the data to match the desired resolution
|
| 142 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 143 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 144 |
+
self._crop_resize_if_necessary(
|
| 145 |
+
image=image,
|
| 146 |
+
resolution=resolution,
|
| 147 |
+
depthmap=depthmap,
|
| 148 |
+
intrinsics=intrinsics,
|
| 149 |
+
additional_quantities=additional_quantities_to_resize,
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 153 |
+
|
| 154 |
+
# Append the view dictionary to the list of views
|
| 155 |
+
views.append(
|
| 156 |
+
dict(
|
| 157 |
+
img=image,
|
| 158 |
+
depthmap=depthmap,
|
| 159 |
+
camera_pose=c2w_pose, # cam2world
|
| 160 |
+
camera_intrinsics=intrinsics,
|
| 161 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 162 |
+
dataset="BlendedMVS",
|
| 163 |
+
label=scene_name,
|
| 164 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 165 |
+
)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return views
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def get_parser():
|
| 172 |
+
import argparse
|
| 173 |
+
|
| 174 |
+
parser = argparse.ArgumentParser()
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/blendedmvs", type=str
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"-dmd",
|
| 180 |
+
"--dataset_metadata_dir",
|
| 181 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 182 |
+
type=str,
|
| 183 |
+
)
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"-nv",
|
| 186 |
+
"--num_of_views",
|
| 187 |
+
default=2,
|
| 188 |
+
type=int,
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument("--viz", action="store_true")
|
| 191 |
+
|
| 192 |
+
return parser
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
import rerun as rr
|
| 197 |
+
from tqdm import tqdm
|
| 198 |
+
|
| 199 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 200 |
+
from mapanything.utils.image import rgb
|
| 201 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 202 |
+
|
| 203 |
+
parser = get_parser()
|
| 204 |
+
script_add_rerun_args(
|
| 205 |
+
parser
|
| 206 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 207 |
+
args = parser.parse_args()
|
| 208 |
+
|
| 209 |
+
dataset = BlendedMVSWAI(
|
| 210 |
+
num_views=args.num_of_views,
|
| 211 |
+
split="train",
|
| 212 |
+
covisibility_thres=0.25,
|
| 213 |
+
ROOT=args.root_dir,
|
| 214 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 215 |
+
resolution=(518, 392),
|
| 216 |
+
aug_crop=16,
|
| 217 |
+
transform="colorjitter+grayscale+gaublur",
|
| 218 |
+
data_norm_type="dinov2",
|
| 219 |
+
)
|
| 220 |
+
# dataset = BlendedMVSWAI(
|
| 221 |
+
# num_views=args.num_of_views,
|
| 222 |
+
# split="val",
|
| 223 |
+
# covisibility_thres=0.25,
|
| 224 |
+
# ROOT=args.root_dir,
|
| 225 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 226 |
+
# resolution=(518, 392),
|
| 227 |
+
# seed=777,
|
| 228 |
+
# transform="imgnorm",
|
| 229 |
+
# data_norm_type="dinov2",
|
| 230 |
+
# )
|
| 231 |
+
print(dataset.get_stats())
|
| 232 |
+
|
| 233 |
+
if args.viz:
|
| 234 |
+
rr.script_setup(args, "BlendedMVS_Dataloader")
|
| 235 |
+
rr.set_time("stable_time", sequence=0)
|
| 236 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 237 |
+
|
| 238 |
+
sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
|
| 239 |
+
|
| 240 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 241 |
+
views = dataset[idx]
|
| 242 |
+
assert len(views) == args.num_of_views
|
| 243 |
+
sample_name = f"{idx}"
|
| 244 |
+
for view_idx in range(args.num_of_views):
|
| 245 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 246 |
+
print(sample_name)
|
| 247 |
+
for view_idx in range(args.num_of_views):
|
| 248 |
+
image = rgb(
|
| 249 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 250 |
+
)
|
| 251 |
+
depthmap = views[view_idx]["depthmap"]
|
| 252 |
+
pose = views[view_idx]["camera_pose"]
|
| 253 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 254 |
+
pts3d = views[view_idx]["pts3d"]
|
| 255 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 256 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 257 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 258 |
+
else:
|
| 259 |
+
non_ambiguous_mask = None
|
| 260 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 261 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 262 |
+
else:
|
| 263 |
+
prior_depth_along_ray = None
|
| 264 |
+
if args.viz:
|
| 265 |
+
rr.set_time("stable_time", sequence=num)
|
| 266 |
+
base_name = f"world/view_{view_idx}"
|
| 267 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 268 |
+
# Log camera info and loaded data
|
| 269 |
+
height, width = image.shape[0], image.shape[1]
|
| 270 |
+
rr.log(
|
| 271 |
+
base_name,
|
| 272 |
+
rr.Transform3D(
|
| 273 |
+
translation=pose[:3, 3],
|
| 274 |
+
mat3x3=pose[:3, :3],
|
| 275 |
+
),
|
| 276 |
+
)
|
| 277 |
+
rr.log(
|
| 278 |
+
f"{base_name}/pinhole",
|
| 279 |
+
rr.Pinhole(
|
| 280 |
+
image_from_camera=intrinsics,
|
| 281 |
+
height=height,
|
| 282 |
+
width=width,
|
| 283 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 284 |
+
),
|
| 285 |
+
)
|
| 286 |
+
rr.log(
|
| 287 |
+
f"{base_name}/pinhole/rgb",
|
| 288 |
+
rr.Image(image),
|
| 289 |
+
)
|
| 290 |
+
rr.log(
|
| 291 |
+
f"{base_name}/pinhole/depth",
|
| 292 |
+
rr.DepthImage(depthmap),
|
| 293 |
+
)
|
| 294 |
+
if prior_depth_along_ray is not None:
|
| 295 |
+
rr.log(
|
| 296 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 297 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 298 |
+
)
|
| 299 |
+
if non_ambiguous_mask is not None:
|
| 300 |
+
rr.log(
|
| 301 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 302 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 303 |
+
)
|
| 304 |
+
# Log points in 3D
|
| 305 |
+
filtered_pts = pts3d[valid_mask]
|
| 306 |
+
filtered_pts_col = image[valid_mask]
|
| 307 |
+
rr.log(
|
| 308 |
+
pts_name,
|
| 309 |
+
rr.Points3D(
|
| 310 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 311 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 312 |
+
),
|
| 313 |
+
)
|
mapanything/datasets/wai/dl3dv.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
DL3DV Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 16 |
+
from mapanything.utils.cropping import (
|
| 17 |
+
rescale_image_and_other_optional_info,
|
| 18 |
+
resize_with_nearest_interpolation_to_match_aspect_ratio,
|
| 19 |
+
)
|
| 20 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DL3DVWAI(BaseDataset):
|
| 24 |
+
"""
|
| 25 |
+
DL3DV dataset containing over 10k in-the-wild and indoor scenes.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
*args,
|
| 31 |
+
ROOT,
|
| 32 |
+
dataset_metadata_dir,
|
| 33 |
+
split,
|
| 34 |
+
overfit_num_sets=None,
|
| 35 |
+
sample_specific_scene: bool = False,
|
| 36 |
+
specific_scene_name: str = None,
|
| 37 |
+
mvs_confidence_filter_thres: float = 0.25,
|
| 38 |
+
**kwargs,
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Initialize the dataset attributes.
|
| 42 |
+
Args:
|
| 43 |
+
ROOT: Root directory of the dataset.
|
| 44 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 45 |
+
split: Dataset split (train, val, test).
|
| 46 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 47 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 48 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 49 |
+
mvs_confidence_filter_thres: Confidence threshold to filter MVS depth. Defaults to 0.25.
|
| 50 |
+
"""
|
| 51 |
+
# Initialize the dataset attributes
|
| 52 |
+
super().__init__(*args, **kwargs)
|
| 53 |
+
self.ROOT = ROOT
|
| 54 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 55 |
+
self.split = split
|
| 56 |
+
self.overfit_num_sets = overfit_num_sets
|
| 57 |
+
self.sample_specific_scene = sample_specific_scene
|
| 58 |
+
self.specific_scene_name = specific_scene_name
|
| 59 |
+
self.mvs_confidence_filter_thres = mvs_confidence_filter_thres
|
| 60 |
+
self._load_data()
|
| 61 |
+
|
| 62 |
+
# Define the dataset type flags
|
| 63 |
+
self.is_metric_scale = False
|
| 64 |
+
self.is_synthetic = False
|
| 65 |
+
|
| 66 |
+
def _load_data(self):
|
| 67 |
+
"Load the precomputed dataset metadata"
|
| 68 |
+
# Load the dataset metadata corresponding to the split
|
| 69 |
+
split_metadata_path = os.path.join(
|
| 70 |
+
self.dataset_metadata_dir,
|
| 71 |
+
self.split,
|
| 72 |
+
f"dl3dv_scene_list_{self.split}.npy",
|
| 73 |
+
)
|
| 74 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 75 |
+
|
| 76 |
+
# Get the list of all scenes
|
| 77 |
+
if not self.sample_specific_scene:
|
| 78 |
+
self.scenes = list(split_scene_list)
|
| 79 |
+
else:
|
| 80 |
+
self.scenes = [self.specific_scene_name]
|
| 81 |
+
self.num_of_scenes = len(self.scenes)
|
| 82 |
+
|
| 83 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 84 |
+
# Get the scene name of the sampled index
|
| 85 |
+
scene_index = sampled_idx
|
| 86 |
+
scene_name = self.scenes[scene_index]
|
| 87 |
+
|
| 88 |
+
# Get the metadata corresponding to the scene
|
| 89 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 90 |
+
scene_meta = load_data(
|
| 91 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 92 |
+
)
|
| 93 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 94 |
+
num_views_in_scene = len(scene_file_names)
|
| 95 |
+
|
| 96 |
+
# Load the scene pairwise covisibility mmap
|
| 97 |
+
covisibility_version_key = "v0_mvsa_based"
|
| 98 |
+
covisibility_map_dir = os.path.join(
|
| 99 |
+
scene_root, "covisibility", covisibility_version_key
|
| 100 |
+
)
|
| 101 |
+
# Assumes only npy file in directory is covisibility map
|
| 102 |
+
covisibility_map_name = next(
|
| 103 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 104 |
+
)
|
| 105 |
+
covisibility_map_path = os.path.join(
|
| 106 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 107 |
+
)
|
| 108 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 109 |
+
|
| 110 |
+
# Get the indices of the N views in the scene
|
| 111 |
+
view_indices = self._sample_view_indices(
|
| 112 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Get the views corresponding to the selected view indices
|
| 116 |
+
views = []
|
| 117 |
+
for view_index in view_indices:
|
| 118 |
+
# Load the data corresponding to the view
|
| 119 |
+
view_file_name = scene_file_names[view_index]
|
| 120 |
+
view_data = load_frame(
|
| 121 |
+
scene_root,
|
| 122 |
+
view_file_name,
|
| 123 |
+
modalities=[
|
| 124 |
+
"image",
|
| 125 |
+
"pred_depth/mvsanywhere",
|
| 126 |
+
"pred_mask/moge2",
|
| 127 |
+
"depth_confidence/mvsanywhere",
|
| 128 |
+
],
|
| 129 |
+
scene_meta=scene_meta,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Convert necessary data to numpy
|
| 133 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 134 |
+
image = (image * 255).astype(np.uint8)
|
| 135 |
+
depthmap = view_data["pred_depth/mvsanywhere"].numpy().astype(np.float32)
|
| 136 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 137 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 138 |
+
|
| 139 |
+
# Ensure that the depthmap has all valid values
|
| 140 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 141 |
+
|
| 142 |
+
# Get the dimensions of the original image
|
| 143 |
+
img_h, img_w = image.shape[:2]
|
| 144 |
+
|
| 145 |
+
# Resize depth to match image aspect ratio while ensuring that depth resolution doesn't increase
|
| 146 |
+
depthmap, target_depth_h, target_depth_w = (
|
| 147 |
+
resize_with_nearest_interpolation_to_match_aspect_ratio(
|
| 148 |
+
input_data=depthmap, img_h=img_h, img_w=img_w
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Now resize the image and update intrinsics to match the resized depth
|
| 153 |
+
image, _, intrinsics, _ = rescale_image_and_other_optional_info(
|
| 154 |
+
image=image,
|
| 155 |
+
output_resolution=(target_depth_w, target_depth_h),
|
| 156 |
+
depthmap=None,
|
| 157 |
+
camera_intrinsics=intrinsics,
|
| 158 |
+
)
|
| 159 |
+
image = np.array(image)
|
| 160 |
+
|
| 161 |
+
# Get the depth confidence map and mask out the MVS depth
|
| 162 |
+
confidence_map = view_data["depth_confidence/mvsanywhere"].numpy()
|
| 163 |
+
confidence_mask = (
|
| 164 |
+
confidence_map > self.mvs_confidence_filter_thres
|
| 165 |
+
).astype(int)
|
| 166 |
+
confidence_mask = cv2.resize(
|
| 167 |
+
confidence_mask,
|
| 168 |
+
(image.shape[1], image.shape[0]),
|
| 169 |
+
interpolation=cv2.INTER_NEAREST,
|
| 170 |
+
)
|
| 171 |
+
depthmap = np.where(confidence_mask, depthmap, 0)
|
| 172 |
+
|
| 173 |
+
# Get the non_ambiguous_mask and ensure it matches image resolution
|
| 174 |
+
non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
|
| 175 |
+
non_ambiguous_mask = cv2.resize(
|
| 176 |
+
non_ambiguous_mask,
|
| 177 |
+
(image.shape[1], image.shape[0]),
|
| 178 |
+
interpolation=cv2.INTER_NEAREST,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Mask out the GT depth using the non_ambiguous_mask
|
| 182 |
+
depthmap = np.where(non_ambiguous_mask, depthmap, 0)
|
| 183 |
+
|
| 184 |
+
# Resize the data to match the desired resolution
|
| 185 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 186 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 187 |
+
self._crop_resize_if_necessary(
|
| 188 |
+
image=image,
|
| 189 |
+
resolution=resolution,
|
| 190 |
+
depthmap=depthmap,
|
| 191 |
+
intrinsics=intrinsics,
|
| 192 |
+
additional_quantities=additional_quantities_to_resize,
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 196 |
+
|
| 197 |
+
# Append the view dictionary to the list of views
|
| 198 |
+
views.append(
|
| 199 |
+
dict(
|
| 200 |
+
img=image,
|
| 201 |
+
depthmap=depthmap,
|
| 202 |
+
camera_pose=c2w_pose, # cam2world
|
| 203 |
+
camera_intrinsics=intrinsics,
|
| 204 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 205 |
+
dataset="DL3DV",
|
| 206 |
+
label=scene_name,
|
| 207 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
return views
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def get_parser():
|
| 215 |
+
import argparse
|
| 216 |
+
|
| 217 |
+
parser = argparse.ArgumentParser()
|
| 218 |
+
parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dl3dv", type=str)
|
| 219 |
+
parser.add_argument(
|
| 220 |
+
"-dmd",
|
| 221 |
+
"--dataset_metadata_dir",
|
| 222 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 223 |
+
type=str,
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"-nv",
|
| 227 |
+
"--num_of_views",
|
| 228 |
+
default=2,
|
| 229 |
+
type=int,
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument("--viz", action="store_true")
|
| 232 |
+
|
| 233 |
+
return parser
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
import rerun as rr
|
| 238 |
+
from tqdm import tqdm
|
| 239 |
+
|
| 240 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 241 |
+
from mapanything.utils.image import rgb
|
| 242 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 243 |
+
|
| 244 |
+
parser = get_parser()
|
| 245 |
+
script_add_rerun_args(
|
| 246 |
+
parser
|
| 247 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 248 |
+
args = parser.parse_args()
|
| 249 |
+
|
| 250 |
+
dataset = DL3DVWAI(
|
| 251 |
+
num_views=args.num_of_views,
|
| 252 |
+
split="train",
|
| 253 |
+
covisibility_thres=0.25,
|
| 254 |
+
ROOT=args.root_dir,
|
| 255 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 256 |
+
mvs_confidence_filter_thres=0.25,
|
| 257 |
+
resolution=(518, 294),
|
| 258 |
+
aug_crop=16,
|
| 259 |
+
transform="colorjitter+grayscale+gaublur",
|
| 260 |
+
data_norm_type="dinov2",
|
| 261 |
+
)
|
| 262 |
+
# dataset = DL3DVWAI(
|
| 263 |
+
# num_views=args.num_of_views,
|
| 264 |
+
# split="val",
|
| 265 |
+
# covisibility_thres=0.25,
|
| 266 |
+
# ROOT=args.root_dir,
|
| 267 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 268 |
+
# mvs_confidence_filter_thres=0.25,
|
| 269 |
+
# resolution=(518, 294),
|
| 270 |
+
# seed=777,
|
| 271 |
+
# transform="imgnorm",
|
| 272 |
+
# data_norm_type="dinov2",
|
| 273 |
+
# )
|
| 274 |
+
print(dataset.get_stats())
|
| 275 |
+
|
| 276 |
+
if args.viz:
|
| 277 |
+
rr.script_setup(args, "DL3DV_Dataloader")
|
| 278 |
+
rr.set_time("stable_time", sequence=0)
|
| 279 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 280 |
+
|
| 281 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 282 |
+
|
| 283 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 284 |
+
views = dataset[idx]
|
| 285 |
+
assert len(views) == args.num_of_views
|
| 286 |
+
sample_name = f"{idx}"
|
| 287 |
+
for view_idx in range(args.num_of_views):
|
| 288 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 289 |
+
print(sample_name)
|
| 290 |
+
for view_idx in range(args.num_of_views):
|
| 291 |
+
image = rgb(
|
| 292 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 293 |
+
)
|
| 294 |
+
depthmap = views[view_idx]["depthmap"]
|
| 295 |
+
pose = views[view_idx]["camera_pose"]
|
| 296 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 297 |
+
pts3d = views[view_idx]["pts3d"]
|
| 298 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 299 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 300 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 301 |
+
else:
|
| 302 |
+
non_ambiguous_mask = None
|
| 303 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 304 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 305 |
+
else:
|
| 306 |
+
prior_depth_along_ray = None
|
| 307 |
+
if args.viz:
|
| 308 |
+
rr.set_time("stable_time", sequence=num)
|
| 309 |
+
base_name = f"world/view_{view_idx}"
|
| 310 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 311 |
+
# Log camera info and loaded data
|
| 312 |
+
height, width = image.shape[0], image.shape[1]
|
| 313 |
+
rr.log(
|
| 314 |
+
base_name,
|
| 315 |
+
rr.Transform3D(
|
| 316 |
+
translation=pose[:3, 3],
|
| 317 |
+
mat3x3=pose[:3, :3],
|
| 318 |
+
),
|
| 319 |
+
)
|
| 320 |
+
rr.log(
|
| 321 |
+
f"{base_name}/pinhole",
|
| 322 |
+
rr.Pinhole(
|
| 323 |
+
image_from_camera=intrinsics,
|
| 324 |
+
height=height,
|
| 325 |
+
width=width,
|
| 326 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 327 |
+
),
|
| 328 |
+
)
|
| 329 |
+
rr.log(
|
| 330 |
+
f"{base_name}/pinhole/rgb",
|
| 331 |
+
rr.Image(image),
|
| 332 |
+
)
|
| 333 |
+
rr.log(
|
| 334 |
+
f"{base_name}/pinhole/depth",
|
| 335 |
+
rr.DepthImage(depthmap),
|
| 336 |
+
)
|
| 337 |
+
if prior_depth_along_ray is not None:
|
| 338 |
+
rr.log(
|
| 339 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 340 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 341 |
+
)
|
| 342 |
+
if non_ambiguous_mask is not None:
|
| 343 |
+
rr.log(
|
| 344 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 345 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 346 |
+
)
|
| 347 |
+
# Log points in 3D
|
| 348 |
+
filtered_pts = pts3d[valid_mask]
|
| 349 |
+
filtered_pts_col = image[valid_mask]
|
| 350 |
+
rr.log(
|
| 351 |
+
pts_name,
|
| 352 |
+
rr.Points3D(
|
| 353 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 354 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 355 |
+
),
|
| 356 |
+
)
|
mapanything/datasets/wai/dynamicreplica.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Dynamic Replica Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 15 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DynamicReplicaWAI(BaseDataset):
|
| 19 |
+
"""
|
| 20 |
+
Dynamic Replica dataset containing synthetic scenes with humans and animals.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*args,
|
| 26 |
+
ROOT,
|
| 27 |
+
dataset_metadata_dir,
|
| 28 |
+
split,
|
| 29 |
+
overfit_num_sets=None,
|
| 30 |
+
sample_specific_scene: bool = False,
|
| 31 |
+
specific_scene_name: str = None,
|
| 32 |
+
**kwargs,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the dataset attributes.
|
| 36 |
+
Args:
|
| 37 |
+
ROOT: Root directory of the dataset.
|
| 38 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 39 |
+
split: Dataset split (train, val, test).
|
| 40 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 41 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 42 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 43 |
+
"""
|
| 44 |
+
# Initialize the dataset attributes
|
| 45 |
+
super().__init__(*args, **kwargs)
|
| 46 |
+
self.ROOT = ROOT
|
| 47 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 48 |
+
self.split = split
|
| 49 |
+
self.overfit_num_sets = overfit_num_sets
|
| 50 |
+
self.sample_specific_scene = sample_specific_scene
|
| 51 |
+
self.specific_scene_name = specific_scene_name
|
| 52 |
+
self._load_data()
|
| 53 |
+
|
| 54 |
+
# Define the dataset type flags
|
| 55 |
+
self.is_metric_scale = True
|
| 56 |
+
self.is_synthetic = True
|
| 57 |
+
|
| 58 |
+
def _load_data(self):
|
| 59 |
+
"Load the precomputed dataset metadata"
|
| 60 |
+
# Load the dataset metadata corresponding to the split
|
| 61 |
+
split_metadata_path = os.path.join(
|
| 62 |
+
self.dataset_metadata_dir,
|
| 63 |
+
self.split,
|
| 64 |
+
f"dynamicreplica_scene_list_{self.split}.npy",
|
| 65 |
+
)
|
| 66 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 67 |
+
|
| 68 |
+
# Get the list of all scenes
|
| 69 |
+
if not self.sample_specific_scene:
|
| 70 |
+
self.scenes = list(split_scene_list)
|
| 71 |
+
else:
|
| 72 |
+
self.scenes = [self.specific_scene_name]
|
| 73 |
+
self.num_of_scenes = len(self.scenes)
|
| 74 |
+
|
| 75 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 76 |
+
# Get the scene name of the sampled index
|
| 77 |
+
scene_index = sampled_idx
|
| 78 |
+
scene_name = self.scenes[scene_index]
|
| 79 |
+
|
| 80 |
+
# Get the metadata corresponding to the scene
|
| 81 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 82 |
+
scene_meta = load_data(
|
| 83 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 84 |
+
)
|
| 85 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 86 |
+
num_views_in_scene = len(scene_file_names)
|
| 87 |
+
|
| 88 |
+
# Load the scene pairwise covisibility mmap
|
| 89 |
+
covisibility_version_key = "v0"
|
| 90 |
+
covisibility_map_dir = os.path.join(
|
| 91 |
+
scene_root, "covisibility", covisibility_version_key
|
| 92 |
+
)
|
| 93 |
+
# Assumes only npy file in directory is covisibility map
|
| 94 |
+
covisibility_map_name = next(
|
| 95 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 96 |
+
)
|
| 97 |
+
covisibility_map_path = os.path.join(
|
| 98 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 99 |
+
)
|
| 100 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 101 |
+
|
| 102 |
+
# Get the indices of the N views in the scene
|
| 103 |
+
view_indices = self._sample_view_indices(
|
| 104 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Get the views corresponding to the selected view indices
|
| 108 |
+
views = []
|
| 109 |
+
for view_index in view_indices:
|
| 110 |
+
# Load the data corresponding to the view
|
| 111 |
+
view_file_name = scene_file_names[view_index]
|
| 112 |
+
view_data = load_frame(
|
| 113 |
+
scene_root,
|
| 114 |
+
view_file_name,
|
| 115 |
+
modalities=["image", "depth"],
|
| 116 |
+
scene_meta=scene_meta,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Convert necessary data to numpy
|
| 120 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 121 |
+
image = image[:, :, :3] # RGBA to RGB
|
| 122 |
+
image = (image * 255).astype(np.uint8)
|
| 123 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 124 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 125 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 126 |
+
|
| 127 |
+
# Ensure that the depthmap has all valid values
|
| 128 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 129 |
+
|
| 130 |
+
# Resize the data to match the desired resolution
|
| 131 |
+
image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 132 |
+
image=image,
|
| 133 |
+
resolution=resolution,
|
| 134 |
+
depthmap=depthmap,
|
| 135 |
+
intrinsics=intrinsics,
|
| 136 |
+
additional_quantities=None,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Append the view dictionary to the list of views
|
| 140 |
+
views.append(
|
| 141 |
+
dict(
|
| 142 |
+
img=image,
|
| 143 |
+
depthmap=depthmap,
|
| 144 |
+
camera_pose=c2w_pose, # cam2world
|
| 145 |
+
camera_intrinsics=intrinsics,
|
| 146 |
+
dataset="DynamicReplica",
|
| 147 |
+
label=scene_name,
|
| 148 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return views
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_parser():
|
| 156 |
+
import argparse
|
| 157 |
+
|
| 158 |
+
parser = argparse.ArgumentParser()
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/dynamicreplica", type=str
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"-dmd",
|
| 164 |
+
"--dataset_metadata_dir",
|
| 165 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 166 |
+
type=str,
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"-nv",
|
| 170 |
+
"--num_of_views",
|
| 171 |
+
default=2,
|
| 172 |
+
type=int,
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument("--viz", action="store_true")
|
| 175 |
+
|
| 176 |
+
return parser
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
import rerun as rr
|
| 181 |
+
from tqdm import tqdm
|
| 182 |
+
|
| 183 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 184 |
+
from mapanything.utils.image import rgb
|
| 185 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 186 |
+
|
| 187 |
+
parser = get_parser()
|
| 188 |
+
script_add_rerun_args(
|
| 189 |
+
parser
|
| 190 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 191 |
+
args = parser.parse_args()
|
| 192 |
+
|
| 193 |
+
dataset = DynamicReplicaWAI(
|
| 194 |
+
num_views=args.num_of_views,
|
| 195 |
+
split="train",
|
| 196 |
+
covisibility_thres=0.25,
|
| 197 |
+
ROOT=args.root_dir,
|
| 198 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 199 |
+
resolution=(518, 294),
|
| 200 |
+
aug_crop=16,
|
| 201 |
+
transform="colorjitter+grayscale+gaublur",
|
| 202 |
+
data_norm_type="dinov2",
|
| 203 |
+
)
|
| 204 |
+
# dataset = DynamicReplicaWAI(
|
| 205 |
+
# num_views=args.num_of_views,
|
| 206 |
+
# split="val",
|
| 207 |
+
# covisibility_thres=0.25,
|
| 208 |
+
# ROOT=args.root_dir,
|
| 209 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 210 |
+
# resolution=(518, 294),
|
| 211 |
+
# seed=777,
|
| 212 |
+
# transform="imgnorm",
|
| 213 |
+
# data_norm_type="dinov2",
|
| 214 |
+
# )
|
| 215 |
+
print(dataset.get_stats())
|
| 216 |
+
|
| 217 |
+
if args.viz:
|
| 218 |
+
rr.script_setup(args, "DynamicReplica_Dataloader")
|
| 219 |
+
rr.set_time("stable_time", sequence=0)
|
| 220 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 221 |
+
|
| 222 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 223 |
+
|
| 224 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 225 |
+
views = dataset[idx]
|
| 226 |
+
assert len(views) == args.num_of_views
|
| 227 |
+
sample_name = f"{idx}"
|
| 228 |
+
for view_idx in range(args.num_of_views):
|
| 229 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 230 |
+
print(sample_name)
|
| 231 |
+
for view_idx in range(args.num_of_views):
|
| 232 |
+
image = rgb(
|
| 233 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 234 |
+
)
|
| 235 |
+
depthmap = views[view_idx]["depthmap"]
|
| 236 |
+
pose = views[view_idx]["camera_pose"]
|
| 237 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 238 |
+
pts3d = views[view_idx]["pts3d"]
|
| 239 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 240 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 241 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 242 |
+
else:
|
| 243 |
+
non_ambiguous_mask = None
|
| 244 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 245 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 246 |
+
else:
|
| 247 |
+
prior_depth_along_ray = None
|
| 248 |
+
if args.viz:
|
| 249 |
+
rr.set_time("stable_time", sequence=num)
|
| 250 |
+
base_name = f"world/view_{view_idx}"
|
| 251 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 252 |
+
# Log camera info and loaded data
|
| 253 |
+
height, width = image.shape[0], image.shape[1]
|
| 254 |
+
rr.log(
|
| 255 |
+
base_name,
|
| 256 |
+
rr.Transform3D(
|
| 257 |
+
translation=pose[:3, 3],
|
| 258 |
+
mat3x3=pose[:3, :3],
|
| 259 |
+
),
|
| 260 |
+
)
|
| 261 |
+
rr.log(
|
| 262 |
+
f"{base_name}/pinhole",
|
| 263 |
+
rr.Pinhole(
|
| 264 |
+
image_from_camera=intrinsics,
|
| 265 |
+
height=height,
|
| 266 |
+
width=width,
|
| 267 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 268 |
+
),
|
| 269 |
+
)
|
| 270 |
+
rr.log(
|
| 271 |
+
f"{base_name}/pinhole/rgb",
|
| 272 |
+
rr.Image(image),
|
| 273 |
+
)
|
| 274 |
+
rr.log(
|
| 275 |
+
f"{base_name}/pinhole/depth",
|
| 276 |
+
rr.DepthImage(depthmap),
|
| 277 |
+
)
|
| 278 |
+
if prior_depth_along_ray is not None:
|
| 279 |
+
rr.log(
|
| 280 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 281 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 282 |
+
)
|
| 283 |
+
if non_ambiguous_mask is not None:
|
| 284 |
+
rr.log(
|
| 285 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 286 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 287 |
+
)
|
| 288 |
+
# Log points in 3D
|
| 289 |
+
filtered_pts = pts3d[valid_mask]
|
| 290 |
+
filtered_pts_col = image[valid_mask]
|
| 291 |
+
rr.log(
|
| 292 |
+
pts_name,
|
| 293 |
+
rr.Points3D(
|
| 294 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 295 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 296 |
+
),
|
| 297 |
+
)
|
mapanything/datasets/wai/eth3d.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
ETH3D Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 15 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ETH3DWAI(BaseDataset):
|
| 19 |
+
"""
|
| 20 |
+
ETH3D dataset containing high-quality outdoor and indoor scans of the ETH Zurich campus.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*args,
|
| 26 |
+
ROOT,
|
| 27 |
+
dataset_metadata_dir,
|
| 28 |
+
overfit_num_sets=None,
|
| 29 |
+
sample_specific_scene: bool = False,
|
| 30 |
+
specific_scene_name: str = None,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Initialize the dataset attributes.
|
| 35 |
+
Args:
|
| 36 |
+
ROOT: Root directory of the dataset.
|
| 37 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 38 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 39 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 40 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 41 |
+
"""
|
| 42 |
+
# Initialize the dataset attributes
|
| 43 |
+
super().__init__(*args, **kwargs)
|
| 44 |
+
self.ROOT = ROOT
|
| 45 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 46 |
+
self.split = "test"
|
| 47 |
+
self.overfit_num_sets = overfit_num_sets
|
| 48 |
+
self.sample_specific_scene = sample_specific_scene
|
| 49 |
+
self.specific_scene_name = specific_scene_name
|
| 50 |
+
self._load_data()
|
| 51 |
+
|
| 52 |
+
# Define the dataset type flags
|
| 53 |
+
self.is_metric_scale = True
|
| 54 |
+
self.is_synthetic = False
|
| 55 |
+
|
| 56 |
+
def _load_data(self):
|
| 57 |
+
"Load the precomputed dataset metadata"
|
| 58 |
+
# Load the dataset metadata corresponding to the split
|
| 59 |
+
split_metadata_path = os.path.join(
|
| 60 |
+
self.dataset_metadata_dir,
|
| 61 |
+
self.split,
|
| 62 |
+
f"eth3d_scene_list_{self.split}.npy",
|
| 63 |
+
)
|
| 64 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 65 |
+
|
| 66 |
+
# Get the list of all scenes
|
| 67 |
+
if not self.sample_specific_scene:
|
| 68 |
+
self.scenes = list(split_scene_list)
|
| 69 |
+
else:
|
| 70 |
+
self.scenes = [self.specific_scene_name]
|
| 71 |
+
self.num_of_scenes = len(self.scenes)
|
| 72 |
+
|
| 73 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 74 |
+
# Get the scene name of the sampled index
|
| 75 |
+
scene_index = sampled_idx
|
| 76 |
+
scene_name = self.scenes[scene_index]
|
| 77 |
+
|
| 78 |
+
# Get the metadata corresponding to the scene
|
| 79 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 80 |
+
scene_meta = load_data(
|
| 81 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 82 |
+
)
|
| 83 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 84 |
+
num_views_in_scene = len(scene_file_names)
|
| 85 |
+
|
| 86 |
+
# Load the scene pairwise covisibility mmap
|
| 87 |
+
covisibility_version_key = "v0"
|
| 88 |
+
covisibility_map_dir = os.path.join(
|
| 89 |
+
scene_root, "covisibility", covisibility_version_key
|
| 90 |
+
)
|
| 91 |
+
# Assumes only npy file in directory is covisibility map
|
| 92 |
+
covisibility_map_name = next(
|
| 93 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 94 |
+
)
|
| 95 |
+
covisibility_map_path = os.path.join(
|
| 96 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 97 |
+
)
|
| 98 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 99 |
+
|
| 100 |
+
# Get the indices of the N views in the scene
|
| 101 |
+
view_indices = self._sample_view_indices(
|
| 102 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Get the views corresponding to the selected view indices
|
| 106 |
+
views = []
|
| 107 |
+
for view_index in view_indices:
|
| 108 |
+
# Load the data corresponding to the view
|
| 109 |
+
view_file_name = scene_file_names[view_index]
|
| 110 |
+
view_data = load_frame(
|
| 111 |
+
scene_root,
|
| 112 |
+
view_file_name,
|
| 113 |
+
modalities=["image", "depth"],
|
| 114 |
+
scene_meta=scene_meta,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Convert necessary data to numpy
|
| 118 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 119 |
+
image = (image * 255).astype(np.uint8)
|
| 120 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 121 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 122 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 123 |
+
|
| 124 |
+
# Resize the data to match the desired resolution
|
| 125 |
+
image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 126 |
+
image=image,
|
| 127 |
+
resolution=resolution,
|
| 128 |
+
depthmap=depthmap,
|
| 129 |
+
intrinsics=intrinsics,
|
| 130 |
+
additional_quantities=None,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Append the view dictionary to the list of views
|
| 134 |
+
views.append(
|
| 135 |
+
dict(
|
| 136 |
+
img=image,
|
| 137 |
+
depthmap=depthmap,
|
| 138 |
+
camera_pose=c2w_pose, # cam2world
|
| 139 |
+
camera_intrinsics=intrinsics,
|
| 140 |
+
dataset="ETH3D",
|
| 141 |
+
label=scene_name,
|
| 142 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return views
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_parser():
|
| 150 |
+
import argparse
|
| 151 |
+
|
| 152 |
+
parser = argparse.ArgumentParser()
|
| 153 |
+
parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/eth3d", type=str)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"-dmd",
|
| 156 |
+
"--dataset_metadata_dir",
|
| 157 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 158 |
+
type=str,
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument(
|
| 161 |
+
"-nv",
|
| 162 |
+
"--num_of_views",
|
| 163 |
+
default=2,
|
| 164 |
+
type=int,
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument("--viz", action="store_true")
|
| 167 |
+
|
| 168 |
+
return parser
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
import rerun as rr
|
| 173 |
+
from tqdm import tqdm
|
| 174 |
+
|
| 175 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 176 |
+
from mapanything.utils.image import rgb
|
| 177 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 178 |
+
|
| 179 |
+
parser = get_parser()
|
| 180 |
+
script_add_rerun_args(
|
| 181 |
+
parser
|
| 182 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 183 |
+
args = parser.parse_args()
|
| 184 |
+
|
| 185 |
+
dataset = ETH3DWAI(
|
| 186 |
+
num_views=args.num_of_views,
|
| 187 |
+
covisibility_thres=0.025,
|
| 188 |
+
ROOT=args.root_dir,
|
| 189 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 190 |
+
resolution=(518, 336),
|
| 191 |
+
seed=777,
|
| 192 |
+
transform="imgnorm",
|
| 193 |
+
data_norm_type="dinov2",
|
| 194 |
+
)
|
| 195 |
+
print(dataset.get_stats())
|
| 196 |
+
|
| 197 |
+
if args.viz:
|
| 198 |
+
rr.script_setup(args, "ETH3D_Dataloader")
|
| 199 |
+
rr.set_time("stable_time", sequence=0)
|
| 200 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 201 |
+
|
| 202 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 203 |
+
|
| 204 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 205 |
+
views = dataset[idx]
|
| 206 |
+
assert len(views) == args.num_of_views
|
| 207 |
+
sample_name = f"{idx}"
|
| 208 |
+
for view_idx in range(args.num_of_views):
|
| 209 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 210 |
+
print(sample_name)
|
| 211 |
+
for view_idx in range(args.num_of_views):
|
| 212 |
+
image = rgb(
|
| 213 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 214 |
+
)
|
| 215 |
+
depthmap = views[view_idx]["depthmap"]
|
| 216 |
+
pose = views[view_idx]["camera_pose"]
|
| 217 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 218 |
+
pts3d = views[view_idx]["pts3d"]
|
| 219 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 220 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 221 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 222 |
+
else:
|
| 223 |
+
non_ambiguous_mask = None
|
| 224 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 225 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 226 |
+
else:
|
| 227 |
+
prior_depth_along_ray = None
|
| 228 |
+
if args.viz:
|
| 229 |
+
rr.set_time("stable_time", sequence=num)
|
| 230 |
+
base_name = f"world/view_{view_idx}"
|
| 231 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 232 |
+
# Log camera info and loaded data
|
| 233 |
+
height, width = image.shape[0], image.shape[1]
|
| 234 |
+
rr.log(
|
| 235 |
+
base_name,
|
| 236 |
+
rr.Transform3D(
|
| 237 |
+
translation=pose[:3, 3],
|
| 238 |
+
mat3x3=pose[:3, :3],
|
| 239 |
+
),
|
| 240 |
+
)
|
| 241 |
+
rr.log(
|
| 242 |
+
f"{base_name}/pinhole",
|
| 243 |
+
rr.Pinhole(
|
| 244 |
+
image_from_camera=intrinsics,
|
| 245 |
+
height=height,
|
| 246 |
+
width=width,
|
| 247 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 248 |
+
),
|
| 249 |
+
)
|
| 250 |
+
rr.log(
|
| 251 |
+
f"{base_name}/pinhole/rgb",
|
| 252 |
+
rr.Image(image),
|
| 253 |
+
)
|
| 254 |
+
rr.log(
|
| 255 |
+
f"{base_name}/pinhole/depth",
|
| 256 |
+
rr.DepthImage(depthmap),
|
| 257 |
+
)
|
| 258 |
+
if prior_depth_along_ray is not None:
|
| 259 |
+
rr.log(
|
| 260 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 261 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 262 |
+
)
|
| 263 |
+
if non_ambiguous_mask is not None:
|
| 264 |
+
rr.log(
|
| 265 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 266 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 267 |
+
)
|
| 268 |
+
# Log points in 3D
|
| 269 |
+
filtered_pts = pts3d[valid_mask]
|
| 270 |
+
filtered_pts_col = image[valid_mask]
|
| 271 |
+
rr.log(
|
| 272 |
+
pts_name,
|
| 273 |
+
rr.Points3D(
|
| 274 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 275 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 276 |
+
),
|
| 277 |
+
)
|
mapanything/datasets/wai/megadepth.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
MegaDepth Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 16 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MegaDepthWAI(BaseDataset):
|
| 20 |
+
"""
|
| 21 |
+
MegaDepth dataset containing outdoor phototourism and in-the-wild scenes.
|
| 22 |
+
Also includes Tanks & Temples scenes.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
*args,
|
| 28 |
+
ROOT,
|
| 29 |
+
dataset_metadata_dir,
|
| 30 |
+
split,
|
| 31 |
+
overfit_num_sets=None,
|
| 32 |
+
sample_specific_scene: bool = False,
|
| 33 |
+
specific_scene_name: str = None,
|
| 34 |
+
**kwargs,
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Initialize the dataset attributes.
|
| 38 |
+
Args:
|
| 39 |
+
ROOT: Root directory of the dataset.
|
| 40 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 41 |
+
split: Dataset split (train, val, test).
|
| 42 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 43 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 44 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 45 |
+
"""
|
| 46 |
+
# Initialize the dataset attributes
|
| 47 |
+
super().__init__(*args, **kwargs)
|
| 48 |
+
self.ROOT = ROOT
|
| 49 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 50 |
+
self.split = split
|
| 51 |
+
self.overfit_num_sets = overfit_num_sets
|
| 52 |
+
self.sample_specific_scene = sample_specific_scene
|
| 53 |
+
self.specific_scene_name = specific_scene_name
|
| 54 |
+
self._load_data()
|
| 55 |
+
|
| 56 |
+
# Define the dataset type flags
|
| 57 |
+
self.is_metric_scale = False
|
| 58 |
+
self.is_synthetic = False
|
| 59 |
+
|
| 60 |
+
def _load_data(self):
|
| 61 |
+
"Load the precomputed dataset metadata"
|
| 62 |
+
# Load the dataset metadata corresponding to the split
|
| 63 |
+
split_metadata_path = os.path.join(
|
| 64 |
+
self.dataset_metadata_dir,
|
| 65 |
+
self.split,
|
| 66 |
+
f"megadepth_scene_list_{self.split}.npy",
|
| 67 |
+
)
|
| 68 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 69 |
+
|
| 70 |
+
# Get the list of all scenes
|
| 71 |
+
if not self.sample_specific_scene:
|
| 72 |
+
self.scenes = list(split_scene_list)
|
| 73 |
+
else:
|
| 74 |
+
self.scenes = [self.specific_scene_name]
|
| 75 |
+
self.num_of_scenes = len(self.scenes)
|
| 76 |
+
|
| 77 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 78 |
+
# Get the scene name of the sampled index
|
| 79 |
+
scene_index = sampled_idx
|
| 80 |
+
scene_name = self.scenes[scene_index]
|
| 81 |
+
|
| 82 |
+
# Get the metadata corresponding to the scene
|
| 83 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 84 |
+
scene_meta = load_data(
|
| 85 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 86 |
+
)
|
| 87 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 88 |
+
num_views_in_scene = len(scene_file_names)
|
| 89 |
+
|
| 90 |
+
# Load the scene pairwise covisibility mmap
|
| 91 |
+
covisibility_version_key = "v0"
|
| 92 |
+
covisibility_map_dir = os.path.join(
|
| 93 |
+
scene_root, "covisibility", covisibility_version_key
|
| 94 |
+
)
|
| 95 |
+
# Assumes only npy file in directory is covisibility map
|
| 96 |
+
covisibility_map_name = next(
|
| 97 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 98 |
+
)
|
| 99 |
+
covisibility_map_path = os.path.join(
|
| 100 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 101 |
+
)
|
| 102 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 103 |
+
|
| 104 |
+
# Get the indices of the N views in the scene
|
| 105 |
+
view_indices = self._sample_view_indices(
|
| 106 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Get the views corresponding to the selected view indices
|
| 110 |
+
views = []
|
| 111 |
+
for view_index in view_indices:
|
| 112 |
+
# Load the data corresponding to the view
|
| 113 |
+
view_file_name = scene_file_names[view_index]
|
| 114 |
+
view_data = load_frame(
|
| 115 |
+
scene_root,
|
| 116 |
+
view_file_name,
|
| 117 |
+
modalities=["image", "depth", "pred_mask/moge2"],
|
| 118 |
+
scene_meta=scene_meta,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Convert necessary data to numpy
|
| 122 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 123 |
+
image = (image * 255).astype(np.uint8)
|
| 124 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 125 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 126 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 127 |
+
|
| 128 |
+
# Ensure that the depthmap has all valid values
|
| 129 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 130 |
+
|
| 131 |
+
# Get the non_ambiguous_mask and ensure it matches image resolution
|
| 132 |
+
non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
|
| 133 |
+
non_ambiguous_mask = cv2.resize(
|
| 134 |
+
non_ambiguous_mask,
|
| 135 |
+
(image.shape[1], image.shape[0]),
|
| 136 |
+
interpolation=cv2.INTER_NEAREST,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Mask out the GT depth using the non_ambiguous_mask
|
| 140 |
+
depthmap = np.where(non_ambiguous_mask, depthmap, 0)
|
| 141 |
+
|
| 142 |
+
# Resize the data to match the desired resolution
|
| 143 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 144 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 145 |
+
self._crop_resize_if_necessary(
|
| 146 |
+
image=image,
|
| 147 |
+
resolution=resolution,
|
| 148 |
+
depthmap=depthmap,
|
| 149 |
+
intrinsics=intrinsics,
|
| 150 |
+
additional_quantities=additional_quantities_to_resize,
|
| 151 |
+
)
|
| 152 |
+
)
|
| 153 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 154 |
+
|
| 155 |
+
# Append the view dictionary to the list of views
|
| 156 |
+
views.append(
|
| 157 |
+
dict(
|
| 158 |
+
img=image,
|
| 159 |
+
depthmap=depthmap,
|
| 160 |
+
camera_pose=c2w_pose, # cam2world
|
| 161 |
+
camera_intrinsics=intrinsics,
|
| 162 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 163 |
+
dataset="MegaDepth",
|
| 164 |
+
label=scene_name,
|
| 165 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
return views
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_parser():
|
| 173 |
+
import argparse
|
| 174 |
+
|
| 175 |
+
parser = argparse.ArgumentParser()
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/megadepth", type=str
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"-dmd",
|
| 181 |
+
"--dataset_metadata_dir",
|
| 182 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 183 |
+
type=str,
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"-nv",
|
| 187 |
+
"--num_of_views",
|
| 188 |
+
default=2,
|
| 189 |
+
type=int,
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument("--viz", action="store_true")
|
| 192 |
+
|
| 193 |
+
return parser
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
import rerun as rr
|
| 198 |
+
from tqdm import tqdm
|
| 199 |
+
|
| 200 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 201 |
+
from mapanything.utils.image import rgb
|
| 202 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 203 |
+
|
| 204 |
+
parser = get_parser()
|
| 205 |
+
script_add_rerun_args(
|
| 206 |
+
parser
|
| 207 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 208 |
+
args = parser.parse_args()
|
| 209 |
+
|
| 210 |
+
dataset = MegaDepthWAI(
|
| 211 |
+
num_views=args.num_of_views,
|
| 212 |
+
split="train",
|
| 213 |
+
covisibility_thres=0.25,
|
| 214 |
+
ROOT=args.root_dir,
|
| 215 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 216 |
+
resolution=(518, 336),
|
| 217 |
+
aug_crop=16,
|
| 218 |
+
transform="colorjitter+grayscale+gaublur",
|
| 219 |
+
data_norm_type="dinov2",
|
| 220 |
+
)
|
| 221 |
+
# dataset = MegaDepthWAI(
|
| 222 |
+
# num_views=args.num_of_views,
|
| 223 |
+
# split="val",
|
| 224 |
+
# covisibility_thres=0.25,
|
| 225 |
+
# ROOT=args.root_dir,
|
| 226 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 227 |
+
# resolution=(518, 336),
|
| 228 |
+
# seed=777,
|
| 229 |
+
# transform="imgnorm",
|
| 230 |
+
# data_norm_type="dinov2",
|
| 231 |
+
# )
|
| 232 |
+
print(dataset.get_stats())
|
| 233 |
+
|
| 234 |
+
if args.viz:
|
| 235 |
+
rr.script_setup(args, "MegaDepth_Dataloader")
|
| 236 |
+
rr.set_time("stable_time", sequence=0)
|
| 237 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 238 |
+
|
| 239 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 240 |
+
|
| 241 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 242 |
+
views = dataset[idx]
|
| 243 |
+
assert len(views) == args.num_of_views
|
| 244 |
+
sample_name = f"{idx}"
|
| 245 |
+
for view_idx in range(args.num_of_views):
|
| 246 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 247 |
+
print(sample_name)
|
| 248 |
+
for view_idx in range(args.num_of_views):
|
| 249 |
+
image = rgb(
|
| 250 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 251 |
+
)
|
| 252 |
+
depthmap = views[view_idx]["depthmap"]
|
| 253 |
+
pose = views[view_idx]["camera_pose"]
|
| 254 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 255 |
+
pts3d = views[view_idx]["pts3d"]
|
| 256 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 257 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 258 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 259 |
+
else:
|
| 260 |
+
non_ambiguous_mask = None
|
| 261 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 262 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 263 |
+
else:
|
| 264 |
+
prior_depth_along_ray = None
|
| 265 |
+
if args.viz:
|
| 266 |
+
rr.set_time("stable_time", sequence=num)
|
| 267 |
+
base_name = f"world/view_{view_idx}"
|
| 268 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 269 |
+
# Log camera info and loaded data
|
| 270 |
+
height, width = image.shape[0], image.shape[1]
|
| 271 |
+
rr.log(
|
| 272 |
+
base_name,
|
| 273 |
+
rr.Transform3D(
|
| 274 |
+
translation=pose[:3, 3],
|
| 275 |
+
mat3x3=pose[:3, :3],
|
| 276 |
+
),
|
| 277 |
+
)
|
| 278 |
+
rr.log(
|
| 279 |
+
f"{base_name}/pinhole",
|
| 280 |
+
rr.Pinhole(
|
| 281 |
+
image_from_camera=intrinsics,
|
| 282 |
+
height=height,
|
| 283 |
+
width=width,
|
| 284 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 285 |
+
),
|
| 286 |
+
)
|
| 287 |
+
rr.log(
|
| 288 |
+
f"{base_name}/pinhole/rgb",
|
| 289 |
+
rr.Image(image),
|
| 290 |
+
)
|
| 291 |
+
rr.log(
|
| 292 |
+
f"{base_name}/pinhole/depth",
|
| 293 |
+
rr.DepthImage(depthmap),
|
| 294 |
+
)
|
| 295 |
+
if prior_depth_along_ray is not None:
|
| 296 |
+
rr.log(
|
| 297 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 298 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 299 |
+
)
|
| 300 |
+
if non_ambiguous_mask is not None:
|
| 301 |
+
rr.log(
|
| 302 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 303 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 304 |
+
)
|
| 305 |
+
# Log points in 3D
|
| 306 |
+
filtered_pts = pts3d[valid_mask]
|
| 307 |
+
filtered_pts_col = image[valid_mask]
|
| 308 |
+
rr.log(
|
| 309 |
+
pts_name,
|
| 310 |
+
rr.Points3D(
|
| 311 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 312 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 313 |
+
),
|
| 314 |
+
)
|
mapanything/datasets/wai/mpsd.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
MPSD Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 16 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MPSDWAI(BaseDataset):
|
| 20 |
+
"""
|
| 21 |
+
MPSD dataset containing outdoor planet scale metric reconstructions.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
*args,
|
| 27 |
+
ROOT,
|
| 28 |
+
dataset_metadata_dir,
|
| 29 |
+
split,
|
| 30 |
+
overfit_num_sets=None,
|
| 31 |
+
sample_specific_scene: bool = False,
|
| 32 |
+
specific_scene_name: str = None,
|
| 33 |
+
**kwargs,
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the dataset attributes.
|
| 37 |
+
Args:
|
| 38 |
+
ROOT: Root directory of the dataset.
|
| 39 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 40 |
+
split: Dataset split (train, val, test).
|
| 41 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 42 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 43 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 44 |
+
"""
|
| 45 |
+
# Initialize the dataset attributes
|
| 46 |
+
super().__init__(*args, **kwargs)
|
| 47 |
+
self.ROOT = ROOT
|
| 48 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 49 |
+
self.split = split
|
| 50 |
+
self.overfit_num_sets = overfit_num_sets
|
| 51 |
+
self.sample_specific_scene = sample_specific_scene
|
| 52 |
+
self.specific_scene_name = specific_scene_name
|
| 53 |
+
self._load_data()
|
| 54 |
+
|
| 55 |
+
# Define the dataset type flags
|
| 56 |
+
self.is_metric_scale = True
|
| 57 |
+
self.is_synthetic = False
|
| 58 |
+
|
| 59 |
+
def _load_data(self):
|
| 60 |
+
"Load the precomputed dataset metadata"
|
| 61 |
+
# Load the dataset metadata corresponding to the split
|
| 62 |
+
split_metadata_path = os.path.join(
|
| 63 |
+
self.dataset_metadata_dir,
|
| 64 |
+
self.split,
|
| 65 |
+
f"mpsd_scene_list_{self.split}.npy",
|
| 66 |
+
)
|
| 67 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 68 |
+
|
| 69 |
+
# Get the list of all scenes
|
| 70 |
+
if not self.sample_specific_scene:
|
| 71 |
+
self.scenes = list(split_scene_list)
|
| 72 |
+
else:
|
| 73 |
+
self.scenes = [self.specific_scene_name]
|
| 74 |
+
self.num_of_scenes = len(self.scenes)
|
| 75 |
+
|
| 76 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 77 |
+
# Get the scene name of the sampled index
|
| 78 |
+
scene_index = sampled_idx
|
| 79 |
+
scene_name = self.scenes[scene_index]
|
| 80 |
+
|
| 81 |
+
# Get the metadata corresponding to the scene
|
| 82 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 83 |
+
scene_meta = load_data(
|
| 84 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 85 |
+
)
|
| 86 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 87 |
+
num_views_in_scene = len(scene_file_names)
|
| 88 |
+
|
| 89 |
+
# Load the scene pairwise covisibility mmap
|
| 90 |
+
covisibility_version_key = "v0"
|
| 91 |
+
covisibility_map_dir = os.path.join(
|
| 92 |
+
scene_root, "covisibility", covisibility_version_key
|
| 93 |
+
)
|
| 94 |
+
# Assumes only npy file in directory is covisibility map
|
| 95 |
+
covisibility_map_name = next(
|
| 96 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 97 |
+
)
|
| 98 |
+
covisibility_map_path = os.path.join(
|
| 99 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 100 |
+
)
|
| 101 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 102 |
+
|
| 103 |
+
# Get the indices of the N views in the scene
|
| 104 |
+
view_indices = self._sample_view_indices(
|
| 105 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Get the views corresponding to the selected view indices
|
| 109 |
+
views = []
|
| 110 |
+
for view_index in view_indices:
|
| 111 |
+
# Load the data corresponding to the view
|
| 112 |
+
view_file_name = scene_file_names[view_index]
|
| 113 |
+
view_data = load_frame(
|
| 114 |
+
scene_root,
|
| 115 |
+
view_file_name,
|
| 116 |
+
modalities=["image", "depth", "pred_mask/moge2"],
|
| 117 |
+
scene_meta=scene_meta,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Convert necessary data to numpy
|
| 121 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 122 |
+
image = (image * 255).astype(np.uint8)
|
| 123 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 124 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 125 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 126 |
+
|
| 127 |
+
# Ensure that the depthmap has all valid values
|
| 128 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 129 |
+
|
| 130 |
+
# Get the non_ambiguous_mask and ensure it matches image resolution
|
| 131 |
+
non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
|
| 132 |
+
non_ambiguous_mask = cv2.resize(
|
| 133 |
+
non_ambiguous_mask,
|
| 134 |
+
(image.shape[1], image.shape[0]),
|
| 135 |
+
interpolation=cv2.INTER_NEAREST,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Mask out the GT depth using the non_ambiguous_mask
|
| 139 |
+
depthmap = np.where(non_ambiguous_mask, depthmap, 0)
|
| 140 |
+
|
| 141 |
+
# Resize the data to match the desired resolution
|
| 142 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 143 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 144 |
+
self._crop_resize_if_necessary(
|
| 145 |
+
image=image,
|
| 146 |
+
resolution=resolution,
|
| 147 |
+
depthmap=depthmap,
|
| 148 |
+
intrinsics=intrinsics,
|
| 149 |
+
additional_quantities=additional_quantities_to_resize,
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 153 |
+
|
| 154 |
+
# Append the view dictionary to the list of views
|
| 155 |
+
views.append(
|
| 156 |
+
dict(
|
| 157 |
+
img=image,
|
| 158 |
+
depthmap=depthmap,
|
| 159 |
+
camera_pose=c2w_pose, # cam2world
|
| 160 |
+
camera_intrinsics=intrinsics,
|
| 161 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 162 |
+
dataset="MPSD",
|
| 163 |
+
label=scene_name,
|
| 164 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 165 |
+
)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return views
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def get_parser():
|
| 172 |
+
import argparse
|
| 173 |
+
|
| 174 |
+
parser = argparse.ArgumentParser()
|
| 175 |
+
parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/mpsd", type=str)
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"-dmd",
|
| 178 |
+
"--dataset_metadata_dir",
|
| 179 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 180 |
+
type=str,
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"-nv",
|
| 184 |
+
"--num_of_views",
|
| 185 |
+
default=2,
|
| 186 |
+
type=int,
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument("--viz", action="store_true")
|
| 189 |
+
|
| 190 |
+
return parser
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
import rerun as rr
|
| 195 |
+
from tqdm import tqdm
|
| 196 |
+
|
| 197 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 198 |
+
from mapanything.utils.image import rgb
|
| 199 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 200 |
+
|
| 201 |
+
parser = get_parser()
|
| 202 |
+
script_add_rerun_args(
|
| 203 |
+
parser
|
| 204 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 205 |
+
args = parser.parse_args()
|
| 206 |
+
|
| 207 |
+
dataset = MPSDWAI(
|
| 208 |
+
num_views=args.num_of_views,
|
| 209 |
+
split="train",
|
| 210 |
+
covisibility_thres=0.15,
|
| 211 |
+
ROOT=args.root_dir,
|
| 212 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 213 |
+
resolution=(518, 392),
|
| 214 |
+
aug_crop=16,
|
| 215 |
+
transform="colorjitter+grayscale+gaublur",
|
| 216 |
+
data_norm_type="dinov2",
|
| 217 |
+
)
|
| 218 |
+
# dataset = MPSDWAI(
|
| 219 |
+
# num_views=args.num_of_views,
|
| 220 |
+
# split="val",
|
| 221 |
+
# covisibility_thres=0.15,
|
| 222 |
+
# ROOT=args.root_dir,
|
| 223 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 224 |
+
# resolution=(518, 392),
|
| 225 |
+
# seed=777,
|
| 226 |
+
# transform="imgnorm",
|
| 227 |
+
# data_norm_type="dinov2",
|
| 228 |
+
# )
|
| 229 |
+
print(dataset.get_stats())
|
| 230 |
+
|
| 231 |
+
if args.viz:
|
| 232 |
+
rr.script_setup(args, "MPSD_Dataloader")
|
| 233 |
+
rr.set_time("stable_time", sequence=0)
|
| 234 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 235 |
+
|
| 236 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 237 |
+
|
| 238 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 239 |
+
views = dataset[idx]
|
| 240 |
+
assert len(views) == args.num_of_views
|
| 241 |
+
sample_name = f"{idx}"
|
| 242 |
+
for view_idx in range(args.num_of_views):
|
| 243 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 244 |
+
print(sample_name)
|
| 245 |
+
for view_idx in range(args.num_of_views):
|
| 246 |
+
image = rgb(
|
| 247 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 248 |
+
)
|
| 249 |
+
depthmap = views[view_idx]["depthmap"]
|
| 250 |
+
pose = views[view_idx]["camera_pose"]
|
| 251 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 252 |
+
pts3d = views[view_idx]["pts3d"]
|
| 253 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 254 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 255 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 256 |
+
else:
|
| 257 |
+
non_ambiguous_mask = None
|
| 258 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 259 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 260 |
+
else:
|
| 261 |
+
prior_depth_along_ray = None
|
| 262 |
+
if args.viz:
|
| 263 |
+
rr.set_time("stable_time", sequence=num)
|
| 264 |
+
base_name = f"world/view_{view_idx}"
|
| 265 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 266 |
+
# Log camera info and loaded data
|
| 267 |
+
height, width = image.shape[0], image.shape[1]
|
| 268 |
+
rr.log(
|
| 269 |
+
base_name,
|
| 270 |
+
rr.Transform3D(
|
| 271 |
+
translation=pose[:3, 3],
|
| 272 |
+
mat3x3=pose[:3, :3],
|
| 273 |
+
),
|
| 274 |
+
)
|
| 275 |
+
rr.log(
|
| 276 |
+
f"{base_name}/pinhole",
|
| 277 |
+
rr.Pinhole(
|
| 278 |
+
image_from_camera=intrinsics,
|
| 279 |
+
height=height,
|
| 280 |
+
width=width,
|
| 281 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 282 |
+
),
|
| 283 |
+
)
|
| 284 |
+
rr.log(
|
| 285 |
+
f"{base_name}/pinhole/rgb",
|
| 286 |
+
rr.Image(image),
|
| 287 |
+
)
|
| 288 |
+
rr.log(
|
| 289 |
+
f"{base_name}/pinhole/depth",
|
| 290 |
+
rr.DepthImage(depthmap),
|
| 291 |
+
)
|
| 292 |
+
if prior_depth_along_ray is not None:
|
| 293 |
+
rr.log(
|
| 294 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 295 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 296 |
+
)
|
| 297 |
+
if non_ambiguous_mask is not None:
|
| 298 |
+
rr.log(
|
| 299 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 300 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 301 |
+
)
|
| 302 |
+
# Log points in 3D
|
| 303 |
+
filtered_pts = pts3d[valid_mask]
|
| 304 |
+
filtered_pts_col = image[valid_mask]
|
| 305 |
+
rr.log(
|
| 306 |
+
pts_name,
|
| 307 |
+
rr.Points3D(
|
| 308 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 309 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 310 |
+
),
|
| 311 |
+
)
|
mapanything/datasets/wai/mvs_synth.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
MVS Synth Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 15 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MVSSynthWAI(BaseDataset):
|
| 19 |
+
"""
|
| 20 |
+
MVS Synth dataset containing large diversity of synthetic in-the-wild scenes.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*args,
|
| 26 |
+
ROOT,
|
| 27 |
+
dataset_metadata_dir,
|
| 28 |
+
split,
|
| 29 |
+
overfit_num_sets=None,
|
| 30 |
+
sample_specific_scene: bool = False,
|
| 31 |
+
specific_scene_name: str = None,
|
| 32 |
+
**kwargs,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the dataset attributes.
|
| 36 |
+
Args:
|
| 37 |
+
ROOT: Root directory of the dataset.
|
| 38 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 39 |
+
split: Dataset split (train, val, test).
|
| 40 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 41 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 42 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 43 |
+
"""
|
| 44 |
+
# Initialize the dataset attributes
|
| 45 |
+
super().__init__(*args, **kwargs)
|
| 46 |
+
self.ROOT = ROOT
|
| 47 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 48 |
+
self.split = split
|
| 49 |
+
self.overfit_num_sets = overfit_num_sets
|
| 50 |
+
self.sample_specific_scene = sample_specific_scene
|
| 51 |
+
self.specific_scene_name = specific_scene_name
|
| 52 |
+
self._load_data()
|
| 53 |
+
|
| 54 |
+
# Define the dataset type flags
|
| 55 |
+
self.is_metric_scale = True
|
| 56 |
+
self.is_synthetic = True
|
| 57 |
+
|
| 58 |
+
def _load_data(self):
|
| 59 |
+
"Load the precomputed dataset metadata"
|
| 60 |
+
# Load the dataset metadata corresponding to the split
|
| 61 |
+
split_metadata_path = os.path.join(
|
| 62 |
+
self.dataset_metadata_dir,
|
| 63 |
+
self.split,
|
| 64 |
+
f"mvs_synth_scene_list_{self.split}.npy",
|
| 65 |
+
)
|
| 66 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 67 |
+
|
| 68 |
+
# Get the list of all scenes
|
| 69 |
+
if not self.sample_specific_scene:
|
| 70 |
+
self.scenes = list(split_scene_list)
|
| 71 |
+
else:
|
| 72 |
+
self.scenes = [self.specific_scene_name]
|
| 73 |
+
self.num_of_scenes = len(self.scenes)
|
| 74 |
+
|
| 75 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 76 |
+
# Get the scene name of the sampled index
|
| 77 |
+
scene_index = sampled_idx
|
| 78 |
+
scene_name = self.scenes[scene_index]
|
| 79 |
+
|
| 80 |
+
# Get the metadata corresponding to the scene
|
| 81 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 82 |
+
scene_meta = load_data(
|
| 83 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 84 |
+
)
|
| 85 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 86 |
+
num_views_in_scene = len(scene_file_names)
|
| 87 |
+
|
| 88 |
+
# Load the scene pairwise covisibility mmap
|
| 89 |
+
covisibility_version_key = "v0"
|
| 90 |
+
covisibility_map_dir = os.path.join(
|
| 91 |
+
scene_root, "covisibility", covisibility_version_key
|
| 92 |
+
)
|
| 93 |
+
# Assumes only npy file in directory is covisibility map
|
| 94 |
+
covisibility_map_name = next(
|
| 95 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 96 |
+
)
|
| 97 |
+
covisibility_map_path = os.path.join(
|
| 98 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 99 |
+
)
|
| 100 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 101 |
+
|
| 102 |
+
# Get the indices of the N views in the scene
|
| 103 |
+
view_indices = self._sample_view_indices(
|
| 104 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Get the views corresponding to the selected view indices
|
| 108 |
+
views = []
|
| 109 |
+
for view_index in view_indices:
|
| 110 |
+
# Load the data corresponding to the view
|
| 111 |
+
view_file_name = scene_file_names[view_index]
|
| 112 |
+
view_data = load_frame(
|
| 113 |
+
scene_root,
|
| 114 |
+
view_file_name,
|
| 115 |
+
modalities=["image", "depth"],
|
| 116 |
+
scene_meta=scene_meta,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Convert necessary data to numpy
|
| 120 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 121 |
+
image = (image * 255).astype(np.uint8)
|
| 122 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 123 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 124 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 125 |
+
|
| 126 |
+
# Ensure that the depthmap has all valid values
|
| 127 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 128 |
+
|
| 129 |
+
# Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
|
| 130 |
+
non_ambiguous_mask = (depthmap > 0).astype(int)
|
| 131 |
+
|
| 132 |
+
# Mask out the outlier depth (horizon depth)
|
| 133 |
+
percentile_depth = np.percentile(depthmap, 95)
|
| 134 |
+
depthmap[depthmap > percentile_depth] = 0
|
| 135 |
+
|
| 136 |
+
# Resize the data to match the desired resolution
|
| 137 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 138 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 139 |
+
self._crop_resize_if_necessary(
|
| 140 |
+
image=image,
|
| 141 |
+
resolution=resolution,
|
| 142 |
+
depthmap=depthmap,
|
| 143 |
+
intrinsics=intrinsics,
|
| 144 |
+
additional_quantities=additional_quantities_to_resize,
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 148 |
+
|
| 149 |
+
# Append the view dictionary to the list of views
|
| 150 |
+
views.append(
|
| 151 |
+
dict(
|
| 152 |
+
img=image,
|
| 153 |
+
depthmap=depthmap,
|
| 154 |
+
camera_pose=c2w_pose, # cam2world
|
| 155 |
+
camera_intrinsics=intrinsics,
|
| 156 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 157 |
+
dataset="MVSSynth",
|
| 158 |
+
label=scene_name,
|
| 159 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return views
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_parser():
|
| 167 |
+
import argparse
|
| 168 |
+
|
| 169 |
+
parser = argparse.ArgumentParser()
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/mvs_synth", type=str
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"-dmd",
|
| 175 |
+
"--dataset_metadata_dir",
|
| 176 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 177 |
+
type=str,
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"-nv",
|
| 181 |
+
"--num_of_views",
|
| 182 |
+
default=2,
|
| 183 |
+
type=int,
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument("--viz", action="store_true")
|
| 186 |
+
|
| 187 |
+
return parser
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
import rerun as rr
|
| 192 |
+
from tqdm import tqdm
|
| 193 |
+
|
| 194 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 195 |
+
from mapanything.utils.image import rgb
|
| 196 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 197 |
+
|
| 198 |
+
parser = get_parser()
|
| 199 |
+
script_add_rerun_args(
|
| 200 |
+
parser
|
| 201 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 202 |
+
args = parser.parse_args()
|
| 203 |
+
|
| 204 |
+
dataset = MVSSynthWAI(
|
| 205 |
+
num_views=args.num_of_views,
|
| 206 |
+
split="train",
|
| 207 |
+
covisibility_thres=0.25,
|
| 208 |
+
ROOT=args.root_dir,
|
| 209 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 210 |
+
resolution=(518, 294),
|
| 211 |
+
aug_crop=16,
|
| 212 |
+
transform="colorjitter+grayscale+gaublur",
|
| 213 |
+
data_norm_type="dinov2",
|
| 214 |
+
)
|
| 215 |
+
# dataset = MVSSynthWAI(
|
| 216 |
+
# num_views=args.num_of_views,
|
| 217 |
+
# split="val",
|
| 218 |
+
# covisibility_thres=0.25,
|
| 219 |
+
# ROOT=args.root_dir,
|
| 220 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 221 |
+
# resolution=(518, 294),
|
| 222 |
+
# seed=777,
|
| 223 |
+
# transform="imgnorm",
|
| 224 |
+
# data_norm_type="dinov2",
|
| 225 |
+
# )
|
| 226 |
+
print(dataset.get_stats())
|
| 227 |
+
|
| 228 |
+
if args.viz:
|
| 229 |
+
rr.script_setup(args, "MVSSynth_Dataloader")
|
| 230 |
+
rr.set_time("stable_time", sequence=0)
|
| 231 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 232 |
+
|
| 233 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 234 |
+
|
| 235 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 236 |
+
views = dataset[idx]
|
| 237 |
+
assert len(views) == args.num_of_views
|
| 238 |
+
sample_name = f"{idx}"
|
| 239 |
+
for view_idx in range(args.num_of_views):
|
| 240 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 241 |
+
print(sample_name)
|
| 242 |
+
for view_idx in range(args.num_of_views):
|
| 243 |
+
image = rgb(
|
| 244 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 245 |
+
)
|
| 246 |
+
depthmap = views[view_idx]["depthmap"]
|
| 247 |
+
pose = views[view_idx]["camera_pose"]
|
| 248 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 249 |
+
pts3d = views[view_idx]["pts3d"]
|
| 250 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 251 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 252 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 253 |
+
else:
|
| 254 |
+
non_ambiguous_mask = None
|
| 255 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 256 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 257 |
+
else:
|
| 258 |
+
prior_depth_along_ray = None
|
| 259 |
+
if args.viz:
|
| 260 |
+
rr.set_time("stable_time", sequence=num)
|
| 261 |
+
base_name = f"world/view_{view_idx}"
|
| 262 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 263 |
+
# Log camera info and loaded data
|
| 264 |
+
height, width = image.shape[0], image.shape[1]
|
| 265 |
+
rr.log(
|
| 266 |
+
base_name,
|
| 267 |
+
rr.Transform3D(
|
| 268 |
+
translation=pose[:3, 3],
|
| 269 |
+
mat3x3=pose[:3, :3],
|
| 270 |
+
),
|
| 271 |
+
)
|
| 272 |
+
rr.log(
|
| 273 |
+
f"{base_name}/pinhole",
|
| 274 |
+
rr.Pinhole(
|
| 275 |
+
image_from_camera=intrinsics,
|
| 276 |
+
height=height,
|
| 277 |
+
width=width,
|
| 278 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 279 |
+
),
|
| 280 |
+
)
|
| 281 |
+
rr.log(
|
| 282 |
+
f"{base_name}/pinhole/rgb",
|
| 283 |
+
rr.Image(image),
|
| 284 |
+
)
|
| 285 |
+
rr.log(
|
| 286 |
+
f"{base_name}/pinhole/depth",
|
| 287 |
+
rr.DepthImage(depthmap),
|
| 288 |
+
)
|
| 289 |
+
if prior_depth_along_ray is not None:
|
| 290 |
+
rr.log(
|
| 291 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 292 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 293 |
+
)
|
| 294 |
+
if non_ambiguous_mask is not None:
|
| 295 |
+
rr.log(
|
| 296 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 297 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 298 |
+
)
|
| 299 |
+
# Log points in 3D
|
| 300 |
+
filtered_pts = pts3d[valid_mask]
|
| 301 |
+
filtered_pts_col = image[valid_mask]
|
| 302 |
+
rr.log(
|
| 303 |
+
pts_name,
|
| 304 |
+
rr.Points3D(
|
| 305 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 306 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 307 |
+
),
|
| 308 |
+
)
|
mapanything/datasets/wai/paralleldomain4d.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Parallel Domain 4D Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 15 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ParallelDomain4DWAI(BaseDataset):
|
| 19 |
+
"""
|
| 20 |
+
Parallel Domain 4D dataset containing large diversity of synthetic AV scenes.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*args,
|
| 26 |
+
ROOT,
|
| 27 |
+
dataset_metadata_dir,
|
| 28 |
+
split,
|
| 29 |
+
overfit_num_sets=None,
|
| 30 |
+
sample_specific_scene: bool = False,
|
| 31 |
+
specific_scene_name: str = None,
|
| 32 |
+
**kwargs,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the dataset attributes.
|
| 36 |
+
Args:
|
| 37 |
+
ROOT: Root directory of the dataset.
|
| 38 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 39 |
+
split: Dataset split (train, val, test).
|
| 40 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 41 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 42 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 43 |
+
"""
|
| 44 |
+
# Initialize the dataset attributes
|
| 45 |
+
super().__init__(*args, **kwargs)
|
| 46 |
+
self.ROOT = ROOT
|
| 47 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 48 |
+
self.split = split
|
| 49 |
+
self.overfit_num_sets = overfit_num_sets
|
| 50 |
+
self.sample_specific_scene = sample_specific_scene
|
| 51 |
+
self.specific_scene_name = specific_scene_name
|
| 52 |
+
self._load_data()
|
| 53 |
+
|
| 54 |
+
# Define the dataset type flags
|
| 55 |
+
self.is_metric_scale = True
|
| 56 |
+
self.is_synthetic = True
|
| 57 |
+
|
| 58 |
+
def _load_data(self):
|
| 59 |
+
"Load the precomputed dataset metadata"
|
| 60 |
+
# Load the dataset metadata corresponding to the split
|
| 61 |
+
split_metadata_path = os.path.join(
|
| 62 |
+
self.dataset_metadata_dir,
|
| 63 |
+
self.split,
|
| 64 |
+
f"paralleldomain4d_scene_list_{self.split}.npy",
|
| 65 |
+
)
|
| 66 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 67 |
+
|
| 68 |
+
# Get the list of all scenes
|
| 69 |
+
if not self.sample_specific_scene:
|
| 70 |
+
self.scenes = list(split_scene_list)
|
| 71 |
+
else:
|
| 72 |
+
self.scenes = [self.specific_scene_name]
|
| 73 |
+
self.num_of_scenes = len(self.scenes)
|
| 74 |
+
|
| 75 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 76 |
+
# Get the scene name of the sampled index
|
| 77 |
+
scene_index = sampled_idx
|
| 78 |
+
scene_name = self.scenes[scene_index]
|
| 79 |
+
|
| 80 |
+
# Get the metadata corresponding to the scene
|
| 81 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 82 |
+
scene_meta = load_data(
|
| 83 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 84 |
+
)
|
| 85 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 86 |
+
num_views_in_scene = len(scene_file_names)
|
| 87 |
+
|
| 88 |
+
# Load the scene pairwise covisibility mmap
|
| 89 |
+
covisibility_version_key = "v0"
|
| 90 |
+
covisibility_map_dir = os.path.join(
|
| 91 |
+
scene_root, "covisibility", covisibility_version_key
|
| 92 |
+
)
|
| 93 |
+
# Assumes only npy file in directory is covisibility map
|
| 94 |
+
covisibility_map_name = next(
|
| 95 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 96 |
+
)
|
| 97 |
+
covisibility_map_path = os.path.join(
|
| 98 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 99 |
+
)
|
| 100 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 101 |
+
|
| 102 |
+
# Get the indices of the N views in the scene
|
| 103 |
+
view_indices = self._sample_view_indices(
|
| 104 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Get the views corresponding to the selected view indices
|
| 108 |
+
views = []
|
| 109 |
+
for view_index in view_indices:
|
| 110 |
+
# Load the data corresponding to the view
|
| 111 |
+
view_file_name = scene_file_names[view_index]
|
| 112 |
+
view_data = load_frame(
|
| 113 |
+
scene_root,
|
| 114 |
+
view_file_name,
|
| 115 |
+
modalities=["image", "depth"],
|
| 116 |
+
scene_meta=scene_meta,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Convert necessary data to numpy
|
| 120 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 121 |
+
image = image[:, :, :3] # RGBA to RGB
|
| 122 |
+
image = (image * 255).astype(np.uint8)
|
| 123 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 124 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 125 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 126 |
+
|
| 127 |
+
# Ensure that the depthmap has all valid values
|
| 128 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 129 |
+
|
| 130 |
+
# Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
|
| 131 |
+
non_ambiguous_mask = (depthmap > 0).astype(int)
|
| 132 |
+
|
| 133 |
+
# Mask out the outlier depth (horizon depth)
|
| 134 |
+
percentile_depth = np.percentile(depthmap, 95)
|
| 135 |
+
depthmap[depthmap > percentile_depth] = 0
|
| 136 |
+
|
| 137 |
+
# Resize the data to match the desired resolution
|
| 138 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 139 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 140 |
+
self._crop_resize_if_necessary(
|
| 141 |
+
image=image,
|
| 142 |
+
resolution=resolution,
|
| 143 |
+
depthmap=depthmap,
|
| 144 |
+
intrinsics=intrinsics,
|
| 145 |
+
additional_quantities=additional_quantities_to_resize,
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 149 |
+
|
| 150 |
+
# Append the view dictionary to the list of views
|
| 151 |
+
views.append(
|
| 152 |
+
dict(
|
| 153 |
+
img=image,
|
| 154 |
+
depthmap=depthmap,
|
| 155 |
+
camera_pose=c2w_pose, # cam2world
|
| 156 |
+
camera_intrinsics=intrinsics,
|
| 157 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 158 |
+
dataset="ParallelDomain4D",
|
| 159 |
+
label=scene_name,
|
| 160 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return views
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_parser():
|
| 168 |
+
import argparse
|
| 169 |
+
|
| 170 |
+
parser = argparse.ArgumentParser()
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/paralleldomain4d", type=str
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"-dmd",
|
| 176 |
+
"--dataset_metadata_dir",
|
| 177 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 178 |
+
type=str,
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"-nv",
|
| 182 |
+
"--num_of_views",
|
| 183 |
+
default=2,
|
| 184 |
+
type=int,
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument("--viz", action="store_true")
|
| 187 |
+
|
| 188 |
+
return parser
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
import rerun as rr
|
| 193 |
+
from tqdm import tqdm
|
| 194 |
+
|
| 195 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 196 |
+
from mapanything.utils.image import rgb
|
| 197 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 198 |
+
|
| 199 |
+
parser = get_parser()
|
| 200 |
+
script_add_rerun_args(
|
| 201 |
+
parser
|
| 202 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 203 |
+
args = parser.parse_args()
|
| 204 |
+
|
| 205 |
+
dataset = ParallelDomain4DWAI(
|
| 206 |
+
num_views=args.num_of_views,
|
| 207 |
+
split="train",
|
| 208 |
+
covisibility_thres=0.25,
|
| 209 |
+
ROOT=args.root_dir,
|
| 210 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 211 |
+
resolution=(518, 392),
|
| 212 |
+
aug_crop=16,
|
| 213 |
+
transform="colorjitter+grayscale+gaublur",
|
| 214 |
+
data_norm_type="dinov2",
|
| 215 |
+
)
|
| 216 |
+
# dataset = ParallelDomain4DWAI(
|
| 217 |
+
# num_views=args.num_of_views,
|
| 218 |
+
# split="val",
|
| 219 |
+
# covisibility_thres=0.25,
|
| 220 |
+
# ROOT=args.root_dir,
|
| 221 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 222 |
+
# resolution=(518, 392),
|
| 223 |
+
# seed=777,
|
| 224 |
+
# transform="imgnorm",
|
| 225 |
+
# data_norm_type="dinov2",
|
| 226 |
+
# )
|
| 227 |
+
print(dataset.get_stats())
|
| 228 |
+
|
| 229 |
+
if args.viz:
|
| 230 |
+
rr.script_setup(args, "ParallelDomain4D_Dataloader")
|
| 231 |
+
rr.set_time("stable_time", sequence=0)
|
| 232 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 233 |
+
|
| 234 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 235 |
+
|
| 236 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 237 |
+
views = dataset[idx]
|
| 238 |
+
assert len(views) == args.num_of_views
|
| 239 |
+
sample_name = f"{idx}"
|
| 240 |
+
for view_idx in range(args.num_of_views):
|
| 241 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 242 |
+
print(sample_name)
|
| 243 |
+
for view_idx in range(args.num_of_views):
|
| 244 |
+
image = rgb(
|
| 245 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 246 |
+
)
|
| 247 |
+
depthmap = views[view_idx]["depthmap"]
|
| 248 |
+
pose = views[view_idx]["camera_pose"]
|
| 249 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 250 |
+
pts3d = views[view_idx]["pts3d"]
|
| 251 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 252 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 253 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 254 |
+
else:
|
| 255 |
+
non_ambiguous_mask = None
|
| 256 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 257 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 258 |
+
else:
|
| 259 |
+
prior_depth_along_ray = None
|
| 260 |
+
if args.viz:
|
| 261 |
+
rr.set_time("stable_time", sequence=num)
|
| 262 |
+
base_name = f"world/view_{view_idx}"
|
| 263 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 264 |
+
# Log camera info and loaded data
|
| 265 |
+
height, width = image.shape[0], image.shape[1]
|
| 266 |
+
rr.log(
|
| 267 |
+
base_name,
|
| 268 |
+
rr.Transform3D(
|
| 269 |
+
translation=pose[:3, 3],
|
| 270 |
+
mat3x3=pose[:3, :3],
|
| 271 |
+
),
|
| 272 |
+
)
|
| 273 |
+
rr.log(
|
| 274 |
+
f"{base_name}/pinhole",
|
| 275 |
+
rr.Pinhole(
|
| 276 |
+
image_from_camera=intrinsics,
|
| 277 |
+
height=height,
|
| 278 |
+
width=width,
|
| 279 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 280 |
+
),
|
| 281 |
+
)
|
| 282 |
+
rr.log(
|
| 283 |
+
f"{base_name}/pinhole/rgb",
|
| 284 |
+
rr.Image(image),
|
| 285 |
+
)
|
| 286 |
+
rr.log(
|
| 287 |
+
f"{base_name}/pinhole/depth",
|
| 288 |
+
rr.DepthImage(depthmap),
|
| 289 |
+
)
|
| 290 |
+
if prior_depth_along_ray is not None:
|
| 291 |
+
rr.log(
|
| 292 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 293 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 294 |
+
)
|
| 295 |
+
if non_ambiguous_mask is not None:
|
| 296 |
+
rr.log(
|
| 297 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 298 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 299 |
+
)
|
| 300 |
+
# Log points in 3D
|
| 301 |
+
filtered_pts = pts3d[valid_mask]
|
| 302 |
+
filtered_pts_col = image[valid_mask]
|
| 303 |
+
rr.log(
|
| 304 |
+
pts_name,
|
| 305 |
+
rr.Points3D(
|
| 306 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 307 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 308 |
+
),
|
| 309 |
+
)
|
mapanything/datasets/wai/sailvos3d.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
SAIL-VOS 3D Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 15 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SAILVOS3DWAI(BaseDataset):
|
| 19 |
+
"""
|
| 20 |
+
SAIL-VOS 3D dataset containing large diversity of synthetic in-the-wild cut scenes from GTA.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*args,
|
| 26 |
+
ROOT,
|
| 27 |
+
dataset_metadata_dir,
|
| 28 |
+
split,
|
| 29 |
+
overfit_num_sets=None,
|
| 30 |
+
sample_specific_scene: bool = False,
|
| 31 |
+
specific_scene_name: str = None,
|
| 32 |
+
**kwargs,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the dataset attributes.
|
| 36 |
+
Args:
|
| 37 |
+
ROOT: Root directory of the dataset.
|
| 38 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 39 |
+
split: Dataset split (train, val, test).
|
| 40 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 41 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 42 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 43 |
+
"""
|
| 44 |
+
# Initialize the dataset attributes
|
| 45 |
+
super().__init__(*args, **kwargs)
|
| 46 |
+
self.ROOT = ROOT
|
| 47 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 48 |
+
self.split = split
|
| 49 |
+
self.overfit_num_sets = overfit_num_sets
|
| 50 |
+
self.sample_specific_scene = sample_specific_scene
|
| 51 |
+
self.specific_scene_name = specific_scene_name
|
| 52 |
+
self._load_data()
|
| 53 |
+
|
| 54 |
+
# Define the dataset type flags
|
| 55 |
+
self.is_metric_scale = True
|
| 56 |
+
self.is_synthetic = True
|
| 57 |
+
|
| 58 |
+
def _load_data(self):
|
| 59 |
+
"Load the precomputed dataset metadata"
|
| 60 |
+
# Load the dataset metadata corresponding to the split
|
| 61 |
+
split_metadata_path = os.path.join(
|
| 62 |
+
self.dataset_metadata_dir,
|
| 63 |
+
self.split,
|
| 64 |
+
f"sailvos3d_scene_list_{self.split}.npy",
|
| 65 |
+
)
|
| 66 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 67 |
+
|
| 68 |
+
# Get the list of all scenes
|
| 69 |
+
if not self.sample_specific_scene:
|
| 70 |
+
self.scenes = list(split_scene_list)
|
| 71 |
+
else:
|
| 72 |
+
self.scenes = [self.specific_scene_name]
|
| 73 |
+
self.num_of_scenes = len(self.scenes)
|
| 74 |
+
|
| 75 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 76 |
+
# Get the scene name of the sampled index
|
| 77 |
+
scene_index = sampled_idx
|
| 78 |
+
scene_name = self.scenes[scene_index]
|
| 79 |
+
|
| 80 |
+
# Get the metadata corresponding to the scene
|
| 81 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 82 |
+
scene_meta = load_data(
|
| 83 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 84 |
+
)
|
| 85 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 86 |
+
num_views_in_scene = len(scene_file_names)
|
| 87 |
+
|
| 88 |
+
# Load the scene pairwise covisibility mmap
|
| 89 |
+
covisibility_version_key = "v0"
|
| 90 |
+
covisibility_map_dir = os.path.join(
|
| 91 |
+
scene_root, "covisibility", covisibility_version_key
|
| 92 |
+
)
|
| 93 |
+
# Assumes only npy file in directory is covisibility map
|
| 94 |
+
covisibility_map_name = next(
|
| 95 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 96 |
+
)
|
| 97 |
+
covisibility_map_path = os.path.join(
|
| 98 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 99 |
+
)
|
| 100 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 101 |
+
|
| 102 |
+
# Get the indices of the N views in the scene
|
| 103 |
+
view_indices = self._sample_view_indices(
|
| 104 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Get the views corresponding to the selected view indices
|
| 108 |
+
views = []
|
| 109 |
+
for view_index in view_indices:
|
| 110 |
+
# Load the data corresponding to the view
|
| 111 |
+
view_file_name = scene_file_names[view_index]
|
| 112 |
+
view_data = load_frame(
|
| 113 |
+
scene_root,
|
| 114 |
+
view_file_name,
|
| 115 |
+
modalities=["image", "depth"],
|
| 116 |
+
scene_meta=scene_meta,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Convert necessary data to numpy
|
| 120 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 121 |
+
image = (image * 255).astype(np.uint8)
|
| 122 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 123 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 124 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 125 |
+
|
| 126 |
+
# Ensure that the depthmap has all valid values
|
| 127 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 128 |
+
|
| 129 |
+
# Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
|
| 130 |
+
non_ambiguous_mask = (depthmap > 0).astype(int)
|
| 131 |
+
|
| 132 |
+
# Mask out the outlier depth (horizon depth)
|
| 133 |
+
percentile_depth = np.percentile(depthmap, 95)
|
| 134 |
+
depthmap[depthmap > percentile_depth] = 0
|
| 135 |
+
|
| 136 |
+
# Resize the data to match the desired resolution
|
| 137 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 138 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 139 |
+
self._crop_resize_if_necessary(
|
| 140 |
+
image=image,
|
| 141 |
+
resolution=resolution,
|
| 142 |
+
depthmap=depthmap,
|
| 143 |
+
intrinsics=intrinsics,
|
| 144 |
+
additional_quantities=additional_quantities_to_resize,
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 148 |
+
|
| 149 |
+
# Append the view dictionary to the list of views
|
| 150 |
+
views.append(
|
| 151 |
+
dict(
|
| 152 |
+
img=image,
|
| 153 |
+
depthmap=depthmap,
|
| 154 |
+
camera_pose=c2w_pose, # cam2world
|
| 155 |
+
camera_intrinsics=intrinsics,
|
| 156 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 157 |
+
dataset="SAILVOS3D",
|
| 158 |
+
label=scene_name,
|
| 159 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return views
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_parser():
|
| 167 |
+
import argparse
|
| 168 |
+
|
| 169 |
+
parser = argparse.ArgumentParser()
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/sailvos3d", type=str
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"-dmd",
|
| 175 |
+
"--dataset_metadata_dir",
|
| 176 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 177 |
+
type=str,
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"-nv",
|
| 181 |
+
"--num_of_views",
|
| 182 |
+
default=2,
|
| 183 |
+
type=int,
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument("--viz", action="store_true")
|
| 186 |
+
|
| 187 |
+
return parser
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
import rerun as rr
|
| 192 |
+
from tqdm import tqdm
|
| 193 |
+
|
| 194 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 195 |
+
from mapanything.utils.image import rgb
|
| 196 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 197 |
+
|
| 198 |
+
parser = get_parser()
|
| 199 |
+
script_add_rerun_args(
|
| 200 |
+
parser
|
| 201 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 202 |
+
args = parser.parse_args()
|
| 203 |
+
|
| 204 |
+
dataset = SAILVOS3DWAI(
|
| 205 |
+
num_views=args.num_of_views,
|
| 206 |
+
split="train",
|
| 207 |
+
covisibility_thres=0.25,
|
| 208 |
+
ROOT=args.root_dir,
|
| 209 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 210 |
+
resolution=(518, 336),
|
| 211 |
+
aug_crop=16,
|
| 212 |
+
transform="colorjitter+grayscale+gaublur",
|
| 213 |
+
data_norm_type="dinov2",
|
| 214 |
+
)
|
| 215 |
+
# dataset = SAILVOS3DWAI(
|
| 216 |
+
# num_views=args.num_of_views,
|
| 217 |
+
# split="val",
|
| 218 |
+
# covisibility_thres=0.25,
|
| 219 |
+
# ROOT=args.root_dir,
|
| 220 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 221 |
+
# resolution=(518, 336),
|
| 222 |
+
# seed=777,
|
| 223 |
+
# transform="imgnorm",
|
| 224 |
+
# data_norm_type="dinov2",
|
| 225 |
+
# )
|
| 226 |
+
print(dataset.get_stats())
|
| 227 |
+
|
| 228 |
+
if args.viz:
|
| 229 |
+
rr.script_setup(args, "SAILVOS3D_Dataloader")
|
| 230 |
+
rr.set_time("stable_time", sequence=0)
|
| 231 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 232 |
+
|
| 233 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 234 |
+
|
| 235 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 236 |
+
views = dataset[idx]
|
| 237 |
+
assert len(views) == args.num_of_views
|
| 238 |
+
sample_name = f"{idx}"
|
| 239 |
+
for view_idx in range(args.num_of_views):
|
| 240 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 241 |
+
print(sample_name)
|
| 242 |
+
for view_idx in range(args.num_of_views):
|
| 243 |
+
image = rgb(
|
| 244 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 245 |
+
)
|
| 246 |
+
depthmap = views[view_idx]["depthmap"]
|
| 247 |
+
pose = views[view_idx]["camera_pose"]
|
| 248 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 249 |
+
pts3d = views[view_idx]["pts3d"]
|
| 250 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 251 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 252 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 253 |
+
else:
|
| 254 |
+
non_ambiguous_mask = None
|
| 255 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 256 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 257 |
+
else:
|
| 258 |
+
prior_depth_along_ray = None
|
| 259 |
+
if args.viz:
|
| 260 |
+
rr.set_time("stable_time", sequence=num)
|
| 261 |
+
base_name = f"world/view_{view_idx}"
|
| 262 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 263 |
+
# Log camera info and loaded data
|
| 264 |
+
height, width = image.shape[0], image.shape[1]
|
| 265 |
+
rr.log(
|
| 266 |
+
base_name,
|
| 267 |
+
rr.Transform3D(
|
| 268 |
+
translation=pose[:3, 3],
|
| 269 |
+
mat3x3=pose[:3, :3],
|
| 270 |
+
),
|
| 271 |
+
)
|
| 272 |
+
rr.log(
|
| 273 |
+
f"{base_name}/pinhole",
|
| 274 |
+
rr.Pinhole(
|
| 275 |
+
image_from_camera=intrinsics,
|
| 276 |
+
height=height,
|
| 277 |
+
width=width,
|
| 278 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 279 |
+
),
|
| 280 |
+
)
|
| 281 |
+
rr.log(
|
| 282 |
+
f"{base_name}/pinhole/rgb",
|
| 283 |
+
rr.Image(image),
|
| 284 |
+
)
|
| 285 |
+
rr.log(
|
| 286 |
+
f"{base_name}/pinhole/depth",
|
| 287 |
+
rr.DepthImage(depthmap),
|
| 288 |
+
)
|
| 289 |
+
if prior_depth_along_ray is not None:
|
| 290 |
+
rr.log(
|
| 291 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 292 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 293 |
+
)
|
| 294 |
+
if non_ambiguous_mask is not None:
|
| 295 |
+
rr.log(
|
| 296 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 297 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 298 |
+
)
|
| 299 |
+
# Log points in 3D
|
| 300 |
+
filtered_pts = pts3d[valid_mask]
|
| 301 |
+
filtered_pts_col = image[valid_mask]
|
| 302 |
+
rr.log(
|
| 303 |
+
pts_name,
|
| 304 |
+
rr.Points3D(
|
| 305 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 306 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 307 |
+
),
|
| 308 |
+
)
|
mapanything/datasets/wai/scannetpp.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
ScanNet++V2 Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 15 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ScanNetPPWAI(BaseDataset):
|
| 19 |
+
"""
|
| 20 |
+
ScanNet++V2 dataset containing large diversity of indoor scenes.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*args,
|
| 26 |
+
ROOT,
|
| 27 |
+
dataset_metadata_dir,
|
| 28 |
+
split,
|
| 29 |
+
overfit_num_sets=None,
|
| 30 |
+
sample_specific_scene: bool = False,
|
| 31 |
+
specific_scene_name: str = None,
|
| 32 |
+
**kwargs,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the dataset attributes.
|
| 36 |
+
Args:
|
| 37 |
+
ROOT: Root directory of the dataset.
|
| 38 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 39 |
+
split: Dataset split (train, val, test).
|
| 40 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 41 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 42 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 43 |
+
"""
|
| 44 |
+
# Initialize the dataset attributes
|
| 45 |
+
super().__init__(*args, **kwargs)
|
| 46 |
+
self.ROOT = ROOT
|
| 47 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 48 |
+
self.split = split
|
| 49 |
+
self.overfit_num_sets = overfit_num_sets
|
| 50 |
+
self.sample_specific_scene = sample_specific_scene
|
| 51 |
+
self.specific_scene_name = specific_scene_name
|
| 52 |
+
self._load_data()
|
| 53 |
+
|
| 54 |
+
# Define the dataset type flags
|
| 55 |
+
self.is_metric_scale = True
|
| 56 |
+
self.is_synthetic = False
|
| 57 |
+
|
| 58 |
+
def _load_data(self):
|
| 59 |
+
"Load the precomputed dataset metadata"
|
| 60 |
+
# Load the dataset metadata corresponding to the split
|
| 61 |
+
split_metadata_path = os.path.join(
|
| 62 |
+
self.dataset_metadata_dir,
|
| 63 |
+
self.split,
|
| 64 |
+
f"scannetppv2_scene_list_{self.split}.npy",
|
| 65 |
+
)
|
| 66 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 67 |
+
|
| 68 |
+
# Get the list of all scenes
|
| 69 |
+
if not self.sample_specific_scene:
|
| 70 |
+
self.scenes = list(split_scene_list)
|
| 71 |
+
else:
|
| 72 |
+
self.scenes = [self.specific_scene_name]
|
| 73 |
+
self.num_of_scenes = len(self.scenes)
|
| 74 |
+
|
| 75 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 76 |
+
# Get the scene name of the sampled index
|
| 77 |
+
scene_index = sampled_idx
|
| 78 |
+
scene_name = self.scenes[scene_index]
|
| 79 |
+
|
| 80 |
+
# Get the metadata corresponding to the scene
|
| 81 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 82 |
+
scene_meta = load_data(
|
| 83 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 84 |
+
)
|
| 85 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 86 |
+
num_views_in_scene = len(scene_file_names)
|
| 87 |
+
|
| 88 |
+
# Load the scene pairwise covisibility mmap
|
| 89 |
+
covisibility_version_key = "v0"
|
| 90 |
+
covisibility_map_dir = os.path.join(
|
| 91 |
+
scene_root, "covisibility", covisibility_version_key
|
| 92 |
+
)
|
| 93 |
+
# Assumes only npy file in directory is covisibility map
|
| 94 |
+
covisibility_map_name = next(
|
| 95 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 96 |
+
)
|
| 97 |
+
covisibility_map_path = os.path.join(
|
| 98 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 99 |
+
)
|
| 100 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 101 |
+
|
| 102 |
+
# Get the indices of the N views in the scene
|
| 103 |
+
view_indices = self._sample_view_indices(
|
| 104 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Get the views corresponding to the selected view indices
|
| 108 |
+
views = []
|
| 109 |
+
for view_index in view_indices:
|
| 110 |
+
# Load the data corresponding to the view
|
| 111 |
+
view_file_name = scene_file_names[view_index]
|
| 112 |
+
view_data = load_frame(
|
| 113 |
+
scene_root,
|
| 114 |
+
view_file_name,
|
| 115 |
+
modalities=["image", "rendered_depth"],
|
| 116 |
+
scene_meta=scene_meta,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Convert necessary data to numpy
|
| 120 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 121 |
+
image = (image * 255).astype(np.uint8)
|
| 122 |
+
depthmap = view_data["rendered_depth"].numpy().astype(np.float32)
|
| 123 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 124 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 125 |
+
|
| 126 |
+
# Ensure that the depthmap has all valid values
|
| 127 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 128 |
+
|
| 129 |
+
# Resize the data to match the desired resolution
|
| 130 |
+
image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 131 |
+
image=image,
|
| 132 |
+
resolution=resolution,
|
| 133 |
+
depthmap=depthmap,
|
| 134 |
+
intrinsics=intrinsics,
|
| 135 |
+
additional_quantities=None,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Append the view dictionary to the list of views
|
| 139 |
+
views.append(
|
| 140 |
+
dict(
|
| 141 |
+
img=image,
|
| 142 |
+
depthmap=depthmap,
|
| 143 |
+
camera_pose=c2w_pose, # cam2world
|
| 144 |
+
camera_intrinsics=intrinsics,
|
| 145 |
+
dataset="ScanNetPP",
|
| 146 |
+
label=scene_name,
|
| 147 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return views
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_parser():
|
| 155 |
+
import argparse
|
| 156 |
+
|
| 157 |
+
parser = argparse.ArgumentParser()
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/scannetppv2", type=str
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"-dmd",
|
| 163 |
+
"--dataset_metadata_dir",
|
| 164 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 165 |
+
type=str,
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"-nv",
|
| 169 |
+
"--num_of_views",
|
| 170 |
+
default=2,
|
| 171 |
+
type=int,
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument("--viz", action="store_true")
|
| 174 |
+
|
| 175 |
+
return parser
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
import rerun as rr
|
| 180 |
+
from tqdm import tqdm
|
| 181 |
+
|
| 182 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 183 |
+
from mapanything.utils.image import rgb
|
| 184 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 185 |
+
|
| 186 |
+
parser = get_parser()
|
| 187 |
+
script_add_rerun_args(
|
| 188 |
+
parser
|
| 189 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 190 |
+
args = parser.parse_args()
|
| 191 |
+
|
| 192 |
+
dataset = ScanNetPPWAI(
|
| 193 |
+
num_views=args.num_of_views,
|
| 194 |
+
split="train",
|
| 195 |
+
covisibility_thres=0.25,
|
| 196 |
+
ROOT=args.root_dir,
|
| 197 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 198 |
+
resolution=(518, 336),
|
| 199 |
+
aug_crop=16,
|
| 200 |
+
transform="colorjitter+grayscale+gaublur",
|
| 201 |
+
data_norm_type="dinov2",
|
| 202 |
+
)
|
| 203 |
+
# dataset = ScanNetPPWAI(
|
| 204 |
+
# num_views=args.num_of_views,
|
| 205 |
+
# split="val",
|
| 206 |
+
# covisibility_thres=0.25,
|
| 207 |
+
# ROOT=args.root_dir,
|
| 208 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 209 |
+
# resolution=(518, 336),
|
| 210 |
+
# seed=777,
|
| 211 |
+
# transform="imgnorm",
|
| 212 |
+
# data_norm_type="dinov2",
|
| 213 |
+
# )
|
| 214 |
+
# dataset = ScanNetPPWAI(
|
| 215 |
+
# num_views=args.num_of_views,
|
| 216 |
+
# split="test",
|
| 217 |
+
# covisibility_thres=0.25,
|
| 218 |
+
# ROOT=args.root_dir,
|
| 219 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 220 |
+
# resolution=(518, 336),
|
| 221 |
+
# seed=777,
|
| 222 |
+
# transform="imgnorm",
|
| 223 |
+
# data_norm_type="dinov2",
|
| 224 |
+
# )
|
| 225 |
+
print(dataset.get_stats())
|
| 226 |
+
|
| 227 |
+
if args.viz:
|
| 228 |
+
rr.script_setup(args, "ScanNetPP_Dataloader")
|
| 229 |
+
rr.set_time("stable_time", sequence=0)
|
| 230 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 231 |
+
|
| 232 |
+
sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
|
| 233 |
+
|
| 234 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 235 |
+
views = dataset[idx]
|
| 236 |
+
assert len(views) == args.num_of_views
|
| 237 |
+
sample_name = f"{idx}"
|
| 238 |
+
for view_idx in range(args.num_of_views):
|
| 239 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 240 |
+
print(sample_name)
|
| 241 |
+
for view_idx in range(args.num_of_views):
|
| 242 |
+
image = rgb(
|
| 243 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 244 |
+
)
|
| 245 |
+
depthmap = views[view_idx]["depthmap"]
|
| 246 |
+
pose = views[view_idx]["camera_pose"]
|
| 247 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 248 |
+
pts3d = views[view_idx]["pts3d"]
|
| 249 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 250 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 251 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 252 |
+
else:
|
| 253 |
+
non_ambiguous_mask = None
|
| 254 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 255 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 256 |
+
else:
|
| 257 |
+
prior_depth_along_ray = None
|
| 258 |
+
if args.viz:
|
| 259 |
+
rr.set_time("stable_time", sequence=num)
|
| 260 |
+
base_name = f"world/view_{view_idx}"
|
| 261 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 262 |
+
# Log camera info and loaded data
|
| 263 |
+
height, width = image.shape[0], image.shape[1]
|
| 264 |
+
rr.log(
|
| 265 |
+
base_name,
|
| 266 |
+
rr.Transform3D(
|
| 267 |
+
translation=pose[:3, 3],
|
| 268 |
+
mat3x3=pose[:3, :3],
|
| 269 |
+
),
|
| 270 |
+
)
|
| 271 |
+
rr.log(
|
| 272 |
+
f"{base_name}/pinhole",
|
| 273 |
+
rr.Pinhole(
|
| 274 |
+
image_from_camera=intrinsics,
|
| 275 |
+
height=height,
|
| 276 |
+
width=width,
|
| 277 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
rr.log(
|
| 281 |
+
f"{base_name}/pinhole/rgb",
|
| 282 |
+
rr.Image(image),
|
| 283 |
+
)
|
| 284 |
+
rr.log(
|
| 285 |
+
f"{base_name}/pinhole/depth",
|
| 286 |
+
rr.DepthImage(depthmap),
|
| 287 |
+
)
|
| 288 |
+
if prior_depth_along_ray is not None:
|
| 289 |
+
rr.log(
|
| 290 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 291 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 292 |
+
)
|
| 293 |
+
if non_ambiguous_mask is not None:
|
| 294 |
+
rr.log(
|
| 295 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 296 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 297 |
+
)
|
| 298 |
+
# Log points in 3D
|
| 299 |
+
filtered_pts = pts3d[valid_mask]
|
| 300 |
+
filtered_pts_col = image[valid_mask]
|
| 301 |
+
rr.log(
|
| 302 |
+
pts_name,
|
| 303 |
+
rr.Points3D(
|
| 304 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 305 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 306 |
+
),
|
| 307 |
+
)
|
mapanything/datasets/wai/spring.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Spring Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 16 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SpringWAI(BaseDataset):
|
| 20 |
+
"""
|
| 21 |
+
Spring dataset containing high-quality large-scale in-the-wild scenes with unique animated objects.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
*args,
|
| 27 |
+
ROOT,
|
| 28 |
+
dataset_metadata_dir,
|
| 29 |
+
split,
|
| 30 |
+
overfit_num_sets=None,
|
| 31 |
+
sample_specific_scene: bool = False,
|
| 32 |
+
specific_scene_name: str = None,
|
| 33 |
+
**kwargs,
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the dataset attributes.
|
| 37 |
+
Args:
|
| 38 |
+
ROOT: Root directory of the dataset.
|
| 39 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 40 |
+
split: Dataset split (train, val, test).
|
| 41 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 42 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 43 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 44 |
+
"""
|
| 45 |
+
# Initialize the dataset attributes
|
| 46 |
+
super().__init__(*args, **kwargs)
|
| 47 |
+
self.ROOT = ROOT
|
| 48 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 49 |
+
self.split = split
|
| 50 |
+
self.overfit_num_sets = overfit_num_sets
|
| 51 |
+
self.sample_specific_scene = sample_specific_scene
|
| 52 |
+
self.specific_scene_name = specific_scene_name
|
| 53 |
+
self._load_data()
|
| 54 |
+
|
| 55 |
+
# Define the dataset type flags
|
| 56 |
+
self.is_metric_scale = True
|
| 57 |
+
self.is_synthetic = True
|
| 58 |
+
|
| 59 |
+
def _load_data(self):
|
| 60 |
+
"Load the precomputed dataset metadata"
|
| 61 |
+
# Load the dataset metadata corresponding to the split
|
| 62 |
+
split_metadata_path = os.path.join(
|
| 63 |
+
self.dataset_metadata_dir,
|
| 64 |
+
self.split,
|
| 65 |
+
f"spring_scene_list_{self.split}.npy",
|
| 66 |
+
)
|
| 67 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 68 |
+
|
| 69 |
+
# Get the list of all scenes
|
| 70 |
+
if not self.sample_specific_scene:
|
| 71 |
+
self.scenes = list(split_scene_list)
|
| 72 |
+
else:
|
| 73 |
+
self.scenes = [self.specific_scene_name]
|
| 74 |
+
self.num_of_scenes = len(self.scenes)
|
| 75 |
+
|
| 76 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 77 |
+
# Get the scene name of the sampled index
|
| 78 |
+
scene_index = sampled_idx
|
| 79 |
+
scene_name = self.scenes[scene_index]
|
| 80 |
+
|
| 81 |
+
# Get the metadata corresponding to the scene
|
| 82 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 83 |
+
scene_meta = load_data(
|
| 84 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 85 |
+
)
|
| 86 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 87 |
+
num_views_in_scene = len(scene_file_names)
|
| 88 |
+
|
| 89 |
+
# Load the scene pairwise covisibility mmap
|
| 90 |
+
covisibility_version_key = "v0"
|
| 91 |
+
covisibility_map_dir = os.path.join(
|
| 92 |
+
scene_root, "covisibility", covisibility_version_key
|
| 93 |
+
)
|
| 94 |
+
covisibility_map_name = next(
|
| 95 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 96 |
+
) # Assumes only npy file in directory is covisibility map
|
| 97 |
+
covisibility_map_path = os.path.join(
|
| 98 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 99 |
+
)
|
| 100 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 101 |
+
|
| 102 |
+
# Get the indices of the N views in the scene
|
| 103 |
+
view_indices = self._sample_view_indices(
|
| 104 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Get the views corresponding to the selected view indices
|
| 108 |
+
views = []
|
| 109 |
+
for view_index in view_indices:
|
| 110 |
+
# Load the data corresponding to the view
|
| 111 |
+
view_file_name = scene_file_names[view_index]
|
| 112 |
+
view_data = load_frame(
|
| 113 |
+
scene_root,
|
| 114 |
+
view_file_name,
|
| 115 |
+
modalities=["image", "depth", "skymask", "pred_mask/moge2"],
|
| 116 |
+
scene_meta=scene_meta,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Convert necessary data to numpy
|
| 120 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 121 |
+
image = (image * 255).astype(np.uint8)
|
| 122 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 123 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 124 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 125 |
+
|
| 126 |
+
# Get the sky mask and mask out GT depth
|
| 127 |
+
sky_mask = view_data["skymask"].numpy().astype(int)
|
| 128 |
+
depthmap = np.where(sky_mask, 0, depthmap)
|
| 129 |
+
|
| 130 |
+
# Ensure that the depthmap has all valid values
|
| 131 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 132 |
+
|
| 133 |
+
# Get the non_ambiguous_mask and ensure it matches image resolution
|
| 134 |
+
non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
|
| 135 |
+
non_ambiguous_mask = cv2.resize(
|
| 136 |
+
non_ambiguous_mask,
|
| 137 |
+
(image.shape[1], image.shape[0]),
|
| 138 |
+
interpolation=cv2.INTER_NEAREST,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Mask out the GT depth using the non_ambiguous_mask
|
| 142 |
+
depthmap = np.where(non_ambiguous_mask, depthmap, 0)
|
| 143 |
+
|
| 144 |
+
# Resize the data to match the desired resolution
|
| 145 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 146 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 147 |
+
self._crop_resize_if_necessary(
|
| 148 |
+
image=image,
|
| 149 |
+
resolution=resolution,
|
| 150 |
+
depthmap=depthmap,
|
| 151 |
+
intrinsics=intrinsics,
|
| 152 |
+
additional_quantities=additional_quantities_to_resize,
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 156 |
+
|
| 157 |
+
# Append the view dictionary to the list of views
|
| 158 |
+
views.append(
|
| 159 |
+
dict(
|
| 160 |
+
img=image,
|
| 161 |
+
depthmap=depthmap,
|
| 162 |
+
camera_pose=c2w_pose, # cam2world
|
| 163 |
+
camera_intrinsics=intrinsics,
|
| 164 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 165 |
+
dataset="Spring",
|
| 166 |
+
label=scene_name,
|
| 167 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
return views
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_parser():
|
| 175 |
+
import argparse
|
| 176 |
+
|
| 177 |
+
parser = argparse.ArgumentParser()
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/spring", type=str
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"-dmd",
|
| 183 |
+
"--dataset_metadata_dir",
|
| 184 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 185 |
+
type=str,
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"-nv",
|
| 189 |
+
"--num_of_views",
|
| 190 |
+
default=2,
|
| 191 |
+
type=int,
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument("--viz", action="store_true")
|
| 194 |
+
|
| 195 |
+
return parser
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
import rerun as rr
|
| 200 |
+
from tqdm import tqdm
|
| 201 |
+
|
| 202 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 203 |
+
from mapanything.utils.image import rgb
|
| 204 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 205 |
+
|
| 206 |
+
parser = get_parser()
|
| 207 |
+
script_add_rerun_args(
|
| 208 |
+
parser
|
| 209 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 210 |
+
args = parser.parse_args()
|
| 211 |
+
|
| 212 |
+
dataset = SpringWAI(
|
| 213 |
+
num_views=args.num_of_views,
|
| 214 |
+
split="train",
|
| 215 |
+
covisibility_thres=0.25,
|
| 216 |
+
ROOT=args.root_dir,
|
| 217 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 218 |
+
resolution=(518, 294),
|
| 219 |
+
aug_crop=16,
|
| 220 |
+
transform="colorjitter+grayscale+gaublur",
|
| 221 |
+
data_norm_type="dinov2",
|
| 222 |
+
)
|
| 223 |
+
# dataset = SpringWAI(
|
| 224 |
+
# num_views=args.num_of_views,
|
| 225 |
+
# split="val",
|
| 226 |
+
# covisibility_thres=0.25,
|
| 227 |
+
# ROOT=args.root_dir,
|
| 228 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 229 |
+
# resolution=(518, 294),
|
| 230 |
+
# seed=777,
|
| 231 |
+
# transform="imgnorm",
|
| 232 |
+
# data_norm_type="dinov2",
|
| 233 |
+
# )
|
| 234 |
+
print(dataset.get_stats())
|
| 235 |
+
|
| 236 |
+
if args.viz:
|
| 237 |
+
rr.script_setup(args, "Spring_Dataloader")
|
| 238 |
+
rr.set_time("stable_time", sequence=0)
|
| 239 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 240 |
+
|
| 241 |
+
sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
|
| 242 |
+
|
| 243 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 244 |
+
views = dataset[idx]
|
| 245 |
+
assert len(views) == args.num_of_views
|
| 246 |
+
sample_name = f"{idx}"
|
| 247 |
+
for view_idx in range(args.num_of_views):
|
| 248 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 249 |
+
print(sample_name)
|
| 250 |
+
for view_idx in range(args.num_of_views):
|
| 251 |
+
image = rgb(
|
| 252 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 253 |
+
)
|
| 254 |
+
depthmap = views[view_idx]["depthmap"]
|
| 255 |
+
pose = views[view_idx]["camera_pose"]
|
| 256 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 257 |
+
pts3d = views[view_idx]["pts3d"]
|
| 258 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 259 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 260 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 261 |
+
else:
|
| 262 |
+
non_ambiguous_mask = None
|
| 263 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 264 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 265 |
+
else:
|
| 266 |
+
prior_depth_along_ray = None
|
| 267 |
+
if args.viz:
|
| 268 |
+
rr.set_time("stable_time", sequence=num)
|
| 269 |
+
base_name = f"world/view_{view_idx}"
|
| 270 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 271 |
+
# Log camera info and loaded data
|
| 272 |
+
height, width = image.shape[0], image.shape[1]
|
| 273 |
+
rr.log(
|
| 274 |
+
base_name,
|
| 275 |
+
rr.Transform3D(
|
| 276 |
+
translation=pose[:3, 3],
|
| 277 |
+
mat3x3=pose[:3, :3],
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
rr.log(
|
| 281 |
+
f"{base_name}/pinhole",
|
| 282 |
+
rr.Pinhole(
|
| 283 |
+
image_from_camera=intrinsics,
|
| 284 |
+
height=height,
|
| 285 |
+
width=width,
|
| 286 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 287 |
+
),
|
| 288 |
+
)
|
| 289 |
+
rr.log(
|
| 290 |
+
f"{base_name}/pinhole/rgb",
|
| 291 |
+
rr.Image(image),
|
| 292 |
+
)
|
| 293 |
+
rr.log(
|
| 294 |
+
f"{base_name}/pinhole/depth",
|
| 295 |
+
rr.DepthImage(depthmap),
|
| 296 |
+
)
|
| 297 |
+
if prior_depth_along_ray is not None:
|
| 298 |
+
rr.log(
|
| 299 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 300 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 301 |
+
)
|
| 302 |
+
if non_ambiguous_mask is not None:
|
| 303 |
+
rr.log(
|
| 304 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 305 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 306 |
+
)
|
| 307 |
+
# Log points in 3D
|
| 308 |
+
filtered_pts = pts3d[valid_mask]
|
| 309 |
+
filtered_pts_col = image[valid_mask]
|
| 310 |
+
rr.log(
|
| 311 |
+
pts_name,
|
| 312 |
+
rr.Points3D(
|
| 313 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 314 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 315 |
+
),
|
| 316 |
+
)
|
mapanything/datasets/wai/tav2_wb.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
TartanAirV2-WB Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 16 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TartanAirV2WBWAI(BaseDataset):
|
| 20 |
+
"""
|
| 21 |
+
TartanAirV2-WB dataset containing vastly-sized in-the-wild synthetic scenes.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
*args,
|
| 27 |
+
ROOT,
|
| 28 |
+
dataset_metadata_dir,
|
| 29 |
+
split,
|
| 30 |
+
overfit_num_sets=None,
|
| 31 |
+
sample_specific_scene: bool = False,
|
| 32 |
+
specific_scene_name: str = None,
|
| 33 |
+
**kwargs,
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the dataset attributes.
|
| 37 |
+
Args:
|
| 38 |
+
ROOT: Root directory of the dataset.
|
| 39 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 40 |
+
split: Dataset split (train, val, test).
|
| 41 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 42 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 43 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 44 |
+
"""
|
| 45 |
+
# Initialize the dataset attributes
|
| 46 |
+
super().__init__(*args, **kwargs)
|
| 47 |
+
self.ROOT = ROOT
|
| 48 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 49 |
+
self.split = split
|
| 50 |
+
self.overfit_num_sets = overfit_num_sets
|
| 51 |
+
self.sample_specific_scene = sample_specific_scene
|
| 52 |
+
self.specific_scene_name = specific_scene_name
|
| 53 |
+
self._load_data()
|
| 54 |
+
|
| 55 |
+
# Define the dataset type flags
|
| 56 |
+
self.is_metric_scale = True
|
| 57 |
+
self.is_synthetic = True
|
| 58 |
+
|
| 59 |
+
def _load_data(self):
|
| 60 |
+
"Load the precomputed dataset metadata"
|
| 61 |
+
# Load the dataset metadata corresponding to the split
|
| 62 |
+
split_metadata_path = os.path.join(
|
| 63 |
+
self.dataset_metadata_dir,
|
| 64 |
+
self.split,
|
| 65 |
+
f"tav2_wb_scene_list_{self.split}.npy",
|
| 66 |
+
)
|
| 67 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 68 |
+
|
| 69 |
+
# Get the list of all scenes
|
| 70 |
+
if not self.sample_specific_scene:
|
| 71 |
+
self.scenes = list(split_scene_list)
|
| 72 |
+
else:
|
| 73 |
+
self.scenes = [self.specific_scene_name]
|
| 74 |
+
self.num_of_scenes = len(self.scenes)
|
| 75 |
+
|
| 76 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 77 |
+
# Get the scene name of the sampled index
|
| 78 |
+
scene_index = sampled_idx
|
| 79 |
+
scene_name = self.scenes[scene_index]
|
| 80 |
+
|
| 81 |
+
# Get the metadata corresponding to the scene
|
| 82 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 83 |
+
scene_meta = load_data(
|
| 84 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 85 |
+
)
|
| 86 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 87 |
+
num_views_in_scene = len(scene_file_names)
|
| 88 |
+
|
| 89 |
+
# Load the scene pairwise covisibility mmap
|
| 90 |
+
covisibility_version_key = "v0"
|
| 91 |
+
covisibility_map_dir = os.path.join(
|
| 92 |
+
scene_root, "covisibility", covisibility_version_key
|
| 93 |
+
)
|
| 94 |
+
# Assumes only npy file in directory is covisibility map
|
| 95 |
+
covisibility_map_name = next(
|
| 96 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 97 |
+
)
|
| 98 |
+
covisibility_map_path = os.path.join(
|
| 99 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 100 |
+
)
|
| 101 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 102 |
+
|
| 103 |
+
# Get the indices of the N views in the scene
|
| 104 |
+
view_indices = self._sample_view_indices(
|
| 105 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Get the views corresponding to the selected view indices
|
| 109 |
+
views = []
|
| 110 |
+
for view_index in view_indices:
|
| 111 |
+
# Load the data corresponding to the view
|
| 112 |
+
view_file_name = scene_file_names[view_index]
|
| 113 |
+
view_data = load_frame(
|
| 114 |
+
scene_root,
|
| 115 |
+
view_file_name,
|
| 116 |
+
modalities=["image", "depth", "pred_mask/moge2"],
|
| 117 |
+
scene_meta=scene_meta,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Convert necessary data to numpy
|
| 121 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 122 |
+
image = (image * 255).astype(np.uint8)
|
| 123 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 124 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 125 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 126 |
+
|
| 127 |
+
# Ensure that the depthmap has all valid values
|
| 128 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 129 |
+
|
| 130 |
+
# Mask out the outlier depth caused due to transparent windows in TartanAirV2
|
| 131 |
+
percentile_depth = np.percentile(depthmap, 95)
|
| 132 |
+
depthmap[depthmap > percentile_depth] = 0
|
| 133 |
+
|
| 134 |
+
# Get the non_ambiguous_mask and ensure it matches image resolution
|
| 135 |
+
non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
|
| 136 |
+
non_ambiguous_mask = cv2.resize(
|
| 137 |
+
non_ambiguous_mask,
|
| 138 |
+
(image.shape[1], image.shape[0]),
|
| 139 |
+
interpolation=cv2.INTER_NEAREST,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Mask out the GT depth using the non_ambiguous_mask
|
| 143 |
+
depthmap = np.where(non_ambiguous_mask, depthmap, 0)
|
| 144 |
+
|
| 145 |
+
# Resize the data to match the desired resolution
|
| 146 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 147 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 148 |
+
self._crop_resize_if_necessary(
|
| 149 |
+
image=image,
|
| 150 |
+
resolution=resolution,
|
| 151 |
+
depthmap=depthmap,
|
| 152 |
+
intrinsics=intrinsics,
|
| 153 |
+
additional_quantities=additional_quantities_to_resize,
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 157 |
+
|
| 158 |
+
# Append the view dictionary to the list of views
|
| 159 |
+
views.append(
|
| 160 |
+
dict(
|
| 161 |
+
img=image,
|
| 162 |
+
depthmap=depthmap,
|
| 163 |
+
camera_pose=c2w_pose, # cam2world
|
| 164 |
+
camera_intrinsics=intrinsics,
|
| 165 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 166 |
+
dataset="TartanAirV2WB",
|
| 167 |
+
label=scene_name,
|
| 168 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return views
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_parser():
|
| 176 |
+
import argparse
|
| 177 |
+
|
| 178 |
+
parser = argparse.ArgumentParser()
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/tav2_wb", type=str
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"-dmd",
|
| 184 |
+
"--dataset_metadata_dir",
|
| 185 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 186 |
+
type=str,
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"-nv",
|
| 190 |
+
"--num_of_views",
|
| 191 |
+
default=2,
|
| 192 |
+
type=int,
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument("--viz", action="store_true")
|
| 195 |
+
|
| 196 |
+
return parser
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
import rerun as rr
|
| 201 |
+
from tqdm import tqdm
|
| 202 |
+
|
| 203 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 204 |
+
from mapanything.utils.image import rgb
|
| 205 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 206 |
+
|
| 207 |
+
parser = get_parser()
|
| 208 |
+
script_add_rerun_args(
|
| 209 |
+
parser
|
| 210 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 211 |
+
args = parser.parse_args()
|
| 212 |
+
|
| 213 |
+
dataset = TartanAirV2WBWAI(
|
| 214 |
+
num_views=args.num_of_views,
|
| 215 |
+
split="train",
|
| 216 |
+
covisibility_thres=0.25,
|
| 217 |
+
ROOT=args.root_dir,
|
| 218 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 219 |
+
resolution=(518, 518),
|
| 220 |
+
aug_crop=16,
|
| 221 |
+
transform="colorjitter+grayscale+gaublur",
|
| 222 |
+
data_norm_type="dinov2",
|
| 223 |
+
)
|
| 224 |
+
# dataset = TartanAirV2WBWAI(
|
| 225 |
+
# num_views=args.num_of_views,
|
| 226 |
+
# split="val",
|
| 227 |
+
# covisibility_thres=0.25,
|
| 228 |
+
# ROOT=args.root_dir,
|
| 229 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 230 |
+
# resolution=(518, 518),
|
| 231 |
+
# seed=777,
|
| 232 |
+
# transform="imgnorm",
|
| 233 |
+
# data_norm_type="dinov2",
|
| 234 |
+
# )
|
| 235 |
+
# dataset = TartanAirV2WBWAI(
|
| 236 |
+
# num_views=args.num_of_views,
|
| 237 |
+
# split="test",
|
| 238 |
+
# covisibility_thres=0.25,
|
| 239 |
+
# ROOT=args.root_dir,
|
| 240 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 241 |
+
# resolution=(518, 518),
|
| 242 |
+
# seed=777,
|
| 243 |
+
# transform="imgnorm",
|
| 244 |
+
# data_norm_type="dinov2",
|
| 245 |
+
# )
|
| 246 |
+
print(dataset.get_stats())
|
| 247 |
+
|
| 248 |
+
if args.viz:
|
| 249 |
+
rr.script_setup(args, "TartanAirV2WB_Dataloader")
|
| 250 |
+
rr.set_time("stable_time", sequence=0)
|
| 251 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 252 |
+
|
| 253 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 254 |
+
|
| 255 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 256 |
+
views = dataset[idx]
|
| 257 |
+
assert len(views) == args.num_of_views
|
| 258 |
+
sample_name = f"{idx}"
|
| 259 |
+
for view_idx in range(args.num_of_views):
|
| 260 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 261 |
+
print(sample_name)
|
| 262 |
+
for view_idx in range(args.num_of_views):
|
| 263 |
+
image = rgb(
|
| 264 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 265 |
+
)
|
| 266 |
+
depthmap = views[view_idx]["depthmap"]
|
| 267 |
+
pose = views[view_idx]["camera_pose"]
|
| 268 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 269 |
+
pts3d = views[view_idx]["pts3d"]
|
| 270 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 271 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 272 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 273 |
+
else:
|
| 274 |
+
non_ambiguous_mask = None
|
| 275 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 276 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 277 |
+
else:
|
| 278 |
+
prior_depth_along_ray = None
|
| 279 |
+
if args.viz:
|
| 280 |
+
rr.set_time("stable_time", sequence=num)
|
| 281 |
+
base_name = f"world/view_{view_idx}"
|
| 282 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 283 |
+
# Log camera info and loaded data
|
| 284 |
+
height, width = image.shape[0], image.shape[1]
|
| 285 |
+
rr.log(
|
| 286 |
+
base_name,
|
| 287 |
+
rr.Transform3D(
|
| 288 |
+
translation=pose[:3, 3],
|
| 289 |
+
mat3x3=pose[:3, :3],
|
| 290 |
+
),
|
| 291 |
+
)
|
| 292 |
+
rr.log(
|
| 293 |
+
f"{base_name}/pinhole",
|
| 294 |
+
rr.Pinhole(
|
| 295 |
+
image_from_camera=intrinsics,
|
| 296 |
+
height=height,
|
| 297 |
+
width=width,
|
| 298 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 299 |
+
),
|
| 300 |
+
)
|
| 301 |
+
rr.log(
|
| 302 |
+
f"{base_name}/pinhole/rgb",
|
| 303 |
+
rr.Image(image),
|
| 304 |
+
)
|
| 305 |
+
rr.log(
|
| 306 |
+
f"{base_name}/pinhole/depth",
|
| 307 |
+
rr.DepthImage(depthmap),
|
| 308 |
+
)
|
| 309 |
+
if prior_depth_along_ray is not None:
|
| 310 |
+
rr.log(
|
| 311 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 312 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 313 |
+
)
|
| 314 |
+
if non_ambiguous_mask is not None:
|
| 315 |
+
rr.log(
|
| 316 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 317 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 318 |
+
)
|
| 319 |
+
# Log points in 3D
|
| 320 |
+
filtered_pts = pts3d[valid_mask]
|
| 321 |
+
filtered_pts_col = image[valid_mask]
|
| 322 |
+
rr.log(
|
| 323 |
+
pts_name,
|
| 324 |
+
rr.Points3D(
|
| 325 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 326 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 327 |
+
),
|
| 328 |
+
)
|
mapanything/datasets/wai/unrealstereo4k.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
UnrealStereo4K Dataset using WAI format data.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from mapanything.datasets.base.base_dataset import BaseDataset
|
| 15 |
+
from mapanything.utils.wai.core import load_data, load_frame
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class UnrealStereo4KWAI(BaseDataset):
|
| 19 |
+
"""
|
| 20 |
+
UnrealStereo4K dataset containing synthetic in-the-wild scenes.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*args,
|
| 26 |
+
ROOT,
|
| 27 |
+
dataset_metadata_dir,
|
| 28 |
+
split,
|
| 29 |
+
overfit_num_sets=None,
|
| 30 |
+
sample_specific_scene: bool = False,
|
| 31 |
+
specific_scene_name: str = None,
|
| 32 |
+
**kwargs,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initialize the dataset attributes.
|
| 36 |
+
Args:
|
| 37 |
+
ROOT: Root directory of the dataset.
|
| 38 |
+
dataset_metadata_dir: Path to the dataset metadata directory.
|
| 39 |
+
split: Dataset split (train, val, test).
|
| 40 |
+
overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
|
| 41 |
+
sample_specific_scene: Whether to sample a specific scene from the dataset.
|
| 42 |
+
specific_scene_name: Name of the specific scene to sample.
|
| 43 |
+
"""
|
| 44 |
+
# Initialize the dataset attributes
|
| 45 |
+
super().__init__(*args, **kwargs)
|
| 46 |
+
self.ROOT = ROOT
|
| 47 |
+
self.dataset_metadata_dir = dataset_metadata_dir
|
| 48 |
+
self.split = split
|
| 49 |
+
self.overfit_num_sets = overfit_num_sets
|
| 50 |
+
self.sample_specific_scene = sample_specific_scene
|
| 51 |
+
self.specific_scene_name = specific_scene_name
|
| 52 |
+
self._load_data()
|
| 53 |
+
|
| 54 |
+
# Define the dataset type flags
|
| 55 |
+
self.is_metric_scale = True
|
| 56 |
+
self.is_synthetic = True
|
| 57 |
+
|
| 58 |
+
def _load_data(self):
|
| 59 |
+
"Load the precomputed dataset metadata"
|
| 60 |
+
# Load the dataset metadata corresponding to the split
|
| 61 |
+
split_metadata_path = os.path.join(
|
| 62 |
+
self.dataset_metadata_dir,
|
| 63 |
+
self.split,
|
| 64 |
+
f"unrealstereo4k_scene_list_{self.split}.npy",
|
| 65 |
+
)
|
| 66 |
+
split_scene_list = np.load(split_metadata_path, allow_pickle=True)
|
| 67 |
+
|
| 68 |
+
# Get the list of all scenes
|
| 69 |
+
if not self.sample_specific_scene:
|
| 70 |
+
self.scenes = list(split_scene_list)
|
| 71 |
+
else:
|
| 72 |
+
self.scenes = [self.specific_scene_name]
|
| 73 |
+
self.num_of_scenes = len(self.scenes)
|
| 74 |
+
|
| 75 |
+
def _get_views(self, sampled_idx, num_views_to_sample, resolution):
|
| 76 |
+
# Get the scene name of the sampled index
|
| 77 |
+
scene_index = sampled_idx
|
| 78 |
+
scene_name = self.scenes[scene_index]
|
| 79 |
+
|
| 80 |
+
# Get the metadata corresponding to the scene
|
| 81 |
+
scene_root = os.path.join(self.ROOT, scene_name)
|
| 82 |
+
scene_meta = load_data(
|
| 83 |
+
os.path.join(scene_root, "scene_meta.json"), "scene_meta"
|
| 84 |
+
)
|
| 85 |
+
scene_file_names = list(scene_meta["frame_names"].keys())
|
| 86 |
+
num_views_in_scene = len(scene_file_names)
|
| 87 |
+
|
| 88 |
+
# Load the scene pairwise covisibility mmap
|
| 89 |
+
covisibility_version_key = "v0"
|
| 90 |
+
covisibility_map_dir = os.path.join(
|
| 91 |
+
scene_root, "covisibility", covisibility_version_key
|
| 92 |
+
)
|
| 93 |
+
# Assumes only npy file in directory is covisibility map
|
| 94 |
+
covisibility_map_name = next(
|
| 95 |
+
f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
|
| 96 |
+
)
|
| 97 |
+
covisibility_map_path = os.path.join(
|
| 98 |
+
scene_root, "covisibility", covisibility_version_key, covisibility_map_name
|
| 99 |
+
)
|
| 100 |
+
pairwise_covisibility = load_data(covisibility_map_path, "mmap")
|
| 101 |
+
|
| 102 |
+
# Get the indices of the N views in the scene
|
| 103 |
+
view_indices = self._sample_view_indices(
|
| 104 |
+
num_views_to_sample, num_views_in_scene, pairwise_covisibility
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Get the views corresponding to the selected view indices
|
| 108 |
+
views = []
|
| 109 |
+
for view_index in view_indices:
|
| 110 |
+
# Load the data corresponding to the view
|
| 111 |
+
view_file_name = scene_file_names[view_index]
|
| 112 |
+
view_data = load_frame(
|
| 113 |
+
scene_root,
|
| 114 |
+
view_file_name,
|
| 115 |
+
modalities=["image", "depth"],
|
| 116 |
+
scene_meta=scene_meta,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Convert necessary data to numpy
|
| 120 |
+
image = view_data["image"].permute(1, 2, 0).numpy()
|
| 121 |
+
image = image[:, :, :3] # RGBA to RGB
|
| 122 |
+
image = (image * 255).astype(np.uint8)
|
| 123 |
+
depthmap = view_data["depth"].numpy().astype(np.float32)
|
| 124 |
+
intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
|
| 125 |
+
c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
|
| 126 |
+
|
| 127 |
+
# Ensure that the depthmap has all valid values
|
| 128 |
+
depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
|
| 129 |
+
|
| 130 |
+
# Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
|
| 131 |
+
non_ambiguous_mask = (depthmap > 0).astype(int)
|
| 132 |
+
|
| 133 |
+
# Mask out the outlier depth (horizon depth)
|
| 134 |
+
percentile_depth = np.percentile(depthmap, 95)
|
| 135 |
+
depthmap[depthmap > percentile_depth] = 0
|
| 136 |
+
|
| 137 |
+
# Resize the data to match the desired resolution
|
| 138 |
+
additional_quantities_to_resize = [non_ambiguous_mask]
|
| 139 |
+
image, depthmap, intrinsics, additional_quantities_to_resize = (
|
| 140 |
+
self._crop_resize_if_necessary(
|
| 141 |
+
image=image,
|
| 142 |
+
resolution=resolution,
|
| 143 |
+
depthmap=depthmap,
|
| 144 |
+
intrinsics=intrinsics,
|
| 145 |
+
additional_quantities=additional_quantities_to_resize,
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
non_ambiguous_mask = additional_quantities_to_resize[0]
|
| 149 |
+
|
| 150 |
+
# Append the view dictionary to the list of views
|
| 151 |
+
views.append(
|
| 152 |
+
dict(
|
| 153 |
+
img=image,
|
| 154 |
+
depthmap=depthmap,
|
| 155 |
+
camera_pose=c2w_pose, # cam2world
|
| 156 |
+
camera_intrinsics=intrinsics,
|
| 157 |
+
non_ambiguous_mask=non_ambiguous_mask,
|
| 158 |
+
dataset="UnrealStereo4K",
|
| 159 |
+
label=scene_name,
|
| 160 |
+
instance=os.path.join("images", str(view_file_name)),
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return views
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_parser():
|
| 168 |
+
import argparse
|
| 169 |
+
|
| 170 |
+
parser = argparse.ArgumentParser()
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"-rd", "--root_dir", default="/fsx/xrtech/data/unrealstereo4k", type=str
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"-dmd",
|
| 176 |
+
"--dataset_metadata_dir",
|
| 177 |
+
default="/fsx/nkeetha/mapanything_dataset_metadata",
|
| 178 |
+
type=str,
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"-nv",
|
| 182 |
+
"--num_of_views",
|
| 183 |
+
default=2,
|
| 184 |
+
type=int,
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument("--viz", action="store_true")
|
| 187 |
+
|
| 188 |
+
return parser
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
import rerun as rr
|
| 193 |
+
from tqdm import tqdm
|
| 194 |
+
|
| 195 |
+
from mapanything.datasets.base.base_dataset import view_name
|
| 196 |
+
from mapanything.utils.image import rgb
|
| 197 |
+
from mapanything.utils.viz import script_add_rerun_args
|
| 198 |
+
|
| 199 |
+
parser = get_parser()
|
| 200 |
+
script_add_rerun_args(
|
| 201 |
+
parser
|
| 202 |
+
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
|
| 203 |
+
args = parser.parse_args()
|
| 204 |
+
|
| 205 |
+
dataset = UnrealStereo4KWAI(
|
| 206 |
+
num_views=args.num_of_views,
|
| 207 |
+
split="train",
|
| 208 |
+
covisibility_thres=0.25,
|
| 209 |
+
ROOT=args.root_dir,
|
| 210 |
+
dataset_metadata_dir=args.dataset_metadata_dir,
|
| 211 |
+
resolution=(518, 294),
|
| 212 |
+
aug_crop=16,
|
| 213 |
+
transform="colorjitter+grayscale+gaublur",
|
| 214 |
+
data_norm_type="dinov2",
|
| 215 |
+
)
|
| 216 |
+
# dataset = UnrealStereo4KWAI(
|
| 217 |
+
# num_views=args.num_of_views,
|
| 218 |
+
# split="val",
|
| 219 |
+
# covisibility_thres=0.25,
|
| 220 |
+
# ROOT=args.root_dir,
|
| 221 |
+
# dataset_metadata_dir=args.dataset_metadata_dir,
|
| 222 |
+
# resolution=(518, 294),
|
| 223 |
+
# seed=777,
|
| 224 |
+
# transform="imgnorm",
|
| 225 |
+
# data_norm_type="dinov2",
|
| 226 |
+
# )
|
| 227 |
+
print(dataset.get_stats())
|
| 228 |
+
|
| 229 |
+
if args.viz:
|
| 230 |
+
rr.script_setup(args, "UnrealStereo4K_Dataloader")
|
| 231 |
+
rr.set_time("stable_time", sequence=0)
|
| 232 |
+
rr.log("world", rr.ViewCoordinates.RDF, static=True)
|
| 233 |
+
|
| 234 |
+
sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
|
| 235 |
+
|
| 236 |
+
for num, idx in enumerate(tqdm(sampled_indices)):
|
| 237 |
+
views = dataset[idx]
|
| 238 |
+
assert len(views) == args.num_of_views
|
| 239 |
+
sample_name = f"{idx}"
|
| 240 |
+
for view_idx in range(args.num_of_views):
|
| 241 |
+
sample_name += f" {view_name(views[view_idx])}"
|
| 242 |
+
print(sample_name)
|
| 243 |
+
for view_idx in range(args.num_of_views):
|
| 244 |
+
image = rgb(
|
| 245 |
+
views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
|
| 246 |
+
)
|
| 247 |
+
depthmap = views[view_idx]["depthmap"]
|
| 248 |
+
pose = views[view_idx]["camera_pose"]
|
| 249 |
+
intrinsics = views[view_idx]["camera_intrinsics"]
|
| 250 |
+
pts3d = views[view_idx]["pts3d"]
|
| 251 |
+
valid_mask = views[view_idx]["valid_mask"]
|
| 252 |
+
if "non_ambiguous_mask" in views[view_idx]:
|
| 253 |
+
non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
|
| 254 |
+
else:
|
| 255 |
+
non_ambiguous_mask = None
|
| 256 |
+
if "prior_depth_along_ray" in views[view_idx]:
|
| 257 |
+
prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
|
| 258 |
+
else:
|
| 259 |
+
prior_depth_along_ray = None
|
| 260 |
+
if args.viz:
|
| 261 |
+
rr.set_time("stable_time", sequence=num)
|
| 262 |
+
base_name = f"world/view_{view_idx}"
|
| 263 |
+
pts_name = f"world/view_{view_idx}_pointcloud"
|
| 264 |
+
# Log camera info and loaded data
|
| 265 |
+
height, width = image.shape[0], image.shape[1]
|
| 266 |
+
rr.log(
|
| 267 |
+
base_name,
|
| 268 |
+
rr.Transform3D(
|
| 269 |
+
translation=pose[:3, 3],
|
| 270 |
+
mat3x3=pose[:3, :3],
|
| 271 |
+
),
|
| 272 |
+
)
|
| 273 |
+
rr.log(
|
| 274 |
+
f"{base_name}/pinhole",
|
| 275 |
+
rr.Pinhole(
|
| 276 |
+
image_from_camera=intrinsics,
|
| 277 |
+
height=height,
|
| 278 |
+
width=width,
|
| 279 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
| 280 |
+
),
|
| 281 |
+
)
|
| 282 |
+
rr.log(
|
| 283 |
+
f"{base_name}/pinhole/rgb",
|
| 284 |
+
rr.Image(image),
|
| 285 |
+
)
|
| 286 |
+
rr.log(
|
| 287 |
+
f"{base_name}/pinhole/depth",
|
| 288 |
+
rr.DepthImage(depthmap),
|
| 289 |
+
)
|
| 290 |
+
if prior_depth_along_ray is not None:
|
| 291 |
+
rr.log(
|
| 292 |
+
f"prior_depth_along_ray_{view_idx}",
|
| 293 |
+
rr.DepthImage(prior_depth_along_ray),
|
| 294 |
+
)
|
| 295 |
+
if non_ambiguous_mask is not None:
|
| 296 |
+
rr.log(
|
| 297 |
+
f"{base_name}/pinhole/non_ambiguous_mask",
|
| 298 |
+
rr.SegmentationImage(non_ambiguous_mask.astype(int)),
|
| 299 |
+
)
|
| 300 |
+
# Log points in 3D
|
| 301 |
+
filtered_pts = pts3d[valid_mask]
|
| 302 |
+
filtered_pts_col = image[valid_mask]
|
| 303 |
+
rr.log(
|
| 304 |
+
pts_name,
|
| 305 |
+
rr.Points3D(
|
| 306 |
+
positions=filtered_pts.reshape(-1, 3),
|
| 307 |
+
colors=filtered_pts_col.reshape(-1, 3),
|
| 308 |
+
),
|
| 309 |
+
)
|
mapanything/models/__init__.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Model Factory for MapAnything
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import importlib.util
|
| 11 |
+
import logging
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
from omegaconf import DictConfig, OmegaConf
|
| 16 |
+
|
| 17 |
+
# Core models that are always available
|
| 18 |
+
from mapanything.models.mapanything import (
|
| 19 |
+
MapAnything,
|
| 20 |
+
MapAnythingAblations,
|
| 21 |
+
ModularDUSt3R,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Suppress DINOv2 warnings
|
| 25 |
+
logging.getLogger("dinov2").setLevel(logging.WARNING)
|
| 26 |
+
warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning)
|
| 27 |
+
warnings.filterwarnings(
|
| 28 |
+
"ignore", message="xFormers is not available", category=UserWarning
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def resolve_special_float(value):
|
| 33 |
+
if value == "inf":
|
| 34 |
+
return np.inf
|
| 35 |
+
elif value == "-inf":
|
| 36 |
+
return -np.inf
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unknown special float value: {value}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def init_model(
|
| 42 |
+
model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Initialize a model using OmegaConf configuration.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model_str (str): Name of the model class to create.
|
| 49 |
+
model_config (DictConfig): OmegaConf model configuration.
|
| 50 |
+
torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub.
|
| 51 |
+
"""
|
| 52 |
+
if not OmegaConf.has_resolver("special_float"):
|
| 53 |
+
OmegaConf.register_new_resolver("special_float", resolve_special_float)
|
| 54 |
+
model_dict = OmegaConf.to_container(model_config, resolve=True)
|
| 55 |
+
model = model_factory(
|
| 56 |
+
model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return model
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Define model configurations with import paths
|
| 63 |
+
MODEL_CONFIGS = {
|
| 64 |
+
# Core models
|
| 65 |
+
"mapanything": {
|
| 66 |
+
"class": MapAnything,
|
| 67 |
+
},
|
| 68 |
+
"mapanything_ablations": {
|
| 69 |
+
"class": MapAnythingAblations,
|
| 70 |
+
},
|
| 71 |
+
"modular_dust3r": {
|
| 72 |
+
"class": ModularDUSt3R,
|
| 73 |
+
},
|
| 74 |
+
# External models
|
| 75 |
+
"anycalib": {
|
| 76 |
+
"module": "mapanything.models.external.anycalib",
|
| 77 |
+
"class_name": "AnyCalibWrapper",
|
| 78 |
+
},
|
| 79 |
+
"dust3r": {
|
| 80 |
+
"module": "mapanything.models.external.dust3r",
|
| 81 |
+
"class_name": "DUSt3RBAWrapper",
|
| 82 |
+
},
|
| 83 |
+
"mast3r": {
|
| 84 |
+
"module": "mapanything.models.external.mast3r",
|
| 85 |
+
"class_name": "MASt3RSGAWrapper",
|
| 86 |
+
},
|
| 87 |
+
"moge": {
|
| 88 |
+
"module": "mapanything.models.external.moge",
|
| 89 |
+
"class_name": "MoGeWrapper",
|
| 90 |
+
},
|
| 91 |
+
"must3r": {
|
| 92 |
+
"module": "mapanything.models.external.must3r",
|
| 93 |
+
"class_name": "MUSt3RWrapper",
|
| 94 |
+
},
|
| 95 |
+
"pi3": {
|
| 96 |
+
"module": "mapanything.models.external.pi3",
|
| 97 |
+
"class_name": "Pi3Wrapper",
|
| 98 |
+
},
|
| 99 |
+
"pow3r": {
|
| 100 |
+
"module": "mapanything.models.external.pow3r",
|
| 101 |
+
"class_name": "Pow3RWrapper",
|
| 102 |
+
},
|
| 103 |
+
"pow3r_ba": {
|
| 104 |
+
"module": "mapanything.models.external.pow3r",
|
| 105 |
+
"class_name": "Pow3RBAWrapper",
|
| 106 |
+
},
|
| 107 |
+
"vggt": {
|
| 108 |
+
"module": "mapanything.models.external.vggt",
|
| 109 |
+
"class_name": "VGGTWrapper",
|
| 110 |
+
},
|
| 111 |
+
# Add other model classes here
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def check_module_exists(module_path):
|
| 116 |
+
"""
|
| 117 |
+
Check if a module can be imported without actually importing it.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
module_path (str): The path to the module to check.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
bool: True if the module can be imported, False otherwise.
|
| 124 |
+
"""
|
| 125 |
+
return importlib.util.find_spec(module_path) is not None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def model_factory(model_str: str, **kwargs):
|
| 129 |
+
"""
|
| 130 |
+
Model factory for MapAnything.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
model_str (str): Name of the model to create.
|
| 134 |
+
**kwargs: Additional keyword arguments to pass to the model constructor.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
nn.Module: An instance of the specified model.
|
| 138 |
+
"""
|
| 139 |
+
if model_str not in MODEL_CONFIGS:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
model_config = MODEL_CONFIGS[model_str]
|
| 145 |
+
|
| 146 |
+
# Handle core models directly
|
| 147 |
+
if "class" in model_config:
|
| 148 |
+
model_class = model_config["class"]
|
| 149 |
+
# Handle external models with dynamic imports
|
| 150 |
+
elif "module" in model_config:
|
| 151 |
+
module_path = model_config["module"]
|
| 152 |
+
class_name = model_config["class_name"]
|
| 153 |
+
|
| 154 |
+
# Check if the module can be imported
|
| 155 |
+
if not check_module_exists(module_path):
|
| 156 |
+
raise ImportError(
|
| 157 |
+
f"Model '{model_str}' requires module '{module_path}' which is not installed. "
|
| 158 |
+
f"Please install the corresponding submodule or package."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Dynamically import the module and get the class
|
| 162 |
+
try:
|
| 163 |
+
module = importlib.import_module(module_path)
|
| 164 |
+
model_class = getattr(module, class_name)
|
| 165 |
+
except (ImportError, AttributeError) as e:
|
| 166 |
+
raise ImportError(
|
| 167 |
+
f"Failed to import {class_name} from {module_path}: {str(e)}"
|
| 168 |
+
)
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(f"Invalid model configuration for {model_str}")
|
| 171 |
+
|
| 172 |
+
print(f"Initializing {model_class} with kwargs: {kwargs}")
|
| 173 |
+
if model_str != "org_dust3r":
|
| 174 |
+
return model_class(**kwargs)
|
| 175 |
+
else:
|
| 176 |
+
eval_str = kwargs.get("model_eval_str", None)
|
| 177 |
+
return eval(eval_str)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_available_models() -> list:
|
| 181 |
+
"""
|
| 182 |
+
Get a list of available models in MapAnything.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
list: A list of available model names.
|
| 186 |
+
"""
|
| 187 |
+
return list(MODEL_CONFIGS.keys())
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
__all__ = ["model_factory", "get_available_models"]
|
mapanything/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (5.63 kB). View file
|
|
|
mapanything/models/external/README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# External Model Code for Benchmarking & Re-Training
|
| 2 |
+
|
| 3 |
+
This directory contains external model code that we use to train and benchmark external models fairly. These libraries are not part of the core MapAnything codebase and are included for only benchmarking purposes. The code in this directory is licensed under the same license as the source code from which it was derived, unless otherwise specified.
|
| 4 |
+
|
| 5 |
+
The open-source Apache 2.0 License of MapAnything does not apply to these libraries.
|
mapanything/models/external/__init__.py
ADDED
|
File without changes
|
mapanything/models/external/anycalib/__init__.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
Inference wrapper for AnyCalib
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from anycalib import AnyCalib
|
| 12 |
+
|
| 13 |
+
from mapanything.utils.geometry import get_rays_in_camera_frame
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AnyCalibWrapper(torch.nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
name,
|
| 20 |
+
model_id="anycalib_pinhole",
|
| 21 |
+
**kwargs,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.name = name
|
| 25 |
+
self.model_id = model_id
|
| 26 |
+
|
| 27 |
+
# Initialize the model
|
| 28 |
+
self.model = AnyCalib(model_id=self.model_id)
|
| 29 |
+
|
| 30 |
+
def forward(self, views):
|
| 31 |
+
"""
|
| 32 |
+
Forward pass wrapper for AnyCalib.
|
| 33 |
+
|
| 34 |
+
Assumption:
|
| 35 |
+
- The number of input views is 1.
|
| 36 |
+
- The output camera model is pinhole (fx, fy, cx, cy).
|
| 37 |
+
This can be relaxed by not hardcoding the cam_id.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
views (List[dict]): List of dictionaries containing the input views' images and instance information.
|
| 41 |
+
Length of the list should be 1.
|
| 42 |
+
Each dictionary should contain the following keys:
|
| 43 |
+
"img" (tensor): Image tensor of shape (B, C, H, W).
|
| 44 |
+
"data_norm_type" (list): ["identity"]
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
List[dict]: A list containing the final outputs for the single view. Length of the list will be 1.
|
| 48 |
+
"""
|
| 49 |
+
# Check that the number of input views is 1
|
| 50 |
+
assert len(views) == 1, "AnyCalib only supports 1 input view."
|
| 51 |
+
|
| 52 |
+
# Get input shape of the images and batch size per view
|
| 53 |
+
_, _, height, width = views[0]["img"].shape
|
| 54 |
+
|
| 55 |
+
# Check the data norm type
|
| 56 |
+
# AnyCalib expects a normalized image but without the DINOv2 mean and std applied ("identity")
|
| 57 |
+
data_norm_type = views[0]["data_norm_type"][0]
|
| 58 |
+
assert data_norm_type == "identity", (
|
| 59 |
+
"AnyCalib expects a normalized image but without the DINOv2 mean and std applied"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Run AnyCalib inference
|
| 63 |
+
# Corresponding batched output dictionary:
|
| 64 |
+
# {
|
| 65 |
+
# "intrinsics": List[(D_i,) tensors] for each camera model "i" at the original input resolution,
|
| 66 |
+
# "fov_field": (B, N, 2) tensor with the regressed FoV field by the network. N≈320^2 (resolution close to the one seen during training),
|
| 67 |
+
# "tangent_coords": alias for "fov_field",
|
| 68 |
+
# "rays": (B, N, 3) tensor with the corresponding (via the exponential map) ray directions in the camera frame (x right, y down, z forward),
|
| 69 |
+
# "pred_size": (H, W) tuple with the image size used by the network. It can be used e.g. for resizing the FoV/ray fields to the original image size.
|
| 70 |
+
# }
|
| 71 |
+
# For "pinhole" camera model, the intrinsics are (fx, fy, cx, cy).
|
| 72 |
+
model_outputs = self.model.predict(views[0]["img"], cam_id="pinhole")
|
| 73 |
+
|
| 74 |
+
# Convert the list of intrinsics to a tensor
|
| 75 |
+
intrinsics = []
|
| 76 |
+
for intrinsics_per_sample in model_outputs["intrinsics"]:
|
| 77 |
+
pred_fx, pred_fy, pred_cx, pred_cy = intrinsics_per_sample
|
| 78 |
+
intrinsics_per_sample = torch.tensor(
|
| 79 |
+
[
|
| 80 |
+
[pred_fx, 0, pred_cx],
|
| 81 |
+
[0, pred_fy, pred_cy],
|
| 82 |
+
[0, 0, 1],
|
| 83 |
+
],
|
| 84 |
+
device=views[0]["img"].device,
|
| 85 |
+
)
|
| 86 |
+
intrinsics.append(intrinsics_per_sample)
|
| 87 |
+
|
| 88 |
+
# Convert the list of intrinsics to a tensor of size (batch_size_per_view, 3, 3)
|
| 89 |
+
intrinsics = torch.stack(intrinsics)
|
| 90 |
+
|
| 91 |
+
# Get the ray directions
|
| 92 |
+
with torch.autocast("cuda", enabled=False):
|
| 93 |
+
_, ray_directions = get_rays_in_camera_frame(
|
| 94 |
+
intrinsics, height, width, normalize_to_unit_sphere=True
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Return the output in MapAnything format
|
| 98 |
+
res = [{"ray_directions": ray_directions, "intrinsics": intrinsics}]
|
| 99 |
+
|
| 100 |
+
return res
|
mapanything/models/external/dinov2/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
mapanything/models/external/dinov2/hub/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
mapanything/models/external/dinov2/hub/backbones.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from mapanything.models.external.dinov2.hub.utils import (
|
| 12 |
+
_DINOV2_BASE_URL,
|
| 13 |
+
_make_dinov2_model_name,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Weights(Enum):
|
| 18 |
+
LVD142M = "LVD142M"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _make_dinov2_model(
|
| 22 |
+
*,
|
| 23 |
+
arch_name: str = "vit_large",
|
| 24 |
+
img_size: int = 518,
|
| 25 |
+
patch_size: int = 14,
|
| 26 |
+
init_values: float = 1.0,
|
| 27 |
+
ffn_layer: str = "mlp",
|
| 28 |
+
block_chunks: int = 0,
|
| 29 |
+
num_register_tokens: int = 0,
|
| 30 |
+
interpolate_antialias: bool = False,
|
| 31 |
+
interpolate_offset: float = 0.1,
|
| 32 |
+
pretrained: bool = True,
|
| 33 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
| 34 |
+
**kwargs,
|
| 35 |
+
):
|
| 36 |
+
from ..models import vision_transformer as vits
|
| 37 |
+
|
| 38 |
+
if isinstance(weights, str):
|
| 39 |
+
try:
|
| 40 |
+
weights = Weights[weights]
|
| 41 |
+
except KeyError:
|
| 42 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 43 |
+
|
| 44 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 45 |
+
vit_kwargs = dict(
|
| 46 |
+
img_size=img_size,
|
| 47 |
+
patch_size=patch_size,
|
| 48 |
+
init_values=init_values,
|
| 49 |
+
ffn_layer=ffn_layer,
|
| 50 |
+
block_chunks=block_chunks,
|
| 51 |
+
num_register_tokens=num_register_tokens,
|
| 52 |
+
interpolate_antialias=interpolate_antialias,
|
| 53 |
+
interpolate_offset=interpolate_offset,
|
| 54 |
+
)
|
| 55 |
+
vit_kwargs.update(**kwargs)
|
| 56 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 57 |
+
|
| 58 |
+
if pretrained:
|
| 59 |
+
model_full_name = _make_dinov2_model_name(
|
| 60 |
+
arch_name, patch_size, num_register_tokens
|
| 61 |
+
)
|
| 62 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 63 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 64 |
+
model.load_state_dict(state_dict, strict=True)
|
| 65 |
+
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def dinov2_vits14(
|
| 70 |
+
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 74 |
+
"""
|
| 75 |
+
return _make_dinov2_model(
|
| 76 |
+
arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def dinov2_vitb14(
|
| 81 |
+
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 82 |
+
):
|
| 83 |
+
"""
|
| 84 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 85 |
+
"""
|
| 86 |
+
return _make_dinov2_model(
|
| 87 |
+
arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def dinov2_vitl14(
|
| 92 |
+
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 93 |
+
):
|
| 94 |
+
"""
|
| 95 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 96 |
+
"""
|
| 97 |
+
return _make_dinov2_model(
|
| 98 |
+
arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def dinov2_vitg14(
|
| 103 |
+
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 104 |
+
):
|
| 105 |
+
"""
|
| 106 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 107 |
+
"""
|
| 108 |
+
return _make_dinov2_model(
|
| 109 |
+
arch_name="vit_giant2",
|
| 110 |
+
ffn_layer="swiglufused",
|
| 111 |
+
weights=weights,
|
| 112 |
+
pretrained=pretrained,
|
| 113 |
+
**kwargs,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def dinov2_vits14_reg(
|
| 118 |
+
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 119 |
+
):
|
| 120 |
+
"""
|
| 121 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 122 |
+
"""
|
| 123 |
+
return _make_dinov2_model(
|
| 124 |
+
arch_name="vit_small",
|
| 125 |
+
pretrained=pretrained,
|
| 126 |
+
weights=weights,
|
| 127 |
+
num_register_tokens=4,
|
| 128 |
+
interpolate_antialias=True,
|
| 129 |
+
interpolate_offset=0.0,
|
| 130 |
+
**kwargs,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def dinov2_vitb14_reg(
|
| 135 |
+
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 136 |
+
):
|
| 137 |
+
"""
|
| 138 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 139 |
+
"""
|
| 140 |
+
return _make_dinov2_model(
|
| 141 |
+
arch_name="vit_base",
|
| 142 |
+
pretrained=pretrained,
|
| 143 |
+
weights=weights,
|
| 144 |
+
num_register_tokens=4,
|
| 145 |
+
interpolate_antialias=True,
|
| 146 |
+
interpolate_offset=0.0,
|
| 147 |
+
**kwargs,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def dinov2_vitl14_reg(
|
| 152 |
+
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 153 |
+
):
|
| 154 |
+
"""
|
| 155 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 156 |
+
"""
|
| 157 |
+
return _make_dinov2_model(
|
| 158 |
+
arch_name="vit_large",
|
| 159 |
+
pretrained=pretrained,
|
| 160 |
+
weights=weights,
|
| 161 |
+
num_register_tokens=4,
|
| 162 |
+
interpolate_antialias=True,
|
| 163 |
+
interpolate_offset=0.0,
|
| 164 |
+
**kwargs,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def dinov2_vitg14_reg(
|
| 169 |
+
*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs
|
| 170 |
+
):
|
| 171 |
+
"""
|
| 172 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 173 |
+
"""
|
| 174 |
+
return _make_dinov2_model(
|
| 175 |
+
arch_name="vit_giant2",
|
| 176 |
+
ffn_layer="swiglufused",
|
| 177 |
+
weights=weights,
|
| 178 |
+
pretrained=pretrained,
|
| 179 |
+
num_register_tokens=4,
|
| 180 |
+
interpolate_antialias=True,
|
| 181 |
+
interpolate_offset=0.0,
|
| 182 |
+
**kwargs,
|
| 183 |
+
)
|
mapanything/models/external/dinov2/hub/utils.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _make_dinov2_model_name(
|
| 17 |
+
arch_name: str, patch_size: int, num_register_tokens: int = 0
|
| 18 |
+
) -> str:
|
| 19 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 20 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 21 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CenterPadding(nn.Module):
|
| 25 |
+
def __init__(self, multiple):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.multiple = multiple
|
| 28 |
+
|
| 29 |
+
def _get_pad(self, size):
|
| 30 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 31 |
+
pad_size = new_size - size
|
| 32 |
+
pad_size_left = pad_size // 2
|
| 33 |
+
pad_size_right = pad_size - pad_size_left
|
| 34 |
+
return pad_size_left, pad_size_right
|
| 35 |
+
|
| 36 |
+
@torch.inference_mode()
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
pads = list(
|
| 39 |
+
itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])
|
| 40 |
+
)
|
| 41 |
+
output = F.pad(x, pads)
|
| 42 |
+
return output
|
mapanything/models/external/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from mapanything.models.external.dinov2.layers.dino_head import DINOHead # noqa
|
| 7 |
+
from mapanything.models.external.dinov2.layers.mlp import Mlp # noqa
|
| 8 |
+
from mapanything.models.external.dinov2.layers.patch_embed import PatchEmbed # noqa
|
| 9 |
+
from mapanything.models.external.dinov2.layers.swiglu_ffn import (
|
| 10 |
+
SwiGLUFFN, # noqa
|
| 11 |
+
SwiGLUFFNFused, # noqa
|
| 12 |
+
)
|
| 13 |
+
from mapanything.models.external.dinov2.layers.block import NestedTensorBlock # noqa
|
| 14 |
+
from mapanything.models.external.dinov2.layers.attention import MemEffAttention # noqa
|
mapanything/models/external/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
from torch import nn, Tensor
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("dinov2")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 19 |
+
try:
|
| 20 |
+
if XFORMERS_ENABLED:
|
| 21 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 22 |
+
|
| 23 |
+
XFORMERS_AVAILABLE = True
|
| 24 |
+
# warnings.warn("xFormers is available (Attention)")
|
| 25 |
+
else:
|
| 26 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
| 27 |
+
raise ImportError
|
| 28 |
+
except ImportError:
|
| 29 |
+
XFORMERS_AVAILABLE = False
|
| 30 |
+
# warnings.warn("xFormers is not available (Attention)")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Attention(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
dim: int,
|
| 37 |
+
num_heads: int = 8,
|
| 38 |
+
qkv_bias: bool = False,
|
| 39 |
+
proj_bias: bool = True,
|
| 40 |
+
attn_drop: float = 0.0,
|
| 41 |
+
proj_drop: float = 0.0,
|
| 42 |
+
) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.num_heads = num_heads
|
| 45 |
+
head_dim = dim // num_heads
|
| 46 |
+
self.scale = head_dim**-0.5
|
| 47 |
+
|
| 48 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 49 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 50 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 51 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 52 |
+
|
| 53 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 54 |
+
B, N, C = x.shape
|
| 55 |
+
qkv = (
|
| 56 |
+
self.qkv(x)
|
| 57 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 58 |
+
.permute(2, 0, 3, 1, 4)
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 62 |
+
attn = q @ k.transpose(-2, -1)
|
| 63 |
+
|
| 64 |
+
attn = attn.softmax(dim=-1)
|
| 65 |
+
attn = self.attn_drop(attn)
|
| 66 |
+
|
| 67 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 68 |
+
x = self.proj(x)
|
| 69 |
+
x = self.proj_drop(x)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MemEffAttention(Attention):
|
| 74 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 75 |
+
if not XFORMERS_AVAILABLE:
|
| 76 |
+
if attn_bias is not None:
|
| 77 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 78 |
+
return super().forward(x)
|
| 79 |
+
|
| 80 |
+
B, N, C = x.shape
|
| 81 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 82 |
+
|
| 83 |
+
q, k, v = unbind(qkv, 2)
|
| 84 |
+
|
| 85 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 86 |
+
x = x.reshape([B, N, C])
|
| 87 |
+
|
| 88 |
+
x = self.proj(x)
|
| 89 |
+
x = self.proj_drop(x)
|
| 90 |
+
return x
|
mapanything/models/external/dinov2/layers/block.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Any, Callable, Dict, List, Tuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn, Tensor
|
| 16 |
+
|
| 17 |
+
from mapanything.models.external.dinov2.layers.attention import (
|
| 18 |
+
Attention,
|
| 19 |
+
MemEffAttention,
|
| 20 |
+
)
|
| 21 |
+
from mapanything.models.external.dinov2.layers.drop_path import DropPath
|
| 22 |
+
from mapanything.models.external.dinov2.layers.layer_scale import LayerScale
|
| 23 |
+
from mapanything.models.external.dinov2.layers.mlp import Mlp
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("dinov2")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 29 |
+
try:
|
| 30 |
+
if XFORMERS_ENABLED:
|
| 31 |
+
from xformers.ops import fmha, index_select_cat, scaled_index_add
|
| 32 |
+
|
| 33 |
+
XFORMERS_AVAILABLE = True
|
| 34 |
+
# warnings.warn("xFormers is available (Block)")
|
| 35 |
+
else:
|
| 36 |
+
# warnings.warn("xFormers is disabled (Block)")
|
| 37 |
+
raise ImportError
|
| 38 |
+
except ImportError:
|
| 39 |
+
XFORMERS_AVAILABLE = False
|
| 40 |
+
# warnings.warn("xFormers is not available (Block)")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Block(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
dim: int,
|
| 47 |
+
num_heads: int,
|
| 48 |
+
mlp_ratio: float = 4.0,
|
| 49 |
+
qkv_bias: bool = False,
|
| 50 |
+
proj_bias: bool = True,
|
| 51 |
+
ffn_bias: bool = True,
|
| 52 |
+
drop: float = 0.0,
|
| 53 |
+
attn_drop: float = 0.0,
|
| 54 |
+
init_values=None,
|
| 55 |
+
drop_path: float = 0.0,
|
| 56 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 57 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 58 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 59 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 60 |
+
) -> None:
|
| 61 |
+
super().__init__()
|
| 62 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 63 |
+
self.norm1 = norm_layer(dim)
|
| 64 |
+
self.attn = attn_class(
|
| 65 |
+
dim,
|
| 66 |
+
num_heads=num_heads,
|
| 67 |
+
qkv_bias=qkv_bias,
|
| 68 |
+
proj_bias=proj_bias,
|
| 69 |
+
attn_drop=attn_drop,
|
| 70 |
+
proj_drop=drop,
|
| 71 |
+
)
|
| 72 |
+
self.ls1 = (
|
| 73 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 74 |
+
)
|
| 75 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 76 |
+
|
| 77 |
+
self.norm2 = norm_layer(dim)
|
| 78 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 79 |
+
self.mlp = ffn_layer(
|
| 80 |
+
in_features=dim,
|
| 81 |
+
hidden_features=mlp_hidden_dim,
|
| 82 |
+
act_layer=act_layer,
|
| 83 |
+
drop=drop,
|
| 84 |
+
bias=ffn_bias,
|
| 85 |
+
)
|
| 86 |
+
self.ls2 = (
|
| 87 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 88 |
+
)
|
| 89 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 90 |
+
|
| 91 |
+
self.sample_drop_ratio = drop_path
|
| 92 |
+
|
| 93 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 94 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 95 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 96 |
+
|
| 97 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 98 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 99 |
+
|
| 100 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 101 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 102 |
+
x = drop_add_residual_stochastic_depth(
|
| 103 |
+
x,
|
| 104 |
+
residual_func=attn_residual_func,
|
| 105 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 106 |
+
)
|
| 107 |
+
x = drop_add_residual_stochastic_depth(
|
| 108 |
+
x,
|
| 109 |
+
residual_func=ffn_residual_func,
|
| 110 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 111 |
+
)
|
| 112 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 113 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 114 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 115 |
+
else:
|
| 116 |
+
x = x + attn_residual_func(x)
|
| 117 |
+
x = x + ffn_residual_func(x)
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def drop_add_residual_stochastic_depth(
|
| 122 |
+
x: Tensor,
|
| 123 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 124 |
+
sample_drop_ratio: float = 0.0,
|
| 125 |
+
) -> Tensor:
|
| 126 |
+
# 1) extract subset using permutation
|
| 127 |
+
b, n, d = x.shape
|
| 128 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 129 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 130 |
+
x_subset = x[brange]
|
| 131 |
+
|
| 132 |
+
# 2) apply residual_func to get residual
|
| 133 |
+
residual = residual_func(x_subset)
|
| 134 |
+
|
| 135 |
+
x_flat = x.flatten(1)
|
| 136 |
+
residual = residual.flatten(1)
|
| 137 |
+
|
| 138 |
+
residual_scale_factor = b / sample_subset_size
|
| 139 |
+
|
| 140 |
+
# 3) add the residual
|
| 141 |
+
x_plus_residual = torch.index_add(
|
| 142 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 143 |
+
)
|
| 144 |
+
return x_plus_residual.view_as(x)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 148 |
+
b, n, d = x.shape
|
| 149 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 150 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 151 |
+
residual_scale_factor = b / sample_subset_size
|
| 152 |
+
return brange, residual_scale_factor
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 156 |
+
if scaling_vector is None:
|
| 157 |
+
x_flat = x.flatten(1)
|
| 158 |
+
residual = residual.flatten(1)
|
| 159 |
+
x_plus_residual = torch.index_add(
|
| 160 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
x_plus_residual = scaled_index_add(
|
| 164 |
+
x,
|
| 165 |
+
brange,
|
| 166 |
+
residual.to(dtype=x.dtype),
|
| 167 |
+
scaling=scaling_vector,
|
| 168 |
+
alpha=residual_scale_factor,
|
| 169 |
+
)
|
| 170 |
+
return x_plus_residual
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 177 |
+
"""
|
| 178 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 179 |
+
"""
|
| 180 |
+
batch_sizes = (
|
| 181 |
+
[b.shape[0] for b in branges]
|
| 182 |
+
if branges is not None
|
| 183 |
+
else [x.shape[0] for x in x_list]
|
| 184 |
+
)
|
| 185 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 186 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 187 |
+
seqlens = []
|
| 188 |
+
for b, x in zip(batch_sizes, x_list):
|
| 189 |
+
for _ in range(b):
|
| 190 |
+
seqlens.append(x.shape[1])
|
| 191 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 192 |
+
attn_bias._batch_sizes = batch_sizes
|
| 193 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 194 |
+
|
| 195 |
+
if branges is not None:
|
| 196 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
|
| 197 |
+
1, -1, x_list[0].shape[-1]
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 201 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 202 |
+
|
| 203 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def drop_add_residual_stochastic_depth_list(
|
| 207 |
+
x_list: List[Tensor],
|
| 208 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 209 |
+
sample_drop_ratio: float = 0.0,
|
| 210 |
+
scaling_vector=None,
|
| 211 |
+
) -> Tensor:
|
| 212 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 213 |
+
branges_scales = [
|
| 214 |
+
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
|
| 215 |
+
]
|
| 216 |
+
branges = [s[0] for s in branges_scales]
|
| 217 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 218 |
+
|
| 219 |
+
# 2) get attention bias and index+concat the tensors
|
| 220 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 221 |
+
|
| 222 |
+
# 3) apply residual_func to get residual, and split the result
|
| 223 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 224 |
+
|
| 225 |
+
outputs = []
|
| 226 |
+
for x, brange, residual, residual_scale_factor in zip(
|
| 227 |
+
x_list, branges, residual_list, residual_scale_factors
|
| 228 |
+
):
|
| 229 |
+
outputs.append(
|
| 230 |
+
add_residual(
|
| 231 |
+
x, brange, residual, residual_scale_factor, scaling_vector
|
| 232 |
+
).view_as(x)
|
| 233 |
+
)
|
| 234 |
+
return outputs
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class NestedTensorBlock(Block):
|
| 238 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 239 |
+
"""
|
| 240 |
+
x_list contains a list of tensors to nest together and run
|
| 241 |
+
"""
|
| 242 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 243 |
+
|
| 244 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 245 |
+
|
| 246 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 247 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 248 |
+
|
| 249 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 250 |
+
return self.mlp(self.norm2(x))
|
| 251 |
+
|
| 252 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 253 |
+
x_list,
|
| 254 |
+
residual_func=attn_residual_func,
|
| 255 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 256 |
+
scaling_vector=self.ls1.gamma
|
| 257 |
+
if isinstance(self.ls1, LayerScale)
|
| 258 |
+
else None,
|
| 259 |
+
)
|
| 260 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 261 |
+
x_list,
|
| 262 |
+
residual_func=ffn_residual_func,
|
| 263 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 264 |
+
scaling_vector=self.ls2.gamma
|
| 265 |
+
if isinstance(self.ls1, LayerScale)
|
| 266 |
+
else None,
|
| 267 |
+
)
|
| 268 |
+
return x_list
|
| 269 |
+
else:
|
| 270 |
+
|
| 271 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 272 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 273 |
+
|
| 274 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 275 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 276 |
+
|
| 277 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 278 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 279 |
+
x = x + ffn_residual_func(x)
|
| 280 |
+
return attn_bias.split(x)
|
| 281 |
+
|
| 282 |
+
def forward(self, x_or_x_list):
|
| 283 |
+
if isinstance(x_or_x_list, Tensor):
|
| 284 |
+
return super().forward(x_or_x_list)
|
| 285 |
+
elif isinstance(x_or_x_list, list):
|
| 286 |
+
if not XFORMERS_AVAILABLE:
|
| 287 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 288 |
+
return self.forward_nested(x_or_x_list)
|
| 289 |
+
else:
|
| 290 |
+
raise AssertionError
|
mapanything/models/external/dinov2/layers/dino_head.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn.init import trunc_normal_
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOHead(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_dim,
|
| 16 |
+
out_dim,
|
| 17 |
+
use_bn=False,
|
| 18 |
+
nlayers=3,
|
| 19 |
+
hidden_dim=2048,
|
| 20 |
+
bottleneck_dim=256,
|
| 21 |
+
mlp_bias=True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
nlayers = max(nlayers, 1)
|
| 25 |
+
self.mlp = _build_mlp(
|
| 26 |
+
nlayers,
|
| 27 |
+
in_dim,
|
| 28 |
+
bottleneck_dim,
|
| 29 |
+
hidden_dim=hidden_dim,
|
| 30 |
+
use_bn=use_bn,
|
| 31 |
+
bias=mlp_bias,
|
| 32 |
+
)
|
| 33 |
+
self.apply(self._init_weights)
|
| 34 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 35 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 36 |
+
|
| 37 |
+
def _init_weights(self, m):
|
| 38 |
+
if isinstance(m, nn.Linear):
|
| 39 |
+
trunc_normal_(m.weight, std=0.02)
|
| 40 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 41 |
+
nn.init.constant_(m.bias, 0)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
x = self.mlp(x)
|
| 45 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 46 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 47 |
+
x = self.last_layer(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _build_mlp(
|
| 52 |
+
nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
|
| 53 |
+
):
|
| 54 |
+
if nlayers == 1:
|
| 55 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 56 |
+
else:
|
| 57 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 58 |
+
if use_bn:
|
| 59 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 60 |
+
layers.append(nn.GELU())
|
| 61 |
+
for _ in range(nlayers - 2):
|
| 62 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 63 |
+
if use_bn:
|
| 64 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 65 |
+
layers.append(nn.GELU())
|
| 66 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 67 |
+
return nn.Sequential(*layers)
|
mapanything/models/external/dinov2/layers/drop_path.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (
|
| 19 |
+
x.ndim - 1
|
| 20 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 21 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 22 |
+
if keep_prob > 0.0:
|
| 23 |
+
random_tensor.div_(keep_prob)
|
| 24 |
+
output = x * random_tensor
|
| 25 |
+
return output
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DropPath(nn.Module):
|
| 29 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, drop_prob=None):
|
| 32 |
+
super(DropPath, self).__init__()
|
| 33 |
+
self.drop_prob = drop_prob
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
return drop_path(x, self.drop_prob, self.training)
|
mapanything/models/external/dinov2/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LayerScale(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dim: int,
|
| 18 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 19 |
+
inplace: bool = False,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.inplace = inplace
|
| 23 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 24 |
+
|
| 25 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 26 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
mapanything/models/external/dinov2/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import nn, Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
mapanything/models/external/dinov2/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(
|
| 66 |
+
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
|
| 67 |
+
)
|
| 68 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 69 |
+
|
| 70 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 71 |
+
_, _, H, W = x.shape
|
| 72 |
+
patch_H, patch_W = self.patch_size
|
| 73 |
+
|
| 74 |
+
assert H % patch_H == 0, (
|
| 75 |
+
f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 76 |
+
)
|
| 77 |
+
assert W % patch_W == 0, (
|
| 78 |
+
f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
x = self.proj(x) # B C H W
|
| 82 |
+
H, W = x.size(2), x.size(3)
|
| 83 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 84 |
+
x = self.norm(x)
|
| 85 |
+
if not self.flatten_embedding:
|
| 86 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
def flops(self) -> float:
|
| 90 |
+
Ho, Wo = self.patches_resolution
|
| 91 |
+
flops = (
|
| 92 |
+
Ho
|
| 93 |
+
* Wo
|
| 94 |
+
* self.embed_dim
|
| 95 |
+
* self.in_chans
|
| 96 |
+
* (self.patch_size[0] * self.patch_size[1])
|
| 97 |
+
)
|
| 98 |
+
if self.norm is not None:
|
| 99 |
+
flops += Ho * Wo * self.embed_dim
|
| 100 |
+
return flops
|
mapanything/models/external/dinov2/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch import nn, Tensor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SwiGLUFFN(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_features: int,
|
| 17 |
+
hidden_features: Optional[int] = None,
|
| 18 |
+
out_features: Optional[int] = None,
|
| 19 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 20 |
+
drop: float = 0.0,
|
| 21 |
+
bias: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
out_features = out_features or in_features
|
| 25 |
+
hidden_features = hidden_features or in_features
|
| 26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 30 |
+
x12 = self.w12(x)
|
| 31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 32 |
+
hidden = F.silu(x1) * x2
|
| 33 |
+
return self.w3(hidden)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 37 |
+
try:
|
| 38 |
+
if XFORMERS_ENABLED:
|
| 39 |
+
from xformers.ops import SwiGLU
|
| 40 |
+
|
| 41 |
+
XFORMERS_AVAILABLE = True
|
| 42 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 43 |
+
else:
|
| 44 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 45 |
+
raise ImportError
|
| 46 |
+
except ImportError:
|
| 47 |
+
SwiGLU = SwiGLUFFN
|
| 48 |
+
XFORMERS_AVAILABLE = False
|
| 49 |
+
|
| 50 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
in_features: int,
|
| 57 |
+
hidden_features: Optional[int] = None,
|
| 58 |
+
out_features: Optional[int] = None,
|
| 59 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 60 |
+
drop: float = 0.0,
|
| 61 |
+
bias: bool = True,
|
| 62 |
+
) -> None:
|
| 63 |
+
out_features = out_features or in_features
|
| 64 |
+
hidden_features = hidden_features or in_features
|
| 65 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 66 |
+
super().__init__(
|
| 67 |
+
in_features=in_features,
|
| 68 |
+
hidden_features=hidden_features,
|
| 69 |
+
out_features=out_features,
|
| 70 |
+
bias=bias,
|
| 71 |
+
)
|
mapanything/models/external/dinov2/models/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import mapanything.models.external.dinov2.models.vision_transformer as vits
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("dinov2")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_model(args, only_teacher=False, img_size=224):
|
| 14 |
+
args.arch = args.arch.removesuffix("_memeff")
|
| 15 |
+
if "vit" in args.arch:
|
| 16 |
+
vit_kwargs = dict(
|
| 17 |
+
img_size=img_size,
|
| 18 |
+
patch_size=args.patch_size,
|
| 19 |
+
init_values=args.layerscale,
|
| 20 |
+
ffn_layer=args.ffn_layer,
|
| 21 |
+
block_chunks=args.block_chunks,
|
| 22 |
+
qkv_bias=args.qkv_bias,
|
| 23 |
+
proj_bias=args.proj_bias,
|
| 24 |
+
ffn_bias=args.ffn_bias,
|
| 25 |
+
num_register_tokens=args.num_register_tokens,
|
| 26 |
+
interpolate_offset=args.interpolate_offset,
|
| 27 |
+
interpolate_antialias=args.interpolate_antialias,
|
| 28 |
+
)
|
| 29 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
| 30 |
+
if only_teacher:
|
| 31 |
+
return teacher, teacher.embed_dim
|
| 32 |
+
student = vits.__dict__[args.arch](
|
| 33 |
+
**vit_kwargs,
|
| 34 |
+
drop_path_rate=args.drop_path_rate,
|
| 35 |
+
drop_path_uniform=args.drop_path_uniform,
|
| 36 |
+
)
|
| 37 |
+
embed_dim = student.embed_dim
|
| 38 |
+
return student, teacher, embed_dim
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
| 42 |
+
return build_model(
|
| 43 |
+
cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size
|
| 44 |
+
)
|
mapanything/models/external/dinov2/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from functools import partial
|
| 12 |
+
from typing import Callable, Sequence, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from torch.nn.init import trunc_normal_
|
| 17 |
+
from torch.utils.checkpoint import checkpoint
|
| 18 |
+
|
| 19 |
+
from mapanything.models.external.dinov2.layers import (
|
| 20 |
+
MemEffAttention,
|
| 21 |
+
Mlp,
|
| 22 |
+
NestedTensorBlock as Block,
|
| 23 |
+
PatchEmbed,
|
| 24 |
+
SwiGLUFFNFused,
|
| 25 |
+
)
|
| 26 |
+
from mapanything.models.external.pi3.layers.attention import FlashAttention
|
| 27 |
+
|
| 28 |
+
# logger = logging.getLogger("dinov2")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def named_apply(
|
| 32 |
+
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
|
| 33 |
+
) -> nn.Module:
|
| 34 |
+
if not depth_first and include_root:
|
| 35 |
+
fn(module=module, name=name)
|
| 36 |
+
for child_name, child_module in module.named_children():
|
| 37 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 38 |
+
named_apply(
|
| 39 |
+
fn=fn,
|
| 40 |
+
module=child_module,
|
| 41 |
+
name=child_name,
|
| 42 |
+
depth_first=depth_first,
|
| 43 |
+
include_root=True,
|
| 44 |
+
)
|
| 45 |
+
if depth_first and include_root:
|
| 46 |
+
fn(module=module, name=name)
|
| 47 |
+
return module
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class BlockChunk(nn.ModuleList):
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
for b in self:
|
| 53 |
+
x = b(x)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class DinoVisionTransformer(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
img_size=224,
|
| 61 |
+
patch_size=16,
|
| 62 |
+
in_chans=3,
|
| 63 |
+
embed_dim=768,
|
| 64 |
+
depth=12,
|
| 65 |
+
num_heads=12,
|
| 66 |
+
mlp_ratio=4.0,
|
| 67 |
+
qkv_bias=True,
|
| 68 |
+
ffn_bias=True,
|
| 69 |
+
proj_bias=True,
|
| 70 |
+
drop_path_rate=0.0,
|
| 71 |
+
drop_path_uniform=False,
|
| 72 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 73 |
+
embed_layer=PatchEmbed,
|
| 74 |
+
act_layer=nn.GELU,
|
| 75 |
+
block_fn=Block,
|
| 76 |
+
ffn_layer="mlp",
|
| 77 |
+
block_chunks=1,
|
| 78 |
+
num_register_tokens=0,
|
| 79 |
+
interpolate_antialias=False,
|
| 80 |
+
interpolate_offset=0.1,
|
| 81 |
+
):
|
| 82 |
+
"""
|
| 83 |
+
Args:
|
| 84 |
+
img_size (int, tuple): input image size
|
| 85 |
+
patch_size (int, tuple): patch size
|
| 86 |
+
in_chans (int): number of input channels
|
| 87 |
+
embed_dim (int): embedding dimension
|
| 88 |
+
depth (int): depth of transformer
|
| 89 |
+
num_heads (int): number of attention heads
|
| 90 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 91 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 92 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 93 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 94 |
+
drop_path_rate (float): stochastic depth rate
|
| 95 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 96 |
+
weight_init (str): weight init scheme
|
| 97 |
+
init_values (float): layer-scale init values
|
| 98 |
+
embed_layer (nn.Module): patch embedding layer
|
| 99 |
+
act_layer (nn.Module): MLP activation layer
|
| 100 |
+
block_fn (nn.Module): transformer block class
|
| 101 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 102 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 103 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 104 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 105 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 106 |
+
"""
|
| 107 |
+
super().__init__()
|
| 108 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 109 |
+
|
| 110 |
+
self.num_features = self.embed_dim = (
|
| 111 |
+
embed_dim # num_features for consistency with other models
|
| 112 |
+
)
|
| 113 |
+
self.num_tokens = 1
|
| 114 |
+
self.n_blocks = depth
|
| 115 |
+
self.num_heads = num_heads
|
| 116 |
+
self.patch_size = patch_size
|
| 117 |
+
self.num_register_tokens = num_register_tokens
|
| 118 |
+
self.interpolate_antialias = interpolate_antialias
|
| 119 |
+
self.interpolate_offset = interpolate_offset
|
| 120 |
+
|
| 121 |
+
self.patch_embed = embed_layer(
|
| 122 |
+
img_size=img_size,
|
| 123 |
+
patch_size=patch_size,
|
| 124 |
+
in_chans=in_chans,
|
| 125 |
+
embed_dim=embed_dim,
|
| 126 |
+
)
|
| 127 |
+
num_patches = self.patch_embed.num_patches
|
| 128 |
+
|
| 129 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 130 |
+
self.pos_embed = nn.Parameter(
|
| 131 |
+
torch.zeros(1, num_patches + self.num_tokens, embed_dim)
|
| 132 |
+
)
|
| 133 |
+
assert num_register_tokens >= 0
|
| 134 |
+
self.register_tokens = (
|
| 135 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
|
| 136 |
+
if num_register_tokens
|
| 137 |
+
else None
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if drop_path_uniform is True:
|
| 141 |
+
dpr = [drop_path_rate] * depth
|
| 142 |
+
else:
|
| 143 |
+
dpr = [
|
| 144 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 145 |
+
] # stochastic depth decay rule
|
| 146 |
+
|
| 147 |
+
if ffn_layer == "mlp":
|
| 148 |
+
# logger.info("using MLP layer as FFN")
|
| 149 |
+
ffn_layer = Mlp
|
| 150 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 151 |
+
# logger.info("using SwiGLU layer as FFN")
|
| 152 |
+
ffn_layer = SwiGLUFFNFused
|
| 153 |
+
elif ffn_layer == "identity":
|
| 154 |
+
# logger.info("using Identity layer as FFN")
|
| 155 |
+
|
| 156 |
+
def f(*args, **kwargs):
|
| 157 |
+
return nn.Identity()
|
| 158 |
+
|
| 159 |
+
ffn_layer = f
|
| 160 |
+
else:
|
| 161 |
+
raise NotImplementedError
|
| 162 |
+
|
| 163 |
+
blocks_list = [
|
| 164 |
+
block_fn(
|
| 165 |
+
dim=embed_dim,
|
| 166 |
+
num_heads=num_heads,
|
| 167 |
+
mlp_ratio=mlp_ratio,
|
| 168 |
+
qkv_bias=qkv_bias,
|
| 169 |
+
proj_bias=proj_bias,
|
| 170 |
+
ffn_bias=ffn_bias,
|
| 171 |
+
drop_path=dpr[i],
|
| 172 |
+
norm_layer=norm_layer,
|
| 173 |
+
act_layer=act_layer,
|
| 174 |
+
ffn_layer=ffn_layer,
|
| 175 |
+
init_values=init_values,
|
| 176 |
+
attn_class=FlashAttention,
|
| 177 |
+
)
|
| 178 |
+
for i in range(depth)
|
| 179 |
+
]
|
| 180 |
+
if block_chunks > 0:
|
| 181 |
+
self.chunked_blocks = True
|
| 182 |
+
chunked_blocks = []
|
| 183 |
+
chunksize = depth // block_chunks
|
| 184 |
+
for i in range(0, depth, chunksize):
|
| 185 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 186 |
+
chunked_blocks.append(
|
| 187 |
+
[nn.Identity()] * i + blocks_list[i : i + chunksize]
|
| 188 |
+
)
|
| 189 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 190 |
+
else:
|
| 191 |
+
self.chunked_blocks = False
|
| 192 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 193 |
+
|
| 194 |
+
self.norm = norm_layer(embed_dim)
|
| 195 |
+
self.head = nn.Identity()
|
| 196 |
+
|
| 197 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 198 |
+
|
| 199 |
+
self.init_weights()
|
| 200 |
+
|
| 201 |
+
def init_weights(self):
|
| 202 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 203 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 204 |
+
if self.register_tokens is not None:
|
| 205 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 206 |
+
named_apply(init_weights_vit_timm, self)
|
| 207 |
+
|
| 208 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 209 |
+
previous_dtype = x.dtype
|
| 210 |
+
npatch = x.shape[1] - 1
|
| 211 |
+
N = self.pos_embed.shape[1] - 1
|
| 212 |
+
if npatch == N and w == h:
|
| 213 |
+
return self.pos_embed
|
| 214 |
+
pos_embed = self.pos_embed.float()
|
| 215 |
+
class_pos_embed = pos_embed[:, 0]
|
| 216 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 217 |
+
dim = x.shape[-1]
|
| 218 |
+
w0 = w // self.patch_size
|
| 219 |
+
h0 = h // self.patch_size
|
| 220 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 221 |
+
assert N == M * M
|
| 222 |
+
kwargs = {}
|
| 223 |
+
if self.interpolate_offset:
|
| 224 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 225 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 226 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 227 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 228 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 229 |
+
else:
|
| 230 |
+
# Simply specify an output size instead of a scale factor
|
| 231 |
+
kwargs["size"] = (w0, h0)
|
| 232 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 233 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 234 |
+
mode="bicubic",
|
| 235 |
+
antialias=self.interpolate_antialias,
|
| 236 |
+
**kwargs,
|
| 237 |
+
)
|
| 238 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 239 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 240 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
|
| 241 |
+
previous_dtype
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 245 |
+
B, nc, w, h = x.shape
|
| 246 |
+
x = self.patch_embed(x)
|
| 247 |
+
if masks is not None:
|
| 248 |
+
x = torch.where(
|
| 249 |
+
masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 253 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 254 |
+
|
| 255 |
+
if self.register_tokens is not None:
|
| 256 |
+
x = torch.cat(
|
| 257 |
+
(
|
| 258 |
+
x[:, :1],
|
| 259 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 260 |
+
x[:, 1:],
|
| 261 |
+
),
|
| 262 |
+
dim=1,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return x
|
| 266 |
+
|
| 267 |
+
def forward_features_list(self, x_list, masks_list):
|
| 268 |
+
x = [
|
| 269 |
+
self.prepare_tokens_with_masks(x, masks)
|
| 270 |
+
for x, masks in zip(x_list, masks_list)
|
| 271 |
+
]
|
| 272 |
+
for blk in self.blocks:
|
| 273 |
+
if self.training:
|
| 274 |
+
x = checkpoint(blk, x, use_reentrant=False)
|
| 275 |
+
else:
|
| 276 |
+
x = blk(x)
|
| 277 |
+
|
| 278 |
+
all_x = x
|
| 279 |
+
output = []
|
| 280 |
+
for x, masks in zip(all_x, masks_list):
|
| 281 |
+
x_norm = self.norm(x)
|
| 282 |
+
output.append(
|
| 283 |
+
{
|
| 284 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 285 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 286 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 287 |
+
"x_prenorm": x,
|
| 288 |
+
"masks": masks,
|
| 289 |
+
}
|
| 290 |
+
)
|
| 291 |
+
return output
|
| 292 |
+
|
| 293 |
+
def forward_features(self, x, masks=None):
|
| 294 |
+
if isinstance(x, list):
|
| 295 |
+
return self.forward_features_list(x, masks)
|
| 296 |
+
|
| 297 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 298 |
+
|
| 299 |
+
for blk in self.blocks:
|
| 300 |
+
if self.training:
|
| 301 |
+
x = checkpoint(blk, x, use_reentrant=False)
|
| 302 |
+
else:
|
| 303 |
+
x = blk(x)
|
| 304 |
+
|
| 305 |
+
x_norm = self.norm(x)
|
| 306 |
+
return {
|
| 307 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 308 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 309 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 310 |
+
"x_prenorm": x,
|
| 311 |
+
"masks": masks,
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 315 |
+
x = self.prepare_tokens_with_masks(x)
|
| 316 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 317 |
+
output, total_block_len = [], len(self.blocks)
|
| 318 |
+
blocks_to_take = (
|
| 319 |
+
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 320 |
+
)
|
| 321 |
+
for i, blk in enumerate(self.blocks):
|
| 322 |
+
x = blk(x)
|
| 323 |
+
if i in blocks_to_take:
|
| 324 |
+
output.append(x)
|
| 325 |
+
assert len(output) == len(blocks_to_take), (
|
| 326 |
+
f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 327 |
+
)
|
| 328 |
+
return output
|
| 329 |
+
|
| 330 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 331 |
+
x = self.prepare_tokens_with_masks(x)
|
| 332 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 333 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 334 |
+
blocks_to_take = (
|
| 335 |
+
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 336 |
+
)
|
| 337 |
+
for block_chunk in self.blocks:
|
| 338 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 339 |
+
x = blk(x)
|
| 340 |
+
if i in blocks_to_take:
|
| 341 |
+
output.append(x)
|
| 342 |
+
i += 1
|
| 343 |
+
assert len(output) == len(blocks_to_take), (
|
| 344 |
+
f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 345 |
+
)
|
| 346 |
+
return output
|
| 347 |
+
|
| 348 |
+
def get_intermediate_layers(
|
| 349 |
+
self,
|
| 350 |
+
x: torch.Tensor,
|
| 351 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 352 |
+
reshape: bool = False,
|
| 353 |
+
return_class_token: bool = False,
|
| 354 |
+
norm=True,
|
| 355 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 356 |
+
if self.chunked_blocks:
|
| 357 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 358 |
+
else:
|
| 359 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 360 |
+
if norm:
|
| 361 |
+
outputs = [self.norm(out) for out in outputs]
|
| 362 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 363 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 364 |
+
if reshape:
|
| 365 |
+
B, _, w, h = x.shape
|
| 366 |
+
outputs = [
|
| 367 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
|
| 368 |
+
.permute(0, 3, 1, 2)
|
| 369 |
+
.contiguous()
|
| 370 |
+
for out in outputs
|
| 371 |
+
]
|
| 372 |
+
if return_class_token:
|
| 373 |
+
return tuple(zip(outputs, class_tokens))
|
| 374 |
+
return tuple(outputs)
|
| 375 |
+
|
| 376 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 377 |
+
ret = self.forward_features(*args, **kwargs)
|
| 378 |
+
if is_training:
|
| 379 |
+
return ret
|
| 380 |
+
else:
|
| 381 |
+
return self.head(ret["x_norm_clstoken"])
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 385 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 386 |
+
if isinstance(module, nn.Linear):
|
| 387 |
+
trunc_normal_(module.weight, std=0.02)
|
| 388 |
+
if module.bias is not None:
|
| 389 |
+
nn.init.zeros_(module.bias)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 393 |
+
model = DinoVisionTransformer(
|
| 394 |
+
patch_size=patch_size,
|
| 395 |
+
embed_dim=384,
|
| 396 |
+
depth=12,
|
| 397 |
+
num_heads=6,
|
| 398 |
+
mlp_ratio=4,
|
| 399 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 400 |
+
num_register_tokens=num_register_tokens,
|
| 401 |
+
**kwargs,
|
| 402 |
+
)
|
| 403 |
+
return model
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 407 |
+
model = DinoVisionTransformer(
|
| 408 |
+
patch_size=patch_size,
|
| 409 |
+
embed_dim=768,
|
| 410 |
+
depth=12,
|
| 411 |
+
num_heads=12,
|
| 412 |
+
mlp_ratio=4,
|
| 413 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 414 |
+
num_register_tokens=num_register_tokens,
|
| 415 |
+
**kwargs,
|
| 416 |
+
)
|
| 417 |
+
return model
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 421 |
+
model = DinoVisionTransformer(
|
| 422 |
+
patch_size=patch_size,
|
| 423 |
+
embed_dim=1024,
|
| 424 |
+
depth=24,
|
| 425 |
+
num_heads=16,
|
| 426 |
+
mlp_ratio=4,
|
| 427 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 428 |
+
num_register_tokens=num_register_tokens,
|
| 429 |
+
**kwargs,
|
| 430 |
+
)
|
| 431 |
+
return model
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 435 |
+
"""
|
| 436 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 437 |
+
"""
|
| 438 |
+
model = DinoVisionTransformer(
|
| 439 |
+
patch_size=patch_size,
|
| 440 |
+
embed_dim=1536,
|
| 441 |
+
depth=40,
|
| 442 |
+
num_heads=24,
|
| 443 |
+
mlp_ratio=4,
|
| 444 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 445 |
+
num_register_tokens=num_register_tokens,
|
| 446 |
+
**kwargs,
|
| 447 |
+
)
|
| 448 |
+
return model
|
mapanything/models/external/dinov2/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
mapanything/models/external/dinov2/utils/cluster.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ClusterType(Enum):
|
| 13 |
+
AWS = "aws"
|
| 14 |
+
FAIR = "fair"
|
| 15 |
+
RSC = "rsc"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _guess_cluster_type() -> ClusterType:
|
| 19 |
+
uname = os.uname()
|
| 20 |
+
if uname.sysname == "Linux":
|
| 21 |
+
if uname.release.endswith("-aws"):
|
| 22 |
+
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
| 23 |
+
return ClusterType.AWS
|
| 24 |
+
elif uname.nodename.startswith("rsc"):
|
| 25 |
+
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
| 26 |
+
return ClusterType.RSC
|
| 27 |
+
|
| 28 |
+
return ClusterType.FAIR
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_cluster_type(
|
| 32 |
+
cluster_type: Optional[ClusterType] = None,
|
| 33 |
+
) -> Optional[ClusterType]:
|
| 34 |
+
if cluster_type is None:
|
| 35 |
+
return _guess_cluster_type()
|
| 36 |
+
|
| 37 |
+
return cluster_type
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 41 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 42 |
+
if cluster_type is None:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
CHECKPOINT_DIRNAMES = {
|
| 46 |
+
ClusterType.AWS: "checkpoints",
|
| 47 |
+
ClusterType.FAIR: "checkpoint",
|
| 48 |
+
ClusterType.RSC: "checkpoint/dino",
|
| 49 |
+
}
|
| 50 |
+
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_user_checkpoint_path(
|
| 54 |
+
cluster_type: Optional[ClusterType] = None,
|
| 55 |
+
) -> Optional[Path]:
|
| 56 |
+
checkpoint_path = get_checkpoint_path(cluster_type)
|
| 57 |
+
if checkpoint_path is None:
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
username = os.environ.get("USER")
|
| 61 |
+
assert username is not None
|
| 62 |
+
return checkpoint_path / username
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| 66 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 67 |
+
if cluster_type is None:
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
SLURM_PARTITIONS = {
|
| 71 |
+
ClusterType.AWS: "learnlab",
|
| 72 |
+
ClusterType.FAIR: "learnlab",
|
| 73 |
+
ClusterType.RSC: "learn",
|
| 74 |
+
}
|
| 75 |
+
return SLURM_PARTITIONS[cluster_type]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_slurm_executor_parameters(
|
| 79 |
+
nodes: int,
|
| 80 |
+
num_gpus_per_node: int,
|
| 81 |
+
cluster_type: Optional[ClusterType] = None,
|
| 82 |
+
**kwargs,
|
| 83 |
+
) -> Dict[str, Any]:
|
| 84 |
+
# create default parameters
|
| 85 |
+
params = {
|
| 86 |
+
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
| 87 |
+
"gpus_per_node": num_gpus_per_node,
|
| 88 |
+
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
| 89 |
+
"cpus_per_task": 10,
|
| 90 |
+
"nodes": nodes,
|
| 91 |
+
"slurm_partition": get_slurm_partition(cluster_type),
|
| 92 |
+
}
|
| 93 |
+
# apply cluster-specific adjustments
|
| 94 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 95 |
+
if cluster_type == ClusterType.AWS:
|
| 96 |
+
params["cpus_per_task"] = 12
|
| 97 |
+
del params["mem_gb"]
|
| 98 |
+
elif cluster_type == ClusterType.RSC:
|
| 99 |
+
params["cpus_per_task"] = 12
|
| 100 |
+
# set additional parameters / apply overrides
|
| 101 |
+
params.update(kwargs)
|
| 102 |
+
return params
|
mapanything/models/external/dinov2/utils/config.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import dinov2.distributed as distributed
|
| 11 |
+
from dinov2.configs import dinov2_default_config
|
| 12 |
+
from dinov2.logging import setup_logging
|
| 13 |
+
from dinov2.utils import utils
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger("dinov2")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def apply_scaling_rules_to_cfg(cfg): # to fix
|
| 20 |
+
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
| 21 |
+
base_lr = cfg.optim.base_lr
|
| 22 |
+
cfg.optim.lr = base_lr
|
| 23 |
+
cfg.optim.lr *= math.sqrt(
|
| 24 |
+
cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0
|
| 25 |
+
)
|
| 26 |
+
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
return cfg
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
| 33 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
| 34 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
| 35 |
+
with open(saved_cfg_path, "w") as f:
|
| 36 |
+
OmegaConf.save(config=cfg, f=f)
|
| 37 |
+
return saved_cfg_path
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_cfg_from_args(args):
|
| 41 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
| 42 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 43 |
+
default_cfg = OmegaConf.create(dinov2_default_config)
|
| 44 |
+
cfg = OmegaConf.load(args.config_file)
|
| 45 |
+
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
| 46 |
+
return cfg
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def default_setup(args):
|
| 50 |
+
distributed.enable(overwrite=True)
|
| 51 |
+
seed = getattr(args, "seed", 0)
|
| 52 |
+
rank = distributed.get_global_rank()
|
| 53 |
+
|
| 54 |
+
global logger
|
| 55 |
+
setup_logging(output=args.output_dir, level=logging.INFO)
|
| 56 |
+
logger = logging.getLogger("dinov2")
|
| 57 |
+
|
| 58 |
+
utils.fix_random_seeds(seed + rank)
|
| 59 |
+
logger.info("git:\n {}\n".format(utils.get_sha()))
|
| 60 |
+
logger.info(
|
| 61 |
+
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def setup(args):
|
| 66 |
+
"""
|
| 67 |
+
Create configs and perform basic setups.
|
| 68 |
+
"""
|
| 69 |
+
cfg = get_cfg_from_args(args)
|
| 70 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 71 |
+
default_setup(args)
|
| 72 |
+
apply_scaling_rules_to_cfg(cfg)
|
| 73 |
+
write_config(cfg, args.output_dir)
|
| 74 |
+
return cfg
|
mapanything/models/external/dinov2/utils/dtype.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
TypeSpec = Union[str, np.dtype, torch.dtype]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
| 16 |
+
np.dtype("bool"): torch.bool,
|
| 17 |
+
np.dtype("uint8"): torch.uint8,
|
| 18 |
+
np.dtype("int8"): torch.int8,
|
| 19 |
+
np.dtype("int16"): torch.int16,
|
| 20 |
+
np.dtype("int32"): torch.int32,
|
| 21 |
+
np.dtype("int64"): torch.int64,
|
| 22 |
+
np.dtype("float16"): torch.float16,
|
| 23 |
+
np.dtype("float32"): torch.float32,
|
| 24 |
+
np.dtype("float64"): torch.float64,
|
| 25 |
+
np.dtype("complex64"): torch.complex64,
|
| 26 |
+
np.dtype("complex128"): torch.complex128,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
| 31 |
+
if isinstance(dtype, torch.dtype):
|
| 32 |
+
return dtype
|
| 33 |
+
if isinstance(dtype, str):
|
| 34 |
+
dtype = np.dtype(dtype)
|
| 35 |
+
assert isinstance(dtype, np.dtype), (
|
| 36 |
+
f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
| 37 |
+
)
|
| 38 |
+
return _NUMPY_TO_TORCH_DTYPE[dtype]
|