Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import torch.utils.data as data | |
| import cv2 | |
| import os | |
| import h5py | |
| import random | |
| import sys | |
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) | |
| sys.path.insert(0, ROOT_DIR) | |
| from utils import train_utils,evaluation_utils | |
| torch.multiprocessing.set_sharing_strategy('file_system') | |
| class Offline_Dataset(data.Dataset): | |
| def __init__(self,config,mode): | |
| assert mode=='train' or mode=='valid' | |
| self.config = config | |
| self.mode = mode | |
| metadir=os.path.join(config.dataset_path,'valid') if mode=='valid' else os.path.join(config.dataset_path,'train') | |
| pair_num_list=np.loadtxt(os.path.join(metadir,'pair_num.txt'),dtype=str) | |
| self.total_pairs=int(pair_num_list[0,1]) | |
| self.pair_seq_list,self.accu_pair_num=train_utils.parse_pair_seq(pair_num_list) | |
| def collate_fn(self, batch): | |
| batch_size, num_pts = len(batch), batch[0]['x1'].shape[0] | |
| data = {} | |
| dtype=['x1','x2','kpt1','kpt2','desc1','desc2','num_corr','num_incorr1','num_incorr2','e_gt','pscore1','pscore2','img_path1','img_path2'] | |
| for key in dtype: | |
| data[key]=[] | |
| for sample in batch: | |
| for key in dtype: | |
| data[key].append(sample[key]) | |
| for key in ['x1', 'x2','kpt1','kpt2', 'desc1', 'desc2','e_gt','pscore1','pscore2']: | |
| data[key] = torch.from_numpy(np.stack(data[key])).float() | |
| for key in ['num_corr', 'num_incorr1', 'num_incorr2']: | |
| data[key] = torch.from_numpy(np.stack(data[key])).int() | |
| # kpt augmentation with random homography | |
| if (self.mode == 'train' and self.config.data_aug): | |
| homo_mat = torch.from_numpy(train_utils.get_rnd_homography(batch_size)).unsqueeze(1) | |
| aug_seed=random.random() | |
| if aug_seed<0.5: | |
| x1_homo = torch.cat([data['x1'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1) | |
| x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1) | |
| data['aug_x1'] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1) | |
| data['aug_x2']=data['x2'] | |
| else: | |
| x2_homo = torch.cat([data['x2'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1) | |
| x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1) | |
| data['aug_x2'] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1) | |
| data['aug_x1']=data['x1'] | |
| else: | |
| data['aug_x1'],data['aug_x2']=data['x1'],data['x2'] | |
| return data | |
| def __getitem__(self, index): | |
| seq=self.pair_seq_list[index] | |
| index_within_seq=index-self.accu_pair_num[seq] | |
| with h5py.File(os.path.join(self.config.dataset_path,seq,'info.h5py'),'r') as data: | |
| R,t = data['dR'][str(index_within_seq)][()], data['dt'][str(index_within_seq)][()] | |
| egt = np.reshape(np.matmul(np.reshape(evaluation_utils.np_skew_symmetric(t.astype('float64').reshape(1, 3)), (3, 3)),np.reshape(R.astype('float64'), (3, 3))), (3, 3)) | |
| egt = egt / np.linalg.norm(egt) | |
| K1, K2 = data['K1'][str(index_within_seq)][()],data['K2'][str(index_within_seq)][()] | |
| size1,size2=data['size1'][str(index_within_seq)][()],data['size2'][str(index_within_seq)][()] | |
| img_path1,img_path2=data['img_path1'][str(index_within_seq)][()][0].decode(),data['img_path2'][str(index_within_seq)][()][0].decode() | |
| img_name1,img_name2=img_path1.split('/')[-1],img_path2.split('/')[-1] | |
| img_path1,img_path2=os.path.join(self.config.rawdata_path,img_path1),os.path.join(self.config.rawdata_path,img_path2) | |
| fea_path1,fea_path2=os.path.join(self.config.desc_path,seq,img_name1+self.config.desc_suffix),\ | |
| os.path.join(self.config.desc_path,seq,img_name2+self.config.desc_suffix) | |
| with h5py.File(fea_path1,'r') as fea1, h5py.File(fea_path2,'r') as fea2: | |
| desc1,kpt1,pscore1=fea1['descriptors'][()],fea1['keypoints'][()][:,:2],fea1['keypoints'][()][:,2] | |
| desc2,kpt2,pscore2=fea2['descriptors'][()],fea2['keypoints'][()][:,:2],fea2['keypoints'][()][:,2] | |
| kpt1,kpt2,desc1,desc2=kpt1[:self.config.num_kpt],kpt2[:self.config.num_kpt],desc1[:self.config.num_kpt],desc2[:self.config.num_kpt] | |
| # normalize kpt | |
| if self.config.input_normalize=='intrinsic': | |
| x1, x2 = np.concatenate([kpt1, np.ones([kpt1.shape[0], 1])], axis=-1), np.concatenate( | |
| [kpt2, np.ones([kpt2.shape[0], 1])], axis=-1) | |
| x1, x2 = np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], np.matmul(np.linalg.inv(K2), x2.T).T[:, :2] | |
| elif self.config.input_normalize=='img' : | |
| x1,x2=(kpt1-size1/2)/size1,(kpt2-size2/2)/size2 | |
| S1_inv,S2_inv=np.asarray([[size1[0],0,0.5*size1[0]],[0,size1[1],0.5*size1[1]],[0,0,1]]),\ | |
| np.asarray([[size2[0],0,0.5*size2[0]],[0,size2[1],0.5*size2[1]],[0,0,1]]) | |
| M1,M2=np.matmul(np.linalg.inv(K1),S1_inv),np.matmul(np.linalg.inv(K2),S2_inv) | |
| egt=np.matmul(np.matmul(M2.transpose(),egt),M1) | |
| egt = egt / np.linalg.norm(egt) | |
| else: | |
| raise NotImplementedError | |
| corr=data['corr'][str(index_within_seq)][()] | |
| incorr1,incorr2=data['incorr1'][str(index_within_seq)][()],data['incorr2'][str(index_within_seq)][()] | |
| #permute kpt | |
| valid_corr=corr[corr.max(axis=-1)<self.config.num_kpt] | |
| valid_incorr1,valid_incorr2=incorr1[incorr1<self.config.num_kpt],incorr2[incorr2<self.config.num_kpt] | |
| num_corr, num_incorr1, num_incorr2 = len(valid_corr), len(valid_incorr1), len(valid_incorr2) | |
| mask1_invlaid, mask2_invalid = np.ones(x1.shape[0]).astype(bool), np.ones(x2.shape[0]).astype(bool) | |
| mask1_invlaid[valid_corr[:, 0]] = False | |
| mask2_invalid[valid_corr[:, 1]] = False | |
| mask1_invlaid[valid_incorr1] = False | |
| mask2_invalid[valid_incorr2] = False | |
| invalid_index1,invalid_index2=np.nonzero(mask1_invlaid)[0],np.nonzero(mask2_invalid)[0] | |
| #random sample from point w/o valid annotation | |
| cur_kpt1 = self.config.num_kpt - num_corr - num_incorr1 | |
| cur_kpt2 = self.config.num_kpt - num_corr - num_incorr2 | |
| if (invalid_index1.shape[0] < cur_kpt1): | |
| sub_idx1 = np.concatenate([np.arange(len(invalid_index1)),np.random.randint(len(invalid_index1),size=cur_kpt1-len(invalid_index1))]) | |
| if (invalid_index1.shape[0] >= cur_kpt1): | |
| sub_idx1 =np.random.choice(len(invalid_index1), cur_kpt1,replace=False) | |
| if (invalid_index2.shape[0] < cur_kpt2): | |
| sub_idx2 = np.concatenate([np.arange(len(invalid_index2)),np.random.randint(len(invalid_index2),size=cur_kpt2-len(invalid_index2))]) | |
| if (invalid_index2.shape[0] >= cur_kpt2): | |
| sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2,replace=False) | |
| per_idx1,per_idx2=np.concatenate([valid_corr[:,0],valid_incorr1,invalid_index1[sub_idx1]]),\ | |
| np.concatenate([valid_corr[:,1],valid_incorr2,invalid_index2[sub_idx2]]) | |
| pscore1,pscore2=pscore1[per_idx1][:,np.newaxis],pscore2[per_idx2][:,np.newaxis] | |
| x1,x2=x1[per_idx1][:,:2],x2[per_idx2][:,:2] | |
| desc1,desc2=desc1[per_idx1],desc2[per_idx2] | |
| kpt1,kpt2=kpt1[per_idx1],kpt2[per_idx2] | |
| return {'x1': x1, 'x2': x2, 'kpt1':kpt1,'kpt2':kpt2,'desc1': desc1, 'desc2': desc2, 'num_corr': num_corr, 'num_incorr1': num_incorr1,'num_incorr2': num_incorr2,'e_gt':egt,\ | |
| 'pscore1':pscore1,'pscore2':pscore2,'img_path1':img_path1,'img_path2':img_path2} | |
| def __len__(self): | |
| return self.total_pairs | |