Upload Gemma.py
Browse files
Gemma.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Inference-only Gemma model implementation."""
|
15 |
+
|
16 |
+
import tensorflow as tf
|
17 |
+
from tensorflow.keras.layers import Dense
|
18 |
+
from tensorflow.keras import Model
|
19 |
+
import dataclasses
|
20 |
+
|
21 |
+
|
22 |
+
@dataclasses.dataclass
|
23 |
+
class GemmaConfig:
|
24 |
+
# The number of tokens in the vocabulary.
|
25 |
+
vocab_size: int = 256000
|
26 |
+
# The maximum sequence length that this model might ever be used with.
|
27 |
+
max_position_embeddings: int = 8192
|
28 |
+
# The number of blocks in the model.
|
29 |
+
num_hidden_layers: int = 28
|
30 |
+
# The number of attention heads used in the attention layers of the model.
|
31 |
+
num_attention_heads: int = 16
|
32 |
+
# The number of key-value heads for implementing attention.
|
33 |
+
num_key_value_heads: int = 16
|
34 |
+
# The hidden size of the model.
|
35 |
+
hidden_size: int = 3072
|
36 |
+
# The dimension of the MLP representations.
|
37 |
+
intermediate_size: int = 24576
|
38 |
+
# The number of head dimensions.
|
39 |
+
head_dim: int = 256
|
40 |
+
# The epsilon used by the rms normalization layers.
|
41 |
+
rms_norm_eps: float = 1e-6
|
42 |
+
|
43 |
+
|
44 |
+
def precompute_freqs_cis(dim: int,
|
45 |
+
end: int,
|
46 |
+
theta: float = 10000.0):
|
47 |
+
"""Precomputes the frequency cis."""
|
48 |
+
freqs = 1.0 / (theta**(tf.cast(tf.range(0, dim, 2)[:(dim // 2)], 'float32') / dim))
|
49 |
+
t = tf.range(end)
|
50 |
+
freqs = tf.cast(tf.experimental.numpy.outer(t, freqs), 'float32')
|
51 |
+
freqs_cis = tf.complex(tf.ones_like(freqs), freqs) # complex64
|
52 |
+
return freqs_cis
|
53 |
+
|
54 |
+
|
55 |
+
def apply_rotary_emb(x, freqs_cis):
|
56 |
+
"""Applies the rotary embedding to the query and key tensors."""
|
57 |
+
x_ = tf.complex(
|
58 |
+
*tf.split(tf.cast(tf.transpose(x, [0, 2, 1, 3]), 'float32'), num_or_size_splits=2, axis=-1),
|
59 |
+
)
|
60 |
+
x_ = x_ * tf.cast(freqs_cis, x_.dtype)
|
61 |
+
x_out = tf.cast(tf.stack(tf.math.real(x_),
|
62 |
+
tf.math.imag(x_), axis=-1), x.dtype)
|
63 |
+
x_out = tf.concat(tf.split(x_out, num_or_size_splits=2, axis=-1), axis=-2)
|
64 |
+
x_out = tf.transpose(tf.reshape(x_out, (x_out.shape[0], x_out.shape[1], x_out.shape[2],
|
65 |
+
-1)), (0, 2, 1, 3))
|
66 |
+
return x_out
|
67 |
+
|
68 |
+
|
69 |
+
class Embedder:
|
70 |
+
"""Embedder module."""
|
71 |
+
def __init__(self, config: GemmaConfig):
|
72 |
+
self.vocab_size = config.vocab_size
|
73 |
+
self.embed_dim = config.hidden_size
|
74 |
+
self.input_embedding_table = tf.Variable(tf.random.normal((self.vocab_size, self.embed_dim)))
|
75 |
+
|
76 |
+
def encode(self, x):
|
77 |
+
x = tf.gather(self.input_embedding_table, x)
|
78 |
+
x *= tf.cast(tf.math.sqrt(self.embed_dim), x.dtype)
|
79 |
+
return x
|
80 |
+
|
81 |
+
def decode(self, x):
|
82 |
+
return tf.matmul(x, tf.transpose(self.input_embedding_table))
|
83 |
+
|
84 |
+
|
85 |
+
class RMSNorm:
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
dim: int,
|
90 |
+
eps: float = 1e-6,
|
91 |
+
add_unit_offset: bool = True,
|
92 |
+
):
|
93 |
+
self.eps = eps
|
94 |
+
self.add_unit_offset = add_unit_offset
|
95 |
+
self.weight = tf.Variable(tf.random.zeros((dim)))
|
96 |
+
|
97 |
+
def _norm(self, x):
|
98 |
+
return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), axis=-1, keepdims=True) + self.eps)
|
99 |
+
|
100 |
+
def __call__(self, x):
|
101 |
+
x = tf.cast(self._norm(tf.cast(x, 'float32')), x.dtype)
|
102 |
+
if self.add_unit_offset:
|
103 |
+
output = x * (1 + self.weight)
|
104 |
+
else:
|
105 |
+
output = x * self.weight
|
106 |
+
return output
|
107 |
+
|
108 |
+
|
109 |
+
class GemmaMLP:
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
hidden_size: int,
|
114 |
+
intermediate_size: int,
|
115 |
+
):
|
116 |
+
self.gate_proj = Dense(intermediate_size)
|
117 |
+
self.up_proj = Dense(intermediate_size)
|
118 |
+
self.down_proj = Dense(hidden_size)
|
119 |
+
|
120 |
+
def __call__(self, x):
|
121 |
+
gate = self.gate_proj(x)
|
122 |
+
gate = tf.nn.gelu(gate)
|
123 |
+
up = self.up_proj(x)
|
124 |
+
fuse = gate * up
|
125 |
+
outputs = self.down_proj(fuse)
|
126 |
+
return outputs
|
127 |
+
|
128 |
+
|
129 |
+
class GemmaAttention:
|
130 |
+
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
hidden_size: int,
|
134 |
+
num_heads: int,
|
135 |
+
num_kv_heads: int,
|
136 |
+
head_dim: int,
|
137 |
+
):
|
138 |
+
self.num_heads = num_heads
|
139 |
+
self.num_kv_heads = num_kv_heads
|
140 |
+
|
141 |
+
assert self.num_heads % self.num_kv_heads == 0
|
142 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
143 |
+
|
144 |
+
self.hidden_size = hidden_size
|
145 |
+
self.head_dim = head_dim
|
146 |
+
|
147 |
+
self.q_size = self.num_heads * self.head_dim
|
148 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
149 |
+
|
150 |
+
self.scaling = self.head_dim**-0.5
|
151 |
+
|
152 |
+
self.qkv_proj = Dense(
|
153 |
+
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
|
154 |
+
)
|
155 |
+
self.o_proj = Dense(
|
156 |
+
self.hidden_size,
|
157 |
+
)
|
158 |
+
|
159 |
+
def __call__(
|
160 |
+
self,
|
161 |
+
hidden_states,
|
162 |
+
freqs_cis,
|
163 |
+
kv_write_indices,
|
164 |
+
kv_cache,
|
165 |
+
mask,
|
166 |
+
):
|
167 |
+
hidden_states_shape = hidden_states.shape
|
168 |
+
assert len(hidden_states_shape) == 3
|
169 |
+
|
170 |
+
batch_size, input_len, _ = hidden_states_shape
|
171 |
+
|
172 |
+
qkv = self.qkv_proj(hidden_states)
|
173 |
+
xq, xk, xv = tf.split(qkv, [self.q_size, self.kv_size, self.kv_size],
|
174 |
+
axis=-1)
|
175 |
+
|
176 |
+
xq = tf.reshape(xq, (batch_size, -1, self.num_heads, self.head_dim))
|
177 |
+
xk = tf.reshape(xk, (batch_size, -1, self.num_kv_heads, self.head_dim))
|
178 |
+
xv = tf.reshape(xv, (batch_size, -1, self.num_kv_heads, self.head_dim))
|
179 |
+
|
180 |
+
# Positional embedding.
|
181 |
+
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
182 |
+
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
183 |
+
|
184 |
+
# Write new kv cache.
|
185 |
+
# [batch_size, input_len, n_local_kv_heads, head_dim]
|
186 |
+
k_cache, v_cache = kv_cache
|
187 |
+
k_cache.assign(tf.tensor_scatter_nd_update(k_cache, kv_write_indices, xk))
|
188 |
+
v_cache.assign(tf.tensor_scatter_nd_update(v_cache, kv_write_indices, xv))
|
189 |
+
|
190 |
+
key = k_cache
|
191 |
+
value = v_cache
|
192 |
+
if self.num_kv_heads != self.num_heads:
|
193 |
+
# [batch_size, max_seq_len, n_local_heads, head_dim]
|
194 |
+
batch_size, seq_len, num_heads, head_dim = key.shape
|
195 |
+
key = tf.reshape(tf.tile(key[:, :, :, None, :], [1, 1, 1, self.num_queries_per_kv, 1]),
|
196 |
+
[batch_size, seq_len, num_heads * self.num_queries_per_kv, head_dim])
|
197 |
+
batch_size, seq_len, num_heads, head_dim = value.shape
|
198 |
+
value = tf.reshape(tf.tile(value[:, :, :, None, :], [1, 1, 1, self.num_queries_per_kv, 1]),
|
199 |
+
[batch_size, seq_len, num_heads * self.num_queries_per_kv, head_dim])
|
200 |
+
|
201 |
+
# [batch_size, n_local_heads, input_len, head_dim]
|
202 |
+
q = tf.transpose(xq, (0, 2, 1, 3))
|
203 |
+
# [batch_size, n_local_heads, max_seq_len, head_dim]
|
204 |
+
k = tf.transpose(key, (0, 2, 1, 3))
|
205 |
+
v = tf.transpose(value, (0, 2, 1, 3))
|
206 |
+
|
207 |
+
# [batch_size, n_local_heads, input_len, max_seq_len]
|
208 |
+
scores = tf.matmul(q, tf.transpose(k, (0, 1, 3, 2))) * self.scaling
|
209 |
+
scores = scores + mask
|
210 |
+
scores = tf.cast(tf.nn.softmax(tf.cast(scores, 'float32'), axis=-1), q.dtype)
|
211 |
+
|
212 |
+
# [batch_size, n_local_heads, input_len, head_dim]
|
213 |
+
output = tf.matmul(scores, v)
|
214 |
+
|
215 |
+
# [batch_size, input_len, hidden_dim]
|
216 |
+
output = tf.reshape((tf.transpose(output, (0, 2, 1, 3)),
|
217 |
+
(batch_size, input_len, -1)))
|
218 |
+
output = self.o_proj(output)
|
219 |
+
return output
|
220 |
+
|
221 |
+
|
222 |
+
class GemmaDecoderLayer:
|
223 |
+
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
config: GemmaConfig,
|
227 |
+
):
|
228 |
+
self.self_attn = GemmaAttention(
|
229 |
+
hidden_size=config.hidden_size,
|
230 |
+
num_heads=config.num_attention_heads,
|
231 |
+
num_kv_heads=config.num_key_value_heads,
|
232 |
+
head_dim=config.head_dim,
|
233 |
+
)
|
234 |
+
self.mlp = GemmaMLP(
|
235 |
+
hidden_size=config.hidden_size,
|
236 |
+
intermediate_size=config.intermediate_size,
|
237 |
+
)
|
238 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
239 |
+
eps=config.rms_norm_eps)
|
240 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
241 |
+
eps=config.rms_norm_eps)
|
242 |
+
|
243 |
+
def __call__(
|
244 |
+
self,
|
245 |
+
hidden_states,
|
246 |
+
freqs_cis,
|
247 |
+
kv_write_indices,
|
248 |
+
kv_cache,
|
249 |
+
mask,
|
250 |
+
):
|
251 |
+
# Self Attention
|
252 |
+
residual = hidden_states
|
253 |
+
hidden_states = self.input_layernorm(hidden_states)
|
254 |
+
hidden_states = self.self_attn(
|
255 |
+
hidden_states=hidden_states,
|
256 |
+
freqs_cis=freqs_cis,
|
257 |
+
kv_write_indices=kv_write_indices,
|
258 |
+
kv_cache=kv_cache,
|
259 |
+
mask=mask,
|
260 |
+
)
|
261 |
+
hidden_states = residual + hidden_states
|
262 |
+
|
263 |
+
# MLP
|
264 |
+
residual = hidden_states
|
265 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
266 |
+
hidden_states = self.mlp(hidden_states)
|
267 |
+
hidden_states = residual + hidden_states
|
268 |
+
|
269 |
+
return hidden_states
|
270 |
+
|
271 |
+
|
272 |
+
class Gemma(Model):
|
273 |
+
|
274 |
+
def __init__(self, config: GemmaConfig):
|
275 |
+
super(Gemma, self).__init__()
|
276 |
+
self.config = config
|
277 |
+
self.vocab_size = config.vocab_size
|
278 |
+
|
279 |
+
self.embedder = Embedder()
|
280 |
+
self.layers = []
|
281 |
+
for _ in range(config.num_hidden_layers):
|
282 |
+
self.layers.append(GemmaDecoderLayer(config))
|
283 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
284 |
+
self.output = Dense(config.vocab_size)
|
285 |
+
|
286 |
+
def __call__(
|
287 |
+
self,
|
288 |
+
data,
|
289 |
+
freqs_cis,
|
290 |
+
kv_write_indices,
|
291 |
+
kv_caches,
|
292 |
+
mask
|
293 |
+
):
|
294 |
+
hidden_states = self.embedder.encode(data)
|
295 |
+
for i in range(len(self.layers)):
|
296 |
+
layer = self.layers[i]
|
297 |
+
hidden_states = layer(
|
298 |
+
hidden_states=hidden_states,
|
299 |
+
freqs_cis=freqs_cis,
|
300 |
+
kv_write_indices=kv_write_indices,
|
301 |
+
kv_cache=kv_caches[i],
|
302 |
+
mask=mask,
|
303 |
+
)
|
304 |
+
hidden_states = self.norm(hidden_states)
|
305 |
+
logits = self.embedder.decode(hidden_states)
|
306 |
+
return logits
|