from typing import Callable from numpy import pi import torch import torch.nn as nn import numpy as np import torch.autograd.profiler as profiler # TODO: rethink encoding mode def encoding_mode( encoding_mode: str, d_min: float, d_max: float, inv_z: bool, EPS: float ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: def _z(xy: torch.Tensor, z: torch.Tensor, distance: torch.Tensor) -> torch.Tensor: if inv_z: z = (1 / z.clamp_min(EPS) - 1 / d_max) / (1 / d_min - 1 / d_max) else: z = (z - d_min) / (d_max - d_min) z = 2 * z - 1 return torch.cat( (xy, z), dim=-1 ) ## concatenates the normalized x, y, and z coordinates def _distance(xy: torch.Tensor, z: torch.Tensor, distance: torch.Tensor): if inv_z: distance = (1 / distance.clamp_min(EPS) - 1 / d_max) / ( 1 / d_min - 1 / d_max ) else: distance = (distance - d_min) / (d_max - d_min) distance = 2 * distance - 1 return torch.cat( (xy, distance), dim=-1 ) ## Apply the positional encoder to the concatenated xy and depth/distance coordinates (it enables the model to capture more complex spatial dependencies without a significant increase in model complexity or training data) match encoding_mode: case "z": return _z case "distance": return _distance case _: return _z class PositionalEncoding(torch.nn.Module): """ Implement NeRF's positional encoding """ def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True): super().__init__() self.num_freqs = num_freqs self.d_in = d_in self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) self.d_out = self.num_freqs * 2 * d_in self.include_input = include_input if include_input: self.d_out += d_in # f1 f1 f2 f2 ... to multiply x by self.register_buffer( "_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) ) # 0 pi/2 0 pi/2 ... so that # (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...) _phases = torch.zeros(2 * self.num_freqs) _phases[1::2] = np.pi * 0.5 self.register_buffer("_phases", _phases.view(1, -1, 1)) def forward(self, x): """ Apply positional encoding (new implementation) :param x (batch, self.d_in) :return (batch, self.d_out) """ with profiler.record_function("positional_enc"): embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1) embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs)) embed = embed.view(x.shape[0], -1) if self.include_input: embed = torch.cat((x, embed), dim=-1) return embed @classmethod def from_conf(cls, conf, d_in=3): # PyHocon construction return cls( conf.get("num_freqs", 6), d_in, conf.get("freq_factor", np.pi), conf.get("include_input", True), ) def token_decoding(filter: nn.Module, pos_offset: float = 0.0): def _decode(xyz: torch.Tensor, tokens: torch.Tensor): """Decode tokens into density for given points Args: x (torch.Tensor): points in xyz n_pts, 3 tokens (torch.Tensor): tokens n_pts, n_tokens, d_in + 2 """ n_pts, n_tokens = tokens.shape with profiler.record_function("positional_enc"): z = xyz[..., 3] scale = tokens[..., 0] # n_pts, n_tokens token_pos_offset = tokens[..., 1] # n_pts, n_tokens weights = tokens[..., 2:] # n_pts, n_tokens, d_in positions = ( 2.0 * (z.unsqueeze(1).unsqueeze(2).repeat(1, n_tokens) - token_pos_offset) / scale - 1.0 ) # n_pts, n_tokens ((z - t_o) / s) * 2.0 - 1.0 t_o => -1.0 t_o + s => 1.0 individual_densities = filter(positions, weights) # n_pts, n_tokens densities = individual_densities.sum(-1) # n_pts return densities return _decode class FourierFilter(nn.Module): # TODO: add filter functions def __init__( self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True, filter_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, ): super().__init__() self.num_freqs = num_freqs self.d_in = d_in self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) self.register_buffer( "_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) ) # 0 pi/2 0 pi/2 ... so that # (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...) _phases = torch.zeros(2 * self.num_freqs) _phases[1::2] = np.pi * 0.5 self.register_buffer("_phases", _phases.view(1, -1, 1)) self.filter_fn = filter_fn def forward(self, positions: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: """Predict density for given normalized points using Fourier features Args: positions (torch.Tensor): normalized positions between -1 and 1, (n_pts, n_tokens) weights (torch.Tensor): weights for each point (n_pts, n_tokens, num_freqs * 2) Returns: torch.Tensor: aggregated density for each point (n_pts) """ with profiler.record_function("positional_enc"): positions = positions.unsqueeze(1).repeat( 1, self.num_freqs * 2, 1 ) # n_pts, num_freqs * 2, n_tokens densities = weights.permute(0, 2, 1) * torch.sin( torch.addcmul(self._phases, positions, self._freqs) ) # n_pts, num_freqs * 2, n_tokens if self.filter_fn is not None: densities = self.filter_fn(densities, positions) return densities.sum(-2) # n_pts, n_tokens @classmethod def from_conf(cls, conf, d_in=3): # PyHocon construction return cls( conf.get("num_freqs", 6), d_in, conf.get("freq_factor", np.pi), ) class LogisticFilter(nn.Module): def __init__(self, slope: float) -> None: super().__init__() self.slope = slope def forward(self, positions: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: """Predict the density as sum of weighted logistic functions Args: positions (torch.Tensor): normalized positions between -1 and 1, (n_pts, n_tokens) weights (torch.Tensor): weights for each point (n_pts, n_tokens, d_in) Returns: torch.Tensor: density for each point (n_pts, n_tokens) """ with profiler.record_function("positional_enc"): weights = weights.squeeze(-1) # n_pts, n_tokens sigmoid_pos = self.slope * positions + 1.0 return ( weights * torch.sigmoid(sigmoid_pos) * torch.sigmoid(-sigmoid_pos) ) # n_pts, n_tokens @classmethod def from_conf(cls, conf): # PyHocon construction return cls(conf.get("slope", 10.0))