Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +52 -111
AlphaS2S.py
CHANGED
|
@@ -183,136 +183,77 @@ class SwiGLU(layers.Layer):
|
|
| 183 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 184 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 185 |
|
| 186 |
-
class gMLPBlock(layers.Layer):
|
| 187 |
-
def __init__(self, d_model, seq_len, dropout=0.1):
|
| 188 |
-
super().__init__()
|
| 189 |
-
self.d_model = d_model
|
| 190 |
-
self.seq_len = seq_len
|
| 191 |
-
self.norm = layers.LayerNormalization(epsilon=1e-6)
|
| 192 |
-
|
| 193 |
-
# FFN: Channel Expansion
|
| 194 |
-
# d_model * 4๋ก ํ์ฅ
|
| 195 |
-
self.channel_proj = layers.Dense(d_model * 4, use_bias=True)
|
| 196 |
-
self.dropout = layers.Dropout(dropout)
|
| 197 |
-
|
| 198 |
-
# Spatial Gating Unit (SGU)
|
| 199 |
-
self.sgu_norm = layers.LayerNormalization(epsilon=1e-6)
|
| 200 |
-
self.sgu_proj = layers.Dense(seq_len, use_bias=False)
|
| 201 |
-
|
| 202 |
-
# ์ถ๋ ฅ ์ฐจ์์ d_model * 2 (U์ ์ฐจ์)๋ก ์ค์
|
| 203 |
-
self.sgu_final = layers.Dense(d_model * 2, use_bias=True)
|
| 204 |
-
|
| 205 |
-
self.out_proj = layers.Dense(d_model, use_bias=True)
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
# 3. Spatial Gating Unit (SGU)
|
| 217 |
-
v_norm = self.sgu_norm(v)
|
| 218 |
-
v_norm_T = tf.transpose(v_norm, perm=[0, 2, 1]) # (B, 2D, L)
|
| 219 |
-
|
| 220 |
-
# ๐ก ํ ํฐ ๋ฏน์ฑ ๋ฐ์ (์ํ์ค ์ถ์ผ๋ก Dense ์ ์ฉ)
|
| 221 |
-
v_proj = self.sgu_proj(v_norm_T) # (B, 2D, L)
|
| 222 |
-
v_proj_T = tf.transpose(v_proj, perm=[0, 2, 1]) # (B, L, 2D)
|
| 223 |
-
|
| 224 |
-
# 4. Activation and Gate Generation
|
| 225 |
-
# ํ์ค gMLP๋ U์ GELU๋ฅผ ์ ์ฉํ๊ณ V๋ ์ ํ ๊ฒ์ดํธ๋ก ์ฌ์ฉ
|
| 226 |
-
# ์ฌ๊ธฐ์๋ U์ GELU๋ฅผ ์ ์ฉ
|
| 227 |
-
u_act = tf.nn.gelu(u)
|
| 228 |
-
v_gate = self.sgu_final(v_proj_T) # Shape: (B, L, 2*D)
|
| 229 |
-
|
| 230 |
-
# 5. Gating and Contraction
|
| 231 |
-
z = u_act * v_gate # ๊ฒ์ดํ
|
| 232 |
-
z = self.dropout(z, training=training)
|
| 233 |
-
out = self.out_proj(z) # Shape: (B, L, D)
|
| 234 |
-
|
| 235 |
-
# 6. Residual Connection
|
| 236 |
-
return residual + out
|
| 237 |
|
| 238 |
-
class
|
| 239 |
-
def __init__(self,
|
| 240 |
super().__init__()
|
| 241 |
-
self.
|
| 242 |
-
self.
|
| 243 |
-
self.
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
super().__init__()
|
| 252 |
-
self.
|
| 253 |
-
self.
|
| 254 |
-
self.
|
| 255 |
-
self.norm1 = layers.LayerNormalization(epsilon=1e-
|
| 256 |
-
self.
|
| 257 |
-
|
| 258 |
-
self.
|
| 259 |
-
self.
|
| 260 |
-
|
| 261 |
-
def call(self, x,
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
return tf.cast(out, x.dtype)
|
| 272 |
-
|
| 273 |
-
# =======================
|
| 274 |
-
# 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ๋ ์ ์ง)
|
| 275 |
-
# =======================
|
| 276 |
-
|
| 277 |
-
class AlphaS2S(tf.keras.Model):
|
| 278 |
-
def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=200, dropout=0.1):
|
| 279 |
super().__init__()
|
| 280 |
self.max_len = max_len
|
| 281 |
self.d_model = d_model
|
| 282 |
-
|
| 283 |
-
# ์ธ์ฝ๋์ ๋์ฝ๋ ์๋ฒ ๋ฉ ๋ฐ ์์น ์๋ฒ ๋ฉ์ ๋ชจ๋ max_len์ ์ฌ์ฉ
|
| 284 |
self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
|
| 285 |
self.enc_pos_embedding = layers.Embedding(max_len, d_model)
|
| 286 |
self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
|
| 287 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
self.
|
| 291 |
-
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
|
| 292 |
-
|
| 293 |
-
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
|
| 294 |
-
|
| 295 |
def call(self, inputs, training=False):
|
| 296 |
-
|
| 297 |
-
enc_inputs = inputs["enc_inputs"]
|
| 298 |
dec_inputs = inputs["dec_inputs"]
|
| 299 |
-
|
| 300 |
enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
|
| 301 |
dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
|
| 302 |
-
|
| 303 |
-
# ์ธ์ฝ๋ ์คํ
|
| 304 |
x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
|
| 305 |
-
# Note: ๋ง์คํฌ ์์ -> Bi-directional (BERT-like Encoder)
|
| 306 |
for layer in self.enc_layers: x = layer(x, training=training)
|
| 307 |
-
enc_out = x
|
| 308 |
-
|
| 309 |
-
# ๋์ฝ๋ ์คํ
|
| 310 |
y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
|
| 311 |
-
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
| 312 |
-
|
| 313 |
return self.final_layer(y)
|
| 314 |
|
| 315 |
-
|
| 316 |
# 5) ํ์ต ์ค์ ๋ฐ ์คํ
|
| 317 |
# =======================
|
| 318 |
|
|
@@ -343,7 +284,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
|
| 343 |
|
| 344 |
with strategy.scope():
|
| 345 |
# โ ๏ธ ์์ : chat_vocab_size ๋์ ์ ์๋ vocab_size ์ฌ์ฉ
|
| 346 |
-
chat_model =
|
| 347 |
input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
|
| 348 |
|
| 349 |
dummy_input = {
|
|
|
|
| 183 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 184 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
+
class SwiGLU(layers.Layer):
|
| 188 |
+
def __init__(self, d_model, d_ff):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.proj = layers.Dense(d_ff*2)
|
| 191 |
+
self.out = layers.Dense(d_model)
|
| 192 |
+
def call(self, x):
|
| 193 |
+
x_proj = self.proj(x)
|
| 194 |
+
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 195 |
+
return self.out(x_val * tf.nn.silu(x_gate))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
+
class EncoderBlock(layers.Layer):
|
| 198 |
+
def __init__(self, d_model, num_heads, dff, dropout=0.1):
|
| 199 |
super().__init__()
|
| 200 |
+
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 201 |
+
self.ffn = SwiGLU(d_model, dff)
|
| 202 |
+
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 203 |
+
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 204 |
+
self.dropout1 = layers.Dropout(dropout)
|
| 205 |
+
self.dropout2 = layers.Dropout(dropout)
|
| 206 |
+
def call(self, x, mask=None, training=False):
|
| 207 |
+
attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
|
| 208 |
+
out1 = self.norm1(x + attn_out)
|
| 209 |
+
ffn_out = self.dropout2(self.ffn(out1), training=training)
|
| 210 |
+
return self.norm2(out1 + ffn_out)
|
| 211 |
+
|
| 212 |
+
class DecoderBlock(layers.Layer):
|
| 213 |
+
def __init__(self, d_model, num_heads, dff, dropout=0.1):
|
| 214 |
super().__init__()
|
| 215 |
+
self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 216 |
+
self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 217 |
+
self.ffn = SwiGLU(d_model, dff)
|
| 218 |
+
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 219 |
+
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 220 |
+
self.norm3 = layers.LayerNormalization(epsilon=1e-6)
|
| 221 |
+
self.dropout1 = layers.Dropout(dropout)
|
| 222 |
+
self.dropout2 = layers.Dropout(dropout)
|
| 223 |
+
self.dropout3 = layers.Dropout(dropout)
|
| 224 |
+
def call(self, x, enc_out, training=False):
|
| 225 |
+
attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
|
| 226 |
+
out1 = self.norm1(x + attn1)
|
| 227 |
+
attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
|
| 228 |
+
out2 = self.norm2(out1 + attn2)
|
| 229 |
+
ffn_out = self.dropout3(self.ffn(out2), training=training)
|
| 230 |
+
return self.norm3(out2 + ffn_out)
|
| 231 |
+
|
| 232 |
+
class Transformer(tf.keras.Model):
|
| 233 |
+
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
super().__init__()
|
| 235 |
self.max_len = max_len
|
| 236 |
self.d_model = d_model
|
|
|
|
|
|
|
| 237 |
self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
|
| 238 |
self.enc_pos_embedding = layers.Embedding(max_len, d_model)
|
| 239 |
self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
|
| 240 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 241 |
+
self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|
| 242 |
+
self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|
| 243 |
+
self.final_layer = layers.Dense(target_vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
def call(self, inputs, training=False):
|
| 245 |
+
enc_inputs = inputs["enc_inputs"]
|
|
|
|
| 246 |
dec_inputs = inputs["dec_inputs"]
|
|
|
|
| 247 |
enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
|
| 248 |
dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
|
|
|
|
|
|
|
| 249 |
x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
|
|
|
|
| 250 |
for layer in self.enc_layers: x = layer(x, training=training)
|
| 251 |
+
enc_out = x
|
|
|
|
|
|
|
| 252 |
y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
|
| 253 |
+
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
|
|
|
| 254 |
return self.final_layer(y)
|
| 255 |
|
| 256 |
+
|
| 257 |
# 5) ํ์ต ์ค์ ๋ฐ ์คํ
|
| 258 |
# =======================
|
| 259 |
|
|
|
|
| 284 |
|
| 285 |
with strategy.scope():
|
| 286 |
# โ ๏ธ ์์ : chat_vocab_size ๋์ ์ ์๋ vocab_size ์ฌ์ฉ
|
| 287 |
+
chat_model = Transformer(num_layers=4, d_model=160, num_heads=8,
|
| 288 |
input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
|
| 289 |
|
| 290 |
dummy_input = {
|