valhalla commited on
Commit
7774483
1 Parent(s): 2856356

fix layernorm

Browse files
Files changed (1) hide show
  1. dalle_mini/modeling_bart_flax.py +1 -1
dalle_mini/modeling_bart_flax.py CHANGED
@@ -329,7 +329,7 @@ class FlaxBartDecoderLayer(nn.Module):
329
  dropout=self.config.attention_dropout,
330
  dtype=self.dtype,
331
  )
332
- self.encoder_attn_layer_norm = nn
333
  self.fc1 = nn.Dense(
334
  self.config.encoder_ffn_dim,
335
  dtype=self.dtype,
 
329
  dropout=self.config.attention_dropout,
330
  dtype=self.dtype,
331
  )
332
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
333
  self.fc1 = nn.Dense(
334
  self.config.encoder_ffn_dim,
335
  dtype=self.dtype,