import torch import torch.nn as nn import torch.nn.functional as F from activation import trunc_exp from .renderer import NeRFRenderer from encoding import get_encoder import numpy as np import tinycudann as tcnn class MLP(nn.Module): def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): super().__init__() self.dim_in = dim_in self.dim_out = dim_out self.dim_hidden = dim_hidden self.num_layers = num_layers net = [] for l in range(num_layers): net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) self.net = nn.ModuleList(net) def forward(self, x): for l in range(self.num_layers): x = self.net[l](x) if l != self.num_layers - 1: x = F.relu(x, inplace=True) return x class NeRFNetwork(NeRFRenderer): def __init__(self, opt, num_layers=3, hidden_dim=64, num_layers_bg=2, hidden_dim_bg=64, ): super().__init__(opt) self.num_layers = num_layers self.hidden_dim = hidden_dim per_level_scale = np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)) self.encoder = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "HashGrid", "n_levels": 16, "n_features_per_level": 2, "log2_hashmap_size": 19, "base_resolution": 16, "per_level_scale": per_level_scale, }, ) self.sigma_net = MLP(32, 4, hidden_dim, num_layers, bias=True) # background network if self.bg_radius > 0: self.num_layers_bg = num_layers_bg self.hidden_dim_bg = hidden_dim_bg self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3) self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) else: self.bg_net = None def gaussian(self, x): # x: [B, N, 3] d = (x ** 2).sum(-1) g = 5 * torch.exp(-d / (2 * 0.2 ** 2)) return g def common_forward(self, x): # x: [N, 3], in [-bound, bound] # sigma h = (x + self.bound) / (2 * self.bound) # to [0, 1] h = self.encoder(h) h = self.sigma_net(h) sigma = trunc_exp(h[..., 0] + self.gaussian(x)) albedo = torch.sigmoid(h[..., 1:]) return sigma, albedo def forward(self, x, d, l=None, ratio=1, shading='albedo'): # x: [N, 3], in [-bound, bound] # d: [N, 3], view direction, nomalized in [-1, 1] # l: [3], plane light direction, nomalized in [-1, 1] # ratio: scalar, ambient ratio, 1 == no shading (albedo only) if shading == 'albedo': # no need to query normal sigma, color = self.common_forward(x) normal = None else: # query normal has_grad = torch.is_grad_enabled() with torch.enable_grad(): x.requires_grad_(True) sigma, albedo = self.common_forward(x) # query gradient normal = torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] # normalize... normal = normal / (torch.norm(normal, dim=-1, keepdim=True) + 1e-9) normal[torch.isnan(normal)] = 0 if not has_grad: normal = normal.detach() # lambertian shading lambertian = ratio + (1 - ratio) * (normal @ l).clamp(min=0) # [N,] if shading == 'textureless': color = lambertian.unsqueeze(-1).repeat(1, 3) elif shading == 'normal': color = (normal + 1) / 2 else: # 'lambertian' color = albedo * lambertian.unsqueeze(-1) return sigma, color, normal def density(self, x): # x: [N, 3], in [-bound, bound] sigma, _ = self.common_forward(x) return { 'sigma': sigma } def background(self, d): # x: [N, 2], in [-1, 1] h = self.encoder_bg(d) # [N, C] h = self.bg_net(h) # sigmoid activation for rgb rgbs = torch.sigmoid(h) return rgbs # optimizer utils def get_params(self, lr): params = [ {'params': self.encoder.parameters(), 'lr': lr * 10}, {'params': self.sigma_net.parameters(), 'lr': lr}, ] if self.bg_radius > 0: params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) params.append({'params': self.bg_net.parameters(), 'lr': lr}) return params