pere commited on
Commit
9aff105
1 Parent(s): 5db3c6b

not save so often

Browse files
Files changed (2) hide show
  1. flax_model.msgpack +1 -1
  2. run_clm_mp.py +1 -1
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fe397b86f6a483191db26e2778ff9e645c14107b76758b09db78a9338e7a8a6
3
  size 5262371934
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4af9ac77cf162bc4703db2637e08a4dc5861b3e5fcaefe2ef38d8edfda1898e
3
  size 5262371934
run_clm_mp.py CHANGED
@@ -632,7 +632,7 @@ def main():
632
  f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
633
  )
634
 
635
- if cur_step % training_args.save_steps == 0 and cur_step > 0:
636
  # save checkpoint after each epoch and push checkpoint to the hub
637
  if jax.process_index() == 0:
638
  params = jax.device_get(params)
 
632
  f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
633
  )
634
 
635
+ if cur_step % steps_per_epoch == 0 and cur_step > 0:
636
  # save checkpoint after each epoch and push checkpoint to the hub
637
  if jax.process_index() == 0:
638
  params = jax.device_get(params)