Stable-Dreamfusion / nerf /renderer.py
ashawkey's picture
fix: background net should condition on rays_d
30e1aa8
raw
history blame
No virus
26.9 kB
import os
import math
import cv2
import trimesh
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import mcubes
import raymarching
from .utils import custom_meshgrid, safe_normalize
def sample_pdf(bins, weights, n_samples, det=False):
# This implementation is from NeRF
# bins: [B, T], old_z_vals
# weights: [B, T - 1], bin weights.
# return: [B, n_samples], new_z_vals
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
# Take uniform samples
if det:
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
def plot_pointcloud(pc, color=None):
# pc: [N, 3]
# color: [N, 3/4]
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
pc = trimesh.PointCloud(pc, color)
# axis
axes = trimesh.creation.axis(axis_length=4)
# sphere
sphere = trimesh.creation.icosphere(radius=1)
trimesh.Scene([pc, axes, sphere]).show()
class NeRFRenderer(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.bound = opt.bound
self.cascade = 1 + math.ceil(math.log2(opt.bound))
self.grid_size = 128
self.cuda_ray = opt.cuda_ray
self.min_near = opt.min_near
self.density_thresh = opt.density_thresh
self.bg_radius = opt.bg_radius
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])
aabb_infer = aabb_train.clone()
self.register_buffer('aabb_train', aabb_train)
self.register_buffer('aabb_infer', aabb_infer)
# extra state for cuda raymarching
if self.cuda_ray:
# density grid
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
self.register_buffer('density_grid', density_grid)
self.register_buffer('density_bitfield', density_bitfield)
self.mean_density = 0
self.iter_density = 0
# step counter
step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
self.register_buffer('step_counter', step_counter)
self.mean_count = 0
self.local_step = 0
def forward(self, x, d):
raise NotImplementedError()
def density(self, x):
raise NotImplementedError()
def color(self, x, d, mask=None, **kwargs):
raise NotImplementedError()
def reset_extra_state(self):
if not self.cuda_ray:
return
# density grid
self.density_grid.zero_()
self.mean_density = 0
self.iter_density = 0
# step counter
self.step_counter.zero_()
self.mean_count = 0
self.local_step = 0
@torch.no_grad()
def export_mesh(self, path, resolution=None, S=128):
if resolution is None:
resolution = self.grid_size
density_thresh = min(self.mean_density, self.density_thresh)
sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32)
# query
X = torch.linspace(-1, 1, resolution).split(S)
Y = torch.linspace(-1, 1, resolution).split(S)
Z = torch.linspace(-1, 1, resolution).split(S)
for xi, xs in enumerate(X):
for yi, ys in enumerate(Y):
for zi, zs in enumerate(Z):
xx, yy, zz = custom_meshgrid(xs, ys, zs)
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
val = self.density(pts.to(self.density_bitfield.device))
sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
vertices = vertices / (resolution - 1.0) * 2 - 1
vertices = vertices.astype(np.float32)
triangles = triangles.astype(np.int32)
v = torch.from_numpy(vertices).to(self.density_bitfield.device)
f = torch.from_numpy(triangles).int().to(self.density_bitfield.device)
# mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
# mesh.export(os.path.join(path, f'mesh.ply'))
# texture?
def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
# v, f: torch Tensor
device = v.device
v_np = v.cpu().numpy() # [N, 3]
f_np = f.cpu().numpy() # [M, 3]
print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
# unwrap uvs
import xatlas
import nvdiffrast.torch as dr
from sklearn.neighbors import NearestNeighbors
from scipy.ndimage import binary_dilation, binary_erosion
glctx = dr.RasterizeCudaContext()
atlas = xatlas.Atlas()
atlas.add_mesh(v_np, f_np)
chart_options = xatlas.ChartOptions()
chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
atlas.generate(chart_options=chart_options)
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
# vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
# render uv maps
uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
if ssaa > 1:
h = int(h0 * ssaa)
w = int(w0 * ssaa)
else:
h, w = h0, w0
rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
# masked query
xyzs = xyzs.view(-1, 3)
mask = (mask > 0).view(-1)
sigmas = torch.zeros(h * w, device=device, dtype=torch.float32)
feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
if mask.any():
xyzs = xyzs[mask] # [M, 3]
# batched inference to avoid OOM
all_sigmas = []
all_feats = []
head = 0
while head < xyzs.shape[0]:
tail = min(head + 640000, xyzs.shape[0])
results_ = self.density(xyzs[head:tail])
all_sigmas.append(results_['sigma'].float())
all_feats.append(results_['albedo'].float())
head += 640000
sigmas[mask] = torch.cat(all_sigmas, dim=0)
feats[mask] = torch.cat(all_feats, dim=0)
sigmas = sigmas.view(h, w, 1)
feats = feats.view(h, w, -1)
mask = mask.view(h, w)
### alpha mask
# deltas = 2 * np.sqrt(3) / 1024
# alphas = 1 - torch.exp(-sigmas * deltas)
# alphas_mask = alphas > 0.5
# feats = feats * alphas_mask
# quantize [0.0, 1.0] to [0, 255]
feats = feats.cpu().numpy()
feats = (feats * 255).astype(np.uint8)
# alphas = alphas.cpu().numpy()
# alphas = (alphas * 255).astype(np.uint8)
### NN search as an antialiasing ...
mask = mask.cpu().numpy()
inpaint_region = binary_dilation(mask, iterations=3)
inpaint_region[mask] = 0
search_region = mask.copy()
not_search_region = binary_erosion(search_region, iterations=2)
search_region[not_search_region] = 0
search_coords = np.stack(np.nonzero(search_region), axis=-1)
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
_, indices = knn.kneighbors(inpaint_coords)
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
# do ssaa after the NN search, in numpy
feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
if ssaa > 1:
# alphas = cv2.resize(alphas, (w0, h0), interpolation=cv2.INTER_NEAREST)
feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
# cv2.imwrite(os.path.join(path, f'alpha.png'), alphas)
cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)
# save obj (v, vt, f /)
obj_file = os.path.join(path, f'{name}mesh.obj')
mtl_file = os.path.join(path, f'{name}mesh.mtl')
print(f'[INFO] writing obj mesh to {obj_file}')
with open(obj_file, "w") as fp:
fp.write(f'mtllib {name}mesh.mtl \n')
print(f'[INFO] writing vertices {v_np.shape}')
for v in v_np:
fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
print(f'[INFO] writing vertices texture coords {vt_np.shape}')
for v in vt_np:
fp.write(f'vt {v[0]} {1 - v[1]} \n')
print(f'[INFO] writing faces {f_np.shape}')
fp.write(f'usemtl mat0 \n')
for i in range(len(f_np)):
fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
with open(mtl_file, "w") as fp:
fp.write(f'newmtl mat0 \n')
fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
fp.write(f'Tr 1.000000 \n')
fp.write(f'illum 1 \n')
fp.write(f'Ns 0.000000 \n')
fp.write(f'map_Kd {name}albedo.png \n')
_export(v, f)
def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# bg_color: [BN, 3] in range [0, 1]
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
results = {}
# choose aabb
aabb = self.aabb_train if self.training else self.aabb_infer
# sample steps
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
nears.unsqueeze_(-1)
fars.unsqueeze_(-1)
# random sample light_d if not provided
if light_d is None:
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
light_d = safe_normalize(light_d)
#print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
z_vals = z_vals.expand((N, num_steps)) # [N, T]
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
# perturb z_vals
sample_dist = (fars - nears) / num_steps
if perturb:
z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
#z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
# generate xyzs
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
# query SDF and RGB
density_outputs = self.density(xyzs.reshape(-1, 3))
#sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
for k, v in density_outputs.items():
density_outputs[k] = v.view(N, num_steps, -1)
# upsample z_vals (nerf-like)
if upsample_steps > 0:
with torch.no_grad():
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
# sample new z_vals
z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
# only forward new points to save computation
new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
#new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
for k, v in new_density_outputs.items():
new_density_outputs[k] = v.view(N, upsample_steps, -1)
# re-order
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
z_vals, z_index = torch.sort(z_vals, dim=1)
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
for k in density_outputs:
tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
for k, v in density_outputs.items():
density_outputs[k] = v.view(-1, v.shape[-1])
sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading)
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
#print(xyzs.shape, 'valid_rgb:', mask.sum().item())
# orientation loss
if normals is not None:
normals = normals.view(N, -1, 3)
# print(weights.shape, normals.shape, dirs.shape)
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
results['loss_orient'] = loss_orient.mean()
# calculate weight_sum (mask)
weights_sum = weights.sum(dim=-1) # [N]
# calculate depth
ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
depth = torch.sum(weights * ori_z_vals, dim=-1)
# calculate color
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
# mix background color
if self.bg_radius > 0:
# use the bg model to calculate bg_color
# sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
bg_color = self.background(rays_d.reshape(-1, 3)) # [N, 3]
elif bg_color is None:
bg_color = 1
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
image = image.view(*prefix, 3)
depth = depth.view(*prefix)
mask = (nears < fars).reshape(*prefix)
results['image'] = image
results['depth'] = depth
results['weights_sum'] = weights_sum
results['mask'] = mask
return results
def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
# pre-calculate near far
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)
# random sample light_d if not provided
if light_d is None:
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
light_d = safe_normalize(light_d)
results = {}
if self.training:
# setup counter
counter = self.step_counter[self.local_step % 16]
counter.zero_() # set to 0
self.local_step += 1
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
#print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
# orientation loss
if normals is not None:
weights = 1 - torch.exp(-sigmas)
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
results['loss_orient'] = loss_orient.mean()
else:
# allocate outputs
dtype = torch.float32
weights_sum = torch.zeros(N, dtype=dtype, device=device)
depth = torch.zeros(N, dtype=dtype, device=device)
image = torch.zeros(N, 3, dtype=dtype, device=device)
n_alive = N
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
rays_t = nears.clone() # [N]
step = 0
while step < max_steps: # hard coded max step
# count alive rays
n_alive = rays_alive.shape[0]
# exit loop
if n_alive <= 0:
break
# decide compact_steps
n_step = max(min(N // n_alive, 8), 1)
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
rays_alive = rays_alive[rays_alive >= 0]
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
step += n_step
# mix background color
if self.bg_radius > 0:
# use the bg model to calculate bg_color
# sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
bg_color = self.background(rays_d) # [N, 3]
elif bg_color is None:
bg_color = 1
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
image = image.view(*prefix, 3)
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
depth = depth.view(*prefix)
weights_sum = weights_sum.reshape(*prefix)
mask = (nears < fars).reshape(*prefix)
results['image'] = image
results['depth'] = depth
results['weights_sum'] = weights_sum
results['mask'] = mask
return results
@torch.no_grad()
def update_extra_state(self, decay=0.95, S=128):
# call before each epoch to update extra states.
if not self.cuda_ray:
return
### update density grid
tmp_grid = - torch.ones_like(self.density_grid)
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
for xs in X:
for ys in Y:
for zs in Z:
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
# cascading
for cas in range(self.cascade):
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_xyzs = xyzs * (bound - half_grid_size)
# add noise in [-hgs, hgs]
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
# query density
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
# assign
tmp_grid[cas, indices] = sigmas
# ema update
valid_mask = self.density_grid >= 0
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
self.mean_density = torch.mean(self.density_grid[valid_mask]).item()
self.iter_density += 1
# convert to bitfield
density_thresh = min(self.mean_density, self.density_thresh)
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
### update step counter
total_step = min(16, self.local_step)
if total_step > 0:
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
self.local_step = 0
# print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: pred_rgb: [B, N, 3]
if self.cuda_ray:
_run = self.run_cuda
else:
_run = self.run
B, N = rays_o.shape[:2]
device = rays_o.device
# never stage when cuda_ray
if staged and not self.cuda_ray:
depth = torch.empty((B, N), device=device)
image = torch.empty((B, N, 3), device=device)
weights_sum = torch.empty((B, N), device=device)
for b in range(B):
head = 0
while head < N:
tail = min(head + max_ray_batch, N)
results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)
depth[b:b+1, head:tail] = results_['depth']
weights_sum[b:b+1, head:tail] = results_['weights_sum']
image[b:b+1, head:tail] = results_['image']
head += max_ray_batch
results = {}
results['depth'] = depth
results['image'] = image
results['weights_sum'] = weights_sum
else:
results = _run(rays_o, rays_d, **kwargs)
return results