|
from abc import ABCMeta, abstractmethod |
|
import os |
|
import h5py |
|
import numpy as np |
|
from tqdm import trange |
|
from torch.multiprocessing import Pool,set_start_method |
|
set_start_method('spawn',force=True) |
|
|
|
import sys |
|
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) |
|
sys.path.insert(0, ROOT_DIR) |
|
from components import load_component |
|
|
|
|
|
class BaseDumper(metaclass=ABCMeta): |
|
def __init__(self,config): |
|
self.config=config |
|
self.img_seq=[] |
|
self.dump_seq=[] |
|
|
|
@abstractmethod |
|
def get_seqs(self): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def format_dump_folder(self): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def format_dump_data(self): |
|
raise NotImplementedError |
|
|
|
def initialize(self): |
|
self.extractor=load_component('extractor',self.config['extractor']['name'],self.config['extractor']) |
|
self.get_seqs() |
|
self.format_dump_folder() |
|
|
|
|
|
def extract(self,index): |
|
img_path,dump_path=self.img_seq[index],self.dump_seq[index] |
|
if not self.config['extractor']['overwrite'] and os.path.exists(dump_path): |
|
return |
|
kp, desc = self.extractor.run(img_path) |
|
self.write_feature(kp,desc,dump_path) |
|
|
|
def dump_feature(self): |
|
print('Extrating features...') |
|
self.num_img=len(self.dump_seq) |
|
pool=Pool(self.config['extractor']['num_process']) |
|
iteration_num=self.num_img//self.config['extractor']['num_process'] |
|
if self.num_img%self.config['extractor']['num_process']!=0: |
|
iteration_num+=1 |
|
for index in trange(iteration_num): |
|
indicies_list=range(index*self.config['extractor']['num_process'],min((index+1)*self.config['extractor']['num_process'],self.num_img)) |
|
pool.map(self.extract,indicies_list) |
|
pool.close() |
|
pool.join() |
|
|
|
def write_feature(self,pts, desc, filename): |
|
with h5py.File(filename, "w") as ifp: |
|
ifp.create_dataset('keypoints', pts.shape, dtype=np.float32) |
|
ifp.create_dataset('descriptors', desc.shape, dtype=np.float32) |
|
ifp["keypoints"][:] = pts |
|
ifp["descriptors"][:] = desc |
|
|
|
def form_standard_dataset(self): |
|
dataset_path=os.path.join(self.config['dataset_dump_dir'],self.config['data_name']+\ |
|
'_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])+'.hdf5') |
|
|
|
pair_data_type=['K1','K2','R','T','e','f'] |
|
num_pairs=len(self.data['K1']) |
|
with h5py.File(dataset_path, 'w') as f: |
|
print('collecting pair info...') |
|
for type in pair_data_type: |
|
dg=f.create_group(type) |
|
for idx in range(num_pairs): |
|
data_item=np.asarray(self.data[type][idx]) |
|
dg.create_dataset(str(idx),data_item.shape,data_item.dtype,data=data_item) |
|
|
|
for type in ['img_path1','img_path2']: |
|
dg=f.create_group(type) |
|
for idx in range(num_pairs): |
|
dg.create_dataset(str(idx),[1],h5py.string_dtype(encoding='ascii'),data=self.data[type][idx].encode('ascii')) |
|
|
|
|
|
print('collecting desc and kpt...') |
|
desc1_g,desc2_g,kpt1_g,kpt2_g=f.create_group('desc1'),f.create_group('desc2'),f.create_group('kpt1'),f.create_group('kpt2') |
|
for idx in trange(num_pairs): |
|
desc_file1,desc_file2=h5py.File(self.data['fea_path1'][idx],'r'),h5py.File(self.data['fea_path2'][idx],'r') |
|
desc1,desc2,kpt1,kpt2=desc_file1['descriptors'][()],desc_file2['descriptors'][()],desc_file1['keypoints'][()],desc_file2['keypoints'][()] |
|
desc1_g.create_dataset(str(idx),desc1.shape,desc1.dtype,data=desc1) |
|
desc2_g.create_dataset(str(idx),desc2.shape,desc2.dtype,data=desc2) |
|
kpt1_g.create_dataset(str(idx),kpt1.shape,kpt1.dtype,data=kpt1) |
|
kpt2_g.create_dataset(str(idx),kpt2.shape,kpt2.dtype,data=kpt2) |
|
|
|
|
|
|
|
|