File size: 5,308 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
 
 
a80d6bb
 
c74a070
a80d6bb
 
 
 
 
c74a070
a80d6bb
 
c74a070
a80d6bb
 
 
 
 
c74a070
a80d6bb
c74a070
 
 
 
 
 
 
 
 
 
 
 
 
 
a80d6bb
 
 
 
c74a070
 
 
 
a80d6bb
 
 
 
 
c74a070
 
 
 
 
 
 
 
 
 
 
 
a80d6bb
 
 
c74a070
 
 
 
 
 
 
 
 
a80d6bb
 
c74a070
 
 
 
 
 
 
 
a80d6bb
 
 
 
c74a070
 
a80d6bb
 
 
c74a070
a80d6bb
c74a070
 
 
 
 
 
 
 
 
 
 
a80d6bb
c74a070
 
 
 
 
 
 
 
 
 
 
 
a80d6bb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from tqdm import tqdm
from os import path as osp
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from src.datasets.megadepth import MegaDepthDataset
from src.datasets.scannet import ScanNetDataset
from src.datasets.aachen import AachenDataset
from src.datasets.inloc import InLocDataset


class TestDataLoader(DataLoader):
    """
    For distributed training, each training process is assgined
    only a part of the training scenes to reduce memory overhead.
    """

    def __init__(self, config):

        # 1. data config
        self.test_data_source = config.DATASET.TEST_DATA_SOURCE
        dataset_name = str(self.test_data_source).lower()
        # testing
        self.test_data_root = config.DATASET.TEST_DATA_ROOT
        self.test_pose_root = config.DATASET.TEST_POSE_ROOT  # (optional)
        self.test_npz_root = config.DATASET.TEST_NPZ_ROOT
        self.test_list_path = config.DATASET.TEST_LIST_PATH
        self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH

        # 2. dataset config
        # general options
        self.min_overlap_score_test = (
            config.DATASET.MIN_OVERLAP_SCORE_TEST
        )  # 0.4, omit data with overlap_score < min_overlap_score

        # MegaDepth options
        if dataset_name == "megadepth":
            self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE  # 800
            self.mgdpt_img_pad = True
            self.mgdpt_depth_pad = True
            self.mgdpt_df = 8
            self.coarse_scale = 0.125
        if dataset_name == "scannet":
            self.img_resize = config.DATASET.TEST_IMGSIZE

        if (dataset_name == "megadepth") or (dataset_name == "scannet"):
            test_dataset = self._setup_dataset(
                self.test_data_root,
                self.test_npz_root,
                self.test_list_path,
                self.test_intrinsic_path,
                mode="test",
                min_overlap_score=self.min_overlap_score_test,
                pose_dir=self.test_pose_root,
            )
        elif dataset_name == "aachen_v1.1":
            test_dataset = AachenDataset(
                self.test_data_root,
                self.test_list_path,
                img_resize=config.DATASET.TEST_IMGSIZE,
            )
        elif dataset_name == "inloc":
            test_dataset = InLocDataset(
                self.test_data_root,
                self.test_list_path,
                img_resize=config.DATASET.TEST_IMGSIZE,
            )
        else:
            raise "unknown dataset"

        self.test_loader_params = {
            "batch_size": 1,
            "shuffle": False,
            "num_workers": 4,
            "pin_memory": True,
        }

        # sampler = Seq(self.test_dataset, shuffle=False)
        super(TestDataLoader, self).__init__(test_dataset, **self.test_loader_params)

    def _setup_dataset(
        self,
        data_root,
        split_npz_root,
        scene_list_path,
        intri_path,
        mode="train",
        min_overlap_score=0.0,
        pose_dir=None,
    ):
        """Setup train / val / test set"""
        with open(scene_list_path, "r") as f:
            npz_names = [name.split()[0] for name in f.readlines()]
        local_npz_names = npz_names

        return self._build_concat_dataset(
            data_root,
            local_npz_names,
            split_npz_root,
            intri_path,
            mode=mode,
            min_overlap_score=min_overlap_score,
            pose_dir=pose_dir,
        )

    def _build_concat_dataset(
        self,
        data_root,
        npz_names,
        npz_dir,
        intrinsic_path,
        mode,
        min_overlap_score=0.0,
        pose_dir=None,
    ):
        datasets = []
        # augment_fn = self.augment_fn if mode == 'train' else None
        data_source = self.test_data_source
        if str(data_source).lower() == "megadepth":
            npz_names = [f"{n}.npz" for n in npz_names]
        for npz_name in tqdm(npz_names):
            # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
            npz_path = osp.join(npz_dir, npz_name)
            if data_source == "ScanNet":
                datasets.append(
                    ScanNetDataset(
                        data_root,
                        npz_path,
                        intrinsic_path,
                        mode=mode,
                        img_resize=self.img_resize,
                        min_overlap_score=min_overlap_score,
                        pose_dir=pose_dir,
                    )
                )
            elif data_source == "MegaDepth":
                datasets.append(
                    MegaDepthDataset(
                        data_root,
                        npz_path,
                        mode=mode,
                        min_overlap_score=min_overlap_score,
                        img_resize=self.mgdpt_img_resize,
                        df=self.mgdpt_df,
                        img_padding=self.mgdpt_img_pad,
                        depth_padding=self.mgdpt_depth_pad,
                        coarse_scale=self.coarse_scale,
                    )
                )
            else:
                raise NotImplementedError()
        return ConcatDataset(datasets)