boris commited on
Commit
862924a
2 Parent(s): 499ddb2 6d252e9

Merge pull request #29 from borisdayma/load_checkpoint

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +44 -23
seq2seq/run_seq2seq_flax.py CHANGED
@@ -125,6 +125,12 @@ class ModelArguments:
125
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
126
  },
127
  )
 
 
 
 
 
 
128
 
129
 
130
  @dataclass
@@ -424,36 +430,51 @@ def main():
424
  # https://huggingface.co/docs/datasets/loading_datasets.html.
425
 
426
  # Load pretrained model and tokenizer
427
- base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
428
- model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
429
- )
430
  tokenizer = AutoTokenizer.from_pretrained(
431
  model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
432
  )
433
 
434
- # Set up our new model config
435
- config = BartConfig.from_pretrained(model_args.model_name_or_path)
436
- config.tie_word_embeddings = False
437
- config.decoder_start_token_id = BOS_TOKEN_ID # for first token
438
- config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
439
- config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
440
- config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
441
- config.forced_bos_token_id = None # we don't need this token
442
- config.forced_eos_token_id = None # we don't need this token
443
- config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
444
- config.min_length = data_args.max_target_length
445
- config.max_length = data_args.max_target_length
446
 
447
- print(f"TPUs: {jax.device_count()}")
448
- assert jax.device_count() == 8, "TPUs in use, please check running processes"
 
 
 
 
 
 
449
 
450
- # Create a custom model and initialize it randomly
451
- model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
- # Use pre-trained weights for encoder
454
- model.params['model']['encoder'] = base_model.params['model']['encoder']
455
- model.params['model']['shared'] = base_model.params['model']['shared']
456
- del base_model
457
 
458
  prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
459
 
 
125
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
126
  },
127
  )
128
+ from_checkpoint: Optional[str] = field(
129
+ default=None,
130
+ metadata={
131
+ "help": "Loads a pretrained wandb checkpoint. Use artifact reference."
132
+ },
133
+ )
134
 
135
 
136
  @dataclass
 
430
  # https://huggingface.co/docs/datasets/loading_datasets.html.
431
 
432
  # Load pretrained model and tokenizer
 
 
 
433
  tokenizer = AutoTokenizer.from_pretrained(
434
  model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
435
  )
436
 
437
+ if model_args.from_checkpoint is not None:
438
+ artifact = wandb.run.use_artifact(model_args.from_checkpoint)
439
+ artifact_dir = artifact.download()
440
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
 
 
 
 
 
 
 
 
441
 
442
+ # some models will try to change bos (because of force_bos_token_to_be_generated)
443
+ # we ensure bos and eos are not forced
444
+ model.config.force_bos_token_to_be_generated = False
445
+ model.config.forced_bos_token_id = None
446
+ model.config.forced_eos_token_id = None
447
+
448
+ # used in the preprocessing function
449
+ config = model.config
450
 
451
+ else:
452
+ base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
453
+ model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
454
+ )
455
+ # Set up our new model config
456
+ config = BartConfig.from_pretrained(model_args.model_name_or_path)
457
+ config.tie_word_embeddings = False
458
+ config.decoder_start_token_id = BOS_TOKEN_ID # for first token
459
+ config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
460
+ config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
461
+ config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
462
+ config.forced_bos_token_id = None # we don't need this token
463
+ config.forced_eos_token_id = None # we don't need this token
464
+ config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
465
+ config.min_length = data_args.max_target_length
466
+ config.max_length = data_args.max_target_length
467
+
468
+ # Create a custom model and initialize it randomly
469
+ model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
470
+
471
+ # Use pre-trained weights for encoder
472
+ model.params['model']['encoder'] = base_model.params['model']['encoder']
473
+ model.params['model']['shared'] = base_model.params['model']['shared']
474
+ del base_model
475
 
476
+ print(f"TPUs: {jax.device_count()}")
477
+ assert jax.device_count() == 8, "TPUs in use, please check running processes"
 
 
478
 
479
  prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
480