|
from typing import * |
|
from numbers import Number |
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
from .utils import image_uv |
|
|
|
|
|
__all__ = [ |
|
'get_rays', |
|
'get_image_rays', |
|
'get_mipnerf_cones', |
|
'volume_rendering', |
|
'bin_sample', |
|
'importance_sample', |
|
'nerf_render_rays', |
|
'mipnerf_render_rays', |
|
'nerf_render_view', |
|
'mipnerf_render_view', |
|
'InstantNGP', |
|
] |
|
|
|
|
|
def get_rays(extrinsics: Tensor, intrinsics: Tensor, uv: Tensor) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Args: |
|
extrinsics: (..., 4, 4) extrinsics matrices. |
|
intrinsics: (..., 3, 3) intrinsics matrices. |
|
uv: (..., n_rays, 2) uv coordinates of the rays. |
|
|
|
Returns: |
|
rays_o: (..., 1, 3) ray origins |
|
rays_d: (..., n_rays, 3) ray directions. |
|
NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth. |
|
""" |
|
uvz = torch.cat([uv, torch.ones_like(uv[..., :1])], dim=-1).to(extrinsics) |
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
inv_transformation = (intrinsics @ extrinsics[..., :3, :3]).inverse() |
|
inv_extrinsics = extrinsics.inverse() |
|
rays_d = uvz @ inv_transformation.transpose(-1, -2) |
|
rays_o = inv_extrinsics[..., None, :3, 3] |
|
return rays_o, rays_d |
|
|
|
|
|
def get_image_rays(extrinsics: Tensor, intrinsics: Tensor, width: int, height: int) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Args: |
|
extrinsics: (..., 4, 4) extrinsics matrices. |
|
intrinsics: (..., 3, 3) intrinsics matrices. |
|
width: width of the image. |
|
height: height of the image. |
|
|
|
Returns: |
|
rays_o: (..., 1, 1, 3) ray origins |
|
rays_d: (..., height, width, 3) ray directions. |
|
NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth. |
|
""" |
|
uv = image_uv(height, width).to(extrinsics).flatten(0, 1) |
|
rays_o, rays_d = get_rays(extrinsics, intrinsics, uv) |
|
rays_o = rays_o.unflatten(-2, (1, 1)) |
|
rays_d = rays_d.unflatten(-2, (height, width)) |
|
return rays_o, rays_d |
|
|
|
|
|
def get_mipnerf_cones(rays_o: Tensor, rays_d: Tensor, z_vals: Tensor, pixel_width: Tensor) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Args: |
|
rays_o: (..., n_rays, 3) ray origins |
|
rays_d: (..., n_rays, 3) ray directions. |
|
z_vals: (..., n_rays, n_samples) z values. |
|
pixel_width: (...) pixel width. = 1 / (normalized focal length * width) |
|
|
|
Returns: |
|
mu: (..., n_rays, n_samples, 3) cone mu. |
|
sigma: (..., n_rays, n_samples, 3, 3) cone sigma. |
|
""" |
|
t_mu = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) |
|
t_delta = (z_vals[..., 1:] - z_vals[..., :-1]).mul_(0.5) |
|
t_mu_square = t_mu.square() |
|
t_delta_square = t_delta.square() |
|
t_delta_quad = t_delta_square.square() |
|
mu_t = t_mu + 2.0 * t_mu * t_delta_square / (3.0 * t_mu_square + t_delta_square) |
|
sigma_t = t_delta_square / 3.0 - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square).square() * (12.0 * t_mu_square - t_delta_square) |
|
sigma_r = (pixel_width[..., None, None].square() / 3.0) * (t_mu_square / 4.0 + (5.0 / 12.0) * t_delta_square - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square)) |
|
points_mu = rays_o[:, :, :, None, :] + rays_d[:, :, :, None, :] * mu_t[..., None] |
|
d_dt = rays_d[..., :, None] * rays_d[..., None, :] |
|
points_sigma = sigma_t[..., None, None] * d_dt[..., None, :, :] + sigma_r[..., None, None] * (torch.eye(3).to(rays_o) - d_dt[..., None, :, :]) |
|
return points_mu, points_sigma |
|
|
|
|
|
def get_pixel_width(intrinsics: Tensor, width: int, height: int) -> Tensor: |
|
""" |
|
Args: |
|
intrinsics: (..., 3, 3) intrinsics matrices. |
|
width: width of the image. |
|
height: height of the image. |
|
|
|
Returns: |
|
pixel_width: (...) pixel width. = 1 / (normalized focal length * width) |
|
""" |
|
assert width == height, "Currently, only square images are supported." |
|
pixel_width = torch.reciprocal((intrinsics[..., 0, 0] * intrinsics[..., 1, 1]).sqrt() * width) |
|
return pixel_width |
|
|
|
|
|
def volume_rendering(color: Tensor, sigma: Tensor, z_vals: Tensor, ray_length: Tensor, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]: |
|
""" |
|
Given color, sigma and z_vals (linear depth of the sampling points), render the volume. |
|
|
|
NOTE: By default, color and sigma should have one less sample than z_vals, in correspondence with the average value in intervals. |
|
If queried color are aligned with z_vals, we use trapezoidal rule to calculate the average values in intervals. |
|
|
|
Args: |
|
color: (..., n_samples or n_samples - 1, 3) color values. |
|
sigma: (..., n_samples or n_samples - 1) density values. |
|
z_vals: (..., n_samples) z values. |
|
ray_length: (...) length of the ray |
|
|
|
Returns: |
|
rgb: (..., 3) rendered color values. |
|
depth: (...) rendered depth values. |
|
weights (..., n_samples) weights. |
|
""" |
|
dists = (z_vals[..., 1:] - z_vals[..., :-1]) * ray_length[..., None] |
|
if color.shape[-2] == z_vals.shape[-1]: |
|
color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5) |
|
sigma = (sigma[..., 1:] + sigma[..., :-1]).mul_(0.5) |
|
sigma_delta = sigma * dists |
|
transparancy = (-torch.cat([torch.zeros_like(sigma_delta[..., :1]), sigma_delta[..., :-1]], dim=-1).cumsum(dim=-1)).exp_() |
|
alpha = 1.0 - (-sigma_delta).exp_() |
|
weights = alpha * transparancy |
|
if rgb: |
|
rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None |
|
if depth: |
|
z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) |
|
depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None |
|
return rgb, depth, weights |
|
|
|
|
|
def neus_volume_rendering(color: Tensor, sdf: Tensor, s: torch.Tensor, z_vals: Tensor = None, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]: |
|
""" |
|
Given color, sdf values and z_vals (linear depth of the sampling points), do volume rendering. (NeuS) |
|
|
|
Args: |
|
color: (..., n_samples or n_samples - 1, 3) color values. |
|
sdf: (..., n_samples) sdf values. |
|
s: (..., n_samples) S values of S-density function in NeuS. The standard deviation of such S-density distribution is 1 / s. |
|
z_vals: (..., n_samples) z values. |
|
ray_length: (...) length of the ray |
|
|
|
Returns: |
|
rgb: (..., 3) rendered color values. |
|
depth: (...) rendered depth values. |
|
weights (..., n_samples) weights. |
|
""" |
|
|
|
if color.shape[-2] == z_vals.shape[-1]: |
|
color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5) |
|
|
|
sigmoid_sdf = torch.sigmoid(s * sdf) |
|
alpha = F.relu(1 - sigmoid_sdf[..., :-1] / sigmoid_sdf[..., :-1]) |
|
transparancy = torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), alpha], dim=-1), dim=-1) |
|
weights = alpha * transparancy |
|
|
|
if rgb: |
|
rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None |
|
if depth: |
|
z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) |
|
depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None |
|
return rgb, depth, weights |
|
|
|
|
|
def bin_sample(size: Union[torch.Size, Tuple[int, ...]], n_samples: int, min_value: Number, max_value: Number, spacing: Literal['linear', 'inverse_linear'], dtype: torch.dtype = None, device: torch.device = None) -> Tensor: |
|
""" |
|
Uniformly (or uniformly in inverse space) sample z values in `n_samples` bins in range [min_value, max_value]. |
|
Args: |
|
size: size of the rays |
|
n_samples: number of samples to be sampled, also the number of bins |
|
min_value: minimum value of the range |
|
max_value: maximum value of the range |
|
space: 'linear' or 'inverse_linear'. If 'inverse_linear', the sampling is uniform in inverse space. |
|
|
|
Returns: |
|
z_rand: (*size, n_samples) sampled z values, sorted in ascending order. |
|
""" |
|
if spacing == 'linear': |
|
pass |
|
elif spacing == 'inverse_linear': |
|
min_value = 1.0 / min_value |
|
max_value = 1.0 / max_value |
|
bin_length = (max_value - min_value) / n_samples |
|
z_rand = (torch.rand(*size, n_samples, device=device, dtype=dtype) - 0.5) * bin_length + torch.linspace(min_value + bin_length * 0.5, max_value - bin_length * 0.5, n_samples, device=device, dtype=dtype) |
|
if spacing == 'inverse_linear': |
|
z_rand = 1.0 / z_rand |
|
return z_rand |
|
|
|
|
|
def importance_sample(z_vals: Tensor, weights: Tensor, n_samples: int) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Importance sample z values. |
|
|
|
NOTE: By default, weights should have one less sample than z_vals, in correspondence with the intervals. |
|
If weights has the same number of samples as z_vals, we use trapezoidal rule to calculate the average weights in intervals. |
|
|
|
Args: |
|
z_vals: (..., n_rays, n_input_samples) z values, sorted in ascending order. |
|
weights: (..., n_rays, n_input_samples or n_input_samples - 1) weights. |
|
n_samples: number of output samples for importance sampling. |
|
|
|
Returns: |
|
z_importance: (..., n_rays, n_samples) importance sampled z values, unsorted. |
|
""" |
|
if weights.shape[-1] == z_vals.shape[-1]: |
|
weights = (weights[..., 1:] + weights[..., :-1]).mul_(0.5) |
|
weights = weights / torch.sum(weights, dim=-1, keepdim=True) |
|
bins_a, bins_b = z_vals[..., :-1], z_vals[..., 1:] |
|
|
|
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) |
|
cdf = torch.cumsum(pdf, dim=-1) |
|
u = torch.rand(*z_vals.shape[:-1], n_samples, device=z_vals.device, dtype=z_vals.dtype) |
|
|
|
inds = torch.searchsorted(cdf, u, right=True).clamp(0, cdf.shape[-1] - 1) |
|
|
|
bins_a = torch.gather(bins_a, dim=-1, index=inds) |
|
bins_b = torch.gather(bins_b, dim=-1, index=inds) |
|
z_importance = bins_a + (bins_b - bins_a) * torch.rand_like(u) |
|
return z_importance |
|
|
|
|
|
def nerf_render_rays( |
|
nerf: Union[Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]], Tuple[Callable[[Tensor], Tuple[Tensor, Tensor]], Callable[[Tensor], Tuple[Tensor, Tensor]]]], |
|
rays_o: Tensor, rays_d: Tensor, |
|
*, |
|
return_dict: bool = False, |
|
n_coarse: int = 64, n_fine: int = 64, |
|
near: float = 0.1, far: float = 100.0, |
|
z_spacing: Literal['linear', 'inverse_linear'] = 'linear', |
|
): |
|
""" |
|
NeRF rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) |
|
|
|
Args: |
|
nerf: nerf model, which takes (points, directions) as input and returns (color, density) as output. |
|
If nerf is a tuple, it should be (nerf_coarse, nerf_fine), where nerf_coarse and nerf_fine are two nerf models for coarse and fine stages respectively. |
|
|
|
nerf args: |
|
points: (..., n_rays, n_samples, 3) |
|
directions: (..., n_rays, n_samples, 3) |
|
nerf returns: |
|
color: (..., n_rays, n_samples, 3) color values. |
|
density: (..., n_rays, n_samples) density values. |
|
|
|
rays_o: (..., n_rays, 3) ray origins |
|
rays_d: (..., n_rays, 3) ray directions. |
|
pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) |
|
|
|
Returns |
|
if return_dict is False, return rendered rgb and depth for short cut. (If there are separate coarse and fine results, return fine results) |
|
rgb: (..., n_rays, 3) rendered color values. |
|
depth: (..., n_rays) rendered depth values. |
|
else, return a dict. If `n_fine == 0` or `nerf` is a single model, the dict only contains coarse results: |
|
``` |
|
{'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} |
|
``` |
|
If there are two models for coarse and fine stages, the dict contains both coarse and fine results: |
|
``` |
|
{ |
|
"coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, |
|
"fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} |
|
} |
|
``` |
|
""" |
|
if isinstance(nerf, tuple): |
|
nerf_coarse, nerf_fine = nerf |
|
else: |
|
nerf_coarse = nerf_fine = nerf |
|
|
|
z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) |
|
points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] |
|
ray_length = rays_d.norm(dim=-1) |
|
|
|
|
|
color_coarse, density_coarse = nerf_coarse(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) |
|
|
|
|
|
with torch.no_grad(): |
|
rgb_coarse, depth_coarse, weights = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) |
|
|
|
if n_fine == 0: |
|
if return_dict: |
|
return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse} |
|
else: |
|
return rgb_coarse, depth_coarse |
|
|
|
|
|
if nerf_coarse is nerf_fine: |
|
|
|
|
|
z_fine = importance_sample(z_coarse, weights, n_fine) |
|
points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None] |
|
color_fine, density_fine = nerf_fine(points_fine, rays_d[..., None, :].expand_as(points_fine)) |
|
|
|
|
|
z_vals = torch.cat([z_coarse, z_fine], dim=-1) |
|
color = torch.cat([color_coarse, color_fine], dim=-2) |
|
density = torch.cat([density_coarse, density_fine], dim=-1) |
|
z_vals, sort_inds = torch.sort(z_vals, dim=-1) |
|
color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color)) |
|
density = torch.gather(density, dim=-1, index=sort_inds) |
|
rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length) |
|
|
|
if return_dict: |
|
return {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density} |
|
else: |
|
return rgb, depth |
|
else: |
|
|
|
z_fine = importance_sample(z_coarse, weights, n_fine) |
|
z_vals = torch.cat([z_coarse, z_fine], dim=-1) |
|
points = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., None] |
|
color, density = nerf_fine(points) |
|
rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length) |
|
|
|
if return_dict: |
|
return { |
|
'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}, |
|
'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density} |
|
} |
|
else: |
|
return rgb, depth |
|
|
|
|
|
def mipnerf_render_rays( |
|
mipnerf: Callable[[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]], |
|
rays_o: Tensor, rays_d: Tensor, pixel_width: Tensor, |
|
*, |
|
return_dict: bool = False, |
|
n_coarse: int = 64, n_fine: int = 64, uniform_ratio: float = 0.4, |
|
near: float = 0.1, far: float = 100.0, |
|
z_spacing: Literal['linear', 'inverse_linear'] = 'linear', |
|
) -> Union[Tuple[Tensor, Tensor], Dict[str, Tensor]]: |
|
""" |
|
MipNeRF rendering. |
|
|
|
Args: |
|
mipnerf: mipnerf model, which takes (points_mu, points_sigma) as input and returns (color, density) as output. |
|
|
|
mipnerf args: |
|
points_mu: (..., n_rays, n_samples, 3) cone mu. |
|
points_sigma: (..., n_rays, n_samples, 3, 3) cone sigma. |
|
directions: (..., n_rays, n_samples, 3) |
|
mipnerf returns: |
|
color: (..., n_rays, n_samples, 3) color values. |
|
density: (..., n_rays, n_samples) density values. |
|
|
|
rays_o: (..., n_rays, 3) ray origins |
|
rays_d: (..., n_rays, 3) ray directions. |
|
pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) |
|
|
|
Returns |
|
if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) |
|
rgb: (..., n_rays, 3) rendered color values. |
|
depth: (..., n_rays) rendered depth values. |
|
else, return a dict. If `n_fine == 0`, the dict only contains coarse results: |
|
``` |
|
{'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} |
|
``` |
|
If n_fine > 0, the dict contains both coarse and fine results : |
|
``` |
|
{ |
|
"coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, |
|
"fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} |
|
} |
|
``` |
|
""" |
|
|
|
z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, spacing=z_spacing, device=rays_o.device, dtype=rays_o.dtype) |
|
points_mu_coarse, points_sigma_coarse = get_mipnerf_cones(rays_o, rays_d, z_coarse, pixel_width) |
|
ray_length = rays_d.norm(dim=-1) |
|
|
|
|
|
color_coarse, density_coarse = mipnerf(points_mu_coarse, points_sigma_coarse, rays_d[..., None, :].expand_as(points_mu_coarse)) |
|
|
|
|
|
rgb_coarse, depth_coarse, weights_coarse = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) |
|
|
|
if n_fine == 0: |
|
if return_dict: |
|
return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse} |
|
else: |
|
return rgb_coarse, depth_coarse |
|
|
|
|
|
with torch.no_grad(): |
|
weights_coarse = (1.0 - uniform_ratio) * weights_coarse + uniform_ratio / weights_coarse.shape[-1] |
|
z_fine = importance_sample(z_coarse, weights_coarse, n_fine) |
|
z_fine, _ = torch.sort(z_fine, dim=-2) |
|
points_mu_fine, points_sigma_fine = get_mipnerf_cones(rays_o, rays_d, z_fine, pixel_width) |
|
color_fine, density_fine = mipnerf(points_mu_fine, points_sigma_fine, rays_d[..., None, :].expand_as(points_mu_fine)) |
|
|
|
|
|
rgb_fine, depth_fine, weights_fine = volume_rendering(color_fine, density_fine, z_fine, ray_length) |
|
|
|
if return_dict: |
|
return { |
|
'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}, |
|
'fine': {'rgb': rgb_fine, 'depth': depth_fine, 'weights': weights_fine, 'z_vals': z_fine, 'color': color_fine, 'density': density_fine} |
|
} |
|
else: |
|
return rgb_fine, depth_fine |
|
|
|
|
|
def neus_render_rays( |
|
neus: Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]], |
|
s: Union[Number, Tensor], |
|
rays_o: Tensor, rays_d: Tensor, |
|
*, |
|
compute_normal: bool = True, |
|
return_dict: bool = False, |
|
n_coarse: int = 64, n_fine: int = 64, |
|
near: float = 0.1, far: float = 100.0, |
|
z_spacing: Literal['linear', 'inverse_linear'] = 'linear', |
|
): |
|
""" |
|
TODO |
|
NeuS rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) |
|
|
|
Args: |
|
neus: neus model, which takes (points, directions) as input and returns (color, density) as output. |
|
|
|
nerf args: |
|
points: (..., n_rays, n_samples, 3) |
|
directions: (..., n_rays, n_samples, 3) |
|
nerf returns: |
|
color: (..., n_rays, n_samples, 3) color values. |
|
density: (..., n_rays, n_samples) density values. |
|
|
|
rays_o: (..., n_rays, 3) ray origins |
|
rays_d: (..., n_rays, 3) ray directions. |
|
pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) |
|
|
|
Returns |
|
if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) |
|
rgb: (..., n_rays, 3) rendered color values. |
|
depth: (..., n_rays) rendered depth values. |
|
else, return a dict. If `n_fine == 0`, the dict only contains coarse results: |
|
``` |
|
{'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'sdf': ..., 'normal': ...} |
|
``` |
|
If n_fine > 0, the dict contains both coarse and fine results: |
|
``` |
|
{ |
|
"coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, |
|
"fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} |
|
} |
|
``` |
|
""" |
|
|
|
|
|
z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) |
|
points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] |
|
|
|
|
|
color_coarse, sdf_coarse = neus(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) |
|
|
|
|
|
with torch.no_grad(): |
|
rgb_coarse, depth_coarse, weights = neus_volume_rendering(color_coarse, sdf_coarse, s, z_coarse) |
|
|
|
if n_fine == 0: |
|
if return_dict: |
|
return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse} |
|
else: |
|
return rgb_coarse, depth_coarse |
|
|
|
|
|
|
|
z_fine = importance_sample(z_coarse, weights, n_fine) |
|
points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None] |
|
color_fine, sdf_fine = neus(points_fine, rays_d[..., None, :].expand_as(points_fine)) |
|
|
|
|
|
z_vals = torch.cat([z_coarse, z_fine], dim=-1) |
|
color = torch.cat([color_coarse, color_fine], dim=-2) |
|
sdf = torch.cat([sdf_coarse, sdf_fine], dim=-1) |
|
z_vals, sort_inds = torch.sort(z_vals, dim=-1) |
|
color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color)) |
|
sdf = torch.gather(sdf, dim=-1, index=sort_inds) |
|
rgb, depth, weights = neus_volume_rendering(color, sdf, s, z_vals) |
|
|
|
if return_dict: |
|
return { |
|
'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse}, |
|
'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'sdf': sdf} |
|
} |
|
else: |
|
return rgb, depth |
|
|
|
|
|
def nerf_render_view( |
|
nerf: Tensor, |
|
extrinsics: Tensor, |
|
intrinsics: Tensor, |
|
width: int, |
|
height: int, |
|
*, |
|
patchify: bool = False, |
|
patch_size: Tuple[int, int] = (64, 64), |
|
**options: Dict[str, Any] |
|
) -> Tuple[Tensor, Tensor]: |
|
""" |
|
NeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) |
|
|
|
Args: |
|
extrinsics: (..., 4, 4) extrinsics matrice of the rendered views |
|
intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. |
|
width (optional): image width of the rendered views. |
|
height (optional): image height of the rendered views. |
|
patchify (optional): If the image is too large, render it patch by patch |
|
**options: rendering options. |
|
|
|
Returns: |
|
rgb: (..., channels, height, width) rendered color values. |
|
depth: (..., height, width) rendered depth values. |
|
""" |
|
if patchify: |
|
|
|
max_patch_width, max_patch_height = patch_size |
|
n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width) |
|
|
|
rgb_rows, depth_rows = [], [] |
|
for i_row in range(n_rows): |
|
rgb_row, depth_row = [], [] |
|
for i_column in range(n_columns): |
|
patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width) |
|
uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics) |
|
uv = uv.flatten(0, 1) |
|
ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) |
|
rgb_, depth_ = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False) |
|
rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) |
|
depth_ = depth_.unflatten(-1, patch_shape) |
|
|
|
rgb_row.append(rgb_) |
|
depth_row.append(depth_) |
|
rgb_rows.append(torch.cat(rgb_row, dim=-1)) |
|
depth_rows.append(torch.cat(depth_row, dim=-1)) |
|
rgb = torch.cat(rgb_rows, dim=-2) |
|
depth = torch.cat(depth_rows, dim=-2) |
|
|
|
return rgb, depth |
|
else: |
|
|
|
uv = image_uv(height, width).to(extrinsics) |
|
uv = uv.flatten(0, 1) |
|
ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) |
|
rgb, depth = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False) |
|
rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) |
|
depth = depth.unflatten(-1, (height, width)) |
|
|
|
return rgb, depth |
|
|
|
|
|
def mipnerf_render_view( |
|
mipnerf: Tensor, |
|
extrinsics: Tensor, |
|
intrinsics: Tensor, |
|
width: int, |
|
height: int, |
|
*, |
|
patchify: bool = False, |
|
patch_size: Tuple[int, int] = (64, 64), |
|
**options: Dict[str, Any] |
|
) -> Tuple[Tensor, Tensor]: |
|
""" |
|
MipNeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) |
|
|
|
Args: |
|
extrinsics: (..., 4, 4) extrinsics matrice of the rendered views |
|
intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. |
|
width (optional): image width of the rendered views. |
|
height (optional): image height of the rendered views. |
|
patchify (optional): If the image is too large, render it patch by patch |
|
**options: rendering options. |
|
|
|
Returns: |
|
rgb: (..., 3, height, width) rendered color values. |
|
depth: (..., height, width) rendered depth values. |
|
""" |
|
pixel_width = get_pixel_width(intrinsics, width, height) |
|
|
|
if patchify: |
|
|
|
max_patch_width, max_patch_height = patch_size |
|
n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width) |
|
|
|
rgb_rows, depth_rows = [], [] |
|
for i_row in range(n_rows): |
|
rgb_row, depth_row = [], [] |
|
for i_column in range(n_columns): |
|
patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width) |
|
uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics) |
|
uv = uv.flatten(0, 1) |
|
ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) |
|
rgb_, depth_ = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options) |
|
rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) |
|
depth_ = depth_.unflatten(-1, patch_shape) |
|
|
|
rgb_row.append(rgb_) |
|
depth_row.append(depth_) |
|
rgb_rows.append(torch.cat(rgb_row, dim=-1)) |
|
depth_rows.append(torch.cat(depth_row, dim=-1)) |
|
rgb = torch.cat(rgb_rows, dim=-2) |
|
depth = torch.cat(depth_rows, dim=-2) |
|
|
|
return rgb, depth |
|
else: |
|
|
|
uv = image_uv(height, width).to(extrinsics) |
|
uv = uv.flatten(0, 1) |
|
ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) |
|
rgb, depth = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options) |
|
rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) |
|
depth = depth.unflatten(-1, (height, width)) |
|
|
|
return rgb, depth |
|
|
|
|
|
class InstantNGP(nn.Module): |
|
""" |
|
An implementation of InstantNGP, Müller et. al., https://nvlabs.github.io/instant-ngp/. |
|
Requires `tinycudann` package. |
|
Install it by: |
|
``` |
|
pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch |
|
``` |
|
""" |
|
def __init__(self, |
|
view_dependent: bool = True, |
|
base_resolution: int = 16, |
|
finest_resolution: int = 2048, |
|
n_levels: int = 16, |
|
num_layers_density: int = 2, |
|
hidden_dim_density: int = 64, |
|
num_layers_color: int = 3, |
|
hidden_dim_color: int = 64, |
|
log2_hashmap_size: int = 19, |
|
bound: float = 1.0, |
|
color_channels: int = 3, |
|
): |
|
super().__init__() |
|
import tinycudann |
|
N_FEATURES_PER_LEVEL = 2 |
|
GEO_FEAT_DIM = 15 |
|
|
|
self.bound = bound |
|
self.color_channels = color_channels |
|
|
|
|
|
self.num_layers_density = num_layers_density |
|
self.hidden_dim_density = hidden_dim_density |
|
|
|
per_level_scale = (finest_resolution / base_resolution) ** (1 / (n_levels - 1)) |
|
|
|
self.encoder = tinycudann.Encoding( |
|
n_input_dims=3, |
|
encoding_config={ |
|
"otype": "HashGrid", |
|
"n_levels": n_levels, |
|
"n_features_per_level": N_FEATURES_PER_LEVEL, |
|
"log2_hashmap_size": log2_hashmap_size, |
|
"base_resolution": base_resolution, |
|
"per_level_scale": per_level_scale, |
|
}, |
|
) |
|
|
|
self.density_net = tinycudann.Network( |
|
n_input_dims=N_FEATURES_PER_LEVEL * n_levels, |
|
n_output_dims=1 + GEO_FEAT_DIM, |
|
network_config={ |
|
"otype": "FullyFusedMLP", |
|
"activation": "ReLU", |
|
"output_activation": "None", |
|
"n_neurons": hidden_dim_density, |
|
"n_hidden_layers": num_layers_density - 1, |
|
}, |
|
) |
|
|
|
|
|
self.num_layers_color = num_layers_color |
|
self.hidden_dim_color = hidden_dim_color |
|
|
|
self.view_dependent = view_dependent |
|
if view_dependent: |
|
self.encoder_dir = tinycudann.Encoding( |
|
n_input_dims=3, |
|
encoding_config={ |
|
"otype": "SphericalHarmonics", |
|
"degree": 4, |
|
}, |
|
) |
|
self.in_dim_color = self.encoder_dir.n_output_dims + GEO_FEAT_DIM |
|
else: |
|
self.in_dim_color = GEO_FEAT_DIM |
|
|
|
self.color_net = tinycudann.Network( |
|
n_input_dims=self.in_dim_color, |
|
n_output_dims=color_channels, |
|
network_config={ |
|
"otype": "FullyFusedMLP", |
|
"activation": "ReLU", |
|
"output_activation": "None", |
|
"n_neurons": hidden_dim_color, |
|
"n_hidden_layers": num_layers_color - 1, |
|
}, |
|
) |
|
|
|
def forward(self, x: torch.Tensor, d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
x: (..., 3) points |
|
d: (..., 3) directions |
|
Returns: |
|
color: (..., 3) color values. |
|
density: (..., 1) density values. |
|
""" |
|
batch_shape = x.shape[:-1] |
|
x, d = x.reshape(-1, 3), d.reshape(-1, 3) |
|
|
|
|
|
x = (x + self.bound) / (2 * self.bound) |
|
x = self.encoder(x) |
|
density, geo_feat = self.density_net(x).split([1, 15], dim=-1) |
|
density = F.softplus(density).squeeze(-1) |
|
|
|
|
|
if self.view_dependent: |
|
d = (F.normalize(d, dim=-1) + 1) / 2 |
|
d = self.encoder_dir(d) |
|
h = torch.cat([d, geo_feat], dim=-1) |
|
else: |
|
h = geo_feat |
|
color = self.color_net(h) |
|
|
|
return color.reshape(*batch_shape, self.color_channels), density.reshape(*batch_shape) |
|
|
|
|