|
import math |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import tinycudann as tcnn |
|
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info |
|
|
|
from utils.misc import config_to_primitive, get_rank |
|
from models.utils import get_activation |
|
from systems.utils import update_module_step |
|
|
|
class VanillaFrequency(nn.Module): |
|
def __init__(self, in_channels, config): |
|
super().__init__() |
|
self.N_freqs = config['n_frequencies'] |
|
self.in_channels, self.n_input_dims = in_channels, in_channels |
|
self.funcs = [torch.sin, torch.cos] |
|
self.freq_bands = 2**torch.linspace(0, self.N_freqs-1, self.N_freqs) |
|
self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs) |
|
self.n_masking_step = config.get('n_masking_step', 0) |
|
self.update_step(None, None) |
|
|
|
def forward(self, x): |
|
out = [] |
|
for freq, mask in zip(self.freq_bands, self.mask): |
|
for func in self.funcs: |
|
out += [func(freq*x) * mask] |
|
return torch.cat(out, -1) |
|
|
|
def update_step(self, epoch, global_step): |
|
if self.n_masking_step <= 0 or global_step is None: |
|
self.mask = torch.ones(self.N_freqs, dtype=torch.float32) |
|
else: |
|
self.mask = (1. - torch.cos(math.pi * (global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs)).clamp(0, 1))) / 2. |
|
rank_zero_debug(f'Update mask: {global_step}/{self.n_masking_step} {self.mask}') |
|
|
|
|
|
class ProgressiveBandHashGrid(nn.Module): |
|
def __init__(self, in_channels, config): |
|
super().__init__() |
|
self.n_input_dims = in_channels |
|
encoding_config = config.copy() |
|
encoding_config['otype'] = 'HashGrid' |
|
with torch.cuda.device(get_rank()): |
|
self.encoding = tcnn.Encoding(in_channels, encoding_config) |
|
self.n_output_dims = self.encoding.n_output_dims |
|
self.n_level = config['n_levels'] |
|
self.n_features_per_level = config['n_features_per_level'] |
|
self.start_level, self.start_step, self.update_steps = config['start_level'], config['start_step'], config['update_steps'] |
|
self.current_level = self.start_level |
|
self.mask = torch.zeros(self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank()) |
|
|
|
def forward(self, x): |
|
enc = self.encoding(x) |
|
enc = enc * self.mask |
|
return enc |
|
|
|
def update_step(self, epoch, global_step): |
|
current_level = min(self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level) |
|
if current_level > self.current_level: |
|
rank_zero_info(f'Update grid level to {current_level}') |
|
self.current_level = current_level |
|
self.mask[:self.current_level * self.n_features_per_level] = 1. |
|
|
|
|
|
class CompositeEncoding(nn.Module): |
|
def __init__(self, encoding, include_xyz=False, xyz_scale=1., xyz_offset=0.): |
|
super(CompositeEncoding, self).__init__() |
|
self.encoding = encoding |
|
self.include_xyz, self.xyz_scale, self.xyz_offset = include_xyz, xyz_scale, xyz_offset |
|
self.n_output_dims = int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims |
|
|
|
def forward(self, x, *args): |
|
return self.encoding(x, *args) if not self.include_xyz else torch.cat([x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1) |
|
|
|
def update_step(self, epoch, global_step): |
|
update_module_step(self.encoding, epoch, global_step) |
|
|
|
|
|
def get_encoding(n_input_dims, config): |
|
|
|
if config.otype == 'VanillaFrequency': |
|
encoding = VanillaFrequency(n_input_dims, config_to_primitive(config)) |
|
elif config.otype == 'ProgressiveBandHashGrid': |
|
encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config)) |
|
else: |
|
with torch.cuda.device(get_rank()): |
|
encoding = tcnn.Encoding(n_input_dims, config_to_primitive(config)) |
|
encoding = CompositeEncoding(encoding, include_xyz=config.get('include_xyz', False), xyz_scale=2., xyz_offset=-1.) |
|
return encoding |
|
|
|
|
|
class VanillaMLP(nn.Module): |
|
def __init__(self, dim_in, dim_out, config): |
|
super().__init__() |
|
self.n_neurons, self.n_hidden_layers = config['n_neurons'], config['n_hidden_layers'] |
|
self.sphere_init, self.weight_norm = config.get('sphere_init', False), config.get('weight_norm', False) |
|
self.sphere_init_radius = config.get('sphere_init_radius', 0.5) |
|
self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] |
|
for i in range(self.n_hidden_layers - 1): |
|
self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] |
|
self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] |
|
self.layers = nn.Sequential(*self.layers) |
|
self.output_activation = get_activation(config['output_activation']) |
|
|
|
@torch.cuda.amp.autocast(False) |
|
def forward(self, x): |
|
x = self.layers(x.float()) |
|
x = self.output_activation(x) |
|
return x |
|
|
|
def make_linear(self, dim_in, dim_out, is_first, is_last): |
|
layer = nn.Linear(dim_in, dim_out, bias=True) |
|
if self.sphere_init: |
|
if is_last: |
|
torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) |
|
torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001) |
|
elif is_first: |
|
torch.nn.init.constant_(layer.bias, 0.0) |
|
torch.nn.init.constant_(layer.weight[:, 3:], 0.0) |
|
torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out)) |
|
else: |
|
torch.nn.init.constant_(layer.bias, 0.0) |
|
torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) |
|
else: |
|
torch.nn.init.constant_(layer.bias, 0.0) |
|
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') |
|
|
|
if self.weight_norm: |
|
layer = nn.utils.weight_norm(layer) |
|
return layer |
|
|
|
def make_activation(self): |
|
if self.sphere_init: |
|
return nn.Softplus(beta=100) |
|
else: |
|
return nn.ReLU(inplace=True) |
|
|
|
|
|
def sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network): |
|
rank_zero_debug('Initialize tcnn MLP to approximately represent a sphere.') |
|
""" |
|
from https://github.com/NVlabs/tiny-cuda-nn/issues/96 |
|
It's the weight matrices of each layer laid out in row-major order and then concatenated. |
|
Notably: inputs and output dimensions are padded to multiples of 8 (CutlassMLP) or 16 (FullyFusedMLP). |
|
The padded input dimensions get a constant value of 1.0, |
|
whereas the padded output dimensions are simply ignored, |
|
so the weights pertaining to those can have any value. |
|
""" |
|
padto = 16 if config.otype == 'FullyFusedMLP' else 8 |
|
n_input_dims = n_input_dims + (padto - n_input_dims % padto) % padto |
|
n_output_dims = n_output_dims + (padto - n_output_dims % padto) % padto |
|
data = list(network.parameters())[0].data |
|
assert data.shape[0] == (n_input_dims + n_output_dims) * config.n_neurons + (config.n_hidden_layers - 1) * config.n_neurons**2 |
|
new_data = [] |
|
|
|
weight = torch.zeros((config.n_neurons, n_input_dims)).to(data) |
|
torch.nn.init.constant_(weight[:, 3:], 0.0) |
|
torch.nn.init.normal_(weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) |
|
new_data.append(weight.flatten()) |
|
|
|
for i in range(config.n_hidden_layers - 1): |
|
weight = torch.zeros((config.n_neurons, config.n_neurons)).to(data) |
|
torch.nn.init.normal_(weight, 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) |
|
new_data.append(weight.flatten()) |
|
|
|
weight = torch.zeros((n_output_dims, config.n_neurons)).to(data) |
|
torch.nn.init.normal_(weight, mean=math.sqrt(math.pi) / math.sqrt(config.n_neurons), std=0.0001) |
|
new_data.append(weight.flatten()) |
|
new_data = torch.cat(new_data) |
|
data.copy_(new_data) |
|
|
|
|
|
def get_mlp(n_input_dims, n_output_dims, config): |
|
if config.otype == 'VanillaMLP': |
|
network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config)) |
|
else: |
|
with torch.cuda.device(get_rank()): |
|
network = tcnn.Network(n_input_dims, n_output_dims, config_to_primitive(config)) |
|
if config.get('sphere_init', False): |
|
sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network) |
|
return network |
|
|
|
|
|
class EncodingWithNetwork(nn.Module): |
|
def __init__(self, encoding, network): |
|
super().__init__() |
|
self.encoding, self.network = encoding, network |
|
|
|
def forward(self, x): |
|
return self.network(self.encoding(x)) |
|
|
|
def update_step(self, epoch, global_step): |
|
update_module_step(self.encoding, epoch, global_step) |
|
update_module_step(self.network, epoch, global_step) |
|
|
|
|
|
def get_encoding_with_network(n_input_dims, n_output_dims, encoding_config, network_config): |
|
|
|
if encoding_config.otype in ['VanillaFrequency', 'ProgressiveBandHashGrid'] \ |
|
or network_config.otype in ['VanillaMLP']: |
|
encoding = get_encoding(n_input_dims, encoding_config) |
|
network = get_mlp(encoding.n_output_dims, n_output_dims, network_config) |
|
encoding_with_network = EncodingWithNetwork(encoding, network) |
|
else: |
|
with torch.cuda.device(get_rank()): |
|
encoding_with_network = tcnn.NetworkWithInputEncoding( |
|
n_input_dims=n_input_dims, |
|
n_output_dims=n_output_dims, |
|
encoding_config=config_to_primitive(encoding_config), |
|
network_config=config_to_primitive(network_config) |
|
) |
|
return encoding_with_network |
|
|