import torch from torch import nn import numpy as np import math # import tinycudann as tcnn class SineLayer(nn.Module): # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0. # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the # nonlinearity. Different signals may require different omega_0 in the first layer - this is a # hyperparameter. # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5) def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30): super().__init__() self.omega_0 = omega_0 self.is_first = is_first self.in_features = in_features self.linear = nn.Linear(in_features, out_features, bias=bias) self.init_weights() def init_weights(self): with torch.no_grad(): if self.is_first: self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) else: self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, np.sqrt(6 / self.in_features) / self.omega_0) def forward(self, input): return torch.sin(self.omega_0 * self.linear(input)) def forward_with_intermediate(self, input): # For visualization of activation distributions intermediate = self.omega_0 * self.linear(input) return torch.sin(intermediate), intermediate class Siren(nn.Module): def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, first_omega_0=30, hidden_omega_0=30.): super().__init__() self.net = [] self.net.append(SineLayer(in_features, hidden_features, is_first=True, omega_0=first_omega_0)) for i in range(hidden_layers): self.net.append(SineLayer(hidden_features, hidden_features, is_first=False, omega_0=hidden_omega_0)) if outermost_linear: final_linear = nn.Linear(hidden_features, out_features) with torch.no_grad(): final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, np.sqrt(6 / hidden_features) / hidden_omega_0) self.net.append(final_linear) else: self.net.append(SineLayer(hidden_features, out_features, is_first=False, omega_0=hidden_omega_0)) self.net = nn.Sequential(*self.net) def forward(self, coords): output = self.net(coords) return output class Homography(nn.Module): def __init__(self, in_features=1, hidden_features=256, hidden_layers=1): super().__init__() out_features = 8 self.net = [] self.net.append(nn.Linear(in_features, hidden_features)) self.net.append(nn.ReLU(inplace=True)) for i in range(hidden_layers): self.net.append(nn.Linear(hidden_features, hidden_features)) self.net.append(nn.ReLU(inplace=True)) self.net.append(nn.Linear(hidden_features, out_features)) self.net = nn.Sequential(*self.net) self.init_weights() def init_weights(self): with torch.no_grad(): self.net[-1].bias.copy_(torch.Tensor([1., 0., 0., 0., 1., 0., 0., 0.])) def forward(self, coords): output = self.net(coords) return output class Annealed(nn.Module): def __init__(self, in_channels, annealed_step, annealed_begin_step=0, identity=True): """ Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) in_channels: number of input channels (3 for both xyz and direction) """ super(Annealed, self).__init__() self.N_freqs = 16 self.in_channels = in_channels self.annealed = True self.annealed_step = annealed_step self.annealed_begin_step = annealed_begin_step self.index = torch.linspace(0, self.N_freqs - 1, self.N_freqs) self.identity = identity self.index_2 = self.index.view(-1, 1).repeat(1, 2).view(-1) def forward(self, x_embed, step): """ Embeds x to (x, sin(2^k x), cos(2^k x), ...) Different from the paper, "x" is also in the output See https://github.com/bmild/nerf/issues/12 Inputs: x: (B, self.in_channels) Outputs: out: (B, self.out_channels) """ use_PE = False if self.annealed_begin_step == 0: # calculate the w for each freq bands alpha = self.N_freqs * step / float(self.annealed_step) else: if step <= self.annealed_begin_step: alpha = 0 else: alpha = (self.N_freqs) * (step - self.annealed_begin_step) / float( self.annealed_step) w = (1 - torch.cos(math.pi * torch.clamp(alpha * torch.ones_like(self.index_2) - self.index_2, 0, 1))) / 2 if use_PE: w[16:] = w[:16] out = x_embed * w.to(x_embed.device) return out class BARF_PE(nn.Module): def __init__(self, config): super().__init__() self.encoder = tcnn.Encoding(n_input_dims=2, encoding_config=config["positional encoding"]) self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims + 2, n_output_dims=3, network_config=config["BARF network"]) def forward(self, x, step=0, aneal_func=None): input = x input = self.encoder(input) if aneal_func is not None: input = torch.cat([x, aneal_func(input,step)], dim=-1) else: input = torch.cat([x, input], dim=-1) weight = torch.ones(input.shape[-1], device=input.device).cuda() x = self.decoder(weight * input) return x class Deform_Hash3d(nn.Module): def __init__(self, config): super().__init__() self.encoder = tcnn.Encoding(n_input_dims=3, encoding_config=config["encoding_deform3d"]) self.decoder = nn.Sequential(nn.Linear(self.encoder.n_output_dims + 3, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 2) ) def forward(self, x, step=0, aneal_func=None): input = x input = self.encoder(input) if aneal_func is not None: input = torch.cat([x, aneal_func(input,step)], dim=-1) else: input = torch.cat([x, input], dim=-1) weight = torch.ones(input.shape[-1], device=input.device).cuda() x = self.decoder(weight * input) / 5 return x class Deform_Hash3d_Warp(nn.Module): def __init__(self, config): super().__init__() self.Deform_Hash3d = Deform_Hash3d(config) def forward(self, xyt_norm, step=0,aneal_func=None): x = self.Deform_Hash3d(xyt_norm,step=step, aneal_func=aneal_func) return x