|
import json |
|
from tldr_eval import TLDREval |
|
from collections import Counter |
|
|
|
with open('test_results_test_same.json', 'r') as f: |
|
data = [] |
|
for line in f: |
|
item = json.loads(line) |
|
data.append(item) |
|
|
|
split_data = [[] for _ in range(10)] |
|
qid_counter = Counter() |
|
for item in data: |
|
if item['question_id'] in ['9931', '7895', '3740', '8077', '4737', '7057', '9530']: |
|
continue |
|
split_idx = qid_counter[item['question_id']] |
|
split_data[split_idx].append(item) |
|
qid_counter[item['question_id']] += 1 |
|
assert all([len(x) in [918, 1845, 0] for x in split_data]) |
|
print([len(x) for x in split_data]) |
|
|
|
refs = [] |
|
preds = [] |
|
for item in split_data[0]: |
|
refs.append(item['gold'].replace('\n', "")) |
|
preds.append(item['clean_code'].replace('\n', "")) |
|
|
|
|
|
evaluator = TLDREval() |
|
metrics = evaluator._compute(preds, refs) |
|
print(metrics) |
|
|