Pedro Cuenca commited on
Commit
ae983d7
·
1 Parent(s): 7e48337

Use DalleBartTokenizer. State restoration reverted to previous method:

Browse files

explicitly download artifact and use the download directory.

A better solution will be addressed in #120.

Files changed (1) hide show
  1. tools/train/train.py +13 -8
tools/train/train.py CHANGED
@@ -44,7 +44,7 @@ from tqdm import tqdm
44
  from transformers import AutoTokenizer, HfArgumentParser
45
 
46
  from dalle_mini.data import Dataset
47
- from dalle_mini.model import DalleBart, DalleBartConfig
48
 
49
  logger = logging.getLogger(__name__)
50
 
@@ -435,9 +435,15 @@ def main():
435
  )
436
 
437
  if training_args.resume_from_checkpoint is not None:
 
 
 
 
 
 
438
  # load model
439
  model = DalleBart.from_pretrained(
440
- training_args.resume_from_checkpoint,
441
  dtype=getattr(jnp, model_args.dtype),
442
  abstract_init=True,
443
  )
@@ -445,8 +451,8 @@ def main():
445
  print(model.params)
446
 
447
  # load tokenizer
448
- tokenizer = AutoTokenizer.from_pretrained(
449
- model.config.resolved_name_or_path,
450
  use_fast=True,
451
  )
452
 
@@ -481,9 +487,8 @@ def main():
481
  model_args.tokenizer_name, use_fast=True
482
  )
483
  else:
484
- # Use non-standard configuration property set by `DalleBart.from_pretrained`
485
- tokenizer = AutoTokenizer.from_pretrained(
486
- model.config.resolved_name_or_path,
487
  use_fast=True,
488
  )
489
 
@@ -621,7 +626,7 @@ def main():
621
  if training_args.resume_from_checkpoint is not None:
622
  # restore optimizer state and other parameters
623
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
624
- state = state.restore_state(model.config.resolved_name_or_path)
625
 
626
  # label smoothed cross entropy
627
  def loss_fn(logits, labels):
 
44
  from transformers import AutoTokenizer, HfArgumentParser
45
 
46
  from dalle_mini.data import Dataset
47
+ from dalle_mini.model import DalleBart, DalleBartConfig, DalleBartTokenizer
48
 
49
  logger = logging.getLogger(__name__)
50
 
 
435
  )
436
 
437
  if training_args.resume_from_checkpoint is not None:
438
+ if jax.process_index() == 0:
439
+ artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
440
+ else:
441
+ artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
442
+ artifact_dir = artifact.download()
443
+
444
  # load model
445
  model = DalleBart.from_pretrained(
446
+ artifact_dir,
447
  dtype=getattr(jnp, model_args.dtype),
448
  abstract_init=True,
449
  )
 
451
  print(model.params)
452
 
453
  # load tokenizer
454
+ tokenizer = DalleBartTokenizer.from_pretrained(
455
+ artifact_dir,
456
  use_fast=True,
457
  )
458
 
 
487
  model_args.tokenizer_name, use_fast=True
488
  )
489
  else:
490
+ tokenizer = DalleBartTokenizer.from_pretrained(
491
+ model_args.model_name_or_path,
 
492
  use_fast=True,
493
  )
494
 
 
626
  if training_args.resume_from_checkpoint is not None:
627
  # restore optimizer state and other parameters
628
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
629
+ state = state.restore_state(artifact_dir)
630
 
631
  # label smoothed cross entropy
632
  def loss_fn(logits, labels):