|
import os.path |
|
|
|
from internal import stepfun |
|
from internal import math |
|
from internal import utils |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def lift_gaussian(d, t_mean, t_var, r_var, diag): |
|
"""Lift a Gaussian defined along a ray to 3D coordinates.""" |
|
mean = d[..., None, :] * t_mean[..., None] |
|
eps = torch.finfo(d.dtype).eps |
|
|
|
d_mag_sq = torch.sum(d ** 2, dim=-1, keepdim=True).clamp_min(eps) |
|
|
|
if diag: |
|
d_outer_diag = d ** 2 |
|
null_outer_diag = 1 - d_outer_diag / d_mag_sq |
|
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :] |
|
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :] |
|
cov_diag = t_cov_diag + xy_cov_diag |
|
return mean, cov_diag |
|
else: |
|
d_outer = d[..., :, None] * d[..., None, :] |
|
eye = torch.eye(d.shape[-1], device=d.device) |
|
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :] |
|
t_cov = t_var[..., None, None] * d_outer[..., None, :, :] |
|
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :] |
|
cov = t_cov + xy_cov |
|
return mean, cov |
|
|
|
|
|
def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True): |
|
"""Approximate a conical frustum as a Gaussian distribution (mean+cov). |
|
|
|
Assumes the ray is originating from the origin, and base_radius is the |
|
radius at dist=1. Doesn't assume `d` is normalized. |
|
|
|
Args: |
|
d: the axis of the cone |
|
t0: the starting distance of the frustum. |
|
t1: the ending distance of the frustum. |
|
base_radius: the scale of the radius as a function of distance. |
|
diag: whether or the Gaussian will be diagonal or full-covariance. |
|
stable: whether or not to use the stable computation described in |
|
the paper (setting this to False will cause catastrophic failure). |
|
|
|
Returns: |
|
a Gaussian (mean and covariance). |
|
""" |
|
if stable: |
|
|
|
mu = (t0 + t1) / 2 |
|
hw = (t1 - t0) / 2 |
|
eps = torch.finfo(d.dtype).eps |
|
|
|
t_mean = mu + (2 * mu * hw ** 2) / (3 * mu ** 2 + hw ** 2).clamp_min(eps) |
|
denom = (3 * mu ** 2 + hw ** 2).clamp_min(eps) |
|
t_var = (hw ** 2) / 3 - (4 / 15) * hw ** 4 * (12 * mu ** 2 - hw ** 2) / denom ** 2 |
|
r_var = (mu ** 2) / 4 + (5 / 12) * hw ** 2 - (4 / 15) * (hw ** 4) / denom |
|
else: |
|
|
|
t_mean = (3 * (t1 ** 4 - t0 ** 4)) / (4 * (t1 ** 3 - t0 ** 3)) |
|
r_var = 3 / 20 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3) |
|
t_mosq = 3 / 5 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3) |
|
t_var = t_mosq - t_mean ** 2 |
|
r_var *= base_radius ** 2 |
|
return lift_gaussian(d, t_mean, t_var, r_var, diag) |
|
|
|
|
|
def cylinder_to_gaussian(d, t0, t1, radius, diag): |
|
"""Approximate a cylinder as a Gaussian distribution (mean+cov). |
|
|
|
Assumes the ray is originating from the origin, and radius is the |
|
radius. Does not renormalize `d`. |
|
|
|
Args: |
|
d: the axis of the cylinder |
|
t0: the starting distance of the cylinder. |
|
t1: the ending distance of the cylinder. |
|
radius: the radius of the cylinder |
|
diag: whether or the Gaussian will be diagonal or full-covariance. |
|
|
|
Returns: |
|
a Gaussian (mean and covariance). |
|
""" |
|
t_mean = (t0 + t1) / 2 |
|
r_var = radius ** 2 / 4 |
|
t_var = (t1 - t0) ** 2 / 12 |
|
return lift_gaussian(d, t_mean, t_var, r_var, diag) |
|
|
|
|
|
def cast_rays(tdist, origins, directions, cam_dirs, radii, rand=True, n=7, m=3, std_scale=0.5, **kwargs): |
|
"""Cast rays (cone- or cylinder-shaped) and featurize sections of it. |
|
|
|
Args: |
|
tdist: float array, the "fencepost" distances along the ray. |
|
origins: float array, the ray origin coordinates. |
|
directions: float array, the ray direction vectors. |
|
radii: float array, the radii (base radii for cones) of the rays. |
|
ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'. |
|
diag: boolean, whether or not the covariance matrices should be diagonal. |
|
|
|
Returns: |
|
a tuple of arrays of means and covariances. |
|
""" |
|
t0 = tdist[..., :-1, None] |
|
t1 = tdist[..., 1:, None] |
|
radii = radii[..., None] |
|
|
|
t_m = (t0 + t1) / 2 |
|
t_d = (t1 - t0) / 2 |
|
|
|
j = torch.arange(6, device=tdist.device) |
|
t = t0 + t_d / (t_d ** 2 + 3 * t_m ** 2) * (t1 ** 2 + 2 * t_m ** 2 + 3 / 7 ** 0.5 * (2 * j / 5 - 1) * ( |
|
(t_d ** 2 - t_m ** 2) ** 2 + 4 * t_m ** 4).sqrt()) |
|
|
|
deg = torch.pi / 3 * torch.tensor([0, 2, 4, 3, 5, 1], device=tdist.device, dtype=torch.float) |
|
deg = torch.broadcast_to(deg, t.shape) |
|
if rand: |
|
|
|
mask = torch.rand_like(t0[..., 0]) > 0.5 |
|
deg = deg + 2 * torch.pi * torch.rand_like(deg[..., 0])[..., None] |
|
deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg) |
|
else: |
|
|
|
mask = torch.arange(t.shape[-2], device=tdist.device) % 2 == 0 |
|
mask = torch.broadcast_to(mask, t.shape[:-1]) |
|
deg = torch.where(mask[..., None], deg, deg + torch.pi / 6) |
|
deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg) |
|
means = torch.stack([ |
|
radii * t * torch.cos(deg) / 2 ** 0.5, |
|
radii * t * torch.sin(deg) / 2 ** 0.5, |
|
t |
|
], dim=-1) |
|
stds = std_scale * radii * t / 2 ** 0.5 |
|
|
|
|
|
rand_vec = torch.randn_like(cam_dirs) |
|
ortho1 = F.normalize(torch.cross(cam_dirs, rand_vec, dim=-1), dim=-1) |
|
ortho2 = F.normalize(torch.cross(cam_dirs, ortho1, dim=-1), dim=-1) |
|
|
|
|
|
|
|
basis_matrix = torch.stack([ortho1, ortho2, directions], dim=-1) |
|
means = math.matmul(means, basis_matrix[..., None, :, :].transpose(-1, -2)) |
|
means = means + origins[..., None, None, :] |
|
|
|
|
|
|
|
return means, stds, t |
|
|
|
|
|
def compute_alpha_weights(density, tdist, dirs, opaque_background=False): |
|
"""Helper function for computing alpha compositing weights.""" |
|
t_delta = tdist[..., 1:] - tdist[..., :-1] |
|
delta = t_delta * torch.norm(dirs[..., None, :], dim=-1) |
|
density_delta = density * delta |
|
|
|
if opaque_background: |
|
|
|
density_delta = torch.cat([ |
|
density_delta[..., :-1], |
|
torch.full_like(density_delta[..., -1:], torch.inf) |
|
], dim=-1) |
|
|
|
alpha = 1 - torch.exp(-density_delta) |
|
trans = torch.exp(-torch.cat([ |
|
torch.zeros_like(density_delta[..., :1]), |
|
torch.cumsum(density_delta[..., :-1], dim=-1) |
|
], dim=-1)) |
|
weights = alpha * trans |
|
return weights, alpha, trans |
|
|
|
|
|
def volumetric_rendering(rgbs, |
|
weights, |
|
tdist, |
|
bg_rgbs, |
|
t_far, |
|
compute_extras, |
|
extras=None): |
|
"""Volumetric Rendering Function. |
|
|
|
Args: |
|
rgbs: color, [batch_size, num_samples, 3] |
|
weights: weights, [batch_size, num_samples]. |
|
tdist: [batch_size, num_samples]. |
|
bg_rgbs: the color(s) to use for the background. |
|
t_far: [batch_size, 1], the distance of the far plane. |
|
compute_extras: bool, if True, compute extra quantities besides color. |
|
extras: dict, a set of values along rays to render by alpha compositing. |
|
|
|
Returns: |
|
rendering: a dict containing an rgb image of size [batch_size, 3], and other |
|
visualizations if compute_extras=True. |
|
""" |
|
eps = torch.finfo(rgbs.dtype).eps |
|
|
|
rendering = {} |
|
|
|
acc = weights.sum(dim=-1) |
|
bg_w = (1 - acc[..., None]).clamp_min(0.) |
|
rgb = (weights[..., None] * rgbs).sum(dim=-2) + bg_w * bg_rgbs |
|
t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:]) |
|
depth = ( |
|
torch.clip( |
|
torch.nan_to_num((weights * t_mids).sum(dim=-1) / acc.clamp_min(eps), torch.inf), |
|
tdist[..., 0], tdist[..., -1])) |
|
|
|
rendering['rgb'] = rgb |
|
rendering['depth'] = depth |
|
rendering['acc'] = acc |
|
|
|
if compute_extras: |
|
if extras is not None: |
|
for k, v in extras.items(): |
|
if v is not None: |
|
rendering[k] = (weights[..., None] * v).sum(dim=-2) |
|
|
|
expectation = lambda x: (weights * x).sum(dim=-1) / acc.clamp_min(eps) |
|
|
|
rendering['distance_mean'] = ( |
|
torch.clip( |
|
torch.nan_to_num(torch.exp(expectation(torch.log(t_mids))), torch.inf), |
|
tdist[..., 0], tdist[..., -1])) |
|
|
|
|
|
|
|
|
|
t_aug = torch.cat([tdist, t_far], dim=-1) |
|
weights_aug = torch.cat([weights, bg_w], dim=-1) |
|
|
|
ps = [5, 50, 95] |
|
distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps) |
|
|
|
for i, p in enumerate(ps): |
|
s = 'median' if p == 50 else 'percentile_' + str(p) |
|
rendering['distance_' + s] = distance_percentiles[..., i] |
|
|
|
return rendering |
|
|