Realcat
fix: eloftr
63f3cf2
raw
history blame
19.3 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> multimap3d
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 04/03/2024 13:47
=================================================='''
import numpy as np
import os
import os.path as osp
import time
import cv2
import torch
import yaml
from copy import deepcopy
from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches
from localization.base_model import dynamic_load
import localization.matchers as matchers
from localization.match_features_batch import confs as matcher_confs
from nets.gm import GM
from tools.common import resize_img
from localization.singlemap3d import SingleMap3D
from localization.frame import Frame
class MultiMap3D:
def __init__(self, config, viewer=None, save_dir=None):
self.config = config
self.save_dir = save_dir
self.scenes = []
self.sid_scene_name = []
self.sub_maps = {}
self.scene_name_start_sid = {}
self.loc_config = config['localization']
self.save_dir = save_dir
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
self.matching_method = config['localization']['matching_method']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Model = dynamic_load(matchers, self.matching_method)
self.matcher = Model(matcher_confs[self.matching_method]['model']).eval().to(device)
self.initialize_map(config=config)
self.loc_config = config['localization']
self.viewer = viewer
# options
self.do_refinement = self.loc_config['do_refinement']
self.refinement_method = self.loc_config['refinement_method']
self.semantic_matching = self.loc_config['semantic_matching']
self.do_pre_filtering = self.loc_config['pre_filtering_th'] > 0
self.pre_filtering_th = self.loc_config['pre_filtering_th']
def initialize_map(self, config):
n_class = 0
datasets = config['dataset']
for name in datasets:
config_path = osp.join(config['config_path'], '{:s}.yaml'.format(name))
dataset_name = name
with open(config_path, 'r') as f:
scene_config = yaml.load(f, Loader=yaml.Loader)
scenes = scene_config['scenes']
for sid, scene in enumerate(scenes):
self.scenes.append(name + '/' + scene)
new_config = deepcopy(config)
new_config['dataset_path'] = osp.join(config['dataset_path'], dataset_name, scene)
new_config['landmark_path'] = osp.join(config['landmark_path'], dataset_name, scene)
new_config['n_cluster'] = scene_config[scene]['n_cluster']
new_config['cluster_mode'] = scene_config[scene]['cluster_mode']
new_config['cluster_method'] = scene_config[scene]['cluster_method']
new_config['gt_pose_path'] = scene_config[scene]['gt_pose_path']
new_config['image_path_prefix'] = scene_config[scene]['image_path_prefix']
sub_map = SingleMap3D(config=new_config,
matcher=self.matcher,
with_compress=config['localization']['with_compress'],
start_sid=n_class)
self.sub_maps[dataset_name + '/' + scene] = sub_map
n_scene_class = scene_config[scene]['n_cluster']
self.sid_scene_name = self.sid_scene_name + [dataset_name + '/' + scene for ni in range(n_scene_class)]
self.scene_name_start_sid[dataset_name + '/' + scene] = n_class
n_class = n_class + n_scene_class
# break
print('Load {} sub_maps from {} datasets'.format(len(self.sub_maps), len(datasets)))
def run(self, q_frame: Frame):
show = self.loc_config['show']
seg_color = generate_color_dic(n_seg=2000)
if show:
cv2.namedWindow('loc', cv2.WINDOW_NORMAL)
q_loc_segs = self.process_segmentations(segs=torch.from_numpy(q_frame.segmentations),
topk=self.loc_config['seg_k'])
q_pred_segs_top1 = q_frame.seg_ids # initial results
q_scene_name = q_frame.scene_name
q_name = q_frame.name
q_full_name = osp.join(q_scene_name, q_name)
q_loc_sids = {}
for v in q_loc_segs:
q_loc_sids[v[0]] = (v[1], v[2])
query_sids = list(q_loc_sids.keys())
for i, sid in enumerate(query_sids):
t_start = time.time()
q_kpt_ids = q_loc_sids[sid][0]
print(q_scene_name, q_name, sid)
sid = sid - 1 # start from 0, confused!
pred_scene_name = self.sid_scene_name[sid]
start_seg_id = self.scene_name_start_sid[pred_scene_name]
pred_sid_in_sub_scene = sid - self.scene_name_start_sid[pred_scene_name]
pred_sub_map = self.sub_maps[pred_scene_name]
pred_image_path_prefix = pred_sub_map.image_path_prefix
print('pred/gt scene: {:s}, {:s}, sid: {:d}'.format(pred_scene_name, q_scene_name, pred_sid_in_sub_scene))
print('{:s}/{:s}, pred: {:s}, sid: {:d}, order: {:d}'.format(q_scene_name, q_name, pred_scene_name, sid,
i))
if (q_kpt_ids.shape[0] >= self.loc_config['min_kpts']
and self.semantic_matching
and pred_sub_map.check_semantic_consistency(q_frame=q_frame,
sid=pred_sid_in_sub_scene,
overlap_ratio=0.5)):
semantic_matching = True
else:
q_kpt_ids = np.arange(q_frame.keypoints.shape[0])
semantic_matching = False
print_text = f'Semantic matching - {semantic_matching}! Query kpts {q_kpt_ids.shape[0]} for {i}th seg {sid}'
print(print_text)
ret = pred_sub_map.localize_with_ref_frame(q_frame=q_frame,
q_kpt_ids=q_kpt_ids,
sid=pred_sid_in_sub_scene,
semantic_matching=semantic_matching)
q_frame.time_loc = q_frame.time_loc + time.time() - t_start # accumulate tracking time
if show:
reference_frame = pred_sub_map.reference_frames[ret['reference_frame_id']]
ref_img = cv2.imread(osp.join(self.config['dataset_path'], pred_scene_name, pred_image_path_prefix,
reference_frame.name))
q_img_seg = vis_seg_point(img=q_frame.image, kpts=q_frame.keypoints[q_kpt_ids, :2],
segs=q_frame.seg_ids[q_kpt_ids] + 1,
seg_color=seg_color)
matched_points3D_ids = ret['matched_point3D_ids']
ref_sids = np.array([pred_sub_map.point3Ds[v].seg_id for v in matched_points3D_ids]) + \
self.scene_name_start_sid[pred_scene_name] + 1 # start from 1 as bg is 0
ref_img_seg = vis_seg_point(img=ref_img, kpts=ret['matched_ref_keypoints'], segs=ref_sids,
seg_color=seg_color)
q_matched_kpts = ret['matched_keypoints']
ref_matched_kpts = ret['matched_ref_keypoints']
img_loc_matching = plot_matches(img1=q_img_seg, img2=ref_img_seg,
pts1=q_matched_kpts, pts2=ref_matched_kpts,
inliers=np.array([True for i in range(q_matched_kpts.shape[0])]),
radius=9, line_thickness=3
)
q_frame.image_matching_tmp = img_loc_matching
q_frame.reference_frame_name_tmp = osp.join(self.config['dataset_path'],
pred_scene_name,
pred_image_path_prefix,
reference_frame.name)
# ret['image_matching'] = img_loc_matching
# ret['reference_frame_name'] = osp.join(self.config['dataset_path'],
# pred_scene_name,
# pred_image_path_prefix,
# reference_frame.name)
q_ref_img_matching = np.hstack([resize_img(q_img_seg, nh=512),
resize_img(ref_img_seg, nh=512),
resize_img(img_loc_matching, nh=512)])
ret['order'] = i
ret['matched_scene_name'] = pred_scene_name
if not ret['success']:
num_matches = ret['matched_keypoints'].shape[0]
num_inliers = ret['num_inliers']
print_text = f'Localization failed with {num_matches}/{q_kpt_ids.shape[0]} matches and {num_inliers} inliers, order {i}'
print(print_text)
if show:
show_text = 'FAIL! order: {:d}/{:d}-{:d}/{:d}'.format(i, len(q_loc_segs),
num_matches,
q_kpt_ids.shape[0])
q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'],
radius=9 + 2, thickness=2)
q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30),
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
thickness=2, lineType=cv2.LINE_AA)
q_frame.image_inlier_tmp = q_img_inlier
q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)])
cv2.imshow('loc', q_img_loc)
key = cv2.waitKey(self.loc_config['show_time'])
if key == ord('q'):
cv2.destroyAllWindows()
exit(0)
continue
if show:
q_err, t_err = q_frame.compute_pose_error()
num_matches = ret['matched_keypoints'].shape[0]
num_inliers = ret['num_inliers']
show_text = 'order: {:d}/{:d}, k/m/i: {:d}/{:d}/{:d}'.format(
i, len(q_loc_segs), q_kpt_ids.shape[0], num_matches, num_inliers)
q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'],
radius=9 + 2, thickness=2)
q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30),
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
thickness=2, lineType=cv2.LINE_AA)
show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err)
q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80),
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
thickness=2, lineType=cv2.LINE_AA)
q_frame.image_inlier_tmp = q_img_inlier
q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)])
cv2.imshow('loc', q_img_loc)
key = cv2.waitKey(self.loc_config['show_time'])
if key == ord('q'):
cv2.destroyAllWindows()
exit(0)
success = self.verify_and_update(q_frame=q_frame, ret=ret)
if not success:
continue
else:
break
if q_frame.tracking_status is None:
print('Failed to find a proper reference frame.')
return False
# do refinement
if not self.do_refinement:
return True
else:
t_start = time.time()
pred_sub_map = self.sub_maps[q_frame.matched_scene_name]
if q_frame.tracking_status is True and np.sum(q_frame.matched_inliers) >= 64:
ret = pred_sub_map.refine_pose(q_frame=q_frame, refinement_method=self.loc_config['refinement_method'])
else:
ret = pred_sub_map.refine_pose(q_frame=q_frame,
refinement_method='matching') # do not trust the pose for projection
q_frame.time_ref = time.time() - t_start
inlier_mask = np.array(ret['inliers'])
q_frame.qvec = ret['qvec']
q_frame.tvec = ret['tvec']
q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask]
q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask]
q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask]
q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask]
q_frame.matched_sids = ret['matched_sids'][inlier_mask]
q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask]
q_frame.refinement_reference_frame_ids = ret['refinement_reference_frame_ids']
q_frame.reference_frame_id = ret['reference_frame_id']
q_err, t_err = q_frame.compute_pose_error()
ref_full_name = q_frame.matched_scene_name + '/' + pred_sub_map.reference_frames[
q_frame.reference_frame_id].name
print_text = 'Localization of {:s} success with inliers {:d}/{:d} with ref_name: {:s}, order: {:d}, q_err: {:.2f}, t_err: {:.2f}'.format(
q_full_name, ret['num_inliers'], len(ret['inliers']), ref_full_name, q_frame.matched_order, q_err,
t_err)
print(print_text)
if show:
q_err, t_err = q_frame.compute_pose_error()
num_matches = ret['matched_keypoints'].shape[0]
num_inliers = ret['num_inliers']
show_text = 'Ref:{:d}/{:d},r_err:{:.2f}/t_err:{:.2f}'.format(num_matches, num_inliers, q_err,
t_err)
q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 130),
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255),
thickness=2, lineType=cv2.LINE_AA)
q_frame.image_inlier = q_img_inlier
return True
def verify_and_update(self, q_frame: Frame, ret: dict):
num_matches = ret['matched_keypoints'].shape[0]
num_inliers = ret['num_inliers']
if q_frame.matched_keypoints is None or np.sum(q_frame.matched_inliers) < num_inliers:
self.update_query_frame(q_frame=q_frame, ret=ret)
q_err, t_err = q_frame.compute_pose_error(pred_qvec=ret['qvec'], pred_tvec=ret['tvec'])
if num_inliers < self.loc_config['min_inliers']:
print_text = 'Failed due to insufficient {:d} inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format(
ret['num_inliers'], ret['order'], q_err, t_err)
print(print_text)
q_frame.tracking_status = False
return False
else:
print_text = 'Succeed! Find {}/{} 2D-3D inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format(
num_inliers, num_matches, ret['order'], q_err, t_err)
print(print_text)
q_frame.tracking_status = True
return True
def update_query_frame(self, q_frame, ret):
q_frame.matched_scene_name = ret['matched_scene_name']
q_frame.reference_frame_id = ret['reference_frame_id']
q_frame.qvec = ret['qvec']
q_frame.tvec = ret['tvec']
inlier_mask = np.array(ret['inliers'])
q_frame.matched_keypoints = ret['matched_keypoints']
q_frame.matched_keypoint_ids = ret['matched_keypoint_ids']
q_frame.matched_xyzs = ret['matched_xyzs']
q_frame.matched_point3D_ids = ret['matched_point3D_ids']
q_frame.matched_sids = ret['matched_sids']
q_frame.matched_inliers = np.array(ret['inliers'])
q_frame.matched_order = ret['order']
if q_frame.image_inlier_tmp is not None:
q_frame.image_inlier = deepcopy(q_frame.image_inlier_tmp)
if q_frame.image_matching_tmp is not None:
q_frame.image_matching = deepcopy(q_frame.image_matching_tmp)
if q_frame.reference_frame_name_tmp is not None:
q_frame.reference_frame_name = q_frame.reference_frame_name_tmp
# inlier_mask = np.array(ret['inliers'])
# q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask]
# q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask]
# q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask]
# q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask]
# q_frame.matched_sids = ret['matched_sids'][inlier_mask]
# q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask]
# print('update_query_frame: ', q_frame.matched_keypoint_ids.shape, q_frame.matched_keypoints.shape,
# q_frame.matched_xyzs.shape, q_frame.matched_xyzs.shape, np.sum(q_frame.matched_inliers))
def process_segmentations(self, segs, topk=10):
pred_values, pred_ids = torch.topk(segs, k=segs.shape[-1], largest=True, dim=-1) # [N, C]
pred_values = pred_values.numpy()
pred_ids = pred_ids.numpy()
out = []
used_sids = []
for k in range(segs.shape[-1]):
values_k = pred_values[:, k]
ids_k = pred_ids[:, k]
uids = np.unique(ids_k)
out_k = []
for sid in uids:
if sid == 0:
continue
if sid in used_sids:
continue
used_sids.append(sid)
ids = np.where(ids_k == sid)[0]
score = np.mean(values_k[ids])
# score = np.median(values_k[ids])
# score = 100 - k
# out_k.append((ids.shape[0], sid - 1, ids, score))
out_k.append((ids.shape[0], sid, ids, score))
out_k = sorted(out_k, key=lambda item: item[0], reverse=True)
for v in out_k:
out.append((v[1], v[2], v[3])) # [sid, ids, score]
if len(out) >= topk:
return out
return out