m3hrdadfi commited on
Commit
904a484
1 Parent(s): 34783fc

Fix prediction metric

Browse files
Files changed (1) hide show
  1. src/run_ed_recipe_nlg.py +5 -5
src/run_ed_recipe_nlg.py CHANGED
@@ -832,14 +832,14 @@ def main():
832
  pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
833
 
834
  # compute ROUGE metrics
835
- rouge_desc = ""
836
  if data_args.predict_with_generate:
837
- rouge_metrics = compute_metrics(pred_generations, pred_labels)
838
- pred_metrics.update(rouge_metrics)
839
- rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
840
 
841
  # Print metrics
842
- desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
843
  logger.info(desc)
844
 
845
  # save checkpoint after each epoch and push checkpoint to the hub
 
832
  pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
833
 
834
  # compute ROUGE metrics
835
+ mix_desc = ""
836
  if data_args.predict_with_generate:
837
+ mix_metrics = compute_metrics(pred_generations, pred_labels)
838
+ pred_metrics.update(mix_metrics)
839
+ mix_desc = " ".join([f"Predict {key}: {value} |" for key, value in mix_metrics.items()])
840
 
841
  # Print metrics
842
+ desc = f"Predict Loss: {pred_metrics['loss']} | {mix_desc})"
843
  logger.info(desc)
844
 
845
  # save checkpoint after each epoch and push checkpoint to the hub