import math import torch import torch.nn as nn import torch.nn.functional as F import models from models.base import BaseModel from models.utils import chunk_batch from systems.utils import update_module_step from nerfacc import ( ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, render_weight_from_alpha, accumulate_along_rays, ) from nerfacc.intersection import ray_aabb_intersect import pdb class VarianceNetwork(nn.Module): def __init__(self, config): super(VarianceNetwork, self).__init__() self.config = config self.init_val = self.config.init_val self.register_parameter( "variance", nn.Parameter(torch.tensor(self.config.init_val)) ) self.modulate = self.config.get("modulate", False) if self.modulate: self.mod_start_steps = self.config.mod_start_steps self.reach_max_steps = self.config.reach_max_steps self.max_inv_s = self.config.max_inv_s @property def inv_s(self): val = torch.exp(self.variance * 10.0) if self.modulate and self.do_mod: val = val.clamp_max(self.mod_val) return val def forward(self, x): return torch.ones([len(x), 1], device=self.variance.device) * self.inv_s def update_step(self, epoch, global_step): if self.modulate: self.do_mod = global_step > self.mod_start_steps if not self.do_mod: self.prev_inv_s = self.inv_s.item() else: self.mod_val = min( (global_step / self.reach_max_steps) * (self.max_inv_s - self.prev_inv_s) + self.prev_inv_s, self.max_inv_s, ) @models.register("neus") class NeuSModel(BaseModel): def setup(self): self.geometry = models.make(, self.config.geometry) self.texture = models.make(, self.config.texture) self.geometry.contraction_type = ContractionType.AABB if self.config.learned_background: self.geometry_bg = models.make(, self.config.geometry_bg ) self.texture_bg = models.make(, self.config.texture_bg ) self.geometry_bg.contraction_type = ContractionType.UN_BOUNDED_SPHERE self.near_plane_bg, self.far_plane_bg = 0.1, 1e3 self.cone_angle_bg = ( 10 ** (math.log10(self.far_plane_bg) / self.config.num_samples_per_ray_bg) - 1.0 ) self.render_step_size_bg = 0.01 self.variance = VarianceNetwork(self.config.variance) self.register_buffer( "scene_aabb", torch.as_tensor( [ -self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius, ], dtype=torch.float32, ), ) if self.config.grid_prune: self.occupancy_grid = OccupancyGrid( roi_aabb=self.scene_aabb, resolution=128, contraction_type=ContractionType.AABB, ) if self.config.learned_background: self.occupancy_grid_bg = OccupancyGrid( roi_aabb=self.scene_aabb, resolution=256, contraction_type=ContractionType.UN_BOUNDED_SPHERE, ) self.randomized = self.config.randomized self.background_color = None self.render_step_size = ( 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray ) def update_step(self, epoch, global_step): update_module_step(self.geometry, epoch, global_step) update_module_step(self.texture, epoch, global_step) if self.config.learned_background: update_module_step(self.geometry_bg, epoch, global_step) update_module_step(self.texture_bg, epoch, global_step) update_module_step(self.variance, epoch, global_step) cos_anneal_end = self.config.get("cos_anneal_end", 0) self.cos_anneal_ratio = ( 1.0 if cos_anneal_end == 0 else min(1.0, global_step / cos_anneal_end) ) def occ_eval_fn(x): sdf = self.geometry(x, with_grad=False, with_feature=False) inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) inv_s = inv_s.expand(sdf.shape[0], 1) estimated_next_sdf = sdf[..., None] - self.render_step_size * 0.5 estimated_prev_sdf = sdf[..., None] + self.render_step_size * 0.5 prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) p = prev_cdf - next_cdf c = prev_cdf alpha = ((p + 1e-5) / (c + 1e-5)).view(-1, 1).clip(0.0, 1.0) return alpha def occ_eval_fn_bg(x): density, _ = self.geometry_bg(x) # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size_bg) based on taylor series return density[..., None] * self.render_step_size_bg if and self.config.grid_prune: self.occupancy_grid.every_n_step( step=global_step, occ_eval_fn=occ_eval_fn, occ_thre=self.config.get("grid_prune_occ_thre", 0.01), ) if self.config.learned_background: self.occupancy_grid_bg.every_n_step( step=global_step, occ_eval_fn=occ_eval_fn_bg, occ_thre=self.config.get("grid_prune_occ_thre_bg", 0.01), ) def isosurface(self): mesh = self.geometry.isosurface() return mesh def get_alpha(self, sdf, normal, dirs, dists): inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip( 1e-6, 1e6 ) # Single parameter inv_s = inv_s.expand(sdf.shape[0], 1) true_cos = (dirs * normal).sum(-1, keepdim=True) # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes # the cos value "not dead" at the beginning training iterations, for better convergence. iter_cos = -( F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) + F.relu(-true_cos) * self.cos_anneal_ratio ) # always non-positive # Estimate signed distances at section points estimated_next_sdf = sdf[..., None] + iter_cos * dists.reshape(-1, 1) * 0.5 estimated_prev_sdf = sdf[..., None] - iter_cos * dists.reshape(-1, 1) * 0.5 prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) p = prev_cdf - next_cdf c = prev_cdf alpha = ((p + 1e-5) / (c + 1e-5)).view(-1).clip(0.0, 1.0) return alpha def forward_bg_(self, rays): n_rays = rays.shape[0] rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) def sigma_fn(t_starts, t_ends, ray_indices): ray_indices = ray_indices.long() t_origins = rays_o[ray_indices] t_dirs = rays_d[ray_indices] positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 density, _ = self.geometry_bg(positions) return density[..., None] _, t_max = ray_aabb_intersect(rays_o, rays_d, self.scene_aabb) # if the ray intersects with the bounding box, start from the farther intersection point # otherwise start from self.far_plane_bg # note that in nerfacc t_max is set to 1e10 if there is no intersection near_plane = torch.where(t_max > 1e9, self.near_plane_bg, t_max) with torch.no_grad(): ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, scene_aabb=None, grid=self.occupancy_grid_bg if self.config.grid_prune else None, sigma_fn=sigma_fn, near_plane=near_plane, far_plane=self.far_plane_bg, render_step_size=self.render_step_size_bg, stratified=self.randomized, cone_angle=self.cone_angle_bg, alpha_thre=0.0, ) ray_indices = ray_indices.long() t_origins = rays_o[ray_indices] t_dirs = rays_d[ray_indices] midpoints = (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * midpoints intervals = t_ends - t_starts density, feature = self.geometry_bg(positions) rgb = self.texture_bg(feature, t_dirs) weights = render_weight_from_density( t_starts, t_ends, density[..., None], ray_indices=ray_indices, n_rays=n_rays ) opacity = accumulate_along_rays( weights, ray_indices, values=None, n_rays=n_rays ) depth = accumulate_along_rays( weights, ray_indices, values=midpoints, n_rays=n_rays ) comp_rgb = accumulate_along_rays( weights, ray_indices, values=rgb, n_rays=n_rays ) comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) out = { "comp_rgb": comp_rgb, "opacity": opacity, "depth": depth, "rays_valid": opacity > 0, "num_samples": torch.as_tensor( [len(t_starts)], dtype=torch.int32, device=rays.device ), } if out.update( { "weights": weights.view(-1), "points": midpoints.view(-1), "intervals": intervals.view(-1), "ray_indices": ray_indices.view(-1), } ) return out def forward_(self, rays): n_rays = rays.shape[0] rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) with torch.no_grad(): ray_indices, t_starts, t_ends = ray_marching( rays_o, rays_d, scene_aabb=self.scene_aabb, grid=self.occupancy_grid if self.config.grid_prune else None, alpha_fn=None, near_plane=None, far_plane=None, render_step_size=self.render_step_size, stratified=self.randomized, cone_angle=0.0, alpha_thre=0.0, ) ray_indices = ray_indices.long() t_origins = rays_o[ray_indices] t_dirs = rays_d[ray_indices] midpoints = (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * midpoints dists = t_ends - t_starts if self.config.geometry.grad_type == "finite_difference": sdf, sdf_grad, feature, sdf_laplace = self.geometry( positions, with_grad=True, with_feature=True, with_laplace=True ) else: sdf, sdf_grad, feature = self.geometry( positions, with_grad=True, with_feature=True ) normal = F.normalize(sdf_grad, p=2, dim=-1) alpha = self.get_alpha(sdf, normal, t_dirs, dists)[..., None] rgb = self.texture(feature, t_dirs, normal) weights = render_weight_from_alpha( alpha, ray_indices=ray_indices, n_rays=n_rays ) opacity = accumulate_along_rays( weights, ray_indices, values=None, n_rays=n_rays ) depth = accumulate_along_rays( weights, ray_indices, values=midpoints, n_rays=n_rays ) comp_rgb = accumulate_along_rays( weights, ray_indices, values=rgb, n_rays=n_rays ) comp_normal = accumulate_along_rays( weights, ray_indices, values=normal, n_rays=n_rays ) comp_normal = F.normalize(comp_normal, p=2, dim=-1) pts_random = ( torch.rand([1024 * 2, 3]).to(sdf.dtype).to(sdf.device) * 2 - 1 ) # normalized to (-1, 1) if self.config.geometry.grad_type == "finite_difference": random_sdf, random_sdf_grad, _ = self.geometry( pts_random, with_grad=True, with_feature=False, with_laplace=True ) _, normal_perturb, _ = self.geometry( pts_random + torch.randn_like(pts_random) * 1e-2, with_grad=True, with_feature=False, with_laplace=True, ) else: random_sdf, random_sdf_grad = self.geometry( pts_random, with_grad=True, with_feature=False ) _, normal_perturb = self.geometry( positions + torch.randn_like(positions) * 1e-2, with_grad=True, with_feature=False, ) # pdb.set_trace() out = { "comp_rgb": comp_rgb, "comp_normal": comp_normal, "opacity": opacity, "depth": depth, "rays_valid": opacity > 0, "num_samples": torch.as_tensor( [len(t_starts)], dtype=torch.int32, device=rays.device ), } if out.update( { "sdf_samples": sdf, "sdf_grad_samples": sdf_grad, "random_sdf": random_sdf, "random_sdf_grad": random_sdf_grad, "normal_perturb": normal_perturb, "weights": weights.view(-1), "points": midpoints.view(-1), "intervals": dists.view(-1), "ray_indices": ray_indices.view(-1), } ) if self.config.geometry.grad_type == "finite_difference": out.update({"sdf_laplace_samples": sdf_laplace}) if self.config.learned_background: out_bg = self.forward_bg_(rays) else: out_bg = { "comp_rgb": self.background_color[None, :].expand(*comp_rgb.shape), "num_samples": torch.zeros_like(out["num_samples"]), "rays_valid": torch.zeros_like(out["rays_valid"]), } out_full = { "comp_rgb": out["comp_rgb"] + out_bg["comp_rgb"] * (1.0 - out["opacity"]), "num_samples": out["num_samples"] + out_bg["num_samples"], "rays_valid": out["rays_valid"] | out_bg["rays_valid"], } return { **out, **{k + "_bg": v for k, v in out_bg.items()}, **{k + "_full": v for k, v in out_full.items()}, } def forward(self, rays): if out = self.forward_(rays) else: out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) return {**out, "inv_s": self.variance.inv_s} def train(self, mode=True): self.randomized = mode and self.config.randomized return super().train(mode=mode) def eval(self): self.randomized = False return super().eval() def regularizations(self, out): losses = {} losses.update(self.geometry.regularizations(out)) losses.update(self.texture.regularizations(out)) return losses @torch.no_grad() def export(self, export_config): mesh = self.isosurface() if export_config.export_vertex_color: _, sdf_grad, feature = chunk_batch( self.geometry, export_config.chunk_size, False, mesh["v_pos"].to(self.rank), with_grad=True, with_feature=True, ) normal = F.normalize(sdf_grad, p=2, dim=-1) rgb = self.texture( feature, -normal, normal ) # set the viewing directions to the normal to get "albedo" mesh["v_rgb"] = rgb.cpu() return mesh