|
import jax.numpy as jnp |
|
import numpy as np |
|
from einops import rearrange, repeat |
|
|
|
|
|
|
|
|
|
|
|
def fixed_pos_embedding(x, seq_dim=1, seq_len=None, position_ids=None): |
|
dim = x.shape[-1] |
|
if seq_len is None: |
|
seq_len = x.shape[seq_dim] |
|
if position_ids is None: |
|
position_ids = np.arange(seq_len) |
|
|
|
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim)) |
|
sinusoid_inp = jnp.einsum("i... , j -> i... j", position_ids, inv_freq) |
|
return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) |
|
|
|
|
|
def rotate_every_two(x): |
|
x1 = x[:, :, :, ::2] |
|
x2 = x[:, :, :, 1::2] |
|
x = jnp.stack((-x2, x1), axis=-1) |
|
return rearrange(x, "... d j -> ... (d j)") |
|
|
|
|
|
def apply_rotary_pos_emb(x, sincos, offset=0): |
|
sin, cos = map(lambda t: repeat(t[:, offset : x.shape[1] + offset, :], "b n d -> b n () (d j)", j=2), sincos) |
|
return (x * cos) + (rotate_every_two(x) * sin) |
|
|