Pedro Cuenca commited on
Commit
bb3f53e
1 Parent(s): 08dd098

Update `resume_from_checkpoint` to use `from_pretrained`.

Browse files
Files changed (1) hide show
  1. tools/train/train.py +3 -9
tools/train/train.py CHANGED
@@ -434,22 +434,16 @@ def main():
434
  )
435
 
436
  if training_args.resume_from_checkpoint is not None:
437
- if jax.process_index() == 0:
438
- artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
439
- else:
440
- artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
441
- artifact_dir = artifact.download()
442
-
443
  # load model
444
  model = DalleBart.from_pretrained(
445
- artifact_dir, dtype=getattr(jnp, model_args.dtype), abstract_init=True
446
  )
447
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
448
  print(model.params)
449
 
450
  # load tokenizer
451
  tokenizer = AutoTokenizer.from_pretrained(
452
- artifact_dir,
453
  use_fast=True,
454
  )
455
 
@@ -624,7 +618,7 @@ def main():
624
  if training_args.resume_from_checkpoint is not None:
625
  # restore optimizer state and other parameters
626
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
627
- state = state.restore_state(artifact_dir)
628
 
629
  # label smoothed cross entropy
630
  def loss_fn(logits, labels):
 
434
  )
435
 
436
  if training_args.resume_from_checkpoint is not None:
 
 
 
 
 
 
437
  # load model
438
  model = DalleBart.from_pretrained(
439
+ training_args.resume_from_checkpoint, dtype=getattr(jnp, model_args.dtype), abstract_init=True
440
  )
441
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
442
  print(model.params)
443
 
444
  # load tokenizer
445
  tokenizer = AutoTokenizer.from_pretrained(
446
+ model.config.resolved_name_or_path,
447
  use_fast=True,
448
  )
449
 
 
618
  if training_args.resume_from_checkpoint is not None:
619
  # restore optimizer state and other parameters
620
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
621
+ state = state.restore_state(model.config.resolved_name_or_path)
622
 
623
  # label smoothed cross entropy
624
  def loss_fn(logits, labels):