File size: 4,157 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from torch.multiprocessing import Process,Manager,set_start_method,Pool
import functools
import argparse
import yaml
import numpy as np
import sys
import cv2
from tqdm import trange
set_start_method('spawn',force=True)


ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)

from components import load_component
from utils import evaluation_utils,metrics

parser = argparse.ArgumentParser(description='dump eval data.')
parser.add_argument('--config_path', type=str, default='configs/eval/scannet_eval_sgm.yaml')
parser.add_argument('--num_process_match', type=int, default=4)
parser.add_argument('--num_process_eval', type=int, default=4)
parser.add_argument('--vis_folder',type=str,default=None)
args=parser.parse_args()    

def feed_match(info,matcher):
    x1,x2,desc1,desc2,size1,size2=info['x1'],info['x2'],info['desc1'],info['desc2'],info['img1'].shape[:2],info['img2'].shape[:2]
    test_data = {'x1': x1,'x2': x2,'desc1': desc1,'desc2': desc2,'size1':np.flip(np.asarray(size1)),'size2':np.flip(np.asarray(size2)) }
    corr1,corr2=matcher.run(test_data)
    return [corr1,corr2]


def reader_handler(config,read_que):
  reader=load_component('reader',config['name'],config)
  for index in range(len(reader)):
    index+=0
    info=reader.run(index)
    read_que.put(info)
  read_que.put('over')


def match_handler(config,read_que,match_que):
  matcher=load_component('matcher',config['name'],config)
  match_func=functools.partial(feed_match,matcher=matcher)
  pool = Pool(args.num_process_match)
  cache=[]
  while True:
    item=read_que.get()
    #clear cache
    if item=='over':
      if len(cache)!=0:
        results=pool.map(match_func,cache)
        for cur_item,cur_result in zip(cache,results):
          cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1]
          match_que.put(cur_item)
      match_que.put('over')
      break
    cache.append(item)
    #print(len(cache))
    if len(cache)==args.num_process_match:
      #matching in parallel
      results=pool.map(match_func,cache)
      for cur_item,cur_result in zip(cache,results):
          cur_item['corr1'],cur_item['corr2']=cur_result[0],cur_result[1]
          match_que.put(cur_item)
      cache=[]
  pool.close()
  pool.join()


def evaluate_handler(config,match_que):
  evaluator=load_component('evaluator',config['name'],config)
  pool = Pool(args.num_process_eval)
  cache=[]
  for _ in trange(config['num_pair']):
    item=match_que.get()
    if item=='over':
      if len(cache)!=0:
        results=pool.map(evaluator.run,cache)
        for cur_res in results:
          evaluator.res_inqueue(cur_res)
      break
    cache.append(item)
    if len(cache)==args.num_process_eval:
      results=pool.map(evaluator.run,cache)
      for cur_res in results:
          evaluator.res_inqueue(cur_res)
      cache=[]
    if args.vis_folder is not None:
      #dump visualization
      corr1_norm,corr2_norm=evaluation_utils.normalize_intrinsic(item['corr1'],item['K1']),\
                            evaluation_utils.normalize_intrinsic(item['corr2'],item['K2'])
      inlier_mask=metrics.compute_epi_inlier(corr1_norm,corr2_norm,item['e'],config['inlier_th'])
      display=evaluation_utils.draw_match(item['img1'],item['img2'],item['corr1'],item['corr2'],inlier_mask)
      cv2.imwrite(os.path.join(args.vis_folder,str(item['index'])+'.png'),display)
  evaluator.parse()


if __name__=='__main__':
  with open(args.config_path, 'r') as f:
    config = yaml.load(f)
  if args.vis_folder is not None and not os.path.exists(args.vis_folder):
    os.mkdir(args.vis_folder)

  read_que,match_que,estimate_que=Manager().Queue(maxsize=100),Manager().Queue(maxsize=100),Manager().Queue(maxsize=100)

  read_process=Process(target=reader_handler,args=(config['reader'],read_que))
  match_process=Process(target=match_handler,args=(config['matcher'],read_que,match_que))
  evaluate_process=Process(target=evaluate_handler,args=(config['evaluator'],match_que))

  read_process.start()
  match_process.start()
  evaluate_process.start()

  read_process.join()
  match_process.join()
  evaluate_process.join()