Pedro Cuenca commited on
Commit
566d5f2
·
1 Parent(s): 835ea55

Add eval_interval to evaluate and log every so often.

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +54 -34
seq2seq/run_seq2seq_flax.py CHANGED
@@ -225,6 +225,12 @@ class DataTrainingArguments:
225
  "value if set."
226
  },
227
  )
 
 
 
 
 
 
228
  log_model: bool = field(
229
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
230
  )
@@ -738,37 +744,8 @@ def main():
738
  train_time = 0
739
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
740
  global_step = 0
741
- for epoch in epochs:
742
- # ======================== Training ================================
743
- train_start = time.time()
744
-
745
- # Create sampling rng
746
- rng, input_rng = jax.random.split(rng)
747
- train_metrics = []
748
-
749
- # Generate an epoch by shuffling sampling indices from the train dataset
750
- train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
751
- steps_per_epoch = len(train_dataset) // train_batch_size
752
- # train
753
- for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
754
- global_step +=1
755
- batch = next(train_loader)
756
- state, train_metric = p_train_step(state, batch)
757
- train_metrics.append(train_metric)
758
-
759
- if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
760
- for k, v in unreplicate(train_metric).items():
761
- wandb.log({"train/step": global_step})
762
- wandb.log({f"train/{k}": jax.device_get(v)})
763
-
764
- train_time += time.time() - train_start
765
-
766
- train_metric = unreplicate(train_metric)
767
-
768
- epochs.write(
769
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
770
- )
771
 
 
772
  # ======================== Evaluating ==============================
773
  eval_metrics = []
774
  if training_args.do_eval:
@@ -795,17 +772,60 @@ def main():
795
  eval_metrics = get_metrics(eval_metrics)
796
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
797
 
 
 
 
 
 
798
  # compute ROUGE metrics
799
  rouge_desc = ""
800
- # if data_args.predict_with_generate:
801
- # rouge_metrics = compute_metrics(eval_preds, eval_labels)
802
- # eval_metrics.update(rouge_metrics)
803
- # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
804
 
805
  # Print metrics and update progress bar
806
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
807
  epochs.write(desc)
808
  epochs.desc = desc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
809
 
810
  # Save metrics
811
  if has_tensorboard and jax.process_index() == 0:
 
225
  "value if set."
226
  },
227
  )
228
+ eval_interval: Optional[int] = field(
229
+ default=40,
230
+ metadata={
231
+ "help": "Evaluation will be performed every eval_interval steps"
232
+ },
233
+ )
234
  log_model: bool = field(
235
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
236
  )
 
744
  train_time = 0
745
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
746
  global_step = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
 
748
+ def run_evaluation():
749
  # ======================== Evaluating ==============================
750
  eval_metrics = []
751
  if training_args.do_eval:
 
772
  eval_metrics = get_metrics(eval_metrics)
773
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
774
 
775
+ if jax.process_index() == 0:
776
+ for k, v in eval_metrics.items():
777
+ wandb.log({"eval/step": global_step})
778
+ wandb.log({f"eval/{k}": jax.device_get(v)})
779
+
780
  # compute ROUGE metrics
781
  rouge_desc = ""
782
+ # if data_args.predict_with_generate:
783
+ # rouge_metrics = compute_metrics(eval_preds, eval_labels)
784
+ # eval_metrics.update(rouge_metrics)
785
+ # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
786
 
787
  # Print metrics and update progress bar
788
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
789
  epochs.write(desc)
790
  epochs.desc = desc
791
+ return eval_metrics
792
+
793
+ for epoch in epochs:
794
+ # ======================== Training ================================
795
+ train_start = time.time()
796
+
797
+ # Create sampling rng
798
+ rng, input_rng = jax.random.split(rng)
799
+ train_metrics = []
800
+
801
+ # Generate an epoch by shuffling sampling indices from the train dataset
802
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
803
+ steps_per_epoch = len(train_dataset) // train_batch_size
804
+ # train
805
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
806
+ global_step +=1
807
+ batch = next(train_loader)
808
+ state, train_metric = p_train_step(state, batch)
809
+ train_metrics.append(train_metric)
810
+
811
+ if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
812
+ print("logging train loss")
813
+ for k, v in unreplicate(train_metric).items():
814
+ wandb.log({"train/step": global_step})
815
+ wandb.log({f"train/{k}": jax.device_get(v)})
816
+
817
+ if global_step % data_args.eval_interval == 0 and jax.process_index() == 0:
818
+ run_evaluation()
819
+
820
+ train_time += time.time() - train_start
821
+
822
+ train_metric = unreplicate(train_metric)
823
+
824
+ epochs.write(
825
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
826
+ )
827
+
828
+ eval_metrics = run_evaluation()
829
 
830
  # Save metrics
831
  if has_tensorboard and jax.process_index() == 0: