|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from activation import trunc_exp |
|
from .renderer import NeRFRenderer |
|
|
|
import numpy as np |
|
from encoding import get_encoder |
|
|
|
from .utils import safe_normalize |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): |
|
super().__init__() |
|
self.dim_in = dim_in |
|
self.dim_out = dim_out |
|
self.dim_hidden = dim_hidden |
|
self.num_layers = num_layers |
|
|
|
net = [] |
|
for l in range(num_layers): |
|
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) |
|
|
|
self.net = nn.ModuleList(net) |
|
|
|
def forward(self, x): |
|
for l in range(self.num_layers): |
|
x = self.net[l](x) |
|
if l != self.num_layers - 1: |
|
x = F.relu(x, inplace=True) |
|
return x |
|
|
|
|
|
class NeRFNetwork(NeRFRenderer): |
|
def __init__(self, |
|
opt, |
|
num_layers=3, |
|
hidden_dim=64, |
|
num_layers_bg=2, |
|
hidden_dim_bg=64, |
|
): |
|
|
|
super().__init__(opt) |
|
|
|
self.num_layers = num_layers |
|
self.hidden_dim = hidden_dim |
|
|
|
self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, desired_resolution=2048 * self.bound) |
|
|
|
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) |
|
|
|
|
|
if self.bg_radius > 0: |
|
self.num_layers_bg = num_layers_bg |
|
self.hidden_dim_bg = hidden_dim_bg |
|
|
|
|
|
|
|
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=2) |
|
|
|
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) |
|
|
|
else: |
|
self.bg_net = None |
|
|
|
|
|
def gaussian(self, x): |
|
|
|
|
|
d = (x ** 2).sum(-1) |
|
g = 5 * torch.exp(-d / (2 * 0.2 ** 2)) |
|
|
|
return g |
|
|
|
def common_forward(self, x): |
|
|
|
|
|
|
|
h = self.encoder(x, bound=self.bound) |
|
|
|
h = self.sigma_net(h) |
|
|
|
sigma = trunc_exp(h[..., 0] + self.gaussian(x)) |
|
albedo = torch.sigmoid(h[..., 1:]) |
|
|
|
return sigma, albedo |
|
|
|
|
|
def finite_difference_normal(self, x, epsilon=5e-4): |
|
|
|
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) |
|
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) |
|
dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) |
|
dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) |
|
dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) |
|
dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) |
|
|
|
normal = torch.stack([ |
|
0.5 * (dx_pos - dx_neg) / epsilon, |
|
0.5 * (dy_pos - dy_neg) / epsilon, |
|
0.5 * (dz_pos - dz_neg) / epsilon |
|
], dim=-1) |
|
|
|
return normal |
|
|
|
def forward(self, x, d, l=None, ratio=1, shading='albedo'): |
|
|
|
|
|
|
|
|
|
|
|
if shading == 'albedo': |
|
|
|
sigma, color = self.common_forward(x) |
|
normal = None |
|
|
|
else: |
|
|
|
|
|
sigma, albedo = self.common_forward(x) |
|
normal = self.finite_difference_normal(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
normal = safe_normalize(normal) |
|
normal[torch.isnan(normal)] = 0 |
|
|
|
|
|
lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) |
|
|
|
if shading == 'textureless': |
|
color = lambertian.unsqueeze(-1).repeat(1, 3) |
|
elif shading == 'normal': |
|
color = (normal + 1) / 2 |
|
else: |
|
color = albedo * lambertian.unsqueeze(-1) |
|
|
|
return sigma, color, normal |
|
|
|
|
|
def density(self, x): |
|
|
|
|
|
sigma, albedo = self.common_forward(x) |
|
|
|
return { |
|
'sigma': sigma, |
|
'albedo': albedo, |
|
} |
|
|
|
|
|
def background(self, x, d): |
|
|
|
|
|
h = self.encoder_bg(x) |
|
|
|
h = self.bg_net(h) |
|
|
|
|
|
rgbs = torch.sigmoid(h) |
|
|
|
return rgbs |
|
|
|
|
|
def get_params(self, lr): |
|
|
|
params = [ |
|
{'params': self.encoder.parameters(), 'lr': lr * 10}, |
|
{'params': self.sigma_net.parameters(), 'lr': lr}, |
|
] |
|
|
|
if self.bg_radius > 0: |
|
params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) |
|
params.append({'params': self.bg_net.parameters(), 'lr': lr}) |
|
|
|
return params |