Spaces:
Running
Running
from tqdm import tqdm | |
from os import path as osp | |
from torch.utils.data import Dataset, DataLoader, ConcatDataset | |
from src.datasets.megadepth import MegaDepthDataset | |
from src.datasets.scannet import ScanNetDataset | |
from src.datasets.aachen import AachenDataset | |
from src.datasets.inloc import InLocDataset | |
class TestDataLoader(DataLoader): | |
""" | |
For distributed training, each training process is assgined | |
only a part of the training scenes to reduce memory overhead. | |
""" | |
def __init__(self, config): | |
# 1. data config | |
self.test_data_source = config.DATASET.TEST_DATA_SOURCE | |
dataset_name = str(self.test_data_source).lower() | |
# testing | |
self.test_data_root = config.DATASET.TEST_DATA_ROOT | |
self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) | |
self.test_npz_root = config.DATASET.TEST_NPZ_ROOT | |
self.test_list_path = config.DATASET.TEST_LIST_PATH | |
self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH | |
# 2. dataset config | |
# general options | |
self.min_overlap_score_test = ( | |
config.DATASET.MIN_OVERLAP_SCORE_TEST | |
) # 0.4, omit data with overlap_score < min_overlap_score | |
# MegaDepth options | |
if dataset_name == "megadepth": | |
self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 800 | |
self.mgdpt_img_pad = True | |
self.mgdpt_depth_pad = True | |
self.mgdpt_df = 8 | |
self.coarse_scale = 0.125 | |
if dataset_name == "scannet": | |
self.img_resize = config.DATASET.TEST_IMGSIZE | |
if (dataset_name == "megadepth") or (dataset_name == "scannet"): | |
test_dataset = self._setup_dataset( | |
self.test_data_root, | |
self.test_npz_root, | |
self.test_list_path, | |
self.test_intrinsic_path, | |
mode="test", | |
min_overlap_score=self.min_overlap_score_test, | |
pose_dir=self.test_pose_root, | |
) | |
elif dataset_name == "aachen_v1.1": | |
test_dataset = AachenDataset( | |
self.test_data_root, | |
self.test_list_path, | |
img_resize=config.DATASET.TEST_IMGSIZE, | |
) | |
elif dataset_name == "inloc": | |
test_dataset = InLocDataset( | |
self.test_data_root, | |
self.test_list_path, | |
img_resize=config.DATASET.TEST_IMGSIZE, | |
) | |
else: | |
raise "unknown dataset" | |
self.test_loader_params = { | |
"batch_size": 1, | |
"shuffle": False, | |
"num_workers": 4, | |
"pin_memory": True, | |
} | |
# sampler = Seq(self.test_dataset, shuffle=False) | |
super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params) | |
def _setup_dataset( | |
self, | |
data_root, | |
split_npz_root, | |
scene_list_path, | |
intri_path, | |
mode="train", | |
min_overlap_score=0.0, | |
pose_dir=None, | |
): | |
"""Setup train / val / test set""" | |
with open(scene_list_path, "r") as f: | |
npz_names = [name.split()[0] for name in f.readlines()] | |
local_npz_names = npz_names | |
return self._build_concat_dataset( | |
data_root, | |
local_npz_names, | |
split_npz_root, | |
intri_path, | |
mode=mode, | |
min_overlap_score=min_overlap_score, | |
pose_dir=pose_dir, | |
) | |
def _build_concat_dataset( | |
self, | |
data_root, | |
npz_names, | |
npz_dir, | |
intrinsic_path, | |
mode, | |
min_overlap_score=0.0, | |
pose_dir=None, | |
): | |
datasets = [] | |
# augment_fn = self.augment_fn if mode == 'train' else None | |
data_source = self.test_data_source | |
if str(data_source).lower() == "megadepth": | |
npz_names = [f"{n}.npz" for n in npz_names] | |
for npz_name in tqdm(npz_names): | |
# `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. | |
npz_path = osp.join(npz_dir, npz_name) | |
if data_source == "ScanNet": | |
datasets.append( | |
ScanNetDataset( | |
data_root, | |
npz_path, | |
intrinsic_path, | |
mode=mode, | |
img_resize=self.img_resize, | |
min_overlap_score=min_overlap_score, | |
pose_dir=pose_dir, | |
) | |
) | |
elif data_source == "MegaDepth": | |
datasets.append( | |
MegaDepthDataset( | |
data_root, | |
npz_path, | |
mode=mode, | |
min_overlap_score=min_overlap_score, | |
img_resize=self.mgdpt_img_resize, | |
df=self.mgdpt_df, | |
img_padding=self.mgdpt_img_pad, | |
depth_padding=self.mgdpt_depth_pad, | |
coarse_scale=self.coarse_scale, | |
) | |
) | |
else: | |
raise NotImplementedError() | |
return ConcatDataset(datasets) | |