"""Various metrics for evaluating predictions.""" import logging import torch from torch.nn import functional as F from siclib.geometry.base_camera import BaseCamera from siclib.geometry.gravity import Gravity from siclib.utils.conversions import rad2deg logger = logging.getLogger(__name__) def pitch_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor: """Computes the pitch error between two gravities. Args: pred_gravity (Gravity): Predicted camera. target_gravity (Gravity): Ground truth camera. Returns: torch.Tensor: Pitch error in degrees. """ return rad2deg(torch.abs(pred_gravity.pitch - target_gravity.pitch)) def roll_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor: """Computes the roll error between two gravities. Args: pred_gravity (Gravity): Predicted Gravity. target_gravity (Gravity): Ground truth Gravity. Returns: torch.Tensor: Roll error in degrees. """ return rad2deg(torch.abs(pred_gravity.roll - target_gravity.roll)) def gravity_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor: """Computes the gravity error between two gravities. Args: pred_gravity (Gravity): Predicted Gravity. target_gravity (Gravity): Ground truth Gravity. Returns: torch.Tensor: Gravity error in degrees. """ assert ( pred_gravity.vec3d.shape == target_gravity.vec3d.shape ), f"{pred_gravity.vec3d.shape} != {target_gravity.vec3d.shape}" assert pred_gravity.vec3d.ndim == 2, f"{pred_gravity.vec3d.ndim} != 2" assert pred_gravity.vec3d.shape[1] == 3, f"{pred_gravity.vec3d.shape[1]} != 3" cossim = F.cosine_similarity(pred_gravity.vec3d, target_gravity.vec3d, dim=-1).clamp(-1, 1) return rad2deg(torch.acos(cossim)) def vfov_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor: """Computes the vertical field of view error between two cameras. Args: pred_cam (Camera): Predicted camera. target_cam (Camera): Ground truth camera. Returns: torch.Tensor: Vertical field of view error in degrees. """ return rad2deg(torch.abs(pred_cam.vfov - target_cam.vfov)) def dist_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor: """Computes the distortion parameter error between two cameras. Returns zero if the cameras do not have distortion parameters. Args: pred_cam (Camera): Predicted camera. target_cam (Camera): Ground truth camera. Returns: torch.Tensor: distortion error. """ if hasattr(pred_cam, "dist") and hasattr(target_cam, "dist"): return torch.abs(pred_cam.dist[..., 0] - target_cam.dist[..., 0]) logger.debug( f"Predicted / target camera doesn't have distortion parameters: {pred_cam}/{target_cam}" ) return pred_cam.new_zeros(pred_cam.f.shape[0]) def latitude_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Computes the latitude error between two tensors. Args: predictions (torch.Tensor): Predicted latitude field of shape (B, 1, H, W). targets (torch.Tensor): Ground truth latitude field of shape (B, 1, H, W). Returns: torch.Tensor: Latitude error in degrees of shape (B, H, W). """ return rad2deg(torch.abs(predictions - targets)).squeeze(1) def up_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Computes the up error between two tensors. Args: predictions (torch.Tensor): Predicted up field of shape (B, 2, H, W). targets (torch.Tensor): Ground truth up field of shape (B, 2, H, W). Returns: torch.Tensor: Up error in degrees of shape (B, H, W). """ assert predictions.shape == targets.shape, f"{predictions.shape} != {targets.shape}" assert predictions.ndim == 4, f"{predictions.ndim} != 4" assert predictions.shape[1] == 2, f"{predictions.shape[1]} != 2" angle = F.cosine_similarity(predictions, targets, dim=1).clamp(-1, 1) return rad2deg(torch.acos(angle))