|
from tqdm import tqdm |
|
from os import path as osp |
|
from torch.utils.data import Dataset, DataLoader, ConcatDataset |
|
|
|
from src.datasets.megadepth import MegaDepthDataset |
|
from src.datasets.scannet import ScanNetDataset |
|
from src.datasets.aachen import AachenDataset |
|
from src.datasets.inloc import InLocDataset |
|
|
|
|
|
class TestDataLoader(DataLoader): |
|
""" |
|
For distributed training, each training process is assgined |
|
only a part of the training scenes to reduce memory overhead. |
|
""" |
|
|
|
def __init__(self, config): |
|
|
|
|
|
self.test_data_source = config.DATASET.TEST_DATA_SOURCE |
|
dataset_name = str(self.test_data_source).lower() |
|
|
|
self.test_data_root = config.DATASET.TEST_DATA_ROOT |
|
self.test_pose_root = config.DATASET.TEST_POSE_ROOT |
|
self.test_npz_root = config.DATASET.TEST_NPZ_ROOT |
|
self.test_list_path = config.DATASET.TEST_LIST_PATH |
|
self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH |
|
|
|
|
|
|
|
self.min_overlap_score_test = ( |
|
config.DATASET.MIN_OVERLAP_SCORE_TEST |
|
) |
|
|
|
|
|
if dataset_name == "megadepth": |
|
self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE |
|
self.mgdpt_img_pad = True |
|
self.mgdpt_depth_pad = True |
|
self.mgdpt_df = 8 |
|
self.coarse_scale = 0.125 |
|
if dataset_name == "scannet": |
|
self.img_resize = config.DATASET.TEST_IMGSIZE |
|
|
|
if (dataset_name == "megadepth") or (dataset_name == "scannet"): |
|
test_dataset = self._setup_dataset( |
|
self.test_data_root, |
|
self.test_npz_root, |
|
self.test_list_path, |
|
self.test_intrinsic_path, |
|
mode="test", |
|
min_overlap_score=self.min_overlap_score_test, |
|
pose_dir=self.test_pose_root, |
|
) |
|
elif dataset_name == "aachen_v1.1": |
|
test_dataset = AachenDataset( |
|
self.test_data_root, |
|
self.test_list_path, |
|
img_resize=config.DATASET.TEST_IMGSIZE, |
|
) |
|
elif dataset_name == "inloc": |
|
test_dataset = InLocDataset( |
|
self.test_data_root, |
|
self.test_list_path, |
|
img_resize=config.DATASET.TEST_IMGSIZE, |
|
) |
|
else: |
|
raise "unknown dataset" |
|
|
|
self.test_loader_params = { |
|
"batch_size": 1, |
|
"shuffle": False, |
|
"num_workers": 4, |
|
"pin_memory": True, |
|
} |
|
|
|
|
|
super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params) |
|
|
|
def _setup_dataset( |
|
self, |
|
data_root, |
|
split_npz_root, |
|
scene_list_path, |
|
intri_path, |
|
mode="train", |
|
min_overlap_score=0.0, |
|
pose_dir=None, |
|
): |
|
"""Setup train / val / test set""" |
|
with open(scene_list_path, "r") as f: |
|
npz_names = [name.split()[0] for name in f.readlines()] |
|
local_npz_names = npz_names |
|
|
|
return self._build_concat_dataset( |
|
data_root, |
|
local_npz_names, |
|
split_npz_root, |
|
intri_path, |
|
mode=mode, |
|
min_overlap_score=min_overlap_score, |
|
pose_dir=pose_dir, |
|
) |
|
|
|
def _build_concat_dataset( |
|
self, |
|
data_root, |
|
npz_names, |
|
npz_dir, |
|
intrinsic_path, |
|
mode, |
|
min_overlap_score=0.0, |
|
pose_dir=None, |
|
): |
|
datasets = [] |
|
|
|
data_source = self.test_data_source |
|
if str(data_source).lower() == "megadepth": |
|
npz_names = [f"{n}.npz" for n in npz_names] |
|
for npz_name in tqdm(npz_names): |
|
|
|
npz_path = osp.join(npz_dir, npz_name) |
|
if data_source == "ScanNet": |
|
datasets.append( |
|
ScanNetDataset( |
|
data_root, |
|
npz_path, |
|
intrinsic_path, |
|
mode=mode, |
|
img_resize=self.img_resize, |
|
min_overlap_score=min_overlap_score, |
|
pose_dir=pose_dir, |
|
) |
|
) |
|
elif data_source == "MegaDepth": |
|
datasets.append( |
|
MegaDepthDataset( |
|
data_root, |
|
npz_path, |
|
mode=mode, |
|
min_overlap_score=min_overlap_score, |
|
img_resize=self.mgdpt_img_resize, |
|
df=self.mgdpt_df, |
|
img_padding=self.mgdpt_img_pad, |
|
depth_padding=self.mgdpt_depth_pad, |
|
coarse_scale=self.coarse_scale, |
|
) |
|
) |
|
else: |
|
raise NotImplementedError() |
|
return ConcatDataset(datasets) |
|
|