Yuchan commited on
Commit
696479e
ยท
verified ยท
1 Parent(s): 2cdfb37

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +52 -111
AlphaS2S.py CHANGED
@@ -183,136 +183,77 @@ class SwiGLU(layers.Layer):
183
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
184
  return self.out(x_val * tf.nn.silu(x_gate))
185
 
186
- class gMLPBlock(layers.Layer):
187
- def __init__(self, d_model, seq_len, dropout=0.1):
188
- super().__init__()
189
- self.d_model = d_model
190
- self.seq_len = seq_len
191
- self.norm = layers.LayerNormalization(epsilon=1e-6)
192
-
193
- # FFN: Channel Expansion
194
- # d_model * 4๋กœ ํ™•์žฅ
195
- self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
196
- self.dropout = layers.Dropout(dropout)
197
-
198
- # Spatial Gating Unit (SGU)
199
- self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
200
- self.sgu_proj = layers.Dense(seq_len, use_bias=False)
201
-
202
- # ์ถœ๋ ฅ ์ฐจ์›์„ d_model * 2 (U์˜ ์ฐจ์›)๋กœ ์„ค์ •
203
- self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
204
-
205
- self.out_proj = layers.Dense(d_model, use_bias=True)
206
 
207
- def call(self, x, training=False):
208
- # 1. Norm and Channel Expansion
209
- residual = x
210
- x_norm = self.norm(x)
211
- x_proj = self.channel_proj(x_norm) # Shape: (B, L, 4*D)
212
-
213
- # 2. Split (U and V streams)
214
- u, v = tf.split(x_proj, 2, axis=-1) # u, v Shape: (B, L, 2*D)
215
-
216
- # 3. Spatial Gating Unit (SGU)
217
- v_norm = self.sgu_norm(v)
218
- v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # (B, 2D, L)
219
-
220
- # ๐Ÿ’ก ํ† ํฐ ๋ฏน์‹ฑ ๋ฐœ์ƒ (์‹œํ€€์Šค ์ถ•์œผ๋กœ Dense ์ ์šฉ)
221
- v_proj = self.sgu_proj(v_norm_T) # (B, 2D, L)
222
- v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # (B, L, 2D)
223
-
224
- # 4. Activation and Gate Generation
225
- # ํ‘œ์ค€ gMLP๋Š” U์— GELU๋ฅผ ์ ์šฉํ•˜๊ณ  V๋Š” ์„ ํ˜• ๊ฒŒ์ดํŠธ๋กœ ์‚ฌ์šฉ
226
- # ์—ฌ๊ธฐ์„œ๋Š” U์— GELU๋ฅผ ์ ์šฉ
227
- u_act = tf.nn.gelu(u)
228
- v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
229
-
230
- # 5. Gating and Contraction
231
- z = u_act * v_gate # ๊ฒŒ์ดํŒ…
232
- z = self.dropout(z, training=training)
233
- out = self.out_proj(z) # Shape: (B, L, D)
234
-
235
- # 6. Residual Connection
236
- return residual + out
237
 
238
- class CrossBlock(layers.Layer):
239
- def __init__(self, clip_value=5.0, eps=1e-6): # ๐Ÿ’ก d_model ์ธ์ž ์ถ”๊ฐ€
240
  super().__init__()
241
- self.clip_value = clip_value
242
- self.eps = eps
243
- self.attn = layers.MultiHeadAttention(8, 20)
244
- # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ
245
- def call(self, x, z):
246
- y = self.attn(x, z, z)
247
- return y
248
-
249
- class LoU(layers.Layer):
250
- def __init__(self, d_model, clip_value=5.0, eps=1e-6):
 
 
 
 
251
  super().__init__()
252
- self.d_model = d_model
253
- self.clip_value = float(clip_value)
254
- self.mha = layers.MultiHeadAttention(8, 20)
255
- self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
256
- self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
257
-
258
- self.glu = SwiGLU(d_model, 350)
259
- self.cross = CrossBlock()
260
-
261
- def call(self, x, z):
262
- x_f32 = tf.cast(x, tf.float32)
263
- residual = x_f32
264
- x = self.norm1(x)
265
-
266
- x_comb = self.mha(x, x, x, use_causal_mask=True)
267
-
268
- out = self.norm(x_comb + residual)
269
- out = self.cross(out, z)
270
- out = self.glu(out)
271
- return tf.cast(out, x.dtype)
272
-
273
- # =======================
274
- # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
275
- # =======================
276
-
277
- class AlphaS2S(tf.keras.Model):
278
- def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=200, dropout=0.1):
279
  super().__init__()
280
  self.max_len = max_len
281
  self.d_model = d_model
282
-
283
- # ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋” ์ž„๋ฒ ๋”ฉ ๋ฐ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์€ ๋ชจ๋‘ max_len์„ ์‚ฌ์šฉ
284
  self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
285
  self.enc_pos_embedding = layers.Embedding(max_len, d_model)
286
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
287
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
288
-
289
- # EncoderBlock๊ณผ LoU๋Š” ๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผํ•œ ๊ตฌ์กฐ
290
- self.enc_layers = [gMLPBlock(d_model, seq_len=max_len) for _ in range(num_layers)]
291
- self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
292
-
293
- self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
294
-
295
  def call(self, inputs, training=False):
296
- # enc_inputs์™€ dec_inputs๋Š” ๋™์ผํ•œ ์‹œํ€€์Šค (Unified Input)
297
- enc_inputs = inputs["enc_inputs"]
298
  dec_inputs = inputs["dec_inputs"]
299
-
300
  enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
301
  dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
302
-
303
- # ์ธ์ฝ”๋” ์‹คํ–‰
304
  x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
305
- # Note: ๋งˆ์Šคํฌ ์—†์Œ -> Bi-directional (BERT-like Encoder)
306
  for layer in self.enc_layers: x = layer(x, training=training)
307
- enc_out = x # ์ธ์ฝ”๋”์˜ ์ตœ์ข… ์ถœ๋ ฅ (๋””์ฝ”๋”์˜ 'z' ์ž…๋ ฅ)
308
-
309
- # ๋””์ฝ”๋” ์‹คํ–‰
310
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
311
- for layer in self.dec_layers: y = layer(y, enc_out, training=training)
312
-
313
  return self.final_layer(y)
314
 
315
- # =======================
316
  # 5) ํ•™์Šต ์„ค์ • ๋ฐ ์‹คํ–‰
317
  # =======================
318
 
@@ -343,7 +284,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
343
 
344
  with strategy.scope():
345
  # โš ๏ธ ์ˆ˜์ •: chat_vocab_size ๋Œ€์‹  ์ •์˜๋œ vocab_size ์‚ฌ์šฉ
346
- chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
347
  input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
348
 
349
  dummy_input = {
 
183
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
184
  return self.out(x_val * tf.nn.silu(x_gate))
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ class SwiGLU(layers.Layer):
188
+ def __init__(self, d_model, d_ff):
189
+ super().__init__()
190
+ self.proj = layers.Dense(d_ff*2)
191
+ self.out = layers.Dense(d_model)
192
+ def call(self, x):
193
+ x_proj = self.proj(x)
194
+ x_val, x_gate = tf.split(x_proj, 2, axis=-1)
195
+ return self.out(x_val * tf.nn.silu(x_gate))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ class EncoderBlock(layers.Layer):
198
+ def __init__(self, d_model, num_heads, dff, dropout=0.1):
199
  super().__init__()
200
+ self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
201
+ self.ffn = SwiGLU(d_model, dff)
202
+ self.norm1 = layers.LayerNormalization(epsilon=1e-6)
203
+ self.norm2 = layers.LayerNormalization(epsilon=1e-6)
204
+ self.dropout1 = layers.Dropout(dropout)
205
+ self.dropout2 = layers.Dropout(dropout)
206
+ def call(self, x, mask=None, training=False):
207
+ attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
208
+ out1 = self.norm1(x + attn_out)
209
+ ffn_out = self.dropout2(self.ffn(out1), training=training)
210
+ return self.norm2(out1 + ffn_out)
211
+
212
+ class DecoderBlock(layers.Layer):
213
+ def __init__(self, d_model, num_heads, dff, dropout=0.1):
214
  super().__init__()
215
+ self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
216
+ self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
217
+ self.ffn = SwiGLU(d_model, dff)
218
+ self.norm1 = layers.LayerNormalization(epsilon=1e-6)
219
+ self.norm2 = layers.LayerNormalization(epsilon=1e-6)
220
+ self.norm3 = layers.LayerNormalization(epsilon=1e-6)
221
+ self.dropout1 = layers.Dropout(dropout)
222
+ self.dropout2 = layers.Dropout(dropout)
223
+ self.dropout3 = layers.Dropout(dropout)
224
+ def call(self, x, enc_out, training=False):
225
+ attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
226
+ out1 = self.norm1(x + attn1)
227
+ attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
228
+ out2 = self.norm2(out1 + attn2)
229
+ ffn_out = self.dropout3(self.ffn(out2), training=training)
230
+ return self.norm3(out2 + ffn_out)
231
+
232
+ class Transformer(tf.keras.Model):
233
+ def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
 
 
 
 
 
 
 
 
234
  super().__init__()
235
  self.max_len = max_len
236
  self.d_model = d_model
 
 
237
  self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
238
  self.enc_pos_embedding = layers.Embedding(max_len, d_model)
239
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
240
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
241
+ self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
242
+ self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
243
+ self.final_layer = layers.Dense(target_vocab_size)
 
 
 
 
244
  def call(self, inputs, training=False):
245
+ enc_inputs = inputs["enc_inputs"]
 
246
  dec_inputs = inputs["dec_inputs"]
 
247
  enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
248
  dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
 
 
249
  x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
 
250
  for layer in self.enc_layers: x = layer(x, training=training)
251
+ enc_out = x
 
 
252
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
253
+ for layer in self.dec_layers: y = layer(y, enc_out, training=training)
 
254
  return self.final_layer(y)
255
 
256
+
257
  # 5) ํ•™์Šต ์„ค์ • ๋ฐ ์‹คํ–‰
258
  # =======================
259
 
 
284
 
285
  with strategy.scope():
286
  # โš ๏ธ ์ˆ˜์ •: chat_vocab_size ๋Œ€์‹  ์ •์˜๋œ vocab_size ์‚ฌ์šฉ
287
+ chat_model = Transformer(num_layers=4, d_model=160, num_heads=8,
288
  input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
289
 
290
  dummy_input = {