boris commited on
Commit
a11892f
1 Parent(s): f234ccf

fix(model): use correct params

Browse files
dalle_mini/configuration_bart.py CHANGED
@@ -21,16 +21,11 @@ from transformers.utils import logging
21
 
22
  logger = logging.get_logger(__name__)
23
 
24
- BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
- "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
26
- # See all BART models at https://huggingface.co/models?filter=bart
27
- }
28
 
29
-
30
- class BartConfig(PretrainedConfig):
31
  r"""
32
- This is the configuration class to store the configuration of a :class:`~transformers.BartModel`. It is used to
33
- instantiate a BART model according to the specified arguments, defining the model architecture. Instantiating a
34
  configuration with the defaults will yield a similar configuration to that of the BART `facebook/bart-large
35
  <https://huggingface.co/facebook/bart-large>`__ architecture.
36
 
@@ -39,7 +34,7 @@ class BartConfig(PretrainedConfig):
39
 
40
 
41
  Args:
42
- vocab_size (:obj:`int`, `optional`, defaults to 50265):
43
  Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
44
  :obj:`inputs_ids` passed when calling :class:`~transformers.BartModel` or
45
  :class:`~transformers.TFBartModel`.
@@ -90,30 +85,18 @@ class BartConfig(PretrainedConfig):
90
  forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
91
  The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
92
  :obj:`eos_token_id`.
93
-
94
- Example::
95
-
96
- >>> from transformers import BartModel, BartConfig
97
-
98
- >>> # Initializing a BART facebook/bart-large style configuration
99
- >>> configuration = BartConfig()
100
-
101
- >>> # Initializing a model from the facebook/bart-large style configuration
102
- >>> model = BartModel(configuration)
103
-
104
- >>> # Accessing the model configuration
105
- >>> configuration = model.config
106
  """
107
- model_type = "bart"
108
  keys_to_ignore_at_inference = ["past_key_values"]
109
  attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
110
 
111
  def __init__(
112
  self,
113
- vocab_size=50265,
114
- decoder_vocab_size=16384 + 1, # encoded image token space + 1 for bos
115
- max_position_embeddings=1024,
116
- decoder_max_position_embeddings=256 + 1, # number of encoded tokens + 1 for bos,
 
117
  encoder_layers=12,
118
  encoder_ffn_dim=4096,
119
  encoder_attention_heads=16,
@@ -133,19 +116,16 @@ class BartConfig(PretrainedConfig):
133
  gradient_checkpointing=False,
134
  use_cache=True,
135
  num_labels=3,
136
- pad_token_id=1,
137
- bos_token_id=0,
138
- eos_token_id=2,
139
  is_encoder_decoder=True,
140
- decoder_start_token_id=16384,
141
- forced_eos_token_id=2,
142
- tie_word_embeddings=False, # don't tie for scaling reasons
143
  **kwargs,
144
  ):
145
- self.vocab_size = vocab_size
146
- self.decoder_vocab_size = decoder_vocab_size
147
- self.max_position_embeddings = max_position_embeddings
148
- self.decoder_max_position_embeddings = decoder_max_position_embeddings
 
149
  self.d_model = d_model
150
  self.encoder_ffn_dim = encoder_ffn_dim
151
  self.encoder_layers = encoder_layers
@@ -165,12 +145,15 @@ class BartConfig(PretrainedConfig):
165
  self.num_hidden_layers = encoder_layers
166
  self.gradient_checkpointing = gradient_checkpointing
167
  self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
 
 
 
168
 
169
  super().__init__(
170
  num_labels=num_labels,
171
- pad_token_id=pad_token_id,
172
- bos_token_id=bos_token_id,
173
- eos_token_id=eos_token_id,
174
  is_encoder_decoder=is_encoder_decoder,
175
  decoder_start_token_id=decoder_start_token_id,
176
  forced_eos_token_id=forced_eos_token_id,
 
21
 
22
  logger = logging.get_logger(__name__)
23
 
 
 
 
 
24
 
25
+ class DalleBartConfig(PretrainedConfig):
 
26
  r"""
27
+ This is the configuration class to store the configuration of a `DalleBartModel`. It is used to
28
+ instantiate a DalleBart model according to the specified arguments, defining the model architecture. Instantiating a
29
  configuration with the defaults will yield a similar configuration to that of the BART `facebook/bart-large
30
  <https://huggingface.co/facebook/bart-large>`__ architecture.
31
 
 
34
 
35
 
36
  Args:
37
+ encoder_vocab_size (:obj:`int`, `optional`, defaults to 50265):
38
  Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
39
  :obj:`inputs_ids` passed when calling :class:`~transformers.BartModel` or
40
  :class:`~transformers.TFBartModel`.
 
85
  forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
86
  The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
87
  :obj:`eos_token_id`.
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  """
89
+ model_type = "dallebart"
90
  keys_to_ignore_at_inference = ["past_key_values"]
91
  attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
92
 
93
  def __init__(
94
  self,
95
+ normalize_text=False,
96
+ encoder_vocab_size=50264,
97
+ image_vocab_size=16384, # encoded image token space
98
+ image_length=256, # number of encoded tokens
99
+ max_text_length=64, # max number of text tokens
100
  encoder_layers=12,
101
  encoder_ffn_dim=4096,
102
  encoder_attention_heads=16,
 
116
  gradient_checkpointing=False,
117
  use_cache=True,
118
  num_labels=3,
 
 
 
119
  is_encoder_decoder=True,
120
+ forced_eos_token_id=None,
121
+ tie_word_embeddings=False, # don't tie for scaling reasons and due to different modalities and sizes
 
122
  **kwargs,
123
  ):
124
+ self.normalize_text = normalize_text
125
+ self.encoder_vocab_size = encoder_vocab_size
126
+ self.decoder_vocab_size = image_vocab_size
127
+ self.image_length = image_length
128
+ self.max_text_length = max_text_length
129
  self.d_model = d_model
130
  self.encoder_ffn_dim = encoder_ffn_dim
131
  self.encoder_layers = encoder_layers
 
145
  self.num_hidden_layers = encoder_layers
146
  self.gradient_checkpointing = gradient_checkpointing
147
  self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
148
+ self.decoder_start_token_id = image_vocab_size, # BOS appended to vocab
149
+ self.min_length = image_length + 1
150
+ self.max_length = image_length + 1
151
 
152
  super().__init__(
153
  num_labels=num_labels,
154
+ pad_token_id=image_vocab_size + 1, # needed to avoid errors during generation (converted to jnp.array)
155
+ bos_token_id=image_vocab_size + 1, # set to unreachable values
156
+ eos_token_id=image_vocab_size + 1,
157
  is_encoder_decoder=is_encoder_decoder,
158
  decoder_start_token_id=decoder_start_token_id,
159
  forced_eos_token_id=forced_eos_token_id,
dalle_mini/modeling_bart_flax.py CHANGED
@@ -93,7 +93,7 @@ class FlaxBartAttention(nn.Module):
93
 
94
  if self.causal:
95
  self.causal_mask = make_causal_mask(
96
- jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
97
  )
98
 
99
  def _split_heads(self, hidden_states):
@@ -431,11 +431,10 @@ class FlaxBartEncoder(nn.Module):
431
 
432
  embed_dim = self.config.d_model
433
  self.padding_idx = self.config.pad_token_id
434
- self.max_source_positions = self.config.max_position_embeddings
435
  self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
436
 
437
  self.embed_tokens = nn.Embed(
438
- self.config.vocab_size,
439
  embed_dim,
440
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
441
  )
@@ -444,7 +443,7 @@ class FlaxBartEncoder(nn.Module):
444
  # and adjust num_embeddings appropriately. Other models don't have this hack
445
  self.offset = 0
446
  self.embed_positions = nn.Embed(
447
- self.config.max_position_embeddings + self.offset,
448
  embed_dim,
449
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
450
  )
@@ -489,11 +488,10 @@ class FlaxBartDecoder(nn.Module):
489
 
490
  embed_dim = self.config.d_model
491
  self.padding_idx = self.config.pad_token_id
492
- self.max_target_positions = self.config.max_position_embeddings
493
  self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
494
 
495
  self.embed_tokens = nn.Embed(
496
- self.config.decoder_vocab_size,
497
  embed_dim,
498
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
499
  )
@@ -502,7 +500,7 @@ class FlaxBartDecoder(nn.Module):
502
  # and adjust num_embeddings appropriately. Other models don't have this hack
503
  self.offset = 0
504
  self.embed_positions = nn.Embed(
505
- self.config.decoder_max_position_embeddings + self.offset,
506
  embed_dim,
507
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
508
  )
@@ -802,11 +800,14 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
802
  def setup(self):
803
  self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
804
  self.lm_head = nn.Dense(
805
- self.config.decoder_vocab_size,
806
  use_bias=False,
807
  dtype=self.dtype,
808
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
809
  )
 
 
 
810
 
811
  def _get_encoder_module(self):
812
  return self.model.encoder
 
93
 
94
  if self.causal:
95
  self.causal_mask = make_causal_mask(
96
+ jnp.ones((1, embed_dim), dtype="bool"), dtype="bool"
97
  )
98
 
99
  def _split_heads(self, hidden_states):
 
431
 
432
  embed_dim = self.config.d_model
433
  self.padding_idx = self.config.pad_token_id
 
434
  self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
435
 
436
  self.embed_tokens = nn.Embed(
437
+ self.config.encoder_vocab_size,
438
  embed_dim,
439
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
440
  )
 
443
  # and adjust num_embeddings appropriately. Other models don't have this hack
444
  self.offset = 0
445
  self.embed_positions = nn.Embed(
446
+ self.config.max_text_length + self.offset,
447
  embed_dim,
448
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
449
  )
 
488
 
489
  embed_dim = self.config.d_model
490
  self.padding_idx = self.config.pad_token_id
 
491
  self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
492
 
493
  self.embed_tokens = nn.Embed(
494
+ self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
495
  embed_dim,
496
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
497
  )
 
500
  # and adjust num_embeddings appropriately. Other models don't have this hack
501
  self.offset = 0
502
  self.embed_positions = nn.Embed(
503
+ self.config.image_length + 1 + self.offset, # image length + 1 for BOS
504
  embed_dim,
505
  embedding_init=jax.nn.initializers.normal(self.config.init_std),
506
  )
 
800
  def setup(self):
801
  self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
802
  self.lm_head = nn.Dense(
803
+ self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
804
  use_bias=False,
805
  dtype=self.dtype,
806
  kernel_init=jax.nn.initializers.normal(self.config.init_std),
807
  )
808
+ self.final_logits_bias = self.param(
809
+ "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
810
+ )
811
 
812
  def _get_encoder_module(self):
813
  return self.model.encoder