Yuchan commited on
Commit
859ea70
·
verified ·
1 Parent(s): 1bf639d

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +6 -2
AlphaS2S.py CHANGED
@@ -166,13 +166,15 @@ class EncoderBlock(layers.Layer):
166
  super().__init__()
167
  self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
168
  self.ffn = SwiGLU(d_model, dff)
 
169
  self.norm1 = layers.LayerNormalization(epsilon=1e-6)
170
  self.norm2 = layers.LayerNormalization(epsilon=1e-6)
171
  self.dropout1 = layers.Dropout(dropout)
172
  self.dropout2 = layers.Dropout(dropout)
173
  def call(self, x, mask=None, training=False):
 
174
  attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
175
- out1 = self.norm1(attn_out)
176
  ffn_out = self.dropout2(self.ffn(out1), training=training)
177
  return self.norm2(out1 + ffn_out)
178
 
@@ -182,6 +184,7 @@ class DecoderBlock(layers.Layer):
182
  self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
183
  self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
184
  self.ffn = SwiGLU(d_model, dff)
 
185
  self.norm1 = layers.LayerNormalization(epsilon=1e-6)
186
  self.norm2 = layers.LayerNormalization(epsilon=1e-6)
187
  self.norm3 = layers.LayerNormalization(epsilon=1e-6)
@@ -189,8 +192,9 @@ class DecoderBlock(layers.Layer):
189
  self.dropout2 = layers.Dropout(dropout)
190
  self.dropout3 = layers.Dropout(dropout)
191
  def call(self, x, enc_out, training=False):
 
192
  attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
193
- out1 = self.norm1(attn1)
194
  attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
195
  out2 = self.norm2(out1 + attn2)
196
  ffn_out = self.dropout3(self.ffn(out2), training=training)
 
166
  super().__init__()
167
  self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
168
  self.ffn = SwiGLU(d_model, dff)
169
+ self.proj = layers.Dense(d_model)
170
  self.norm1 = layers.LayerNormalization(epsilon=1e-6)
171
  self.norm2 = layers.LayerNormalization(epsilon=1e-6)
172
  self.dropout1 = layers.Dropout(dropout)
173
  self.dropout2 = layers.Dropout(dropout)
174
  def call(self, x, mask=None, training=False):
175
+ x = self.proj(x)
176
  attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
177
+ out1 = self.norm1(attn_out + x)
178
  ffn_out = self.dropout2(self.ffn(out1), training=training)
179
  return self.norm2(out1 + ffn_out)
180
 
 
184
  self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
185
  self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
186
  self.ffn = SwiGLU(d_model, dff)
187
+ self.proj = layers.Dense(d_model)
188
  self.norm1 = layers.LayerNormalization(epsilon=1e-6)
189
  self.norm2 = layers.LayerNormalization(epsilon=1e-6)
190
  self.norm3 = layers.LayerNormalization(epsilon=1e-6)
 
192
  self.dropout2 = layers.Dropout(dropout)
193
  self.dropout3 = layers.Dropout(dropout)
194
  def call(self, x, enc_out, training=False):
195
+ x = self.proj(x)
196
  attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
197
+ out1 = self.norm1(attn1 + x)
198
  attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
199
  out2 = self.norm2(out1 + attn2)
200
  ffn_out = self.dropout3(self.ffn(out2), training=training)