import json from tldr_eval import TLDREval from collections import Counter with open('test_results_test_same.json', 'r') as f: data = [] for line in f: item = json.loads(line) data.append(item) split_data = [[] for _ in range(10)] qid_counter = Counter() for item in data: if item['question_id'] in ['9931', '7895', '3740', '8077', '4737', '7057', '9530']: continue split_idx = qid_counter[item['question_id']] split_data[split_idx].append(item) qid_counter[item['question_id']] += 1 assert all([len(x) in [918, 1845, 0] for x in split_data]) print([len(x) for x in split_data]) refs = [] preds = [] for item in split_data[0]: refs.append(item['gold'].replace('\n', "")) preds.append(item['clean_code'].replace('\n', "")) evaluator = TLDREval() metrics = evaluator._compute(preds, refs) print(metrics)