Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from activation import trunc_exp | |
| from .renderer import NeRFRenderer | |
| from encoding import get_encoder | |
| import numpy as np | |
| import tinycudann as tcnn | |
| class MLP(nn.Module): | |
| def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): | |
| super().__init__() | |
| self.dim_in = dim_in | |
| self.dim_out = dim_out | |
| self.dim_hidden = dim_hidden | |
| self.num_layers = num_layers | |
| net = [] | |
| for l in range(num_layers): | |
| net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) | |
| self.net = nn.ModuleList(net) | |
| def forward(self, x): | |
| for l in range(self.num_layers): | |
| x = self.net[l](x) | |
| if l != self.num_layers - 1: | |
| x = F.relu(x, inplace=True) | |
| return x | |
| class NeRFNetwork(NeRFRenderer): | |
| def __init__(self, | |
| opt, | |
| num_layers=3, | |
| hidden_dim=64, | |
| num_layers_bg=2, | |
| hidden_dim_bg=64, | |
| ): | |
| super().__init__(opt) | |
| self.num_layers = num_layers | |
| self.hidden_dim = hidden_dim | |
| per_level_scale = np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)) | |
| self.encoder = tcnn.Encoding( | |
| n_input_dims=3, | |
| encoding_config={ | |
| "otype": "HashGrid", | |
| "n_levels": 16, | |
| "n_features_per_level": 2, | |
| "log2_hashmap_size": 19, | |
| "base_resolution": 16, | |
| "per_level_scale": per_level_scale, | |
| }, | |
| ) | |
| self.sigma_net = MLP(32, 4, hidden_dim, num_layers, bias=True) | |
| # background network | |
| if self.bg_radius > 0: | |
| self.num_layers_bg = num_layers_bg | |
| self.hidden_dim_bg = hidden_dim_bg | |
| self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3) | |
| self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) | |
| else: | |
| self.bg_net = None | |
| def gaussian(self, x): | |
| # x: [B, N, 3] | |
| d = (x ** 2).sum(-1) | |
| g = 5 * torch.exp(-d / (2 * 0.2 ** 2)) | |
| return g | |
| def common_forward(self, x): | |
| # x: [N, 3], in [-bound, bound] | |
| # sigma | |
| h = (x + self.bound) / (2 * self.bound) # to [0, 1] | |
| h = self.encoder(h) | |
| h = self.sigma_net(h) | |
| sigma = trunc_exp(h[..., 0] + self.gaussian(x)) | |
| albedo = torch.sigmoid(h[..., 1:]) | |
| return sigma, albedo | |
| def forward(self, x, d, l=None, ratio=1, shading='albedo'): | |
| # x: [N, 3], in [-bound, bound] | |
| # d: [N, 3], view direction, nomalized in [-1, 1] | |
| # l: [3], plane light direction, nomalized in [-1, 1] | |
| # ratio: scalar, ambient ratio, 1 == no shading (albedo only) | |
| if shading == 'albedo': | |
| # no need to query normal | |
| sigma, color = self.common_forward(x) | |
| normal = None | |
| else: | |
| # query normal | |
| has_grad = torch.is_grad_enabled() | |
| with torch.enable_grad(): | |
| x.requires_grad_(True) | |
| sigma, albedo = self.common_forward(x) | |
| # query gradient | |
| normal = torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3] | |
| # normalize... | |
| normal = normal / (torch.norm(normal, dim=-1, keepdim=True) + 1e-9) | |
| normal[torch.isnan(normal)] = 0 | |
| if not has_grad: | |
| normal = normal.detach() | |
| # lambertian shading | |
| lambertian = ratio + (1 - ratio) * (normal @ l).clamp(min=0) # [N,] | |
| if shading == 'textureless': | |
| color = lambertian.unsqueeze(-1).repeat(1, 3) | |
| elif shading == 'normal': | |
| color = (normal + 1) / 2 | |
| else: # 'lambertian' | |
| color = albedo * lambertian.unsqueeze(-1) | |
| return sigma, color, normal | |
| def density(self, x): | |
| # x: [N, 3], in [-bound, bound] | |
| sigma, _ = self.common_forward(x) | |
| return { | |
| 'sigma': sigma | |
| } | |
| def background(self, d): | |
| # x: [N, 2], in [-1, 1] | |
| h = self.encoder_bg(d) # [N, C] | |
| h = self.bg_net(h) | |
| # sigmoid activation for rgb | |
| rgbs = torch.sigmoid(h) | |
| return rgbs | |
| # optimizer utils | |
| def get_params(self, lr): | |
| params = [ | |
| {'params': self.encoder.parameters(), 'lr': lr * 10}, | |
| {'params': self.sigma_net.parameters(), 'lr': lr}, | |
| ] | |
| if self.bg_radius > 0: | |
| params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) | |
| params.append({'params': self.bg_net.parameters(), 'lr': lr}) | |
| return params |