miwojc commited on
Commit
233af6c
1 Parent(s): 0c42b31

Saving weights and logs of step 30000

Browse files
Files changed (3) hide show
  1. flax_model.msgpack +1 -1
  2. pretrain.sh +5 -3
  3. run_clm_flax.py +2 -2
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c95d9bb38c83f3a8dd2a2300154d85d29447ff3d924a9327a8b732bfb55c0a66
3
  size 3096134690
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d95ccbc5b88a04092b2ba965a20d6670dc94adfcef72afcf8d59d702a4e382b
3
  size 3096134690
pretrain.sh CHANGED
@@ -9,10 +9,12 @@
9
  --block_size="512" \
10
  --per_device_train_batch_size="4" \
11
  --per_device_eval_batch_size="4" \
12
- --learning_rate="2e-5" --warmup_steps="16000" \
13
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
14
  --overwrite_output_dir \
15
- --num_train_epochs="5" \
16
- --logging_steps="80000" \
 
 
17
  --adafactor \
18
  --push_to_hub
 
9
  --block_size="512" \
10
  --per_device_train_batch_size="4" \
11
  --per_device_eval_batch_size="4" \
12
+ --learning_rate="3e-4" --warmup_steps="16000" \
13
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
14
  --overwrite_output_dir \
15
+ --num_train_epochs="2" \
16
+ --logging_steps="15000" \
17
+ --save_steps="30000" \
18
+ --eval_steps="30000" \
19
  --adafactor \
20
  --push_to_hub
run_clm_flax.py CHANGED
@@ -398,7 +398,8 @@ def main():
398
  total_length = len(concatenated_examples[list(examples.keys())[0]])
399
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
400
  # customize this part to your needs.
401
- total_length = (total_length // block_size) * block_size
 
402
  # Split by chunks of max_len.
403
  result = {
404
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
@@ -621,7 +622,6 @@ def main():
621
 
622
  # Save metrics
623
  if has_tensorboard and jax.process_index() == 0:
624
- cur_step = epoch * (len(train_dataset) // train_batch_size)
625
  write_eval_metric(summary_writer, eval_metrics, cur_step)
626
 
627
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
 
398
  total_length = len(concatenated_examples[list(examples.keys())[0]])
399
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
400
  # customize this part to your needs.
401
+ if total_length >= block_size:
402
+ total_length = (total_length // block_size) * block_size
403
  # Split by chunks of max_len.
404
  result = {
405
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
 
622
 
623
  # Save metrics
624
  if has_tensorboard and jax.process_index() == 0:
 
625
  write_eval_metric(summary_writer, eval_metrics, cur_step)
626
 
627
  if cur_step % training_args.save_steps == 0 and cur_step > 0: