boris commited on
Commit
b798ed3
1 Parent(s): 79557f9

feat: no gradient checkpointing for params init

Browse files
Files changed (1) hide show
  1. tools/train/train.py +7 -6
tools/train/train.py CHANGED
@@ -531,8 +531,6 @@ def main():
531
  # Set up our new model config
532
  if model_args.config_name:
533
  config = DalleBartConfig.from_pretrained(model_args.config_name)
534
- # initializing params with gradient checkpointing create issues
535
- config.gradient_checkpointing = False
536
  else:
537
  config = None
538
 
@@ -545,6 +543,9 @@ def main():
545
  dtype=getattr(jnp, model_args.dtype),
546
  abstract_init=True,
547
  load_on_cpu=True,
 
 
 
548
  )
549
  else:
550
  model = DalleBart(
@@ -552,6 +553,7 @@ def main():
552
  seed=training_args.seed_model,
553
  dtype=getattr(jnp, model_args.dtype),
554
  load_on_cpu=True,
 
555
  )
556
 
557
  # update model config per training args
@@ -559,11 +561,10 @@ def main():
559
  # This is still considered correctly during training as function is pjitted
560
  model.config.gradient_checkpointing = training_args.gradient_checkpointing
561
 
562
- # eval model cannot use remat
563
- eval_config = copy.deepcopy(model.config)
564
- eval_config.gradient_checkpointing = False
565
-
566
  if training_args.gradient_checkpointing:
 
 
 
567
  eval_model = DalleBart(
568
  eval_config,
569
  seed=training_args.seed_model,
 
531
  # Set up our new model config
532
  if model_args.config_name:
533
  config = DalleBartConfig.from_pretrained(model_args.config_name)
 
 
534
  else:
535
  config = None
536
 
 
543
  dtype=getattr(jnp, model_args.dtype),
544
  abstract_init=True,
545
  load_on_cpu=True,
546
+ # initializing params with gradient checkpointing creates issues
547
+ # we correctly set it later per training_args
548
+ gradient_checkpointing=False,
549
  )
550
  else:
551
  model = DalleBart(
 
553
  seed=training_args.seed_model,
554
  dtype=getattr(jnp, model_args.dtype),
555
  load_on_cpu=True,
556
+ gradient_checkpointing=False,
557
  )
558
 
559
  # update model config per training args
 
561
  # This is still considered correctly during training as function is pjitted
562
  model.config.gradient_checkpointing = training_args.gradient_checkpointing
563
 
 
 
 
 
564
  if training_args.gradient_checkpointing:
565
+ # eval model cannot use remat
566
+ eval_config = copy.deepcopy(model.config)
567
+ eval_config.gradient_checkpointing = False
568
  eval_model = DalleBart(
569
  eval_config,
570
  seed=training_args.seed_model,