Spaces:
Runtime error
Runtime error
import math | |
import trimesh | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from packaging import version as pver | |
import tinycudann as tcnn | |
from torch.autograd import Function | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
import raymarching | |
def custom_meshgrid(*args): | |
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid | |
if pver.parse(torch.__version__) < pver.parse('1.10'): | |
return torch.meshgrid(*args) | |
else: | |
return torch.meshgrid(*args, indexing='ij') | |
def sample_pdf(bins, weights, n_samples, det=False): | |
# This implementation is from NeRF | |
# bins: [B, T], old_z_vals | |
# weights: [B, T - 1], bin weights. | |
# return: [B, n_samples], new_z_vals | |
# Get pdf | |
weights = weights + 1e-5 # prevent nans | |
pdf = weights / torch.sum(weights, -1, keepdim=True) | |
cdf = torch.cumsum(pdf, -1) | |
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) | |
# Take uniform samples | |
if det: | |
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) | |
u = u.expand(list(cdf.shape[:-1]) + [n_samples]) | |
else: | |
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) | |
# Invert CDF | |
u = u.contiguous() | |
inds = torch.searchsorted(cdf, u, right=True) | |
below = torch.max(torch.zeros_like(inds - 1), inds - 1) | |
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) | |
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) | |
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] | |
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) | |
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) | |
denom = (cdf_g[..., 1] - cdf_g[..., 0]) | |
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) | |
t = (u - cdf_g[..., 0]) / denom | |
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) | |
return samples | |
def plot_pointcloud(pc, color=None): | |
# pc: [N, 3] | |
# color: [N, 3/4] | |
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) | |
pc = trimesh.PointCloud(pc, color) | |
# axis | |
axes = trimesh.creation.axis(axis_length=4) | |
# sphere | |
sphere = trimesh.creation.icosphere(radius=1) | |
trimesh.Scene([pc, axes, sphere]).show() | |
class NGPRenderer(nn.Module): | |
def __init__(self, | |
bound=1, | |
cuda_ray=True, | |
density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. | |
min_near=0.2, | |
density_thresh=0.01, | |
bg_radius=-1, | |
): | |
super().__init__() | |
self.bound = bound | |
self.cascade = 1 | |
self.grid_size = 128 | |
self.density_scale = density_scale | |
self.min_near = min_near | |
self.density_thresh = density_thresh | |
self.bg_radius = bg_radius # radius of the background sphere. | |
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) | |
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. | |
aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) | |
aabb_infer = aabb_train.clone() | |
self.register_buffer('aabb_train', aabb_train) | |
self.register_buffer('aabb_infer', aabb_infer) | |
# extra state for cuda raymarching | |
self.cuda_ray = cuda_ray | |
if cuda_ray: | |
# density grid | |
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] | |
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] | |
self.register_buffer('density_grid', density_grid) | |
self.register_buffer('density_bitfield', density_bitfield) | |
self.mean_density = 0 | |
self.iter_density = 0 | |
# step counter | |
step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... | |
self.register_buffer('step_counter', step_counter) | |
self.mean_count = 0 | |
self.local_step = 0 | |
def forward(self, x, d): | |
raise NotImplementedError() | |
# separated density and color query (can accelerate non-cuda-ray mode.) | |
def density(self, x): | |
raise NotImplementedError() | |
def color(self, x, d, mask=None, **kwargs): | |
raise NotImplementedError() | |
def reset_extra_state(self): | |
if not self.cuda_ray: | |
return | |
# density grid | |
self.density_grid.zero_() | |
self.mean_density = 0 | |
self.iter_density = 0 | |
# step counter | |
self.step_counter.zero_() | |
self.mean_count = 0 | |
self.local_step = 0 | |
def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs): | |
# rays_o, rays_d: [B, N, 3], assumes B == 1 | |
# bg_color: [3] in range [0, 1] | |
# return: image: [B, N, 3], depth: [B, N] | |
prefix = rays_o.shape[:-1] | |
rays_o = rays_o.contiguous().view(-1, 3) | |
rays_d = rays_d.contiguous().view(-1, 3) | |
N = rays_o.shape[0] # N = B * N, in fact | |
device = rays_o.device | |
# choose aabb | |
aabb = self.aabb_train if self.training else self.aabb_infer | |
# sample steps | |
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near) | |
nears.unsqueeze_(-1) | |
fars.unsqueeze_(-1) | |
#print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') | |
z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T] | |
z_vals = z_vals.expand((N, num_steps)) # [N, T] | |
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] | |
# perturb z_vals | |
sample_dist = (fars - nears) / num_steps | |
if perturb: | |
z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist | |
#z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. | |
# generate xyzs | |
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] | |
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. | |
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) | |
# query SDF and RGB | |
density_outputs = self.density(xyzs.reshape(-1, 3)) | |
#sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] | |
for k, v in density_outputs.items(): | |
density_outputs[k] = v.view(N, num_steps, -1) | |
# upsample z_vals (nerf-like) | |
if upsample_steps > 0: | |
with torch.no_grad(): | |
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] | |
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) | |
alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T] | |
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] | |
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] | |
# sample new z_vals | |
z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] | |
new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t] | |
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] | |
new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. | |
# only forward new points to save computation | |
new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) | |
#new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] | |
for k, v in new_density_outputs.items(): | |
new_density_outputs[k] = v.view(N, upsample_steps, -1) | |
# re-order | |
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] | |
z_vals, z_index = torch.sort(z_vals, dim=1) | |
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] | |
xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) | |
for k in density_outputs: | |
tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) | |
density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) | |
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] | |
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) | |
alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t] | |
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] | |
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] | |
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) | |
for k, v in density_outputs.items(): | |
density_outputs[k] = v.view(-1, v.shape[-1]) | |
mask = weights > 1e-4 # hard coded | |
rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs) | |
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] | |
#print(xyzs.shape, 'valid_rgb:', mask.sum().item()) | |
# calculate weight_sum (mask) | |
weights_sum = weights.sum(dim=-1) # [N] | |
# calculate depth | |
ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) | |
depth = torch.sum(weights * ori_z_vals, dim=-1) | |
# calculate color | |
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] | |
# mix background color | |
if self.bg_radius > 0: | |
# use the bg model to calculate bg_color | |
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] | |
bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3] | |
elif bg_color is None: | |
bg_color = 1 | |
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color | |
image = image.view(*prefix, 3) | |
depth = depth.view(*prefix) | |
# tmp: reg loss in mip-nerf 360 | |
# z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1) | |
# mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T] | |
# loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum() | |
return { | |
'depth': depth, | |
'image': image, | |
'weights_sum': weights_sum, | |
} | |
def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): | |
# rays_o, rays_d: [B, N, 3], assumes B == 1 | |
# return: image: [B, N, 3], depth: [B, N] | |
prefix = rays_o.shape[:-1] | |
rays_o = rays_o.contiguous().view(-1, 3) | |
rays_d = rays_d.contiguous().view(-1, 3) | |
N = rays_o.shape[0] # N = B * N, in fact | |
device = rays_o.device | |
# pre-calculate near far | |
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) | |
# mix background color | |
if self.bg_radius > 0: | |
# use the bg model to calculate bg_color | |
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] | |
bg_color = self.background(sph, rays_d) # [N, 3] | |
elif bg_color is None: | |
bg_color = 1 | |
results = {} | |
if self.training: | |
# setup counter | |
counter = self.step_counter[self.local_step % 16] | |
counter.zero_() # set to 0 | |
self.local_step += 1 | |
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) | |
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) | |
sigmas, rgbs = self(xyzs, dirs) | |
sigmas = self.density_scale * sigmas | |
weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh) | |
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color | |
depth = torch.clamp(depth - nears, min=0) / (fars - nears) | |
image = image.view(*prefix, 3) | |
depth = depth.view(*prefix) | |
else: | |
# allocate outputs | |
# if use autocast, must init as half so it won't be autocasted and lose reference. | |
#dtype = torch.half if torch.is_autocast_enabled() else torch.float32 | |
# output should always be float32! only network inference uses half. | |
dtype = torch.float32 | |
weights_sum = torch.zeros(N, dtype=dtype, device=device) | |
depth = torch.zeros(N, dtype=dtype, device=device) | |
image = torch.zeros(N, 3, dtype=dtype, device=device) | |
n_alive = N | |
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] | |
rays_t = nears.clone() # [N] | |
step = 0 | |
while step < max_steps: | |
# count alive rays | |
n_alive = rays_alive.shape[0] | |
# exit loop | |
if n_alive <= 0: | |
break | |
# decide compact_steps | |
n_step = max(min(N // n_alive, 8), 1) | |
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) | |
sigmas, rgbs = self(xyzs, dirs) | |
# density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. | |
# sigmas = density_outputs['sigma'] | |
# rgbs = self.color(xyzs, dirs, **density_outputs) | |
sigmas = self.density_scale * sigmas | |
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh) | |
rays_alive = rays_alive[rays_alive >= 0] | |
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') | |
step += n_step | |
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color | |
depth = torch.clamp(depth - nears, min=0) / (fars - nears) | |
image = image.view(*prefix, 3) | |
depth = depth.view(*prefix) | |
results['weights_sum'] = weights_sum | |
results['depth'] = depth | |
results['image'] = image | |
return results | |
def mark_untrained_grid(self, poses, intrinsic, S=64): | |
# poses: [B, 4, 4] | |
# intrinsic: [3, 3] | |
if not self.cuda_ray: | |
return | |
if isinstance(poses, np.ndarray): | |
poses = torch.from_numpy(poses) | |
B = poses.shape[0] | |
fx, fy, cx, cy = intrinsic | |
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) | |
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) | |
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) | |
count = torch.zeros_like(self.density_grid) | |
poses = poses.to(count.device) | |
# 5-level loop, forgive me... | |
for xs in X: | |
for ys in Y: | |
for zs in Z: | |
# construct points | |
xx, yy, zz = custom_meshgrid(xs, ys, zs) | |
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) | |
indices = raymarching.morton3D(coords).long() # [N] | |
world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1] | |
# cascading | |
for cas in range(self.cascade): | |
bound = min(2 ** cas, self.bound) | |
half_grid_size = bound / self.grid_size | |
# scale to current cascade's resolution | |
cas_world_xyzs = world_xyzs * (bound - half_grid_size) | |
# split batch to avoid OOM | |
head = 0 | |
while head < B: | |
tail = min(head + S, B) | |
# world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) | |
cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1) | |
cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] | |
# query if point is covered by any camera | |
mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] | |
mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 | |
mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 | |
mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N] | |
# update count | |
count[cas, indices] += mask | |
head += S | |
# mark untrained grid as -1 | |
self.density_grid[count == 0] = -1 | |
print(f'[mark untrained grid] {(count == 0).sum()} from {self.grid_size ** 3 * self.cascade}') | |
def update_extra_state(self, decay=0.95, S=128): | |
# call before each epoch to update extra states. | |
if not self.cuda_ray: | |
return | |
### update density grid | |
tmp_grid = - torch.ones_like(self.density_grid) | |
# full update. | |
if self.iter_density < 16: | |
#if True: | |
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) | |
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) | |
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) | |
for xs in X: | |
for ys in Y: | |
for zs in Z: | |
# construct points | |
xx, yy, zz = custom_meshgrid(xs, ys, zs) | |
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) | |
indices = raymarching.morton3D(coords).long() # [N] | |
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] | |
# cascading | |
for cas in range(self.cascade): | |
bound = min(2 ** cas, self.bound) | |
half_grid_size = bound / self.grid_size | |
# scale to current cascade's resolution | |
cas_xyzs = xyzs * (bound - half_grid_size) | |
# add noise in [-hgs, hgs] | |
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size | |
# query density | |
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() | |
sigmas *= self.density_scale | |
# assign | |
tmp_grid[cas, indices] = sigmas | |
# partial update (half the computation) | |
# TODO: why no need of maxpool ? | |
else: | |
N = self.grid_size ** 3 // 4 # H * H * H / 4 | |
for cas in range(self.cascade): | |
# random sample some positions | |
coords = torch.randint(0, self.grid_size, (N, 3), device=self.density_bitfield.device) # [N, 3], in [0, 128) | |
indices = raymarching.morton3D(coords).long() # [N] | |
# random sample occupied positions | |
occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(-1) # [Nz] | |
rand_mask = torch.randint(0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_bitfield.device) | |
occ_indices = occ_indices[rand_mask] # [Nz] --> [N], allow for duplication | |
occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3] | |
# concat | |
indices = torch.cat([indices, occ_indices], dim=0) | |
coords = torch.cat([coords, occ_coords], dim=0) | |
# same below | |
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] | |
bound = min(2 ** cas, self.bound) | |
half_grid_size = bound / self.grid_size | |
# scale to current cascade's resolution | |
cas_xyzs = xyzs * (bound - half_grid_size) | |
# add noise in [-hgs, hgs] | |
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size | |
# query density | |
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() | |
sigmas *= self.density_scale | |
# assign | |
tmp_grid[cas, indices] = sigmas | |
## max-pool on tmp_grid for less aggressive culling [No significant improvement...] | |
# invalid_mask = tmp_grid < 0 | |
# tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1) | |
# tmp_grid[invalid_mask] = -1 | |
# ema update | |
valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) | |
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) | |
self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 regions are viewed as 0 density. | |
#self.mean_density = torch.mean(self.density_grid[self.density_grid > 0]).item() # do not count -1 regions | |
self.iter_density += 1 | |
# convert to bitfield | |
density_thresh = min(self.mean_density, self.density_thresh) | |
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) | |
### update step counter | |
total_step = min(16, self.local_step) | |
if total_step > 0: | |
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) | |
self.local_step = 0 | |
#print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') | |
def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): | |
# rays_o, rays_d: [B, N, 3], assumes B == 1 | |
# return: pred_rgb: [B, N, 3] | |
if self.cuda_ray: | |
_run = self.run_cuda | |
else: | |
_run = self.run | |
results = _run(rays_o, rays_d, **kwargs) | |
return results | |
class _trunc_exp(Function): | |
# cast to float32 | |
def forward(ctx, x): | |
ctx.save_for_backward(x) | |
return torch.exp(x) | |
def backward(ctx, g): | |
x = ctx.saved_tensors[0] | |
return g * torch.exp(x.clamp(-15, 15)) | |
trunc_exp = _trunc_exp.apply | |
class NGPNetwork(NGPRenderer): | |
def __init__(self, | |
num_layers=2, | |
hidden_dim=64, | |
geo_feat_dim=15, | |
num_layers_color=3, | |
hidden_dim_color=64, | |
bound=0.5, | |
max_resolution=128, | |
base_resolution=16, | |
n_levels=16, | |
**kwargs | |
): | |
super().__init__(bound, **kwargs) | |
# sigma network | |
self.num_layers = num_layers | |
self.hidden_dim = hidden_dim | |
self.geo_feat_dim = geo_feat_dim | |
self.bound = bound | |
log2_hashmap_size = 19 | |
n_features_per_level = 2 | |
per_level_scale = np.exp2(np.log2(max_resolution / base_resolution) / (n_levels - 1)) | |
self.encoder = tcnn.Encoding( | |
n_input_dims=3, | |
encoding_config={ | |
"otype": "HashGrid", | |
"n_levels": n_levels, | |
"n_features_per_level": n_features_per_level, | |
"log2_hashmap_size": log2_hashmap_size, | |
"base_resolution": base_resolution, | |
"per_level_scale": per_level_scale, | |
}, | |
) | |
self.sigma_net = tcnn.Network( | |
n_input_dims = n_levels * 2, | |
n_output_dims=1 + self.geo_feat_dim, | |
network_config={ | |
"otype": "FullyFusedMLP", | |
"activation": "ReLU", | |
"output_activation": "None", | |
"n_neurons": hidden_dim, | |
"n_hidden_layers": num_layers - 1, | |
}, | |
) | |
# color network | |
self.num_layers_color = num_layers_color | |
self.hidden_dim_color = hidden_dim_color | |
self.encoder_dir = tcnn.Encoding( | |
n_input_dims=3, | |
encoding_config={ | |
"otype": "SphericalHarmonics", | |
"degree": 4, | |
}, | |
) | |
self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim | |
self.color_net = tcnn.Network( | |
n_input_dims = self.in_dim_color, | |
n_output_dims=3, | |
network_config={ | |
"otype": "FullyFusedMLP", | |
"activation": "ReLU", | |
"output_activation": "None", | |
"n_neurons": hidden_dim_color, | |
"n_hidden_layers": num_layers_color - 1, | |
}, | |
) | |
self.density_scale, self.density_std = 10.0, 0.25 | |
def forward(self, x, d): | |
# x: [N, 3], in [-bound, bound] | |
# d: [N, 3], nomalized in [-1, 1] | |
# sigma | |
x_raw = x | |
x = (x + self.bound) / (2 * self.bound) # to [0, 1] | |
x = self.encoder(x) | |
h = self.sigma_net(x) | |
# sigma = F.relu(h[..., 0]) | |
density = h[..., 0] | |
# add density bias | |
dist = torch.norm(x_raw, dim=-1) | |
density_bias = (1 - dist / self.density_std) * self.density_scale | |
density = density_bias + density | |
sigma = F.softplus(density) | |
geo_feat = h[..., 1:] | |
# color | |
d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] | |
d = self.encoder_dir(d) | |
# p = torch.zeros_like(geo_feat[..., :1]) # manual input padding | |
h = torch.cat([d, geo_feat], dim=-1) | |
h = self.color_net(h) | |
# sigmoid activation for rgb | |
color = torch.sigmoid(h) | |
return sigma, color | |
def density(self, x): | |
# x: [N, 3], in [-bound, bound] | |
x_raw = x | |
x = (x + self.bound) / (2 * self.bound) # to [0, 1] | |
x = self.encoder(x) | |
h = self.sigma_net(x) | |
# sigma = F.relu(h[..., 0]) | |
density = h[..., 0] | |
# add density bias | |
dist = torch.norm(x_raw, dim=-1) | |
density_bias = (1 - dist / self.density_std) * self.density_scale | |
density = density_bias + density | |
sigma = F.softplus(density) | |
geo_feat = h[..., 1:] | |
return { | |
'sigma': sigma, | |
'geo_feat': geo_feat, | |
} | |
# allow masked inference | |
def color(self, x, d, mask=None, geo_feat=None, **kwargs): | |
# x: [N, 3] in [-bound, bound] | |
# mask: [N,], bool, indicates where we actually needs to compute rgb. | |
x = (x + self.bound) / (2 * self.bound) # to [0, 1] | |
if mask is not None: | |
rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] | |
# in case of empty mask | |
if not mask.any(): | |
return rgbs | |
x = x[mask] | |
d = d[mask] | |
geo_feat = geo_feat[mask] | |
# color | |
d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] | |
d = self.encoder_dir(d) | |
h = torch.cat([d, geo_feat], dim=-1) | |
h = self.color_net(h) | |
# sigmoid activation for rgb | |
h = torch.sigmoid(h) | |
if mask is not None: | |
rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 | |
else: | |
rgbs = h | |
return rgbs | |