Pedro Cuenca commited on
Commit
9f522b8
1 Parent(s): 290e443

Accept changes suggested by linter.

Browse files
src/dalle_mini/model/modeling.py CHANGED
@@ -569,14 +569,18 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
569
  """
570
  Initializes from a wandb artifact, or delegates loading to the superclass.
571
  """
572
- if ':' in pretrained_model_name_or_path and not os.path.isdir(pretrained_model_name_or_path):
 
 
573
  # wandb artifact
574
  artifact = wandb.Api().artifact(pretrained_model_name_or_path)
575
-
576
  # we download everything, including opt_state, so we can resume training if needed
577
  # see also: #120
578
  pretrained_model_name_or_path = artifact.download()
579
 
580
- model = super(DalleBart, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
 
 
581
  model.config.resolved_name_or_path = pretrained_model_name_or_path
582
  return model
 
569
  """
570
  Initializes from a wandb artifact, or delegates loading to the superclass.
571
  """
572
+ if ":" in pretrained_model_name_or_path and not os.path.isdir(
573
+ pretrained_model_name_or_path
574
+ ):
575
  # wandb artifact
576
  artifact = wandb.Api().artifact(pretrained_model_name_or_path)
577
+
578
  # we download everything, including opt_state, so we can resume training if needed
579
  # see also: #120
580
  pretrained_model_name_or_path = artifact.download()
581
 
582
+ model = super(DalleBart, cls).from_pretrained(
583
+ pretrained_model_name_or_path, *model_args, **kwargs
584
+ )
585
  model.config.resolved_name_or_path = pretrained_model_name_or_path
586
  return model
tools/train/train.py CHANGED
@@ -437,7 +437,9 @@ def main():
437
  if training_args.resume_from_checkpoint is not None:
438
  # load model
439
  model = DalleBart.from_pretrained(
440
- training_args.resume_from_checkpoint, dtype=getattr(jnp, model_args.dtype), abstract_init=True
 
 
441
  )
442
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
443
  print(model.params)
 
437
  if training_args.resume_from_checkpoint is not None:
438
  # load model
439
  model = DalleBart.from_pretrained(
440
+ training_args.resume_from_checkpoint,
441
+ dtype=getattr(jnp, model_args.dtype),
442
+ abstract_init=True,
443
  )
444
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
445
  print(model.params)