m3hrdadfi commited on
Commit
0f70e83
1 Parent(s): 14f41dd

Handling states-steps

Browse files
Files changed (2) hide show
  1. src/run.sh +9 -5
  2. src/run_ed_recipe_nlg.py +87 -102
src/run.sh CHANGED
@@ -12,16 +12,19 @@ export VALIDATION_FILE=/to/../dev.csv
12
  export TEST_FILE=/to/../test.csv
13
  export TEXT_COLUMN=inputs
14
  export TARGET_COLUMN=targets
15
- export MAX_SOURCE_LENGTH=128
16
  export MAX_TARGET_LENGTH=1024
17
  export SOURCE_PREFIX=ingredients
18
 
19
  export PER_DEVICE_TRAIN_BATCH_SIZE=8
20
  export PER_DEVICE_EVAL_BATCH_SIZE=8
21
  export GRADIENT_ACCUMULATION_STEPS=2
22
- export NUM_TRAIN_EPOCHS=3.0
23
- export LEARNING_RATE=5e-4
24
  export WARMUP_STEPS=5000
 
 
 
25
 
26
  python run_ed_recipe_nlg.py \
27
  --output_dir="$OUTPUT_DIR" \
@@ -42,10 +45,11 @@ python run_ed_recipe_nlg.py \
42
  --num_train_epochs=$NUM_TRAIN_EPOCHS \
43
  --learning_rate=$LEARNING_RATE \
44
  --warmup_steps=$WARMUP_STEPS \
45
- --preprocessing_num_workers=4 \
 
 
46
  --prediction_debug \
47
  --do_train \
48
  --do_eval \
49
- --do_predict \
50
  --overwrite_output_dir \
51
  --predict_with_generate
 
12
  export TEST_FILE=/to/../test.csv
13
  export TEXT_COLUMN=inputs
14
  export TARGET_COLUMN=targets
15
+ export MAX_SOURCE_LENGTH=256
16
  export MAX_TARGET_LENGTH=1024
17
  export SOURCE_PREFIX=ingredients
18
 
19
  export PER_DEVICE_TRAIN_BATCH_SIZE=8
20
  export PER_DEVICE_EVAL_BATCH_SIZE=8
21
  export GRADIENT_ACCUMULATION_STEPS=2
22
+ export NUM_TRAIN_EPOCHS=5.0
23
+ export LEARNING_RATE=1e-4
24
  export WARMUP_STEPS=5000
25
+ export LOGGING_STEPS=500
26
+ export EVAL_STEPS=2500
27
+ export SAVE_STEPS=2500
28
 
29
  python run_ed_recipe_nlg.py \
30
  --output_dir="$OUTPUT_DIR" \
 
45
  --num_train_epochs=$NUM_TRAIN_EPOCHS \
46
  --learning_rate=$LEARNING_RATE \
47
  --warmup_steps=$WARMUP_STEPS \
48
+ --logging_step=$LOGGING_STEPS \
49
+ --eval_steps=$EVAL_STEPS \
50
+ --save_steps=$SAVE_STEPS \
51
  --prediction_debug \
52
  --do_train \
53
  --do_eval \
 
54
  --overwrite_output_dir \
55
  --predict_with_generate
src/run_ed_recipe_nlg.py CHANGED
@@ -258,7 +258,20 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
258
  yield batch
259
 
260
 
261
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  summary_writer.scalar("train_time", train_time, step)
263
 
264
  train_metrics = get_metrics(train_metrics)
@@ -267,6 +280,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
267
  for i, val in enumerate(vals):
268
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
269
 
 
 
270
  for metric_name, value in eval_metrics.items():
271
  summary_writer.scalar(f"eval_{metric_name}", value, step)
272
 
@@ -553,7 +568,7 @@ def main():
553
  result = {}
554
 
555
  try:
556
- result_blue = bleu.compute(predictions=decoded_preds, references=decoded_labels_wer)
557
  result_blue = result_blue["score"]
558
  except Exception as e:
559
  logger.info(f'Error occurred during bleu {e}')
@@ -734,6 +749,7 @@ def main():
734
  logger.info(f" Total optimization steps = {total_train_steps}")
735
 
736
  train_time = 0
 
737
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
738
  for epoch in epochs:
739
  # ======================== Training ================================
@@ -741,115 +757,84 @@ def main():
741
 
742
  # Create sampling rng
743
  rng, input_rng = jax.random.split(rng)
744
- train_metrics = []
745
 
746
  # Generate an epoch by shuffling sampling indices from the train dataset
747
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
748
  steps_per_epoch = len(train_dataset) // train_batch_size
749
  # train
750
- for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
751
  batch = next(train_loader)
752
  state, train_metric = p_train_step(state, batch)
753
  train_metrics.append(train_metric)
754
 
755
- train_time += time.time() - train_start
756
-
757
- train_metric = unreplicate(train_metric)
758
-
759
- epochs.write(
760
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
761
- )
762
-
763
- # ======================== Evaluating ==============================
764
- eval_metrics = []
765
- eval_preds = []
766
- eval_labels = []
767
-
768
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
769
- eval_steps = len(eval_dataset) // eval_batch_size
770
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
771
- # Model forward
772
- batch = next(eval_loader)
773
- labels = batch["labels"]
774
-
775
- metrics = p_eval_step(state.params, batch)
776
- eval_metrics.append(metrics)
777
-
778
- # generation
779
- if data_args.predict_with_generate:
780
- generated_ids = p_generate_step(state.params, batch)
781
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
782
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
783
-
784
- # normalize eval metrics
785
- eval_metrics = get_metrics(eval_metrics)
786
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
787
-
788
- # compute ROUGE metrics
789
- rouge_desc = ""
790
- if data_args.predict_with_generate:
791
- rouge_metrics = compute_metrics(eval_preds, eval_labels)
792
- eval_metrics.update(rouge_metrics)
793
- rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
794
-
795
- # Print metrics and update progress bar
796
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
797
- epochs.write(desc)
798
- epochs.desc = desc
799
-
800
- # Save metrics
801
- if has_tensorboard and jax.process_index() == 0:
802
- cur_step = epoch * (len(train_dataset) // train_batch_size)
803
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
804
-
805
- # ======================== Prediction loop ==============================
806
- if training_args.do_predict:
807
- logger.info("*** Predict ***")
808
-
809
- pred_metrics = []
810
- pred_generations = []
811
- pred_labels = []
812
-
813
- pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
814
- pred_steps = len(predict_dataset) // eval_batch_size
815
- for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
816
- # Model forward
817
- batch = next(pred_loader)
818
- labels = batch["labels"]
819
-
820
- metrics = p_eval_step(state.params, batch)
821
- pred_metrics.append(metrics)
822
-
823
- # generation
824
- if data_args.predict_with_generate:
825
- generated_ids = p_generate_step(state.params, batch)
826
- pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
827
- pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
828
-
829
- # normalize prediction metrics
830
- pred_metrics = get_metrics(pred_metrics)
831
- pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
832
-
833
- # compute ROUGE metrics
834
- mix_desc = ""
835
- if data_args.predict_with_generate:
836
- mix_metrics = compute_metrics(pred_generations, pred_labels)
837
- pred_metrics.update(mix_metrics)
838
- mix_desc = " ".join([f"Predict {key}: {value} |" for key, value in mix_metrics.items()])
839
-
840
- # Print metrics
841
- desc = f"Predict Loss: {pred_metrics['loss']} | {mix_desc})"
842
- logger.info(desc)
843
-
844
- # save checkpoint after each epoch and push checkpoint to the hub
845
- if jax.process_index() == 0:
846
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
847
- model.save_pretrained(
848
- training_args.output_dir,
849
- params=params,
850
- push_to_hub=training_args.push_to_hub,
851
- commit_message=f"Saving weights and logs of epoch {epoch + 1}",
852
- )
853
 
854
 
855
  if __name__ == "__main__":
 
258
  yield batch
259
 
260
 
261
+ # def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
262
+ # summary_writer.scalar("train_time", train_time, step)
263
+ #
264
+ # train_metrics = get_metrics(train_metrics)
265
+ # for key, vals in train_metrics.items():
266
+ # tag = f"train_{key}"
267
+ # for i, val in enumerate(vals):
268
+ # summary_writer.scalar(tag, val, step - len(vals) + i + 1)
269
+ #
270
+ # for metric_name, value in eval_metrics.items():
271
+ # summary_writer.scalar(f"eval_{metric_name}", value, step)
272
+ #
273
+
274
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
275
  summary_writer.scalar("train_time", train_time, step)
276
 
277
  train_metrics = get_metrics(train_metrics)
 
280
  for i, val in enumerate(vals):
281
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
282
 
283
+
284
+ def write_eval_metric(summary_writer, eval_metrics, step):
285
  for metric_name, value in eval_metrics.items():
286
  summary_writer.scalar(f"eval_{metric_name}", value, step)
287
 
 
568
  result = {}
569
 
570
  try:
571
+ result_blue = bleu.compute(predictions=decoded_preds, references=decoded_labels_bleu)
572
  result_blue = result_blue["score"]
573
  except Exception as e:
574
  logger.info(f'Error occurred during bleu {e}')
 
749
  logger.info(f" Total optimization steps = {total_train_steps}")
750
 
751
  train_time = 0
752
+ train_metrics = []
753
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
754
  for epoch in epochs:
755
  # ======================== Training ================================
 
757
 
758
  # Create sampling rng
759
  rng, input_rng = jax.random.split(rng)
 
760
 
761
  # Generate an epoch by shuffling sampling indices from the train dataset
762
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
763
  steps_per_epoch = len(train_dataset) // train_batch_size
764
  # train
765
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
766
  batch = next(train_loader)
767
  state, train_metric = p_train_step(state, batch)
768
  train_metrics.append(train_metric)
769
 
770
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
771
+
772
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
773
+ # Save metrics
774
+ train_metric = unreplicate(train_metric)
775
+ train_time += time.time() - train_start
776
+ if has_tensorboard and jax.process_index() == 0:
777
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
778
+
779
+ epochs.write(
780
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
781
+ )
782
+
783
+ train_metrics = []
784
+
785
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
786
+ eval_metrics = []
787
+ eval_preds = []
788
+ eval_labels = []
789
+
790
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
791
+ eval_steps = len(eval_dataset) // eval_batch_size
792
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
793
+ # Model forward
794
+ batch = next(eval_loader)
795
+ labels = batch["labels"]
796
+
797
+ metrics = p_eval_step(state.params, batch)
798
+ eval_metrics.append(metrics)
799
+
800
+ # generation
801
+ if data_args.predict_with_generate:
802
+ generated_ids = p_generate_step(state.params, batch)
803
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
804
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
805
+
806
+ # normalize eval metrics
807
+ eval_metrics = get_metrics(eval_metrics)
808
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
809
+
810
+ # compute MIX metrics
811
+ mix_desc = ""
812
+ if data_args.predict_with_generate:
813
+ mix_metrics = compute_metrics(eval_preds, eval_labels)
814
+ eval_metrics.update(mix_metrics)
815
+ mix_desc = " ".join([f"Eval {key}: {value} |" for key, value in mix_metrics.items()])
816
+
817
+ # Print metrics and update progress bar
818
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {mix_desc})"
819
+ epochs.write(desc)
820
+ epochs.desc = desc
821
+
822
+ # Save metrics
823
+ if has_tensorboard and jax.process_index() == 0:
824
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
825
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
826
+
827
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
828
+ # save checkpoint after each epoch and push checkpoint to the hub
829
+ if jax.process_index() == 0:
830
+ # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
831
+ params = jax.device_get(unreplicate(state.params))
832
+ model.save_pretrained(
833
+ training_args.output_dir,
834
+ params=params,
835
+ push_to_hub=training_args.push_to_hub,
836
+ commit_message=f"Saving weights and logs of step {cur_step}",
837
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
 
839
 
840
  if __name__ == "__main__":