import numpy as np import torch import torch.utils.data as data import cv2 import os import h5py import random import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) sys.path.insert(0, ROOT_DIR) from utils import train_utils, evaluation_utils torch.multiprocessing.set_sharing_strategy("file_system") class Offline_Dataset(data.Dataset): def __init__(self, config, mode): assert mode == "train" or mode == "valid" self.config = config self.mode = mode metadir = ( os.path.join(config.dataset_path, "valid") if mode == "valid" else os.path.join(config.dataset_path, "train") ) pair_num_list = np.loadtxt(os.path.join(metadir, "pair_num.txt"), dtype=str) self.total_pairs = int(pair_num_list[0, 1]) self.pair_seq_list, self.accu_pair_num = train_utils.parse_pair_seq( pair_num_list ) def collate_fn(self, batch): batch_size, num_pts = len(batch), batch[0]["x1"].shape[0] data = {} dtype = [ "x1", "x2", "kpt1", "kpt2", "desc1", "desc2", "num_corr", "num_incorr1", "num_incorr2", "e_gt", "pscore1", "pscore2", "img_path1", "img_path2", ] for key in dtype: data[key] = [] for sample in batch: for key in dtype: data[key].append(sample[key]) for key in [ "x1", "x2", "kpt1", "kpt2", "desc1", "desc2", "e_gt", "pscore1", "pscore2", ]: data[key] = torch.from_numpy(np.stack(data[key])).float() for key in ["num_corr", "num_incorr1", "num_incorr2"]: data[key] = torch.from_numpy(np.stack(data[key])).int() # kpt augmentation with random homography if self.mode == "train" and self.config.data_aug: homo_mat = torch.from_numpy( train_utils.get_rnd_homography(batch_size) ).unsqueeze(1) aug_seed = random.random() if aug_seed < 0.5: x1_homo = torch.cat( [data["x1"], torch.ones([batch_size, num_pts, 1])], dim=-1 ).unsqueeze(-1) x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1) data["aug_x1"] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1) data["aug_x2"] = data["x2"] else: x2_homo = torch.cat( [data["x2"], torch.ones([batch_size, num_pts, 1])], dim=-1 ).unsqueeze(-1) x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1) data["aug_x2"] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1) data["aug_x1"] = data["x1"] else: data["aug_x1"], data["aug_x2"] = data["x1"], data["x2"] return data def __getitem__(self, index): seq = self.pair_seq_list[index] index_within_seq = index - self.accu_pair_num[seq] with h5py.File( os.path.join(self.config.dataset_path, seq, "info.h5py"), "r" ) as data: R, t = ( data["dR"][str(index_within_seq)][()], data["dt"][str(index_within_seq)][()], ) egt = np.reshape( np.matmul( np.reshape( evaluation_utils.np_skew_symmetric( t.astype("float64").reshape(1, 3) ), (3, 3), ), np.reshape(R.astype("float64"), (3, 3)), ), (3, 3), ) egt = egt / np.linalg.norm(egt) K1, K2 = ( data["K1"][str(index_within_seq)][()], data["K2"][str(index_within_seq)][()], ) size1, size2 = ( data["size1"][str(index_within_seq)][()], data["size2"][str(index_within_seq)][()], ) img_path1, img_path2 = ( data["img_path1"][str(index_within_seq)][()][0].decode(), data["img_path2"][str(index_within_seq)][()][0].decode(), ) img_name1, img_name2 = img_path1.split("/")[-1], img_path2.split("/")[-1] img_path1, img_path2 = os.path.join( self.config.rawdata_path, img_path1 ), os.path.join(self.config.rawdata_path, img_path2) fea_path1, fea_path2 = os.path.join( self.config.desc_path, seq, img_name1 + self.config.desc_suffix ), os.path.join( self.config.desc_path, seq, img_name2 + self.config.desc_suffix ) with h5py.File(fea_path1, "r") as fea1, h5py.File(fea_path2, "r") as fea2: desc1, kpt1, pscore1 = ( fea1["descriptors"][()], fea1["keypoints"][()][:, :2], fea1["keypoints"][()][:, 2], ) desc2, kpt2, pscore2 = ( fea2["descriptors"][()], fea2["keypoints"][()][:, :2], fea2["keypoints"][()][:, 2], ) kpt1, kpt2, desc1, desc2 = ( kpt1[: self.config.num_kpt], kpt2[: self.config.num_kpt], desc1[: self.config.num_kpt], desc2[: self.config.num_kpt], ) # normalize kpt if self.config.input_normalize == "intrinsic": x1, x2 = np.concatenate( [kpt1, np.ones([kpt1.shape[0], 1])], axis=-1 ), np.concatenate([kpt2, np.ones([kpt2.shape[0], 1])], axis=-1) x1, x2 = ( np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], np.matmul(np.linalg.inv(K2), x2.T).T[:, :2], ) elif self.config.input_normalize == "img": x1, x2 = (kpt1 - size1 / 2) / size1, (kpt2 - size2 / 2) / size2 S1_inv, S2_inv = np.asarray( [ [size1[0], 0, 0.5 * size1[0]], [0, size1[1], 0.5 * size1[1]], [0, 0, 1], ] ), np.asarray( [ [size2[0], 0, 0.5 * size2[0]], [0, size2[1], 0.5 * size2[1]], [0, 0, 1], ] ) M1, M2 = np.matmul(np.linalg.inv(K1), S1_inv), np.matmul( np.linalg.inv(K2), S2_inv ) egt = np.matmul(np.matmul(M2.transpose(), egt), M1) egt = egt / np.linalg.norm(egt) else: raise NotImplementedError corr = data["corr"][str(index_within_seq)][()] incorr1, incorr2 = ( data["incorr1"][str(index_within_seq)][()], data["incorr2"][str(index_within_seq)][()], ) # permute kpt valid_corr = corr[corr.max(axis=-1) < self.config.num_kpt] valid_incorr1, valid_incorr2 = ( incorr1[incorr1 < self.config.num_kpt], incorr2[incorr2 < self.config.num_kpt], ) num_corr, num_incorr1, num_incorr2 = ( len(valid_corr), len(valid_incorr1), len(valid_incorr2), ) mask1_invlaid, mask2_invalid = np.ones(x1.shape[0]).astype(bool), np.ones( x2.shape[0] ).astype(bool) mask1_invlaid[valid_corr[:, 0]] = False mask2_invalid[valid_corr[:, 1]] = False mask1_invlaid[valid_incorr1] = False mask2_invalid[valid_incorr2] = False invalid_index1, invalid_index2 = ( np.nonzero(mask1_invlaid)[0], np.nonzero(mask2_invalid)[0], ) # random sample from point w/o valid annotation cur_kpt1 = self.config.num_kpt - num_corr - num_incorr1 cur_kpt2 = self.config.num_kpt - num_corr - num_incorr2 if invalid_index1.shape[0] < cur_kpt1: sub_idx1 = np.concatenate( [ np.arange(len(invalid_index1)), np.random.randint( len(invalid_index1), size=cur_kpt1 - len(invalid_index1) ), ] ) if invalid_index1.shape[0] >= cur_kpt1: sub_idx1 = np.random.choice(len(invalid_index1), cur_kpt1, replace=False) if invalid_index2.shape[0] < cur_kpt2: sub_idx2 = np.concatenate( [ np.arange(len(invalid_index2)), np.random.randint( len(invalid_index2), size=cur_kpt2 - len(invalid_index2) ), ] ) if invalid_index2.shape[0] >= cur_kpt2: sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2, replace=False) per_idx1, per_idx2 = np.concatenate( [valid_corr[:, 0], valid_incorr1, invalid_index1[sub_idx1]] ), np.concatenate([valid_corr[:, 1], valid_incorr2, invalid_index2[sub_idx2]]) pscore1, pscore2 = ( pscore1[per_idx1][:, np.newaxis], pscore2[per_idx2][:, np.newaxis], ) x1, x2 = x1[per_idx1][:, :2], x2[per_idx2][:, :2] desc1, desc2 = desc1[per_idx1], desc2[per_idx2] kpt1, kpt2 = kpt1[per_idx1], kpt2[per_idx2] return { "x1": x1, "x2": x2, "kpt1": kpt1, "kpt2": kpt2, "desc1": desc1, "desc2": desc2, "num_corr": num_corr, "num_incorr1": num_incorr1, "num_incorr2": num_incorr2, "e_gt": egt, "pscore1": pscore1, "pscore2": pscore2, "img_path1": img_path1, "img_path2": img_path2, } def __len__(self): return self.total_pairs