cahya commited on
Commit
6f9afb3
1 Parent(s): a9c5d2c

fixed the save_steps, make test

Browse files
Files changed (2) hide show
  1. run_clm_flax.py +2 -2
  2. run_pretraining.sh +6 -3
run_clm_flax.py CHANGED
@@ -413,7 +413,8 @@ def main():
413
  total_length = len(concatenated_examples[list(examples.keys())[0]])
414
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
415
  # customize this part to your needs.
416
- total_length = (total_length // block_size) * block_size
 
417
  # Split by chunks of max_len.
418
  result = {
419
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
@@ -636,7 +637,6 @@ def main():
636
 
637
  # Save metrics
638
  if has_tensorboard and jax.process_index() == 0:
639
- cur_step = epoch * (len(train_dataset) // train_batch_size)
640
  write_eval_metric(summary_writer, eval_metrics, cur_step)
641
 
642
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
 
413
  total_length = len(concatenated_examples[list(examples.keys())[0]])
414
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
415
  # customize this part to your needs.
416
+ if total_length >= block_size:
417
+ total_length = (total_length // block_size) * block_size
418
  # Split by chunks of max_len.
419
  result = {
420
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
 
637
 
638
  # Save metrics
639
  if has_tensorboard and jax.process_index() == 0:
 
640
  write_eval_metric(summary_writer, eval_metrics, cur_step)
641
 
642
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
run_pretraining.sh CHANGED
@@ -1,3 +1,4 @@
 
1
  export WANDB_ENTITY="wandb"
2
  export WANDB_PROJECT="hf-flax-gpt2-indonesian"
3
  export WANDB_LOG_MODEL="true"
@@ -13,12 +14,14 @@ export WANDB_LOG_MODEL="true"
13
  --block_size="512" \
14
  --per_device_train_batch_size="24" \
15
  --per_device_eval_batch_size="24" \
16
- --learning_rate="5e-3" --warmup_steps="1000" \
17
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
18
  --overwrite_output_dir \
19
  --num_train_epochs="20" \
20
  --dataloader_num_workers="64" \
21
  --preprocessing_num_workers="64" \
22
- --save_steps="2000" \
23
- --eval_steps="2000" \
 
 
24
  --push_to_hub
 
1
+ export MODEL_DIR=`pwd`
2
  export WANDB_ENTITY="wandb"
3
  export WANDB_PROJECT="hf-flax-gpt2-indonesian"
4
  export WANDB_LOG_MODEL="true"
 
14
  --block_size="512" \
15
  --per_device_train_batch_size="24" \
16
  --per_device_eval_batch_size="24" \
17
+ --learning_rate="0.0024" --warmup_steps="1000" \
18
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
19
  --overwrite_output_dir \
20
  --num_train_epochs="20" \
21
  --dataloader_num_workers="64" \
22
  --preprocessing_num_workers="64" \
23
+ --save_steps="10" \
24
+ --eval_steps="10" \
25
+ --max_train_samples="10000" \
26
+ --max_eval_samples="1000" \
27
  --push_to_hub