Spaces:
Running
Running
File size: 5,298 Bytes
437b5f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from tqdm import tqdm
from os import path as osp
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from src.datasets.megadepth import MegaDepthDataset
from src.datasets.scannet import ScanNetDataset
from src.datasets.aachen import AachenDataset
from src.datasets.inloc import InLocDataset
class TestDataLoader(DataLoader):
"""
For distributed training, each training process is assgined
only a part of the training scenes to reduce memory overhead.
"""
def __init__(self, config):
# 1. data config
self.test_data_source = config.DATASET.TEST_DATA_SOURCE
dataset_name = str(self.test_data_source).lower()
# testing
self.test_data_root = config.DATASET.TEST_DATA_ROOT
self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional)
self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
self.test_list_path = config.DATASET.TEST_LIST_PATH
self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH
# 2. dataset config
# general options
self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
# MegaDepth options
if dataset_name == 'megadepth':
self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 800
self.mgdpt_img_pad = True
self.mgdpt_depth_pad = True
self.mgdpt_df = 8
self.coarse_scale = 0.125
if dataset_name == 'scannet':
self.img_resize = config.DATASET.TEST_IMGSIZE
if (dataset_name == 'megadepth') or (dataset_name == 'scannet'):
test_dataset = self._setup_dataset(
self.test_data_root,
self.test_npz_root,
self.test_list_path,
self.test_intrinsic_path,
mode='test',
min_overlap_score=self.min_overlap_score_test,
pose_dir=self.test_pose_root)
elif dataset_name == 'aachen_v1.1':
test_dataset = AachenDataset(self.test_data_root, self.test_list_path,
img_resize=config.DATASET.TEST_IMGSIZE)
elif dataset_name == 'inloc':
test_dataset = InLocDataset(self.test_data_root, self.test_list_path,
img_resize=config.DATASET.TEST_IMGSIZE)
else:
raise "unknown dataset"
self.test_loader_params = {
'batch_size': 1,
'shuffle': False,
'num_workers': 4,
'pin_memory': True
}
# sampler = Seq(self.test_dataset, shuffle=False)
super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params)
def _setup_dataset(self,
data_root,
split_npz_root,
scene_list_path,
intri_path,
mode='train',
min_overlap_score=0.,
pose_dir=None):
""" Setup train / val / test set"""
with open(scene_list_path, 'r') as f:
npz_names = [name.split()[0] for name in f.readlines()]
local_npz_names = npz_names
return self._build_concat_dataset(data_root, local_npz_names, split_npz_root, intri_path,
mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
def _build_concat_dataset(
self,
data_root,
npz_names,
npz_dir,
intrinsic_path,
mode,
min_overlap_score=0.,
pose_dir=None
):
datasets = []
# augment_fn = self.augment_fn if mode == 'train' else None
data_source = self.test_data_source
if str(data_source).lower() == 'megadepth':
npz_names = [f'{n}.npz' for n in npz_names]
for npz_name in tqdm(npz_names):
# `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
npz_path = osp.join(npz_dir, npz_name)
if data_source == 'ScanNet':
datasets.append(
ScanNetDataset(data_root,
npz_path,
intrinsic_path,
mode=mode, img_resize=self.img_resize,
min_overlap_score=min_overlap_score,
pose_dir=pose_dir))
elif data_source == 'MegaDepth':
datasets.append(
MegaDepthDataset(data_root,
npz_path,
mode=mode,
min_overlap_score=min_overlap_score,
img_resize=self.mgdpt_img_resize,
df=self.mgdpt_df,
img_padding=self.mgdpt_img_pad,
depth_padding=self.mgdpt_depth_pad,
coarse_scale=self.coarse_scale))
else:
raise NotImplementedError()
return ConcatDataset(datasets)
|