import os import numpy as np import torch from torch.utils.data import Dataset from random import shuffle, seed from .gl3d.io import read_list, _parse_img, _parse_depth, _parse_kpts from .utils.common import Notify from .utils.photaug import photaug class GL3DDataset(Dataset): def __init__(self, dataset_dir, config, data_split, is_training): self.dataset_dir = dataset_dir self.config = config self.is_training = is_training self.data_split = data_split ( self.match_set_list, self.global_img_list, self.global_depth_list, ) = self.prepare_match_sets() pass def __len__(self): return len(self.match_set_list) def __getitem__(self, idx): match_set_path = self.match_set_list[idx] decoded = np.fromfile(match_set_path, dtype=np.float32) idx0, idx1 = int(decoded[0]), int(decoded[1]) inlier_num = int(decoded[2]) ori_img_size0 = np.reshape(decoded[3:5], (2,)) ori_img_size1 = np.reshape(decoded[5:7], (2,)) K0 = np.reshape(decoded[7:16], (3, 3)) K1 = np.reshape(decoded[16:25], (3, 3)) rel_pose = np.reshape(decoded[34:46], (3, 4)) # parse images. img0 = _parse_img(self.global_img_list, idx0, self.config) img1 = _parse_img(self.global_img_list, idx1, self.config) # parse depths depth0 = _parse_depth(self.global_depth_list, idx0, self.config) depth1 = _parse_depth(self.global_depth_list, idx1, self.config) # photometric augmentation img0 = photaug(img0) img1 = photaug(img1) return { "img0": img0 / 255.0, "img1": img1 / 255.0, "depth0": depth0, "depth1": depth1, "ori_img_size0": ori_img_size0, "ori_img_size1": ori_img_size1, "K0": K0, "K1": K1, "rel_pose": rel_pose, "inlier_num": inlier_num, } def points_to_2D(self, pnts, H, W): labels = np.zeros((H, W)) pnts = pnts.astype(int) labels[pnts[:, 1], pnts[:, 0]] = 1 return labels def prepare_match_sets(self, q_diff_thld=3, rot_diff_thld=60): """Get match sets. Args: is_training: Use training imageset or testing imageset. data_split: Data split name. Returns: match_set_list: List of match sets path. global_img_list: List of global image path. global_context_feat_list: """ # get necessary lists. gl3d_list_folder = os.path.join(self.dataset_dir, "list", self.data_split) global_info = read_list( os.path.join(gl3d_list_folder, "image_index_offset.txt") ) global_img_list = [ os.path.join(self.dataset_dir, i) for i in read_list(os.path.join(gl3d_list_folder, "image_list.txt")) ] global_depth_list = [ os.path.join(self.dataset_dir, i) for i in read_list(os.path.join(gl3d_list_folder, "depth_list.txt")) ] imageset_list_name = ( "imageset_train.txt" if self.is_training else "imageset_test.txt" ) match_set_list = self.get_match_set_list( os.path.join(gl3d_list_folder, imageset_list_name), q_diff_thld, rot_diff_thld, ) return match_set_list, global_img_list, global_depth_list def get_match_set_list(self, imageset_list_path, q_diff_thld, rot_diff_thld): """Get the path list of match sets. Args: imageset_list_path: Path to imageset list. q_diff_thld: Threshold of image pair sampling regarding camera orientation. Returns: match_set_list: List of match set path. """ imageset_list = [ os.path.join(self.dataset_dir, "data", i) for i in read_list(imageset_list_path) ] print(Notify.INFO, "Use # imageset", len(imageset_list), Notify.ENDC) match_set_list = [] # discard image pairs whose image simiarity is beyond the threshold. for i in imageset_list: match_set_folder = os.path.join(i, "match_sets") if os.path.exists(match_set_folder): match_set_files = os.listdir(match_set_folder) for val in match_set_files: name, ext = os.path.splitext(val) if ext == ".match_set": splits = name.split("_") q_diff = int(splits[2]) rot_diff = int(splits[3]) if q_diff >= q_diff_thld and rot_diff <= rot_diff_thld: match_set_list.append(os.path.join(match_set_folder, val)) print(Notify.INFO, "Get # match sets", len(match_set_list), Notify.ENDC) return match_set_list