Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding: UTF-8 -*- | |
| '''================================================= | |
| @Project -> File pram -> loc_by_rec | |
| @IDE PyCharm | |
| @Author fx221@cam.ac.uk | |
| @Date 08/02/2024 15:26 | |
| ==================================================''' | |
| import torch | |
| from torch.autograd import Variable | |
| from localization.multimap3d import MultiMap3D | |
| from localization.frame import Frame | |
| import yaml, cv2, time | |
| import numpy as np | |
| import os.path as osp | |
| import threading | |
| import os | |
| from tqdm import tqdm | |
| from recognition.vis_seg import vis_seg_point, generate_color_dic | |
| from tools.metrics import compute_iou, compute_precision | |
| from localization.tracker import Tracker | |
| from localization.utils import read_query_info | |
| from localization.camera import Camera | |
| def loc_by_rec_eval(rec_model, loader, config, local_feat, img_transforms=None): | |
| n_epoch = int(config['weight_path'].split('.')[1]) | |
| save_fn = osp.join(config['localization']['save_path'], | |
| config['weight_path'].split('/')[0] + '_{:d}'.format(n_epoch) + '_{:d}'.format( | |
| config['feat_dim'])) | |
| tag = 'k{:d}_th{:d}_mm{:d}_mi{:d}'.format(config['localization']['seg_k'], config['localization']['threshold'], | |
| config['localization']['min_matches'], | |
| config['localization']['min_inliers']) | |
| if config['localization']['do_refinement']: | |
| tag += '_op{:d}'.format(config['localization']['covisibility_frame']) | |
| if config['localization']['with_compress']: | |
| tag += '_comp' | |
| save_fn = save_fn + '_' + tag | |
| save = config['localization']['save'] | |
| save = config['localization']['save'] | |
| if save: | |
| save_dir = save_fn | |
| os.makedirs(save_dir, exist_ok=True) | |
| else: | |
| save_dir = None | |
| seg_color = generate_color_dic(n_seg=2000) | |
| dataset_path = config['dataset_path'] | |
| show = config['localization']['show'] | |
| if show: | |
| cv2.namedWindow('img', cv2.WINDOW_NORMAL) | |
| locMap = MultiMap3D(config=config, save_dir=None) | |
| # start tracker | |
| mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config) | |
| dataset_name = config['dataset'][0] | |
| all_scene_query_info = {} | |
| with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f: | |
| scene_config = yaml.load(f, Loader=yaml.Loader) | |
| scenes = scene_config['scenes'] | |
| for scene in scenes: | |
| query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path']) | |
| query_info = read_query_info(query_fn=query_path) | |
| all_scene_query_info[dataset_name + '/' + scene] = query_info | |
| # print(scene, query_info.keys()) | |
| tracking = False | |
| full_log = '' | |
| failed_cases = [] | |
| success_cases = [] | |
| poses = {} | |
| err_ths_cnt = [0, 0, 0, 0] | |
| seg_results = {} | |
| time_results = { | |
| 'feat': [], | |
| 'rec': [], | |
| 'loc': [], | |
| 'ref': [], | |
| 'total': [], | |
| } | |
| n_total = 0 | |
| loc_scene_names = config['localization']['loc_scene_name'] | |
| # loader = loader[8990:] | |
| for bid, pred in tqdm(enumerate(loader), total=len(loader)): | |
| pred = loader[bid] | |
| image_name = pred['file_name'] # [0] | |
| scene_name = pred['scene_name'] # [0] # dataset_scene | |
| if len(loc_scene_names) > 0: | |
| skip = True | |
| for loc_scene in loc_scene_names: | |
| if scene_name.find(loc_scene) > 0: | |
| skip = False | |
| break | |
| if skip: | |
| continue | |
| with torch.no_grad(): | |
| for k in pred: | |
| if k.find('name') >= 0: | |
| continue | |
| if k != 'image0' and k != 'image1' and k != 'depth0' and k != 'depth1': | |
| if type(pred[k]) == np.ndarray: | |
| pred[k] = Variable(torch.from_numpy(pred[k]).float().cuda())[None] | |
| elif type(pred[k]) == torch.Tensor: | |
| pred[k] = Variable(pred[k].float().cuda()) | |
| elif type(pred[k]) == list: | |
| continue | |
| else: | |
| pred[k] = Variable(torch.stack(pred[k]).float().cuda()) | |
| print('scene: ', scene_name, image_name) | |
| n_total += 1 | |
| with torch.no_grad(): | |
| img = pred['image'] | |
| while isinstance(img, list): | |
| img = img[0] | |
| new_im = torch.from_numpy(img).permute(2, 0, 1).cuda().float() | |
| if img_transforms is not None: | |
| new_im = img_transforms(new_im)[None] | |
| else: | |
| new_im = new_im[None] | |
| img = (img * 255).astype(np.uint8) | |
| fn = image_name | |
| camera_model, width, height, params = all_scene_query_info[scene_name][fn] | |
| camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params) | |
| curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=scene_name) | |
| gt_sub_map = locMap.sub_maps[curr_frame.scene_name] | |
| if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys(): | |
| curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec'] | |
| curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec'] | |
| t_start = time.time() | |
| encoder_out = local_feat.extract_local_global(data={'image': new_im}, | |
| config= | |
| { | |
| # 'min_keypoints': 128, | |
| 'max_keypoints': config['eval_max_keypoints'], | |
| } | |
| ) | |
| t_feat = time.time() - t_start | |
| # global_descriptors_cuda = encoder_out['global_descriptors'] | |
| # scores_cuda = encoder_out['scores'][0][None] | |
| # kpts_cuda = encoder_out['keypoints'][0][None] | |
| # descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1) | |
| sparse_scores = pred['scores'] | |
| sparse_descs = pred['descriptors'] | |
| sparse_kpts = pred['keypoints'] | |
| gt_seg = pred['gt_seg'] | |
| curr_frame.add_keypoints(keypoints=np.hstack([sparse_kpts[0].cpu().numpy(), | |
| sparse_scores[0].cpu().numpy().reshape(-1, 1)]), | |
| descriptors=sparse_descs[0].cpu().numpy()) | |
| curr_frame.time_feat = t_feat | |
| t_start = time.time() | |
| _, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'], | |
| semi_descs=encoder_out['mid_features'], | |
| # kpts=kpts_cuda[0], | |
| kpts=sparse_kpts[0], | |
| norm_desc=config['norm_desc']) | |
| rec_out = rec_model({'scores': sparse_scores, | |
| 'seg_descriptors': seg_descriptors[None].permute(0, 2, 1), | |
| 'keypoints': sparse_kpts, | |
| 'image': new_im}) | |
| t_rec = time.time() - t_start | |
| curr_frame.time_rec = t_rec | |
| pred = { | |
| # 'scores': scores_cuda, | |
| # 'keypoints': kpts_cuda, | |
| # 'descriptors': descriptors_cuda, | |
| # 'global_descriptors': global_descriptors_cuda, | |
| 'image_size': np.array([img.shape[1], img.shape[0]])[None], | |
| } | |
| pred = {**pred, **rec_out} | |
| pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C] | |
| pred_seg = pred_seg[0].cpu().numpy() | |
| kpts = sparse_kpts[0].cpu().numpy() | |
| img_pred_seg = vis_seg_point(img=img, kpts=kpts, segs=pred_seg, seg_color=seg_color, radius=9) | |
| show_text = 'kpts: {:d}'.format(kpts.shape[0]) | |
| img_pred_seg = cv2.putText(img=img_pred_seg, text=show_text, | |
| org=(50, 30), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
| fontScale=1, color=(0, 0, 255), | |
| thickness=2, lineType=cv2.LINE_AA) | |
| curr_frame.image_rec = img_pred_seg | |
| if show: | |
| cv2.imshow('img', img) | |
| key = cv2.waitKey(1) | |
| if key == ord('q'): | |
| exit(0) | |
| elif key == ord('s'): | |
| show_time = -1 | |
| elif key == ord('c'): | |
| show_time = 1 | |
| segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C] | |
| curr_frame.add_segmentations(segmentations=segmentations, | |
| filtering_threshold=config['localization']['pre_filtering_th']) | |
| # Step1: do tracker first | |
| success = not mTracker.lost and tracking | |
| if success: | |
| success = mTracker.run(frame=curr_frame) | |
| if not success: | |
| success = locMap.run(q_frame=curr_frame) | |
| if success: | |
| curr_frame.update_point3ds() | |
| if tracking: | |
| mTracker.lost = False | |
| mTracker.last_frame = curr_frame | |
| # ''' | |
| pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C] | |
| pred_seg = pred_seg[0].cpu().numpy() | |
| gt_seg = gt_seg[0].cpu().numpy() | |
| iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=pred_seg.shape[0], | |
| ignored_ids=[0]) # 0 - background | |
| prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0]) | |
| kpts = sparse_kpts[0].cpu().numpy() | |
| if scene not in seg_results.keys(): | |
| seg_results[scene] = { | |
| 'day': { | |
| 'prec': [], | |
| 'iou': [], | |
| 'kpts': [], | |
| }, | |
| 'night': { | |
| 'prec': [], | |
| 'iou': [], | |
| 'kpts': [], | |
| } | |
| } | |
| if fn.find('night') >= 0: | |
| seg_results[scene]['night']['prec'].append(prec) | |
| seg_results[scene]['night']['iou'].append(iou) | |
| seg_results[scene]['night']['kpts'].append(kpts.shape[0]) | |
| else: | |
| seg_results[scene]['day']['prec'].append(prec) | |
| seg_results[scene]['day']['iou'].append(iou) | |
| seg_results[scene]['day']['kpts'].append(kpts.shape[0]) | |
| print_text = 'name: {:s}, kpts: {:d}, iou: {:.3f}, prec: {:.3f}'.format(fn, kpts.shape[0], iou, | |
| prec) | |
| print(print_text) | |
| # ''' | |
| t_feat = curr_frame.time_feat | |
| t_rec = curr_frame.time_rec | |
| t_loc = curr_frame.time_loc | |
| t_ref = curr_frame.time_ref | |
| t_total = t_feat + t_rec + t_loc + t_ref | |
| time_results['feat'].append(t_feat) | |
| time_results['rec'].append(t_rec) | |
| time_results['loc'].append(t_loc) | |
| time_results['ref'].append(t_ref) | |
| time_results['total'].append(t_total) | |
| poses[scene + '/' + fn] = (curr_frame.qvec, curr_frame.tvec) | |
| q_err, t_err = curr_frame.compute_pose_error() | |
| if q_err <= 5 and t_err <= 0.05: | |
| err_ths_cnt[0] = err_ths_cnt[0] + 1 | |
| if q_err <= 2 and t_err <= 0.25: | |
| err_ths_cnt[1] = err_ths_cnt[1] + 1 | |
| if q_err <= 5 and t_err <= 0.5: | |
| err_ths_cnt[2] = err_ths_cnt[2] + 1 | |
| if q_err <= 10 and t_err <= 5: | |
| err_ths_cnt[3] = err_ths_cnt[3] + 1 | |
| if success: | |
| success_cases.append(scene + '/' + fn) | |
| print_text = 'qname: {:s} localization success {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( | |
| scene + '/' + fn, len(success_cases), n_total, q_err, t_err, err_ths_cnt[0], | |
| err_ths_cnt[1], | |
| err_ths_cnt[2], | |
| err_ths_cnt[3], | |
| n_total, | |
| t_feat, t_rec, t_loc, t_ref, t_total | |
| ) | |
| else: | |
| failed_cases.append(scene + '/' + fn) | |
| print_text = 'qname: {:s} localization fail {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( | |
| scene + '/' + fn, len(failed_cases), n_total, q_err, t_err, err_ths_cnt[0], | |
| err_ths_cnt[1], | |
| err_ths_cnt[2], | |
| err_ths_cnt[3], | |
| n_total, t_feat, t_rec, t_loc, t_ref, t_total) | |
| print(print_text) | |