NaRCan_demo / NaRCan_model.py
Koi953215's picture
init commit
f9cbc98
raw
history blame
No virus
8.19 kB
import torch
from torch import nn
import numpy as np
import math
# import tinycudann as tcnn
class SineLayer(nn.Module):
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
# If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
# nonlinearity. Different signals may require different omega_0 in the first layer - this is a
# hyperparameter.
# If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
# activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega_0=30):
super().__init__()
self.omega_0 = omega_0
self.is_first = is_first
self.in_features = in_features
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.init_weights()
def init_weights(self):
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-1 / self.in_features,
1 / self.in_features)
else:
self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
np.sqrt(6 / self.in_features) / self.omega_0)
def forward(self, input):
return torch.sin(self.omega_0 * self.linear(input))
def forward_with_intermediate(self, input):
# For visualization of activation distributions
intermediate = self.omega_0 * self.linear(input)
return torch.sin(intermediate), intermediate
class Siren(nn.Module):
def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
first_omega_0=30, hidden_omega_0=30.):
super().__init__()
self.net = []
self.net.append(SineLayer(in_features, hidden_features,
is_first=True, omega_0=first_omega_0))
for i in range(hidden_layers):
self.net.append(SineLayer(hidden_features, hidden_features,
is_first=False, omega_0=hidden_omega_0))
if outermost_linear:
final_linear = nn.Linear(hidden_features, out_features)
with torch.no_grad():
final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
np.sqrt(6 / hidden_features) / hidden_omega_0)
self.net.append(final_linear)
else:
self.net.append(SineLayer(hidden_features, out_features,
is_first=False, omega_0=hidden_omega_0))
self.net = nn.Sequential(*self.net)
def forward(self, coords):
output = self.net(coords)
return output
class Homography(nn.Module):
def __init__(self, in_features=1, hidden_features=256, hidden_layers=1):
super().__init__()
out_features = 8
self.net = []
self.net.append(nn.Linear(in_features, hidden_features))
self.net.append(nn.ReLU(inplace=True))
for i in range(hidden_layers):
self.net.append(nn.Linear(hidden_features, hidden_features))
self.net.append(nn.ReLU(inplace=True))
self.net.append(nn.Linear(hidden_features, out_features))
self.net = nn.Sequential(*self.net)
self.init_weights()
def init_weights(self):
with torch.no_grad():
self.net[-1].bias.copy_(torch.Tensor([1., 0., 0., 0., 1., 0., 0., 0.]))
def forward(self, coords):
output = self.net(coords)
return output
class Annealed(nn.Module):
def __init__(self, in_channels, annealed_step, annealed_begin_step=0, identity=True):
"""
Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
in_channels: number of input channels (3 for both xyz and direction)
"""
super(Annealed, self).__init__()
self.N_freqs = 16
self.in_channels = in_channels
self.annealed = True
self.annealed_step = annealed_step
self.annealed_begin_step = annealed_begin_step
self.index = torch.linspace(0, self.N_freqs - 1, self.N_freqs)
self.identity = identity
self.index_2 = self.index.view(-1, 1).repeat(1, 2).view(-1)
def forward(self, x_embed, step):
"""
Embeds x to (x, sin(2^k x), cos(2^k x), ...)
Different from the paper, "x" is also in the output
See https://github.com/bmild/nerf/issues/12
Inputs:
x: (B, self.in_channels)
Outputs:
out: (B, self.out_channels)
"""
use_PE = False
if self.annealed_begin_step == 0:
# calculate the w for each freq bands
alpha = self.N_freqs * step / float(self.annealed_step)
else:
if step <= self.annealed_begin_step:
alpha = 0
else:
alpha = (self.N_freqs) * (step - self.annealed_begin_step) / float(
self.annealed_step)
w = (1 - torch.cos(math.pi * torch.clamp(alpha * torch.ones_like(self.index_2) - self.index_2, 0, 1))) / 2
if use_PE:
w[16:] = w[:16]
out = x_embed * w.to(x_embed.device)
return out
class BARF_PE(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder = tcnn.Encoding(n_input_dims=2,
encoding_config=config["positional encoding"])
self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims +
2,
n_output_dims=3,
network_config=config["BARF network"])
def forward(self, x, step=0, aneal_func=None):
input = x
input = self.encoder(input)
if aneal_func is not None:
input = torch.cat([x, aneal_func(input,step)], dim=-1)
else:
input = torch.cat([x, input], dim=-1)
weight = torch.ones(input.shape[-1], device=input.device).cuda()
x = self.decoder(weight * input)
return x
class Deform_Hash3d(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder = tcnn.Encoding(n_input_dims=3,
encoding_config=config["encoding_deform3d"])
self.decoder = nn.Sequential(nn.Linear(self.encoder.n_output_dims + 3, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 2)
)
def forward(self, x, step=0, aneal_func=None):
input = x
input = self.encoder(input)
if aneal_func is not None:
input = torch.cat([x, aneal_func(input,step)], dim=-1)
else:
input = torch.cat([x, input], dim=-1)
weight = torch.ones(input.shape[-1], device=input.device).cuda()
x = self.decoder(weight * input) / 5
return x
class Deform_Hash3d_Warp(nn.Module):
def __init__(self, config):
super().__init__()
self.Deform_Hash3d = Deform_Hash3d(config)
def forward(self, xyt_norm, step=0,aneal_func=None):
x = self.Deform_Hash3d(xyt_norm,step=step, aneal_func=aneal_func)
return x