Stable-Dreamfusion / nerf /renderer.py
Linn Bedroc
Jerb
1e9f296
raw
history blame
No virus
26.9 kB
import os
import math
import cv2
import trimesh
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import mcubes
import raymarching
from .utils import custom_meshgrid, safe_normalize
def sample_pdf(bins, weights, n_samples, det=False):
# This implementation is from NeRF
# bins: [B, T], old_z_vals
# weights: [B, T - 1], bin weights.
# return: [B, n_samples], new_z_vals
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
# Take uniform samples
if det:
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
def plot_pointcloud(pc, color=None):
# pc: [N, 3]
# color: [N, 3/4]
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
pc = trimesh.PointCloud(pc, color)
# axis
axes = trimesh.creation.axis(axis_length=4)
# sphere
sphere = trimesh.creation.icosphere(radius=1)
trimesh.Scene([pc, axes, sphere]).show()
class NeRFRenderer(nn.Module):
def __init__(self, opt):
super().__init__()
self.opt = opt
self.bound = opt.bound
self.cascade = 1 + math.ceil(math.log2(opt.bound))
self.grid_size = 128
self.cuda_ray = opt.cuda_ray
self.min_near = opt.min_near
self.density_thresh = opt.density_thresh
self.bg_radius = opt.bg_radius
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])
aabb_infer = aabb_train.clone()
self.register_buffer('aabb_train', aabb_train)
self.register_buffer('aabb_infer', aabb_infer)
# extra state for cuda raymarching
if self.cuda_ray:
# density grid
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
self.register_buffer('density_grid', density_grid)
self.register_buffer('density_bitfield', density_bitfield)
self.mean_density = 0
self.iter_density = 0
# step counter
step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
self.register_buffer('step_counter', step_counter)
self.mean_count = 0
self.local_step = 0
def forward(self, x, d):
raise NotImplementedError()
def density(self, x):
raise NotImplementedError()
def color(self, x, d, mask=None, **kwargs):
raise NotImplementedError()
def reset_extra_state(self):
if not self.cuda_ray:
return
# density grid
self.density_grid.zero_()
self.mean_density = 0
self.iter_density = 0
# step counter
self.step_counter.zero_()
self.mean_count = 0
self.local_step = 0
@torch.no_grad()
def export_mesh(self, path, resolution=None, S=128):
if resolution is None:
resolution = self.grid_size
density_thresh = min(self.mean_density, self.density_thresh)
sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32)
# query
X = torch.linspace(-1, 1, resolution).split(S)
Y = torch.linspace(-1, 1, resolution).split(S)
Z = torch.linspace(-1, 1, resolution).split(S)
for xi, xs in enumerate(X):
for yi, ys in enumerate(Y):
for zi, zs in enumerate(Z):
xx, yy, zz = custom_meshgrid(xs, ys, zs)
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
val = self.density(pts.to(self.density_bitfield.device))
sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
vertices = vertices / (resolution - 1.0) * 2 - 1
vertices = vertices.astype(np.float32)
triangles = triangles.astype(np.int32)
v = torch.from_numpy(vertices).to(self.density_bitfield.device)
f = torch.from_numpy(triangles).int().to(self.density_bitfield.device)
# mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
# mesh.export(os.path.join(path, f'mesh.ply'))
# texture?
def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
# v, f: torch Tensor
device = v.device
v_np = v.cpu().numpy() # [N, 3]
f_np = f.cpu().numpy() # [M, 3]
print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
# unwrap uvs
import xatlas
import nvdiffrast.torch as dr
from sklearn.neighbors import NearestNeighbors
from scipy.ndimage import binary_dilation, binary_erosion
glctx = dr.RasterizeCudaContext()
atlas = xatlas.Atlas()
atlas.add_mesh(v_np, f_np)
chart_options = xatlas.ChartOptions()
chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
atlas.generate(chart_options=chart_options)
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
# vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
# render uv maps
uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
if ssaa > 1:
h = int(h0 * ssaa)
w = int(w0 * ssaa)
else:
h, w = h0, w0
rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
# masked query
xyzs = xyzs.view(-1, 3)
mask = (mask > 0).view(-1)
sigmas = torch.zeros(h * w, device=device, dtype=torch.float32)
feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
if mask.any():
xyzs = xyzs[mask] # [M, 3]
# batched inference to avoid OOM
all_sigmas = []
all_feats = []
head = 0
while head < xyzs.shape[0]:
tail = min(head + 640000, xyzs.shape[0])
results_ = self.density(xyzs[head:tail])
all_sigmas.append(results_['sigma'].float())
all_feats.append(results_['albedo'].float())
head += 640000
sigmas[mask] = torch.cat(all_sigmas, dim=0)
feats[mask] = torch.cat(all_feats, dim=0)
sigmas = sigmas.view(h, w, 1)
feats = feats.view(h, w, -1)
mask = mask.view(h, w)
### alpha mask
# deltas = 2 * np.sqrt(3) / 1024
# alphas = 1 - torch.exp(-sigmas * deltas)
# alphas_mask = alphas > 0.5
# feats = feats * alphas_mask
# quantize [0.0, 1.0] to [0, 255]
feats = feats.cpu().numpy()
feats = (feats * 255).astype(np.uint8)
# alphas = alphas.cpu().numpy()
# alphas = (alphas * 255).astype(np.uint8)
### NN search as an antialiasing ...
mask = mask.cpu().numpy()
inpaint_region = binary_dilation(mask, iterations=3)
inpaint_region[mask] = 0
search_region = mask.copy()
not_search_region = binary_erosion(search_region, iterations=2)
search_region[not_search_region] = 0
search_coords = np.stack(np.nonzero(search_region), axis=-1)
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
_, indices = knn.kneighbors(inpaint_coords)
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
# do ssaa after the NN search, in numpy
feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
if ssaa > 1:
# alphas = cv2.resize(alphas, (w0, h0), interpolation=cv2.INTER_NEAREST)
feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
# cv2.imwrite(os.path.join(path, f'alpha.png'), alphas)
cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)
# save obj (v, vt, f /)
obj_file = os.path.join(path, f'{name}mesh.obj')
mtl_file = os.path.join(path, f'{name}mesh.mtl')
print(f'[INFO] writing obj mesh to {obj_file}')
with open(obj_file, "w") as fp:
fp.write(f'mtllib {name}mesh.mtl \n')
print(f'[INFO] writing vertices {v_np.shape}')
for v in v_np:
fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
print(f'[INFO] writing vertices texture coords {vt_np.shape}')
for v in vt_np:
fp.write(f'vt {v[0]} {1 - v[1]} \n')
print(f'[INFO] writing faces {f_np.shape}')
fp.write(f'usemtl mat0 \n')
for i in range(len(f_np)):
fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
with open(mtl_file, "w") as fp:
fp.write(f'newmtl mat0 \n')
fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
fp.write(f'Tr 1.000000 \n')
fp.write(f'illum 1 \n')
fp.write(f'Ns 0.000000 \n')
fp.write(f'map_Kd {name}albedo.png \n')
os.system(f"blender -b -P obj2glb.py")
_export(v, f)
def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# bg_color: [BN, 3] in range [0, 1]
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
results = {}
# choose aabb
aabb = self.aabb_train if self.training else self.aabb_infer
# sample steps
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
nears.unsqueeze_(-1)
fars.unsqueeze_(-1)
# random sample light_d if not provided
if light_d is None:
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
light_d = safe_normalize(light_d)
#print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
z_vals = z_vals.expand((N, num_steps)) # [N, T]
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
# perturb z_vals
sample_dist = (fars - nears) / num_steps
if perturb:
z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
#z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
# generate xyzs
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
# query SDF and RGB
density_outputs = self.density(xyzs.reshape(-1, 3))
#sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
for k, v in density_outputs.items():
density_outputs[k] = v.view(N, num_steps, -1)
# upsample z_vals (nerf-like)
if upsample_steps > 0:
with torch.no_grad():
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
# sample new z_vals
z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
# only forward new points to save computation
new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
#new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
for k, v in new_density_outputs.items():
new_density_outputs[k] = v.view(N, upsample_steps, -1)
# re-order
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
z_vals, z_index = torch.sort(z_vals, dim=1)
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
for k in density_outputs:
tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
for k, v in density_outputs.items():
density_outputs[k] = v.view(-1, v.shape[-1])
sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading)
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
#print(xyzs.shape, 'valid_rgb:', mask.sum().item())
# orientation loss
if normals is not None:
normals = normals.view(N, -1, 3)
# print(weights.shape, normals.shape, dirs.shape)
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
results['loss_orient'] = loss_orient.mean()
# calculate weight_sum (mask)
weights_sum = weights.sum(dim=-1) # [N]
# calculate depth
ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
depth = torch.sum(weights * ori_z_vals, dim=-1)
# calculate color
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
# mix background color
if self.bg_radius > 0:
# use the bg model to calculate bg_color
# sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
bg_color = self.background(rays_d.reshape(-1, 3)) # [N, 3]
elif bg_color is None:
bg_color = 1
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
image = image.view(*prefix, 3)
depth = depth.view(*prefix)
mask = (nears < fars).reshape(*prefix)
results['image'] = image
results['depth'] = depth
results['weights_sum'] = weights_sum
results['mask'] = mask
return results
def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
# pre-calculate near far
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)
# random sample light_d if not provided
if light_d is None:
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
light_d = safe_normalize(light_d)
results = {}
if self.training:
# setup counter
counter = self.step_counter[self.local_step % 16]
counter.zero_() # set to 0
self.local_step += 1
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
#print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
# orientation loss
if normals is not None:
weights = 1 - torch.exp(-sigmas)
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
results['loss_orient'] = loss_orient.mean()
else:
# allocate outputs
dtype = torch.float32
weights_sum = torch.zeros(N, dtype=dtype, device=device)
depth = torch.zeros(N, dtype=dtype, device=device)
image = torch.zeros(N, 3, dtype=dtype, device=device)
n_alive = N
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
rays_t = nears.clone() # [N]
step = 0
while step < max_steps: # hard coded max step
# count alive rays
n_alive = rays_alive.shape[0]
# exit loop
if n_alive <= 0:
break
# decide compact_steps
n_step = max(min(N // n_alive, 8), 1)
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
rays_alive = rays_alive[rays_alive >= 0]
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
step += n_step
# mix background color
if self.bg_radius > 0:
# use the bg model to calculate bg_color
# sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
bg_color = self.background(rays_d) # [N, 3]
elif bg_color is None:
bg_color = 1
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
image = image.view(*prefix, 3)
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
depth = depth.view(*prefix)
weights_sum = weights_sum.reshape(*prefix)
mask = (nears < fars).reshape(*prefix)
results['image'] = image
results['depth'] = depth
results['weights_sum'] = weights_sum
results['mask'] = mask
return results
@torch.no_grad()
def update_extra_state(self, decay=0.95, S=128):
# call before each epoch to update extra states.
if not self.cuda_ray:
return
### update density grid
tmp_grid = - torch.ones_like(self.density_grid)
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
for xs in X:
for ys in Y:
for zs in Z:
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
# cascading
for cas in range(self.cascade):
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_xyzs = xyzs * (bound - half_grid_size)
# add noise in [-hgs, hgs]
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
# query density
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
# assign
tmp_grid[cas, indices] = sigmas
# ema update
valid_mask = self.density_grid >= 0
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
self.mean_density = torch.mean(self.density_grid[valid_mask]).item()
self.iter_density += 1
# convert to bitfield
density_thresh = min(self.mean_density, self.density_thresh)
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
### update step counter
total_step = min(16, self.local_step)
if total_step > 0:
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
self.local_step = 0
# print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: pred_rgb: [B, N, 3]
if self.cuda_ray:
_run = self.run_cuda
else:
_run = self.run
B, N = rays_o.shape[:2]
device = rays_o.device
# never stage when cuda_ray
if staged and not self.cuda_ray:
depth = torch.empty((B, N), device=device)
image = torch.empty((B, N, 3), device=device)
weights_sum = torch.empty((B, N), device=device)
for b in range(B):
head = 0
while head < N:
tail = min(head + max_ray_batch, N)
results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)
depth[b:b+1, head:tail] = results_['depth']
weights_sum[b:b+1, head:tail] = results_['weights_sum']
image[b:b+1, head:tail] = results_['image']
head += max_ray_batch
results = {}
results['depth'] = depth
results['image'] = image
results['weights_sum'] = weights_sum
else:
results = _run(rays_o, rays_d, **kwargs)
return results