boris commited on
Commit
0a77f72
1 Parent(s): 2816f98

feat: use pretrained weights

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +52 -19
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -68,6 +68,12 @@ class ModelArguments:
68
  "Don't set if you want to train a model from scratch."
69
  },
70
  )
 
 
 
 
 
 
71
  image_vocab_size: Optional[int] = field(
72
  default=None,
73
  metadata={"help": "Vocab size of image encoder"},
@@ -82,9 +88,11 @@ class ModelArguments:
82
  "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
83
  },
84
  )
85
- normalize_text: bool = field(
86
- default=False,
87
- metadata={"help": "Whether to normalize text or not."},
 
 
88
  )
89
  dtype: Optional[str] = field(
90
  default="float32",
@@ -125,8 +133,9 @@ class DataTrainingArguments:
125
  "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
126
  },
127
  )
 
128
  streaming: bool = field(
129
- default=False,
130
  metadata={"help": "Whether to stream the dataset."},
131
  )
132
  use_auth_token: bool = field(
@@ -283,9 +292,9 @@ class TrainingArguments:
283
  },
284
  )
285
 
286
- resume_from_wandb_checkpoint: Optional[str] = field(
287
  default=None,
288
- metadata={"help": "The reference to a wandb artifact for resuming training."},
289
  )
290
 
291
 
@@ -460,8 +469,8 @@ def main():
460
  config=parser.parse_args(),
461
  )
462
 
463
- if training_args.resume_from_wandb_checkpoint is not None:
464
- artifact = wandb.run.use_artifact(training_args.resume_from_wandb_checkpoint)
465
  artifact_dir = artifact.download()
466
 
467
  # load model
@@ -476,9 +485,20 @@ def main():
476
  else:
477
  # Set up our new model config
478
  # TODO: simplify with custom config class
479
- config = BartConfig.from_pretrained(model_args.model_name_or_path)
480
- config.image_vocab_size = model_args.image_vocab_size
481
- config.image_length = model_args.image_length
 
 
 
 
 
 
 
 
 
 
 
482
  # we append decoder bos to image vocab
483
  config.decoder_start_token_id = config.image_vocab_size
484
  # ensure we don't generate bos (in addition to decoder start token)
@@ -487,8 +507,8 @@ def main():
487
  config.forced_eos_token_id = None # we don't need this token
488
 
489
  config.tie_word_embeddings = False
490
- config.min_length = model_args.image_length + 1
491
- config.max_length = model_args.image_length + 1
492
 
493
  # below tokens need to be set to avoid error during generation (converted to jnp.array)
494
  # they are not expected to be used and are set to unreachable token id
@@ -497,12 +517,25 @@ def main():
497
  config.eos_token_id = config.image_vocab_size + 1
498
 
499
  # save whether we normalize the text
500
- config.normalize_text = model_args.normalize_text
 
 
 
501
 
502
- # Create a custom model and initialize it randomly
503
- model = CustomFlaxBartForConditionalGeneration(
504
- config, seed=training_args.seed_model, dtype=getattr(jnp, model_args.dtype)
505
- )
 
 
 
 
 
 
 
 
 
 
506
 
507
  # Load tokenizer
508
  if model_args.tokenizer_name is not None:
@@ -741,7 +774,7 @@ def main():
741
  tx=optimizer,
742
  dropout_rng=dropout_rng,
743
  )
744
- if training_args.resume_from_wandb_checkpoint is not None:
745
  # restore optimizer state and other parameters
746
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
747
  state = state.restore_state(artifact_dir)
 
68
  "Don't set if you want to train a model from scratch."
69
  },
70
  )
71
+ config_name: Optional[str] = field(
72
+ default=None,
73
+ metadata={
74
+ "help": "Pretrained config name or path if not the same as model_name"
75
+ },
76
+ )
77
  image_vocab_size: Optional[int] = field(
78
  default=None,
79
  metadata={"help": "Vocab size of image encoder"},
 
88
  "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
89
  },
90
  )
91
+ normalize_text: Optional[bool] = field(
92
+ default=None,
93
+ metadata={
94
+ "help": "Whether to normalize text or not. By default, we refer to base model or don't normalize for new models."
95
+ },
96
  )
97
  dtype: Optional[str] = field(
98
  default="float32",
 
133
  "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
134
  },
135
  )
136
+ # data loading should not be a bottleneck so we use "streaming" mode by default
137
  streaming: bool = field(
138
+ default=True,
139
  metadata={"help": "Whether to stream the dataset."},
140
  )
141
  use_auth_token: bool = field(
 
292
  },
293
  )
294
 
295
+ resume_from_checkpoint: Optional[str] = field(
296
  default=None,
297
+ metadata={"help": "Reference to a wandb artifact for resuming training."},
298
  )
299
 
300
 
 
469
  config=parser.parse_args(),
470
  )
471
 
472
+ if training_args.resume_from_checkpoint is not None:
473
+ artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
474
  artifact_dir = artifact.download()
475
 
476
  # load model
 
485
  else:
486
  # Set up our new model config
487
  # TODO: simplify with custom config class
488
+ if model_args.config_name:
489
+ config = BartConfig.from_pretrained(model_args.config_name)
490
+ else:
491
+ config = BartConfig.from_pretrained(model_args.model_name_or_path)
492
+ if model_args.image_vocab_size:
493
+ config.image_vocab_size = model_args.image_vocab_size
494
+ assert (
495
+ getattr(config, "image_vocab_size") is not None
496
+ ), "image_vocab_size must be specified when not present in base model/config"
497
+ if model_args.image_length:
498
+ config.image_length = model_args.image_length
499
+ assert (
500
+ getattr(config, "image_length") is not None
501
+ ), "image_length must be specified when not present in base model/config"
502
  # we append decoder bos to image vocab
503
  config.decoder_start_token_id = config.image_vocab_size
504
  # ensure we don't generate bos (in addition to decoder start token)
 
507
  config.forced_eos_token_id = None # we don't need this token
508
 
509
  config.tie_word_embeddings = False
510
+ config.min_length = config.image_length + 1
511
+ config.max_length = config.image_length + 1
512
 
513
  # below tokens need to be set to avoid error during generation (converted to jnp.array)
514
  # they are not expected to be used and are set to unreachable token id
 
517
  config.eos_token_id = config.image_vocab_size + 1
518
 
519
  # save whether we normalize the text
520
+ if model_args.normalize_text is not None:
521
+ config.normalize_text = model_args.normalize_text
522
+ else:
523
+ config.normalize_text = getattr(config, "normalize_text", False)
524
 
525
+ # Load or create new model
526
+ if model_args.model_name_or_path:
527
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(
528
+ model_args.model_name_or_path,
529
+ config=config,
530
+ seed=training_args.seed_model,
531
+ dtype=getattr(jnp, model_args.dtype),
532
+ )
533
+ else:
534
+ model = CustomFlaxBartForConditionalGeneration(
535
+ config,
536
+ seed=training_args.seed_model,
537
+ dtype=getattr(jnp, model_args.dtype),
538
+ )
539
 
540
  # Load tokenizer
541
  if model_args.tokenizer_name is not None:
 
774
  tx=optimizer,
775
  dropout_rng=dropout_rng,
776
  )
777
+ if training_args.resume_from_checkpoint is not None:
778
  # restore optimizer state and other parameters
779
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
780
  state = state.restore_state(artifact_dir)