multi-hmr / blocks /camera_embed.py
fbaradel's picture
first commit
934fdee
# Multi-HMR
# Copyright (c) 2024-present NAVER Corp.
# CC BY-NC-SA 4.0 license
import torch
from torch import nn
import numpy as np
class FourierPositionEncoding(nn.Module):
def __init__(self, n, num_bands, max_resolution):
"""
Module that generate Fourier encoding - no learning involved
"""
super().__init__()
self.num_bands = num_bands
self.max_resolution = [max_resolution] * n
@property
def channels(self):
"""
Return the output dimension
"""
num_dims = len(self.max_resolution)
encoding_size = self.num_bands * num_dims
encoding_size *= 2 # sin-cos
encoding_size += num_dims # concat
return encoding_size
def forward(self, pos):
"""
Forward pass that take rays as input and generate Fourier positional encodings
"""
fourier_pos_enc = _generate_fourier_features(pos, num_bands=self.num_bands, max_resolution=self.max_resolution)
return fourier_pos_enc
def _generate_fourier_features(pos, num_bands, max_resolution):
"""Generate fourier features from a given set of positions and frequencies"""
b, n = pos.shape[:2]
device = pos.device
# Linear frequency sampling
min_freq = 1.0
freq_bands = torch.stack([torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=device) for res in max_resolution], dim=0)
# Stacking
per_pos_features = torch.stack([pos[i, :, :][:, :, None] * freq_bands[None, :, :] for i in range(b)], 0)
per_pos_features = per_pos_features.reshape(b, n, -1)
# Sin-Cos
per_pos_features = torch.cat([torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1)
# Concat with initial pos
per_pos_features = torch.cat([pos, per_pos_features], dim=-1)
return per_pos_features