jbochi commited on
Commit
1765fef
1 Parent(s): a4dd515

Fix skip connection

Browse files
Files changed (1) hide show
  1. decoder_only_t5/modeling.py +3 -3
decoder_only_t5/modeling.py CHANGED
@@ -532,9 +532,9 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
532
 
533
  if self.parallel_layers:
534
  # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
535
- hidden_states = x + ff_output
536
- hidden_states *= 2**-0.5
537
- hidden_states = hidden_states + self.layer[0].dropout(hidden_states)
538
  else:
539
  hidden_states = ff_layer(x)
540
 
 
532
 
533
  if self.parallel_layers:
534
  # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
535
+ x = x + ff_output
536
+ x *= 2**-0.5
537
+ hidden_states = hidden_states + self.layer[0].dropout(x)
538
  else:
539
  hidden_states = ff_layer(x)
540