import torch from lib.config import cfg from .nerf_net_utils import * from .. import embedder class Renderer: def __init__(self, net): self.net = net def get_sampling_points(self, ray_o, ray_d, near, far): # calculate the steps for each ray t_vals = torch.linspace(0., 1., steps=cfg.N_samples).to(near) z_vals = near[..., None] * (1. - t_vals) + far[..., None] * t_vals if cfg.perturb > 0. and self.net.training: # get intervals between samples mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) upper = torch.cat([mids, z_vals[..., -1:]], -1) lower = torch.cat([z_vals[..., :1], mids], -1) # stratified samples in those intervals t_rand = torch.rand(z_vals.shape).to(upper) z_vals = lower + (upper - lower) * t_rand pts = ray_o[:, :, None] + ray_d[:, :, None] * z_vals[..., None] return pts, z_vals def pts_to_can_pts(self, pts, batch): """transform pts from the world coordinate to the smpl coordinate""" Th = batch['Th'][:, None] pts = pts - Th R = batch['R'] sh = pts.shape pts = torch.matmul(pts.view(sh[0], -1, sh[3]), R) pts = pts.view(*sh) return pts def transform_sampling_points(self, pts, batch): if not self.net.training: return pts center = batch['center'][:, None, None] pts = pts - center rot = batch['rot'] pts_ = pts[..., [0, 2]].clone() sh = pts_.shape pts_ = torch.matmul(pts_.view(sh[0], -1, sh[3]), rot.permute(0, 2, 1)) pts[..., [0, 2]] = pts_.view(*sh) pts = pts + center trans = batch['trans'][:, None, None] pts = pts + trans return pts def prepare_sp_input(self, batch): # feature, coordinate, shape, batch size sp_input = {} # coordinate: [N, 4], batch_idx, x, y, z sh = batch['tcoord'].shape idx = [torch.full([sh[1]], i) for i in range(sh[0])] idx = torch.cat(idx).to(batch['tcoord']) coord = batch['tcoord'].view(-1, sh[-1]) sp_input['coord'] = torch.cat([idx[:, None], coord], dim=1) out_sh, _ = torch.max(batch['tout_sh'], dim=0) sp_input['out_sh'] = out_sh.tolist() sp_input['batch_size'] = sh[0] sp_input['i'] = batch['i'] return sp_input def get_ptot_grid_coords(self, pts, out_sh, bounds): # pts: [batch_size, x, y, z, 3], x, y, z min_xyz = bounds[:, 0] pts = pts - min_xyz[:, None, None, None] pts = pts / torch.tensor(cfg.voxel_size).to(pts) # convert the voxel coordinate to [-1, 1] out_sh = torch.tensor(out_sh).to(pts) pts = pts / out_sh * 2 - 1 # convert xyz to zyx, since the occupancy is indexed by xyz grid_coords = pts[..., [2, 1, 0]] return grid_coords def get_grid_coords(self, pts, ptot_pts, bounds): out_sh = torch.tensor(ptot_pts.shape[1:-1]).to(pts) # pts: [batch_size, N, 3], x, y, z min_xyz = bounds[:, 0] pts = pts - min_xyz[:, None] pts = pts / torch.tensor(cfg.ptot_vsize).to(pts) # convert the voxel coordinate to [-1, 1] pts = pts / out_sh * 2 - 1 # convert xyz to zyx, since the occupancy is indexed by xyz grid_coords = pts[..., [2, 1, 0]] return grid_coords # def batchify_rays(self, rays_flat, chunk=1024 * 32, net_c=None): def batchify_rays(self, sp_input, tgrid_coords, pgrid_coords, viewdir, light_pts, chunk=1024 * 32, net_c=None): """Render rays in smaller minibatches to avoid OOM. """ all_ret = [] for i in range(0, tgrid_coords.shape[1], chunk): # ret = self.render_rays(rays_flat[i:i + chunk], net_c) ret = self.net(sp_input, tgrid_coords[:, i:i + chunk], pgrid_coords[:, i:i + chunk], viewdir[:, i:i + chunk], light_pts[:, i:i + chunk]) # for k in ret: # if k not in all_ret: # all_ret[k] = [] # all_ret[k].append(ret[k]) all_ret.append(ret) # all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} all_ret = torch.cat(all_ret, 1) return all_ret def render(self, batch): ray_o = batch['ray_o'] ray_d = batch['ray_d'] near = batch['near'] far = batch['far'] sh = ray_o.shape pts, z_vals = self.get_sampling_points(ray_o, ray_d, near, far) # light intensity varies with 3D location light_pts = embedder.xyz_embedder(pts) ppts = self.pts_to_can_pts(pts, batch) ray_d0 = batch['ray_d'] viewdir = ray_d0 / torch.norm(ray_d0, dim=2, keepdim=True) viewdir = embedder.view_embedder(viewdir) viewdir = viewdir[:, :, None].repeat(1, 1, pts.size(2), 1).contiguous() sp_input = self.prepare_sp_input(batch) # reshape to [batch_size, n, 3] light_pts = light_pts.view(sh[0], -1, embedder.xyz_dim) viewdir = viewdir.view(sh[0], -1, embedder.view_dim) ppts = ppts.view(sh[0], -1, 3) # create grid coords for sampling feature volume at t pose ptot_pts = batch['ptot_pts'] tgrid_coords = self.get_ptot_grid_coords(ptot_pts, sp_input['out_sh'], batch['tbounds']) # create grid coords for sampling feature volume at i-th frame pgrid_coords = self.get_grid_coords(ppts, ptot_pts, batch['pbounds']) if ray_o.size(1) <= 2048: raw = self.net(sp_input, tgrid_coords, pgrid_coords, viewdir, light_pts) else: raw = self.batchify_rays(sp_input, tgrid_coords, pgrid_coords, viewdir, light_pts, 1024 * 32, None) # reshape to [num_rays, num_samples along ray, 4] raw = raw.reshape(-1, z_vals.size(2), 4) z_vals = z_vals.view(-1, z_vals.size(2)) ray_d = ray_d.view(-1, 3) rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs( raw, z_vals, ray_d, cfg.raw_noise_std, cfg.white_bkgd) rgb_map = rgb_map.view(*sh[:-1], -1) acc_map = acc_map.view(*sh[:-1]) depth_map = depth_map.view(*sh[:-1]) ret = {'rgb_map': rgb_map, 'acc_map': acc_map, 'depth_map': depth_map} return ret