Spaces:
Runtime error
Runtime error
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import tinycudann as tcnn | |
# Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. | |
class Embedder: | |
def __init__(self, **kwargs): | |
self.kwargs = kwargs | |
self.create_embedding_fn() | |
def create_embedding_fn(self): | |
embed_fns = [] | |
d = self.kwargs['input_dims'] | |
out_dim = 0 | |
if self.kwargs['include_input']: | |
embed_fns.append(lambda x: x) | |
out_dim += d | |
max_freq = self.kwargs['max_freq_log2'] | |
N_freqs = self.kwargs['num_freqs'] | |
if self.kwargs['log_sampling']: | |
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) | |
else: | |
freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs) | |
for freq in freq_bands: | |
for p_fn in self.kwargs['periodic_fns']: | |
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) | |
out_dim += d | |
self.embed_fns = embed_fns | |
self.out_dim = out_dim | |
def embed(self, inputs): | |
return torch.cat([fn(inputs) for fn in self.embed_fns], -1) | |
def get_embedder(multires, input_dims=3): | |
embed_kwargs = { | |
'include_input': True, | |
'input_dims': input_dims, | |
'max_freq_log2': multires - 1, | |
'num_freqs': multires, | |
'log_sampling': True, | |
'periodic_fns': [torch.sin, torch.cos], | |
} | |
embedder_obj = Embedder(**embed_kwargs) | |
def embed(x, eo=embedder_obj): return eo.embed(x) | |
return embed, embedder_obj.out_dim | |
class SDFNetwork(nn.Module): | |
def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5, | |
scale=1, geometric_init=True, weight_norm=True, inside_outside=False): | |
super(SDFNetwork, self).__init__() | |
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] | |
self.embed_fn_fine = None | |
if multires > 0: | |
embed_fn, input_ch = get_embedder(multires, input_dims=d_in) | |
self.embed_fn_fine = embed_fn | |
dims[0] = input_ch | |
self.num_layers = len(dims) | |
self.skip_in = skip_in | |
self.scale = scale | |
for l in range(0, self.num_layers - 1): | |
if l + 1 in self.skip_in: | |
out_dim = dims[l + 1] - dims[0] | |
else: | |
out_dim = dims[l + 1] | |
lin = nn.Linear(dims[l], out_dim) | |
if geometric_init: | |
if l == self.num_layers - 2: | |
if not inside_outside: | |
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) | |
torch.nn.init.constant_(lin.bias, -bias) | |
else: | |
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) | |
torch.nn.init.constant_(lin.bias, bias) | |
elif multires > 0 and l == 0: | |
torch.nn.init.constant_(lin.bias, 0.0) | |
torch.nn.init.constant_(lin.weight[:, 3:], 0.0) | |
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
elif multires > 0 and l in self.skip_in: | |
torch.nn.init.constant_(lin.bias, 0.0) | |
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) | |
else: | |
torch.nn.init.constant_(lin.bias, 0.0) | |
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
if weight_norm: | |
lin = nn.utils.weight_norm(lin) | |
setattr(self, "lin" + str(l), lin) | |
self.activation = nn.Softplus(beta=100) | |
def forward(self, inputs): | |
inputs = inputs * self.scale | |
if self.embed_fn_fine is not None: | |
inputs = self.embed_fn_fine(inputs) | |
x = inputs | |
for l in range(0, self.num_layers - 1): | |
lin = getattr(self, "lin" + str(l)) | |
if l in self.skip_in: | |
x = torch.cat([x, inputs], -1) / np.sqrt(2) | |
x = lin(x) | |
if l < self.num_layers - 2: | |
x = self.activation(x) | |
return x | |
def sdf(self, x): | |
return self.forward(x)[..., :1] | |
def sdf_hidden_appearance(self, x): | |
return self.forward(x) | |
def gradient(self, x): | |
x.requires_grad_(True) | |
with torch.enable_grad(): | |
y = self.sdf(x) | |
d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
gradients = torch.autograd.grad( | |
outputs=y, | |
inputs=x, | |
grad_outputs=d_output, | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True)[0] | |
return gradients | |
def sdf_normal(self, x): | |
x.requires_grad_(True) | |
with torch.enable_grad(): | |
y = self.sdf(x) | |
d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
gradients = torch.autograd.grad( | |
outputs=y, | |
inputs=x, | |
grad_outputs=d_output, | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True)[0] | |
return y[..., :1].detach(), gradients.detach() | |
class SDFNetworkWithFeature(nn.Module): | |
def __init__(self, cube, dp_in, df_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5, | |
scale=1, geometric_init=True, weight_norm=True, inside_outside=False, cube_length=0.5): | |
super().__init__() | |
self.register_buffer("cube", cube) | |
self.cube_length = cube_length | |
dims = [dp_in+df_in] + [d_hidden for _ in range(n_layers)] + [d_out] | |
self.embed_fn_fine = None | |
if multires > 0: | |
embed_fn, input_ch = get_embedder(multires, input_dims=dp_in) | |
self.embed_fn_fine = embed_fn | |
dims[0] = input_ch + df_in | |
self.num_layers = len(dims) | |
self.skip_in = skip_in | |
self.scale = scale | |
for l in range(0, self.num_layers - 1): | |
if l + 1 in self.skip_in: | |
out_dim = dims[l + 1] - dims[0] | |
else: | |
out_dim = dims[l + 1] | |
lin = nn.Linear(dims[l], out_dim) | |
if geometric_init: | |
if l == self.num_layers - 2: | |
if not inside_outside: | |
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) | |
torch.nn.init.constant_(lin.bias, -bias) | |
else: | |
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) | |
torch.nn.init.constant_(lin.bias, bias) | |
elif multires > 0 and l == 0: | |
torch.nn.init.constant_(lin.bias, 0.0) | |
torch.nn.init.constant_(lin.weight[:, 3:], 0.0) | |
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
elif multires > 0 and l in self.skip_in: | |
torch.nn.init.constant_(lin.bias, 0.0) | |
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) | |
else: | |
torch.nn.init.constant_(lin.bias, 0.0) | |
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
if weight_norm: | |
lin = nn.utils.weight_norm(lin) | |
setattr(self, "lin" + str(l), lin) | |
self.activation = nn.Softplus(beta=100) | |
def forward(self, points): | |
points = points * self.scale | |
# note: point*2 because the cube is [-0.5,0.5] | |
with torch.no_grad(): | |
feats = F.grid_sample(self.cube, points.view(1,-1,1,1,3)/self.cube_length, mode='bilinear', align_corners=True, padding_mode='zeros').detach() | |
feats = feats.view(self.cube.shape[1], -1).permute(1,0).view(*points.shape[:-1], -1) | |
if self.embed_fn_fine is not None: | |
points = self.embed_fn_fine(points) | |
x = torch.cat([points, feats], -1) | |
for l in range(0, self.num_layers - 1): | |
lin = getattr(self, "lin" + str(l)) | |
if l in self.skip_in: | |
x = torch.cat([x, points, feats], -1) / np.sqrt(2) | |
x = lin(x) | |
if l < self.num_layers - 2: | |
x = self.activation(x) | |
# concat feats | |
x = torch.cat([x, feats], -1) | |
return x | |
def sdf(self, x): | |
return self.forward(x)[..., :1] | |
def sdf_hidden_appearance(self, x): | |
return self.forward(x) | |
def gradient(self, x): | |
x.requires_grad_(True) | |
with torch.enable_grad(): | |
y = self.sdf(x) | |
d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
gradients = torch.autograd.grad( | |
outputs=y, | |
inputs=x, | |
grad_outputs=d_output, | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True)[0] | |
return gradients | |
def sdf_normal(self, x): | |
x.requires_grad_(True) | |
with torch.enable_grad(): | |
y = self.sdf(x) | |
d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
gradients = torch.autograd.grad( | |
outputs=y, | |
inputs=x, | |
grad_outputs=d_output, | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True)[0] | |
return y[..., :1].detach(), gradients.detach() | |
class VanillaMLP(nn.Module): | |
def __init__(self, dim_in, dim_out, n_neurons, n_hidden_layers): | |
super().__init__() | |
self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers | |
self.sphere_init, self.weight_norm = True, True | |
self.sphere_init_radius = 0.5 | |
self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] | |
for i in range(self.n_hidden_layers - 1): | |
self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] | |
self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] | |
self.layers = nn.Sequential(*self.layers) | |
def forward(self, x): | |
x = self.layers(x.float()) | |
return x | |
def make_linear(self, dim_in, dim_out, is_first, is_last): | |
layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality | |
if self.sphere_init: | |
if is_last: | |
torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) | |
torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001) | |
elif is_first: | |
torch.nn.init.constant_(layer.bias, 0.0) | |
torch.nn.init.constant_(layer.weight[:, 3:], 0.0) | |
torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out)) | |
else: | |
torch.nn.init.constant_(layer.bias, 0.0) | |
torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) | |
else: | |
torch.nn.init.constant_(layer.bias, 0.0) | |
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') | |
if self.weight_norm: | |
layer = nn.utils.weight_norm(layer) | |
return layer | |
def make_activation(self): | |
if self.sphere_init: | |
return nn.Softplus(beta=100) | |
else: | |
return nn.ReLU(inplace=True) | |
class SDFHashGridNetwork(nn.Module): | |
def __init__(self, bound=0.5, feats_dim=13): | |
super().__init__() | |
self.bound = bound | |
# max_resolution = 32 | |
# base_resolution = 16 | |
# n_levels = 4 | |
# log2_hashmap_size = 16 | |
# n_features_per_level = 8 | |
max_resolution = 2048 | |
base_resolution = 16 | |
n_levels = 16 | |
log2_hashmap_size = 19 | |
n_features_per_level = 2 | |
# max_res = base_res * t^(k-1) | |
per_level_scale = (max_resolution / base_resolution)** (1 / (n_levels - 1)) | |
self.encoder = tcnn.Encoding( | |
n_input_dims=3, | |
encoding_config={ | |
"otype": "HashGrid", | |
"n_levels": n_levels, | |
"n_features_per_level": n_features_per_level, | |
"log2_hashmap_size": log2_hashmap_size, | |
"base_resolution": base_resolution, | |
"per_level_scale": per_level_scale, | |
}, | |
) | |
self.sdf_mlp = VanillaMLP(n_levels*n_features_per_level+3,feats_dim,64,1) | |
def forward(self, x): | |
shape = x.shape[:-1] | |
x = x.reshape(-1, 3) | |
x_ = (x + self.bound) / (2 * self.bound) | |
feats = self.encoder(x_) | |
feats = torch.cat([x, feats], 1) | |
feats = self.sdf_mlp(feats) | |
feats = feats.reshape(*shape,-1) | |
return feats | |
def sdf(self, x): | |
return self(x)[...,:1] | |
def gradient(self, x): | |
x.requires_grad_(True) | |
with torch.enable_grad(): | |
y = self.sdf(x) | |
d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
gradients = torch.autograd.grad( | |
outputs=y, | |
inputs=x, | |
grad_outputs=d_output, | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True)[0] | |
return gradients | |
def sdf_normal(self, x): | |
x.requires_grad_(True) | |
with torch.enable_grad(): | |
y = self.sdf(x) | |
d_output = torch.ones_like(y, requires_grad=False, device=y.device) | |
gradients = torch.autograd.grad( | |
outputs=y, | |
inputs=x, | |
grad_outputs=d_output, | |
create_graph=True, | |
retain_graph=True, | |
only_inputs=True)[0] | |
return y[..., :1].detach(), gradients.detach() | |
class RenderingFFNetwork(nn.Module): | |
def __init__(self, in_feats_dim=12): | |
super().__init__() | |
self.dir_encoder = tcnn.Encoding( | |
n_input_dims=3, | |
encoding_config={ | |
"otype": "SphericalHarmonics", | |
"degree": 4, | |
}, | |
) | |
self.color_mlp = tcnn.Network( | |
n_input_dims = in_feats_dim + 3 + self.dir_encoder.n_output_dims, | |
n_output_dims = 3, | |
network_config={ | |
"otype": "FullyFusedMLP", | |
"activation": "ReLU", | |
"output_activation": "none", | |
"n_neurons": 64, | |
"n_hidden_layers": 2, | |
}, | |
) | |
def forward(self, points, normals, view_dirs, feature_vectors): | |
normals = F.normalize(normals, dim=-1) | |
view_dirs = F.normalize(view_dirs, dim=-1) | |
reflective = torch.sum(view_dirs * normals, -1, keepdim=True) * normals * 2 - view_dirs | |
x = torch.cat([feature_vectors, normals, self.dir_encoder(reflective)], -1) | |
colors = self.color_mlp(x).float() | |
colors = F.sigmoid(colors) | |
return colors | |
# This implementation is borrowed from IDR: https://github.com/lioryariv/idr | |
class RenderingNetwork(nn.Module): | |
def __init__(self, d_feature, d_in, d_out, d_hidden, | |
n_layers, weight_norm=True, multires_view=0, squeeze_out=True, use_view_dir=True): | |
super().__init__() | |
self.squeeze_out = squeeze_out | |
self.rgb_act=F.sigmoid | |
self.use_view_dir=use_view_dir | |
dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] | |
self.embedview_fn = None | |
if multires_view > 0: | |
embedview_fn, input_ch = get_embedder(multires_view) | |
self.embedview_fn = embedview_fn | |
dims[0] += (input_ch - 3) | |
self.num_layers = len(dims) | |
for l in range(0, self.num_layers - 1): | |
out_dim = dims[l + 1] | |
lin = nn.Linear(dims[l], out_dim) | |
if weight_norm: | |
lin = nn.utils.weight_norm(lin) | |
setattr(self, "lin" + str(l), lin) | |
self.relu = nn.ReLU() | |
def forward(self, points, normals, view_dirs, feature_vectors): | |
if self.use_view_dir: | |
view_dirs = F.normalize(view_dirs, dim=-1) | |
normals = F.normalize(normals, dim=-1) | |
reflective = torch.sum(view_dirs*normals, -1, keepdim=True) * normals * 2 - view_dirs | |
if self.embedview_fn is not None: reflective = self.embedview_fn(reflective) | |
rendering_input = torch.cat([points, reflective, normals, feature_vectors], dim=-1) | |
else: | |
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) | |
x = rendering_input | |
for l in range(0, self.num_layers - 1): | |
lin = getattr(self, "lin" + str(l)) | |
x = lin(x) | |
if l < self.num_layers - 2: | |
x = self.relu(x) | |
if self.squeeze_out: | |
x = self.rgb_act(x) | |
return x | |
class SingleVarianceNetwork(nn.Module): | |
def __init__(self, init_val, activation='exp'): | |
super(SingleVarianceNetwork, self).__init__() | |
self.act = activation | |
self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) | |
def forward(self, x): | |
device = x.device | |
if self.act=='exp': | |
return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * torch.exp(self.variance * 10.0) | |
else: | |
raise NotImplementedError | |
def warp(self, x, inv_s): | |
device = x.device | |
return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * inv_s |