pinnnn commited on
Commit
325bf75
1 Parent(s): c447410

Former-commit-id: 0a964e3f26e607fa7a85bbed9bff7ff13c13299a

Files changed (1) hide show
  1. evaluation/evaluation.py +25 -26
evaluation/evaluation.py CHANGED
@@ -1,57 +1,56 @@
1
  import argparse
2
  import pandas as pd
3
- from evaluation.alignment import alignment
4
- from evaluation.scores.multi_scores import multi_scores
5
- from src.srt_util.srt import SrtScript
6
 
7
  class Evaluator:
8
- def __init__(self, src_path, pred_path, gt_path, eval_path, conclusion_path):
9
- self.src_path = src_path
10
  self.pred_path = pred_path
11
  self.gt_path = gt_path
12
  self.eval_path = eval_path
13
- self.conclusion_path = conclusion_path
14
 
15
  def eval(self):
16
  # Align two SRT files
17
  aligned_srt = alignment(self.pred_path, self.gt_path)
18
 
19
- # Parse src
20
- src_s = [s.source_text for s in SrtScript.parse_from_srt_file(self.src_path).segments]
21
-
22
  # Get sentence scores
23
  scorer = multi_scores()
24
  result_data = []
25
- for ((prd_s, gt_s), src_s) in zip(aligned_srt, src_s):
26
- scores_dict = scorer.get(src_s, prd_s, gt_s)
27
- scores_dict['Prediction'] = prd_s
28
- scores_dict['Ground Truth'] = gt_s
 
 
 
 
 
29
  result_data.append(scores_dict)
30
 
31
  eval_df = pd.DataFrame(result_data)
32
- eval_df.to_csv(self.output_path, index=False, columns=['Prediction', 'Ground Truth', 'llm', 'bleu', 'comet'])
33
 
34
  # Get average scores
35
- avg_llm = eval_df['llm'].mean()
36
- avg_bleu = eval_df['bleu'].mean()
37
- avg_comet = eval_df['comet'].mean()
38
 
39
- conclusion_data = {
40
  'Metric': ['Avg LLM', 'Avg BLEU', 'Avg COMET'],
41
  'Score': [avg_llm, avg_bleu, avg_comet]
42
  }
43
- conclusion_df = pd.DataFrame(conclusion_data)
44
- conclusion_df.to_csv(self.conclusion_path, index=False)
45
 
46
  if __name__ == "__main__":
47
  parser = argparse.ArgumentParser(description='Evaluate SRT files.')
48
- parser.add_argument('-src', default='test/short_src', help='Path to source SRT file')
49
- parser.add_argument('-pred', default='test/short_pred', help='Path to predicted SRT file')
50
- parser.add_argument('-gt', default='test/short_gt', help='Path to ground truth SRT file')
51
- parser.add_argument('-eval', default='eval.csv', help='Path to output CSV file')
52
- parser.add_argument('-conclusion', default='conclusion.csv', help='Path to conclusion CSV file')
53
  args = parser.parse_args()
54
 
55
- evaluator = Evaluator(args.src, args.pred, args.gt, args.eval, args.conclusion)
56
  evaluator.eval()
57
 
 
1
  import argparse
2
  import pandas as pd
3
+ from alignment import alignment
4
+ from scores.multi_scores import multi_scores
 
5
 
6
  class Evaluator:
7
+ def __init__(self, pred_path, gt_path, eval_path, res_path):
 
8
  self.pred_path = pred_path
9
  self.gt_path = gt_path
10
  self.eval_path = eval_path
11
+ self.res_path = res_path
12
 
13
  def eval(self):
14
  # Align two SRT files
15
  aligned_srt = alignment(self.pred_path, self.gt_path)
16
 
 
 
 
17
  # Get sentence scores
18
  scorer = multi_scores()
19
  result_data = []
20
+ for (pred_s, gt_s) in aligned_srt:
21
+ print("pred_s.source_text: ", pred_s.source_text)
22
+ print("pred_s.translation: ", pred_s.translation)
23
+ print("gt_s.source_text: ", gt_s.translation)
24
+
25
+ scores_dict = scorer.get_scores(pred_s.source_text, pred_s.translation, gt_s.translation)
26
+ scores_dict['Source'] = pred_s.source_text
27
+ scores_dict['Prediction'] = pred_s.translation
28
+ scores_dict['Ground Truth'] = gt_s.translation
29
  result_data.append(scores_dict)
30
 
31
  eval_df = pd.DataFrame(result_data)
32
+ eval_df.to_csv(self.eval_path, index=False, columns=['Source', 'Prediction', 'Ground Truth', 'bleu_score', 'comet_score', 'llm_score', 'llm_explanation'])
33
 
34
  # Get average scores
35
+ avg_llm = eval_df['llm_score'].mean()
36
+ avg_bleu = eval_df['bleu_score'].mean()
37
+ avg_comet = eval_df['comet_score'].mean()
38
 
39
+ res_data = {
40
  'Metric': ['Avg LLM', 'Avg BLEU', 'Avg COMET'],
41
  'Score': [avg_llm, avg_bleu, avg_comet]
42
  }
43
+ res_df = pd.DataFrame(res_data)
44
+ res_df.to_csv(self.res_path, index=False)
45
 
46
  if __name__ == "__main__":
47
  parser = argparse.ArgumentParser(description='Evaluate SRT files.')
48
+ parser.add_argument('-bi_path', default='evaluation/test5_tiny/test5_bi.srt', help='Path to predicted SRT file')
49
+ parser.add_argument('-zh_path', default='evaluation/test5_tiny/test5_gt.srt', help='Path to ground truth SRT file')
50
+ parser.add_argument('-eval_output', default='evaluation/test5_tiny/eval.csv', help='Path to eval CSV file')
51
+ parser.add_argument('-res_output', default='evaluation/test5_tiny/res.csv', help='Path to result CSV file')
 
52
  args = parser.parse_args()
53
 
54
+ evaluator = Evaluator(args.bi_path, args.zh_path, args.eval_output, args.res_output)
55
  evaluator.eval()
56