import torch import cv2 import numpy as np from collections import OrderedDict from loguru import logger from kornia.geometry.epipolar import numeric from kornia.geometry.conversions import convert_points_to_homogeneous # --- METRICS --- def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): # angle error between 2 vectors t_gt = T_0to1[:3, 3] n = np.linalg.norm(t) * np.linalg.norm(t_gt) t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging t_err = 0 # angle error between 2 rotation matrices R_gt = T_0to1[:3, :3] cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 cos = np.clip(cos, -1.0, 1.0) # handle numercial errors R_err = np.rad2deg(np.abs(np.arccos(cos))) return t_err, R_err def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): """Squared symmetric epipolar distance. This can be seen as a biased estimation of the reprojection error. Args: pts0 (torch.Tensor): [N, 2] E (torch.Tensor): [3, 3] """ pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] pts0 = convert_points_to_homogeneous(pts0) pts1 = convert_points_to_homogeneous(pts1) Ep0 = pts0 @ E.T # [N, 3] p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] Etp1 = pts1 @ E # [N, 3] d = p1Ep0**2 * ( 1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2) + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2) ) # N return d def compute_symmetrical_epipolar_errors(data): """ Update: data (dict):{"epi_errs": [M]} """ Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3]) E_mat = Tx @ data["T_0to1"][:, :3, :3] m_bids = data["m_bids"] pts0 = data["mkpts0_f"] pts1 = data["mkpts1_f"] epi_errs = [] for bs in range(Tx.size(0)): mask = m_bids == bs epi_errs.append( symmetric_epipolar_distance( pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs] ) ) epi_errs = torch.cat(epi_errs, dim=0) data.update({"epi_errs": epi_errs}) def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): if len(kpts0) < 5: return None # normalize keypoints kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] # normalize ransac threshold ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) # compute pose with cv2 E, mask = cv2.findEssentialMat( kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC ) if E is None: print("\nE is None while trying to recover pose.\n") return None # recover pose from E best_num_inliers = 0 ret = None for _E in np.split(E, len(E) / 3): n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) if n > best_num_inliers: ret = (R, t[:, 0], mask.ravel() > 0) best_num_inliers = n return ret def compute_pose_errors(data, config=None, ransac_thr=0.5, ransac_conf=0.99999): """ Update: data (dict):{ "R_errs" List[float]: [N] "t_errs" List[float]: [N] "inliers" List[np.ndarray]: [N] } """ pixel_thr = ( config.TRAINER.RANSAC_PIXEL_THR if config is not None else ransac_thr ) # 0.5 conf = config.TRAINER.RANSAC_CONF if config is not None else ransac_conf # 0.99999 data.update({"R_errs": [], "t_errs": [], "inliers": []}) m_bids = data["m_bids"].cpu().numpy() pts0 = data["mkpts0_f"].cpu().numpy() pts1 = data["mkpts1_f"].cpu().numpy() K0 = data["K0"].cpu().numpy() K1 = data["K1"].cpu().numpy() T_0to1 = data["T_0to1"].cpu().numpy() for bs in range(K0.shape[0]): mask = m_bids == bs ret = estimate_pose( pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf ) if ret is None: data["R_errs"].append(np.inf) data["t_errs"].append(np.inf) data["inliers"].append(np.array([]).astype(np.bool)) else: R, t, inliers = ret t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) data["R_errs"].append(R_err) data["t_errs"].append(t_err) data["inliers"].append(inliers) # --- METRIC AGGREGATION --- def error_auc(errors, thresholds): """ Args: errors (list): [N,] thresholds (list) """ errors = [0] + sorted(list(errors)) recall = list(np.linspace(0, 1, len(errors))) aucs = [] thresholds = [5, 10, 20] for thr in thresholds: last_index = np.searchsorted(errors, thr) y = recall[:last_index] + [recall[last_index - 1]] x = errors[:last_index] + [thr] aucs.append(np.trapz(y, x) / thr) return {f"auc@{t}": auc for t, auc in zip(thresholds, aucs)} def epidist_prec(errors, thresholds, ret_dict=False): precs = [] for thr in thresholds: prec_ = [] for errs in errors: correct_mask = errs < thr prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) precs.append(np.mean(prec_) if len(prec_) > 0 else 0) if ret_dict: return {f"prec@{t:.0e}": prec for t, prec in zip(thresholds, precs)} else: return precs def aggregate_metrics(metrics, epi_err_thr=5e-4): """Aggregate metrics for the whole dataset: (This method should be called once per dataset) 1. AUC of the pose error (angular) at the threshold [5, 10, 20] 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) """ # filter duplicates unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics["identifiers"])) unq_ids = list(unq_ids.values()) logger.info(f"Aggregating metrics over {len(unq_ids)} unique items...") # pose auc angular_thresholds = [5, 10, 20] pose_errors = np.max(np.stack([metrics["R_errs"], metrics["t_errs"]]), axis=0)[ unq_ids ] aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) # matching precision dist_thresholds = [epi_err_thr] precs = epidist_prec( np.array(metrics["epi_errs"], dtype=object)[unq_ids], dist_thresholds, True ) # (prec@err_thr) return {**aucs, **precs}