boris commited on
Commit
a96f4dc
1 Parent(s): a11892f

fix: adjust training script + dataloader

Browse files
dalle_mini/data.py CHANGED
@@ -15,12 +15,10 @@ class Dataset:
15
  dataset_repo_or_path: str
16
  train_file: str = None
17
  validation_file: str = None
18
- dataset_type: str = "dataset"
19
  streaming: bool = True
20
  use_auth_token: bool = False
21
  text_column: str = "caption"
22
  encoding_column: str = "encoding"
23
- max_source_length: int = 128
24
  max_train_samples: int = None
25
  max_eval_samples: int = None
26
  preprocessing_num_workers: int = None
@@ -70,7 +68,7 @@ class Dataset:
70
  else self.eval_dataset.select(range(self.max_eval_samples))
71
  )
72
 
73
- def preprocess(self, tokenizer, decoder_start_token_id, normalize_text):
74
  if self.streaming:
75
  # we need to shuffle early in streaming mode
76
  if hasattr(self, "train_dataset"):
@@ -112,7 +110,7 @@ class Dataset:
112
  tokenizer=tokenizer,
113
  text_column=self.text_column,
114
  encoding_column=self.encoding_column,
115
- max_source_length=self.max_source_length,
116
  decoder_start_token_id=decoder_start_token_id,
117
  )
118
  for ds in ["train_dataset", "eval_dataset"]:
@@ -232,14 +230,14 @@ def preprocess_function(
232
  tokenizer,
233
  text_column,
234
  encoding_column,
235
- max_source_length,
236
  decoder_start_token_id,
237
  ):
238
  inputs = examples[text_column]
239
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
240
  model_inputs = tokenizer(
241
  inputs,
242
- max_length=max_source_length,
243
  padding="max_length",
244
  truncation=True,
245
  return_tensors="np",
 
15
  dataset_repo_or_path: str
16
  train_file: str = None
17
  validation_file: str = None
 
18
  streaming: bool = True
19
  use_auth_token: bool = False
20
  text_column: str = "caption"
21
  encoding_column: str = "encoding"
 
22
  max_train_samples: int = None
23
  max_eval_samples: int = None
24
  preprocessing_num_workers: int = None
 
68
  else self.eval_dataset.select(range(self.max_eval_samples))
69
  )
70
 
71
+ def preprocess(self, tokenizer, decoder_start_token_id, normalize_text, max_length):
72
  if self.streaming:
73
  # we need to shuffle early in streaming mode
74
  if hasattr(self, "train_dataset"):
 
110
  tokenizer=tokenizer,
111
  text_column=self.text_column,
112
  encoding_column=self.encoding_column,
113
+ max_length=max_length,
114
  decoder_start_token_id=decoder_start_token_id,
115
  )
116
  for ds in ["train_dataset", "eval_dataset"]:
 
230
  tokenizer,
231
  text_column,
232
  encoding_column,
233
+ max_length,
234
  decoder_start_token_id,
235
  ):
236
  inputs = examples[text_column]
237
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
238
  model_inputs = tokenizer(
239
  inputs,
240
+ max_length=max_length,
241
  padding="max_length",
242
  truncation=True,
243
  return_tensors="np",
dalle_mini/model.py DELETED
@@ -1,64 +0,0 @@
1
- import flax.linen as nn
2
- import jax
3
- from transformers import BartConfig
4
- from transformers.models.bart.modeling_flax_bart import (
5
- FlaxBartDecoder,
6
- FlaxBartEncoder,
7
- FlaxBartForConditionalGeneration,
8
- FlaxBartForConditionalGenerationModule,
9
- FlaxBartModule,
10
- )
11
-
12
-
13
- class CustomFlaxBartModule(FlaxBartModule):
14
- def setup(self):
15
- # we keep shared to easily load pre-trained weights
16
- self.shared = nn.Embed(
17
- self.config.vocab_size,
18
- self.config.d_model,
19
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
20
- )
21
- # a separate embedding is used for the decoder
22
- self.decoder_embed = nn.Embed(
23
- self.config.image_vocab_size + 1,
24
- self.config.d_model,
25
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
26
- )
27
- self.encoder = FlaxBartEncoder(
28
- self.config, dtype=self.dtype, embed_tokens=self.shared
29
- )
30
-
31
- # the decoder has a different config
32
- # TODO: should not be needed once we have custom config/module
33
- decoder_config = BartConfig(self.config.to_dict())
34
- decoder_config.max_position_embeddings = (
35
- self.config.image_length + 1 # image tokens + BOS
36
- )
37
- decoder_config.vocab_size = self.config.image_vocab_size + 1
38
- self.decoder = FlaxBartDecoder(
39
- decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
40
- )
41
-
42
-
43
- class CustomFlaxBartForConditionalGenerationModule(
44
- FlaxBartForConditionalGenerationModule
45
- ):
46
- def setup(self):
47
- # set default config
48
- self.config.normalize_text = getattr(self.config, "normalize_text", False)
49
- self.config.image_length = getattr(self.config, "image_length", 256)
50
- self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384)
51
-
52
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
53
- self.lm_head = nn.Dense(
54
- self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
55
- use_bias=False,
56
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
57
- )
58
- self.final_logits_bias = self.param(
59
- "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
60
- )
61
-
62
-
63
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
64
- module_class = CustomFlaxBartForConditionalGenerationModule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration import DalleBartConfig
2
+ from .modeling import DalleBartForConditionalGeneration
dalle_mini/{configuration_bart.py → model/configuration.py} RENAMED
@@ -12,7 +12,7 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """ BART model configuration """
16
  import warnings
17
 
18
  from transformers.configuration_utils import PretrainedConfig
@@ -123,7 +123,7 @@ class DalleBartConfig(PretrainedConfig):
123
  ):
124
  self.normalize_text = normalize_text
125
  self.encoder_vocab_size = encoder_vocab_size
126
- self.decoder_vocab_size = image_vocab_size
127
  self.image_length = image_length
128
  self.max_text_length = max_text_length
129
  self.d_model = d_model
@@ -145,17 +145,21 @@ class DalleBartConfig(PretrainedConfig):
145
  self.num_hidden_layers = encoder_layers
146
  self.gradient_checkpointing = gradient_checkpointing
147
  self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
148
- self.decoder_start_token_id = image_vocab_size, # BOS appended to vocab
149
  self.min_length = image_length + 1
150
  self.max_length = image_length + 1
151
 
 
 
 
 
152
  super().__init__(
153
  num_labels=num_labels,
154
  pad_token_id=image_vocab_size + 1, # needed to avoid errors during generation (converted to jnp.array)
155
  bos_token_id=image_vocab_size + 1, # set to unreachable values
156
  eos_token_id=image_vocab_size + 1,
157
  is_encoder_decoder=is_encoder_decoder,
158
- decoder_start_token_id=decoder_start_token_id,
159
  forced_eos_token_id=forced_eos_token_id,
160
  tie_word_embeddings=tie_word_embeddings,
161
  **kwargs,
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """ DalleBart model configuration """
16
  import warnings
17
 
18
  from transformers.configuration_utils import PretrainedConfig
 
123
  ):
124
  self.normalize_text = normalize_text
125
  self.encoder_vocab_size = encoder_vocab_size
126
+ self.image_vocab_size = image_vocab_size
127
  self.image_length = image_length
128
  self.max_text_length = max_text_length
129
  self.d_model = d_model
 
145
  self.num_hidden_layers = encoder_layers
146
  self.gradient_checkpointing = gradient_checkpointing
147
  self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
148
+ self.decoder_start_token_id = image_vocab_size # BOS appended to vocab
149
  self.min_length = image_length + 1
150
  self.max_length = image_length + 1
151
 
152
+ # remove keys we are about to set to prevent errors
153
+ for k in ['bos_token_id', 'eos_token_id', 'pad_token_id', 'decoder_start_token_id', 'forced_eos_token_id']:
154
+ kwargs.pop(k, None)
155
+
156
  super().__init__(
157
  num_labels=num_labels,
158
  pad_token_id=image_vocab_size + 1, # needed to avoid errors during generation (converted to jnp.array)
159
  bos_token_id=image_vocab_size + 1, # set to unreachable values
160
  eos_token_id=image_vocab_size + 1,
161
  is_encoder_decoder=is_encoder_decoder,
162
+ decoder_start_token_id=self.decoder_start_token_id,
163
  forced_eos_token_id=forced_eos_token_id,
164
  tie_word_embeddings=tie_word_embeddings,
165
  **kwargs,
dalle_mini/{modeling_bart_flax.py → model/modeling.py} RENAMED
@@ -45,7 +45,7 @@ from transformers.modeling_flax_utils import (
45
  from transformers.utils import logging
46
 
47
 
48
- from .configuration_bart import BartConfig
49
 
50
 
51
  logger = logging.get_logger(__name__)
@@ -64,7 +64,7 @@ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_tok
64
 
65
 
66
  class FlaxBartAttention(nn.Module):
67
- config: BartConfig
68
  embed_dim: int
69
  num_heads: int
70
  dropout: float = 0.0
@@ -93,7 +93,7 @@ class FlaxBartAttention(nn.Module):
93
 
94
  if self.causal:
95
  self.causal_mask = make_causal_mask(
96
- jnp.ones((1, embed_dim), dtype="bool"), dtype="bool"
97
  )
98
 
99
  def _split_heads(self, hidden_states):
@@ -224,7 +224,7 @@ class FlaxBartAttention(nn.Module):
224
 
225
 
226
  class FlaxBartEncoderLayer(nn.Module):
227
- config: BartConfig
228
  dtype: jnp.dtype = jnp.float32
229
 
230
  def setup(self) -> None:
@@ -279,7 +279,7 @@ class FlaxBartEncoderLayer(nn.Module):
279
 
280
 
281
  class FlaxBartEncoderLayerCollection(nn.Module):
282
- config: BartConfig
283
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
284
 
285
  def setup(self):
@@ -306,7 +306,7 @@ class FlaxBartEncoderLayerCollection(nn.Module):
306
 
307
 
308
  class FlaxBartDecoderLayer(nn.Module):
309
- config: BartConfig
310
  dtype: jnp.dtype = jnp.float32
311
 
312
  def setup(self) -> None:
@@ -390,7 +390,7 @@ class FlaxBartDecoderLayer(nn.Module):
390
 
391
 
392
  class FlaxBartDecoderLayerCollection(nn.Module):
393
- config: BartConfig
394
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
395
 
396
  def setup(self):
@@ -422,8 +422,8 @@ class FlaxBartDecoderLayerCollection(nn.Module):
422
  return FlaxBaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states)
423
 
424
 
425
- class FlaxBartEncoder(nn.Module):
426
- config: BartConfig
427
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
428
 
429
  def setup(self):
@@ -479,8 +479,8 @@ class FlaxBartEncoder(nn.Module):
479
  )
480
 
481
 
482
- class FlaxBartDecoder(nn.Module):
483
- config: BartConfig
484
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
485
 
486
  def setup(self):
@@ -550,13 +550,13 @@ class FlaxBartDecoder(nn.Module):
550
  )
551
 
552
 
553
- class FlaxBartModule(nn.Module):
554
- config: BartConfig
555
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
556
 
557
  def setup(self):
558
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype)
559
- self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype)
560
 
561
  def _get_encoder_module(self):
562
  return self.encoder
@@ -605,14 +605,14 @@ class FlaxBartModule(nn.Module):
605
  )
606
 
607
 
608
- class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
609
- config_class = BartConfig
610
- base_model_prefix: str = "model"
611
  module_class: nn.Module = None
612
 
613
  def __init__(
614
  self,
615
- config: BartConfig,
616
  input_shape: Tuple[int] = (1, 1),
617
  seed: int = 0,
618
  dtype: jnp.dtype = jnp.float32,
@@ -792,13 +792,13 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
792
  )
793
 
794
 
795
- class FlaxBartForConditionalGenerationModule(nn.Module):
796
- config: BartConfig
797
  dtype: jnp.dtype = jnp.float32
798
  bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
799
 
800
  def setup(self):
801
- self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
802
  self.lm_head = nn.Dense(
803
  self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
804
  use_bias=False,
@@ -854,8 +854,8 @@ class FlaxBartForConditionalGenerationModule(nn.Module):
854
  )
855
 
856
 
857
- class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
858
- module_class = FlaxBartForConditionalGenerationModule
859
  dtype: jnp.dtype = jnp.float32
860
 
861
  def decode(
 
45
  from transformers.utils import logging
46
 
47
 
48
+ from .configuration import DalleBartConfig
49
 
50
 
51
  logger = logging.get_logger(__name__)
 
64
 
65
 
66
  class FlaxBartAttention(nn.Module):
67
+ config: DalleBartConfig
68
  embed_dim: int
69
  num_heads: int
70
  dropout: float = 0.0
 
93
 
94
  if self.causal:
95
  self.causal_mask = make_causal_mask(
96
+ jnp.ones((1, self.embed_dim), dtype="bool"), dtype="bool"
97
  )
98
 
99
  def _split_heads(self, hidden_states):
 
224
 
225
 
226
  class FlaxBartEncoderLayer(nn.Module):
227
+ config: DalleBartConfig
228
  dtype: jnp.dtype = jnp.float32
229
 
230
  def setup(self) -> None:
 
279
 
280
 
281
  class FlaxBartEncoderLayerCollection(nn.Module):
282
+ config: DalleBartConfig
283
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
284
 
285
  def setup(self):
 
306
 
307
 
308
  class FlaxBartDecoderLayer(nn.Module):
309
+ config: DalleBartConfig
310
  dtype: jnp.dtype = jnp.float32
311
 
312
  def setup(self) -> None:
 
390
 
391
 
392
  class FlaxBartDecoderLayerCollection(nn.Module):
393
+ config: DalleBartConfig
394
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
395
 
396
  def setup(self):
 
422
  return FlaxBaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states)
423
 
424
 
425
+ class DalleBartEncoder(nn.Module):
426
+ config: DalleBartConfig
427
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
428
 
429
  def setup(self):
 
479
  )
480
 
481
 
482
+ class DalleBartDecoder(nn.Module):
483
+ config: DalleBartConfig
484
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
485
 
486
  def setup(self):
 
550
  )
551
 
552
 
553
+ class DalleBartModule(nn.Module):
554
+ config: DalleBartConfig
555
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
556
 
557
  def setup(self):
558
+ self.encoder = DalleBartEncoder(self.config, dtype=self.dtype)
559
+ self.decoder = DalleBartDecoder(self.config, dtype=self.dtype)
560
 
561
  def _get_encoder_module(self):
562
  return self.encoder
 
605
  )
606
 
607
 
608
+ class DalleBartPreTrainedModel(FlaxPreTrainedModel):
609
+ config_class = DalleBartConfig
610
+ base_model_prefix: str = "dallebart"
611
  module_class: nn.Module = None
612
 
613
  def __init__(
614
  self,
615
+ config: DalleBartConfig,
616
  input_shape: Tuple[int] = (1, 1),
617
  seed: int = 0,
618
  dtype: jnp.dtype = jnp.float32,
 
792
  )
793
 
794
 
795
+ class DalleBartForConditionalGenerationModule(nn.Module):
796
+ config: DalleBartConfig
797
  dtype: jnp.dtype = jnp.float32
798
  bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
799
 
800
  def setup(self):
801
+ self.model = DalleBartModule(config=self.config, dtype=self.dtype)
802
  self.lm_head = nn.Dense(
803
  self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
804
  use_bias=False,
 
854
  )
855
 
856
 
857
+ class DalleBartForConditionalGeneration(DalleBartPreTrainedModel):
858
+ module_class = DalleBartForConditionalGenerationModule
859
  dtype: jnp.dtype = jnp.float32
860
 
861
  def decode(
dalle_mini/{partitions.py → model/partitions.py} RENAMED
@@ -5,7 +5,7 @@ from flax.traverse_util import flatten_dict, unflatten_dict
5
  from jax.experimental import PartitionSpec as P
6
 
7
 
8
- # utils adapted from https://gitihub.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
9
  # Sentinels
10
  _unmatched = object()
11
 
 
5
  from jax.experimental import PartitionSpec as P
6
 
7
 
8
+ # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
9
  # Sentinels
10
  _unmatched = object()
11
 
tools/train/train.py CHANGED
@@ -44,7 +44,7 @@ from transformers import AutoTokenizer, HfArgumentParser
44
  from transformers.models.bart.modeling_flax_bart import BartConfig
45
 
46
  from dalle_mini.data import Dataset
47
- from dalle_mini.model import CustomFlaxBartForConditionalGeneration
48
 
49
  logger = logging.getLogger(__name__)
50
 
@@ -68,26 +68,12 @@ class ModelArguments:
68
  "help": "Pretrained config name or path if not the same as model_name"
69
  },
70
  )
71
- image_vocab_size: Optional[int] = field(
72
- default=None,
73
- metadata={"help": "Vocab size of image encoder"},
74
- )
75
- image_length: Optional[int] = field(
76
- default=None,
77
- metadata={"help": "Number of tokens per image"},
78
- )
79
  tokenizer_name: Optional[str] = field(
80
  default=None,
81
  metadata={
82
  "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
83
  },
84
  )
85
- normalize_text: Optional[bool] = field(
86
- default=None,
87
- metadata={
88
- "help": "Whether to normalize text or not. By default, we refer to base model or don't normalize for new models."
89
- },
90
- )
91
  dtype: Optional[str] = field(
92
  default="float32",
93
  metadata={
@@ -126,10 +112,6 @@ class DataTrainingArguments:
126
  default=None,
127
  metadata={"help": "An optional input evaluation data file (glob acceptable)."},
128
  )
129
- dataset_type: str = field(
130
- default="datasets",
131
- metadata={"help": "Either 🤗 'dataset' (default) or 'webdataset'."},
132
- )
133
  # data loading should not be a bottleneck so we use "streaming" mode by default
134
  streaming: bool = field(
135
  default=True,
@@ -141,13 +123,6 @@ class DataTrainingArguments:
141
  "help": "Whether to use the authentication token for private datasets."
142
  },
143
  )
144
- max_source_length: Optional[int] = field(
145
- default=128,
146
- metadata={
147
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
148
- "than this will be truncated, sequences shorter will be padded."
149
- },
150
- )
151
  max_train_samples: Optional[int] = field(
152
  default=None,
153
  metadata={
@@ -436,47 +411,14 @@ def main():
436
 
437
  else:
438
  # Set up our new model config
439
- # TODO: simplify with custom config class
440
  if model_args.config_name:
441
- config = BartConfig.from_pretrained(model_args.config_name)
442
- else:
443
- config = BartConfig.from_pretrained(model_args.model_name_or_path)
444
- if model_args.image_vocab_size:
445
- config.image_vocab_size = model_args.image_vocab_size
446
- assert (
447
- getattr(config, "image_vocab_size") is not None
448
- ), "image_vocab_size must be specified when not present in base model/config"
449
- if model_args.image_length:
450
- config.image_length = model_args.image_length
451
- assert (
452
- getattr(config, "image_length") is not None
453
- ), "image_length must be specified when not present in base model/config"
454
- # we append decoder bos to image vocab
455
- config.decoder_start_token_id = config.image_vocab_size
456
- # ensure we don't generate bos (in addition to decoder start token)
457
- config.force_bos_token_to_be_generated = False
458
- config.forced_bos_token_id = None # we don't need this token
459
- config.forced_eos_token_id = None # we don't need this token
460
-
461
- config.tie_word_embeddings = False
462
- config.min_length = config.image_length + 1
463
- config.max_length = config.image_length + 1
464
-
465
- # below tokens need to be set to avoid error during generation (converted to jnp.array)
466
- # they are not expected to be used and are set to unreachable token id
467
- config.bos_token_id = config.image_vocab_size + 1
468
- config.pos_token_id = config.image_vocab_size + 1
469
- config.eos_token_id = config.image_vocab_size + 1
470
-
471
- # save whether we normalize the text
472
- if model_args.normalize_text is not None:
473
- config.normalize_text = model_args.normalize_text
474
  else:
475
- config.normalize_text = getattr(config, "normalize_text", False)
476
 
477
  # Load or create new model
478
  if model_args.model_name_or_path:
479
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(
480
  model_args.model_name_or_path,
481
  config=config,
482
  seed=training_args.seed_model,
@@ -485,7 +427,7 @@ def main():
485
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
486
  print(model.params)
487
  else:
488
- model = CustomFlaxBartForConditionalGeneration(
489
  config,
490
  seed=training_args.seed_model,
491
  dtype=getattr(jnp, model_args.dtype),
@@ -512,6 +454,7 @@ def main():
512
  tokenizer=tokenizer,
513
  decoder_start_token_id=model.config.decoder_start_token_id,
514
  normalize_text=model.config.normalize_text,
 
515
  )
516
 
517
  # Initialize our training
 
44
  from transformers.models.bart.modeling_flax_bart import BartConfig
45
 
46
  from dalle_mini.data import Dataset
47
+ from dalle_mini.model import DalleBartConfig, DalleBartForConditionalGeneration
48
 
49
  logger = logging.getLogger(__name__)
50
 
 
68
  "help": "Pretrained config name or path if not the same as model_name"
69
  },
70
  )
 
 
 
 
 
 
 
 
71
  tokenizer_name: Optional[str] = field(
72
  default=None,
73
  metadata={
74
  "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
75
  },
76
  )
 
 
 
 
 
 
77
  dtype: Optional[str] = field(
78
  default="float32",
79
  metadata={
 
112
  default=None,
113
  metadata={"help": "An optional input evaluation data file (glob acceptable)."},
114
  )
 
 
 
 
115
  # data loading should not be a bottleneck so we use "streaming" mode by default
116
  streaming: bool = field(
117
  default=True,
 
123
  "help": "Whether to use the authentication token for private datasets."
124
  },
125
  )
 
 
 
 
 
 
 
126
  max_train_samples: Optional[int] = field(
127
  default=None,
128
  metadata={
 
411
 
412
  else:
413
  # Set up our new model config
 
414
  if model_args.config_name:
415
+ config = DalleBartConfig.from_pretrained(model_args.config_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  else:
417
+ config = DalleBartConfig.from_pretrained(model_args.model_name_or_path)
418
 
419
  # Load or create new model
420
  if model_args.model_name_or_path:
421
+ model = DalleBartForConditionalGeneration.from_pretrained(
422
  model_args.model_name_or_path,
423
  config=config,
424
  seed=training_args.seed_model,
 
427
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
428
  print(model.params)
429
  else:
430
+ model = DalleBartForConditionalGeneration(
431
  config,
432
  seed=training_args.seed_model,
433
  dtype=getattr(jnp, model_args.dtype),
 
454
  tokenizer=tokenizer,
455
  decoder_start_token_id=model.config.decoder_start_token_id,
456
  normalize_text=model.config.normalize_text,
457
+ max_length=model.config.max_text_length,
458
  )
459
 
460
  # Initialize our training