import torch class RotaryPositionalEmbedding(torch.nn.Module): def __init__(self, dim, base=10000, precision=torch.half): """Rotary positional embedding Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf Args: dim: Dimension of embedding base: Base value for exponential precision: precision to use for numerical values """ super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self.seq_len_cached = 0 self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim) self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim) self.precision = precision def forward(self, x, seq_len: int = 0): """ Args: x: Input x with T X B X C seq_len: Sequence length of input x """ if seq_len > self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1)) self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1)) return self.cos_cached, self.sin_cached # rotary pos emb helpers: def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat( (-x2, x1), dim=x1.ndim - 1 ) # dim=-1 triggers a bug in earlier torch versions def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): cos, sin = ( cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...], ) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)