Vincentqyw
update: features and matchers
437b5f6
raw
history blame
7.89 kB
import numpy as np
import torch
import torch.utils.data as data
import cv2
import os
import h5py
import random
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
sys.path.insert(0, ROOT_DIR)
from utils import train_utils,evaluation_utils
torch.multiprocessing.set_sharing_strategy('file_system')
class Offline_Dataset(data.Dataset):
def __init__(self,config,mode):
assert mode=='train' or mode=='valid'
self.config = config
self.mode = mode
metadir=os.path.join(config.dataset_path,'valid') if mode=='valid' else os.path.join(config.dataset_path,'train')
pair_num_list=np.loadtxt(os.path.join(metadir,'pair_num.txt'),dtype=str)
self.total_pairs=int(pair_num_list[0,1])
self.pair_seq_list,self.accu_pair_num=train_utils.parse_pair_seq(pair_num_list)
def collate_fn(self, batch):
batch_size, num_pts = len(batch), batch[0]['x1'].shape[0]
data = {}
dtype=['x1','x2','kpt1','kpt2','desc1','desc2','num_corr','num_incorr1','num_incorr2','e_gt','pscore1','pscore2','img_path1','img_path2']
for key in dtype:
data[key]=[]
for sample in batch:
for key in dtype:
data[key].append(sample[key])
for key in ['x1', 'x2','kpt1','kpt2', 'desc1', 'desc2','e_gt','pscore1','pscore2']:
data[key] = torch.from_numpy(np.stack(data[key])).float()
for key in ['num_corr', 'num_incorr1', 'num_incorr2']:
data[key] = torch.from_numpy(np.stack(data[key])).int()
# kpt augmentation with random homography
if (self.mode == 'train' and self.config.data_aug):
homo_mat = torch.from_numpy(train_utils.get_rnd_homography(batch_size)).unsqueeze(1)
aug_seed=random.random()
if aug_seed<0.5:
x1_homo = torch.cat([data['x1'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1)
x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1)
data['aug_x1'] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1)
data['aug_x2']=data['x2']
else:
x2_homo = torch.cat([data['x2'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1)
x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1)
data['aug_x2'] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1)
data['aug_x1']=data['x1']
else:
data['aug_x1'],data['aug_x2']=data['x1'],data['x2']
return data
def __getitem__(self, index):
seq=self.pair_seq_list[index]
index_within_seq=index-self.accu_pair_num[seq]
with h5py.File(os.path.join(self.config.dataset_path,seq,'info.h5py'),'r') as data:
R,t = data['dR'][str(index_within_seq)][()], data['dt'][str(index_within_seq)][()]
egt = np.reshape(np.matmul(np.reshape(evaluation_utils.np_skew_symmetric(t.astype('float64').reshape(1, 3)), (3, 3)),np.reshape(R.astype('float64'), (3, 3))), (3, 3))
egt = egt / np.linalg.norm(egt)
K1, K2 = data['K1'][str(index_within_seq)][()],data['K2'][str(index_within_seq)][()]
size1,size2=data['size1'][str(index_within_seq)][()],data['size2'][str(index_within_seq)][()]
img_path1,img_path2=data['img_path1'][str(index_within_seq)][()][0].decode(),data['img_path2'][str(index_within_seq)][()][0].decode()
img_name1,img_name2=img_path1.split('/')[-1],img_path2.split('/')[-1]
img_path1,img_path2=os.path.join(self.config.rawdata_path,img_path1),os.path.join(self.config.rawdata_path,img_path2)
fea_path1,fea_path2=os.path.join(self.config.desc_path,seq,img_name1+self.config.desc_suffix),\
os.path.join(self.config.desc_path,seq,img_name2+self.config.desc_suffix)
with h5py.File(fea_path1,'r') as fea1, h5py.File(fea_path2,'r') as fea2:
desc1,kpt1,pscore1=fea1['descriptors'][()],fea1['keypoints'][()][:,:2],fea1['keypoints'][()][:,2]
desc2,kpt2,pscore2=fea2['descriptors'][()],fea2['keypoints'][()][:,:2],fea2['keypoints'][()][:,2]
kpt1,kpt2,desc1,desc2=kpt1[:self.config.num_kpt],kpt2[:self.config.num_kpt],desc1[:self.config.num_kpt],desc2[:self.config.num_kpt]
# normalize kpt
if self.config.input_normalize=='intrinsic':
x1, x2 = np.concatenate([kpt1, np.ones([kpt1.shape[0], 1])], axis=-1), np.concatenate(
[kpt2, np.ones([kpt2.shape[0], 1])], axis=-1)
x1, x2 = np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], np.matmul(np.linalg.inv(K2), x2.T).T[:, :2]
elif self.config.input_normalize=='img' :
x1,x2=(kpt1-size1/2)/size1,(kpt2-size2/2)/size2
S1_inv,S2_inv=np.asarray([[size1[0],0,0.5*size1[0]],[0,size1[1],0.5*size1[1]],[0,0,1]]),\
np.asarray([[size2[0],0,0.5*size2[0]],[0,size2[1],0.5*size2[1]],[0,0,1]])
M1,M2=np.matmul(np.linalg.inv(K1),S1_inv),np.matmul(np.linalg.inv(K2),S2_inv)
egt=np.matmul(np.matmul(M2.transpose(),egt),M1)
egt = egt / np.linalg.norm(egt)
else:
raise NotImplementedError
corr=data['corr'][str(index_within_seq)][()]
incorr1,incorr2=data['incorr1'][str(index_within_seq)][()],data['incorr2'][str(index_within_seq)][()]
#permute kpt
valid_corr=corr[corr.max(axis=-1)<self.config.num_kpt]
valid_incorr1,valid_incorr2=incorr1[incorr1<self.config.num_kpt],incorr2[incorr2<self.config.num_kpt]
num_corr, num_incorr1, num_incorr2 = len(valid_corr), len(valid_incorr1), len(valid_incorr2)
mask1_invlaid, mask2_invalid = np.ones(x1.shape[0]).astype(bool), np.ones(x2.shape[0]).astype(bool)
mask1_invlaid[valid_corr[:, 0]] = False
mask2_invalid[valid_corr[:, 1]] = False
mask1_invlaid[valid_incorr1] = False
mask2_invalid[valid_incorr2] = False
invalid_index1,invalid_index2=np.nonzero(mask1_invlaid)[0],np.nonzero(mask2_invalid)[0]
#random sample from point w/o valid annotation
cur_kpt1 = self.config.num_kpt - num_corr - num_incorr1
cur_kpt2 = self.config.num_kpt - num_corr - num_incorr2
if (invalid_index1.shape[0] < cur_kpt1):
sub_idx1 = np.concatenate([np.arange(len(invalid_index1)),np.random.randint(len(invalid_index1),size=cur_kpt1-len(invalid_index1))])
if (invalid_index1.shape[0] >= cur_kpt1):
sub_idx1 =np.random.choice(len(invalid_index1), cur_kpt1,replace=False)
if (invalid_index2.shape[0] < cur_kpt2):
sub_idx2 = np.concatenate([np.arange(len(invalid_index2)),np.random.randint(len(invalid_index2),size=cur_kpt2-len(invalid_index2))])
if (invalid_index2.shape[0] >= cur_kpt2):
sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2,replace=False)
per_idx1,per_idx2=np.concatenate([valid_corr[:,0],valid_incorr1,invalid_index1[sub_idx1]]),\
np.concatenate([valid_corr[:,1],valid_incorr2,invalid_index2[sub_idx2]])
pscore1,pscore2=pscore1[per_idx1][:,np.newaxis],pscore2[per_idx2][:,np.newaxis]
x1,x2=x1[per_idx1][:,:2],x2[per_idx2][:,:2]
desc1,desc2=desc1[per_idx1],desc2[per_idx2]
kpt1,kpt2=kpt1[per_idx1],kpt2[per_idx2]
return {'x1': x1, 'x2': x2, 'kpt1':kpt1,'kpt2':kpt2,'desc1': desc1, 'desc2': desc2, 'num_corr': num_corr, 'num_incorr1': num_incorr1,'num_incorr2': num_incorr2,'e_gt':egt,\
'pscore1':pscore1,'pscore2':pscore2,'img_path1':img_path1,'img_path2':img_path2}
def __len__(self):
return self.total_pairs