cuad / compute_score.py
lvwerra's picture
lvwerra HF staff
Update Space (evaluate main: 828c6327)
22a9bb8
raw history blame
No virus
6.98 kB
""" Official evaluation script for CUAD dataset. """
import argparse
import json
import re
import string
import sys
import numpy as np
IOU_THRESH = 0.5
def get_jaccard(prediction, ground_truth):
remove_tokens = [".", ",", ";", ":"]
for token in remove_tokens:
ground_truth = ground_truth.replace(token, "")
prediction = prediction.replace(token, "")
ground_truth, prediction = ground_truth.lower(), prediction.lower()
ground_truth, prediction = ground_truth.replace("/", " "), prediction.replace("/", " ")
ground_truth, prediction = set(ground_truth.split(" ")), set(prediction.split(" "))
intersection = ground_truth.intersection(prediction)
union = ground_truth.union(prediction)
jaccard = len(intersection) / len(union)
return jaccard
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def compute_precision_recall(predictions, ground_truths, qa_id):
tp, fp, fn = 0, 0, 0
substr_ok = "Parties" in qa_id
# first check if ground truth is empty
if len(ground_truths) == 0:
if len(predictions) > 0:
fp += len(predictions) # false positive for each one
else:
for ground_truth in ground_truths:
assert len(ground_truth) > 0
# check if there is a match
match_found = False
for pred in predictions:
if substr_ok:
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred
else:
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH
if is_match:
match_found = True
if match_found:
tp += 1
else:
fn += 1
# now also get any fps by looping through preds
for pred in predictions:
# Check if there's a match. if so, don't count (don't want to double count based on the above)
# but if there's no match, then this is a false positive.
# (Note: we get the true positives in the above loop instead of this loop so that we don't double count
# multiple predictions that are matched with the same answer.)
match_found = False
for ground_truth in ground_truths:
assert len(ground_truth) > 0
if substr_ok:
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred
else:
is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH
if is_match:
match_found = True
if not match_found:
fp += 1
precision = tp / (tp + fp) if tp + fp > 0 else np.nan
recall = tp / (tp + fn) if tp + fn > 0 else np.nan
return precision, recall
def process_precisions(precisions):
"""
Processes precisions to ensure that precision and recall don't both get worse.
Assumes the list precision is sorted in order of recalls
"""
precision_best = precisions[::-1]
for i in range(1, len(precision_best)):
precision_best[i] = max(precision_best[i - 1], precision_best[i])
precisions = precision_best[::-1]
return precisions
def get_aupr(precisions, recalls):
processed_precisions = process_precisions(precisions)
aupr = np.trapz(processed_precisions, recalls)
if np.isnan(aupr):
return 0
return aupr
def get_prec_at_recall(precisions, recalls, recall_thresh):
"""Assumes recalls are sorted in increasing order"""
processed_precisions = process_precisions(precisions)
prec_at_recall = 0
for prec, recall in zip(processed_precisions, recalls):
if recall >= recall_thresh:
prec_at_recall = prec
break
return prec_at_recall
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def metric_max_over_ground_truths(metric_fn, predictions, ground_truths):
score = 0
for pred in predictions:
for ground_truth in ground_truths:
score = metric_fn(pred, ground_truth)
if score == 1: # break the loop when one prediction matches the ground truth
break
if score == 1:
break
return score
def compute_score(dataset, predictions):
f1 = exact_match = total = 0
precisions = []
recalls = []
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
print(message, file=sys.stderr)
continue
ground_truths = list(map(lambda x: x["text"], qa["answers"]))
prediction = predictions[qa["id"]]
precision, recall = compute_precision_recall(prediction, ground_truths, qa["id"])
precisions.append(precision)
recalls.append(recall)
if precision == 0 and recall == 0:
f1 += 0
else:
f1 += 2 * (precision * recall) / (precision + recall)
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
precisions = [x for _, x in sorted(zip(recalls, precisions))]
recalls.sort()
f1 = 100.0 * f1 / total
exact_match = 100.0 * exact_match / total
aupr = get_aupr(precisions, recalls)
prec_at_90_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.9)
prec_at_80_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.8)
return {
"exact_match": exact_match,
"f1": f1,
"aupr": aupr,
"prec_at_80_recall": prec_at_80_recall,
"prec_at_90_recall": prec_at_90_recall,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluation for CUAD")
parser.add_argument("dataset_file", help="Dataset file")
parser.add_argument("prediction_file", help="Prediction File")
args = parser.parse_args()
with open(args.dataset_file) as dataset_file:
dataset_json = json.load(dataset_file)
dataset = dataset_json["data"]
with open(args.prediction_file) as prediction_file:
predictions = json.load(prediction_file)
print(json.dumps(compute_score(dataset, predictions)))