File size: 722 Bytes
ba888e1 46417fa ba888e1 46417fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import datasets
import evaluate
from docred import docred
train_data = datasets.load_dataset("docred", split="train_annotated[:100]").to_list()
pred_data = datasets.load_dataset("docred", split="validation[:10]").to_list()
gold_data = datasets.load_dataset("docred", split="validation[:10]").to_list()
metric = docred()
# gold_data[0]["labels"] = {k: [] for k, v in pred_data[0]["labels"].items()}
# for i in range(len(gold_data)):
# gold_data[i]["labels"] = {k: [] for k, v in pred_data[i]["labels"].items()}
for i in range(len(pred_data)):
pred_data[i]["labels"] = {k: [] for k, v in pred_data[i]["labels"].items()}
print(metric.compute(predictions=pred_data, references=gold_data, train_data=train_data))
|