Vincentqyw
fix: roma
c74a070
raw
history blame
No virus
5.31 kB
from tqdm import tqdm
from os import path as osp
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from src.datasets.megadepth import MegaDepthDataset
from src.datasets.scannet import ScanNetDataset
from src.datasets.aachen import AachenDataset
from src.datasets.inloc import InLocDataset
class TestDataLoader(DataLoader):
"""
For distributed training, each training process is assgined
only a part of the training scenes to reduce memory overhead.
"""
def __init__(self, config):
# 1. data config
self.test_data_source = config.DATASET.TEST_DATA_SOURCE
dataset_name = str(self.test_data_source).lower()
# testing
self.test_data_root = config.DATASET.TEST_DATA_ROOT
self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
self.test_list_path = config.DATASET.TEST_LIST_PATH
self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
# 2. dataset config
# general options
self.min_overlap_score_test = (
config.DATASET.MIN_OVERLAP_SCORE_TEST
) # 0.4, omit data with overlap_score < min_overlap_score
# MegaDepth options
if dataset_name == "megadepth":
self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 800
self.mgdpt_img_pad = True
self.mgdpt_depth_pad = True
self.mgdpt_df = 8
self.coarse_scale = 0.125
if dataset_name == "scannet":
self.img_resize = config.DATASET.TEST_IMGSIZE
if (dataset_name == "megadepth") or (dataset_name == "scannet"):
test_dataset = self._setup_dataset(
self.test_data_root,
self.test_npz_root,
self.test_list_path,
self.test_intrinsic_path,
mode="test",
min_overlap_score=self.min_overlap_score_test,
pose_dir=self.test_pose_root,
)
elif dataset_name == "aachen_v1.1":
test_dataset = AachenDataset(
self.test_data_root,
self.test_list_path,
img_resize=config.DATASET.TEST_IMGSIZE,
)
elif dataset_name == "inloc":
test_dataset = InLocDataset(
self.test_data_root,
self.test_list_path,
img_resize=config.DATASET.TEST_IMGSIZE,
)
else:
raise "unknown dataset"
self.test_loader_params = {
"batch_size": 1,
"shuffle": False,
"num_workers": 4,
"pin_memory": True,
}
# sampler = Seq(self.test_dataset, shuffle=False)
super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params)
def _setup_dataset(
self,
data_root,
split_npz_root,
scene_list_path,
intri_path,
mode="train",
min_overlap_score=0.0,
pose_dir=None,
):
"""Setup train / val / test set"""
with open(scene_list_path, "r") as f:
npz_names = [name.split()[0] for name in f.readlines()]
local_npz_names = npz_names
return self._build_concat_dataset(
data_root,
local_npz_names,
split_npz_root,
intri_path,
mode=mode,
min_overlap_score=min_overlap_score,
pose_dir=pose_dir,
)
def _build_concat_dataset(
self,
data_root,
npz_names,
npz_dir,
intrinsic_path,
mode,
min_overlap_score=0.0,
pose_dir=None,
):
datasets = []
# augment_fn = self.augment_fn if mode == 'train' else None
data_source = self.test_data_source
if str(data_source).lower() == "megadepth":
npz_names = [f"{n}.npz" for n in npz_names]
for npz_name in tqdm(npz_names):
# `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
npz_path = osp.join(npz_dir, npz_name)
if data_source == "ScanNet":
datasets.append(
ScanNetDataset(
data_root,
npz_path,
intrinsic_path,
mode=mode,
img_resize=self.img_resize,
min_overlap_score=min_overlap_score,
pose_dir=pose_dir,
)
)
elif data_source == "MegaDepth":
datasets.append(
MegaDepthDataset(
data_root,
npz_path,
mode=mode,
min_overlap_score=min_overlap_score,
img_resize=self.mgdpt_img_resize,
df=self.mgdpt_df,
img_padding=self.mgdpt_img_pad,
depth_padding=self.mgdpt_depth_pad,
coarse_scale=self.coarse_scale,
)
)
else:
raise NotImplementedError()
return ConcatDataset(datasets)