|
import tensorflow as tf |
|
from tensorflow.keras.layers import Dense,LayerNormalization,Embedding |
|
from tensorflow.keras import Model |
|
import math |
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
class ModelArgs: |
|
n_positions: int = 2048 |
|
vocab_size: int = 51200 |
|
n_embd: int = 2560 |
|
n_head: int = 32 |
|
n_layer: int = 32 |
|
rotary_dim: int = 32 |
|
|
|
|
|
class RoPEAttention: |
|
def __init__(self, dims: int, n_head: int, rotary_dim: int): |
|
self.n_head = n_head |
|
|
|
self.q_proj = Dense(dims) |
|
self.k_proj = Dense(dims) |
|
self.v_proj = Dense(dims) |
|
self.dense = Dense(dims) |
|
|
|
self.rope = RoPE(rotary_dim, traditional=False) |
|
|
|
def __call__(self, x, mask=None, cache=None): |
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) |
|
|
|
|
|
n_head = self.n_head |
|
B, L, D = queries.shape |
|
|
|
|
|
queries = tf.transpose(tf.reshape(queries, (B, L, n_head, -1)), (0, 2, 1, 3)) |
|
keys = tf.transpose(tf.reshape(keys, (B, L, n_head, -1)), (0, 2, 1, 3)) |
|
values = tf.transpose(tf.reshape(values, (B, L, n_head, -1)), (0, 2, 1, 3)) |
|
|
|
|
|
if cache is not None: |
|
key_cache, value_cache = cache |
|
queries = self.rope(queries, offset=key_cache.shape[2]) |
|
keys = self.rope(keys, offset=key_cache.shape[2]) |
|
keys = tf.concat([key_cache, keys], axis=2) |
|
values = tf.concat([value_cache, values], axis=2) |
|
else: |
|
queries = self.rope(queries) |
|
keys = self.rope(keys) |
|
|
|
queries = tf.cast(queries, tf.float32) |
|
keys = tf.cast(keys, tf.float32) |
|
|
|
|
|
scale = math.sqrt(1 / queries.shape[-1]) |
|
scores = tf.matmul((queries * scale), tf.transpose(keys, (0, 1, 3, 2))) |
|
if mask is not None: |
|
scores = scores + mask |
|
|
|
scores = tf.cast(tf.nn.softmax(scores, axis=-1), values.dtype) |
|
values_hat = tf.reshape(tf.transpose(tf.matmul(scores, values), (0, 2, 1, 3)), (B, L, -1)) |
|
|
|
return self.dense(values_hat), (keys, values) |
|
|
|
|
|
class MLP: |
|
def __init__(self, dim, hidden_dim): |
|
self.fc1 = Dense(hidden_dim) |
|
self.fc2 = Dense(dim) |
|
|
|
def __call__(self, x): |
|
return self.fc2(tf.nn.gelu(self.fc1(x), approximate="precise")) |
|
|
|
|
|
class ParallelBlock: |
|
def __init__(self, config: ModelArgs): |
|
dims = config.n_embd |
|
mlp_dims = dims * 4 |
|
self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim) |
|
self.input_layernorm = LayerNormalization() |
|
self.mlp = MLP(dims, mlp_dims) |
|
|
|
def __call__(self, x, mask, cache): |
|
h = self.input_layernorm(x) |
|
attn_h, cache = self.self_attn(h, mask, cache) |
|
ff_h = self.mlp(h) |
|
return attn_h + ff_h + x, cache |
|
|
|
|
|
class Transformer: |
|
def __init__(self, config: ModelArgs): |
|
self.embed_tokens = Embedding(config.vocab_size, config.n_embd) |
|
self.layers = [ParallelBlock(config) for i in range(config.n_layer)] |
|
self.final_layernorm = LayerNormalization() |
|
|
|
def __call__(self, x, mask, cache): |
|
x = self.embed_tokens(x) |
|
if cache is None: |
|
cache = [None] * len(self.layers) |
|
|
|
for e, layer in enumerate(self.layers): |
|
x, cache[e] = layer(x, mask, cache[e]) |
|
return self.final_layernorm(x), cache |
|
|
|
|
|
class Phi2(Model): |
|
def __init__(self, config: ModelArgs): |
|
super(Phi2, self).__init__() |
|
self.model = Transformer(config) |
|
self.lm_head = Dense(config.vocab_size) |
|
|
|
def __call__( |
|
self, |
|
x, |
|
mask = None, |
|
cache = None, |
|
): |
|
mask = None |
|
if x.shape[1] > 1: |
|
mask = tf.fill((x.shape[1], x.shape[1]), float("-inf")) |
|
mask = tf.linalg.band_part(mask, 0, -1) |
|
mask = tf.linalg.set_diag(mask, tf.zeros(x.shape[1])) |
|
mask = tf.cast(mask, x.dtype) |
|
|
|
y, cache = self.model(x, mask, cache) |
|
return self.lm_head(y), cache |
|
|
|
|
|
class RoPE: |
|
def __init__(self, dims: int, traditional: bool = False, base=None): |
|
self.dims = dims |
|
self.traditional = traditional |
|
self.base = base |
|
|
|
def _compute_rope(self, costheta, sintheta, x): |
|
x1 = x[..., : self.dims // 2] |
|
x2 = x[..., self.dims // 2 : self.dims] |
|
rx1 = x1 * costheta - x2 * sintheta |
|
rx2 = x1 * sintheta + x2 * costheta |
|
|
|
if self.dims < x.shape[-1]: |
|
rx = tf.concat([rx1, rx2, x[..., self.dims :]], axis=-1) |
|
else: |
|
rx = tf.concat([rx1, rx2], axis=-1) |
|
|
|
return rx |
|
|
|
def _compute_traditional_rope(self, costheta, sintheta, x): |
|
x1 = x[..., ::2] |
|
x2 = x[..., 1::2] |
|
rx1 = x1 * costheta - x2 * sintheta |
|
rx2 = x1 * sintheta + x2 * costheta |
|
|
|
if self.dims < x.shape[-1]: |
|
raise NotImplementedError( |
|
"RoPE doesn't implement partial traditional application" |
|
) |
|
|
|
rx = tf.concat([rx1[..., None], rx2[..., None]], axis=-1) |
|
|
|
return rx |
|
|
|
def __call__(self, x, offset: int = 0): |
|
shape = x.shape |
|
x = tf.reshape(x, (-1, shape[-2], shape[-1])) |
|
N = x.shape[1] + offset |
|
costheta, sintheta = RoPE.create_cos_sin_theta( |
|
N, self.dims, offset=offset, base=self.base, dtype=x.dtype |
|
) |
|
|
|
rope = ( |
|
self._compute_traditional_rope if self.traditional else self._compute_rope |
|
) |
|
rx = rope(costheta, sintheta, x) |
|
|
|
return tf.reshape(rx, shape) |
|
|
|
@staticmethod |
|
def create_cos_sin_theta( |
|
N: int, |
|
D: int, |
|
offset: int = 0, |
|
base: float = 10000, |
|
dtype=tf.float32, |
|
): |
|
D = D // 2 |
|
positions = tf.range(offset, N, dtype=dtype) |
|
freqs = tf.math.exp( |
|
-tf.range(0, D, dtype=dtype) * (tf.math.log(base) / D) |
|
) |
|
theta = tf.reshape(positions, (-1, 1)) * tf.reshape(freqs, (1, -1)) |
|
costheta = tf.math.cos(theta) |
|
sintheta = tf.math.sin(theta) |
|
|
|
return costheta, sintheta |