SATO / models /pos_encoding.py
chencws's picture
Upload 36 files
23586c3 verified
raw
history blame contribute delete
No virus
1.37 kB
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
def PE1d_sincos(seq_length, dim):
"""
:param d_model: dimension of the model
:param length: length of positions
:return: length*d_model position matrix
"""
if dim % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(dim))
pe = torch.zeros(seq_length, dim)
position = torch.arange(0, seq_length).unsqueeze(1)
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
-(math.log(10000.0) / dim)))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
return pe.unsqueeze(1)
class PositionEmbedding(nn.Module):
"""
Absolute pos embedding (standard), learned.
"""
def __init__(self, seq_length, dim, dropout, grad=False):
super().__init__()
self.embed = nn.Parameter(data=PE1d_sincos(seq_length, dim), requires_grad=grad)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
# x.shape: bs, seq_len, feat_dim
l = x.shape[1]
x = x.permute(1, 0, 2) + self.embed[:l].expand(x.permute(1, 0, 2).shape)
x = self.dropout(x.permute(1, 0, 2))
return x