tldr_eval / run.py
shuyanzh's picture
add metric
8ab7e68
import json
from tldr_eval import TLDREval
from collections import Counter
with open('test_results_test_same.json', 'r') as f:
data = []
for line in f:
item = json.loads(line)
data.append(item)
split_data = [[] for _ in range(10)]
qid_counter = Counter()
for item in data:
if item['question_id'] in ['9931', '7895', '3740', '8077', '4737', '7057', '9530']:
continue
split_idx = qid_counter[item['question_id']]
split_data[split_idx].append(item)
qid_counter[item['question_id']] += 1
assert all([len(x) in [918, 1845, 0] for x in split_data])
print([len(x) for x in split_data])
refs = []
preds = []
for item in split_data[0]:
refs.append(item['gold'].replace('\n', ""))
preds.append(item['clean_code'].replace('\n', ""))
evaluator = TLDREval()
metrics = evaluator._compute(preds, refs)
print(metrics)