File size: 4,953 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
 
 
 
 
a80d6bb
c74a070
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
 
 
 
 
 
 
 
 
 
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a80d6bb
 
 
 
 
 
 
 
 
 
c74a070
 
 
 
 
a80d6bb
 
 
c74a070
a80d6bb
 
 
 
c74a070
 
a80d6bb
 
 
c74a070
a80d6bb
c74a070
a80d6bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from random import shuffle, seed

from .gl3d.io import read_list, _parse_img, _parse_depth, _parse_kpts
from .utils.common import Notify
from .utils.photaug import photaug


class GL3DDataset(Dataset):
    def __init__(self, dataset_dir, config, data_split, is_training):
        self.dataset_dir = dataset_dir
        self.config = config
        self.is_training = is_training
        self.data_split = data_split

        (
            self.match_set_list,
            self.global_img_list,
            self.global_depth_list,
        ) = self.prepare_match_sets()

        pass

    def __len__(self):
        return len(self.match_set_list)

    def __getitem__(self, idx):
        match_set_path = self.match_set_list[idx]
        decoded = np.fromfile(match_set_path, dtype=np.float32)

        idx0, idx1 = int(decoded[0]), int(decoded[1])
        inlier_num = int(decoded[2])
        ori_img_size0 = np.reshape(decoded[3:5], (2,))
        ori_img_size1 = np.reshape(decoded[5:7], (2,))
        K0 = np.reshape(decoded[7:16], (3, 3))
        K1 = np.reshape(decoded[16:25], (3, 3))
        rel_pose = np.reshape(decoded[34:46], (3, 4))

        # parse images.
        img0 = _parse_img(self.global_img_list, idx0, self.config)
        img1 = _parse_img(self.global_img_list, idx1, self.config)
        # parse depths
        depth0 = _parse_depth(self.global_depth_list, idx0, self.config)
        depth1 = _parse_depth(self.global_depth_list, idx1, self.config)

        # photometric augmentation
        img0 = photaug(img0)
        img1 = photaug(img1)

        return {
            "img0": img0 / 255.0,
            "img1": img1 / 255.0,
            "depth0": depth0,
            "depth1": depth1,
            "ori_img_size0": ori_img_size0,
            "ori_img_size1": ori_img_size1,
            "K0": K0,
            "K1": K1,
            "rel_pose": rel_pose,
            "inlier_num": inlier_num,
        }

    def points_to_2D(self, pnts, H, W):
        labels = np.zeros((H, W))
        pnts = pnts.astype(int)
        labels[pnts[:, 1], pnts[:, 0]] = 1
        return labels

    def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60):
        """Get match sets.
        Args:
            is_training: Use training imageset or testing imageset.
            data_split: Data split name.
        Returns:
            match_set_list: List of match sets path.
            global_img_list: List of global image path.
            global_context_feat_list:
        """
        # get necessary lists.
        gl3d_list_folder = os.path.join(self.dataset_dir, "list", self.data_split)
        global_info = read_list(
            os.path.join(gl3d_list_folder, "image_index_offset.txt")
        )
        global_img_list = [
            os.path.join(self.dataset_dir, i)
            for i in read_list(os.path.join(gl3d_list_folder, "image_list.txt"))
        ]
        global_depth_list = [
            os.path.join(self.dataset_dir, i)
            for i in read_list(os.path.join(gl3d_list_folder, "depth_list.txt"))
        ]

        imageset_list_name = (
            "imageset_train.txt" if self.is_training else "imageset_test.txt"
        )
        match_set_list = self.get_match_set_list(
            os.path.join(gl3d_list_folder, imageset_list_name),
            q_diff_thld,
            rot_diff_thld,
        )
        return match_set_list, global_img_list, global_depth_list

    def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld):
        """Get the path list of match sets.
        Args:
            imageset_list_path: Path to imageset list.
            q_diff_thld: Threshold of image pair sampling regarding camera orientation.
        Returns:
            match_set_list: List of match set path.
        """
        imageset_list = [
            os.path.join(self.dataset_dir, "data", i)
            for i in read_list(imageset_list_path)
        ]
        print(Notify.INFO, "Use # imageset", len(imageset_list), Notify.ENDC)
        match_set_list = []
        # discard image pairs whose image simiarity is beyond the threshold.
        for i in imageset_list:
            match_set_folder = os.path.join(i, "match_sets")
            if os.path.exists(match_set_folder):
                match_set_files = os.listdir(match_set_folder)
                for val in match_set_files:
                    name, ext = os.path.splitext(val)
                    if ext == ".match_set":
                        splits = name.split("_")
                        q_diff = int(splits[2])
                        rot_diff = int(splits[3])
                        if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld:
                            match_set_list.append(os.path.join(match_set_folder, val))

        print(Notify.INFO, "Get # match sets", len(match_set_list), Notify.ENDC)
        return match_set_list