|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from einops import rearrange |
|
from torch import einsum, nn |
|
|
|
__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb'] |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
""" |
|
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. |
|
""" |
|
|
|
def __init__( |
|
self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None |
|
): |
|
""" |
|
Args: |
|
|
|
dim (int): rotary embedding dimension |
|
seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated |
|
by this factor via the trick in https://arxiv.org/abs/2306.15595. |
|
pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation. |
|
""" |
|
super().__init__() |
|
self.seq_len_interpolation_factor = seq_len_interpolation_factor |
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer('inv_freq', inv_freq) |
|
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings |
|
|
|
def forward(self, max_seq_len, offset=0): |
|
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset |
|
seq = seq.type_as(self.inv_freq) |
|
|
|
if self.pretrained_max_position_embeddings is not None and self.seq_len_interpolation_factor is not None: |
|
if max_seq_len > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor: |
|
|
|
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) |
|
else: |
|
|
|
seq *= 1 / self.seq_len_interpolation_factor |
|
|
|
freqs = einsum('i , j -> i j', seq, self.inv_freq) |
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
return rearrange(emb, 'n d -> n 1 1 d') |
|
|
|
|
|
def _rotate_half(x): |
|
""" |
|
change sign so the last dimension |
|
[A, B, C, D] -> [-C, -D, A, B] |
|
""" |
|
x = rearrange(x, '... (j d) -> ... j d', j=2) |
|
x1, x2 = x.unbind(dim=-2) |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(t, freqs): |
|
""" |
|
input tensor t is of shape [seq_length, ..., dim] |
|
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] |
|
check https://kexue.fm/archives/8265 for detailed formulas |
|
""" |
|
|
|
|
|
|
|
|
|
freqs = freqs.permute(1, 2, 0, 3) |
|
|
|
assert freqs.shape[-2] >= t.shape[-2] |
|
if freqs.shape[-2] != t.shape[-2]: |
|
freqs = freqs[:, :, -t.shape[-2]:, :] |
|
|
|
rot_dim = freqs.shape[-1] |
|
|
|
t, t_pass = t[..., :rot_dim], t[..., rot_dim:] |
|
|
|
|
|
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin()) |
|
return torch.cat((t, t_pass), dim=-1) |
|
|