# The codes are from NeuS import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from models.embedder import get_embedder class SDFNetwork(nn.Module): def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5, scale=1, geometric_init=True, weight_norm=True, activation='softplus', conditional_type='multiply'): super(SDFNetwork, self).__init__() dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] self.embed_fn_fine = None if multires > 0: embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) self.embed_fn_fine = embed_fn dims[0] = input_ch self.num_layers = len(dims) self.skip_in = skip_in self.scale = scale for l in range(0, self.num_layers - 1): if l + 1 in self.skip_in: out_dim = dims[l + 1] - dims[0] else: out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if geometric_init: if l == self.num_layers - 2: torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, -bias) elif multires > 0 and l == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif multires > 0 and l in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) # ? why dims[0] - 3 else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) if activation == 'softplus': self.activation = nn.Softplus(beta=100) else: assert activation == 'relu' self.activation = nn.ReLU() def forward(self, inputs): inputs = inputs * self.scale if self.embed_fn_fine is not None: inputs = self.embed_fn_fine(inputs) x = inputs for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) if l in self.skip_in: x = torch.cat([x, inputs], 1) / np.sqrt(2) x = lin(x) if l < self.num_layers - 2: x = self.activation(x) return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) def sdf(self, x): return self.forward(x)[:, :1] def sdf_hidden_appearance(self, x): return self.forward(x) def gradient(self, x): x.requires_grad_(True) y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients.unsqueeze(1) class VarianceNetwork(nn.Module): def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0): super(VarianceNetwork, self).__init__() dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] self.embed_fn_fine = None if multires > 0: embed_fn, input_ch = get_embedder(multires, normalize=False) self.embed_fn_fine = embed_fn dims[0] = input_ch self.num_layers = len(dims) self.skip_in = skip_in for l in range(0, self.num_layers - 1): if l + 1 in self.skip_in: out_dim = dims[l + 1] - dims[0] else: out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) setattr(self, "lin" + str(l), lin) self.relu = nn.ReLU() self.softplus = nn.Softplus(beta=100) def forward(self, inputs): if self.embed_fn_fine is not None: inputs = self.embed_fn_fine(inputs) x = inputs for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) if l in self.skip_in: x = torch.cat([x, inputs], 1) / np.sqrt(2) x = lin(x) if l < self.num_layers - 2: x = self.relu(x) # return torch.exp(x) return 1.0 / (self.softplus(x + 0.5) + 1e-3) def coarse(self, inputs): return self.forward(inputs)[:, :1] def fine(self, inputs): return self.forward(inputs)[:, 1:] class FixVarianceNetwork(nn.Module): def __init__(self, base): super(FixVarianceNetwork, self).__init__() self.base = base self.iter_step = 0 def set_iter_step(self, iter_step): self.iter_step = iter_step def forward(self, x): return torch.ones([len(x), 1]) * np.exp(-self.iter_step / self.base) class SingleVarianceNetwork(nn.Module): def __init__(self, init_val=1.0): super(SingleVarianceNetwork, self).__init__() self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) def forward(self, x): return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0) class RenderingNetwork(nn.Module): def __init__( self, d_feature, mode, d_in, d_out, d_hidden, n_layers, weight_norm=True, multires_view=0, squeeze_out=True, d_conditional_colors=0 ): super().__init__() self.mode = mode self.squeeze_out = squeeze_out dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] self.embedview_fn = None if multires_view > 0: embedview_fn, input_ch = get_embedder(multires_view) self.embedview_fn = embedview_fn dims[0] += (input_ch - 3) self.num_layers = len(dims) for l in range(0, self.num_layers - 1): out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.relu = nn.ReLU() def forward(self, points, normals, view_dirs, feature_vectors): if self.embedview_fn is not None: view_dirs = self.embedview_fn(view_dirs) rendering_input = None if self.mode == 'idr': rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) elif self.mode == 'no_view_dir': rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) elif self.mode == 'no_normal': rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) elif self.mode == 'no_points': rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) elif self.mode == 'no_points_no_view_dir': rendering_input = torch.cat([normals, feature_vectors], dim=-1) x = rendering_input for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) x = lin(x) if l < self.num_layers - 2: x = self.relu(x) if self.squeeze_out: x = torch.sigmoid(x) return x # Code from nerf-pytorch class NeRF(nn.Module): def __init__(self, D=8, W=256, d_in=3, d_in_view=3, multires=0, multires_view=0, output_ch=4, skips=[4], use_viewdirs=False): """ """ super(NeRF, self).__init__() self.D = D self.W = W self.d_in = d_in self.d_in_view = d_in_view self.input_ch = 3 self.input_ch_view = 3 self.embed_fn = None self.embed_fn_view = None if multires > 0: embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) self.embed_fn = embed_fn self.input_ch = input_ch if multires_view > 0: embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view, normalize=False) self.embed_fn_view = embed_fn_view self.input_ch_view = input_ch_view self.skips = skips self.use_viewdirs = use_viewdirs self.pts_linears = nn.ModuleList( [nn.Linear(self.input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)]) ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) ### Implementation according to the paper # self.views_linears = nn.ModuleList( # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) if use_viewdirs: self.feature_linear = nn.Linear(W, W) self.alpha_linear = nn.Linear(W, 1) self.rgb_linear = nn.Linear(W // 2, 3) else: self.output_linear = nn.Linear(W, output_ch) def forward(self, input_pts, input_views): if self.embed_fn is not None: input_pts = self.embed_fn(input_pts) if self.embed_fn_view is not None: input_views = self.embed_fn_view(input_views) h = input_pts for i, l in enumerate(self.pts_linears): h = self.pts_linears[i](h) h = F.relu(h) if i in self.skips: h = torch.cat([input_pts, h], -1) if self.use_viewdirs: alpha = self.alpha_linear(h) feature = self.feature_linear(h) h = torch.cat([feature, input_views], -1) for i, l in enumerate(self.views_linears): h = self.views_linears[i](h) h = F.relu(h) rgb = self.rgb_linear(h) return alpha + 1.0, rgb else: assert False