Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +2 -4
AlphaS2S.py
CHANGED
|
@@ -230,7 +230,7 @@ class DecoderBlock(layers.Layer):
|
|
| 230 |
return self.norm3(out2 + ffn_out)
|
| 231 |
|
| 232 |
class Transformer(tf.keras.Model):
|
| 233 |
-
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=
|
| 234 |
super().__init__()
|
| 235 |
self.max_len = max_len
|
| 236 |
self.d_model = d_model
|
|
@@ -253,7 +253,6 @@ class Transformer(tf.keras.Model):
|
|
| 253 |
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
| 254 |
return self.final_layer(y)
|
| 255 |
|
| 256 |
-
|
| 257 |
# 5) νμ΅ μ€μ λ° μ€ν
|
| 258 |
# =======================
|
| 259 |
|
|
@@ -284,8 +283,7 @@ def create_lr_schedule(initial_lr=5e-5, decay_steps=10000, decay_rate=0.9):
|
|
| 284 |
|
| 285 |
with strategy.scope():
|
| 286 |
# β οΈ μμ : chat_vocab_size λμ μ μλ vocab_size μ¬μ©
|
| 287 |
-
chat_model = Transformer(num_layers=4, d_model=
|
| 288 |
-
input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=max_len)
|
| 289 |
|
| 290 |
dummy_input = {
|
| 291 |
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
|
|
|
|
| 230 |
return self.norm3(out2 + ffn_out)
|
| 231 |
|
| 232 |
class Transformer(tf.keras.Model):
|
| 233 |
+
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=256, dropout=0.1):
|
| 234 |
super().__init__()
|
| 235 |
self.max_len = max_len
|
| 236 |
self.d_model = d_model
|
|
|
|
| 253 |
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
| 254 |
return self.final_layer(y)
|
| 255 |
|
|
|
|
| 256 |
# 5) νμ΅ μ€μ λ° μ€ν
|
| 257 |
# =======================
|
| 258 |
|
|
|
|
| 283 |
|
| 284 |
with strategy.scope():
|
| 285 |
# β οΈ μμ : chat_vocab_size λμ μ μλ vocab_size μ¬μ©
|
| 286 |
+
chat_model = Transformer(num_layers=4, d_model=512, num_heads=8, dff=2048, input_vocab_size=vocab_size, target_vocab_size=vocab_size, max_len=256, dropout=0.1)
|
|
|
|
| 287 |
|
| 288 |
dummy_input = {
|
| 289 |
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
|