|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from models.embedder import get_embedder |
|
|
|
|
|
class SDFNetwork(nn.Module): |
|
def __init__(self, |
|
d_in, |
|
d_out, |
|
d_hidden, |
|
n_layers, |
|
skip_in=(4,), |
|
multires=0, |
|
bias=0.5, |
|
scale=1, |
|
geometric_init=True, |
|
weight_norm=True, |
|
activation='softplus', |
|
conditional_type='multiply'): |
|
super(SDFNetwork, self).__init__() |
|
|
|
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] |
|
|
|
self.embed_fn_fine = None |
|
|
|
if multires > 0: |
|
embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) |
|
self.embed_fn_fine = embed_fn |
|
dims[0] = input_ch |
|
|
|
self.num_layers = len(dims) |
|
self.skip_in = skip_in |
|
self.scale = scale |
|
|
|
for l in range(0, self.num_layers - 1): |
|
if l + 1 in self.skip_in: |
|
out_dim = dims[l + 1] - dims[0] |
|
else: |
|
out_dim = dims[l + 1] |
|
|
|
lin = nn.Linear(dims[l], out_dim) |
|
|
|
if geometric_init: |
|
if l == self.num_layers - 2: |
|
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) |
|
torch.nn.init.constant_(lin.bias, -bias) |
|
elif multires > 0 and l == 0: |
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
torch.nn.init.constant_(lin.weight[:, 3:], 0.0) |
|
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
|
elif multires > 0 and l in self.skip_in: |
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
|
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) |
|
else: |
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
|
|
|
if weight_norm: |
|
lin = nn.utils.weight_norm(lin) |
|
|
|
setattr(self, "lin" + str(l), lin) |
|
|
|
if activation == 'softplus': |
|
self.activation = nn.Softplus(beta=100) |
|
else: |
|
assert activation == 'relu' |
|
self.activation = nn.ReLU() |
|
|
|
def forward(self, inputs): |
|
inputs = inputs * self.scale |
|
if self.embed_fn_fine is not None: |
|
inputs = self.embed_fn_fine(inputs) |
|
|
|
x = inputs |
|
for l in range(0, self.num_layers - 1): |
|
lin = getattr(self, "lin" + str(l)) |
|
|
|
if l in self.skip_in: |
|
x = torch.cat([x, inputs], 1) / np.sqrt(2) |
|
|
|
x = lin(x) |
|
|
|
if l < self.num_layers - 2: |
|
x = self.activation(x) |
|
return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) |
|
|
|
def sdf(self, x): |
|
return self.forward(x)[:, :1] |
|
|
|
def sdf_hidden_appearance(self, x): |
|
return self.forward(x) |
|
|
|
def gradient(self, x): |
|
x.requires_grad_(True) |
|
y = self.sdf(x) |
|
d_output = torch.ones_like(y, requires_grad=False, device=y.device) |
|
gradients = torch.autograd.grad( |
|
outputs=y, |
|
inputs=x, |
|
grad_outputs=d_output, |
|
create_graph=True, |
|
retain_graph=True, |
|
only_inputs=True)[0] |
|
return gradients.unsqueeze(1) |
|
|
|
|
|
class VarianceNetwork(nn.Module): |
|
def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0): |
|
super(VarianceNetwork, self).__init__() |
|
|
|
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] |
|
|
|
self.embed_fn_fine = None |
|
|
|
if multires > 0: |
|
embed_fn, input_ch = get_embedder(multires, normalize=False) |
|
self.embed_fn_fine = embed_fn |
|
dims[0] = input_ch |
|
|
|
self.num_layers = len(dims) |
|
self.skip_in = skip_in |
|
|
|
for l in range(0, self.num_layers - 1): |
|
if l + 1 in self.skip_in: |
|
out_dim = dims[l + 1] - dims[0] |
|
else: |
|
out_dim = dims[l + 1] |
|
|
|
lin = nn.Linear(dims[l], out_dim) |
|
setattr(self, "lin" + str(l), lin) |
|
|
|
self.relu = nn.ReLU() |
|
self.softplus = nn.Softplus(beta=100) |
|
|
|
def forward(self, inputs): |
|
if self.embed_fn_fine is not None: |
|
inputs = self.embed_fn_fine(inputs) |
|
|
|
x = inputs |
|
for l in range(0, self.num_layers - 1): |
|
lin = getattr(self, "lin" + str(l)) |
|
|
|
if l in self.skip_in: |
|
x = torch.cat([x, inputs], 1) / np.sqrt(2) |
|
|
|
x = lin(x) |
|
|
|
if l < self.num_layers - 2: |
|
x = self.relu(x) |
|
|
|
|
|
return 1.0 / (self.softplus(x + 0.5) + 1e-3) |
|
|
|
def coarse(self, inputs): |
|
return self.forward(inputs)[:, :1] |
|
|
|
def fine(self, inputs): |
|
return self.forward(inputs)[:, 1:] |
|
|
|
|
|
class FixVarianceNetwork(nn.Module): |
|
def __init__(self, base): |
|
super(FixVarianceNetwork, self).__init__() |
|
self.base = base |
|
self.iter_step = 0 |
|
|
|
def set_iter_step(self, iter_step): |
|
self.iter_step = iter_step |
|
|
|
def forward(self, x): |
|
return torch.ones([len(x), 1]) * np.exp(-self.iter_step / self.base) |
|
|
|
|
|
class SingleVarianceNetwork(nn.Module): |
|
def __init__(self, init_val=1.0): |
|
super(SingleVarianceNetwork, self).__init__() |
|
self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) |
|
|
|
def forward(self, x): |
|
return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0) |
|
|
|
|
|
|
|
class RenderingNetwork(nn.Module): |
|
def __init__( |
|
self, |
|
d_feature, |
|
mode, |
|
d_in, |
|
d_out, |
|
d_hidden, |
|
n_layers, |
|
weight_norm=True, |
|
multires_view=0, |
|
squeeze_out=True, |
|
d_conditional_colors=0 |
|
): |
|
super().__init__() |
|
|
|
self.mode = mode |
|
self.squeeze_out = squeeze_out |
|
dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] |
|
|
|
self.embedview_fn = None |
|
if multires_view > 0: |
|
embedview_fn, input_ch = get_embedder(multires_view) |
|
self.embedview_fn = embedview_fn |
|
dims[0] += (input_ch - 3) |
|
|
|
self.num_layers = len(dims) |
|
|
|
for l in range(0, self.num_layers - 1): |
|
out_dim = dims[l + 1] |
|
lin = nn.Linear(dims[l], out_dim) |
|
|
|
if weight_norm: |
|
lin = nn.utils.weight_norm(lin) |
|
|
|
setattr(self, "lin" + str(l), lin) |
|
|
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, points, normals, view_dirs, feature_vectors): |
|
if self.embedview_fn is not None: |
|
view_dirs = self.embedview_fn(view_dirs) |
|
|
|
rendering_input = None |
|
|
|
if self.mode == 'idr': |
|
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) |
|
elif self.mode == 'no_view_dir': |
|
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) |
|
elif self.mode == 'no_normal': |
|
rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) |
|
elif self.mode == 'no_points': |
|
rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) |
|
elif self.mode == 'no_points_no_view_dir': |
|
rendering_input = torch.cat([normals, feature_vectors], dim=-1) |
|
|
|
x = rendering_input |
|
|
|
for l in range(0, self.num_layers - 1): |
|
lin = getattr(self, "lin" + str(l)) |
|
|
|
x = lin(x) |
|
|
|
if l < self.num_layers - 2: |
|
x = self.relu(x) |
|
|
|
if self.squeeze_out: |
|
x = torch.sigmoid(x) |
|
return x |
|
|
|
|
|
|
|
class NeRF(nn.Module): |
|
def __init__(self, D=8, W=256, d_in=3, d_in_view=3, multires=0, multires_view=0, output_ch=4, skips=[4], |
|
use_viewdirs=False): |
|
""" |
|
""" |
|
super(NeRF, self).__init__() |
|
self.D = D |
|
self.W = W |
|
self.d_in = d_in |
|
self.d_in_view = d_in_view |
|
self.input_ch = 3 |
|
self.input_ch_view = 3 |
|
self.embed_fn = None |
|
self.embed_fn_view = None |
|
|
|
if multires > 0: |
|
embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) |
|
self.embed_fn = embed_fn |
|
self.input_ch = input_ch |
|
|
|
if multires_view > 0: |
|
embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view, normalize=False) |
|
self.embed_fn_view = embed_fn_view |
|
self.input_ch_view = input_ch_view |
|
|
|
self.skips = skips |
|
self.use_viewdirs = use_viewdirs |
|
|
|
self.pts_linears = nn.ModuleList( |
|
[nn.Linear(self.input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) |
|
for i in |
|
range(D - 1)]) |
|
|
|
|
|
self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) |
|
|
|
|
|
|
|
|
|
|
|
if use_viewdirs: |
|
self.feature_linear = nn.Linear(W, W) |
|
self.alpha_linear = nn.Linear(W, 1) |
|
self.rgb_linear = nn.Linear(W // 2, 3) |
|
else: |
|
self.output_linear = nn.Linear(W, output_ch) |
|
|
|
def forward(self, input_pts, input_views): |
|
if self.embed_fn is not None: |
|
input_pts = self.embed_fn(input_pts) |
|
if self.embed_fn_view is not None: |
|
input_views = self.embed_fn_view(input_views) |
|
|
|
h = input_pts |
|
for i, l in enumerate(self.pts_linears): |
|
h = self.pts_linears[i](h) |
|
h = F.relu(h) |
|
if i in self.skips: |
|
h = torch.cat([input_pts, h], -1) |
|
|
|
if self.use_viewdirs: |
|
alpha = self.alpha_linear(h) |
|
feature = self.feature_linear(h) |
|
h = torch.cat([feature, input_views], -1) |
|
|
|
for i, l in enumerate(self.views_linears): |
|
h = self.views_linears[i](h) |
|
h = F.relu(h) |
|
|
|
rgb = self.rgb_linear(h) |
|
return alpha + 1.0, rgb |
|
else: |
|
assert False |
|
|