Spaces:
Sleeping
Sleeping
add Evaluator class
Browse filesFormer-commit-id: 6ba881680089ff333c7c47bf3a369f8655b1a511
- evaluation/evaluation.py +57 -0
evaluation/evaluation.py
CHANGED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pandas as pd
|
3 |
+
from evaluation.alignment import alignment
|
4 |
+
from evaluation.scores.multi_scores import multi_scores
|
5 |
+
from src.srt_util.srt import SrtScript
|
6 |
+
|
7 |
+
class Evaluator:
|
8 |
+
def __init__(self, src_path, pred_path, gt_path, eval_path, conclusion_path):
|
9 |
+
self.src_path = src_path
|
10 |
+
self.pred_path = pred_path
|
11 |
+
self.gt_path = gt_path
|
12 |
+
self.eval_path = eval_path
|
13 |
+
self.conclusion_path = conclusion_path
|
14 |
+
|
15 |
+
def eval(self):
|
16 |
+
# Align two SRT files
|
17 |
+
aligned_srt = alignment(self.pred_path, self.gt_path)
|
18 |
+
|
19 |
+
# Parse src
|
20 |
+
src_s = [s.source_text for s in SrtScript.parse_from_srt_file(self.src_path).segments]
|
21 |
+
|
22 |
+
# Get sentence scores
|
23 |
+
scorer = multi_scores()
|
24 |
+
result_data = []
|
25 |
+
for ((prd_s, gt_s), src_s) in zip(aligned_srt, src_s):
|
26 |
+
scores_dict = scorer.get(src_s, prd_s, gt_s)
|
27 |
+
scores_dict['Prediction'] = prd_s
|
28 |
+
scores_dict['Ground Truth'] = gt_s
|
29 |
+
result_data.append(scores_dict)
|
30 |
+
|
31 |
+
eval_df = pd.DataFrame(result_data)
|
32 |
+
eval_df.to_csv(self.output_path, index=False, columns=['Prediction', 'Ground Truth', 'llm', 'bleu', 'comet'])
|
33 |
+
|
34 |
+
# Get average scores
|
35 |
+
avg_llm = eval_df['llm'].mean()
|
36 |
+
avg_bleu = eval_df['bleu'].mean()
|
37 |
+
avg_comet = eval_df['comet'].mean()
|
38 |
+
|
39 |
+
conclusion_data = {
|
40 |
+
'Metric': ['Avg LLM', 'Avg BLEU', 'Avg COMET'],
|
41 |
+
'Score': [avg_llm, avg_bleu, avg_comet]
|
42 |
+
}
|
43 |
+
conclusion_df = pd.DataFrame(conclusion_data)
|
44 |
+
conclusion_df.to_csv(self.conclusion_path, index=False)
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
parser = argparse.ArgumentParser(description='Evaluate SRT files.')
|
48 |
+
parser.add_argument('-src', default='test/short_src', help='Path to source SRT file')
|
49 |
+
parser.add_argument('-pred', default='test/short_pred', help='Path to predicted SRT file')
|
50 |
+
parser.add_argument('-gt', default='test/short_gt', help='Path to ground truth SRT file')
|
51 |
+
parser.add_argument('-eval', default='eval.csv', help='Path to output CSV file')
|
52 |
+
parser.add_argument('-conclusion', default='conclusion.csv', help='Path to conclusion CSV file')
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
evaluator = Evaluator(args.src, args.pred, args.gt, args.eval, args.conclusion)
|
56 |
+
evaluator.eval()
|
57 |
+
|