Patrick von Platen commited on
Commit
313c50a
1 Parent(s): d9a2697

Saving weights and logs of step 8

Browse files
events.out.tfevents.1625595098.t1v-n-71556209-w-0.22293.3.v2 ADDED
Binary file (40 Bytes). View file
 
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:79e9400819ada0aa172374be8a0a62667546a8db83f8483f337944a9eaf9cb19
3
  size 498796983
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b4812d4f42c82ee8a963f90fe45ce25a75b0283d71e093ff578ba3c65de9d6e
3
  size 498796983
run.sh CHANGED
@@ -7,10 +7,13 @@
7
  --dataset_name="oscar" \
8
  --dataset_config_name="unshuffled_deduplicated_als" \
9
  --max_seq_length="128" \
10
- --per_device_train_batch_size="4" \
11
- --per_device_eval_batch_size="4" \
12
  --learning_rate="3e-4" \
13
  --warmup_steps="1000" \
14
  --overwrite_output_dir \
15
  --num_train_epochs="8" \
 
 
 
16
  --push_to_hub
 
7
  --dataset_name="oscar" \
8
  --dataset_config_name="unshuffled_deduplicated_als" \
9
  --max_seq_length="128" \
10
+ --per_device_train_batch_size="1" \
11
+ --per_device_eval_batch_size="1" \
12
  --learning_rate="3e-4" \
13
  --warmup_steps="1000" \
14
  --overwrite_output_dir \
15
  --num_train_epochs="8" \
16
+ --logging_steps="10" \
17
+ --save_steps="8" \
18
+ --eval_steps="15" \
19
  --push_to_hub
run_mlm_flax.py CHANGED
@@ -606,7 +606,7 @@ if __name__ == "__main__":
606
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
607
  train_metrics.append(train_metric)
608
 
609
- cur_step = epoch * num_train_samples + step
610
 
611
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
612
  # Save metrics
@@ -621,43 +621,43 @@ if __name__ == "__main__":
621
 
622
  train_metrics = []
623
 
624
- # ======================== Evaluating ==============================
625
- num_eval_samples = len(tokenized_datasets["validation"])
626
- eval_samples_idx = jnp.arange(num_eval_samples)
627
- eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
 
628
 
629
- eval_metrics = []
630
- for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
631
- samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
632
- model_inputs = data_collator(samples, pad_to_multiple_of=16)
633
 
634
- # Model forward
635
- model_inputs = shard(model_inputs.data)
636
- metrics = p_eval_step(state.params, model_inputs)
637
- eval_metrics.append(metrics)
638
-
639
- # normalize eval metrics
640
- eval_metrics = get_metrics(eval_metrics)
641
- eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
642
- eval_normalizer = eval_metrics.pop("normalizer")
643
- eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
644
-
645
- # Update progress bar
646
- epochs.desc = (
647
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
648
- )
649
 
650
- # Save metrics
651
- if has_tensorboard and jax.process_index() == 0:
652
- cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
653
- write_eval_metric(summary_writer, eval_metrics, cur_step)
654
-
655
- # save checkpoint after each epoch and push checkpoint to the hub
656
- if jax.process_index() == 0:
657
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
658
- model.save_pretrained(
659
- training_args.output_dir,
660
- params=params,
661
- push_to_hub=training_args.push_to_hub,
662
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
663
- )
 
 
 
 
 
 
 
 
 
 
 
606
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
607
  train_metrics.append(train_metric)
608
 
609
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
610
 
611
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
612
  # Save metrics
 
621
 
622
  train_metrics = []
623
 
624
+ if cur_step % training_args.eval_steps == 0 and step > 0:
625
+ # ======================== Evaluating ==============================
626
+ num_eval_samples = len(tokenized_datasets["validation"])
627
+ eval_samples_idx = jnp.arange(num_eval_samples)
628
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
629
 
630
+ eval_metrics = []
631
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
632
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
633
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
634
 
635
+ # Model forward
636
+ model_inputs = shard(model_inputs.data)
637
+ metrics = p_eval_step(state.params, model_inputs)
638
+ eval_metrics.append(metrics)
 
 
 
 
 
 
 
 
 
 
 
639
 
640
+ # normalize eval metrics
641
+ eval_metrics = get_metrics(eval_metrics)
642
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
643
+ eval_normalizer = eval_metrics.pop("normalizer")
644
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
645
+
646
+ # Update progress bar
647
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
648
+
649
+ # Save metrics
650
+ if has_tensorboard and jax.process_index() == 0:
651
+ cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
652
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
653
+
654
+ if cur_step % training_args.save_steps == 0 and step > 0:
655
+ # save checkpoint after each epoch and push checkpoint to the hub
656
+ if jax.process_index() == 0:
657
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
658
+ model.save_pretrained(
659
+ training_args.output_dir,
660
+ params=params,
661
+ push_to_hub=training_args.push_to_hub,
662
+ commit_message=f"Saving weights and logs of step {cur_step}",
663
+ )