|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
import numpy as np |
|
|
|
class FourierPositionEncoding(nn.Module): |
|
def __init__(self, n, num_bands, max_resolution): |
|
""" |
|
Module that generate Fourier encoding - no learning involved |
|
""" |
|
super().__init__() |
|
|
|
self.num_bands = num_bands |
|
self.max_resolution = [max_resolution] * n |
|
|
|
@property |
|
def channels(self): |
|
""" |
|
Return the output dimension |
|
""" |
|
num_dims = len(self.max_resolution) |
|
encoding_size = self.num_bands * num_dims |
|
encoding_size *= 2 |
|
encoding_size += num_dims |
|
|
|
return encoding_size |
|
|
|
def forward(self, pos): |
|
""" |
|
Forward pass that take rays as input and generate Fourier positional encodings |
|
""" |
|
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution) |
|
return fourier_pos_enc |
|
|
|
|
|
def _generate_fourier_features(pos, num_bands, max_resolution): |
|
"""Generate fourier features from a given set of positions and frequencies""" |
|
b, n = pos.shape[:2] |
|
device = pos.device |
|
|
|
|
|
min_freq = 1.0 |
|
freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=device) for res in max_resolution], dim=0) |
|
|
|
|
|
per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0) |
|
per_pos_features = per_pos_features.reshape(b, n, -1) |
|
|
|
|
|
per_pos_features = torch.cat([torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1) |
|
|
|
|
|
per_pos_features = torch.cat([pos, per_pos_features], dim=-1) |
|
|
|
return per_pos_features |