File size: 5,298 Bytes
437b5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from tqdm import tqdm
from os import path as osp
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from src.datasets.megadepth import MegaDepthDataset
from src.datasets.scannet import ScanNetDataset
from src.datasets.aachen import AachenDataset
from src.datasets.inloc import InLocDataset


class TestDataLoader(DataLoader):
    """
    For distributed training, each training process is assgined
    only a part of the training scenes to reduce memory overhead.
    """

    def __init__(self, config):

        # 1. data config
        self.test_data_source = config.DATASET.TEST_DATA_SOURCE
        dataset_name = str(self.test_data_source).lower()
        # testing
        self.test_data_root = config.DATASET.TEST_DATA_ROOT
        self.test_pose_root = config.DATASET.TEST_POSE_ROOT  # (optional)
        self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
        self.test_list_path = config.DATASET.TEST_LIST_PATH
        self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH

        # 2. dataset config
        # general options
        self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST  # 0.4, omit data with overlap_score < min_overlap_score

        # MegaDepth options
        if dataset_name == 'megadepth':
            self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE  # 800
            self.mgdpt_img_pad = True
            self.mgdpt_depth_pad = True
            self.mgdpt_df = 8
            self.coarse_scale = 0.125
        if dataset_name == 'scannet':
            self.img_resize = config.DATASET.TEST_IMGSIZE

        if (dataset_name == 'megadepth') or (dataset_name == 'scannet'):
            test_dataset = self._setup_dataset(
                self.test_data_root,
                self.test_npz_root,
                self.test_list_path,
                self.test_intrinsic_path,
                mode='test',
                min_overlap_score=self.min_overlap_score_test,
                pose_dir=self.test_pose_root)
        elif dataset_name == 'aachen_v1.1':
            test_dataset = AachenDataset(self.test_data_root, self.test_list_path,
                                         img_resize=config.DATASET.TEST_IMGSIZE)
        elif dataset_name == 'inloc':
            test_dataset = InLocDataset(self.test_data_root, self.test_list_path,
                                        img_resize=config.DATASET.TEST_IMGSIZE)
        else:
            raise "unknown dataset"

        self.test_loader_params = {
            'batch_size': 1,
            'shuffle': False,
            'num_workers': 4,
            'pin_memory': True
        }

        # sampler = Seq(self.test_dataset, shuffle=False)
        super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params)

    def _setup_dataset(self,
                       data_root,
                       split_npz_root,
                       scene_list_path,
                       intri_path,
                       mode='train',
                       min_overlap_score=0.,
                       pose_dir=None):
        """ Setup train / val / test set"""
        with open(scene_list_path, 'r') as f:
            npz_names = [name.split()[0] for name in f.readlines()]
        local_npz_names = npz_names

        return self._build_concat_dataset(data_root, local_npz_names, split_npz_root, intri_path,
                                          mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)

    def _build_concat_dataset(
            self,
            data_root,
            npz_names,
            npz_dir,
            intrinsic_path,
            mode,
            min_overlap_score=0.,
            pose_dir=None
    ):
        datasets = []
        # augment_fn = self.augment_fn if mode == 'train' else None
        data_source = self.test_data_source
        if str(data_source).lower() == 'megadepth':
            npz_names = [f'{n}.npz' for n in npz_names]
        for npz_name in tqdm(npz_names):
            # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
            npz_path = osp.join(npz_dir, npz_name)
            if data_source == 'ScanNet':
                datasets.append(
                    ScanNetDataset(data_root,
                                   npz_path,
                                   intrinsic_path,
                                   mode=mode, img_resize=self.img_resize,
                                   min_overlap_score=min_overlap_score,
                                   pose_dir=pose_dir))
            elif data_source == 'MegaDepth':
                datasets.append(
                    MegaDepthDataset(data_root,
                                     npz_path,
                                     mode=mode,
                                     min_overlap_score=min_overlap_score,
                                     img_resize=self.mgdpt_img_resize,
                                     df=self.mgdpt_df,
                                     img_padding=self.mgdpt_img_pad,
                                     depth_padding=self.mgdpt_depth_pad,
                                     coarse_scale=self.coarse_scale))
            else:
                raise NotImplementedError()
        return ConcatDataset(datasets)