Spaces:
Running
Running
# -*- coding: UTF-8 -*- | |
'''================================================= | |
@Project -> File pram -> loc_by_rec | |
@IDE PyCharm | |
@Author fx221@cam.ac.uk | |
@Date 08/02/2024 15:26 | |
==================================================''' | |
import torch | |
from torch.autograd import Variable | |
from localization.multimap3d import MultiMap3D | |
from localization.frame import Frame | |
import yaml, cv2, time | |
import numpy as np | |
import os.path as osp | |
import threading | |
import os | |
from tqdm import tqdm | |
from recognition.vis_seg import vis_seg_point, generate_color_dic | |
from tools.metrics import compute_iou, compute_precision | |
from localization.tracker import Tracker | |
from localization.utils import read_query_info | |
from localization.camera import Camera | |
def loc_by_rec_eval(rec_model, loader, config, local_feat, img_transforms=None): | |
n_epoch = int(config['weight_path'].split('.')[1]) | |
save_fn = osp.join(config['localization']['save_path'], | |
config['weight_path'].split('/')[0] + '_{:d}'.format(n_epoch) + '_{:d}'.format( | |
config['feat_dim'])) | |
tag = 'k{:d}_th{:d}_mm{:d}_mi{:d}'.format(config['localization']['seg_k'], config['localization']['threshold'], | |
config['localization']['min_matches'], | |
config['localization']['min_inliers']) | |
if config['localization']['do_refinement']: | |
tag += '_op{:d}'.format(config['localization']['covisibility_frame']) | |
if config['localization']['with_compress']: | |
tag += '_comp' | |
save_fn = save_fn + '_' + tag | |
save = config['localization']['save'] | |
save = config['localization']['save'] | |
if save: | |
save_dir = save_fn | |
os.makedirs(save_dir, exist_ok=True) | |
else: | |
save_dir = None | |
seg_color = generate_color_dic(n_seg=2000) | |
dataset_path = config['dataset_path'] | |
show = config['localization']['show'] | |
if show: | |
cv2.namedWindow('img', cv2.WINDOW_NORMAL) | |
locMap = MultiMap3D(config=config, save_dir=None) | |
# start tracker | |
mTracker = Tracker(locMap=locMap, matcher=locMap.matcher, config=config) | |
dataset_name = config['dataset'][0] | |
all_scene_query_info = {} | |
with open(osp.join(config['config_path'], '{:s}.yaml'.format(dataset_name)), 'r') as f: | |
scene_config = yaml.load(f, Loader=yaml.Loader) | |
scenes = scene_config['scenes'] | |
for scene in scenes: | |
query_path = osp.join(config['dataset_path'], dataset_name, scene, scene_config[scene]['query_path']) | |
query_info = read_query_info(query_fn=query_path) | |
all_scene_query_info[dataset_name + '/' + scene] = query_info | |
# print(scene, query_info.keys()) | |
tracking = False | |
full_log = '' | |
failed_cases = [] | |
success_cases = [] | |
poses = {} | |
err_ths_cnt = [0, 0, 0, 0] | |
seg_results = {} | |
time_results = { | |
'feat': [], | |
'rec': [], | |
'loc': [], | |
'ref': [], | |
'total': [], | |
} | |
n_total = 0 | |
loc_scene_names = config['localization']['loc_scene_name'] | |
# loader = loader[8990:] | |
for bid, pred in tqdm(enumerate(loader), total=len(loader)): | |
pred = loader[bid] | |
image_name = pred['file_name'] # [0] | |
scene_name = pred['scene_name'] # [0] # dataset_scene | |
if len(loc_scene_names) > 0: | |
skip = True | |
for loc_scene in loc_scene_names: | |
if scene_name.find(loc_scene) > 0: | |
skip = False | |
break | |
if skip: | |
continue | |
with torch.no_grad(): | |
for k in pred: | |
if k.find('name') >= 0: | |
continue | |
if k != 'image0' and k != 'image1' and k != 'depth0' and k != 'depth1': | |
if type(pred[k]) == np.ndarray: | |
pred[k] = Variable(torch.from_numpy(pred[k]).float().cuda())[None] | |
elif type(pred[k]) == torch.Tensor: | |
pred[k] = Variable(pred[k].float().cuda()) | |
elif type(pred[k]) == list: | |
continue | |
else: | |
pred[k] = Variable(torch.stack(pred[k]).float().cuda()) | |
print('scene: ', scene_name, image_name) | |
n_total += 1 | |
with torch.no_grad(): | |
img = pred['image'] | |
while isinstance(img, list): | |
img = img[0] | |
new_im = torch.from_numpy(img).permute(2, 0, 1).cuda().float() | |
if img_transforms is not None: | |
new_im = img_transforms(new_im)[None] | |
else: | |
new_im = new_im[None] | |
img = (img * 255).astype(np.uint8) | |
fn = image_name | |
camera_model, width, height, params = all_scene_query_info[scene_name][fn] | |
camera = Camera(id=-1, model=camera_model, width=width, height=height, params=params) | |
curr_frame = Frame(image=img, camera=camera, id=0, name=fn, scene_name=scene_name) | |
gt_sub_map = locMap.sub_maps[curr_frame.scene_name] | |
if gt_sub_map.gt_poses is not None and curr_frame.name in gt_sub_map.gt_poses.keys(): | |
curr_frame.gt_qvec = gt_sub_map.gt_poses[curr_frame.name]['qvec'] | |
curr_frame.gt_tvec = gt_sub_map.gt_poses[curr_frame.name]['tvec'] | |
t_start = time.time() | |
encoder_out = local_feat.extract_local_global(data={'image': new_im}, | |
config= | |
{ | |
# 'min_keypoints': 128, | |
'max_keypoints': config['eval_max_keypoints'], | |
} | |
) | |
t_feat = time.time() - t_start | |
# global_descriptors_cuda = encoder_out['global_descriptors'] | |
# scores_cuda = encoder_out['scores'][0][None] | |
# kpts_cuda = encoder_out['keypoints'][0][None] | |
# descriptors_cuda = encoder_out['descriptors'][0][None].permute(0, 2, 1) | |
sparse_scores = pred['scores'] | |
sparse_descs = pred['descriptors'] | |
sparse_kpts = pred['keypoints'] | |
gt_seg = pred['gt_seg'] | |
curr_frame.add_keypoints(keypoints=np.hstack([sparse_kpts[0].cpu().numpy(), | |
sparse_scores[0].cpu().numpy().reshape(-1, 1)]), | |
descriptors=sparse_descs[0].cpu().numpy()) | |
curr_frame.time_feat = t_feat | |
t_start = time.time() | |
_, seg_descriptors = local_feat.sample(score_map=encoder_out['score_map'], | |
semi_descs=encoder_out['mid_features'], | |
# kpts=kpts_cuda[0], | |
kpts=sparse_kpts[0], | |
norm_desc=config['norm_desc']) | |
rec_out = rec_model({'scores': sparse_scores, | |
'seg_descriptors': seg_descriptors[None].permute(0, 2, 1), | |
'keypoints': sparse_kpts, | |
'image': new_im}) | |
t_rec = time.time() - t_start | |
curr_frame.time_rec = t_rec | |
pred = { | |
# 'scores': scores_cuda, | |
# 'keypoints': kpts_cuda, | |
# 'descriptors': descriptors_cuda, | |
# 'global_descriptors': global_descriptors_cuda, | |
'image_size': np.array([img.shape[1], img.shape[0]])[None], | |
} | |
pred = {**pred, **rec_out} | |
pred_seg = torch.max(pred['prediction'], dim=2)[1] # [B, N, C] | |
pred_seg = pred_seg[0].cpu().numpy() | |
kpts = sparse_kpts[0].cpu().numpy() | |
img_pred_seg = vis_seg_point(img=img, kpts=kpts, segs=pred_seg, seg_color=seg_color, radius=9) | |
show_text = 'kpts: {:d}'.format(kpts.shape[0]) | |
img_pred_seg = cv2.putText(img=img_pred_seg, text=show_text, | |
org=(50, 30), | |
fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
fontScale=1, color=(0, 0, 255), | |
thickness=2, lineType=cv2.LINE_AA) | |
curr_frame.image_rec = img_pred_seg | |
if show: | |
cv2.imshow('img', img) | |
key = cv2.waitKey(1) | |
if key == ord('q'): | |
exit(0) | |
elif key == ord('s'): | |
show_time = -1 | |
elif key == ord('c'): | |
show_time = 1 | |
segmentations = pred['prediction'][0] # .cpu().numpy() # [N, C] | |
curr_frame.add_segmentations(segmentations=segmentations, | |
filtering_threshold=config['localization']['pre_filtering_th']) | |
# Step1: do tracker first | |
success = not mTracker.lost and tracking | |
if success: | |
success = mTracker.run(frame=curr_frame) | |
if not success: | |
success = locMap.run(q_frame=curr_frame) | |
if success: | |
curr_frame.update_point3ds() | |
if tracking: | |
mTracker.lost = False | |
mTracker.last_frame = curr_frame | |
# ''' | |
pred_seg = torch.max(pred['prediction'], dim=-1)[1] # [B, N, C] | |
pred_seg = pred_seg[0].cpu().numpy() | |
gt_seg = gt_seg[0].cpu().numpy() | |
iou = compute_iou(pred=pred_seg, target=gt_seg, n_class=pred_seg.shape[0], | |
ignored_ids=[0]) # 0 - background | |
prec = compute_precision(pred=pred_seg, target=gt_seg, ignored_ids=[0]) | |
kpts = sparse_kpts[0].cpu().numpy() | |
if scene not in seg_results.keys(): | |
seg_results[scene] = { | |
'day': { | |
'prec': [], | |
'iou': [], | |
'kpts': [], | |
}, | |
'night': { | |
'prec': [], | |
'iou': [], | |
'kpts': [], | |
} | |
} | |
if fn.find('night') >= 0: | |
seg_results[scene]['night']['prec'].append(prec) | |
seg_results[scene]['night']['iou'].append(iou) | |
seg_results[scene]['night']['kpts'].append(kpts.shape[0]) | |
else: | |
seg_results[scene]['day']['prec'].append(prec) | |
seg_results[scene]['day']['iou'].append(iou) | |
seg_results[scene]['day']['kpts'].append(kpts.shape[0]) | |
print_text = 'name: {:s}, kpts: {:d}, iou: {:.3f}, prec: {:.3f}'.format(fn, kpts.shape[0], iou, | |
prec) | |
print(print_text) | |
# ''' | |
t_feat = curr_frame.time_feat | |
t_rec = curr_frame.time_rec | |
t_loc = curr_frame.time_loc | |
t_ref = curr_frame.time_ref | |
t_total = t_feat + t_rec + t_loc + t_ref | |
time_results['feat'].append(t_feat) | |
time_results['rec'].append(t_rec) | |
time_results['loc'].append(t_loc) | |
time_results['ref'].append(t_ref) | |
time_results['total'].append(t_total) | |
poses[scene + '/' + fn] = (curr_frame.qvec, curr_frame.tvec) | |
q_err, t_err = curr_frame.compute_pose_error() | |
if q_err <= 5 and t_err <= 0.05: | |
err_ths_cnt[0] = err_ths_cnt[0] + 1 | |
if q_err <= 2 and t_err <= 0.25: | |
err_ths_cnt[1] = err_ths_cnt[1] + 1 | |
if q_err <= 5 and t_err <= 0.5: | |
err_ths_cnt[2] = err_ths_cnt[2] + 1 | |
if q_err <= 10 and t_err <= 5: | |
err_ths_cnt[3] = err_ths_cnt[3] + 1 | |
if success: | |
success_cases.append(scene + '/' + fn) | |
print_text = 'qname: {:s} localization success {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( | |
scene + '/' + fn, len(success_cases), n_total, q_err, t_err, err_ths_cnt[0], | |
err_ths_cnt[1], | |
err_ths_cnt[2], | |
err_ths_cnt[3], | |
n_total, | |
t_feat, t_rec, t_loc, t_ref, t_total | |
) | |
else: | |
failed_cases.append(scene + '/' + fn) | |
print_text = 'qname: {:s} localization fail {:d}/{:d}, q_err: {:.2f}, t_err: {:.2f}, {:d}/{:d}/{:d}/{:d}/{:d}, time: {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}'.format( | |
scene + '/' + fn, len(failed_cases), n_total, q_err, t_err, err_ths_cnt[0], | |
err_ths_cnt[1], | |
err_ths_cnt[2], | |
err_ths_cnt[3], | |
n_total, t_feat, t_rec, t_loc, t_ref, t_total) | |
print(print_text) | |