boris commited on
Commit
e2400cc
1 Parent(s): 36cb737

fix: OOM with checkpoints

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +4 -3
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -262,15 +262,15 @@ class TrainState(train_state.TrainState):
262
  def restore_state(self, artifact_dir):
263
  # restore optimizer state
264
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
265
- opt_state = from_bytes(self.opt_state, f.read())
266
 
267
  # restore steps
268
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
269
  training_state = json.load(f)
270
- step = training_state["step"]
271
 
272
  # replace state
273
- return self.replace(step=step, opt_state=opt_state)
274
 
275
 
276
  class CustomFlaxBartModule(FlaxBartModule):
@@ -802,6 +802,7 @@ def main():
802
 
803
  # Replicate the train state on each device
804
  state = state.replicate()
 
805
 
806
  logger.info("***** Running training *****")
807
  logger.info(f" Num examples = {len_train_dataset}")
 
262
  def restore_state(self, artifact_dir):
263
  # restore optimizer state
264
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
265
+ new_opt_state = from_bytes(self.opt_state, f.read())
266
 
267
  # restore steps
268
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
269
  training_state = json.load(f)
270
+ new_step = training_state["step"]
271
 
272
  # replace state
273
+ return self.replace(step=new_step, opt_state=new_opt_state)
274
 
275
 
276
  class CustomFlaxBartModule(FlaxBartModule):
 
802
 
803
  # Replicate the train state on each device
804
  state = state.replicate()
805
+ del model._params
806
 
807
  logger.info("***** Running training *****")
808
  logger.info(f" Num examples = {len_train_dataset}")