boris commited on
Commit
1d04ab3
1 Parent(s): 0c9ff65

fix: actually replace state

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +10 -11
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -435,18 +435,16 @@ def main():
435
 
436
  def restore_state(state, artifact_dir):
437
  # restore optimizer state
438
- if (Path(artifact_dir) / 'opt_state.msgpack').exists():
439
- with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
440
- opt_state = from_bytes(state.opt_state, f.read())
441
- state.replace(opt_state=opt_state)
442
 
443
  # restore steps
444
- if (Path(artifact_dir) / 'training_state.json').exists():
445
- with (Path(artifact_dir) / 'training_state.json').open('r') as f:
446
- training_state = json.load(f)
447
- step = training_state['step']
448
- optimizer_step = step // training_args.gradient_accumulation_steps
449
- state.replace(step=step, optimizer_step=optimizer_step)
450
 
451
  if model_args.from_checkpoint is not None:
452
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
@@ -668,7 +666,8 @@ def main():
668
  )
669
  if model_args.from_checkpoint is not None:
670
  # restore optimizer state, step and optimizer_step
671
- restore_state(state, artifact_dir)
 
672
 
673
  # label smoothed cross entropy
674
  def loss_fn(logits, labels):
 
435
 
436
  def restore_state(state, artifact_dir):
437
  # restore optimizer state
438
+ with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
439
+ opt_state = from_bytes(state.opt_state, f.read())
 
 
440
 
441
  # restore steps
442
+ with (Path(artifact_dir) / 'training_state.json').open('r') as f:
443
+ training_state = json.load(f)
444
+ step = training_state['step']
445
+ optimizer_step = step // training_args.gradient_accumulation_steps
446
+
447
+ return step, optimizer_step, opt_state
448
 
449
  if model_args.from_checkpoint is not None:
450
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
 
666
  )
667
  if model_args.from_checkpoint is not None:
668
  # restore optimizer state, step and optimizer_step
669
+ step, optimizer_step, opt_state = restore_state(state, artifact_dir)
670
+ state = state.replace(step=step, optimizer_step=optimizer_step, opt_state=opt_state)
671
 
672
  # label smoothed cross entropy
673
  def loss_fn(logits, labels):