not save so often
Browse files- flax_model.msgpack +1 -1
- run_clm_mp.py +1 -1
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5262371934
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4af9ac77cf162bc4703db2637e08a4dc5861b3e5fcaefe2ef38d8edfda1898e
|
3 |
size 5262371934
|
run_clm_mp.py
CHANGED
@@ -632,7 +632,7 @@ def main():
|
|
632 |
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
|
633 |
)
|
634 |
|
635 |
-
if cur_step %
|
636 |
# save checkpoint after each epoch and push checkpoint to the hub
|
637 |
if jax.process_index() == 0:
|
638 |
params = jax.device_get(params)
|
|
|
632 |
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
|
633 |
)
|
634 |
|
635 |
+
if cur_step % steps_per_epoch == 0 and cur_step > 0:
|
636 |
# save checkpoint after each epoch and push checkpoint to the hub
|
637 |
if jax.process_index() == 0:
|
638 |
params = jax.device_get(params)
|