Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.utils.data as data | |
import cv2 | |
import os | |
import h5py | |
import random | |
import sys | |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) | |
sys.path.insert(0, ROOT_DIR) | |
from utils import train_utils,evaluation_utils | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
class Offline_Dataset(data.Dataset): | |
def __init__(self,config,mode): | |
assert mode=='train' or mode=='valid' | |
self.config = config | |
self.mode = mode | |
metadir=os.path.join(config.dataset_path,'valid') if mode=='valid' else os.path.join(config.dataset_path,'train') | |
pair_num_list=np.loadtxt(os.path.join(metadir,'pair_num.txt'),dtype=str) | |
self.total_pairs=int(pair_num_list[0,1]) | |
self.pair_seq_list,self.accu_pair_num=train_utils.parse_pair_seq(pair_num_list) | |
def collate_fn(self, batch): | |
batch_size, num_pts = len(batch), batch[0]['x1'].shape[0] | |
data = {} | |
dtype=['x1','x2','kpt1','kpt2','desc1','desc2','num_corr','num_incorr1','num_incorr2','e_gt','pscore1','pscore2','img_path1','img_path2'] | |
for key in dtype: | |
data[key]=[] | |
for sample in batch: | |
for key in dtype: | |
data[key].append(sample[key]) | |
for key in ['x1', 'x2','kpt1','kpt2', 'desc1', 'desc2','e_gt','pscore1','pscore2']: | |
data[key] = torch.from_numpy(np.stack(data[key])).float() | |
for key in ['num_corr', 'num_incorr1', 'num_incorr2']: | |
data[key] = torch.from_numpy(np.stack(data[key])).int() | |
# kpt augmentation with random homography | |
if (self.mode == 'train' and self.config.data_aug): | |
homo_mat = torch.from_numpy(train_utils.get_rnd_homography(batch_size)).unsqueeze(1) | |
aug_seed=random.random() | |
if aug_seed<0.5: | |
x1_homo = torch.cat([data['x1'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1) | |
x1_homo = torch.matmul(homo_mat.float(), x1_homo.float()).squeeze(-1) | |
data['aug_x1'] = x1_homo[:, :, :2] / x1_homo[:, :, 2].unsqueeze(-1) | |
data['aug_x2']=data['x2'] | |
else: | |
x2_homo = torch.cat([data['x2'], torch.ones([batch_size, num_pts, 1])], dim=-1).unsqueeze(-1) | |
x2_homo = torch.matmul(homo_mat.float(), x2_homo.float()).squeeze(-1) | |
data['aug_x2'] = x2_homo[:, :, :2] / x2_homo[:, :, 2].unsqueeze(-1) | |
data['aug_x1']=data['x1'] | |
else: | |
data['aug_x1'],data['aug_x2']=data['x1'],data['x2'] | |
return data | |
def __getitem__(self, index): | |
seq=self.pair_seq_list[index] | |
index_within_seq=index-self.accu_pair_num[seq] | |
with h5py.File(os.path.join(self.config.dataset_path,seq,'info.h5py'),'r') as data: | |
R,t = data['dR'][str(index_within_seq)][()], data['dt'][str(index_within_seq)][()] | |
egt = np.reshape(np.matmul(np.reshape(evaluation_utils.np_skew_symmetric(t.astype('float64').reshape(1, 3)), (3, 3)),np.reshape(R.astype('float64'), (3, 3))), (3, 3)) | |
egt = egt / np.linalg.norm(egt) | |
K1, K2 = data['K1'][str(index_within_seq)][()],data['K2'][str(index_within_seq)][()] | |
size1,size2=data['size1'][str(index_within_seq)][()],data['size2'][str(index_within_seq)][()] | |
img_path1,img_path2=data['img_path1'][str(index_within_seq)][()][0].decode(),data['img_path2'][str(index_within_seq)][()][0].decode() | |
img_name1,img_name2=img_path1.split('/')[-1],img_path2.split('/')[-1] | |
img_path1,img_path2=os.path.join(self.config.rawdata_path,img_path1),os.path.join(self.config.rawdata_path,img_path2) | |
fea_path1,fea_path2=os.path.join(self.config.desc_path,seq,img_name1+self.config.desc_suffix),\ | |
os.path.join(self.config.desc_path,seq,img_name2+self.config.desc_suffix) | |
with h5py.File(fea_path1,'r') as fea1, h5py.File(fea_path2,'r') as fea2: | |
desc1,kpt1,pscore1=fea1['descriptors'][()],fea1['keypoints'][()][:,:2],fea1['keypoints'][()][:,2] | |
desc2,kpt2,pscore2=fea2['descriptors'][()],fea2['keypoints'][()][:,:2],fea2['keypoints'][()][:,2] | |
kpt1,kpt2,desc1,desc2=kpt1[:self.config.num_kpt],kpt2[:self.config.num_kpt],desc1[:self.config.num_kpt],desc2[:self.config.num_kpt] | |
# normalize kpt | |
if self.config.input_normalize=='intrinsic': | |
x1, x2 = np.concatenate([kpt1, np.ones([kpt1.shape[0], 1])], axis=-1), np.concatenate( | |
[kpt2, np.ones([kpt2.shape[0], 1])], axis=-1) | |
x1, x2 = np.matmul(np.linalg.inv(K1), x1.T).T[:, :2], np.matmul(np.linalg.inv(K2), x2.T).T[:, :2] | |
elif self.config.input_normalize=='img' : | |
x1,x2=(kpt1-size1/2)/size1,(kpt2-size2/2)/size2 | |
S1_inv,S2_inv=np.asarray([[size1[0],0,0.5*size1[0]],[0,size1[1],0.5*size1[1]],[0,0,1]]),\ | |
np.asarray([[size2[0],0,0.5*size2[0]],[0,size2[1],0.5*size2[1]],[0,0,1]]) | |
M1,M2=np.matmul(np.linalg.inv(K1),S1_inv),np.matmul(np.linalg.inv(K2),S2_inv) | |
egt=np.matmul(np.matmul(M2.transpose(),egt),M1) | |
egt = egt / np.linalg.norm(egt) | |
else: | |
raise NotImplementedError | |
corr=data['corr'][str(index_within_seq)][()] | |
incorr1,incorr2=data['incorr1'][str(index_within_seq)][()],data['incorr2'][str(index_within_seq)][()] | |
#permute kpt | |
valid_corr=corr[corr.max(axis=-1)<self.config.num_kpt] | |
valid_incorr1,valid_incorr2=incorr1[incorr1<self.config.num_kpt],incorr2[incorr2<self.config.num_kpt] | |
num_corr, num_incorr1, num_incorr2 = len(valid_corr), len(valid_incorr1), len(valid_incorr2) | |
mask1_invlaid, mask2_invalid = np.ones(x1.shape[0]).astype(bool), np.ones(x2.shape[0]).astype(bool) | |
mask1_invlaid[valid_corr[:, 0]] = False | |
mask2_invalid[valid_corr[:, 1]] = False | |
mask1_invlaid[valid_incorr1] = False | |
mask2_invalid[valid_incorr2] = False | |
invalid_index1,invalid_index2=np.nonzero(mask1_invlaid)[0],np.nonzero(mask2_invalid)[0] | |
#random sample from point w/o valid annotation | |
cur_kpt1 = self.config.num_kpt - num_corr - num_incorr1 | |
cur_kpt2 = self.config.num_kpt - num_corr - num_incorr2 | |
if (invalid_index1.shape[0] < cur_kpt1): | |
sub_idx1 = np.concatenate([np.arange(len(invalid_index1)),np.random.randint(len(invalid_index1),size=cur_kpt1-len(invalid_index1))]) | |
if (invalid_index1.shape[0] >= cur_kpt1): | |
sub_idx1 =np.random.choice(len(invalid_index1), cur_kpt1,replace=False) | |
if (invalid_index2.shape[0] < cur_kpt2): | |
sub_idx2 = np.concatenate([np.arange(len(invalid_index2)),np.random.randint(len(invalid_index2),size=cur_kpt2-len(invalid_index2))]) | |
if (invalid_index2.shape[0] >= cur_kpt2): | |
sub_idx2 = np.random.choice(len(invalid_index2), cur_kpt2,replace=False) | |
per_idx1,per_idx2=np.concatenate([valid_corr[:,0],valid_incorr1,invalid_index1[sub_idx1]]),\ | |
np.concatenate([valid_corr[:,1],valid_incorr2,invalid_index2[sub_idx2]]) | |
pscore1,pscore2=pscore1[per_idx1][:,np.newaxis],pscore2[per_idx2][:,np.newaxis] | |
x1,x2=x1[per_idx1][:,:2],x2[per_idx2][:,:2] | |
desc1,desc2=desc1[per_idx1],desc2[per_idx2] | |
kpt1,kpt2=kpt1[per_idx1],kpt2[per_idx2] | |
return {'x1': x1, 'x2': x2, 'kpt1':kpt1,'kpt2':kpt2,'desc1': desc1, 'desc2': desc2, 'num_corr': num_corr, 'num_incorr1': num_incorr1,'num_incorr2': num_incorr2,'e_gt':egt,\ | |
'pscore1':pscore1,'pscore2':pscore2,'img_path1':img_path1,'img_path2':img_path2} | |
def __len__(self): | |
return self.total_pairs | |