File size: 4,725 Bytes
a80d6bb
 
 
 
 
c74a070
 
 
a80d6bb
 
c74a070
a80d6bb
 
 
 
 
 
c74a070
 
 
 
 
a80d6bb
 
 
c74a070
a80d6bb
 
 
 
 
 
 
c74a070
a80d6bb
c74a070
 
 
a80d6bb
 
 
c74a070
 
 
a80d6bb
 
c74a070
a80d6bb
 
c74a070
 
 
 
 
 
a80d6bb
c74a070
 
 
 
 
 
 
a80d6bb
 
 
c74a070
a80d6bb
c74a070
 
a80d6bb
 
 
 
c74a070
 
 
 
 
 
 
 
 
 
 
 
 
 
a80d6bb
c74a070
a80d6bb
c74a070
 
 
 
a80d6bb
c74a070
 
a80d6bb
c74a070
 
 
 
 
 
a80d6bb
c74a070
 
 
 
 
 
 
 
a80d6bb
c74a070
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from abc import ABCMeta, abstractmethod
import os
import h5py
import numpy as np
from tqdm import trange
from torch.multiprocessing import Pool, set_start_method

set_start_method("spawn", force=True)

import sys

ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
sys.path.insert(0, ROOT_DIR)
from components import load_component


class BaseDumper(metaclass=ABCMeta):
    def __init__(self, config):
        self.config = config
        self.img_seq = []
        self.dump_seq = []  # feature dump seq

    @abstractmethod
    def get_seqs(self):
        raise NotImplementedError

    @abstractmethod
    def format_dump_folder(self):
        raise NotImplementedError

    @abstractmethod
    def format_dump_data(self):
        raise NotImplementedError

    def initialize(self):
        self.extractor = load_component(
            "extractor", self.config["extractor"]["name"], self.config["extractor"]
        )
        self.get_seqs()
        self.format_dump_folder()

    def extract(self, index):
        img_path, dump_path = self.img_seq[index], self.dump_seq[index]
        if not self.config["extractor"]["overwrite"] and os.path.exists(dump_path):
            return
        kp, desc = self.extractor.run(img_path)
        self.write_feature(kp, desc, dump_path)

    def dump_feature(self):
        print("Extrating features...")
        self.num_img = len(self.dump_seq)
        pool = Pool(self.config["extractor"]["num_process"])
        iteration_num = self.num_img // self.config["extractor"]["num_process"]
        if self.num_img % self.config["extractor"]["num_process"] != 0:
            iteration_num += 1
        for index in trange(iteration_num):
            indicies_list = range(
                index * self.config["extractor"]["num_process"],
                min(
                    (index + 1) * self.config["extractor"]["num_process"], self.num_img
                ),
            )
            pool.map(self.extract, indicies_list)
        pool.close()
        pool.join()

    def write_feature(self, pts, desc, filename):
        with h5py.File(filename, "w") as ifp:
            ifp.create_dataset("keypoints", pts.shape, dtype=np.float32)
            ifp.create_dataset("descriptors", desc.shape, dtype=np.float32)
            ifp["keypoints"][:] = pts
            ifp["descriptors"][:] = desc

    def form_standard_dataset(self):
        dataset_path = os.path.join(
            self.config["dataset_dump_dir"],
            self.config["data_name"]
            + "_"
            + self.config["extractor"]["name"]
            + "_"
            + str(self.config["extractor"]["num_kpt"])
            + ".hdf5",
        )

        pair_data_type = ["K1", "K2", "R", "T", "e", "f"]
        num_pairs = len(self.data["K1"])
        with h5py.File(dataset_path, "w") as f:
            print("collecting pair info...")
            for type in pair_data_type:
                dg = f.create_group(type)
                for idx in range(num_pairs):
                    data_item = np.asarray(self.data[type][idx])
                    dg.create_dataset(
                        str(idx), data_item.shape, data_item.dtype, data=data_item
                    )

            for type in ["img_path1", "img_path2"]:
                dg = f.create_group(type)
                for idx in range(num_pairs):
                    dg.create_dataset(
                        str(idx),
                        [1],
                        h5py.string_dtype(encoding="ascii"),
                        data=self.data[type][idx].encode("ascii"),
                    )

            # dump desc
            print("collecting desc and kpt...")
            desc1_g, desc2_g, kpt1_g, kpt2_g = (
                f.create_group("desc1"),
                f.create_group("desc2"),
                f.create_group("kpt1"),
                f.create_group("kpt2"),
            )
            for idx in trange(num_pairs):
                desc_file1, desc_file2 = h5py.File(
                    self.data["fea_path1"][idx], "r"
                ), h5py.File(self.data["fea_path2"][idx], "r")
                desc1, desc2, kpt1, kpt2 = (
                    desc_file1["descriptors"][()],
                    desc_file2["descriptors"][()],
                    desc_file1["keypoints"][()],
                    desc_file2["keypoints"][()],
                )
                desc1_g.create_dataset(str(idx), desc1.shape, desc1.dtype, data=desc1)
                desc2_g.create_dataset(str(idx), desc2.shape, desc2.dtype, data=desc2)
                kpt1_g.create_dataset(str(idx), kpt1.shape, kpt1.dtype, data=kpt1)
                kpt2_g.create_dataset(str(idx), kpt2.shape, kpt2.dtype, data=kpt2)