File size: 884 Bytes
8ab7e68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import json
from tldr_eval import TLDREval
from collections import Counter
with open('test_results_test_same.json', 'r') as f:
data = []
for line in f:
item = json.loads(line)
data.append(item)
split_data = [[] for _ in range(10)]
qid_counter = Counter()
for item in data:
if item['question_id'] in ['9931', '7895', '3740', '8077', '4737', '7057', '9530']:
continue
split_idx = qid_counter[item['question_id']]
split_data[split_idx].append(item)
qid_counter[item['question_id']] += 1
assert all([len(x) in [918, 1845, 0] for x in split_data])
print([len(x) for x in split_data])
refs = []
preds = []
for item in split_data[0]:
refs.append(item['gold'].replace('\n', ""))
preds.append(item['clean_code'].replace('\n', ""))
evaluator = TLDREval()
metrics = evaluator._compute(preds, refs)
print(metrics)
|