SyncDreamer / renderer /neus_networks.py
liuyuan-pal's picture
init
8bb8404
raw
history blame
17.9 kB
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tinycudann as tcnn
# Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
else:
freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, input_dims=3):
embed_kwargs = {
'include_input': True,
'input_dims': input_dims,
'max_freq_log2': multires - 1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj): return eo.embed(x)
return embed, embedder_obj.out_dim
class SDFNetwork(nn.Module):
def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
scale=1, geometric_init=True, weight_norm=True, inside_outside=False):
super(SDFNetwork, self).__init__()
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
self.embed_fn_fine = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
self.embed_fn_fine = embed_fn
dims[0] = input_ch
self.num_layers = len(dims)
self.skip_in = skip_in
self.scale = scale
for l in range(0, self.num_layers - 1):
if l + 1 in self.skip_in:
out_dim = dims[l + 1] - dims[0]
else:
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if geometric_init:
if l == self.num_layers - 2:
if not inside_outside:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, -bias)
else:
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, bias)
elif multires > 0 and l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
elif multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.activation = nn.Softplus(beta=100)
def forward(self, inputs):
inputs = inputs * self.scale
if self.embed_fn_fine is not None:
inputs = self.embed_fn_fine(inputs)
x = inputs
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
if l in self.skip_in:
x = torch.cat([x, inputs], -1) / np.sqrt(2)
x = lin(x)
if l < self.num_layers - 2:
x = self.activation(x)
return x
def sdf(self, x):
return self.forward(x)[..., :1]
def sdf_hidden_appearance(self, x):
return self.forward(x)
def gradient(self, x):
x.requires_grad_(True)
with torch.enable_grad():
y = self.sdf(x)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients
def sdf_normal(self, x):
x.requires_grad_(True)
with torch.enable_grad():
y = self.sdf(x)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return y[..., :1].detach(), gradients.detach()
class SDFNetworkWithFeature(nn.Module):
def __init__(self, cube, dp_in, df_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5,
scale=1, geometric_init=True, weight_norm=True, inside_outside=False, cube_length=0.5):
super().__init__()
self.register_buffer("cube", cube)
self.cube_length = cube_length
dims = [dp_in+df_in] + [d_hidden for _ in range(n_layers)] + [d_out]
self.embed_fn_fine = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires, input_dims=dp_in)
self.embed_fn_fine = embed_fn
dims[0] = input_ch + df_in
self.num_layers = len(dims)
self.skip_in = skip_in
self.scale = scale
for l in range(0, self.num_layers - 1):
if l + 1 in self.skip_in:
out_dim = dims[l + 1] - dims[0]
else:
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if geometric_init:
if l == self.num_layers - 2:
if not inside_outside:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, -bias)
else:
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, bias)
elif multires > 0 and l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
elif multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.activation = nn.Softplus(beta=100)
def forward(self, points):
points = points * self.scale
# note: point*2 because the cube is [-0.5,0.5]
with torch.no_grad():
feats = F.grid_sample(self.cube, points.view(1,-1,1,1,3)/self.cube_length, mode='bilinear', align_corners=True, padding_mode='zeros').detach()
feats = feats.view(self.cube.shape[1], -1).permute(1,0).view(*points.shape[:-1], -1)
if self.embed_fn_fine is not None:
points = self.embed_fn_fine(points)
x = torch.cat([points, feats], -1)
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
if l in self.skip_in:
x = torch.cat([x, points, feats], -1) / np.sqrt(2)
x = lin(x)
if l < self.num_layers - 2:
x = self.activation(x)
# concat feats
x = torch.cat([x, feats], -1)
return x
def sdf(self, x):
return self.forward(x)[..., :1]
def sdf_hidden_appearance(self, x):
return self.forward(x)
def gradient(self, x):
x.requires_grad_(True)
with torch.enable_grad():
y = self.sdf(x)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients
def sdf_normal(self, x):
x.requires_grad_(True)
with torch.enable_grad():
y = self.sdf(x)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return y[..., :1].detach(), gradients.detach()
class VanillaMLP(nn.Module):
def __init__(self, dim_in, dim_out, n_neurons, n_hidden_layers):
super().__init__()
self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers
self.sphere_init, self.weight_norm = True, True
self.sphere_init_radius = 0.5
self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()]
for i in range(self.n_hidden_layers - 1):
self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
self.layers = nn.Sequential(*self.layers)
@torch.cuda.amp.autocast(False)
def forward(self, x):
x = self.layers(x.float())
return x
def make_linear(self, dim_in, dim_out, is_first, is_last):
layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality
if self.sphere_init:
if is_last:
torch.nn.init.constant_(layer.bias, -self.sphere_init_radius)
torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001)
elif is_first:
torch.nn.init.constant_(layer.bias, 0.0)
torch.nn.init.constant_(layer.weight[:, 3:], 0.0)
torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out))
else:
torch.nn.init.constant_(layer.bias, 0.0)
torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out))
else:
torch.nn.init.constant_(layer.bias, 0.0)
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
if self.weight_norm:
layer = nn.utils.weight_norm(layer)
return layer
def make_activation(self):
if self.sphere_init:
return nn.Softplus(beta=100)
else:
return nn.ReLU(inplace=True)
class SDFHashGridNetwork(nn.Module):
def __init__(self, bound=0.5, feats_dim=13):
super().__init__()
self.bound = bound
# max_resolution = 32
# base_resolution = 16
# n_levels = 4
# log2_hashmap_size = 16
# n_features_per_level = 8
max_resolution = 2048
base_resolution = 16
n_levels = 16
log2_hashmap_size = 19
n_features_per_level = 2
# max_res = base_res * t^(k-1)
per_level_scale = (max_resolution / base_resolution)** (1 / (n_levels - 1))
self.encoder = tcnn.Encoding(
n_input_dims=3,
encoding_config={
"otype": "HashGrid",
"n_levels": n_levels,
"n_features_per_level": n_features_per_level,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_resolution,
"per_level_scale": per_level_scale,
},
)
self.sdf_mlp = VanillaMLP(n_levels*n_features_per_level+3,feats_dim,64,1)
def forward(self, x):
shape = x.shape[:-1]
x = x.reshape(-1, 3)
x_ = (x + self.bound) / (2 * self.bound)
feats = self.encoder(x_)
feats = torch.cat([x, feats], 1)
feats = self.sdf_mlp(feats)
feats = feats.reshape(*shape,-1)
return feats
def sdf(self, x):
return self(x)[...,:1]
def gradient(self, x):
x.requires_grad_(True)
with torch.enable_grad():
y = self.sdf(x)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients
def sdf_normal(self, x):
x.requires_grad_(True)
with torch.enable_grad():
y = self.sdf(x)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return y[..., :1].detach(), gradients.detach()
class RenderingFFNetwork(nn.Module):
def __init__(self, in_feats_dim=12):
super().__init__()
self.dir_encoder = tcnn.Encoding(
n_input_dims=3,
encoding_config={
"otype": "SphericalHarmonics",
"degree": 4,
},
)
self.color_mlp = tcnn.Network(
n_input_dims = in_feats_dim + 3 + self.dir_encoder.n_output_dims,
n_output_dims = 3,
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "none",
"n_neurons": 64,
"n_hidden_layers": 2,
},
)
def forward(self, points, normals, view_dirs, feature_vectors):
normals = F.normalize(normals, dim=-1)
view_dirs = F.normalize(view_dirs, dim=-1)
reflective = torch.sum(view_dirs * normals, -1, keepdim=True) * normals * 2 - view_dirs
x = torch.cat([feature_vectors, normals, self.dir_encoder(reflective)], -1)
colors = self.color_mlp(x).float()
colors = F.sigmoid(colors)
return colors
# This implementation is borrowed from IDR: https://github.com/lioryariv/idr
class RenderingNetwork(nn.Module):
def __init__(self, d_feature, d_in, d_out, d_hidden,
n_layers, weight_norm=True, multires_view=0, squeeze_out=True, use_view_dir=True):
super().__init__()
self.squeeze_out = squeeze_out
self.rgb_act=F.sigmoid
self.use_view_dir=use_view_dir
dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
self.embedview_fn = None
if multires_view > 0:
embedview_fn, input_ch = get_embedder(multires_view)
self.embedview_fn = embedview_fn
dims[0] += (input_ch - 3)
self.num_layers = len(dims)
for l in range(0, self.num_layers - 1):
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.relu = nn.ReLU()
def forward(self, points, normals, view_dirs, feature_vectors):
if self.use_view_dir:
view_dirs = F.normalize(view_dirs, dim=-1)
normals = F.normalize(normals, dim=-1)
reflective = torch.sum(view_dirs*normals, -1, keepdim=True) * normals * 2 - view_dirs
if self.embedview_fn is not None: reflective = self.embedview_fn(reflective)
rendering_input = torch.cat([points, reflective, normals, feature_vectors], dim=-1)
else:
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
x = rendering_input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
x = lin(x)
if l < self.num_layers - 2:
x = self.relu(x)
if self.squeeze_out:
x = self.rgb_act(x)
return x
class SingleVarianceNetwork(nn.Module):
def __init__(self, init_val, activation='exp'):
super(SingleVarianceNetwork, self).__init__()
self.act = activation
self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
def forward(self, x):
device = x.device
if self.act=='exp':
return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * torch.exp(self.variance * 10.0)
else:
raise NotImplementedError
def warp(self, x, inv_s):
device = x.device
return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * inv_s