File size: 6,082 Bytes
205a7af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import logging
from typing import Dict
import torch
from siclib.geometry.base_camera import BaseCamera
from siclib.geometry.gravity import Gravity
from siclib.utils.conversions import deg2rad, focal2fov
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
def get_initial_estimation(
data: Dict[str, torch.Tensor], camera_model: BaseCamera, trivial_init: bool = True
) -> BaseCamera:
"""Get initial camera for optimization using heuristics."""
return (
get_trivial_estimation(data, camera_model)
if trivial_init
else get_heuristic_estimation(data, camera_model)
)
def get_heuristic_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
"""Get initial camera for optimization using heuristics.
Initial camera is initialized with the following heuristics:
- roll is the angle of the up vector at the center of the image
- pitch is the value at the center of the latitude map
- vfov is the difference between the central top and bottom of the latitude map
- distortions are set to zero
Use the prior values if available.
Args:
data (Dict[str, torch.Tensor]): Input data dictionary.
camera_model (BaseCamera): Camera model to use.
Returns:
BaseCamera: Initial camera for optimization.
"""
up_ref = data["up_field"].detach()
latitude_ref = data["latitude_field"].detach()
h, w = up_ref.shape[-2:]
batch_h, batch_w = (
up_ref.new_ones((up_ref.shape[0],)) * h,
up_ref.new_ones((up_ref.shape[0],)) * w,
)
# init roll is angle of the up vector at the center of the image
init_r = -torch.atan2(
up_ref[:, 0, int(h / 2), int(w / 2)], -up_ref[:, 1, int(h / 2), int(w / 2)]
)
init_r = init_r.clamp(min=-deg2rad(45), max=deg2rad(45))
# init pitch is the value at the center of the latitude map
init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)]
init_p = init_p.clamp(min=-deg2rad(45), max=deg2rad(45))
# init vfov is the difference between the central top and bottom of the latitude map
init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)]
init_vfov = torch.abs(init_vfov)
init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120))
focal = data.get("prior_focal")
init_vfov = init_vfov if focal is None else focal2fov(focal, h)
params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
params |= {"scales": data["scales"]} if "scales" in data else {}
params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
camera = camera_model.from_dict(params)
camera = camera.float().to(data["up_field"].device)
gravity = Gravity.from_rp(init_r, init_p).float().to(data["up_field"].device)
if "prior_gravity" in data:
gravity = data["prior_gravity"].float().to(up_ref.device)
return camera, gravity
def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
"""Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).
Args:
data (Dict[str, torch.Tensor]): Input data dictionary.
camera_model (BaseCamera): Camera model to use.
Returns:
BaseCamera: Initial camera for optimization.
"""
"""Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w)."""
ref = data.get("up_field", data["latitude_field"])
ref = ref.detach()
h, w = ref.shape[-2:]
batch_h, batch_w = (
ref.new_ones((ref.shape[0],)) * h,
ref.new_ones((ref.shape[0],)) * w,
)
init_r = ref.new_zeros((ref.shape[0],))
init_p = ref.new_zeros((ref.shape[0],))
focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w))
init_vfov = init_vfov if focal is None else focal2fov(focal, h)
params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
params |= {"scales": data["scales"]} if "scales" in data else {}
params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
camera = camera_model.from_dict(params)
camera = camera.float().to(ref.device)
gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device)
if "prior_gravity" in data:
gravity = data["prior_gravity"].float().to(ref.device)
return camera, gravity
def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool:
"""Early stopping criterion based on cost convergence."""
return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol)
def update_lambda(
lamb: torch.Tensor,
prev_cost: torch.Tensor,
new_cost: torch.Tensor,
lambda_min: float = 1e-6,
lambda_max: float = 1e2,
) -> torch.Tensor:
"""Update damping factor for Levenberg-Marquardt optimization."""
new_lamb = lamb.new_zeros(lamb.shape)
new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1)
lamb = torch.clamp(new_lamb, lambda_min, lambda_max)
return lamb
def optimizer_step(
G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""One optimization step with Gauss-Newton or Levenberg-Marquardt.
Args:
G (torch.Tensor): Batched gradient tensor of size (..., N).
H (torch.Tensor): Batched hessian tensor of size (..., N, N).
lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,).
eps (float, optional): Epsilon for damping. Defaults to 1e-6.
Returns:
torch.Tensor: Batched update tensor of size (..., N).
"""
diag = H.diagonal(dim1=-2, dim2=-1)
diag = diag * lambda_.unsqueeze(-1) # (B, 3)
H = H + diag.clamp(min=eps).diag_embed()
H_, G_ = H.cpu(), G.cpu()
try:
U = torch.linalg.cholesky(H_)
except RuntimeError:
logger.warning("Cholesky decomposition failed. Stopping.")
delta = H.new_zeros((H.shape[0], H.shape[-1])) # (B, 3)
else:
delta = torch.cholesky_solve(G_[..., None], U)[..., 0]
return delta.to(H.device)
|