|
from internal import math |
|
from internal import utils |
|
import numpy as np |
|
import torch |
|
|
|
|
|
|
|
def contract(x): |
|
"""Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077).""" |
|
eps = torch.finfo(x.dtype).eps |
|
|
|
|
|
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps) |
|
z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) |
|
return z |
|
|
|
|
|
def inv_contract(z): |
|
"""The inverse of contract().""" |
|
eps = torch.finfo(z.dtype).eps |
|
|
|
|
|
z_mag_sq = torch.sum(z ** 2, dim=-1, keepdim=True).clamp_min(eps) |
|
x = torch.where(z_mag_sq <= 1, z, z / (2 * torch.sqrt(z_mag_sq) - z_mag_sq).clamp_min(eps)) |
|
return x |
|
|
|
|
|
def inv_contract_np(z): |
|
"""The inverse of contract().""" |
|
eps = np.finfo(z.dtype).eps |
|
|
|
|
|
z_mag_sq = np.maximum(np.sum(z ** 2, axis=-1, keepdims=True), eps) |
|
x = np.where(z_mag_sq <= 1, z, z / np.maximum(2 * np.sqrt(z_mag_sq) - z_mag_sq, eps)) |
|
return x |
|
|
|
|
|
def contract_tuple(x): |
|
res = contract(x) |
|
return res, res |
|
|
|
|
|
def contract_mean_jacobi(x): |
|
eps = torch.finfo(x.dtype).eps |
|
|
|
|
|
|
|
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps) |
|
x_mag_sqrt = torch.sqrt(x_mag_sq) |
|
x_xT = math.matmul(x[..., None], x[..., None, :]) |
|
mask = x_mag_sq <= 1 |
|
z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) |
|
|
|
eye = torch.broadcast_to(torch.eye(3, device=x.device), z.shape[:-1] + z.shape[-1:] * 2) |
|
jacobi = (2 * x_xT * (1 - x_mag_sqrt[..., None]) + (2 * x_mag_sqrt[..., None] ** 3 - x_mag_sqrt[..., None] ** 2) * eye) / x_mag_sqrt[..., None] ** 4 |
|
jacobi = torch.where(mask[..., None], eye, jacobi) |
|
return z, jacobi |
|
|
|
|
|
def contract_mean_std(x, std): |
|
eps = torch.finfo(x.dtype).eps |
|
|
|
|
|
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps) |
|
x_mag_sqrt = torch.sqrt(x_mag_sq) |
|
mask = x_mag_sq <= 1 |
|
z = torch.where(mask, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) |
|
|
|
det_13 = (torch.pow(2 * x_mag_sqrt - 1, 1/3) / x_mag_sqrt) ** 2 |
|
|
|
std = torch.where(mask[..., 0], std, det_13[..., 0] * std) |
|
return z, std |
|
|
|
|
|
@torch.no_grad() |
|
def track_linearize(fn, mean, std): |
|
"""Apply function `fn` to a set of means and covariances, ala a Kalman filter. |
|
|
|
We can analytically transform a Gaussian parameterized by `mean` and `cov` |
|
with a function `fn` by linearizing `fn` around `mean`, and taking advantage |
|
of the fact that Covar[Ax + y] = A(Covar[x])A^T (see |
|
https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details). |
|
|
|
Args: |
|
fn: the function applied to the Gaussians parameterized by (mean, cov). |
|
mean: a tensor of means, where the last axis is the dimension. |
|
std: a tensor of covariances, where the last two axes are the dimensions. |
|
|
|
Returns: |
|
fn_mean: the transformed means. |
|
fn_cov: the transformed covariances. |
|
""" |
|
if fn == 'contract': |
|
fn = contract_mean_jacobi |
|
else: |
|
raise NotImplementedError |
|
|
|
pre_shape = mean.shape[:-1] |
|
mean = mean.reshape(-1, 3) |
|
std = std.reshape(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mean, std = contract_mean_std(mean, std) |
|
|
|
mean = mean.reshape(*pre_shape, 3) |
|
std = std.reshape(*pre_shape) |
|
return mean, std |
|
|
|
|
|
def power_transformation(x, lam): |
|
""" |
|
power transformation for Eq(4) in zip-nerf |
|
""" |
|
lam_1 = np.abs(lam - 1) |
|
return lam_1 / lam * ((x / lam_1 + 1) ** lam - 1) |
|
|
|
|
|
def inv_power_transformation(x, lam): |
|
""" |
|
inverse power transformation |
|
""" |
|
lam_1 = np.abs(lam - 1) |
|
eps = torch.finfo(x.dtype).eps |
|
|
|
return ((x * lam / lam_1 + 1 + eps) ** (1 / lam) - 1) * lam_1 |
|
|
|
|
|
def construct_ray_warps(fn, t_near, t_far, lam=None): |
|
"""Construct a bijection between metric distances and normalized distances. |
|
|
|
See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a |
|
detailed explanation. |
|
|
|
Args: |
|
fn: the function to ray distances. |
|
t_near: a tensor of near-plane distances. |
|
t_far: a tensor of far-plane distances. |
|
lam: for lam in Eq(4) in zip-nerf |
|
|
|
Returns: |
|
t_to_s: a function that maps distances to normalized distances in [0, 1]. |
|
s_to_t: the inverse of t_to_s. |
|
""" |
|
if fn is None: |
|
fn_fwd = lambda x: x |
|
fn_inv = lambda x: x |
|
elif fn == 'piecewise': |
|
|
|
fn_fwd = lambda x: torch.where(x < 1, .5 * x, 1 - .5 / x) |
|
fn_inv = lambda x: torch.where(x < .5, 2 * x, .5 / (1 - x)) |
|
elif fn == 'power_transformation': |
|
fn_fwd = lambda x: power_transformation(x * 2, lam=lam) |
|
fn_inv = lambda y: inv_power_transformation(y, lam=lam) / 2 |
|
else: |
|
inv_mapping = { |
|
'reciprocal': torch.reciprocal, |
|
'log': torch.exp, |
|
'exp': torch.log, |
|
'sqrt': torch.square, |
|
'square': torch.sqrt, |
|
} |
|
fn_fwd = fn |
|
fn_inv = inv_mapping[fn.__name__] |
|
|
|
s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)] |
|
t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near) |
|
s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near) |
|
return t_to_s, s_to_t |
|
|
|
|
|
def expected_sin(mean, var): |
|
"""Compute the mean of sin(x), x ~ N(mean, var).""" |
|
return torch.exp(-0.5 * var) * math.safe_sin(mean) |
|
|
|
|
|
def integrated_pos_enc(mean, var, min_deg, max_deg): |
|
"""Encode `x` with sinusoids scaled by 2^[min_deg, max_deg). |
|
|
|
Args: |
|
mean: tensor, the mean coordinates to be encoded |
|
var: tensor, the variance of the coordinates to be encoded. |
|
min_deg: int, the min degree of the encoding. |
|
max_deg: int, the max degree of the encoding. |
|
|
|
Returns: |
|
encoded: tensor, encoded variables. |
|
""" |
|
scales = 2 ** torch.arange(min_deg, max_deg, device=mean.device) |
|
shape = mean.shape[:-1] + (-1,) |
|
scaled_mean = (mean[..., None, :] * scales[:, None]).reshape(*shape) |
|
scaled_var = (var[..., None, :] * scales[:, None] ** 2).reshape(*shape) |
|
|
|
return expected_sin( |
|
torch.cat([scaled_mean, scaled_mean + 0.5 * torch.pi], dim=-1), |
|
torch.cat([scaled_var] * 2, dim=-1)) |
|
|
|
|
|
def lift_and_diagonalize(mean, cov, basis): |
|
"""Project `mean` and `cov` onto basis and diagonalize the projected cov.""" |
|
fn_mean = math.matmul(mean, basis) |
|
fn_cov_diag = torch.sum(basis * math.matmul(cov, basis), dim=-2) |
|
return fn_mean, fn_cov_diag |
|
|
|
|
|
def pos_enc(x, min_deg, max_deg, append_identity=True): |
|
"""The positional encoding used by the original NeRF paper.""" |
|
scales = 2 ** torch.arange(min_deg, max_deg, device=x.device) |
|
shape = x.shape[:-1] + (-1,) |
|
scaled_x = (x[..., None, :] * scales[:, None]).reshape(*shape) |
|
|
|
four_feat = torch.sin( |
|
torch.cat([scaled_x, scaled_x + 0.5 * torch.pi], dim=-1)) |
|
if append_identity: |
|
return torch.cat([x] + [four_feat], dim=-1) |
|
else: |
|
return four_feat |
|
|