Spaces:
Running
Running
File size: 4,157 Bytes
a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
from torch.multiprocessing import Process,Manager,set_start_method,Pool
import functools
import argparse
import yaml
import numpy as np
import sys
import cv2
from tqdm import trange
set_start_method('spawn',force=True)
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)
from components import load_component
from utils import evaluation_utils,metrics
parser = argparse.ArgumentParser(description='dump eval data.')
parser.add_argument('--config_path', type=str, default='configs/eval/scannet_eval_sgm.yaml')
parser.add_argument('--num_process_match', type=int, default=4)
parser.add_argument('--num_process_eval', type=int, default=4)
parser.add_argument('--vis_folder',type=str,default=None)
args=parser.parse_args()
def feed_match(info,matcher):
x1,x2,desc1,desc2,size1,size2=info['x1'],info['x2'],info['desc1'],info['desc2'],info['img1'].shape[:2],info['img2'].shape[:2]
test_data = {'x1': x1,'x2': x2,'desc1': desc1,'desc2': desc2,'size1':np.flip(np.asarray(size1)),'size2':np.flip(np.asarray(size2)) }
corr1,corr2=matcher.run(test_data)
return [corr1,corr2]
def reader_handler(config,read_que):
reader=load_component('reader',config['name'],config)
for index in range(len(reader)):
index+=0
info=reader.run(index)
read_que.put(info)
read_que.put('over')
def match_handler(config,read_que,match_que):
matcher=load_component('matcher',config['name'],config)
match_func=functools.partial(feed_match,matcher=matcher)
pool = Pool(args.num_process_match)
cache=[]
while True:
item=read_que.get()
#clear cache
if item=='over':
if len(cache)!=0:
results=pool.map(match_func,cache)
for cur_item,cur_result in zip(cache,results):
cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1]
match_que.put(cur_item)
match_que.put('over')
break
cache.append(item)
#print(len(cache))
if len(cache)==args.num_process_match:
#matching in parallel
results=pool.map(match_func,cache)
for cur_item,cur_result in zip(cache,results):
cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1]
match_que.put(cur_item)
cache=[]
pool.close()
pool.join()
def evaluate_handler(config,match_que):
evaluator=load_component('evaluator',config['name'],config)
pool = Pool(args.num_process_eval)
cache=[]
for _ in trange(config['num_pair']):
item=match_que.get()
if item=='over':
if len(cache)!=0:
results=pool.map(evaluator.run,cache)
for cur_res in results:
evaluator.res_inqueue(cur_res)
break
cache.append(item)
if len(cache)==args.num_process_eval:
results=pool.map(evaluator.run,cache)
for cur_res in results:
evaluator.res_inqueue(cur_res)
cache=[]
if args.vis_folder is not None:
#dump visualization
corr1_norm,corr2_norm=evaluation_utils.normalize_intrinsic(item['corr1'],item['K1']),\
evaluation_utils.normalize_intrinsic(item['corr2'],item['K2'])
inlier_mask=metrics.compute_epi_inlier(corr1_norm,corr2_norm,item['e'],config['inlier_th'])
display=evaluation_utils.draw_match(item['img1'],item['img2'],item['corr1'],item['corr2'],inlier_mask)
cv2.imwrite(os.path.join(args.vis_folder,str(item['index'])+'.png'),display)
evaluator.parse()
if __name__=='__main__':
with open(args.config_path, 'r') as f:
config = yaml.load(f)
if args.vis_folder is not None and not os.path.exists(args.vis_folder):
os.mkdir(args.vis_folder)
read_que,match_que,estimate_que=Manager().Queue(maxsize=100),Manager().Queue(maxsize=100),Manager().Queue(maxsize=100)
read_process=Process(target=reader_handler,args=(config['reader'],read_que))
match_process=Process(target=match_handler,args=(config['matcher'],read_que,match_que))
evaluate_process=Process(target=evaluate_handler,args=(config['evaluator'],match_que))
read_process.start()
match_process.start()
evaluate_process.start()
read_process.join()
match_process.join()
evaluate_process.join() |