boris commited on
Commit
5bd4c20
1 Parent(s): 503d6b4

feat: allow more configurations

Browse files
src/dalle_mini/model/configuration.py CHANGED
@@ -58,13 +58,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
58
  tie_word_embeddings=False, # different modalities and sizes
59
  do_sample=True,
60
  # transformer variants
61
- head_scale=False, # used in NormFormer
62
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
63
- ln_positions="deepnet", # layer normalization positions, "normformer", "swinv2", "cogview", "deepnet" (same as post-ln)
 
64
  use_cosine_attention=False, # used in Swin v2
65
  tau_init=0.05, # used only in cosine attention (Swin v2)
66
  use_deepnet_scaling=False, # used in Deepnet
67
- use_glu=False, # "GLU Variants Improve Transformer"
 
68
  **kwargs,
69
  ):
70
  # text normalizer
@@ -83,11 +84,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
83
  "cogview",
84
  "deepnet",
85
  ], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
 
 
86
  self.ln_positions = ln_positions
87
  self.use_cosine_attention = use_cosine_attention
88
  self.tau_init = tau_init
89
  self.use_deepnet_scaling = use_deepnet_scaling
90
  self.use_glu = use_glu
 
91
 
92
  # common parameters
93
  self.encoder_vocab_size = encoder_vocab_size
 
58
  tie_word_embeddings=False, # different modalities and sizes
59
  do_sample=True,
60
  # transformer variants
 
61
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
62
+ ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
63
+ head_scale=True, # used in NormFormer
64
  use_cosine_attention=False, # used in Swin v2
65
  tau_init=0.05, # used only in cosine attention (Swin v2)
66
  use_deepnet_scaling=False, # used in Deepnet
67
+ use_glu=True, # "GLU Variants Improve Transformer"
68
+ use_all_scale=True, # use scale in layernorm even when seemingly unnecessary
69
  **kwargs,
70
  ):
71
  # text normalizer
 
84
  "cogview",
85
  "deepnet",
86
  ], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
87
+ if ln_positions == "deepnet":
88
+ ln_positions = "postln"
89
  self.ln_positions = ln_positions
90
  self.use_cosine_attention = use_cosine_attention
91
  self.tau_init = tau_init
92
  self.use_deepnet_scaling = use_deepnet_scaling
93
  self.use_glu = use_glu
94
+ self.use_all_scale = use_all_scale
95
 
96
  # common parameters
97
  self.encoder_vocab_size = encoder_vocab_size
src/dalle_mini/model/modeling.py CHANGED
@@ -375,7 +375,10 @@ class GLU(nn.Module):
375
 
376
  if self.config.ln_positions in ["normformer", "cogview"]:
377
  x = norm(
378
- self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
 
 
 
379
  )(x)
380
  w = nn.Dense(
381
  self.ffn_dim,
@@ -397,7 +400,10 @@ class GLU(nn.Module):
397
  x = w * v
398
  if self.config.ln_positions in ["normformer"]:
399
  x = norm(
400
- self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
 
 
 
401
  )(x)
402
  x = nn.Dropout(rate=self.config.activation_dropout)(
403
  x, deterministic=deterministic
@@ -434,7 +440,10 @@ class FFN(nn.Module):
434
  )
435
  if self.config.ln_positions in ["normformer", "cogview"]:
436
  x = norm(
437
- self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
 
 
 
438
  )(x)
439
  x = nn.Dense(
440
  self.ffn_dim,
@@ -447,7 +456,10 @@ class FFN(nn.Module):
447
  x = ACT2FN[self.config.activation_function](x)
448
  if self.config.ln_positions in ["normformer"]:
449
  x = norm(
450
- self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
 
 
 
451
  )(x)
452
  x = nn.Dropout(rate=self.config.activation_dropout)(
453
  x, deterministic=deterministic
@@ -495,10 +507,13 @@ class FlaxBartEncoderLayer(nn.Module):
495
 
496
  embed_dim = self.config.d_model
497
  residual = hidden_states
498
- if self.config.ln_positions in ["normformer"]:
499
- hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
500
- hidden_states
501
- )
 
 
 
502
  hidden_states, attn_weights = FlaxBartAttention(
503
  config=self.config,
504
  embed_dim=embed_dim,
@@ -509,7 +524,7 @@ class FlaxBartEncoderLayer(nn.Module):
509
  is_encoder=True,
510
  )(hidden_states=hidden_states, attention_mask=attention_mask)
511
 
512
- if self.config.ln_positions in ["normformer", "swinv2"]:
513
  hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
514
  hidden_states
515
  )
@@ -517,7 +532,7 @@ class FlaxBartEncoderLayer(nn.Module):
517
  hidden_states, deterministic=deterministic
518
  )
519
  hidden_states = residual * res_gain + hidden_states
520
- if self.config.ln_positions in ["deepnet"]:
521
  hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
522
  hidden_states
523
  )
@@ -542,8 +557,12 @@ class FlaxBartEncoderLayer(nn.Module):
542
  )
543
  hidden_states = ff_block(hidden_states, deterministic=deterministic)
544
  hidden_states = residual * res_gain + hidden_states
545
- if self.add_norm or self.config.ln_positions in ["deepnet"]:
546
- use_scale = self.use_scale or self.config.ln_positions == "deepnet"
 
 
 
 
547
  hidden_states = norm(
548
  self.config.ln_type,
549
  dtype=self.dtype,
@@ -598,7 +617,7 @@ class FlaxBartDecoderLayer(nn.Module):
598
  self.config.ln_type,
599
  dtype=self.dtype,
600
  epsilon=1e-05,
601
- use_scale=False,
602
  )(hidden_states)
603
  hidden_states, attn_weights = FlaxBartAttention(
604
  config=self.config,
@@ -623,7 +642,7 @@ class FlaxBartDecoderLayer(nn.Module):
623
  hidden_states, deterministic=deterministic
624
  )
625
  hidden_states = residual * res_gain + hidden_states
626
- if self.config.ln_positions in ["deepnet"]:
627
  hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
628
  hidden_states
629
  )
@@ -637,7 +656,7 @@ class FlaxBartDecoderLayer(nn.Module):
637
  self.config.ln_type,
638
  dtype=self.dtype,
639
  epsilon=1e-05,
640
- use_scale=False,
641
  )(hidden_states)
642
  hidden_states, cross_attn_weights = FlaxBartAttention(
643
  config=self.config,
@@ -660,7 +679,7 @@ class FlaxBartDecoderLayer(nn.Module):
660
  hidden_states, deterministic=deterministic
661
  )
662
  hidden_states = residual * res_gain + hidden_states
663
- if self.config.ln_positions in ["deepnet"]:
664
  hidden_states = norm(
665
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
666
  )(hidden_states)
@@ -686,8 +705,12 @@ class FlaxBartDecoderLayer(nn.Module):
686
  )
687
  hidden_states = ff_block(hidden_states, deterministic=deterministic)
688
  hidden_states = residual * res_gain + hidden_states
689
- if self.add_norm or self.config.ln_positions in ["deepnet"]:
690
- use_scale = self.use_scale or self.config.ln_positions == "deepnet"
 
 
 
 
691
  hidden_states = norm(
692
  self.config.ln_type,
693
  dtype=self.dtype,
 
375
 
376
  if self.config.ln_positions in ["normformer", "cogview"]:
377
  x = norm(
378
+ self.config.ln_type,
379
+ dtype=self.dtype,
380
+ epsilon=1e-05,
381
+ use_scale=self.config.use_all_scale,
382
  )(x)
383
  w = nn.Dense(
384
  self.ffn_dim,
 
400
  x = w * v
401
  if self.config.ln_positions in ["normformer"]:
402
  x = norm(
403
+ self.config.ln_type,
404
+ dtype=self.dtype,
405
+ epsilon=1e-05,
406
+ use_scale=self.config.use_all_scale,
407
  )(x)
408
  x = nn.Dropout(rate=self.config.activation_dropout)(
409
  x, deterministic=deterministic
 
440
  )
441
  if self.config.ln_positions in ["normformer", "cogview"]:
442
  x = norm(
443
+ self.config.ln_type,
444
+ dtype=self.dtype,
445
+ epsilon=1e-05,
446
+ use_scale=self.config.use_all_scale,
447
  )(x)
448
  x = nn.Dense(
449
  self.ffn_dim,
 
456
  x = ACT2FN[self.config.activation_function](x)
457
  if self.config.ln_positions in ["normformer"]:
458
  x = norm(
459
+ self.config.ln_type,
460
+ dtype=self.dtype,
461
+ epsilon=1e-05,
462
+ use_scale=self.config.use_all_scale,
463
  )(x)
464
  x = nn.Dropout(rate=self.config.activation_dropout)(
465
  x, deterministic=deterministic
 
507
 
508
  embed_dim = self.config.d_model
509
  residual = hidden_states
510
+ if self.config.ln_positions in ["normformer", "cogview"]:
511
+ hidden_states = norm(
512
+ self.config.ln_type,
513
+ dtype=self.dtype,
514
+ epsilon=1e-05,
515
+ use_scale=self.config.use_all_scale,
516
+ )(hidden_states)
517
  hidden_states, attn_weights = FlaxBartAttention(
518
  config=self.config,
519
  embed_dim=embed_dim,
 
524
  is_encoder=True,
525
  )(hidden_states=hidden_states, attention_mask=attention_mask)
526
 
527
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
528
  hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
529
  hidden_states
530
  )
 
532
  hidden_states, deterministic=deterministic
533
  )
534
  hidden_states = residual * res_gain + hidden_states
535
+ if self.config.ln_positions in ["postln"]:
536
  hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
537
  hidden_states
538
  )
 
557
  )
558
  hidden_states = ff_block(hidden_states, deterministic=deterministic)
559
  hidden_states = residual * res_gain + hidden_states
560
+ if self.add_norm or self.config.ln_positions in ["postln"]:
561
+ use_scale = (
562
+ self.use_scale
563
+ or self.config.ln_positions == "postln"
564
+ or self.config.use_all_scale
565
+ )
566
  hidden_states = norm(
567
  self.config.ln_type,
568
  dtype=self.dtype,
 
617
  self.config.ln_type,
618
  dtype=self.dtype,
619
  epsilon=1e-05,
620
+ use_scale=self.config.use_all_scale,
621
  )(hidden_states)
622
  hidden_states, attn_weights = FlaxBartAttention(
623
  config=self.config,
 
642
  hidden_states, deterministic=deterministic
643
  )
644
  hidden_states = residual * res_gain + hidden_states
645
+ if self.config.ln_positions in ["postln"]:
646
  hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
647
  hidden_states
648
  )
 
656
  self.config.ln_type,
657
  dtype=self.dtype,
658
  epsilon=1e-05,
659
+ use_scale=self.config.use_all_scale,
660
  )(hidden_states)
661
  hidden_states, cross_attn_weights = FlaxBartAttention(
662
  config=self.config,
 
679
  hidden_states, deterministic=deterministic
680
  )
681
  hidden_states = residual * res_gain + hidden_states
682
+ if self.config.ln_positions in ["postln"]:
683
  hidden_states = norm(
684
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
685
  )(hidden_states)
 
705
  )
706
  hidden_states = ff_block(hidden_states, deterministic=deterministic)
707
  hidden_states = residual * res_gain + hidden_states
708
+ if self.add_norm or self.config.ln_positions in ["postln"]:
709
+ use_scale = (
710
+ self.use_scale
711
+ or self.config.ln_positions == "postln"
712
+ or self.config.use_all_scale
713
+ )
714
  hidden_states = norm(
715
  self.config.ln_type,
716
  dtype=self.dtype,