|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class FrequencyPositionalEmbedding(nn.Module): |
|
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts |
|
each feature dimension of `x[..., i]` into: |
|
[ |
|
sin(x[..., i]), |
|
sin(f_1*x[..., i]), |
|
sin(f_2*x[..., i]), |
|
... |
|
sin(f_N * x[..., i]), |
|
cos(x[..., i]), |
|
cos(f_1*x[..., i]), |
|
cos(f_2*x[..., i]), |
|
... |
|
cos(f_N * x[..., i]), |
|
x[..., i] # only present if include_input is True. |
|
], here f_i is the frequency. |
|
|
|
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. |
|
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; |
|
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. |
|
|
|
Args: |
|
num_freqs (int): the number of frequencies, default is 6; |
|
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], |
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; |
|
input_dim (int): the input dimension, default is 3; |
|
include_input (bool): include the input tensor or not, default is True. |
|
|
|
Attributes: |
|
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], |
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); |
|
|
|
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), |
|
otherwise, it is input_dim * num_freqs * 2. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_freqs: int = 6, |
|
logspace: bool = True, |
|
input_dim: int = 3, |
|
include_input: bool = True, |
|
include_pi: bool = True, |
|
) -> None: |
|
"""The initialization""" |
|
|
|
super().__init__() |
|
|
|
if logspace: |
|
frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32) |
|
else: |
|
frequencies = torch.linspace( |
|
1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32 |
|
) |
|
|
|
if include_pi: |
|
frequencies *= torch.pi |
|
|
|
self.register_buffer("frequencies", frequencies, persistent=False) |
|
self.include_input = include_input |
|
self.num_freqs = num_freqs |
|
|
|
self.out_dim = self.get_dims(input_dim) |
|
|
|
def get_dims(self, input_dim): |
|
temp = 1 if self.include_input or self.num_freqs == 0 else 0 |
|
out_dim = input_dim * (self.num_freqs * 2 + temp) |
|
|
|
return out_dim |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Forward process. |
|
|
|
Args: |
|
x: tensor of shape [..., dim] |
|
|
|
Returns: |
|
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] |
|
where temp is 1 if include_input is True and 0 otherwise. |
|
""" |
|
|
|
if self.num_freqs > 0: |
|
embed = (x[..., None].contiguous() * self.frequencies).view( |
|
*x.shape[:-1], -1 |
|
) |
|
if self.include_input: |
|
return torch.cat((x, embed.sin(), embed.cos()), dim=-1) |
|
else: |
|
return torch.cat((embed.sin(), embed.cos()), dim=-1) |
|
else: |
|
return x |
|
|