Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
import math | |
from collections import namedtuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.types | |
import utils3d | |
def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min: | |
"Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`." | |
shape = src.shape[:dim] + (size,) + src.shape[dim + 1:] | |
minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False) | |
minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index)) | |
indices = torch.full(shape, -1, dtype=torch.long, device=src.device) | |
indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim] | |
return torch.return_types.min((minimum, indices)) | |
def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs): | |
batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0] | |
n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0) | |
splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args) | |
splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()} | |
results = [] | |
for i in range(n_chunks): | |
chunk_args = tuple(arg[i] for arg in splited_args) | |
chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()} | |
results.append(fn(*chunk_args, **chunk_kwargs)) | |
if isinstance(results[0], tuple): | |
return tuple(torch.cat(r, dim=0) for r in zip(*results)) | |
else: | |
return torch.cat(results, dim=0) | |
def _pad_inf(x_: torch.Tensor): | |
return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1) | |
def _pad_cumsum(cumsum: torch.Tensor): | |
return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1) | |
def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float): | |
return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1) | |
def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: | |
""" | |
If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`. | |
w_i must be >= 0. | |
### Parameters: | |
- `x`: tensor of shape (..., n) | |
- `y`: tensor of shape (..., n) | |
- `w`: tensor of shape (..., n) | |
- `trunc`: optional, float or tensor of shape (..., n) or None | |
### Returns: | |
- `a`: tensor of shape (...), differentiable | |
- `loss`: tensor of shape (...), value of loss function at `a`, detached | |
- `index`: tensor of shape (...), where a = y[idx] / x[idx] | |
""" | |
if trunc is None: | |
x, y, w = torch.broadcast_tensors(x, y, w) | |
sign = torch.sign(x) | |
x, y = x * sign, y * sign | |
y_div_x = y / x.clamp_min(eps) | |
y_div_x, argsort = y_div_x.sort(dim=-1) | |
wx = torch.gather(x * w, dim=-1, index=argsort) | |
derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True) | |
search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1) | |
a = y_div_x.gather(dim=-1, index=search).squeeze(-1) | |
index = argsort.gather(dim=-1, index=search).squeeze(-1) | |
loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1) | |
else: | |
# Reshape to (batch_size, n) for simplicity | |
x, y, w = torch.broadcast_tensors(x, y, w) | |
batch_shape = x.shape[:-1] | |
batch_size = math.prod(batch_shape) | |
x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1]) | |
sign = torch.sign(x) | |
x, y = x * sign, y * sign | |
wx, wy = w * x, w * y | |
xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering | |
y_div_x = A = y / x.clamp_min(eps) | |
B = (wy - trunc) / wx.clamp_min(eps) | |
C = (wy + trunc) / wx.clamp_min(eps) | |
with torch.no_grad(): | |
# Caculate prefix sum by orders of A, B, C | |
A, A_argsort = A.sort(dim=-1) | |
Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1) | |
A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases. | |
B, B_argsort = B.sort(dim=-1) | |
Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1) | |
B, Q_B = _pad_inf(B), _pad_cumsum(Q_B) | |
C, C_argsort = C.sort(dim=-1) | |
Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1) | |
C, Q_C = _pad_inf(C), _pad_cumsum(Q_C) | |
# Caculate left and right derivative of A | |
j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1) | |
j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1) | |
j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1) | |
left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C) | |
j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1) | |
j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1) | |
j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1) | |
right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C) | |
# Find extrema | |
is_extrema = (left_derivative < 0) & (right_derivative >= 0) | |
is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema. | |
where_extrema_batch, where_extrema_index = torch.where(is_extrema) | |
# Calculate objective value at extrema | |
extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,) | |
MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G) | |
SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1] | |
extrema_value = torch.cat([ | |
_compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc) | |
for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE)) | |
]) # (num_extrema,) | |
# Find minima among corresponding extrema | |
minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,) | |
index = where_extrema_index[indices] | |
a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps) | |
a = a.reshape(batch_shape) | |
loss = minima.reshape(batch_shape) | |
index = index.reshape(batch_shape) | |
return a, loss, index | |
def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): | |
""" | |
Align `depth_src` to `depth_tgt` with given constant weights. | |
### Parameters: | |
- `depth_src: torch.Tensor` of shape (..., N) | |
- `depth_tgt: torch.Tensor` of shape (..., N) | |
""" | |
scale, _, _ = align(depth_src, depth_tgt, weight, trunc) | |
return scale | |
def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): | |
""" | |
Align `depth_src` to `depth_tgt` with given constant weights. | |
### Parameters: | |
- `depth_src: torch.Tensor` of shape (..., N) | |
- `depth_tgt: torch.Tensor` of shape (..., N) | |
- `weight: torch.Tensor` of shape (..., N) | |
- `trunc: float` or tensor of shape (..., N) or None | |
### Returns: | |
- `scale: torch.Tensor` of shape (...). | |
- `shift: torch.Tensor` of shape (...). | |
""" | |
dtype, device = depth_src.dtype, depth_src.device | |
# Flatten batch dimensions for simplicity | |
batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1] | |
batch_size = math.prod(batch_shape) | |
depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n) | |
# Here, we take anchors only for non-zero weights. | |
# Although the results will be still correct even anchor points have zero weight, | |
# it is wasting computation and may cause instability in some cases, e.g. too many extrema. | |
anchors_where_batch, anchors_where_n = torch.where(weight > 0) | |
# Stop gradient when solving optimal anchors | |
with torch.no_grad(): | |
depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors) | |
depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors) | |
depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n) | |
depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n) | |
weight_anchored = weight[anchors_where_batch, :] # (anchors, n) | |
scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors) | |
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,) | |
# Reproduce by indexing for shorter compute graph | |
index_1 = anchors_where_n[index_anchor] # (batch_size,) | |
index_2 = index[index_anchor] # (batch_size,) | |
tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1) | |
tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1) | |
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7) | |
shift = tgt_1 - scale * src_1 | |
scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape) | |
return scale, shift | |
def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12): | |
""" | |
Align `depth_src` to `depth_tgt` with given constant weights using IRLS. | |
""" | |
dtype, device = depth_src.dtype, depth_src.device | |
w = weight | |
x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1) | |
y = depth_tgt | |
for i in range(max_iter): | |
beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1) | |
w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps) | |
return beta[..., 0], beta[..., 1] | |
def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): | |
""" | |
### Parameters: | |
- `points_src: torch.Tensor` of shape (..., N, 3) | |
- `points_tgt: torch.Tensor` of shape (..., N, 3) | |
- `weight: torch.Tensor` of shape (..., N) | |
### Returns: | |
- `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it. | |
- `b: torch.Tensor` of shape (...) | |
""" | |
dtype, device = points_src.dtype, points_src.device | |
scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc) | |
return scale | |
def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): | |
""" | |
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift. | |
It is similar to `align_affine` but scale and shift are applied to different dimensions. | |
### Parameters: | |
- `points_src: torch.Tensor` of shape (..., N, 3) | |
- `points_tgt: torch.Tensor` of shape (..., N, 3) | |
- `weights: torch.Tensor` of shape (..., N) | |
### Returns: | |
- `scale: torch.Tensor` of shape (...). | |
- `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros. | |
""" | |
dtype, device = points_src.dtype, points_src.device | |
# Flatten batch dimensions for simplicity | |
batch_shape, n = points_src.shape[:-2], points_src.shape[-2] | |
batch_size = math.prod(batch_shape) | |
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n) | |
# Take anchors | |
anchor_where_batch, anchor_where_n = torch.where(weight > 0) | |
with torch.no_grad(): | |
zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype) | |
points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3) | |
points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3) | |
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3) | |
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3) | |
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3) | |
# Solve optimal scale and shift for each anchor | |
MAX_ELEMENTS = 2 ** 20 | |
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,) | |
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,) | |
# Reproduce by indexing for shorter compute graph | |
index_2 = index[index_anchor] # (batch_size,) [0, 3n) | |
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n) | |
zeros = torch.zeros((batch_size, n), device=device, dtype=dtype) | |
points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1) | |
tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1) | |
tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1) | |
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0) | |
shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) | |
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3) | |
return scale, shift | |
def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): | |
""" | |
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift. | |
It is similar to `align_affine` but scale and shift are applied to different dimensions. | |
### Parameters: | |
- `points_src: torch.Tensor` of shape (..., N, 3) | |
- `points_tgt: torch.Tensor` of shape (..., N, 3) | |
- `weights: torch.Tensor` of shape (..., N) | |
### Returns: | |
- `scale: torch.Tensor` of shape (...). | |
- `shift: torch.Tensor` of shape (..., 3) | |
""" | |
dtype, device = points_src.dtype, points_src.device | |
# Flatten batch dimensions for simplicity | |
batch_shape, n = points_src.shape[:-2], points_src.shape[-2] | |
batch_size = math.prod(batch_shape) | |
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n) | |
# Take anchors | |
anchor_where_batch, anchor_where_n = torch.where(weight > 0) | |
with torch.no_grad(): | |
points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3) | |
points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3) | |
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3) | |
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3) | |
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3) | |
# Solve optimal scale and shift for each anchor | |
MAX_ELEMENTS = 2 ** 20 | |
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,) | |
# Get optimal scale and shift for each batch element | |
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,) | |
index_2 = index[index_anchor] # (batch_size,) [0, 3n) | |
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n) | |
src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1) | |
src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1) | |
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0) | |
shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) | |
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3) | |
return scale, shift | |
def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): | |
""" | |
Align `points_src` to `points_tgt` with respect to a Z-axis shift. | |
### Parameters: | |
- `points_src: torch.Tensor` of shape (..., N, 3) | |
- `points_tgt: torch.Tensor` of shape (..., N, 3) | |
- `weights: torch.Tensor` of shape (..., N) | |
### Returns: | |
- `scale: torch.Tensor` of shape (...). | |
- `shift: torch.Tensor` of shape (..., 3) | |
""" | |
dtype, device = points_src.dtype, points_src.device | |
shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc) | |
shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1) | |
return shift | |
def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): | |
""" | |
Align `points_src` to `points_tgt` with respect to a Z-axis shift. | |
### Parameters: | |
- `points_src: torch.Tensor` of shape (..., N, 3) | |
- `points_tgt: torch.Tensor` of shape (..., N, 3) | |
- `weights: torch.Tensor` of shape (..., N) | |
### Returns: | |
- `scale: torch.Tensor` of shape (...). | |
- `shift: torch.Tensor` of shape (..., 3) | |
""" | |
dtype, device = points_src.dtype, points_src.device | |
shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc) | |
return shift | |
def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares. | |
### Parameters: | |
- `x: torch.Tensor` of shape (..., N) | |
- `y: torch.Tensor` of shape (..., N) | |
- `w: torch.Tensor` of shape (..., N) | |
### Returns: | |
- `a: torch.Tensor` of shape (...,) | |
- `b: torch.Tensor` of shape (...,) | |
""" | |
w_sqrt = torch.ones_like(x) if w is None else w.sqrt() | |
A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1) | |
B = (w_sqrt * y)[..., None] | |
a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1) | |
return a, b |