File size: 992 Bytes
e565538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import jax.numpy as jnp
import numpy as np
from einops import rearrange, repeat

# Taken from Ben Wang's mesh-transformer-jax implementation
# https://github.com/kingoflolz/mesh-transformer-jax


def fixed_pos_embedding(x, seq_dim=1, seq_len=None, position_ids=None):
    dim = x.shape[-1]
    if seq_len is None:
        seq_len = x.shape[seq_dim]
    if position_ids is None:
        position_ids = np.arange(seq_len)

    inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
    sinusoid_inp = jnp.einsum("i... , j -> i... j", position_ids, inv_freq)
    return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp)


def rotate_every_two(x):
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = jnp.stack((-x2, x1), axis=-1)
    return rearrange(x, "... d j -> ... (d j)")


def apply_rotary_pos_emb(x, sincos, offset=0):
    sin, cos = map(lambda t: repeat(t[:, offset : x.shape[1] + offset, :], "b n d -> b n () (d j)", j=2), sincos)
    return (x * cos) + (rotate_every_two(x) * sin)