qitaoz's picture
Upload 57 files
4562a06 verified
import numpy as np
import torch
from pytorch3d.renderer import FoVPerspectiveCameras
from pytorch3d.transforms import quaternion_to_matrix
def generate_random_rotations(N=1, device="cpu"):
q = torch.randn(N, 4, device=device)
q = q / q.norm(dim=-1, keepdim=True)
return quaternion_to_matrix(q)
def symmetric_orthogonalization(x):
"""Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
x: should have size [batch_size, 9]
Output has size [batch_size, 3, 3], where each inner 3x3 matrix is in SO(3).
"""
m = x.view(-1, 3, 3)
u, s, v = torch.svd(m)
vt = torch.transpose(v, 1, 2)
det = torch.det(torch.matmul(u, vt))
det = det.view(-1, 1, 1)
vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1)
r = torch.matmul(u, vt)
return r
def get_permutations(num_images):
permutations = []
for i in range(0, num_images):
for j in range(0, num_images):
if i != j:
permutations.append((j, i))
return permutations
def n_to_np_rotations(num_frames, n_rots):
R_pred_rel = []
permutations = get_permutations(num_frames)
for i, j in permutations:
R_pred_rel.append(n_rots[i].T @ n_rots[j])
R_pred_rel = torch.stack(R_pred_rel)
return R_pred_rel
def compute_angular_error_batch(rotation1, rotation2):
R_rel = np.einsum("Bij,Bjk ->Bik", rotation2, rotation1.transpose(0, 2, 1))
t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2
theta = np.arccos(np.clip(t, -1, 1))
return theta * 180 / np.pi
# A should be GT, B should be predicted
def compute_optimal_alignment(A, B):
"""
Compute the optimal scale s, rotation R, and translation t that minimizes:
|| A - (s * B @ R + T) || ^ 2
Reference: Umeyama (TPAMI 91)
Args:
A (torch.Tensor): (N, 3).
B (torch.Tensor): (N, 3).
Returns:
s (float): scale.
R (torch.Tensor): rotation matrix (3, 3).
t (torch.Tensor): translation (3,).
"""
A_bar = A.mean(0)
B_bar = B.mean(0)
# normally with R @ B, this would be A @ B.T
H = (B - B_bar).T @ (A - A_bar)
U, S, Vh = torch.linalg.svd(H, full_matrices=True)
s = torch.linalg.det(U @ Vh)
S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
variance = torch.sum((B - B_bar) ** 2)
scale = 1 / variance * torch.trace(torch.diag(S) @ S_prime)
R = U @ S_prime @ Vh
t = A_bar - scale * B_bar @ R
A_hat = scale * B @ R + t
return A_hat, scale, R, t
def compute_optimal_translation_alignment(T_A, T_B, R_B):
"""
Assuming right-multiplied rotation matrices.
E.g., for world2cam R and T, a world coordinate is transformed to camera coordinate
system using X_cam = X_world.T @ R + T = R.T @ X_world + T
Finds s, t that minimizes || T_A - (s * T_B + R_B.T @ t) ||^2
Args:
T_A (torch.Tensor): Target translation (N, 3).
T_B (torch.Tensor): Initial translation (N, 3).
R_B (torch.Tensor): Initial rotation (N, 3, 3).
Returns:
T_A_hat (torch.Tensor): s * T_B + t @ R_B (N, 3).
scale s (torch.Tensor): (1,).
translation t (torch.Tensor): (1, 3).
"""
n = len(T_A)
T_A = T_A.unsqueeze(2)
T_B = T_B.unsqueeze(2)
A = torch.sum(T_B * T_A)
B = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) @ (R_B @ T_A).sum(0) / n
C = torch.sum(T_B * T_B)
D = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0)
E = (D * D).sum() / n
s = (A - B.sum()) / (C - E.sum())
t = (R_B @ (T_A - s * T_B)).sum(0) / n
T_A_hat = s * T_B + R_B.transpose(1, 2) @ t
return T_A_hat.squeeze(2), s, t.transpose(1, 0)
def get_error(predict_rotations, R_pred, T_pred, R_gt, T_gt, gt_scene_scale):
if predict_rotations:
cameras_gt = FoVPerspectiveCameras(R=R_gt, T=T_gt)
cc_gt = cameras_gt.get_camera_center()
cameras_pred = FoVPerspectiveCameras(R=R_pred, T=T_pred)
cc_pred = cameras_pred.get_camera_center()
A_hat, _, _, _ = compute_optimal_alignment(cc_gt, cc_pred)
norm = torch.linalg.norm(cc_gt - A_hat, dim=1) / gt_scene_scale
norms = np.ndarray.tolist(norm.detach().cpu().numpy())
return norms, A_hat
else:
T_A_hat, _, _ = compute_optimal_translation_alignment(T_gt, T_pred, R_pred)
norm = torch.linalg.norm(T_gt - T_A_hat, dim=1) / gt_scene_scale
norms = np.ndarray.tolist(norm.detach().cpu().numpy())
return norms, T_A_hat