NeuralBody / lib /networks /renderer /volume_mesh_renderer.py
pengsida
initial commit
1ba539f
raw
history blame
3.94 kB
import torch
from lib.config import cfg
from .nerf_net_utils import *
import numpy as np
import mcubes
import trimesh
class Renderer:
def __init__(self, net):
self.net = net
def render_rays(self, ray_batch, net_c=None, pytest=False):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn: function. Model for predicting RGB and density at each point
in space.
network_query_fn: function used for passing queries to network_fn.
N_samples: int. Number of different times to sample along each ray.
retraw: bool. If True, include model's raw, unprocessed predictions.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray.
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn.
white_bkgd: bool. If True, assume a white background.
raw_noise_std: ...
verbose: bool. If True, print more debugging info.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
z_std: [num_rays]. Standard deviation of distances along ray for each
sample.
"""
pts = ray_batch
if net_c is None:
alpha = self.net(pts)
else:
alpha = self.net(pts, net_c)
if cfg.N_importance > 0:
alpha_0 = alpha
if net_c is None:
alpha = self.net(pts, model='fine')
else:
alpha = self.net(pts, net_c, model='fine')
ret = {
'alpha': alpha
}
if cfg.N_importance > 0:
ret['alpha0'] = alpha_0
for k in ret:
DEBUG = False
if (torch.isnan(ret[k]).any()
or torch.isinf(ret[k]).any()) and DEBUG:
print(f"! [Numerical Error] {k} contains nan or inf.")
return ret
def batchify_rays(self, rays_flat, chunk=1024 * 32):
"""Render rays in smaller minibatches to avoid OOM.
"""
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = self.render_rays(rays_flat[i:i + chunk])
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
def render(self, batch):
pts = batch['pts']
sh = pts.shape
inside = batch['inside'][0].bool()
pts = pts[0][inside][None]
pts = pts.view(sh[0], -1, 1, 3)
ret = self.batchify_rays(pts, cfg.chunk)
alpha = ret['alpha']
alpha = alpha[0, :, 0, 0].detach().cpu().numpy()
cube = np.zeros(sh[1:-1])
inside = inside.detach().cpu().numpy()
cube[inside == 1] = alpha
cube = np.pad(cube, 10, mode='constant')
vertices, triangles = mcubes.marching_cubes(cube, cfg.mesh_th)
mesh = trimesh.Trimesh(vertices, triangles)
ret = {'cube': cube, 'mesh': mesh}
return ret