File size: 4,851 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from random import shuffle, seed

from .gl3d.io import read_list, _parse_img, _parse_depth, _parse_kpts
from .utils.common import Notify
from .utils.photaug import photaug


class GL3DDataset(Dataset):
    def __init__(self, dataset_dir, config, data_split, is_training):
        self.dataset_dir = dataset_dir
        self.config = config
        self.is_training = is_training
        self.data_split = data_split
        
        self.match_set_list, self.global_img_list, \
            self.global_depth_list = self.prepare_match_sets()

        pass


    def __len__(self):
        return len(self.match_set_list)


    def __getitem__(self, idx):
        match_set_path = self.match_set_list[idx]
        decoded = np.fromfile(match_set_path, dtype=np.float32)

        idx0, idx1 = int(decoded[0]), int(decoded[1])
        inlier_num = int(decoded[2])
        ori_img_size0 = np.reshape(decoded[3:5], (2,))
        ori_img_size1 = np.reshape(decoded[5:7], (2,))
        K0 = np.reshape(decoded[7:16], (3, 3))
        K1 = np.reshape(decoded[16:25], (3, 3))
        rel_pose = np.reshape(decoded[34:46], (3, 4))

        # parse images.
        img0 = _parse_img(self.global_img_list, idx0, self.config)
        img1 = _parse_img(self.global_img_list, idx1, self.config)
        # parse depths
        depth0 = _parse_depth(self.global_depth_list, idx0, self.config)
        depth1 = _parse_depth(self.global_depth_list, idx1, self.config)

        # photometric augmentation
        img0 = photaug(img0)
        img1 = photaug(img1)

        return {
            'img0': img0 / 255.,
            'img1': img1 / 255.,
            'depth0': depth0,
            'depth1': depth1,
            'ori_img_size0': ori_img_size0,
            'ori_img_size1': ori_img_size1,
            'K0': K0,
            'K1': K1,
            'rel_pose': rel_pose,
            'inlier_num': inlier_num
        }


    def points_to_2D(self, pnts, H, W):
        labels = np.zeros((H, W))
        pnts = pnts.astype(int)
        labels[pnts[:, 1], pnts[:, 0]] = 1
        return labels


    def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60):
        """Get match sets.
        Args:
            is_training: Use training imageset or testing imageset.
            data_split: Data split name.
        Returns:
            match_set_list: List of match sets path.
            global_img_list: List of global image path.
            global_context_feat_list:
        """
        # get necessary lists.
        gl3d_list_folder = os.path.join(self.dataset_dir, 'list', self.data_split)
        global_info = read_list(os.path.join(
            gl3d_list_folder, 'image_index_offset.txt'))
        global_img_list = [os.path.join(self.dataset_dir, i) for i in read_list(
            os.path.join(gl3d_list_folder, 'image_list.txt'))]
        global_depth_list = [os.path.join(self.dataset_dir, i) for i in read_list(
            os.path.join(gl3d_list_folder, 'depth_list.txt'))]

        imageset_list_name = 'imageset_train.txt' if self.is_training else 'imageset_test.txt'
        match_set_list = self.get_match_set_list(os.path.join(
            gl3d_list_folder, imageset_list_name), q_diff_thld, rot_diff_thld)
        return match_set_list, global_img_list, global_depth_list


    def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld):
        """Get the path list of match sets.
        Args:
            imageset_list_path: Path to imageset list.
            q_diff_thld: Threshold of image pair sampling regarding camera orientation.
        Returns:
            match_set_list: List of match set path.
        """
        imageset_list = [os.path.join(self.dataset_dir, 'data', i)
                        for i in read_list(imageset_list_path)]
        print(Notify.INFO, 'Use # imageset', len(imageset_list), Notify.ENDC)
        match_set_list = []
        # discard image pairs whose image simiarity is beyond the threshold.
        for i in imageset_list:
            match_set_folder = os.path.join(i, 'match_sets')
            if os.path.exists(match_set_folder):
                match_set_files = os.listdir(match_set_folder)
                for val in match_set_files:
                    name, ext = os.path.splitext(val)
                    if ext == '.match_set':
                        splits = name.split('_')
                        q_diff = int(splits[2])
                        rot_diff = int(splits[3])
                        if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld:
                            match_set_list.append(
                                os.path.join(match_set_folder, val))

        print(Notify.INFO, 'Get # match sets', len(match_set_list), Notify.ENDC)
        return match_set_list