boris commited on
Commit
6016fc0
2 Parent(s): 2816f98 2be9847

Merge pull request #109 from borisdayma/feat-model_pretrained

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +56 -19
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -68,6 +68,12 @@ class ModelArguments:
68
  "Don't set if you want to train a model from scratch."
69
  },
70
  )
 
 
 
 
 
 
71
  image_vocab_size: Optional[int] = field(
72
  default=None,
73
  metadata={"help": "Vocab size of image encoder"},
@@ -82,9 +88,11 @@ class ModelArguments:
82
  "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
83
  },
84
  )
85
- normalize_text: bool = field(
86
- default=False,
87
- metadata={"help": "Whether to normalize text or not."},
 
 
88
  )
89
  dtype: Optional[str] = field(
90
  default="float32",
@@ -125,8 +133,9 @@ class DataTrainingArguments:
125
  "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
126
  },
127
  )
 
128
  streaming: bool = field(
129
- default=False,
130
  metadata={"help": "Whether to stream the dataset."},
131
  )
132
  use_auth_token: bool = field(
@@ -283,9 +292,9 @@ class TrainingArguments:
283
  },
284
  )
285
 
286
- resume_from_wandb_checkpoint: Optional[str] = field(
287
  default=None,
288
- metadata={"help": "The reference to a wandb artifact for resuming training."},
289
  )
290
 
291
 
@@ -460,12 +469,14 @@ def main():
460
  config=parser.parse_args(),
461
  )
462
 
463
- if training_args.resume_from_wandb_checkpoint is not None:
464
- artifact = wandb.run.use_artifact(training_args.resume_from_wandb_checkpoint)
465
  artifact_dir = artifact.download()
466
 
467
  # load model
468
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
 
 
469
 
470
  # load tokenizer
471
  tokenizer = AutoTokenizer.from_pretrained(
@@ -476,9 +487,20 @@ def main():
476
  else:
477
  # Set up our new model config
478
  # TODO: simplify with custom config class
479
- config = BartConfig.from_pretrained(model_args.model_name_or_path)
480
- config.image_vocab_size = model_args.image_vocab_size
481
- config.image_length = model_args.image_length
 
 
 
 
 
 
 
 
 
 
 
482
  # we append decoder bos to image vocab
483
  config.decoder_start_token_id = config.image_vocab_size
484
  # ensure we don't generate bos (in addition to decoder start token)
@@ -487,8 +509,8 @@ def main():
487
  config.forced_eos_token_id = None # we don't need this token
488
 
489
  config.tie_word_embeddings = False
490
- config.min_length = model_args.image_length + 1
491
- config.max_length = model_args.image_length + 1
492
 
493
  # below tokens need to be set to avoid error during generation (converted to jnp.array)
494
  # they are not expected to be used and are set to unreachable token id
@@ -497,12 +519,27 @@ def main():
497
  config.eos_token_id = config.image_vocab_size + 1
498
 
499
  # save whether we normalize the text
500
- config.normalize_text = model_args.normalize_text
 
 
 
501
 
502
- # Create a custom model and initialize it randomly
503
- model = CustomFlaxBartForConditionalGeneration(
504
- config, seed=training_args.seed_model, dtype=getattr(jnp, model_args.dtype)
505
- )
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
  # Load tokenizer
508
  if model_args.tokenizer_name is not None:
@@ -741,7 +778,7 @@ def main():
741
  tx=optimizer,
742
  dropout_rng=dropout_rng,
743
  )
744
- if training_args.resume_from_wandb_checkpoint is not None:
745
  # restore optimizer state and other parameters
746
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
747
  state = state.restore_state(artifact_dir)
 
68
  "Don't set if you want to train a model from scratch."
69
  },
70
  )
71
+ config_name: Optional[str] = field(
72
+ default=None,
73
+ metadata={
74
+ "help": "Pretrained config name or path if not the same as model_name"
75
+ },
76
+ )
77
  image_vocab_size: Optional[int] = field(
78
  default=None,
79
  metadata={"help": "Vocab size of image encoder"},
 
88
  "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
89
  },
90
  )
91
+ normalize_text: Optional[bool] = field(
92
+ default=None,
93
+ metadata={
94
+ "help": "Whether to normalize text or not. By default, we refer to base model or don't normalize for new models."
95
+ },
96
  )
97
  dtype: Optional[str] = field(
98
  default="float32",
 
133
  "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
134
  },
135
  )
136
+ # data loading should not be a bottleneck so we use "streaming" mode by default
137
  streaming: bool = field(
138
+ default=True,
139
  metadata={"help": "Whether to stream the dataset."},
140
  )
141
  use_auth_token: bool = field(
 
292
  },
293
  )
294
 
295
+ resume_from_checkpoint: Optional[str] = field(
296
  default=None,
297
+ metadata={"help": "Reference to a wandb artifact for resuming training."},
298
  )
299
 
300
 
 
469
  config=parser.parse_args(),
470
  )
471
 
472
+ if training_args.resume_from_checkpoint is not None:
473
+ artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
474
  artifact_dir = artifact.download()
475
 
476
  # load model
477
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
478
+ # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
479
+ print(model.params)
480
 
481
  # load tokenizer
482
  tokenizer = AutoTokenizer.from_pretrained(
 
487
  else:
488
  # Set up our new model config
489
  # TODO: simplify with custom config class
490
+ if model_args.config_name:
491
+ config = BartConfig.from_pretrained(model_args.config_name)
492
+ else:
493
+ config = BartConfig.from_pretrained(model_args.model_name_or_path)
494
+ if model_args.image_vocab_size:
495
+ config.image_vocab_size = model_args.image_vocab_size
496
+ assert (
497
+ getattr(config, "image_vocab_size") is not None
498
+ ), "image_vocab_size must be specified when not present in base model/config"
499
+ if model_args.image_length:
500
+ config.image_length = model_args.image_length
501
+ assert (
502
+ getattr(config, "image_length") is not None
503
+ ), "image_length must be specified when not present in base model/config"
504
  # we append decoder bos to image vocab
505
  config.decoder_start_token_id = config.image_vocab_size
506
  # ensure we don't generate bos (in addition to decoder start token)
 
509
  config.forced_eos_token_id = None # we don't need this token
510
 
511
  config.tie_word_embeddings = False
512
+ config.min_length = config.image_length + 1
513
+ config.max_length = config.image_length + 1
514
 
515
  # below tokens need to be set to avoid error during generation (converted to jnp.array)
516
  # they are not expected to be used and are set to unreachable token id
 
519
  config.eos_token_id = config.image_vocab_size + 1
520
 
521
  # save whether we normalize the text
522
+ if model_args.normalize_text is not None:
523
+ config.normalize_text = model_args.normalize_text
524
+ else:
525
+ config.normalize_text = getattr(config, "normalize_text", False)
526
 
527
+ # Load or create new model
528
+ if model_args.model_name_or_path:
529
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(
530
+ model_args.model_name_or_path,
531
+ config=config,
532
+ seed=training_args.seed_model,
533
+ dtype=getattr(jnp, model_args.dtype),
534
+ )
535
+ # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
536
+ print(model.params)
537
+ else:
538
+ model = CustomFlaxBartForConditionalGeneration(
539
+ config,
540
+ seed=training_args.seed_model,
541
+ dtype=getattr(jnp, model_args.dtype),
542
+ )
543
 
544
  # Load tokenizer
545
  if model_args.tokenizer_name is not None:
 
778
  tx=optimizer,
779
  dropout_rng=dropout_rng,
780
  )
781
+ if training_args.resume_from_checkpoint is not None:
782
  # restore optimizer state and other parameters
783
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
784
  state = state.restore_state(artifact_dir)