Phi2-Keras / Phi2.py
NoteDance's picture
Upload Phi2.py
cfbb81e verified
import tensorflow as tf
from tensorflow.keras.layers import Dense,LayerNormalization,Embedding
from tensorflow.keras import Model
import math
from dataclasses import dataclass
@dataclass
class ModelArgs:
n_positions: int = 2048
vocab_size: int = 51200
n_embd: int = 2560
n_head: int = 32
n_layer: int = 32
rotary_dim: int = 32
class RoPEAttention:
def __init__(self, dims: int, n_head: int, rotary_dim: int):
self.n_head = n_head
self.q_proj = Dense(dims)
self.k_proj = Dense(dims)
self.v_proj = Dense(dims)
self.dense = Dense(dims)
self.rope = RoPE(rotary_dim, traditional=False)
def __call__(self, x, mask=None, cache=None):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Extract some shapes
n_head = self.n_head
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = tf.transpose(tf.reshape(queries, (B, L, n_head, -1)), (0, 2, 1, 3))
keys = tf.transpose(tf.reshape(keys, (B, L, n_head, -1)), (0, 2, 1, 3))
values = tf.transpose(tf.reshape(values, (B, L, n_head, -1)), (0, 2, 1, 3))
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = tf.concat([key_cache, keys], axis=2)
values = tf.concat([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = tf.cast(queries, tf.float32)
keys = tf.cast(keys, tf.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = tf.matmul((queries * scale), tf.transpose(keys, (0, 1, 3, 2)))
if mask is not None:
scores = scores + mask
scores = tf.cast(tf.nn.softmax(scores, axis=-1), values.dtype)
values_hat = tf.reshape(tf.transpose(tf.matmul(scores, values), (0, 2, 1, 3)), (B, L, -1))
return self.dense(values_hat), (keys, values)
class MLP:
def __init__(self, dim, hidden_dim):
self.fc1 = Dense(hidden_dim)
self.fc2 = Dense(dim)
def __call__(self, x):
return self.fc2(tf.nn.gelu(self.fc1(x), approximate="precise"))
class ParallelBlock:
def __init__(self, config: ModelArgs):
dims = config.n_embd
mlp_dims = dims * 4
self.self_attn = RoPEAttention(dims, config.n_head, config.rotary_dim)
self.input_layernorm = LayerNormalization()
self.mlp = MLP(dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
attn_h, cache = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x, cache
class Transformer:
def __init__(self, config: ModelArgs):
self.embed_tokens = Embedding(config.vocab_size, config.n_embd)
self.layers = [ParallelBlock(config) for i in range(config.n_layer)]
self.final_layernorm = LayerNormalization()
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
return self.final_layernorm(x), cache
class Phi2(Model):
def __init__(self, config: ModelArgs):
super(Phi2, self).__init__()
self.model = Transformer(config)
self.lm_head = Dense(config.vocab_size)
def __call__(
self,
x,
mask = None,
cache = None,
):
mask = None
if x.shape[1] > 1:
mask = tf.fill((x.shape[1], x.shape[1]), float("-inf"))
mask = tf.linalg.band_part(mask, 0, -1)
mask = tf.linalg.set_diag(mask, tf.zeros(x.shape[1]))
mask = tf.cast(mask, x.dtype)
y, cache = self.model(x, mask, cache)
return self.lm_head(y), cache
class RoPE:
def __init__(self, dims: int, traditional: bool = False, base=None):
self.dims = dims
self.traditional = traditional
self.base = base
def _compute_rope(self, costheta, sintheta, x):
x1 = x[..., : self.dims // 2]
x2 = x[..., self.dims // 2 : self.dims]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
rx = tf.concat([rx1, rx2, x[..., self.dims :]], axis=-1)
else:
rx = tf.concat([rx1, rx2], axis=-1)
return rx
def _compute_traditional_rope(self, costheta, sintheta, x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * costheta - x2 * sintheta
rx2 = x1 * sintheta + x2 * costheta
if self.dims < x.shape[-1]:
raise NotImplementedError(
"RoPE doesn't implement partial traditional application"
)
rx = tf.concat([rx1[..., None], rx2[..., None]], axis=-1)
return rx
def __call__(self, x, offset: int = 0):
shape = x.shape
x = tf.reshape(x, (-1, shape[-2], shape[-1]))
N = x.shape[1] + offset
costheta, sintheta = RoPE.create_cos_sin_theta(
N, self.dims, offset=offset, base=self.base, dtype=x.dtype
)
rope = (
self._compute_traditional_rope if self.traditional else self._compute_rope
)
rx = rope(costheta, sintheta, x)
return tf.reshape(rx, shape)
@staticmethod
def create_cos_sin_theta(
N: int,
D: int,
offset: int = 0,
base: float = 10000,
dtype=tf.float32,
):
D = D // 2
positions = tf.range(offset, N, dtype=dtype)
freqs = tf.math.exp(
-tf.range(0, D, dtype=dtype) * (tf.math.log(base) / D)
)
theta = tf.reshape(positions, (-1, 1)) * tf.reshape(freqs, (1, -1))
costheta = tf.math.cos(theta)
sintheta = tf.math.sin(theta)
return costheta, sintheta