Vincentqyw
fix: roma
c74a070
import os
from torch.multiprocessing import Process, Manager, set_start_method, Pool
import functools
import argparse
import yaml
import numpy as np
import sys
import cv2
from tqdm import trange
set_start_method("spawn", force=True)
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)
from components import load_component
from utils import evaluation_utils, metrics
parser = argparse.ArgumentParser(description="dump eval data.")
parser.add_argument(
"--config_path", type=str, default="configs/eval/scannet_eval_sgm.yaml"
)
parser.add_argument("--num_process_match", type=int, default=4)
parser.add_argument("--num_process_eval", type=int, default=4)
parser.add_argument("--vis_folder", type=str, default=None)
args = parser.parse_args()
def feed_match(info, matcher):
x1, x2, desc1, desc2, size1, size2 = (
info["x1"],
info["x2"],
info["desc1"],
info["desc2"],
info["img1"].shape[:2],
info["img2"].shape[:2],
)
test_data = {
"x1": x1,
"x2": x2,
"desc1": desc1,
"desc2": desc2,
"size1": np.flip(np.asarray(size1)),
"size2": np.flip(np.asarray(size2)),
}
corr1, corr2 = matcher.run(test_data)
return [corr1, corr2]
def reader_handler(config, read_que):
reader = load_component("reader", config["name"], config)
for index in range(len(reader)):
index += 0
info = reader.run(index)
read_que.put(info)
read_que.put("over")
def match_handler(config, read_que, match_que):
matcher = load_component("matcher", config["name"], config)
match_func = functools.partial(feed_match, matcher=matcher)
pool = Pool(args.num_process_match)
cache = []
while True:
item = read_que.get()
# clear cache
if item == "over":
if len(cache) != 0:
results = pool.map(match_func, cache)
for cur_item, cur_result in zip(cache, results):
cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1]
match_que.put(cur_item)
match_que.put("over")
break
cache.append(item)
# print(len(cache))
if len(cache) == args.num_process_match:
# matching in parallel
results = pool.map(match_func, cache)
for cur_item, cur_result in zip(cache, results):
cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1]
match_que.put(cur_item)
cache = []
pool.close()
pool.join()
def evaluate_handler(config, match_que):
evaluator = load_component("evaluator", config["name"], config)
pool = Pool(args.num_process_eval)
cache = []
for _ in trange(config["num_pair"]):
item = match_que.get()
if item == "over":
if len(cache) != 0:
results = pool.map(evaluator.run, cache)
for cur_res in results:
evaluator.res_inqueue(cur_res)
break
cache.append(item)
if len(cache) == args.num_process_eval:
results = pool.map(evaluator.run, cache)
for cur_res in results:
evaluator.res_inqueue(cur_res)
cache = []
if args.vis_folder is not None:
# dump visualization
corr1_norm, corr2_norm = evaluation_utils.normalize_intrinsic(
item["corr1"], item["K1"]
), evaluation_utils.normalize_intrinsic(item["corr2"], item["K2"])
inlier_mask = metrics.compute_epi_inlier(
corr1_norm, corr2_norm, item["e"], config["inlier_th"]
)
display = evaluation_utils.draw_match(
item["img1"], item["img2"], item["corr1"], item["corr2"], inlier_mask
)
cv2.imwrite(
os.path.join(args.vis_folder, str(item["index"]) + ".png"), display
)
evaluator.parse()
if __name__ == "__main__":
with open(args.config_path, "r") as f:
config = yaml.load(f)
if args.vis_folder is not None and not os.path.exists(args.vis_folder):
os.mkdir(args.vis_folder)
read_que, match_que, estimate_que = (
Manager().Queue(maxsize=100),
Manager().Queue(maxsize=100),
Manager().Queue(maxsize=100),
)
read_process = Process(target=reader_handler, args=(config["reader"], read_que))
match_process = Process(
target=match_handler, args=(config["matcher"], read_que, match_que)
)
evaluate_process = Process(
target=evaluate_handler, args=(config["evaluator"], match_que)
)
read_process.start()
match_process.start()
evaluate_process.start()
read_process.join()
match_process.join()
evaluate_process.join()