Spaces:
Sleeping
Sleeping
fix: eval
Browse filesFormer-commit-id: 0a964e3f26e607fa7a85bbed9bff7ff13c13299a
- evaluation/evaluation.py +25 -26
evaluation/evaluation.py
CHANGED
@@ -1,57 +1,56 @@
|
|
1 |
import argparse
|
2 |
import pandas as pd
|
3 |
-
from
|
4 |
-
from
|
5 |
-
from src.srt_util.srt import SrtScript
|
6 |
|
7 |
class Evaluator:
|
8 |
-
def __init__(self,
|
9 |
-
self.src_path = src_path
|
10 |
self.pred_path = pred_path
|
11 |
self.gt_path = gt_path
|
12 |
self.eval_path = eval_path
|
13 |
-
self.
|
14 |
|
15 |
def eval(self):
|
16 |
# Align two SRT files
|
17 |
aligned_srt = alignment(self.pred_path, self.gt_path)
|
18 |
|
19 |
-
# Parse src
|
20 |
-
src_s = [s.source_text for s in SrtScript.parse_from_srt_file(self.src_path).segments]
|
21 |
-
|
22 |
# Get sentence scores
|
23 |
scorer = multi_scores()
|
24 |
result_data = []
|
25 |
-
for (
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
result_data.append(scores_dict)
|
30 |
|
31 |
eval_df = pd.DataFrame(result_data)
|
32 |
-
eval_df.to_csv(self.
|
33 |
|
34 |
# Get average scores
|
35 |
-
avg_llm = eval_df['
|
36 |
-
avg_bleu = eval_df['
|
37 |
-
avg_comet = eval_df['
|
38 |
|
39 |
-
|
40 |
'Metric': ['Avg LLM', 'Avg BLEU', 'Avg COMET'],
|
41 |
'Score': [avg_llm, avg_bleu, avg_comet]
|
42 |
}
|
43 |
-
|
44 |
-
|
45 |
|
46 |
if __name__ == "__main__":
|
47 |
parser = argparse.ArgumentParser(description='Evaluate SRT files.')
|
48 |
-
parser.add_argument('-
|
49 |
-
parser.add_argument('-
|
50 |
-
parser.add_argument('-
|
51 |
-
parser.add_argument('-
|
52 |
-
parser.add_argument('-conclusion', default='conclusion.csv', help='Path to conclusion CSV file')
|
53 |
args = parser.parse_args()
|
54 |
|
55 |
-
evaluator = Evaluator(args.
|
56 |
evaluator.eval()
|
57 |
|
|
|
1 |
import argparse
|
2 |
import pandas as pd
|
3 |
+
from alignment import alignment
|
4 |
+
from scores.multi_scores import multi_scores
|
|
|
5 |
|
6 |
class Evaluator:
|
7 |
+
def __init__(self, pred_path, gt_path, eval_path, res_path):
|
|
|
8 |
self.pred_path = pred_path
|
9 |
self.gt_path = gt_path
|
10 |
self.eval_path = eval_path
|
11 |
+
self.res_path = res_path
|
12 |
|
13 |
def eval(self):
|
14 |
# Align two SRT files
|
15 |
aligned_srt = alignment(self.pred_path, self.gt_path)
|
16 |
|
|
|
|
|
|
|
17 |
# Get sentence scores
|
18 |
scorer = multi_scores()
|
19 |
result_data = []
|
20 |
+
for (pred_s, gt_s) in aligned_srt:
|
21 |
+
print("pred_s.source_text: ", pred_s.source_text)
|
22 |
+
print("pred_s.translation: ", pred_s.translation)
|
23 |
+
print("gt_s.source_text: ", gt_s.translation)
|
24 |
+
|
25 |
+
scores_dict = scorer.get_scores(pred_s.source_text, pred_s.translation, gt_s.translation)
|
26 |
+
scores_dict['Source'] = pred_s.source_text
|
27 |
+
scores_dict['Prediction'] = pred_s.translation
|
28 |
+
scores_dict['Ground Truth'] = gt_s.translation
|
29 |
result_data.append(scores_dict)
|
30 |
|
31 |
eval_df = pd.DataFrame(result_data)
|
32 |
+
eval_df.to_csv(self.eval_path, index=False, columns=['Source', 'Prediction', 'Ground Truth', 'bleu_score', 'comet_score', 'llm_score', 'llm_explanation'])
|
33 |
|
34 |
# Get average scores
|
35 |
+
avg_llm = eval_df['llm_score'].mean()
|
36 |
+
avg_bleu = eval_df['bleu_score'].mean()
|
37 |
+
avg_comet = eval_df['comet_score'].mean()
|
38 |
|
39 |
+
res_data = {
|
40 |
'Metric': ['Avg LLM', 'Avg BLEU', 'Avg COMET'],
|
41 |
'Score': [avg_llm, avg_bleu, avg_comet]
|
42 |
}
|
43 |
+
res_df = pd.DataFrame(res_data)
|
44 |
+
res_df.to_csv(self.res_path, index=False)
|
45 |
|
46 |
if __name__ == "__main__":
|
47 |
parser = argparse.ArgumentParser(description='Evaluate SRT files.')
|
48 |
+
parser.add_argument('-bi_path', default='evaluation/test5_tiny/test5_bi.srt', help='Path to predicted SRT file')
|
49 |
+
parser.add_argument('-zh_path', default='evaluation/test5_tiny/test5_gt.srt', help='Path to ground truth SRT file')
|
50 |
+
parser.add_argument('-eval_output', default='evaluation/test5_tiny/eval.csv', help='Path to eval CSV file')
|
51 |
+
parser.add_argument('-res_output', default='evaluation/test5_tiny/res.csv', help='Path to result CSV file')
|
|
|
52 |
args = parser.parse_args()
|
53 |
|
54 |
+
evaluator = Evaluator(args.bi_path, args.zh_path, args.eval_output, args.res_output)
|
55 |
evaluator.eval()
|
56 |
|