boris commited on
Commit
5173ec7
1 Parent(s): 1c4e839

feat: handle gradient checkpointing

Browse files
src/dalle_mini/model/modeling.py CHANGED
@@ -144,7 +144,7 @@ class FlaxBartEncoderLayerCollection(FlaxBartEncoderLayerCollection):
144
 
145
  def setup(self):
146
  layer_module = (
147
- nn.remat(FlaxBartEncoderLayer)
148
  if self.config.gradient_checkpointing
149
  else FlaxBartEncoderLayer
150
  )
@@ -211,7 +211,7 @@ class FlaxBartDecoderLayerCollection(FlaxBartDecoderLayerCollection):
211
 
212
  def setup(self):
213
  layer_module = (
214
- nn.remat(FlaxBartDecoderLayer)
215
  if self.config.gradient_checkpointing
216
  else FlaxBartDecoderLayer
217
  )
 
144
 
145
  def setup(self):
146
  layer_module = (
147
+ nn.remat(FlaxBartEncoderLayer, concrete=True)
148
  if self.config.gradient_checkpointing
149
  else FlaxBartEncoderLayer
150
  )
 
211
 
212
  def setup(self):
213
  layer_module = (
214
+ nn.remat(FlaxBartDecoderLayer, concrete=True)
215
  if self.config.gradient_checkpointing
216
  else FlaxBartDecoderLayer
217
  )
tools/train/train.py CHANGED
@@ -18,6 +18,7 @@ Training DALL·E Mini.
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
 
21
  import io
22
  import logging
23
  import os
@@ -531,6 +532,8 @@ def main():
531
  # Set up our new model config
532
  if model_args.config_name:
533
  config = DalleBartConfig.from_pretrained(model_args.config_name)
 
 
534
  else:
535
  config = None
536
 
@@ -553,8 +556,27 @@ def main():
553
  )
554
 
555
  # update model config per training args
 
 
556
  model.config.gradient_checkpointing = training_args.gradient_checkpointing
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  # get model metadata
559
  model_metadata = model_args.get_metadata()
560
 
@@ -967,7 +989,7 @@ def main():
967
 
968
  def compute_eval_loss(batch):
969
  batch, labels = batch.pop("labels")
970
- logits = state.apply_fn(**batch, params=state.params, train=False)[0]
971
  return loss_fn(logits, labels)
972
 
973
  # calculate loss independently per dp_device
 
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
21
+ import copy
22
  import io
23
  import logging
24
  import os
 
532
  # Set up our new model config
533
  if model_args.config_name:
534
  config = DalleBartConfig.from_pretrained(model_args.config_name)
535
+ # initializing params with gradient checkpointing create issues
536
+ config.gradient_checkpointing = False
537
  else:
538
  config = None
539
 
 
556
  )
557
 
558
  # update model config per training args
559
+ # Done after initialization of weights to avoid issues with remat
560
+ # This is still considered correctly during training as function is pjitted
561
  model.config.gradient_checkpointing = training_args.gradient_checkpointing
562
 
563
+ # eval model cannot use remat
564
+ eval_config = copy.deepcopy(model.config)
565
+ eval_config.gradient_checkpointing = False
566
+
567
+ if training_args.gradient_checkpointing:
568
+ eval_model = DalleBart(
569
+ eval_config,
570
+ seed=training_args.seed_model,
571
+ dtype=getattr(jnp, model_args.dtype),
572
+ abstract_init=True,
573
+ load_on_cpu=True,
574
+ )
575
+ del eval_model._params
576
+ eval_fn = eval_model.__call__
577
+ else:
578
+ eval_fn = model.__call__
579
+
580
  # get model metadata
581
  model_metadata = model_args.get_metadata()
582
 
 
989
 
990
  def compute_eval_loss(batch):
991
  batch, labels = batch.pop("labels")
992
+ logits = eval_fn(**batch, params=state.params, train=False)[0]
993
  return loss_fn(logits, labels)
994
 
995
  # calculate loss independently per dp_device