sakares commited on
Commit
6338ce3
1 Parent(s): 6577ec1

adjust batch_size and update run_mlm_flax script

Browse files
events.out.tfevents.1625581996.t1v-n-bf8aeee7-w-0.10129.3.v2 DELETED
Binary file (883 kB)
 
run.sh CHANGED
@@ -9,8 +9,8 @@ python3 run_mlm_flax.py \
9
  --dataset_config_name="unshuffled_deduplicated_th" \
10
  --max_seq_length="128" \
11
  --preprocessing_num_workers="64" \
12
- --per_device_train_batch_size="32" \
13
- --per_device_eval_batch_size="32" \
14
  --learning_rate="2e-4" \
15
  --warmup_steps="1000" \
16
  --overwrite_output_dir \
 
9
  --dataset_config_name="unshuffled_deduplicated_th" \
10
  --max_seq_length="128" \
11
  --preprocessing_num_workers="64" \
12
+ --per_device_train_batch_size="64" \
13
+ --per_device_eval_batch_size="64" \
14
  --learning_rate="2e-4" \
15
  --warmup_steps="1000" \
16
  --overwrite_output_dir \
run_mlm_flax.py CHANGED
@@ -606,7 +606,7 @@ if __name__ == "__main__":
606
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
607
  train_metrics.append(train_metric)
608
 
609
- cur_step = epoch * num_train_samples + step
610
 
611
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
612
  # Save metrics
@@ -621,43 +621,43 @@ if __name__ == "__main__":
621
 
622
  train_metrics = []
623
 
624
- # ======================== Evaluating ==============================
625
- num_eval_samples = len(tokenized_datasets["validation"])
626
- eval_samples_idx = jnp.arange(num_eval_samples)
627
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
 
628
 
629
- eval_metrics = []
630
- for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
631
- samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
632
- model_inputs = data_collator(samples, pad_to_multiple_of=16)
633
 
634
- # Model forward
635
- model_inputs = shard(model_inputs.data)
636
- metrics = p_eval_step(state.params, model_inputs)
637
- eval_metrics.append(metrics)
638
-
639
- # normalize eval metrics
640
- eval_metrics = get_metrics(eval_metrics)
641
- eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
642
- eval_normalizer = eval_metrics.pop("normalizer")
643
- eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
644
-
645
- # Update progress bar
646
- epochs.desc = (
647
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
648
- )
649
 
650
- # Save metrics
651
- if has_tensorboard and jax.process_index() == 0:
652
- cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
653
- write_eval_metric(summary_writer, eval_metrics, cur_step)
654
-
655
- # save checkpoint after each epoch and push checkpoint to the hub
656
- if jax.process_index() == 0:
657
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
658
- model.save_pretrained(
659
- training_args.output_dir,
660
- params=params,
661
- push_to_hub=training_args.push_to_hub,
662
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
663
- )
 
 
 
 
 
 
 
 
 
 
 
606
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
607
  train_metrics.append(train_metric)
608
 
609
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
610
 
611
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
612
  # Save metrics
 
621
 
622
  train_metrics = []
623
 
624
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
625
+ # ======================== Evaluating ==============================
626
+ num_eval_samples = len(tokenized_datasets["validation"])
627
+ eval_samples_idx = jnp.arange(num_eval_samples)
628
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
629
 
630
+ eval_metrics = []
631
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
632
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
633
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
634
 
635
+ # Model forward
636
+ model_inputs = shard(model_inputs.data)
637
+ metrics = p_eval_step(state.params, model_inputs)
638
+ eval_metrics.append(metrics)
 
 
 
 
 
 
 
 
 
 
 
639
 
640
+ # normalize eval metrics
641
+ eval_metrics = get_metrics(eval_metrics)
642
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
643
+ eval_normalizer = eval_metrics.pop("normalizer")
644
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
645
+
646
+ # Update progress bar
647
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
648
+
649
+ # Save metrics
650
+ if has_tensorboard and jax.process_index() == 0:
651
+ cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
652
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
653
+
654
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
655
+ # save checkpoint after each epoch and push checkpoint to the hub
656
+ if jax.process_index() == 0:
657
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
658
+ model.save_pretrained(
659
+ training_args.output_dir,
660
+ params=params,
661
+ push_to_hub=training_args.push_to_hub,
662
+ commit_message=f"Saving weights and logs of step {cur_step}",
663
+ )