docred / sample.py
bowdbeg's picture
debugged
46417fa
raw
history blame
722 Bytes
import datasets
import evaluate
from docred import docred
train_data = datasets.load_dataset("docred", split="train_annotated[:100]").to_list()
pred_data = datasets.load_dataset("docred", split="validation[:10]").to_list()
gold_data = datasets.load_dataset("docred", split="validation[:10]").to_list()
metric = docred()
# gold_data[0]["labels"] = {k: [] for k, v in pred_data[0]["labels"].items()}
# for i in range(len(gold_data)):
# gold_data[i]["labels"] = {k: [] for k, v in pred_data[i]["labels"].items()}
for i in range(len(pred_data)):
pred_data[i]["labels"] = {k: [] for k, v in pred_data[i]["labels"].items()}
print(metric.compute(predictions=pred_data, references=gold_data, train_data=train_data))