File size: 722 Bytes
ba888e1
 
 
 
 
46417fa
 
 
ba888e1
 
46417fa
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import datasets
import evaluate

from docred import docred

train_data = datasets.load_dataset("docred", split="train_annotated[:100]").to_list()
pred_data = datasets.load_dataset("docred", split="validation[:10]").to_list()
gold_data = datasets.load_dataset("docred", split="validation[:10]").to_list()
metric = docred()

# gold_data[0]["labels"] = {k: [] for k, v in pred_data[0]["labels"].items()}

# for i in range(len(gold_data)):
#     gold_data[i]["labels"] = {k: [] for k, v in pred_data[i]["labels"].items()}

for i in range(len(pred_data)):
    pred_data[i]["labels"] = {k: [] for k, v in pred_data[i]["labels"].items()}

print(metric.compute(predictions=pred_data, references=gold_data, train_data=train_data))