Spaces:
Configuration error
Configuration error
import torch | |
from lib.config import cfg | |
from .nerf_net_utils import * | |
from .. import embedder | |
class Renderer: | |
def __init__(self, net): | |
self.net = net | |
def get_sampling_points(self, ray_o, ray_d, near, far): | |
# calculate the steps for each ray | |
t_vals = torch.linspace(0., 1., steps=cfg.N_samples).to(near) | |
z_vals = near[..., None] * (1. - t_vals) + far[..., None] * t_vals | |
if cfg.perturb > 0. and self.net.training: | |
# get intervals between samples | |
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) | |
upper = torch.cat([mids, z_vals[..., -1:]], -1) | |
lower = torch.cat([z_vals[..., :1], mids], -1) | |
# stratified samples in those intervals | |
t_rand = torch.rand(z_vals.shape).to(upper) | |
z_vals = lower + (upper - lower) * t_rand | |
pts = ray_o[:, :, None] + ray_d[:, :, None] * z_vals[..., None] | |
return pts, z_vals | |
def prepare_sp_input(self, batch): | |
# feature, coordinate, shape, batch size | |
sp_input = {} | |
# coordinate: [N, 4], batch_idx, z, y, x | |
sh = batch['coord'].shape | |
idx = [torch.full([sh[1]], i) for i in range(sh[0])] | |
idx = torch.cat(idx).to(batch['coord']) | |
coord = batch['coord'].view(-1, sh[-1]) | |
sp_input['coord'] = torch.cat([idx[:, None], coord], dim=1) | |
out_sh, _ = torch.max(batch['out_sh'], dim=0) | |
sp_input['out_sh'] = out_sh.tolist() | |
sp_input['batch_size'] = sh[0] | |
# used for feature interpolation | |
sp_input['bounds'] = batch['bounds'] | |
sp_input['R'] = batch['R'] | |
sp_input['Th'] = batch['Th'] | |
# used for color function | |
sp_input['latent_index'] = batch['latent_index'] | |
return sp_input | |
def get_density_color(self, wpts, viewdir, raw_decoder): | |
n_batch, n_pixel, n_sample = wpts.shape[:3] | |
wpts = wpts.view(n_batch, n_pixel * n_sample, -1) | |
viewdir = viewdir[:, :, None].repeat(1, 1, n_sample, 1).contiguous() | |
viewdir = viewdir.view(n_batch, n_pixel * n_sample, -1) | |
raw = raw_decoder(wpts, viewdir) | |
return raw | |
def get_pixel_value(self, ray_o, ray_d, near, far, feature_volume, | |
sp_input, batch): | |
# sampling points along camera rays | |
wpts, z_vals = self.get_sampling_points(ray_o, ray_d, near, far) | |
# viewing direction | |
viewdir = ray_d / torch.norm(ray_d, dim=2, keepdim=True) | |
raw_decoder = lambda x_point, viewdir_val: self.net.calculate_density_color( | |
x_point, viewdir_val, feature_volume, sp_input) | |
# compute the color and density | |
wpts_raw = self.get_density_color(wpts, viewdir, raw_decoder) | |
# volume rendering for wpts | |
n_batch, n_pixel, n_sample = wpts.shape[:3] | |
raw = wpts_raw.reshape(-1, n_sample, 4) | |
z_vals = z_vals.view(-1, n_sample) | |
ray_d = ray_d.view(-1, 3) | |
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs( | |
raw, z_vals, ray_d, cfg.raw_noise_std, cfg.white_bkgd) | |
ret = { | |
'rgb_map': rgb_map.view(n_batch, n_pixel, -1), | |
'disp_map': disp_map.view(n_batch, n_pixel), | |
'acc_map': acc_map.view(n_batch, n_pixel), | |
'weights': weights.view(n_batch, n_pixel, -1), | |
'depth_map': depth_map.view(n_batch, n_pixel) | |
} | |
return ret | |
def render(self, batch): | |
ray_o = batch['ray_o'] | |
ray_d = batch['ray_d'] | |
near = batch['near'] | |
far = batch['far'] | |
sh = ray_o.shape | |
# encode neural body | |
sp_input = self.prepare_sp_input(batch) | |
feature_volume = self.net.encode_sparse_voxels(sp_input) | |
# volume rendering for each pixel | |
n_batch, n_pixel = ray_o.shape[:2] | |
chunk = 2048 | |
ret_list = [] | |
for i in range(0, n_pixel, chunk): | |
ray_o_chunk = ray_o[:, i:i + chunk] | |
ray_d_chunk = ray_d[:, i:i + chunk] | |
near_chunk = near[:, i:i + chunk] | |
far_chunk = far[:, i:i + chunk] | |
pixel_value = self.get_pixel_value(ray_o_chunk, ray_d_chunk, | |
near_chunk, far_chunk, | |
feature_volume, sp_input, batch) | |
ret_list.append(pixel_value) | |
keys = ret_list[0].keys() | |
ret = {k: torch.cat([r[k] for r in ret_list], dim=1) for k in keys} | |
return ret | |