boris commited on
Commit
2e02683
1 Parent(s): b798ed3

fix: no gradient checkpointing for new model

Browse files
Files changed (1) hide show
  1. tools/train/train.py +3 -1
tools/train/train.py CHANGED
@@ -531,6 +531,9 @@ def main():
531
  # Set up our new model config
532
  if model_args.config_name:
533
  config = DalleBartConfig.from_pretrained(model_args.config_name)
 
 
 
534
  else:
535
  config = None
536
 
@@ -553,7 +556,6 @@ def main():
553
  seed=training_args.seed_model,
554
  dtype=getattr(jnp, model_args.dtype),
555
  load_on_cpu=True,
556
- gradient_checkpointing=False,
557
  )
558
 
559
  # update model config per training args
 
531
  # Set up our new model config
532
  if model_args.config_name:
533
  config = DalleBartConfig.from_pretrained(model_args.config_name)
534
+ # initializing params with gradient checkpointing creates issues
535
+ # we correctly set it later per training_args
536
+ config.gradient_checkpointing = False
537
  else:
538
  config = None
539
 
 
556
  seed=training_args.seed_model,
557
  dtype=getattr(jnp, model_args.dtype),
558
  load_on_cpu=True,
 
559
  )
560
 
561
  # update model config per training args