boris commited on
Commit
3d61350
1 Parent(s): 499ddb2

feat: allow loading a model checkpoint

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +41 -23
seq2seq/run_seq2seq_flax.py CHANGED
@@ -125,6 +125,12 @@ class ModelArguments:
125
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
126
  },
127
  )
 
 
 
 
 
 
128
 
129
 
130
  @dataclass
@@ -424,36 +430,48 @@ def main():
424
  # https://huggingface.co/docs/datasets/loading_datasets.html.
425
 
426
  # Load pretrained model and tokenizer
427
- base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
428
- model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
429
- )
430
  tokenizer = AutoTokenizer.from_pretrained(
431
  model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
432
  )
433
 
434
- # Set up our new model config
435
- config = BartConfig.from_pretrained(model_args.model_name_or_path)
436
- config.tie_word_embeddings = False
437
- config.decoder_start_token_id = BOS_TOKEN_ID # for first token
438
- config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
439
- config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
440
- config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
441
- config.forced_bos_token_id = None # we don't need this token
442
- config.forced_eos_token_id = None # we don't need this token
443
- config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
444
- config.min_length = data_args.max_target_length
445
- config.max_length = data_args.max_target_length
446
 
447
- print(f"TPUs: {jax.device_count()}")
448
- assert jax.device_count() == 8, "TPUs in use, please check running processes"
 
 
 
449
 
450
- # Create a custom model and initialize it randomly
451
- model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
- # Use pre-trained weights for encoder
454
- model.params['model']['encoder'] = base_model.params['model']['encoder']
455
- model.params['model']['shared'] = base_model.params['model']['shared']
456
- del base_model
457
 
458
  prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
459
 
 
125
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
126
  },
127
  )
128
+ from_checkpoint: Optional[str] = field(
129
+ default=None,
130
+ metadata={
131
+ "help": "Loads a pretrained wandb checkpoint. Use artifact reference."
132
+ },
133
+ )
134
 
135
 
136
  @dataclass
 
430
  # https://huggingface.co/docs/datasets/loading_datasets.html.
431
 
432
  # Load pretrained model and tokenizer
 
 
 
433
  tokenizer = AutoTokenizer.from_pretrained(
434
  model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
435
  )
436
 
437
+ if model_args.from_checkpoint is not None:
438
+ artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-3h3x3565:latest')
439
+ artifact_dir = artifact.download()
440
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
 
 
 
 
 
 
 
 
441
 
442
+ # some models will try to change bos (because of force_bos_token_to_be_generated)
443
+ # we ensure bos and eos are not forced
444
+ model.config.force_bos_token_to_be_generated = False
445
+ model.config.forced_bos_token_id = None
446
+ model.config.forced_eos_token_id = None
447
 
448
+ else:
449
+ base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
450
+ model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
451
+ )
452
+ # Set up our new model config
453
+ config = BartConfig.from_pretrained(model_args.model_name_or_path)
454
+ config.tie_word_embeddings = False
455
+ config.decoder_start_token_id = BOS_TOKEN_ID # for first token
456
+ config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
457
+ config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
458
+ config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
459
+ config.forced_bos_token_id = None # we don't need this token
460
+ config.forced_eos_token_id = None # we don't need this token
461
+ config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
462
+ config.min_length = data_args.max_target_length
463
+ config.max_length = data_args.max_target_length
464
+
465
+ # Create a custom model and initialize it randomly
466
+ model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
467
+
468
+ # Use pre-trained weights for encoder
469
+ model.params['model']['encoder'] = base_model.params['model']['encoder']
470
+ model.params['model']['shared'] = base_model.params['model']['shared']
471
+ del base_model
472
 
473
+ print(f"TPUs: {jax.device_count()}")
474
+ assert jax.device_count() == 8, "TPUs in use, please check running processes"
 
 
475
 
476
  prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
477