pinnnn commited on
Commit
ee800e9
1 Parent(s): 1a7e1ee

add Evaluator class

Browse files

Former-commit-id: 6ba881680089ff333c7c47bf3a369f8655b1a511

Files changed (1) hide show
  1. evaluation/evaluation.py +57 -0
evaluation/evaluation.py CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pandas as pd
3
+ from evaluation.alignment import alignment
4
+ from evaluation.scores.multi_scores import multi_scores
5
+ from src.srt_util.srt import SrtScript
6
+
7
+ class Evaluator:
8
+ def __init__(self, src_path, pred_path, gt_path, eval_path, conclusion_path):
9
+ self.src_path = src_path
10
+ self.pred_path = pred_path
11
+ self.gt_path = gt_path
12
+ self.eval_path = eval_path
13
+ self.conclusion_path = conclusion_path
14
+
15
+ def eval(self):
16
+ # Align two SRT files
17
+ aligned_srt = alignment(self.pred_path, self.gt_path)
18
+
19
+ # Parse src
20
+ src_s = [s.source_text for s in SrtScript.parse_from_srt_file(self.src_path).segments]
21
+
22
+ # Get sentence scores
23
+ scorer = multi_scores()
24
+ result_data = []
25
+ for ((prd_s, gt_s), src_s) in zip(aligned_srt, src_s):
26
+ scores_dict = scorer.get(src_s, prd_s, gt_s)
27
+ scores_dict['Prediction'] = prd_s
28
+ scores_dict['Ground Truth'] = gt_s
29
+ result_data.append(scores_dict)
30
+
31
+ eval_df = pd.DataFrame(result_data)
32
+ eval_df.to_csv(self.output_path, index=False, columns=['Prediction', 'Ground Truth', 'llm', 'bleu', 'comet'])
33
+
34
+ # Get average scores
35
+ avg_llm = eval_df['llm'].mean()
36
+ avg_bleu = eval_df['bleu'].mean()
37
+ avg_comet = eval_df['comet'].mean()
38
+
39
+ conclusion_data = {
40
+ 'Metric': ['Avg LLM', 'Avg BLEU', 'Avg COMET'],
41
+ 'Score': [avg_llm, avg_bleu, avg_comet]
42
+ }
43
+ conclusion_df = pd.DataFrame(conclusion_data)
44
+ conclusion_df.to_csv(self.conclusion_path, index=False)
45
+
46
+ if __name__ == "__main__":
47
+ parser = argparse.ArgumentParser(description='Evaluate SRT files.')
48
+ parser.add_argument('-src', default='test/short_src', help='Path to source SRT file')
49
+ parser.add_argument('-pred', default='test/short_pred', help='Path to predicted SRT file')
50
+ parser.add_argument('-gt', default='test/short_gt', help='Path to ground truth SRT file')
51
+ parser.add_argument('-eval', default='eval.csv', help='Path to output CSV file')
52
+ parser.add_argument('-conclusion', default='conclusion.csv', help='Path to conclusion CSV file')
53
+ args = parser.parse_args()
54
+
55
+ evaluator = Evaluator(args.src, args.pred, args.gt, args.eval, args.conclusion)
56
+ evaluator.eval()
57
+