Saving weights and logs of step 30000
Browse files- flax_model.msgpack +1 -1
- pretrain.sh +5 -3
- run_clm_flax.py +2 -2
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 3096134690
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2d95ccbc5b88a04092b2ba965a20d6670dc94adfcef72afcf8d59d702a4e382b
|
3 |
size 3096134690
|
pretrain.sh
CHANGED
@@ -9,10 +9,12 @@
|
|
9 |
--block_size="512" \
|
10 |
--per_device_train_batch_size="4" \
|
11 |
--per_device_eval_batch_size="4" \
|
12 |
-
--learning_rate="
|
13 |
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
14 |
--overwrite_output_dir \
|
15 |
-
--num_train_epochs="
|
16 |
-
--logging_steps="
|
|
|
|
|
17 |
--adafactor \
|
18 |
--push_to_hub
|
|
|
9 |
--block_size="512" \
|
10 |
--per_device_train_batch_size="4" \
|
11 |
--per_device_eval_batch_size="4" \
|
12 |
+
--learning_rate="3e-4" --warmup_steps="16000" \
|
13 |
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
|
14 |
--overwrite_output_dir \
|
15 |
+
--num_train_epochs="2" \
|
16 |
+
--logging_steps="15000" \
|
17 |
+
--save_steps="30000" \
|
18 |
+
--eval_steps="30000" \
|
19 |
--adafactor \
|
20 |
--push_to_hub
|
run_clm_flax.py
CHANGED
@@ -398,7 +398,8 @@ def main():
|
|
398 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
399 |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
400 |
# customize this part to your needs.
|
401 |
-
|
|
|
402 |
# Split by chunks of max_len.
|
403 |
result = {
|
404 |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
@@ -621,7 +622,6 @@ def main():
|
|
621 |
|
622 |
# Save metrics
|
623 |
if has_tensorboard and jax.process_index() == 0:
|
624 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
625 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
626 |
|
627 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
|
|
398 |
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
399 |
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
400 |
# customize this part to your needs.
|
401 |
+
if total_length >= block_size:
|
402 |
+
total_length = (total_length // block_size) * block_size
|
403 |
# Split by chunks of max_len.
|
404 |
result = {
|
405 |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
|
|
622 |
|
623 |
# Save metrics
|
624 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
625 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
626 |
|
627 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|