boris commited on
Commit
02824a7
1 Parent(s): d9a16f2

feat: remove unecessary LN

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +3 -2
src/dalle_mini/model/modeling.py CHANGED
@@ -736,9 +736,10 @@ class FlaxBartEncoderLayerCollection(nn.Module):
736
  all_hidden_states += (hidden_states,)
737
  # final layernorm on the output of the last layer
738
  # or every 6 layers for Swin v2
 
739
  # ignored args for deepnet which always add a norm with scale
740
- add_norm = (i == n_layers - 1) or (
741
- (self.config.ln_positions == "swinv2") and ((i + 1) % 6 == 0)
742
  )
743
  # we don't need to scale the norm for the last layer
744
  use_scale = i != n_layers - 1
 
736
  all_hidden_states += (hidden_states,)
737
  # final layernorm on the output of the last layer
738
  # or every 6 layers for Swin v2
739
+ # not needed for other models which use layernorm before x-attention
740
  # ignored args for deepnet which always add a norm with scale
741
+ add_norm = self.config.ln_positions == "swinv2" and (
742
+ (i == n_layers - 1) or ((i + 1) % 6 == 0)
743
  )
744
  # we don't need to scale the norm for the last layer
745
  use_scale = i != n_layers - 1