Saving weights and logs of step 200949
Browse files
__pycache__/partitions.cpython-38.pyc
CHANGED
Binary files a/__pycache__/partitions.cpython-38.pyc and b/__pycache__/partitions.cpython-38.pyc differ
|
|
events.out.tfevents.1631437573.t1v-n-1a0a7c50-w-0.2668975.0.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c91cd6fd5f940d9fd43ca0a987b067a45de3926cb3b2c74491b72c278a7e734
|
3 |
+
size 29767300
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5262371934
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5cd155c5ef1aa1fe765db196793aef1ec57df9d9c9388d93f7bc24381f7c5be1
|
3 |
size 5262371934
|
run_clm_mp.py
CHANGED
@@ -263,7 +263,6 @@ def main():
|
|
263 |
#
|
264 |
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
265 |
# 'text' is found. You can easily tweak this behavior (see below).
|
266 |
-
|
267 |
if data_args.dataset_name is not None:
|
268 |
# Downloading and loading a dataset from the hub.
|
269 |
dataset = load_dataset(
|
@@ -633,7 +632,7 @@ def main():
|
|
633 |
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
|
634 |
)
|
635 |
|
636 |
-
if cur_step %
|
637 |
# save checkpoint after each epoch and push checkpoint to the hub
|
638 |
if jax.process_index() == 0:
|
639 |
params = jax.device_get(params)
|
|
|
263 |
#
|
264 |
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
265 |
# 'text' is found. You can easily tweak this behavior (see below).
|
|
|
266 |
if data_args.dataset_name is not None:
|
267 |
# Downloading and loading a dataset from the hub.
|
268 |
dataset = load_dataset(
|
|
|
632 |
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
|
633 |
)
|
634 |
|
635 |
+
if cur_step % steps_per_epoch == 0 and cur_step > 0:
|
636 |
# save checkpoint after each epoch and push checkpoint to the hub
|
637 |
if jax.process_index() == 0:
|
638 |
params = jax.device_get(params)
|