zipnerf / internal /coord.py
Cr4yfish's picture
copy files from SuLvXiangXin
c165cd8
from internal import math
from internal import utils
import numpy as np
import torch
# from torch.func import vmap, jacrev
def contract(x):
"""Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077)."""
eps = torch.finfo(x.dtype).eps
# eps = 1e-3
# Clamping to eps prevents non-finite gradients when x == 0.
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
return z
def inv_contract(z):
"""The inverse of contract()."""
eps = torch.finfo(z.dtype).eps
# Clamping to eps prevents non-finite gradients when z == 0.
z_mag_sq = torch.sum(z ** 2, dim=-1, keepdim=True).clamp_min(eps)
x = torch.where(z_mag_sq <= 1, z, z / (2 * torch.sqrt(z_mag_sq) - z_mag_sq).clamp_min(eps))
return x
def inv_contract_np(z):
"""The inverse of contract()."""
eps = np.finfo(z.dtype).eps
# Clamping to eps prevents non-finite gradients when z == 0.
z_mag_sq = np.maximum(np.sum(z ** 2, axis=-1, keepdims=True), eps)
x = np.where(z_mag_sq <= 1, z, z / np.maximum(2 * np.sqrt(z_mag_sq) - z_mag_sq, eps))
return x
def contract_tuple(x):
res = contract(x)
return res, res
def contract_mean_jacobi(x):
eps = torch.finfo(x.dtype).eps
# eps = 1e-3
# Clamping to eps prevents non-finite gradients when x == 0.
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
x_mag_sqrt = torch.sqrt(x_mag_sq)
x_xT = math.matmul(x[..., None], x[..., None, :])
mask = x_mag_sq <= 1
z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
eye = torch.broadcast_to(torch.eye(3, device=x.device), z.shape[:-1] + z.shape[-1:] * 2)
jacobi = (2 * x_xT * (1 - x_mag_sqrt[..., None]) + (2 * x_mag_sqrt[..., None] ** 3 - x_mag_sqrt[..., None] ** 2) * eye) / x_mag_sqrt[..., None] ** 4
jacobi = torch.where(mask[..., None], eye, jacobi)
return z, jacobi
def contract_mean_std(x, std):
eps = torch.finfo(x.dtype).eps
# eps = 1e-3
# Clamping to eps prevents non-finite gradients when x == 0.
x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps)
x_mag_sqrt = torch.sqrt(x_mag_sq)
mask = x_mag_sq <= 1
z = torch.where(mask, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x)
# det_13 = ((1 / x_mag_sq) * ((2 / x_mag_sqrt - 1 / x_mag_sq) ** 2)) ** (1 / 3)
det_13 = (torch.pow(2 * x_mag_sqrt - 1, 1/3) / x_mag_sqrt) ** 2
std = torch.where(mask[..., 0], std, det_13[..., 0] * std)
return z, std
@torch.no_grad()
def track_linearize(fn, mean, std):
"""Apply function `fn` to a set of means and covariances, ala a Kalman filter.
We can analytically transform a Gaussian parameterized by `mean` and `cov`
with a function `fn` by linearizing `fn` around `mean`, and taking advantage
of the fact that Covar[Ax + y] = A(Covar[x])A^T (see
https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details).
Args:
fn: the function applied to the Gaussians parameterized by (mean, cov).
mean: a tensor of means, where the last axis is the dimension.
std: a tensor of covariances, where the last two axes are the dimensions.
Returns:
fn_mean: the transformed means.
fn_cov: the transformed covariances.
"""
if fn == 'contract':
fn = contract_mean_jacobi
else:
raise NotImplementedError
pre_shape = mean.shape[:-1]
mean = mean.reshape(-1, 3)
std = std.reshape(-1)
# jvp_1, mean_1 = vmap(jacrev(contract_tuple, has_aux=True))(mean)
# std_1 = std * torch.linalg.det(jvp_1) ** (1 / mean.shape[-1])
#
# mean_2, jvp_2 = fn(mean)
# std_2 = std * torch.linalg.det(jvp_2) ** (1 / mean.shape[-1])
#
# mean_3, std_3 = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues
# torch.allclose(std_1, std_3, atol=1e-7) # True
# torch.allclose(mean_1, mean_3) # True
# import ipdb; ipdb.set_trace()
mean, std = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues
mean = mean.reshape(*pre_shape, 3)
std = std.reshape(*pre_shape)
return mean, std
def power_transformation(x, lam):
"""
power transformation for Eq(4) in zip-nerf
"""
lam_1 = np.abs(lam - 1)
return lam_1 / lam * ((x / lam_1 + 1) ** lam - 1)
def inv_power_transformation(x, lam):
"""
inverse power transformation
"""
lam_1 = np.abs(lam - 1)
eps = torch.finfo(x.dtype).eps # may cause inf
# eps = 1e-3
return ((x * lam / lam_1 + 1 + eps) ** (1 / lam) - 1) * lam_1
def construct_ray_warps(fn, t_near, t_far, lam=None):
"""Construct a bijection between metric distances and normalized distances.
See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a
detailed explanation.
Args:
fn: the function to ray distances.
t_near: a tensor of near-plane distances.
t_far: a tensor of far-plane distances.
lam: for lam in Eq(4) in zip-nerf
Returns:
t_to_s: a function that maps distances to normalized distances in [0, 1].
s_to_t: the inverse of t_to_s.
"""
if fn is None:
fn_fwd = lambda x: x
fn_inv = lambda x: x
elif fn == 'piecewise':
# Piecewise spacing combining identity and 1/x functions to allow t_near=0.
fn_fwd = lambda x: torch.where(x < 1, .5 * x, 1 - .5 / x)
fn_inv = lambda x: torch.where(x < .5, 2 * x, .5 / (1 - x))
elif fn == 'power_transformation':
fn_fwd = lambda x: power_transformation(x * 2, lam=lam)
fn_inv = lambda y: inv_power_transformation(y, lam=lam) / 2
else:
inv_mapping = {
'reciprocal': torch.reciprocal,
'log': torch.exp,
'exp': torch.log,
'sqrt': torch.square,
'square': torch.sqrt,
}
fn_fwd = fn
fn_inv = inv_mapping[fn.__name__]
s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)]
t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near)
s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near)
return t_to_s, s_to_t
def expected_sin(mean, var):
"""Compute the mean of sin(x), x ~ N(mean, var)."""
return torch.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value.
def integrated_pos_enc(mean, var, min_deg, max_deg):
"""Encode `x` with sinusoids scaled by 2^[min_deg, max_deg).
Args:
mean: tensor, the mean coordinates to be encoded
var: tensor, the variance of the coordinates to be encoded.
min_deg: int, the min degree of the encoding.
max_deg: int, the max degree of the encoding.
Returns:
encoded: tensor, encoded variables.
"""
scales = 2 ** torch.arange(min_deg, max_deg, device=mean.device)
shape = mean.shape[:-1] + (-1,)
scaled_mean = (mean[..., None, :] * scales[:, None]).reshape(*shape)
scaled_var = (var[..., None, :] * scales[:, None] ** 2).reshape(*shape)
return expected_sin(
torch.cat([scaled_mean, scaled_mean + 0.5 * torch.pi], dim=-1),
torch.cat([scaled_var] * 2, dim=-1))
def lift_and_diagonalize(mean, cov, basis):
"""Project `mean` and `cov` onto basis and diagonalize the projected cov."""
fn_mean = math.matmul(mean, basis)
fn_cov_diag = torch.sum(basis * math.matmul(cov, basis), dim=-2)
return fn_mean, fn_cov_diag
def pos_enc(x, min_deg, max_deg, append_identity=True):
"""The positional encoding used by the original NeRF paper."""
scales = 2 ** torch.arange(min_deg, max_deg, device=x.device)
shape = x.shape[:-1] + (-1,)
scaled_x = (x[..., None, :] * scales[:, None]).reshape(*shape)
# Note that we're not using safe_sin, unlike IPE.
four_feat = torch.sin(
torch.cat([scaled_x, scaled_x + 0.5 * torch.pi], dim=-1))
if append_identity:
return torch.cat([x] + [four_feat], dim=-1)
else:
return four_feat