import os import glob import math import re import numpy as np import h5py from tqdm import trange from torch.multiprocessing import Pool import pyxis as px from .base_dumper import BaseDumper import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) sys.path.insert(0, ROOT_DIR) from utils import transformations,data_utils class gl3d_train(BaseDumper): def get_seqs(self): data_dir=os.path.join(self.config['rawdata_dir'],'data') seq_train=np.loadtxt(os.path.join(self.config['rawdata_dir'],'list','comb','imageset_train.txt'),dtype=str) seq_valid=np.loadtxt(os.path.join(self.config['rawdata_dir'],'list','comb','imageset_test.txt'),dtype=str) #filtering seq list self.seq_list,self.train_list,self.valid_list=[],[],[] for seq in seq_train: if seq not in self.config['exclude_seq']: self.train_list.append(seq) for seq in seq_valid: if seq not in self.config['exclude_seq']: self.valid_list.append(seq) seq_list=[] if self.config['dump_train']: seq_list.append(self.train_list) if self.config['dump_valid']: seq_list.append(self.valid_list) self.seq_list=np.concatenate(seq_list,axis=0) #self.seq_list=self.seq_list[:2] #self.valid_list=self.valid_list[:2] for seq in self.seq_list: dump_dir=os.path.join(self.config['feature_dump_dir'],seq) cur_img_seq=glob.glob(os.path.join(data_dir,seq,'undist_images','*.jpg')) cur_dump_seq=[os.path.join(dump_dir,path.split('/')[-1])+'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])\ +'.hdf5' for path in cur_img_seq] self.img_seq+=cur_img_seq self.dump_seq+=cur_dump_seq def format_dump_folder(self): if not os.path.exists(self.config['feature_dump_dir']): os.mkdir(self.config['feature_dump_dir']) for seq in self.seq_list: seq_dir=os.path.join(self.config['feature_dump_dir'],seq) if not os.path.exists(seq_dir): os.mkdir(seq_dir) if not os.path.exists(self.config['dataset_dump_dir']): os.mkdir(self.config['dataset_dump_dir']) def load_geom(self,seq): # load geometry file geom_file=os.path.join(self.config['rawdata_dir'],'data',seq,'geolabel','cameras.txt') basename_list=np.loadtxt(os.path.join(self.config['rawdata_dir'],'data',seq,'basenames.txt'),dtype=str) geom_dict = [] cameras = np.loadtxt(geom_file) camera_index=0 for base_index in range(len(basename_list)): if base_indexself.config['angle_th'][0],angle_listself.config['overlap_th'][0],overlap_scoreself.config['min_corr'] and len(incorr_index1)>self.config['min_incorr'] and len(incorr_index2)>self.config['min_incorr']: info['corr'].append(corr_index),info['incorr1'].append(incorr_index1),info['incorr2'].append(incorr_index2) info['dR'].append(dR),info['dt'].append(dt),info['K1'].append(K1),info['K2'].append(K2),info['img_path1'].append(img_path1),info['img_path2'].append(img_path2) info['fea_path1'].append(fea_path1),info['fea_path2'].append(fea_path2),info['size1'].append(size1),info['size2'].append(size2) sample_number+=1 if sample_number==sample_target: break info['pair_num']=sample_number #dump info self.dump_info(seq,info) def collect_meta(self): print('collecting meta info...') dump_path,seq_list=[],[] if self.config['dump_train']: dump_path.append(os.path.join(self.config['dataset_dump_dir'],'train')) seq_list.append(self.train_list) if self.config['dump_valid']: dump_path.append(os.path.join(self.config['dataset_dump_dir'],'valid')) seq_list.append(self.valid_list) for pth,seqs in zip(dump_path,seq_list): if not os.path.exists(pth): os.mkdir(pth) pair_num_list,total_pair=[],0 for seq_index in range(len(seqs)): seq=seqs[seq_index] pair_num=np.loadtxt(os.path.join(self.config['dataset_dump_dir'],seq,'pair_num.txt'),dtype=int) pair_num_list.append(str(pair_num)) total_pair+=pair_num pair_num_list=np.stack([np.asarray(seqs,dtype=str),np.asarray(pair_num_list,dtype=str)],axis=1) pair_num_list=np.concatenate([np.asarray([['total',str(total_pair)]]),pair_num_list],axis=0) np.savetxt(os.path.join(pth,'pair_num.txt'),pair_num_list,fmt='%s') def format_dump_data(self): print('Formatting data...') iteration_num=len(self.seq_list)//self.config['num_process'] if len(self.seq_list)%self.config['num_process']!=0: iteration_num+=1 pool=Pool(self.config['num_process']) for index in trange(iteration_num): indices=range(index*self.config['num_process'],min((index+1)*self.config['num_process'],len(self.seq_list))) pool.map(self.format_seq,indices) pool.close() pool.join() self.collect_meta()