alverciito
fix positional encoding error
6faa82b
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
class PositionalEncoding(torch.nn.Module):
"""
Sinusoidal positional encoding module for Transformer models.
This module injects information about the relative or absolute position of
tokens in a sequence by adding fixed sinusoidal embeddings to the input
embeddings. The positional encodings are non-learnable and follow the
formulation introduced in the original Transformer architecture.
"""
def __init__(self, emb_dim: int, max_len: int = 5000, **kwargs):
"""
Initialize the positional encoding module.
Parameters
----------
emb_dim : int
Dimensionality of the embedding space.
max_len : int, optional
Maximum supported sequence length for which positional encodings
are precomputed.
"""
super().__init__(**kwargs)
# Create positional encodings:
pe = torch.zeros(max_len, emb_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / emb_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Register as a buffer:
self.register_buffer('positional_encoding', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encodings to the input embeddings.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, sequence_length, emb_dim).
Returns
-------
torch.Tensor
Tensor of the same shape as the input with positional encodings added.
"""
return x + self.positional_encoding[:, :x.size(-2), :]
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #