Vincentqyw
fix: roma
8b973ee
raw history blame
No virus
4.73 kB
from abc import ABCMeta, abstractmethod
import os
import h5py
import numpy as np
from tqdm import trange
from torch.multiprocessing import Pool, set_start_method
set_start_method("spawn", force=True)
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
sys.path.insert(0, ROOT_DIR)
from components import load_component
class BaseDumper(metaclass=ABCMeta):
def __init__(self, config):
self.config = config
self.img_seq = []
self.dump_seq = [] # feature dump seq
@abstractmethod
def get_seqs(self):
raise NotImplementedError
@abstractmethod
def format_dump_folder(self):
raise NotImplementedError
@abstractmethod
def format_dump_data(self):
raise NotImplementedError
def initialize(self):
self.extractor = load_component(
"extractor", self.config["extractor"]["name"], self.config["extractor"]
)
self.get_seqs()
self.format_dump_folder()
def extract(self, index):
img_path, dump_path = self.img_seq[index], self.dump_seq[index]
if not self.config["extractor"]["overwrite"] and os.path.exists(dump_path):
return
kp, desc = self.extractor.run(img_path)
self.write_feature(kp, desc, dump_path)
def dump_feature(self):
print("Extrating features...")
self.num_img = len(self.dump_seq)
pool = Pool(self.config["extractor"]["num_process"])
iteration_num = self.num_img // self.config["extractor"]["num_process"]
if self.num_img % self.config["extractor"]["num_process"] != 0:
iteration_num += 1
for index in trange(iteration_num):
indicies_list = range(
index * self.config["extractor"]["num_process"],
min(
(index + 1) * self.config["extractor"]["num_process"], self.num_img
),
)
pool.map(self.extract, indicies_list)
pool.close()
pool.join()
def write_feature(self, pts, desc, filename):
with h5py.File(filename, "w") as ifp:
ifp.create_dataset("keypoints", pts.shape, dtype=np.float32)
ifp.create_dataset("descriptors", desc.shape, dtype=np.float32)
ifp["keypoints"][:] = pts
ifp["descriptors"][:] = desc
def form_standard_dataset(self):
dataset_path = os.path.join(
self.config["dataset_dump_dir"],
self.config["data_name"]
+ "_"
+ self.config["extractor"]["name"]
+ "_"
+ str(self.config["extractor"]["num_kpt"])
+ ".hdf5",
)
pair_data_type = ["K1", "K2", "R", "T", "e", "f"]
num_pairs = len(self.data["K1"])
with h5py.File(dataset_path, "w") as f:
print("collecting pair info...")
for type in pair_data_type:
dg = f.create_group(type)
for idx in range(num_pairs):
data_item = np.asarray(self.data[type][idx])
dg.create_dataset(
str(idx), data_item.shape, data_item.dtype, data=data_item
)
for type in ["img_path1", "img_path2"]:
dg = f.create_group(type)
for idx in range(num_pairs):
dg.create_dataset(
str(idx),
[1],
h5py.string_dtype(encoding="ascii"),
data=self.data[type][idx].encode("ascii"),
)
# dump desc
print("collecting desc and kpt...")
desc1_g, desc2_g, kpt1_g, kpt2_g = (
f.create_group("desc1"),
f.create_group("desc2"),
f.create_group("kpt1"),
f.create_group("kpt2"),
)
for idx in trange(num_pairs):
desc_file1, desc_file2 = h5py.File(
self.data["fea_path1"][idx], "r"
), h5py.File(self.data["fea_path2"][idx], "r")
desc1, desc2, kpt1, kpt2 = (
desc_file1["descriptors"][()],
desc_file2["descriptors"][()],
desc_file1["keypoints"][()],
desc_file2["keypoints"][()],
)
desc1_g.create_dataset(str(idx), desc1.shape, desc1.dtype, data=desc1)
desc2_g.create_dataset(str(idx), desc2.shape, desc2.dtype, data=desc2)
kpt1_g.create_dataset(str(idx), kpt1.shape, kpt1.dtype, data=kpt1)
kpt2_g.create_dataset(str(idx), kpt2.shape, kpt2.dtype, data=kpt2)