xiaoyuxi
gradio_app
a51c6d2
from typing import *
import math
from collections import namedtuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.types
import utils3d
def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
"Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
return torch.return_types.min((minimum, indices))
def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
results = []
for i in range(n_chunks):
chunk_args = tuple(arg[i] for arg in splited_args)
chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
results.append(fn(*chunk_args, **chunk_kwargs))
if isinstance(results[0], tuple):
return tuple(torch.cat(r, dim=0) for r in zip(*results))
else:
return torch.cat(results, dim=0)
def _pad_inf(x_: torch.Tensor):
return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
def _pad_cumsum(cumsum: torch.Tensor):
return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
"""
If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
w_i must be >= 0.
### Parameters:
- `x`: tensor of shape (..., n)
- `y`: tensor of shape (..., n)
- `w`: tensor of shape (..., n)
- `trunc`: optional, float or tensor of shape (..., n) or None
### Returns:
- `a`: tensor of shape (...), differentiable
- `loss`: tensor of shape (...), value of loss function at `a`, detached
- `index`: tensor of shape (...), where a = y[idx] / x[idx]
"""
if trunc is None:
x, y, w = torch.broadcast_tensors(x, y, w)
sign = torch.sign(x)
x, y = x * sign, y * sign
y_div_x = y / x.clamp_min(eps)
y_div_x, argsort = y_div_x.sort(dim=-1)
wx = torch.gather(x * w, dim=-1, index=argsort)
derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
index = argsort.gather(dim=-1, index=search).squeeze(-1)
loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
else:
# Reshape to (batch_size, n) for simplicity
x, y, w = torch.broadcast_tensors(x, y, w)
batch_shape = x.shape[:-1]
batch_size = math.prod(batch_shape)
x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
sign = torch.sign(x)
x, y = x * sign, y * sign
wx, wy = w * x, w * y
xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
y_div_x = A = y / x.clamp_min(eps)
B = (wy - trunc) / wx.clamp_min(eps)
C = (wy + trunc) / wx.clamp_min(eps)
with torch.no_grad():
# Caculate prefix sum by orders of A, B, C
A, A_argsort = A.sort(dim=-1)
Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
B, B_argsort = B.sort(dim=-1)
Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
C, C_argsort = C.sort(dim=-1)
Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
# Caculate left and right derivative of A
j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
# Find extrema
is_extrema = (left_derivative < 0) & (right_derivative >= 0)
is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
where_extrema_batch, where_extrema_index = torch.where(is_extrema)
# Calculate objective value at extrema
extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
extrema_value = torch.cat([
_compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
]) # (num_extrema,)
# Find minima among corresponding extrema
minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
index = where_extrema_index[indices]
a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
a = a.reshape(batch_shape)
loss = minima.reshape(batch_shape)
index = index.reshape(batch_shape)
return a, loss, index
def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
"""
Align `depth_src` to `depth_tgt` with given constant weights.
### Parameters:
- `depth_src: torch.Tensor` of shape (..., N)
- `depth_tgt: torch.Tensor` of shape (..., N)
"""
scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
return scale
def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
"""
Align `depth_src` to `depth_tgt` with given constant weights.
### Parameters:
- `depth_src: torch.Tensor` of shape (..., N)
- `depth_tgt: torch.Tensor` of shape (..., N)
- `weight: torch.Tensor` of shape (..., N)
- `trunc: float` or tensor of shape (..., N) or None
### Returns:
- `scale: torch.Tensor` of shape (...).
- `shift: torch.Tensor` of shape (...).
"""
dtype, device = depth_src.dtype, depth_src.device
# Flatten batch dimensions for simplicity
batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
batch_size = math.prod(batch_shape)
depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
# Here, we take anchors only for non-zero weights.
# Although the results will be still correct even anchor points have zero weight,
# it is wasting computation and may cause instability in some cases, e.g. too many extrema.
anchors_where_batch, anchors_where_n = torch.where(weight > 0)
# Stop gradient when solving optimal anchors
with torch.no_grad():
depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
# Reproduce by indexing for shorter compute graph
index_1 = anchors_where_n[index_anchor] # (batch_size,)
index_2 = index[index_anchor] # (batch_size,)
tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
shift = tgt_1 - scale * src_1
scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
return scale, shift
def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
"""
Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
"""
dtype, device = depth_src.dtype, depth_src.device
w = weight
x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
y = depth_tgt
for i in range(max_iter):
beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
return beta[..., 0], beta[..., 1]
def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
"""
### Parameters:
- `points_src: torch.Tensor` of shape (..., N, 3)
- `points_tgt: torch.Tensor` of shape (..., N, 3)
- `weight: torch.Tensor` of shape (..., N)
### Returns:
- `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
- `b: torch.Tensor` of shape (...)
"""
dtype, device = points_src.dtype, points_src.device
scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
return scale
def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
"""
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
It is similar to `align_affine` but scale and shift are applied to different dimensions.
### Parameters:
- `points_src: torch.Tensor` of shape (..., N, 3)
- `points_tgt: torch.Tensor` of shape (..., N, 3)
- `weights: torch.Tensor` of shape (..., N)
### Returns:
- `scale: torch.Tensor` of shape (...).
- `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
"""
dtype, device = points_src.dtype, points_src.device
# Flatten batch dimensions for simplicity
batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
batch_size = math.prod(batch_shape)
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
# Take anchors
anchor_where_batch, anchor_where_n = torch.where(weight > 0)
with torch.no_grad():
zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
# Solve optimal scale and shift for each anchor
MAX_ELEMENTS = 2 ** 20
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
# Reproduce by indexing for shorter compute graph
index_2 = index[index_anchor] # (batch_size,) [0, 3n)
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
return scale, shift
def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
"""
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
It is similar to `align_affine` but scale and shift are applied to different dimensions.
### Parameters:
- `points_src: torch.Tensor` of shape (..., N, 3)
- `points_tgt: torch.Tensor` of shape (..., N, 3)
- `weights: torch.Tensor` of shape (..., N)
### Returns:
- `scale: torch.Tensor` of shape (...).
- `shift: torch.Tensor` of shape (..., 3)
"""
dtype, device = points_src.dtype, points_src.device
# Flatten batch dimensions for simplicity
batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
batch_size = math.prod(batch_shape)
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
# Take anchors
anchor_where_batch, anchor_where_n = torch.where(weight > 0)
with torch.no_grad():
points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
# Solve optimal scale and shift for each anchor
MAX_ELEMENTS = 2 ** 20
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
# Get optimal scale and shift for each batch element
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
index_2 = index[index_anchor] # (batch_size,) [0, 3n)
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
return scale, shift
def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
"""
Align `points_src` to `points_tgt` with respect to a Z-axis shift.
### Parameters:
- `points_src: torch.Tensor` of shape (..., N, 3)
- `points_tgt: torch.Tensor` of shape (..., N, 3)
- `weights: torch.Tensor` of shape (..., N)
### Returns:
- `scale: torch.Tensor` of shape (...).
- `shift: torch.Tensor` of shape (..., 3)
"""
dtype, device = points_src.dtype, points_src.device
shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
return shift
def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
"""
Align `points_src` to `points_tgt` with respect to a Z-axis shift.
### Parameters:
- `points_src: torch.Tensor` of shape (..., N, 3)
- `points_tgt: torch.Tensor` of shape (..., N, 3)
- `weights: torch.Tensor` of shape (..., N)
### Returns:
- `scale: torch.Tensor` of shape (...).
- `shift: torch.Tensor` of shape (..., 3)
"""
dtype, device = points_src.dtype, points_src.device
shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
return shift
def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
### Parameters:
- `x: torch.Tensor` of shape (..., N)
- `y: torch.Tensor` of shape (..., N)
- `w: torch.Tensor` of shape (..., N)
### Returns:
- `a: torch.Tensor` of shape (...,)
- `b: torch.Tensor` of shape (...,)
"""
w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
B = (w_sqrt * y)[..., None]
a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
return a, b