File size: 6,082 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import logging
from typing import Dict

import torch

from siclib.geometry.base_camera import BaseCamera
from siclib.geometry.gravity import Gravity
from siclib.utils.conversions import deg2rad, focal2fov

logger = logging.getLogger(__name__)

# flake8: noqa
# mypy: ignore-errors


def get_initial_estimation(
    data: Dict[str, torch.Tensor], camera_model: BaseCamera, trivial_init: bool = True
) -> BaseCamera:
    """Get initial camera for optimization using heuristics."""
    return (
        get_trivial_estimation(data, camera_model)
        if trivial_init
        else get_heuristic_estimation(data, camera_model)
    )


def get_heuristic_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
    """Get initial camera for optimization using heuristics.

    Initial camera is initialized with the following heuristics:
    - roll is the angle of the up vector at the center of the image
    - pitch is the value at the center of the latitude map
    - vfov is the difference between the central top and bottom of the latitude map
    - distortions are set to zero

    Use the prior values if available.

    Args:
        data (Dict[str, torch.Tensor]): Input data dictionary.
        camera_model (BaseCamera): Camera model to use.

    Returns:
        BaseCamera: Initial camera for optimization.
    """
    up_ref = data["up_field"].detach()
    latitude_ref = data["latitude_field"].detach()

    h, w = up_ref.shape[-2:]
    batch_h, batch_w = (
        up_ref.new_ones((up_ref.shape[0],)) * h,
        up_ref.new_ones((up_ref.shape[0],)) * w,
    )

    # init roll is angle of the up vector at the center of the image
    init_r = -torch.atan2(
        up_ref[:, 0, int(h / 2), int(w / 2)], -up_ref[:, 1, int(h / 2), int(w / 2)]
    )
    init_r = init_r.clamp(min=-deg2rad(45), max=deg2rad(45))

    # init pitch is the value at the center of the latitude map
    init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)]
    init_p = init_p.clamp(min=-deg2rad(45), max=deg2rad(45))

    # init vfov is the difference between the central top and bottom of the latitude map
    init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)]
    init_vfov = torch.abs(init_vfov)
    init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120))

    focal = data.get("prior_focal")
    init_vfov = init_vfov if focal is None else focal2fov(focal, h)

    params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
    params |= {"scales": data["scales"]} if "scales" in data else {}
    params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
    camera = camera_model.from_dict(params)
    camera = camera.float().to(data["up_field"].device)

    gravity = Gravity.from_rp(init_r, init_p).float().to(data["up_field"].device)
    if "prior_gravity" in data:
        gravity = data["prior_gravity"].float().to(up_ref.device)

    return camera, gravity


def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
    """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).

    Args:
        data (Dict[str, torch.Tensor]): Input data dictionary.
        camera_model (BaseCamera): Camera model to use.

    Returns:
        BaseCamera: Initial camera for optimization.
    """
    """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w)."""
    ref = data.get("up_field", data["latitude_field"])
    ref = ref.detach()

    h, w = ref.shape[-2:]
    batch_h, batch_w = (
        ref.new_ones((ref.shape[0],)) * h,
        ref.new_ones((ref.shape[0],)) * w,
    )

    init_r = ref.new_zeros((ref.shape[0],))
    init_p = ref.new_zeros((ref.shape[0],))

    focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w))
    init_vfov = init_vfov if focal is None else focal2fov(focal, h)

    params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
    params |= {"scales": data["scales"]} if "scales" in data else {}
    params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
    camera = camera_model.from_dict(params)
    camera = camera.float().to(ref.device)

    gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device)

    if "prior_gravity" in data:
        gravity = data["prior_gravity"].float().to(ref.device)

    return camera, gravity


def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool:
    """Early stopping criterion based on cost convergence."""
    return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol)


def update_lambda(
    lamb: torch.Tensor,
    prev_cost: torch.Tensor,
    new_cost: torch.Tensor,
    lambda_min: float = 1e-6,
    lambda_max: float = 1e2,
) -> torch.Tensor:
    """Update damping factor for Levenberg-Marquardt optimization."""
    new_lamb = lamb.new_zeros(lamb.shape)
    new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1)
    lamb = torch.clamp(new_lamb, lambda_min, lambda_max)
    return lamb


def optimizer_step(
    G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
    """One optimization step with Gauss-Newton or Levenberg-Marquardt.

    Args:
        G (torch.Tensor): Batched gradient tensor of size (..., N).
        H (torch.Tensor): Batched hessian tensor of size (..., N, N).
        lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,).
        eps (float, optional): Epsilon for damping. Defaults to 1e-6.

    Returns:
        torch.Tensor: Batched update tensor of size (..., N).
    """
    diag = H.diagonal(dim1=-2, dim2=-1)
    diag = diag * lambda_.unsqueeze(-1)  # (B, 3)

    H = H + diag.clamp(min=eps).diag_embed()

    H_, G_ = H.cpu(), G.cpu()
    try:
        U = torch.linalg.cholesky(H_)
    except RuntimeError:
        logger.warning("Cholesky decomposition failed. Stopping.")
        delta = H.new_zeros((H.shape[0], H.shape[-1]))  # (B, 3)
    else:
        delta = torch.cholesky_solve(G_[..., None], U)[..., 0]

    return delta.to(H.device)