import warnings import numpy as np import cv2 import math import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import torch.nn.functional as F from PIL import Image import kornia def recover_pose(E, kpts0, kpts1, K0, K1, mask): best_num_inliers = 0 K0inv = np.linalg.inv(K0[:2, :2]) K1inv = np.linalg.inv(K1[:2, :2]) kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T for _E in np.split(E, len(E) / 3): n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) if n > best_num_inliers: best_num_inliers = n ret = (R, t, mask.ravel() > 0) return ret # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py # --- GEOMETRY --- def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): if len(kpts0) < 5: return None K0inv = np.linalg.inv(K0[:2, :2]) K1inv = np.linalg.inv(K1[:2, :2]) kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T E, mask = cv2.findEssentialMat( kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf ) ret = None if E is not None: best_num_inliers = 0 for _E in np.split(E, len(E) / 3): n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) if n > best_num_inliers: best_num_inliers = n ret = (R, t, mask.ravel() > 0) return ret def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): if len(kpts0) < 5: return None method = cv2.USAC_ACCURATE F, mask = cv2.findFundamentalMat( kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000, ) E = K1.T @ F @ K0 ret = None if E is not None: best_num_inliers = 0 K0inv = np.linalg.inv(K0[:2, :2]) K1inv = np.linalg.inv(K1[:2, :2]) kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T for _E in np.split(E, len(E) / 3): n, R, t, _ = cv2.recoverPose( _E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask ) if n > best_num_inliers: best_num_inliers = n ret = (R, t, mask.ravel() > 0) return ret def unnormalize_coords(x_n, h, w): x = torch.stack( (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1 ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] return x def rotate_intrinsic(K, n): base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) rot = np.linalg.matrix_power(base_rot, n) return rot @ K def rotate_pose_inplane(i_T_w, rot): rotation_matrices = [ np.array( [ [np.cos(r), -np.sin(r), 0.0, 0.0], [np.sin(r), np.cos(r), 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0], ], dtype=np.float32, ) for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] ] return np.dot(rotation_matrices[rot], i_T_w) def scale_intrinsics(K, scales): scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) return np.dot(scales, K) def to_homogeneous(points): return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) def angle_error_mat(R1, R2): cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds return np.rad2deg(np.abs(np.arccos(cos))) def angle_error_vec(v1, v2): n = np.linalg.norm(v1) * np.linalg.norm(v2) return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) def compute_pose_error(T_0to1, R, t): R_gt = T_0to1[:3, :3] t_gt = T_0to1[:3, 3] error_t = angle_error_vec(t.squeeze(), t_gt) error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation error_R = angle_error_mat(R, R_gt) return error_t, error_R def pose_auc(errors, thresholds): sort_idx = np.argsort(errors) errors = np.array(errors.copy())[sort_idx] recall = (np.arange(len(errors)) + 1) / len(errors) errors = np.r_[0.0, errors] recall = np.r_[0.0, recall] aucs = [] for t in thresholds: last_index = np.searchsorted(errors, t) r = np.r_[recall[:last_index], recall[last_index - 1]] e = np.r_[errors[:last_index], t] aucs.append(np.trapz(r, x=e) / t) return aucs # From Patch2Pix https://github.com/GrumpyZhou/patch2pix def get_depth_tuple_transform_ops_nearest_exact(resize=None): ops = [] if resize: ops.append(TupleResizeNearestExact(resize)) return TupleCompose(ops) def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): ops = [] if resize: ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR)) return TupleCompose(ops) def get_tuple_transform_ops( resize=None, normalize=True, unscale=False, clahe=False, colorjiggle_params=None ): ops = [] if resize: ops.append(TupleResize(resize)) ops.append(TupleToTensorScaled()) if normalize: ops.append( TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ) # Imagenet mean/std return TupleCompose(ops) class ToTensorScaled(object): """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" def __call__(self, im): if not isinstance(im, torch.Tensor): im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) im /= 255.0 return torch.from_numpy(im) else: return im def __repr__(self): return "ToTensorScaled(./255)" class TupleToTensorScaled(object): def __init__(self): self.to_tensor = ToTensorScaled() def __call__(self, im_tuple): return [self.to_tensor(im) for im in im_tuple] def __repr__(self): return "TupleToTensorScaled(./255)" class ToTensorUnscaled(object): """Convert a RGB PIL Image to a CHW ordered Tensor""" def __call__(self, im): return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) def __repr__(self): return "ToTensorUnscaled()" class TupleToTensorUnscaled(object): """Convert a RGB PIL Image to a CHW ordered Tensor""" def __init__(self): self.to_tensor = ToTensorUnscaled() def __call__(self, im_tuple): return [self.to_tensor(im) for im in im_tuple] def __repr__(self): return "TupleToTensorUnscaled()" class TupleResizeNearestExact: def __init__(self, size): self.size = size def __call__(self, im_tuple): return [ F.interpolate(im, size=self.size, mode="nearest-exact") for im in im_tuple ] def __repr__(self): return "TupleResizeNearestExact(size={})".format(self.size) class TupleResize(object): def __init__(self, size, mode=InterpolationMode.BICUBIC): self.size = size self.resize = transforms.Resize(size, mode) def __call__(self, im_tuple): return [self.resize(im) for im in im_tuple] def __repr__(self): return "TupleResize(size={})".format(self.size) class Normalize: def __call__(self, im): mean = im.mean(dim=(1, 2), keepdims=True) std = im.std(dim=(1, 2), keepdims=True) return (im - mean) / std class TupleNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std self.normalize = transforms.Normalize(mean=mean, std=std) def __call__(self, im_tuple): c, h, w = im_tuple[0].shape if c > 3: warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb") return [self.normalize(im[:3]) for im in im_tuple] def __repr__(self): return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) class TupleCompose(object): def __init__(self, transforms): self.transforms = transforms def __call__(self, im_tuple): for t in self.transforms: im_tuple = t(im_tuple) return im_tuple def __repr__(self): format_string = self.__class__.__name__ + "(" for t in self.transforms: format_string += "\n" format_string += " {0}".format(t) format_string += "\n)" return format_string @torch.no_grad() def cls_to_flow(cls, deterministic_sampling=True): B, C, H, W = cls.shape device = cls.device res = round(math.sqrt(C)) G = torch.meshgrid( *[ torch.linspace(-1 + 1 / res, 1 - 1 / res, steps=res, device=device) for _ in range(2) ] ) G = torch.stack([G[1], G[0]], dim=-1).reshape(C, 2) if deterministic_sampling: sampled_cls = cls.max(dim=1).indices else: sampled_cls = torch.multinomial( cls.permute(0, 2, 3, 1).reshape(B * H * W, C).softmax(dim=-1), 1 ).reshape(B, H, W) flow = G[sampled_cls] return flow @torch.no_grad() def cls_to_flow_refine(cls): B, C, H, W = cls.shape device = cls.device res = round(math.sqrt(C)) G = torch.meshgrid( *[ torch.linspace(-1 + 1 / res, 1 - 1 / res, steps=res, device=device) for _ in range(2) ] ) G = torch.stack([G[1], G[0]], dim=-1).reshape(C, 2) cls = cls.softmax(dim=1) mode = cls.max(dim=1).indices index = ( torch.stack((mode - 1, mode, mode + 1, mode - res, mode + res), dim=1) .clamp(0, C - 1) .long() ) neighbours = torch.gather(cls, dim=1, index=index)[..., None] flow = ( neighbours[:, 0] * G[index[:, 0]] + neighbours[:, 1] * G[index[:, 1]] + neighbours[:, 2] * G[index[:, 2]] + neighbours[:, 3] * G[index[:, 3]] + neighbours[:, 4] * G[index[:, 4]] ) tot_prob = neighbours.sum(dim=1) flow = flow / tot_prob return flow def get_gt_warp( depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode="bilinear", relative_depth_error_threshold=0.05, H=None, W=None, ): if H is None: B, H, W = depth1.shape else: B = depth1.shape[0] with torch.no_grad(): x1_n = torch.meshgrid( *[ torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=depth1.device) for n in (B, H, W) ] ) x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) mask, x2 = warp_kpts( x1_n.double(), depth1.double(), depth2.double(), T_1to2.double(), K1.double(), K2.double(), depth_interpolation_mode=depth_interpolation_mode, relative_depth_error_threshold=relative_depth_error_threshold, ) prob = mask.float().reshape(B, H, W) x2 = x2.reshape(B, H, W, 2) return x2, prob @torch.no_grad() def warp_kpts( kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask=False, return_relative_depth_error=False, depth_interpolation_mode="bilinear", relative_depth_error_threshold=0.05, ): """Warp kpts0 from I0 to I1 with depth, K and Rt Also check covisibility and depth consistency. Depth is consistent if relative error < 0.2 (hard-coded). # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here Args: kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) depth0 (torch.Tensor): [N, H, W], depth1 (torch.Tensor): [N, H, W], T_0to1 (torch.Tensor): [N, 3, 4], K0 (torch.Tensor): [N, 3, 3], K1 (torch.Tensor): [N, 3, 3], Returns: calculable_mask (torch.Tensor): [N, L] warped_keypoints0 (torch.Tensor): [N, L, 2] """ ( n, h, w, ) = depth0.shape if depth_interpolation_mode == "combined": # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation if smooth_mask: raise NotImplementedError("Combined bilinear and NN warp not implemented") valid_bilinear, warp_bilinear = warp_kpts( kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask=smooth_mask, return_relative_depth_error=return_relative_depth_error, depth_interpolation_mode="bilinear", relative_depth_error_threshold=relative_depth_error_threshold, ) valid_nearest, warp_nearest = warp_kpts( kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask=smooth_mask, return_relative_depth_error=return_relative_depth_error, depth_interpolation_mode="nearest-exact", relative_depth_error_threshold=relative_depth_error_threshold, ) nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) warp = warp_bilinear.clone() warp[nearest_valid_bilinear_invalid] = warp_nearest[ nearest_valid_bilinear_invalid ] valid = valid_bilinear | valid_nearest return valid, warp kpts0_depth = F.grid_sample( depth0[:, None], kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False, )[:, 0, :, 0] kpts0 = torch.stack( (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] # Sample depth, get calculable_mask on depth != 0 nonzero_mask = kpts0_depth != 0 # Unproject kpts0_h = ( torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] ) # (N, L, 3) kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) kpts0_cam = kpts0_n # Rigid Transform w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] # Project w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) w_kpts0 = w_kpts0_h[:, :, :2] / ( w_kpts0_h[:, :, [2]] + 1e-4 ) # (N, L, 2), +1e-4 to avoid zero depth # Covisible Check h, w = depth1.shape[1:3] covisible_mask = ( (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w - 1) * (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h - 1) ) w_kpts0 = torch.stack( (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] # w_kpts0[~covisible_mask, :] = -5 # xd w_kpts0_depth = F.grid_sample( depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False, )[:, 0, :, 0] relative_depth_error = ( (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth ).abs() if not smooth_mask: consistent_mask = relative_depth_error < relative_depth_error_threshold else: consistent_mask = (-relative_depth_error / smooth_mask).exp() valid_mask = nonzero_mask * covisible_mask * consistent_mask if return_relative_depth_error: return relative_depth_error, w_kpts0 else: return valid_mask, w_kpts0 imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) imagenet_std = torch.tensor([0.229, 0.224, 0.225]) def numpy_to_pil(x: np.ndarray): """ Args: x: Assumed to be of shape (h,w,c) """ if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() if x.max() <= 1.01: x *= 255 x = x.astype(np.uint8) return Image.fromarray(x) def tensor_to_pil(x, unnormalize=False): if unnormalize: x = x * (imagenet_std[:, None, None].to(x.device)) + ( imagenet_mean[:, None, None].to(x.device) ) x = x.detach().permute(1, 2, 0).cpu().numpy() x = np.clip(x, 0.0, 1.0) return numpy_to_pil(x) def to_cuda(batch): for key, value in batch.items(): if isinstance(value, torch.Tensor): batch[key] = value.cuda() return batch def to_cpu(batch): for key, value in batch.items(): if isinstance(value, torch.Tensor): batch[key] = value.cpu() return batch def get_pose(calib): w, h = np.array(calib["imsize"])[0] return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w def compute_relative_pose(R1, t1, R2, t2): rots = R2 @ (R1.T) trans = -rots @ t1 + t2 return rots, trans @torch.no_grad() def reset_opt(opt): for group in opt.param_groups: for p in group["params"]: if p.requires_grad: state = opt.state[p] # State initialization # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p) # Exponential moving average of gradient difference state["exp_avg_diff"] = torch.zeros_like(p) def flow_to_pixel_coords(flow, h1, w1): flow = torch.stack( ( w1 * (flow[..., 0] + 1) / 2, h1 * (flow[..., 1] + 1) / 2, ), axis=-1, ) return flow def flow_to_normalized_coords(flow, h1, w1): flow = torch.stack( ( 2 * (flow[..., 0]) / w1 - 1, 2 * (flow[..., 1]) / h1 - 1, ), axis=-1, ) return flow def warp_to_pixel_coords(warp, h1, w1, h2, w2): warp1 = warp[..., :2] warp1 = torch.stack( ( w1 * (warp1[..., 0] + 1) / 2, h1 * (warp1[..., 1] + 1) / 2, ), axis=-1, ) warp2 = warp[..., 2:] warp2 = torch.stack( ( w2 * (warp2[..., 0] + 1) / 2, h2 * (warp2[..., 1] + 1) / 2, ), axis=-1, ) return torch.cat((warp1, warp2), dim=-1) def signed_point_line_distance(point, line, eps: float = 1e-9): r"""Return the distance from points to lines. Args: point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`. line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`. eps: Small constant for safe sqrt. Returns: the computed distance with shape :math:`(*, N)`. """ if not point.shape[-1] in (2, 3): raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}") if not line.shape[-1] == 3: raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}") numerator = ( line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2] ) denominator = line[..., :2].norm(dim=-1) return numerator / (denominator + eps) def signed_left_to_right_epipolar_distance(pts1, pts2, Fm): r"""Return one-sided epipolar distance for correspondences given the fundamental matrix. This method measures the distance from points in the right images to the epilines of the corresponding points in the left images as they reflect in the right images. Args: pts1: correspondences from the left images with shape :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. pts2: correspondences from the right images with shape :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to avoid ambiguity with torch.nn.functional. Returns: the computed Symmetrical distance with shape :math:`(*, N)`. """ import kornia if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3): raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}") if pts1.shape[-1] == 2: pts1 = kornia.geometry.convert_points_to_homogeneous(pts1) F_t = Fm.transpose(dim0=-2, dim1=-1) line1_in_2 = pts1 @ F_t return signed_point_line_distance(pts2, line1_in_2) def get_grid(b, h, w, device): grid = torch.meshgrid( *[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (b, h, w)] ) grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2) return grid