NoteDance commited on
Commit
2558fcc
1 Parent(s): bb3339b

Upload Gemma.py

Browse files
Files changed (1) hide show
  1. Gemma.py +306 -0
Gemma.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Inference-only Gemma model implementation."""
15
+
16
+ import tensorflow as tf
17
+ from tensorflow.keras.layers import Dense
18
+ from tensorflow.keras import Model
19
+ import dataclasses
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class GemmaConfig:
24
+ # The number of tokens in the vocabulary.
25
+ vocab_size: int = 256000
26
+ # The maximum sequence length that this model might ever be used with.
27
+ max_position_embeddings: int = 8192
28
+ # The number of blocks in the model.
29
+ num_hidden_layers: int = 28
30
+ # The number of attention heads used in the attention layers of the model.
31
+ num_attention_heads: int = 16
32
+ # The number of key-value heads for implementing attention.
33
+ num_key_value_heads: int = 16
34
+ # The hidden size of the model.
35
+ hidden_size: int = 3072
36
+ # The dimension of the MLP representations.
37
+ intermediate_size: int = 24576
38
+ # The number of head dimensions.
39
+ head_dim: int = 256
40
+ # The epsilon used by the rms normalization layers.
41
+ rms_norm_eps: float = 1e-6
42
+
43
+
44
+ def precompute_freqs_cis(dim: int,
45
+ end: int,
46
+ theta: float = 10000.0):
47
+ """Precomputes the frequency cis."""
48
+ freqs = 1.0 / (theta**(tf.cast(tf.range(0, dim, 2)[:(dim // 2)], 'float32') / dim))
49
+ t = tf.range(end)
50
+ freqs = tf.cast(tf.experimental.numpy.outer(t, freqs), 'float32')
51
+ freqs_cis = tf.complex(tf.ones_like(freqs), freqs) # complex64
52
+ return freqs_cis
53
+
54
+
55
+ def apply_rotary_emb(x, freqs_cis):
56
+ """Applies the rotary embedding to the query and key tensors."""
57
+ x_ = tf.complex(
58
+ *tf.split(tf.cast(tf.transpose(x, [0, 2, 1, 3]), 'float32'), num_or_size_splits=2, axis=-1),
59
+ )
60
+ x_ = x_ * tf.cast(freqs_cis, x_.dtype)
61
+ x_out = tf.cast(tf.stack(tf.math.real(x_),
62
+ tf.math.imag(x_), axis=-1), x.dtype)
63
+ x_out = tf.concat(tf.split(x_out, num_or_size_splits=2, axis=-1), axis=-2)
64
+ x_out = tf.transpose(tf.reshape(x_out, (x_out.shape[0], x_out.shape[1], x_out.shape[2],
65
+ -1)), (0, 2, 1, 3))
66
+ return x_out
67
+
68
+
69
+ class Embedder:
70
+ """Embedder module."""
71
+ def __init__(self, config: GemmaConfig):
72
+ self.vocab_size = config.vocab_size
73
+ self.embed_dim = config.hidden_size
74
+ self.input_embedding_table = tf.Variable(tf.random.normal((self.vocab_size, self.embed_dim)))
75
+
76
+ def encode(self, x):
77
+ x = tf.gather(self.input_embedding_table, x)
78
+ x *= tf.cast(tf.math.sqrt(self.embed_dim), x.dtype)
79
+ return x
80
+
81
+ def decode(self, x):
82
+ return tf.matmul(x, tf.transpose(self.input_embedding_table))
83
+
84
+
85
+ class RMSNorm:
86
+
87
+ def __init__(
88
+ self,
89
+ dim: int,
90
+ eps: float = 1e-6,
91
+ add_unit_offset: bool = True,
92
+ ):
93
+ self.eps = eps
94
+ self.add_unit_offset = add_unit_offset
95
+ self.weight = tf.Variable(tf.random.zeros((dim)))
96
+
97
+ def _norm(self, x):
98
+ return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), axis=-1, keepdims=True) + self.eps)
99
+
100
+ def __call__(self, x):
101
+ x = tf.cast(self._norm(tf.cast(x, 'float32')), x.dtype)
102
+ if self.add_unit_offset:
103
+ output = x * (1 + self.weight)
104
+ else:
105
+ output = x * self.weight
106
+ return output
107
+
108
+
109
+ class GemmaMLP:
110
+
111
+ def __init__(
112
+ self,
113
+ hidden_size: int,
114
+ intermediate_size: int,
115
+ ):
116
+ self.gate_proj = Dense(intermediate_size)
117
+ self.up_proj = Dense(intermediate_size)
118
+ self.down_proj = Dense(hidden_size)
119
+
120
+ def __call__(self, x):
121
+ gate = self.gate_proj(x)
122
+ gate = tf.nn.gelu(gate)
123
+ up = self.up_proj(x)
124
+ fuse = gate * up
125
+ outputs = self.down_proj(fuse)
126
+ return outputs
127
+
128
+
129
+ class GemmaAttention:
130
+
131
+ def __init__(
132
+ self,
133
+ hidden_size: int,
134
+ num_heads: int,
135
+ num_kv_heads: int,
136
+ head_dim: int,
137
+ ):
138
+ self.num_heads = num_heads
139
+ self.num_kv_heads = num_kv_heads
140
+
141
+ assert self.num_heads % self.num_kv_heads == 0
142
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
143
+
144
+ self.hidden_size = hidden_size
145
+ self.head_dim = head_dim
146
+
147
+ self.q_size = self.num_heads * self.head_dim
148
+ self.kv_size = self.num_kv_heads * self.head_dim
149
+
150
+ self.scaling = self.head_dim**-0.5
151
+
152
+ self.qkv_proj = Dense(
153
+ (self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
154
+ )
155
+ self.o_proj = Dense(
156
+ self.hidden_size,
157
+ )
158
+
159
+ def __call__(
160
+ self,
161
+ hidden_states,
162
+ freqs_cis,
163
+ kv_write_indices,
164
+ kv_cache,
165
+ mask,
166
+ ):
167
+ hidden_states_shape = hidden_states.shape
168
+ assert len(hidden_states_shape) == 3
169
+
170
+ batch_size, input_len, _ = hidden_states_shape
171
+
172
+ qkv = self.qkv_proj(hidden_states)
173
+ xq, xk, xv = tf.split(qkv, [self.q_size, self.kv_size, self.kv_size],
174
+ axis=-1)
175
+
176
+ xq = tf.reshape(xq, (batch_size, -1, self.num_heads, self.head_dim))
177
+ xk = tf.reshape(xk, (batch_size, -1, self.num_kv_heads, self.head_dim))
178
+ xv = tf.reshape(xv, (batch_size, -1, self.num_kv_heads, self.head_dim))
179
+
180
+ # Positional embedding.
181
+ xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
182
+ xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
183
+
184
+ # Write new kv cache.
185
+ # [batch_size, input_len, n_local_kv_heads, head_dim]
186
+ k_cache, v_cache = kv_cache
187
+ k_cache.assign(tf.tensor_scatter_nd_update(k_cache, kv_write_indices, xk))
188
+ v_cache.assign(tf.tensor_scatter_nd_update(v_cache, kv_write_indices, xv))
189
+
190
+ key = k_cache
191
+ value = v_cache
192
+ if self.num_kv_heads != self.num_heads:
193
+ # [batch_size, max_seq_len, n_local_heads, head_dim]
194
+ batch_size, seq_len, num_heads, head_dim = key.shape
195
+ key = tf.reshape(tf.tile(key[:, :, :, None, :], [1, 1, 1, self.num_queries_per_kv, 1]),
196
+ [batch_size, seq_len, num_heads * self.num_queries_per_kv, head_dim])
197
+ batch_size, seq_len, num_heads, head_dim = value.shape
198
+ value = tf.reshape(tf.tile(value[:, :, :, None, :], [1, 1, 1, self.num_queries_per_kv, 1]),
199
+ [batch_size, seq_len, num_heads * self.num_queries_per_kv, head_dim])
200
+
201
+ # [batch_size, n_local_heads, input_len, head_dim]
202
+ q = tf.transpose(xq, (0, 2, 1, 3))
203
+ # [batch_size, n_local_heads, max_seq_len, head_dim]
204
+ k = tf.transpose(key, (0, 2, 1, 3))
205
+ v = tf.transpose(value, (0, 2, 1, 3))
206
+
207
+ # [batch_size, n_local_heads, input_len, max_seq_len]
208
+ scores = tf.matmul(q, tf.transpose(k, (0, 1, 3, 2))) * self.scaling
209
+ scores = scores + mask
210
+ scores = tf.cast(tf.nn.softmax(tf.cast(scores, 'float32'), axis=-1), q.dtype)
211
+
212
+ # [batch_size, n_local_heads, input_len, head_dim]
213
+ output = tf.matmul(scores, v)
214
+
215
+ # [batch_size, input_len, hidden_dim]
216
+ output = tf.reshape((tf.transpose(output, (0, 2, 1, 3)),
217
+ (batch_size, input_len, -1)))
218
+ output = self.o_proj(output)
219
+ return output
220
+
221
+
222
+ class GemmaDecoderLayer:
223
+
224
+ def __init__(
225
+ self,
226
+ config: GemmaConfig,
227
+ ):
228
+ self.self_attn = GemmaAttention(
229
+ hidden_size=config.hidden_size,
230
+ num_heads=config.num_attention_heads,
231
+ num_kv_heads=config.num_key_value_heads,
232
+ head_dim=config.head_dim,
233
+ )
234
+ self.mlp = GemmaMLP(
235
+ hidden_size=config.hidden_size,
236
+ intermediate_size=config.intermediate_size,
237
+ )
238
+ self.input_layernorm = RMSNorm(config.hidden_size,
239
+ eps=config.rms_norm_eps)
240
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
241
+ eps=config.rms_norm_eps)
242
+
243
+ def __call__(
244
+ self,
245
+ hidden_states,
246
+ freqs_cis,
247
+ kv_write_indices,
248
+ kv_cache,
249
+ mask,
250
+ ):
251
+ # Self Attention
252
+ residual = hidden_states
253
+ hidden_states = self.input_layernorm(hidden_states)
254
+ hidden_states = self.self_attn(
255
+ hidden_states=hidden_states,
256
+ freqs_cis=freqs_cis,
257
+ kv_write_indices=kv_write_indices,
258
+ kv_cache=kv_cache,
259
+ mask=mask,
260
+ )
261
+ hidden_states = residual + hidden_states
262
+
263
+ # MLP
264
+ residual = hidden_states
265
+ hidden_states = self.post_attention_layernorm(hidden_states)
266
+ hidden_states = self.mlp(hidden_states)
267
+ hidden_states = residual + hidden_states
268
+
269
+ return hidden_states
270
+
271
+
272
+ class Gemma(Model):
273
+
274
+ def __init__(self, config: GemmaConfig):
275
+ super(Gemma, self).__init__()
276
+ self.config = config
277
+ self.vocab_size = config.vocab_size
278
+
279
+ self.embedder = Embedder()
280
+ self.layers = []
281
+ for _ in range(config.num_hidden_layers):
282
+ self.layers.append(GemmaDecoderLayer(config))
283
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
284
+ self.output = Dense(config.vocab_size)
285
+
286
+ def __call__(
287
+ self,
288
+ data,
289
+ freqs_cis,
290
+ kv_write_indices,
291
+ kv_caches,
292
+ mask
293
+ ):
294
+ hidden_states = self.embedder.encode(data)
295
+ for i in range(len(self.layers)):
296
+ layer = self.layers[i]
297
+ hidden_states = layer(
298
+ hidden_states=hidden_states,
299
+ freqs_cis=freqs_cis,
300
+ kv_write_indices=kv_write_indices,
301
+ kv_cache=kv_caches[i],
302
+ mask=mask,
303
+ )
304
+ hidden_states = self.norm(hidden_states)
305
+ logits = self.embedder.decode(hidden_states)
306
+ return logits