File size: 4,093 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from abc import ABCMeta, abstractmethod
import os
import h5py
import numpy as np
from tqdm import trange
from torch.multiprocessing import Pool,set_start_method
set_start_method('spawn',force=True)

import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
sys.path.insert(0, ROOT_DIR)
from components import load_component


class BaseDumper(metaclass=ABCMeta):
    def __init__(self,config):
        self.config=config
        self.img_seq=[]
        self.dump_seq=[]#feature dump seq
    
    @abstractmethod
    def get_seqs(self):
        raise NotImplementedError
    
    @abstractmethod
    def format_dump_folder(self):
        raise NotImplementedError

    @abstractmethod
    def format_dump_data(self):
        raise NotImplementedError
    
    def initialize(self):
        self.extractor=load_component('extractor',self.config['extractor']['name'],self.config['extractor'])
        self.get_seqs()
        self.format_dump_folder()


    def extract(self,index):
        img_path,dump_path=self.img_seq[index],self.dump_seq[index]
        if not self.config['extractor']['overwrite'] and os.path.exists(dump_path):
            return
        kp, desc = self.extractor.run(img_path)
        self.write_feature(kp,desc,dump_path)

    def dump_feature(self):
        print('Extrating features...')
        self.num_img=len(self.dump_seq)
        pool=Pool(self.config['extractor']['num_process'])
        iteration_num=self.num_img//self.config['extractor']['num_process']
        if self.num_img%self.config['extractor']['num_process']!=0:
            iteration_num+=1
        for index in trange(iteration_num):
            indicies_list=range(index*self.config['extractor']['num_process'],min((index+1)*self.config['extractor']['num_process'],self.num_img))
            pool.map(self.extract,indicies_list)
        pool.close()
        pool.join()

    def write_feature(self,pts, desc, filename):
        with h5py.File(filename, "w") as ifp:
            ifp.create_dataset('keypoints', pts.shape, dtype=np.float32)
            ifp.create_dataset('descriptors', desc.shape, dtype=np.float32)
            ifp["keypoints"][:] = pts
            ifp["descriptors"][:] = desc

    def form_standard_dataset(self):
        dataset_path=os.path.join(self.config['dataset_dump_dir'],self.config['data_name']+\
                                '_'+self.config['extractor']['name']+'_'+str(self.config['extractor']['num_kpt'])+'.hdf5')
        
        pair_data_type=['K1','K2','R','T','e','f']
        num_pairs=len(self.data['K1'])
        with h5py.File(dataset_path, 'w') as f:
            print('collecting pair info...')
            for type in pair_data_type:
                dg=f.create_group(type)
                for idx in range(num_pairs):
                    data_item=np.asarray(self.data[type][idx])
                    dg.create_dataset(str(idx),data_item.shape,data_item.dtype,data=data_item)

            for type in ['img_path1','img_path2']:
                dg=f.create_group(type)
                for idx in range(num_pairs):
                    dg.create_dataset(str(idx),[1],h5py.string_dtype(encoding='ascii'),data=self.data[type][idx].encode('ascii'))

            #dump desc
            print('collecting desc and kpt...')
            desc1_g,desc2_g,kpt1_g,kpt2_g=f.create_group('desc1'),f.create_group('desc2'),f.create_group('kpt1'),f.create_group('kpt2')
            for idx in trange(num_pairs):
                desc_file1,desc_file2=h5py.File(self.data['fea_path1'][idx],'r'),h5py.File(self.data['fea_path2'][idx],'r')
                desc1,desc2,kpt1,kpt2=desc_file1['descriptors'][()],desc_file2['descriptors'][()],desc_file1['keypoints'][()],desc_file2['keypoints'][()]
                desc1_g.create_dataset(str(idx),desc1.shape,desc1.dtype,data=desc1)
                desc2_g.create_dataset(str(idx),desc2.shape,desc2.dtype,data=desc2)
                kpt1_g.create_dataset(str(idx),kpt1.shape,kpt1.dtype,data=kpt1)
                kpt2_g.create_dataset(str(idx),kpt2.shape,kpt2.dtype,data=kpt2)