Spaces:
Running
Running
import os | |
from torch.multiprocessing import Process,Manager,set_start_method,Pool | |
import functools | |
import argparse | |
import yaml | |
import numpy as np | |
import sys | |
import cv2 | |
from tqdm import trange | |
set_start_method('spawn',force=True) | |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.insert(0, ROOT_DIR) | |
from components import load_component | |
from utils import evaluation_utils,metrics | |
parser = argparse.ArgumentParser(description='dump eval data.') | |
parser.add_argument('--config_path', type=str, default='configs/eval/scannet_eval_sgm.yaml') | |
parser.add_argument('--num_process_match', type=int, default=4) | |
parser.add_argument('--num_process_eval', type=int, default=4) | |
parser.add_argument('--vis_folder',type=str,default=None) | |
args=parser.parse_args() | |
def feed_match(info,matcher): | |
x1,x2,desc1,desc2,size1,size2=info['x1'],info['x2'],info['desc1'],info['desc2'],info['img1'].shape[:2],info['img2'].shape[:2] | |
test_data = {'x1': x1,'x2': x2,'desc1': desc1,'desc2': desc2,'size1':np.flip(np.asarray(size1)),'size2':np.flip(np.asarray(size2)) } | |
corr1,corr2=matcher.run(test_data) | |
return [corr1,corr2] | |
def reader_handler(config,read_que): | |
reader=load_component('reader',config['name'],config) | |
for index in range(len(reader)): | |
index+=0 | |
info=reader.run(index) | |
read_que.put(info) | |
read_que.put('over') | |
def match_handler(config,read_que,match_que): | |
matcher=load_component('matcher',config['name'],config) | |
match_func=functools.partial(feed_match,matcher=matcher) | |
pool = Pool(args.num_process_match) | |
cache=[] | |
while True: | |
item=read_que.get() | |
#clear cache | |
if item=='over': | |
if len(cache)!=0: | |
results=pool.map(match_func,cache) | |
for cur_item,cur_result in zip(cache,results): | |
cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1] | |
match_que.put(cur_item) | |
match_que.put('over') | |
break | |
cache.append(item) | |
#print(len(cache)) | |
if len(cache)==args.num_process_match: | |
#matching in parallel | |
results=pool.map(match_func,cache) | |
for cur_item,cur_result in zip(cache,results): | |
cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1] | |
match_que.put(cur_item) | |
cache=[] | |
pool.close() | |
pool.join() | |
def evaluate_handler(config,match_que): | |
evaluator=load_component('evaluator',config['name'],config) | |
pool = Pool(args.num_process_eval) | |
cache=[] | |
for _ in trange(config['num_pair']): | |
item=match_que.get() | |
if item=='over': | |
if len(cache)!=0: | |
results=pool.map(evaluator.run,cache) | |
for cur_res in results: | |
evaluator.res_inqueue(cur_res) | |
break | |
cache.append(item) | |
if len(cache)==args.num_process_eval: | |
results=pool.map(evaluator.run,cache) | |
for cur_res in results: | |
evaluator.res_inqueue(cur_res) | |
cache=[] | |
if args.vis_folder is not None: | |
#dump visualization | |
corr1_norm,corr2_norm=evaluation_utils.normalize_intrinsic(item['corr1'],item['K1']),\ | |
evaluation_utils.normalize_intrinsic(item['corr2'],item['K2']) | |
inlier_mask=metrics.compute_epi_inlier(corr1_norm,corr2_norm,item['e'],config['inlier_th']) | |
display=evaluation_utils.draw_match(item['img1'],item['img2'],item['corr1'],item['corr2'],inlier_mask) | |
cv2.imwrite(os.path.join(args.vis_folder,str(item['index'])+'.png'),display) | |
evaluator.parse() | |
if __name__=='__main__': | |
with open(args.config_path, 'r') as f: | |
config = yaml.load(f) | |
if args.vis_folder is not None and not os.path.exists(args.vis_folder): | |
os.mkdir(args.vis_folder) | |
read_que,match_que,estimate_que=Manager().Queue(maxsize=100),Manager().Queue(maxsize=100),Manager().Queue(maxsize=100) | |
read_process=Process(target=reader_handler,args=(config['reader'],read_que)) | |
match_process=Process(target=match_handler,args=(config['matcher'],read_que,match_que)) | |
evaluate_process=Process(target=evaluate_handler,args=(config['evaluator'],match_que)) | |
read_process.start() | |
match_process.start() | |
evaluate_process.start() | |
read_process.join() | |
match_process.join() | |
evaluate_process.join() |