Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
No virus
4.85 kB
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from random import shuffle, seed
from .gl3d.io import read_list, _parse_img, _parse_depth, _parse_kpts
from .utils.common import Notify
from .utils.photaug import photaug
class GL3DDataset(Dataset):
def __init__(self, dataset_dir, config, data_split, is_training):
self.dataset_dir = dataset_dir
self.config = config
self.is_training = is_training
self.data_split = data_split
self.match_set_list, self.global_img_list, \
self.global_depth_list = self.prepare_match_sets()
pass
def __len__(self):
return len(self.match_set_list)
def __getitem__(self, idx):
match_set_path = self.match_set_list[idx]
decoded = np.fromfile(match_set_path, dtype=np.float32)
idx0, idx1 = int(decoded[0]), int(decoded[1])
inlier_num = int(decoded[2])
ori_img_size0 = np.reshape(decoded[3:5], (2,))
ori_img_size1 = np.reshape(decoded[5:7], (2,))
K0 = np.reshape(decoded[7:16], (3, 3))
K1 = np.reshape(decoded[16:25], (3, 3))
rel_pose = np.reshape(decoded[34:46], (3, 4))
# parse images.
img0 = _parse_img(self.global_img_list, idx0, self.config)
img1 = _parse_img(self.global_img_list, idx1, self.config)
# parse depths
depth0 = _parse_depth(self.global_depth_list, idx0, self.config)
depth1 = _parse_depth(self.global_depth_list, idx1, self.config)
# photometric augmentation
img0 = photaug(img0)
img1 = photaug(img1)
return {
'img0': img0 / 255.,
'img1': img1 / 255.,
'depth0': depth0,
'depth1': depth1,
'ori_img_size0': ori_img_size0,
'ori_img_size1': ori_img_size1,
'K0': K0,
'K1': K1,
'rel_pose': rel_pose,
'inlier_num': inlier_num
}
def points_to_2D(self, pnts, H, W):
labels = np.zeros((H, W))
pnts = pnts.astype(int)
labels[pnts[:, 1], pnts[:, 0]] = 1
return labels
def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60):
"""Get match sets.
Args:
is_training: Use training imageset or testing imageset.
data_split: Data split name.
Returns:
match_set_list: List of match sets path.
global_img_list: List of global image path.
global_context_feat_list:
"""
# get necessary lists.
gl3d_list_folder = os.path.join(self.dataset_dir, 'list', self.data_split)
global_info = read_list(os.path.join(
gl3d_list_folder, 'image_index_offset.txt'))
global_img_list = [os.path.join(self.dataset_dir, i) for i in read_list(
os.path.join(gl3d_list_folder, 'image_list.txt'))]
global_depth_list = [os.path.join(self.dataset_dir, i) for i in read_list(
os.path.join(gl3d_list_folder, 'depth_list.txt'))]
imageset_list_name = 'imageset_train.txt' if self.is_training else 'imageset_test.txt'
match_set_list = self.get_match_set_list(os.path.join(
gl3d_list_folder, imageset_list_name), q_diff_thld, rot_diff_thld)
return match_set_list, global_img_list, global_depth_list
def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld):
"""Get the path list of match sets.
Args:
imageset_list_path: Path to imageset list.
q_diff_thld: Threshold of image pair sampling regarding camera orientation.
Returns:
match_set_list: List of match set path.
"""
imageset_list = [os.path.join(self.dataset_dir, 'data', i)
for i in read_list(imageset_list_path)]
print(Notify.INFO, 'Use # imageset', len(imageset_list), Notify.ENDC)
match_set_list = []
# discard image pairs whose image simiarity is beyond the threshold.
for i in imageset_list:
match_set_folder = os.path.join(i, 'match_sets')
if os.path.exists(match_set_folder):
match_set_files = os.listdir(match_set_folder)
for val in match_set_files:
name, ext = os.path.splitext(val)
if ext == '.match_set':
splits = name.split('_')
q_diff = int(splits[2])
rot_diff = int(splits[3])
if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld:
match_set_list.append(
os.path.join(match_set_folder, val))
print(Notify.INFO, 'Get # match sets', len(match_set_list), Notify.ENDC)
return match_set_list