File size: 4,143 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Various metrics for evaluating predictions."""

import logging

import torch
from torch.nn import functional as F

from siclib.geometry.base_camera import BaseCamera
from siclib.geometry.gravity import Gravity
from siclib.utils.conversions import rad2deg

logger = logging.getLogger(__name__)


def pitch_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor:
    """Computes the pitch error between two gravities.

    Args:
        pred_gravity (Gravity): Predicted camera.
        target_gravity (Gravity): Ground truth camera.

    Returns:
        torch.Tensor: Pitch error in degrees.
    """
    return rad2deg(torch.abs(pred_gravity.pitch - target_gravity.pitch))


def roll_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor:
    """Computes the roll error between two gravities.

    Args:
        pred_gravity (Gravity): Predicted Gravity.
        target_gravity (Gravity): Ground truth Gravity.

    Returns:
        torch.Tensor: Roll error in degrees.
    """
    return rad2deg(torch.abs(pred_gravity.roll - target_gravity.roll))


def gravity_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor:
    """Computes the gravity error between two gravities.

    Args:
        pred_gravity (Gravity): Predicted Gravity.
        target_gravity (Gravity): Ground truth Gravity.

    Returns:
        torch.Tensor: Gravity error in degrees.
    """
    assert (
        pred_gravity.vec3d.shape == target_gravity.vec3d.shape
    ), f"{pred_gravity.vec3d.shape} != {target_gravity.vec3d.shape}"
    assert pred_gravity.vec3d.ndim == 2, f"{pred_gravity.vec3d.ndim} != 2"
    assert pred_gravity.vec3d.shape[1] == 3, f"{pred_gravity.vec3d.shape[1]} != 3"

    cossim = F.cosine_similarity(pred_gravity.vec3d, target_gravity.vec3d, dim=-1).clamp(-1, 1)
    return rad2deg(torch.acos(cossim))


def vfov_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor:
    """Computes the vertical field of view error between two cameras.

    Args:
        pred_cam (Camera): Predicted camera.
        target_cam (Camera): Ground truth camera.

    Returns:
        torch.Tensor: Vertical field of view error in degrees.
    """
    return rad2deg(torch.abs(pred_cam.vfov - target_cam.vfov))


def dist_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor:
    """Computes the distortion parameter error between two cameras.

    Returns zero if the cameras do not have distortion parameters.

    Args:
        pred_cam (Camera): Predicted camera.
        target_cam (Camera): Ground truth camera.

    Returns:
        torch.Tensor: distortion error.
    """
    if hasattr(pred_cam, "dist") and hasattr(target_cam, "dist"):
        return torch.abs(pred_cam.dist[..., 0] - target_cam.dist[..., 0])

    logger.debug(
        f"Predicted / target camera doesn't have distortion parameters: {pred_cam}/{target_cam}"
    )
    return pred_cam.new_zeros(pred_cam.f.shape[0])


def latitude_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """Computes the latitude error between two tensors.

    Args:
        predictions (torch.Tensor): Predicted latitude field of shape (B, 1, H, W).
        targets (torch.Tensor): Ground truth latitude field of shape (B, 1, H, W).

    Returns:
        torch.Tensor: Latitude error in degrees of shape (B, H, W).
    """
    return rad2deg(torch.abs(predictions - targets)).squeeze(1)


def up_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """Computes the up error between two tensors.

    Args:
        predictions (torch.Tensor): Predicted up field of shape (B, 2, H, W).
        targets (torch.Tensor): Ground truth up field of shape (B, 2, H, W).

    Returns:
        torch.Tensor: Up error in degrees of shape (B, H, W).
    """
    assert predictions.shape == targets.shape, f"{predictions.shape} != {targets.shape}"
    assert predictions.ndim == 4, f"{predictions.ndim} != 4"
    assert predictions.shape[1] == 2, f"{predictions.shape[1]} != 2"

    angle = F.cosine_similarity(predictions, targets, dim=1).clamp(-1, 1)
    return rad2deg(torch.acos(angle))