Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +6 -2
AlphaS2S.py
CHANGED
|
@@ -166,13 +166,15 @@ class EncoderBlock(layers.Layer):
|
|
| 166 |
super().__init__()
|
| 167 |
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 168 |
self.ffn = SwiGLU(d_model, dff)
|
|
|
|
| 169 |
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 170 |
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 171 |
self.dropout1 = layers.Dropout(dropout)
|
| 172 |
self.dropout2 = layers.Dropout(dropout)
|
| 173 |
def call(self, x, mask=None, training=False):
|
|
|
|
| 174 |
attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
|
| 175 |
-
out1 = self.norm1(attn_out)
|
| 176 |
ffn_out = self.dropout2(self.ffn(out1), training=training)
|
| 177 |
return self.norm2(out1 + ffn_out)
|
| 178 |
|
|
@@ -182,6 +184,7 @@ class DecoderBlock(layers.Layer):
|
|
| 182 |
self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 183 |
self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 184 |
self.ffn = SwiGLU(d_model, dff)
|
|
|
|
| 185 |
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 186 |
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 187 |
self.norm3 = layers.LayerNormalization(epsilon=1e-6)
|
|
@@ -189,8 +192,9 @@ class DecoderBlock(layers.Layer):
|
|
| 189 |
self.dropout2 = layers.Dropout(dropout)
|
| 190 |
self.dropout3 = layers.Dropout(dropout)
|
| 191 |
def call(self, x, enc_out, training=False):
|
|
|
|
| 192 |
attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
|
| 193 |
-
out1 = self.norm1(attn1)
|
| 194 |
attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
|
| 195 |
out2 = self.norm2(out1 + attn2)
|
| 196 |
ffn_out = self.dropout3(self.ffn(out2), training=training)
|
|
|
|
| 166 |
super().__init__()
|
| 167 |
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 168 |
self.ffn = SwiGLU(d_model, dff)
|
| 169 |
+
self.proj = layers.Dense(d_model)
|
| 170 |
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 171 |
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 172 |
self.dropout1 = layers.Dropout(dropout)
|
| 173 |
self.dropout2 = layers.Dropout(dropout)
|
| 174 |
def call(self, x, mask=None, training=False):
|
| 175 |
+
x = self.proj(x)
|
| 176 |
attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
|
| 177 |
+
out1 = self.norm1(attn_out + x)
|
| 178 |
ffn_out = self.dropout2(self.ffn(out1), training=training)
|
| 179 |
return self.norm2(out1 + ffn_out)
|
| 180 |
|
|
|
|
| 184 |
self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 185 |
self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 186 |
self.ffn = SwiGLU(d_model, dff)
|
| 187 |
+
self.proj = layers.Dense(d_model)
|
| 188 |
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 189 |
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 190 |
self.norm3 = layers.LayerNormalization(epsilon=1e-6)
|
|
|
|
| 192 |
self.dropout2 = layers.Dropout(dropout)
|
| 193 |
self.dropout3 = layers.Dropout(dropout)
|
| 194 |
def call(self, x, enc_out, training=False):
|
| 195 |
+
x = self.proj(x)
|
| 196 |
attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
|
| 197 |
+
out1 = self.norm1(attn1 + x)
|
| 198 |
attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
|
| 199 |
out2 = self.norm2(out1 + attn2)
|
| 200 |
ffn_out = self.dropout3(self.ffn(out2), training=training)
|