Spaces:
Running
Running
feat: remove unecessary LN
Browse files
src/dalle_mini/model/modeling.py
CHANGED
@@ -736,9 +736,10 @@ class FlaxBartEncoderLayerCollection(nn.Module):
|
|
736 |
all_hidden_states += (hidden_states,)
|
737 |
# final layernorm on the output of the last layer
|
738 |
# or every 6 layers for Swin v2
|
|
|
739 |
# ignored args for deepnet which always add a norm with scale
|
740 |
-
add_norm =
|
741 |
-
(
|
742 |
)
|
743 |
# we don't need to scale the norm for the last layer
|
744 |
use_scale = i != n_layers - 1
|
|
|
736 |
all_hidden_states += (hidden_states,)
|
737 |
# final layernorm on the output of the last layer
|
738 |
# or every 6 layers for Swin v2
|
739 |
+
# not needed for other models which use layernorm before x-attention
|
740 |
# ignored args for deepnet which always add a norm with scale
|
741 |
+
add_norm = self.config.ln_positions == "swinv2" and (
|
742 |
+
(i == n_layers - 1) or ((i + 1) % 6 == 0)
|
743 |
)
|
744 |
# we don't need to scale the norm for the last layer
|
745 |
use_scale = i != n_layers - 1
|