Vincentqyw
add app queue
c7a0722
import warnings
import numpy as np
import math
import cv2
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import torch.nn.functional as F
from PIL import Image
from einops import rearrange
import torch
from time import perf_counter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def recover_pose(E, kpts0, kpts1, K0, K1, mask):
best_num_inliers = 0
K0inv = np.linalg.inv(K0[:2, :2])
K1inv = np.linalg.inv(K1[:2, :2])
kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
for _E in np.split(E, len(E) / 3):
n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
if n > best_num_inliers:
best_num_inliers = n
ret = (R, t, mask.ravel() > 0)
return ret
# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
# --- GEOMETRY ---
def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
if len(kpts0) < 5:
return None
K0inv = np.linalg.inv(K0[:2, :2])
K1inv = np.linalg.inv(K1[:2, :2])
kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
E, mask = cv2.findEssentialMat(
kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
)
ret = None
if E is not None:
best_num_inliers = 0
for _E in np.split(E, len(E) / 3):
n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
if n > best_num_inliers:
best_num_inliers = n
ret = (R, t, mask.ravel() > 0)
return ret
def get_grid(B, H, W, device=device):
x1_n = torch.meshgrid(
*[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (B, H, W)]
)
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
return x1_n
@torch.no_grad()
def finite_diff_hessian(f: tuple(["B", "H", "W"]), device=device):
dxx = (
torch.tensor([[0, 0, 0], [1, -2, 1], [0, 0, 0]], device=device)[None, None] / 2
)
dxy = (
torch.tensor([[1, 0, -1], [0, 0, 0], [-1, 0, 1]], device=device)[None, None] / 4
)
dyy = dxx.mT
Hxx = F.conv2d(f[:, None], dxx, padding=1)[:, 0]
Hxy = F.conv2d(f[:, None], dxy, padding=1)[:, 0]
Hyy = F.conv2d(f[:, None], dyy, padding=1)[:, 0]
H = torch.stack((Hxx, Hxy, Hxy, Hyy), dim=-1).reshape(*f.shape, 2, 2)
return H
def finite_diff_grad(f: tuple(["B", "H", "W"]), device=device):
dx = torch.tensor([[0, 0, 0], [-1, 0, 1], [0, 0, 0]], device=device)[None, None] / 2
dy = dx.mT
gx = F.conv2d(f[:, None], dx, padding=1)
gy = F.conv2d(f[:, None], dy, padding=1)
g = torch.cat((gx, gy), dim=1)
return g
def fast_inv_2x2(matrix: tuple[..., 2, 2], eps=1e-10):
return (
1
/ (torch.linalg.det(matrix)[..., None, None] + eps)
* torch.stack(
(
matrix[..., 1, 1],
-matrix[..., 0, 1],
-matrix[..., 1, 0],
matrix[..., 0, 0],
),
dim=-1,
).reshape(*matrix.shape)
)
def newton_step(f: tuple["B", "H", "W"], inds, device=device):
B, H, W = f.shape
Hess = finite_diff_hessian(f).reshape(B, H * W, 2, 2)
Hess = torch.gather(Hess, dim=1, index=inds[..., None].expand(B, -1, 2, 2))
grad = finite_diff_grad(f).reshape(B, H * W, 2)
grad = torch.gather(grad, dim=1, index=inds)
Hessinv = fast_inv_2x2(Hess - torch.eye(2, device=device)[None, None])
step = Hessinv @ grad[..., None]
return step[..., 0]
@torch.no_grad()
def sample_keypoints(
scoremap,
num_samples=8192,
device=device,
use_nms=True,
sample_topk=False,
return_scoremap=False,
sharpen=False,
upsample=False,
increase_coverage=False,
):
# scoremap = scoremap**2
log_scoremap = (scoremap + 1e-10).log()
if upsample:
log_scoremap = F.interpolate(
log_scoremap[:, None], scale_factor=3, mode="bicubic", align_corners=False
)[
:, 0
] # .clamp(min = 0)
scoremap = log_scoremap.exp()
B, H, W = scoremap.shape
if increase_coverage:
weights = (-torch.linspace(-2, 2, steps=51, device=device) ** 2).exp()[
None, None
]
# 10000 is just some number for maybe numerical stability, who knows. :), result is invariant anyway
local_density_x = F.conv2d(
(scoremap[:, None] + 1e-6) * 10000,
weights[..., None, :],
padding=(0, 51 // 2),
)
local_density = F.conv2d(
local_density_x, weights[..., None], padding=(51 // 2, 0)
)[:, 0]
scoremap = scoremap * (local_density + 1e-8) ** (-1 / 2)
grid = get_grid(B, H, W, device=device).reshape(B, H * W, 2)
if sharpen:
laplace_operator = (
torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], device=device) / 4
)
scoremap = scoremap[:, None] - 0.5 * F.conv2d(
scoremap[:, None], weight=laplace_operator, padding=1
)
scoremap = scoremap[:, 0].clamp(min=0)
if use_nms:
scoremap = scoremap * (
scoremap == F.max_pool2d(scoremap, (3, 3), stride=1, padding=1)
)
if sample_topk:
inds = torch.topk(scoremap.reshape(B, H * W), k=num_samples).indices
else:
inds = torch.multinomial(
scoremap.reshape(B, H * W), num_samples=num_samples, replacement=False
)
kps = torch.gather(grid, dim=1, index=inds[..., None].expand(B, num_samples, 2))
if return_scoremap:
return kps, torch.gather(scoremap.reshape(B, H * W), dim=1, index=inds)
return kps
@torch.no_grad()
def jacobi_determinant(warp, certainty, R=3, device=device, dtype=torch.float32):
t = perf_counter()
*dims, _ = warp.shape
warp = warp.to(dtype)
certainty = certainty.to(dtype)
dtype = warp.dtype
match_regions = torch.zeros((*dims, 4, R, R), device=device).to(dtype)
match_regions[:, 1:-1, 1:-1] = warp.unfold(1, R, 1).unfold(2, R, 1)
match_regions = (
rearrange(match_regions, "B H W D R1 R2 -> B H W (R1 R2) D")
- warp[..., None, :]
)
match_regions_cert = torch.zeros((*dims, R, R), device=device).to(dtype)
match_regions_cert[:, 1:-1, 1:-1] = certainty.unfold(1, R, 1).unfold(2, R, 1)
match_regions_cert = rearrange(match_regions_cert, "B H W R1 R2 -> B H W (R1 R2)")[
..., None
]
# print("Time for unfold", perf_counter()-t)
# t = perf_counter()
*dims, N, D = match_regions.shape
# standardize:
mu, sigma = match_regions.mean(dim=(-2, -1), keepdim=True), match_regions.std(
dim=(-2, -1), keepdim=True
)
match_regions = (match_regions - mu) / (sigma + 1e-6)
x_a, x_b = match_regions.chunk(2, -1)
A = torch.zeros((*dims, 2 * x_a.shape[-2], 4), device=device).to(dtype)
A[..., ::2, :2] = x_a * match_regions_cert
A[..., 1::2, 2:] = x_a * match_regions_cert
a_block = A[..., ::2, :2]
ata = a_block.mT @ a_block
# print("Time for ata", perf_counter()-t)
# t = perf_counter()
# atainv = torch.linalg.inv(ata+1e-5*torch.eye(2,device=device).to(dtype))
atainv = fast_inv_2x2(ata)
ATA_inv = torch.zeros((*dims, 4, 4), device=device, dtype=dtype)
ATA_inv[..., :2, :2] = atainv
ATA_inv[..., 2:, 2:] = atainv
atb = A.mT @ (match_regions_cert * x_b).reshape(*dims, N * 2, 1)
theta = ATA_inv @ atb
# print("Time for theta", perf_counter()-t)
# t = perf_counter()
J = theta.reshape(*dims, 2, 2)
abs_J_det = torch.linalg.det(
J + 1e-8 * torch.eye(2, 2, device=device).expand(*dims, 2, 2)
).abs() # Note: This should always be positive for correct warps, but still taking abs here
abs_J_logdet = (abs_J_det + 1e-12).log()
B = certainty.shape[0]
# Handle outliers
robust_abs_J_logdet = abs_J_logdet.clamp(
-3, 3
) # Shouldn't be more that exp(3) \approx 8 times zoom
# print("Time for logdet", perf_counter()-t)
# t = perf_counter()
return robust_abs_J_logdet
def get_gt_warp(
depth1,
depth2,
T_1to2,
K1,
K2,
depth_interpolation_mode="bilinear",
relative_depth_error_threshold=0.05,
H=None,
W=None,
):
if H is None:
B, H, W = depth1.shape
else:
B = depth1.shape[0]
with torch.no_grad():
x1_n = torch.meshgrid(
*[
torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=depth1.device)
for n in (B, H, W)
]
)
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
mask, x2 = warp_kpts(
x1_n.double(),
depth1.double(),
depth2.double(),
T_1to2.double(),
K1.double(),
K2.double(),
depth_interpolation_mode=depth_interpolation_mode,
relative_depth_error_threshold=relative_depth_error_threshold,
)
prob = mask.float().reshape(B, H, W)
x2 = x2.reshape(B, H, W, 2)
return torch.cat((x1_n.reshape(B, H, W, 2), x2), dim=-1), prob
def recover_pose(E, kpts0, kpts1, K0, K1, mask):
best_num_inliers = 0
K0inv = np.linalg.inv(K0[:2, :2])
K1inv = np.linalg.inv(K1[:2, :2])
kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
for _E in np.split(E, len(E) / 3):
n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
if n > best_num_inliers:
best_num_inliers = n
ret = (R, t, mask.ravel() > 0)
return ret
# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
# --- GEOMETRY ---
def estimate_pose(
kpts0,
kpts1,
K0,
K1,
norm_thresh,
conf=0.99999,
):
if len(kpts0) < 5:
return None
K0inv = np.linalg.inv(K0[:2, :2])
K1inv = np.linalg.inv(K1[:2, :2])
kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
method = cv2.USAC_ACCURATE
E, mask = cv2.findEssentialMat(
kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, method=method
)
ret = None
if E is not None:
best_num_inliers = 0
for _E in np.split(E, len(E) / 3):
n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
if n > best_num_inliers:
best_num_inliers = n
ret = (R, t, mask.ravel() > 0)
return ret
def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
if len(kpts0) < 5:
return None
method = cv2.USAC_ACCURATE
F, mask = cv2.findFundamentalMat(
kpts0,
kpts1,
ransacReprojThreshold=norm_thresh,
confidence=conf,
method=method,
maxIters=10000,
)
E = K1.T @ F @ K0
ret = None
if E is not None:
best_num_inliers = 0
K0inv = np.linalg.inv(K0[:2, :2])
K1inv = np.linalg.inv(K1[:2, :2])
kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
for _E in np.split(E, len(E) / 3):
n, R, t, _ = cv2.recoverPose(
_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask
)
if n > best_num_inliers:
best_num_inliers = n
ret = (R, t, mask.ravel() > 0)
return ret
def unnormalize_coords(x_n, h, w):
x = torch.stack(
(w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
return x
def rotate_intrinsic(K, n):
base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
rot = np.linalg.matrix_power(base_rot, n)
return rot @ K
def rotate_pose_inplane(i_T_w, rot):
rotation_matrices = [
np.array(
[
[np.cos(r), -np.sin(r), 0.0, 0.0],
[np.sin(r), np.cos(r), 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
dtype=np.float32,
)
for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
]
return np.dot(rotation_matrices[rot], i_T_w)
def scale_intrinsics(K, scales):
scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
return np.dot(scales, K)
def angle_error_mat(R1, R2):
cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
return np.rad2deg(np.abs(np.arccos(cos)))
def angle_error_vec(v1, v2):
n = np.linalg.norm(v1) * np.linalg.norm(v2)
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
def compute_pose_error(T_0to1, R, t):
R_gt = T_0to1[:3, :3]
t_gt = T_0to1[:3, 3]
error_t = angle_error_vec(t.squeeze(), t_gt)
error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
error_R = angle_error_mat(R, R_gt)
return error_t, error_R
def pose_auc(errors, thresholds):
sort_idx = np.argsort(errors)
errors = np.array(errors.copy())[sort_idx]
recall = (np.arange(len(errors)) + 1) / len(errors)
errors = np.r_[0.0, errors]
recall = np.r_[0.0, recall]
aucs = []
for t in thresholds:
last_index = np.searchsorted(errors, t)
r = np.r_[recall[:last_index], recall[last_index - 1]]
e = np.r_[errors[:last_index], t]
aucs.append(np.trapz(r, x=e) / t)
return aucs
# From Patch2Pix https://github.com/GrumpyZhou/patch2pix
def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
ops = []
if resize:
ops.append(
TupleResize(resize, mode=InterpolationMode.BILINEAR, antialias=False)
)
return TupleCompose(ops)
def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe=False):
ops = []
if resize:
ops.append(TupleResize(resize, antialias=True))
if clahe:
ops.append(TupleClahe())
if normalize:
ops.append(TupleToTensorScaled())
ops.append(
TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
) # Imagenet mean/std
else:
if unscale:
ops.append(TupleToTensorUnscaled())
else:
ops.append(TupleToTensorScaled())
return TupleCompose(ops)
class Clahe:
def __init__(self, cliplimit=2, blocksize=8) -> None:
self.clahe = cv2.createCLAHE(cliplimit, (blocksize, blocksize))
def __call__(self, im):
im_hsv = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2HSV)
im_v = self.clahe.apply(im_hsv[:, :, 2])
im_hsv[..., 2] = im_v
im_clahe = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2RGB)
return Image.fromarray(im_clahe)
class TupleClahe:
def __init__(self, cliplimit=8, blocksize=8) -> None:
self.clahe = Clahe(cliplimit, blocksize)
def __call__(self, ims):
return [self.clahe(im) for im in ims]
class ToTensorScaled(object):
"""Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
def __call__(self, im):
if not isinstance(im, torch.Tensor):
im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
im /= 255.0
return torch.from_numpy(im)
else:
return im
def __repr__(self):
return "ToTensorScaled(./255)"
class TupleToTensorScaled(object):
def __init__(self):
self.to_tensor = ToTensorScaled()
def __call__(self, im_tuple):
return [self.to_tensor(im) for im in im_tuple]
def __repr__(self):
return "TupleToTensorScaled(./255)"
class ToTensorUnscaled(object):
"""Convert a RGB PIL Image to a CHW ordered Tensor"""
def __call__(self, im):
return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
def __repr__(self):
return "ToTensorUnscaled()"
class TupleToTensorUnscaled(object):
"""Convert a RGB PIL Image to a CHW ordered Tensor"""
def __init__(self):
self.to_tensor = ToTensorUnscaled()
def __call__(self, im_tuple):
return [self.to_tensor(im) for im in im_tuple]
def __repr__(self):
return "TupleToTensorUnscaled()"
class TupleResize(object):
def __init__(self, size, mode=InterpolationMode.BICUBIC, antialias=None):
self.size = size
self.resize = transforms.Resize(size, mode, antialias=antialias)
def __call__(self, im_tuple):
return [self.resize(im) for im in im_tuple]
def __repr__(self):
return "TupleResize(size={})".format(self.size)
class Normalize:
def __call__(self, im):
mean = im.mean(dim=(1, 2), keepdims=True)
std = im.std(dim=(1, 2), keepdims=True)
return (im - mean) / std
class TupleNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
self.normalize = transforms.Normalize(mean=mean, std=std)
def __call__(self, im_tuple):
c, h, w = im_tuple[0].shape
if c > 3:
warnings.warn(f"Number of channels {c=} > 3, assuming first 3 are rgb")
return [self.normalize(im[:3]) for im in im_tuple]
def __repr__(self):
return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
class TupleCompose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, im_tuple):
for t in self.transforms:
im_tuple = t(im_tuple)
return im_tuple
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
@torch.no_grad()
def warp_kpts(
kpts0,
depth0,
depth1,
T_0to1,
K0,
K1,
smooth_mask=False,
return_relative_depth_error=False,
depth_interpolation_mode="bilinear",
relative_depth_error_threshold=0.05,
):
"""Warp kpts0 from I0 to I1 with depth, K and Rt
Also check covisibility and depth consistency.
Depth is consistent if relative error < 0.2 (hard-coded).
# https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
Args:
kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
depth0 (torch.Tensor): [N, H, W],
depth1 (torch.Tensor): [N, H, W],
T_0to1 (torch.Tensor): [N, 3, 4],
K0 (torch.Tensor): [N, 3, 3],
K1 (torch.Tensor): [N, 3, 3],
Returns:
calculable_mask (torch.Tensor): [N, L]
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
"""
(
n,
h,
w,
) = depth0.shape
if depth_interpolation_mode == "combined":
# Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
if smooth_mask:
raise NotImplementedError("Combined bilinear and NN warp not implemented")
valid_bilinear, warp_bilinear = warp_kpts(
kpts0,
depth0,
depth1,
T_0to1,
K0,
K1,
smooth_mask=smooth_mask,
return_relative_depth_error=return_relative_depth_error,
depth_interpolation_mode="bilinear",
relative_depth_error_threshold=relative_depth_error_threshold,
)
valid_nearest, warp_nearest = warp_kpts(
kpts0,
depth0,
depth1,
T_0to1,
K0,
K1,
smooth_mask=smooth_mask,
return_relative_depth_error=return_relative_depth_error,
depth_interpolation_mode="nearest-exact",
relative_depth_error_threshold=relative_depth_error_threshold,
)
nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
warp = warp_bilinear.clone()
warp[nearest_valid_bilinear_invalid] = warp_nearest[
nearest_valid_bilinear_invalid
]
valid = valid_bilinear | valid_nearest
return valid, warp
kpts0_depth = F.grid_sample(
depth0[:, None],
kpts0[:, :, None],
mode=depth_interpolation_mode,
align_corners=False,
)[:, 0, :, 0]
kpts0 = torch.stack(
(w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
# Sample depth, get calculable_mask on depth != 0
nonzero_mask = kpts0_depth != 0
# Unproject
kpts0_h = (
torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
* kpts0_depth[..., None]
) # (N, L, 3)
kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
kpts0_cam = kpts0_n
# Rigid Transform
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
# Project
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
w_kpts0 = w_kpts0_h[:, :, :2] / (
w_kpts0_h[:, :, [2]] + 1e-4
) # (N, L, 2), +1e-4 to avoid zero depth
# Covisible Check
h, w = depth1.shape[1:3]
covisible_mask = (
(w_kpts0[:, :, 0] > 0)
* (w_kpts0[:, :, 0] < w - 1)
* (w_kpts0[:, :, 1] > 0)
* (w_kpts0[:, :, 1] < h - 1)
)
w_kpts0 = torch.stack(
(2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
# w_kpts0[~covisible_mask, :] = -5 # xd
w_kpts0_depth = F.grid_sample(
depth1[:, None],
w_kpts0[:, :, None],
mode=depth_interpolation_mode,
align_corners=False,
)[:, 0, :, 0]
relative_depth_error = (
(w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
).abs()
if not smooth_mask:
consistent_mask = relative_depth_error < relative_depth_error_threshold
else:
consistent_mask = (-relative_depth_error / smooth_mask).exp()
valid_mask = nonzero_mask * covisible_mask * consistent_mask
if return_relative_depth_error:
return relative_depth_error, w_kpts0
else:
return valid_mask, w_kpts0
imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
imagenet_std = torch.tensor([0.229, 0.224, 0.225])
def numpy_to_pil(x: np.ndarray):
"""
Args:
x: Assumed to be of shape (h,w,c)
"""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
if x.max() <= 1.01:
x *= 255
x = x.astype(np.uint8)
return Image.fromarray(x)
def tensor_to_pil(x, unnormalize=False, autoscale=False):
if unnormalize:
x = x * (imagenet_std[:, None, None].to(x.device)) + (
imagenet_mean[:, None, None].to(x.device)
)
if autoscale:
if x.max() == x.min():
warnings.warn("x max == x min, cant autoscale")
else:
x = (x - x.min()) / (x.max() - x.min())
x = x.detach().permute(1, 2, 0).cpu().numpy()
x = np.clip(x, 0.0, 1.0)
return numpy_to_pil(x)
def to_cuda(batch):
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.cuda()
return batch
def to_cpu(batch):
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.cpu()
return batch
def get_pose(calib):
w, h = np.array(calib["imsize"])[0]
return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
def compute_relative_pose(R1, t1, R2, t2):
rots = R2 @ (R1.T)
trans = -rots @ t1 + t2
return rots, trans
def to_pixel_coords(flow, h1, w1):
flow = torch.stack(
(
w1 * (flow[..., 0] + 1) / 2,
h1 * (flow[..., 1] + 1) / 2,
),
axis=-1,
)
return flow
def to_normalized_coords(flow, h1, w1):
flow = torch.stack(
(
2 * (flow[..., 0]) / w1 - 1,
2 * (flow[..., 1]) / h1 - 1,
),
axis=-1,
)
return flow
def warp_to_pixel_coords(warp, h1, w1, h2, w2):
warp1 = warp[..., :2]
warp1 = torch.stack(
(
w1 * (warp1[..., 0] + 1) / 2,
h1 * (warp1[..., 1] + 1) / 2,
),
axis=-1,
)
warp2 = warp[..., 2:]
warp2 = torch.stack(
(
w2 * (warp2[..., 0] + 1) / 2,
h2 * (warp2[..., 1] + 1) / 2,
),
axis=-1,
)
return torch.cat((warp1, warp2), dim=-1)
def to_homogeneous(x):
ones = torch.ones_like(x[..., -1:])
return torch.cat((x, ones), dim=-1)
def from_homogeneous(xh, eps=1e-12):
return xh[..., :-1] / (xh[..., -1:] + eps)
def homog_transform(Homog, x):
xh = to_homogeneous(x)
yh = (Homog @ xh.mT).mT
y = from_homogeneous(yh)
return y
def get_homog_warp(Homog, H, W, device=device):
grid = torch.meshgrid(
torch.linspace(-1 + 1 / H, 1 - 1 / H, H, device=device),
torch.linspace(-1 + 1 / W, 1 - 1 / W, W, device=device),
)
x_A = torch.stack((grid[1], grid[0]), dim=-1)[None]
x_A_to_B = homog_transform(Homog, x_A)
mask = ((x_A_to_B > -1) * (x_A_to_B < 1)).prod(dim=-1).float()
return torch.cat((x_A.expand(*x_A_to_B.shape), x_A_to_B), dim=-1), mask
def dual_log_softmax_matcher(
desc_A: tuple["B", "N", "C"],
desc_B: tuple["B", "M", "C"],
inv_temperature=1,
normalize=False,
):
B, N, C = desc_A.shape
if normalize:
desc_A = desc_A / desc_A.norm(dim=-1, keepdim=True)
desc_B = desc_B / desc_B.norm(dim=-1, keepdim=True)
corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
else:
corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
logP = corr.log_softmax(dim=-2) + corr.log_softmax(dim=-1)
return logP
def dual_softmax_matcher(
desc_A: tuple["B", "N", "C"],
desc_B: tuple["B", "M", "C"],
inv_temperature=1,
normalize=False,
):
if len(desc_A.shape) < 3:
desc_A, desc_B = desc_A[None], desc_B[None]
B, N, C = desc_A.shape
if normalize:
desc_A = desc_A / desc_A.norm(dim=-1, keepdim=True)
desc_B = desc_B / desc_B.norm(dim=-1, keepdim=True)
corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
else:
corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
P = corr.softmax(dim=-2) * corr.softmax(dim=-1)
return P
def conditional_softmax_matcher(
desc_A: tuple["B", "N", "C"],
desc_B: tuple["B", "M", "C"],
inv_temperature=1,
normalize=False,
):
if len(desc_A.shape) < 3:
desc_A, desc_B = desc_A[None], desc_B[None]
B, N, C = desc_A.shape
if normalize:
desc_A = desc_A / desc_A.norm(dim=-1, keepdim=True)
desc_B = desc_B / desc_B.norm(dim=-1, keepdim=True)
corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
else:
corr = torch.einsum("b n c, b m c -> b n m", desc_A, desc_B) * inv_temperature
P_B_cond_A = corr.softmax(dim=-1)
P_A_cond_B = corr.softmax(dim=-2)
return P_A_cond_B, P_B_cond_A