# -*- coding: UTF-8 -*- '''================================================= @Project -> File pram -> multimap3d @IDE PyCharm @Author fx221@cam.ac.uk @Date 04/03/2024 13:47 ==================================================''' import numpy as np import os import os.path as osp import time import cv2 import torch import yaml from copy import deepcopy from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches from localization.base_model import dynamic_load import localization.matchers as matchers from localization.match_features_batch import confs as matcher_confs from nets.gm import GM from tools.common import resize_img from localization.singlemap3d import SingleMap3D from localization.frame import Frame class MultiMap3D: def __init__(self, config, viewer=None, save_dir=None): self.config = config self.save_dir = save_dir self.scenes = [] self.sid_scene_name = [] self.sub_maps = {} self.scene_name_start_sid = {} self.loc_config = config['localization'] self.save_dir = save_dir if self.save_dir is not None: os.makedirs(self.save_dir, exist_ok=True) self.matching_method = config['localization']['matching_method'] device = 'cuda' if torch.cuda.is_available() else 'cpu' Model = dynamic_load(matchers, self.matching_method) self.matcher = Model(matcher_confs[self.matching_method]['model']).eval().to(device) self.initialize_map(config=config) self.loc_config = config['localization'] self.viewer = viewer # options self.do_refinement = self.loc_config['do_refinement'] self.refinement_method = self.loc_config['refinement_method'] self.semantic_matching = self.loc_config['semantic_matching'] self.do_pre_filtering = self.loc_config['pre_filtering_th'] > 0 self.pre_filtering_th = self.loc_config['pre_filtering_th'] def initialize_map(self, config): n_class = 0 datasets = config['dataset'] for name in datasets: config_path = osp.join(config['config_path'], '{:s}.yaml'.format(name)) dataset_name = name with open(config_path, 'r') as f: scene_config = yaml.load(f, Loader=yaml.Loader) scenes = scene_config['scenes'] for sid, scene in enumerate(scenes): self.scenes.append(name + '/' + scene) new_config = deepcopy(config) new_config['dataset_path'] = osp.join(config['dataset_path'], dataset_name, scene) new_config['landmark_path'] = osp.join(config['landmark_path'], dataset_name, scene) new_config['n_cluster'] = scene_config[scene]['n_cluster'] new_config['cluster_mode'] = scene_config[scene]['cluster_mode'] new_config['cluster_method'] = scene_config[scene]['cluster_method'] new_config['gt_pose_path'] = scene_config[scene]['gt_pose_path'] new_config['image_path_prefix'] = scene_config[scene]['image_path_prefix'] sub_map = SingleMap3D(config=new_config, matcher=self.matcher, with_compress=config['localization']['with_compress'], start_sid=n_class) self.sub_maps[dataset_name + '/' + scene] = sub_map n_scene_class = scene_config[scene]['n_cluster'] self.sid_scene_name = self.sid_scene_name + [dataset_name + '/' + scene for ni in range(n_scene_class)] self.scene_name_start_sid[dataset_name + '/' + scene] = n_class n_class = n_class + n_scene_class # break print('Load {} sub_maps from {} datasets'.format(len(self.sub_maps), len(datasets))) def run(self, q_frame: Frame): show = self.loc_config['show'] seg_color = generate_color_dic(n_seg=2000) if show: cv2.namedWindow('loc', cv2.WINDOW_NORMAL) q_loc_segs = self.process_segmentations(segs=torch.from_numpy(q_frame.segmentations), topk=self.loc_config['seg_k']) q_pred_segs_top1 = q_frame.seg_ids # initial results q_scene_name = q_frame.scene_name q_name = q_frame.name q_full_name = osp.join(q_scene_name, q_name) q_loc_sids = {} for v in q_loc_segs: q_loc_sids[v[0]] = (v[1], v[2]) query_sids = list(q_loc_sids.keys()) for i, sid in enumerate(query_sids): t_start = time.time() q_kpt_ids = q_loc_sids[sid][0] print(q_scene_name, q_name, sid) sid = sid - 1 # start from 0, confused! pred_scene_name = self.sid_scene_name[sid] start_seg_id = self.scene_name_start_sid[pred_scene_name] pred_sid_in_sub_scene = sid - self.scene_name_start_sid[pred_scene_name] pred_sub_map = self.sub_maps[pred_scene_name] pred_image_path_prefix = pred_sub_map.image_path_prefix print('pred/gt scene: {:s}, {:s}, sid: {:d}'.format(pred_scene_name, q_scene_name, pred_sid_in_sub_scene)) print('{:s}/{:s}, pred: {:s}, sid: {:d}, order: {:d}'.format(q_scene_name, q_name, pred_scene_name, sid, i)) if (q_kpt_ids.shape[0] >= self.loc_config['min_kpts'] and self.semantic_matching and pred_sub_map.check_semantic_consistency(q_frame=q_frame, sid=pred_sid_in_sub_scene, overlap_ratio=0.5)): semantic_matching = True else: q_kpt_ids = np.arange(q_frame.keypoints.shape[0]) semantic_matching = False print_text = f'Semantic matching - {semantic_matching}! Query kpts {q_kpt_ids.shape[0]} for {i}th seg {sid}' print(print_text) ret = pred_sub_map.localize_with_ref_frame(q_frame=q_frame, q_kpt_ids=q_kpt_ids, sid=pred_sid_in_sub_scene, semantic_matching=semantic_matching) q_frame.time_loc = q_frame.time_loc + time.time() - t_start # accumulate tracking time if show: reference_frame = pred_sub_map.reference_frames[ret['reference_frame_id']] ref_img = cv2.imread(osp.join(self.config['dataset_path'], pred_scene_name, pred_image_path_prefix, reference_frame.name)) q_img_seg = vis_seg_point(img=q_frame.image, kpts=q_frame.keypoints[q_kpt_ids, :2], segs=q_frame.seg_ids[q_kpt_ids] + 1, seg_color=seg_color) matched_points3D_ids = ret['matched_point3D_ids'] ref_sids = np.array([pred_sub_map.point3Ds[v].seg_id for v in matched_points3D_ids]) + \ self.scene_name_start_sid[pred_scene_name] + 1 # start from 1 as bg is 0 ref_img_seg = vis_seg_point(img=ref_img, kpts=ret['matched_ref_keypoints'], segs=ref_sids, seg_color=seg_color) q_matched_kpts = ret['matched_keypoints'] ref_matched_kpts = ret['matched_ref_keypoints'] img_loc_matching = plot_matches(img1=q_img_seg, img2=ref_img_seg, pts1=q_matched_kpts, pts2=ref_matched_kpts, inliers=np.array([True for i in range(q_matched_kpts.shape[0])]), radius=9, line_thickness=3 ) q_frame.image_matching_tmp = img_loc_matching q_frame.reference_frame_name_tmp = osp.join(self.config['dataset_path'], pred_scene_name, pred_image_path_prefix, reference_frame.name) # ret['image_matching'] = img_loc_matching # ret['reference_frame_name'] = osp.join(self.config['dataset_path'], # pred_scene_name, # pred_image_path_prefix, # reference_frame.name) q_ref_img_matching = np.hstack([resize_img(q_img_seg, nh=512), resize_img(ref_img_seg, nh=512), resize_img(img_loc_matching, nh=512)]) ret['order'] = i ret['matched_scene_name'] = pred_scene_name if not ret['success']: num_matches = ret['matched_keypoints'].shape[0] num_inliers = ret['num_inliers'] print_text = f'Localization failed with {num_matches}/{q_kpt_ids.shape[0]} matches and {num_inliers} inliers, order {i}' print(print_text) if show: show_text = 'FAIL! order: {:d}/{:d}-{:d}/{:d}'.format(i, len(q_loc_segs), num_matches, q_kpt_ids.shape[0]) q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'], radius=9 + 2, thickness=2) q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2, lineType=cv2.LINE_AA) q_frame.image_inlier_tmp = q_img_inlier q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) cv2.imshow('loc', q_img_loc) key = cv2.waitKey(self.loc_config['show_time']) if key == ord('q'): cv2.destroyAllWindows() exit(0) continue if show: q_err, t_err = q_frame.compute_pose_error() num_matches = ret['matched_keypoints'].shape[0] num_inliers = ret['num_inliers'] show_text = 'order: {:d}/{:d}, k/m/i: {:d}/{:d}/{:d}'.format( i, len(q_loc_segs), q_kpt_ids.shape[0], num_matches, num_inliers) q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'], radius=9 + 2, thickness=2) q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2, lineType=cv2.LINE_AA) show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err) q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2, lineType=cv2.LINE_AA) q_frame.image_inlier_tmp = q_img_inlier q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) cv2.imshow('loc', q_img_loc) key = cv2.waitKey(self.loc_config['show_time']) if key == ord('q'): cv2.destroyAllWindows() exit(0) success = self.verify_and_update(q_frame=q_frame, ret=ret) if not success: continue else: break if q_frame.tracking_status is None: print('Failed to find a proper reference frame.') return False # do refinement if not self.do_refinement: return True else: t_start = time.time() pred_sub_map = self.sub_maps[q_frame.matched_scene_name] if q_frame.tracking_status is True and np.sum(q_frame.matched_inliers) >= 64: ret = pred_sub_map.refine_pose(q_frame=q_frame, refinement_method=self.loc_config['refinement_method']) else: ret = pred_sub_map.refine_pose(q_frame=q_frame, refinement_method='matching') # do not trust the pose for projection q_frame.time_ref = time.time() - t_start inlier_mask = np.array(ret['inliers']) q_frame.qvec = ret['qvec'] q_frame.tvec = ret['tvec'] q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask] q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask] q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask] q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask] q_frame.matched_sids = ret['matched_sids'][inlier_mask] q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask] q_frame.refinement_reference_frame_ids = ret['refinement_reference_frame_ids'] q_frame.reference_frame_id = ret['reference_frame_id'] q_err, t_err = q_frame.compute_pose_error() ref_full_name = q_frame.matched_scene_name + '/' + pred_sub_map.reference_frames[ q_frame.reference_frame_id].name print_text = 'Localization of {:s} success with inliers {:d}/{:d} with ref_name: {:s}, order: {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( q_full_name, ret['num_inliers'], len(ret['inliers']), ref_full_name, q_frame.matched_order, q_err, t_err) print(print_text) if show: q_err, t_err = q_frame.compute_pose_error() num_matches = ret['matched_keypoints'].shape[0] num_inliers = ret['num_inliers'] show_text = 'Ref:{:d}/{:d},r_err:{:.2f}/t_err:{:.2f}'.format(num_matches, num_inliers, q_err, t_err) q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 130), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=2, lineType=cv2.LINE_AA) q_frame.image_inlier = q_img_inlier return True def verify_and_update(self, q_frame: Frame, ret: dict): num_matches = ret['matched_keypoints'].shape[0] num_inliers = ret['num_inliers'] if q_frame.matched_keypoints is None or np.sum(q_frame.matched_inliers) < num_inliers: self.update_query_frame(q_frame=q_frame, ret=ret) q_err, t_err = q_frame.compute_pose_error(pred_qvec=ret['qvec'], pred_tvec=ret['tvec']) if num_inliers < self.loc_config['min_inliers']: print_text = 'Failed due to insufficient {:d} inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( ret['num_inliers'], ret['order'], q_err, t_err) print(print_text) q_frame.tracking_status = False return False else: print_text = 'Succeed! Find {}/{} 2D-3D inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( num_inliers, num_matches, ret['order'], q_err, t_err) print(print_text) q_frame.tracking_status = True return True def update_query_frame(self, q_frame, ret): q_frame.matched_scene_name = ret['matched_scene_name'] q_frame.reference_frame_id = ret['reference_frame_id'] q_frame.qvec = ret['qvec'] q_frame.tvec = ret['tvec'] inlier_mask = np.array(ret['inliers']) q_frame.matched_keypoints = ret['matched_keypoints'] q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'] q_frame.matched_xyzs = ret['matched_xyzs'] q_frame.matched_point3D_ids = ret['matched_point3D_ids'] q_frame.matched_sids = ret['matched_sids'] q_frame.matched_inliers = np.array(ret['inliers']) q_frame.matched_order = ret['order'] if q_frame.image_inlier_tmp is not None: q_frame.image_inlier = deepcopy(q_frame.image_inlier_tmp) if q_frame.image_matching_tmp is not None: q_frame.image_matching = deepcopy(q_frame.image_matching_tmp) if q_frame.reference_frame_name_tmp is not None: q_frame.reference_frame_name = q_frame.reference_frame_name_tmp # inlier_mask = np.array(ret['inliers']) # q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask] # q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask] # q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask] # q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask] # q_frame.matched_sids = ret['matched_sids'][inlier_mask] # q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask] # print('update_query_frame: ', q_frame.matched_keypoint_ids.shape, q_frame.matched_keypoints.shape, # q_frame.matched_xyzs.shape, q_frame.matched_xyzs.shape, np.sum(q_frame.matched_inliers)) def process_segmentations(self, segs, topk=10): pred_values, pred_ids = torch.topk(segs, k=segs.shape[-1], largest=True, dim=-1) # [N, C] pred_values = pred_values.numpy() pred_ids = pred_ids.numpy() out = [] used_sids = [] for k in range(segs.shape[-1]): values_k = pred_values[:, k] ids_k = pred_ids[:, k] uids = np.unique(ids_k) out_k = [] for sid in uids: if sid == 0: continue if sid in used_sids: continue used_sids.append(sid) ids = np.where(ids_k == sid)[0] score = np.mean(values_k[ids]) # score = np.median(values_k[ids]) # score = 100 - k # out_k.append((ids.shape[0], sid - 1, ids, score)) out_k.append((ids.shape[0], sid, ids, score)) out_k = sorted(out_k, key=lambda item: item[0], reverse=True) for v in out_k: out.append((v[1], v[2], v[3])) # [sid, ids, score] if len(out) >= topk: return out return out