pere commited on
Commit
6925281
1 Parent(s): ad7df07

Saving weights and logs of step 200949

Browse files
__pycache__/partitions.cpython-38.pyc CHANGED
Binary files a/__pycache__/partitions.cpython-38.pyc and b/__pycache__/partitions.cpython-38.pyc differ
 
events.out.tfevents.1631437573.t1v-n-1a0a7c50-w-0.2668975.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c91cd6fd5f940d9fd43ca0a987b067a45de3926cb3b2c74491b72c278a7e734
3
+ size 29767300
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dbe8bc88b4653f67693fc1db73c5db19a982a7574e3663c485c9229ef18ee3a9
3
  size 5262371934
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cd155c5ef1aa1fe765db196793aef1ec57df9d9c9388d93f7bc24381f7c5be1
3
  size 5262371934
run_clm_mp.py CHANGED
@@ -263,7 +263,6 @@ def main():
263
  #
264
  # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
265
  # 'text' is found. You can easily tweak this behavior (see below).
266
-
267
  if data_args.dataset_name is not None:
268
  # Downloading and loading a dataset from the hub.
269
  dataset = load_dataset(
@@ -633,7 +632,7 @@ def main():
633
  f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
634
  )
635
 
636
- if cur_step % training_args.save_steps == 0 and cur_step > 0:
637
  # save checkpoint after each epoch and push checkpoint to the hub
638
  if jax.process_index() == 0:
639
  params = jax.device_get(params)
 
263
  #
264
  # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
265
  # 'text' is found. You can easily tweak this behavior (see below).
 
266
  if data_args.dataset_name is not None:
267
  # Downloading and loading a dataset from the hub.
268
  dataset = load_dataset(
 
632
  f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
633
  )
634
 
635
+ if cur_step % steps_per_epoch == 0 and cur_step > 0:
636
  # save checkpoint after each epoch and push checkpoint to the hub
637
  if jax.process_index() == 0:
638
  params = jax.device_get(params)