AnySplat / src /utils /pose.py
alexnasa's picture
Upload 243 files
2568013 verified
import torch
import numpy as np
from src.model.encoder.vggt.utils.rotation import mat_to_quat
from src.model.encoder.vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
def convert_pt3d_RT_to_opencv(Rot, Trans):
"""
Convert Point3D extrinsic matrices to OpenCV convention.
Args:
Rot: 3D rotation matrix in Point3D format
Trans: 3D translation vector in Point3D format
Returns:
extri_opencv: 3x4 extrinsic matrix in OpenCV format
"""
rot_pt3d = np.array(Rot)
trans_pt3d = np.array(Trans)
trans_pt3d[:2] *= -1
rot_pt3d[:, :2] *= -1
rot_pt3d = rot_pt3d.transpose(1, 0)
extri_opencv = np.hstack((rot_pt3d, trans_pt3d[:, None]))
return extri_opencv
def build_pair_index(N, B=1):
"""
Build indices for all possible pairs of frames.
Args:
N: Number of frames
B: Batch size
Returns:
i1, i2: Indices for all possible pairs
"""
i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]]
return i1, i2
def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15):
"""
Calculate rotation angle error between ground truth and predicted rotations.
Args:
rot_gt: Ground truth rotation matrices
rot_pred: Predicted rotation matrices
batch_size: Batch size for reshaping the result
eps: Small value to avoid numerical issues
Returns:
Rotation angle error in degrees
"""
q_pred = mat_to_quat(rot_pred)
q_gt = mat_to_quat(rot_gt)
loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
err_q = torch.arccos(1 - 2 * loss_q)
rel_rangle_deg = err_q * 180 / np.pi
if batch_size is not None:
rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
return rel_rangle_deg
def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True):
"""
Calculate translation angle error between ground truth and predicted translations.
Args:
tvec_gt: Ground truth translation vectors
tvec_pred: Predicted translation vectors
batch_size: Batch size for reshaping the result
ambiguity: Whether to handle direction ambiguity
Returns:
Translation angle error in degrees
"""
rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
if ambiguity:
rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
if batch_size is not None:
rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
return rel_tangle_deg
def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6):
"""
Normalize the translation vectors and compute the angle between them.
Args:
t_gt: Ground truth translation vectors
t: Predicted translation vectors
eps: Small value to avoid division by zero
default_err: Default error value for invalid cases
Returns:
Angular error between translation vectors in radians
"""
t_norm = torch.norm(t, dim=1, keepdim=True)
t = t / (t_norm + eps)
t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
t_gt = t_gt / (t_gt_norm + eps)
loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
err_t = torch.acos(torch.sqrt(1 - loss_t))
err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
return err_t
def calculate_auc(r_error, t_error, max_threshold=30, return_list=False):
"""
Calculate the Area Under the Curve (AUC) for the given error arrays using PyTorch.
Args:
r_error: torch.Tensor representing R error values (Degree)
t_error: torch.Tensor representing T error values (Degree)
max_threshold: Maximum threshold value for binning the histogram
return_list: Whether to return the normalized histogram as well
Returns:
AUC value, and optionally the normalized histogram
"""
error_matrix = torch.stack((r_error, t_error), dim=1)
max_errors, _ = torch.max(error_matrix, dim=1)
histogram = torch.histc(
max_errors, bins=max_threshold + 1, min=0, max=max_threshold
)
num_pairs = float(max_errors.size(0))
normalized_histogram = histogram / num_pairs
if return_list:
return (
torch.cumsum(normalized_histogram, dim=0).mean(),
normalized_histogram,
)
return torch.cumsum(normalized_histogram, dim=0).mean()
def calculate_auc_np(r_error, t_error, max_threshold=30):
"""
Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy.
Args:
r_error: numpy array representing R error values (Degree)
t_error: numpy array representing T error values (Degree)
max_threshold: Maximum threshold value for binning the histogram
Returns:
AUC value and the normalized histogram
"""
error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
max_errors = np.max(error_matrix, axis=1)
bins = np.arange(max_threshold + 1)
histogram, _ = np.histogram(max_errors, bins=bins)
num_pairs = float(len(max_errors))
normalized_histogram = histogram.astype(float) / num_pairs
return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames):
"""
Compute rotation and translation errors between predicted and ground truth poses.
Args:
pred_se3: Predicted SE(3) transformations
gt_se3: Ground truth SE(3) transformations
num_frames: Number of frames
Returns:
Rotation and translation angle errors in degrees
"""
pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
# Compute relative camera poses between pairs
# We use closed_form_inverse to avoid potential numerical loss by torch.inverse()
relative_pose_gt = closed_form_inverse_se3(gt_se3[pair_idx_i1]).bmm(
gt_se3[pair_idx_i2]
)
relative_pose_pred = closed_form_inverse_se3(pred_se3[pair_idx_i1]).bmm(
pred_se3[pair_idx_i2]
)
# Compute the difference in rotation and translation
rel_rangle_deg = rotation_angle(
relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]
)
rel_tangle_deg = translation_angle(
relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]
)
return rel_rangle_deg, rel_tangle_deg
def align_to_first_camera(camera_poses):
"""
Align all camera poses to the first camera's coordinate frame.
Args:
camera_poses: Tensor of shape (N, 4, 4) containing camera poses as SE3 transformations
Returns:
Tensor of shape (N, 4, 4) containing aligned camera poses
"""
first_cam_extrinsic_inv = closed_form_inverse_se3(camera_poses[0][None])
aligned_poses = torch.matmul(camera_poses, first_cam_extrinsic_inv)
return aligned_poses