Shuming Ma Shuming Ma commited on
Commit
503d6b4
1 Parent(s): 02824a7

fix: DeepNet doesn't scale weights of embedding/output layers (#150)

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +5 -15
src/dalle_mini/model/modeling.py CHANGED
@@ -883,9 +883,7 @@ class FlaxBartEncoder(FlaxBartEncoder):
883
  self.embed_positions = nn.Embed(
884
  self.config.max_text_length + self.offset,
885
  embed_dim,
886
- embedding_init=deepnet_init()
887
- if self.config.use_deepnet_scaling
888
- else jax.nn.initializers.normal(self.config.init_std),
889
  )
890
  self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
891
  self.layernorm_embedding = norm(
@@ -917,9 +915,7 @@ class FlaxBartDecoder(FlaxBartDecoder):
917
  self.embed_positions = nn.Embed(
918
  self.config.image_length + self.offset, # image length for BOS
919
  embed_dim,
920
- embedding_init=deepnet_init()
921
- if self.config.use_deepnet_scaling
922
- else jax.nn.initializers.normal(self.config.init_std),
923
  )
924
 
925
  self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
@@ -939,16 +935,12 @@ class FlaxBartModule(FlaxBartModule):
939
  encoder_embed_tokens = nn.Embed(
940
  self.config.encoder_vocab_size,
941
  self.config.d_model,
942
- embedding_init=deepnet_init()
943
- if self.config.use_deepnet_scaling
944
- else jax.nn.initializers.normal(self.config.init_std),
945
  )
946
  decoder_embed_tokens = nn.Embed(
947
  self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
948
  self.config.d_model,
949
- embedding_init=deepnet_init()
950
- if self.config.use_deepnet_scaling
951
- else jax.nn.initializers.normal(self.config.init_std),
952
  )
953
 
954
  self.encoder = FlaxBartEncoder(
@@ -1288,9 +1280,7 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
1288
  + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
1289
  use_bias=False,
1290
  dtype=self.dtype,
1291
- kernel_init=deepnet_init()
1292
- if self.config.use_deepnet_scaling
1293
- else jax.nn.initializers.normal(self.config.init_std),
1294
  )
1295
 
1296
  def __call__(
 
883
  self.embed_positions = nn.Embed(
884
  self.config.max_text_length + self.offset,
885
  embed_dim,
886
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
 
887
  )
888
  self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
889
  self.layernorm_embedding = norm(
 
915
  self.embed_positions = nn.Embed(
916
  self.config.image_length + self.offset, # image length for BOS
917
  embed_dim,
918
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
 
919
  )
920
 
921
  self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
 
935
  encoder_embed_tokens = nn.Embed(
936
  self.config.encoder_vocab_size,
937
  self.config.d_model,
938
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
 
939
  )
940
  decoder_embed_tokens = nn.Embed(
941
  self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
942
  self.config.d_model,
943
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
 
944
  )
945
 
946
  self.encoder = FlaxBartEncoder(
 
1280
  + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
1281
  use_bias=False,
1282
  dtype=self.dtype,
1283
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
 
 
1284
  )
1285
 
1286
  def __call__(