""" Official evaluation script for CUAD dataset. """ import argparse import json import re import string import sys import numpy as np IOU_THRESH = 0.5 def get_jaccard(prediction, ground_truth): remove_tokens = [".", ",", ";", ":"] for token in remove_tokens: ground_truth = ground_truth.replace(token, "") prediction = prediction.replace(token, "") ground_truth, prediction = ground_truth.lower(), prediction.lower() ground_truth, prediction = ground_truth.replace("/", " "), prediction.replace("/", " ") ground_truth, prediction = set(ground_truth.split(" ")), set(prediction.split(" ")) intersection = ground_truth.intersection(prediction) union = ground_truth.union(prediction) jaccard = len(intersection) / len(union) return jaccard def normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def compute_precision_recall(predictions, ground_truths, qa_id): tp, fp, fn = 0, 0, 0 substr_ok = "Parties" in qa_id # first check if ground truth is empty if len(ground_truths) == 0: if len(predictions) > 0: fp += len(predictions) # false positive for each one else: for ground_truth in ground_truths: assert len(ground_truth) > 0 # check if there is a match match_found = False for pred in predictions: if substr_ok: is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred else: is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH if is_match: match_found = True if match_found: tp += 1 else: fn += 1 # now also get any fps by looping through preds for pred in predictions: # Check if there's a match. if so, don't count (don't want to double count based on the above) # but if there's no match, then this is a false positive. # (Note: we get the true positives in the above loop instead of this loop so that we don't double count # multiple predictions that are matched with the same answer.) match_found = False for ground_truth in ground_truths: assert len(ground_truth) > 0 if substr_ok: is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH or ground_truth in pred else: is_match = get_jaccard(pred, ground_truth) >= IOU_THRESH if is_match: match_found = True if not match_found: fp += 1 precision = tp / (tp + fp) if tp + fp > 0 else np.nan recall = tp / (tp + fn) if tp + fn > 0 else np.nan return precision, recall def process_precisions(precisions): """ Processes precisions to ensure that precision and recall don't both get worse. Assumes the list precision is sorted in order of recalls """ precision_best = precisions[::-1] for i in range(1, len(precision_best)): precision_best[i] = max(precision_best[i - 1], precision_best[i]) precisions = precision_best[::-1] return precisions def get_aupr(precisions, recalls): processed_precisions = process_precisions(precisions) aupr = np.trapz(processed_precisions, recalls) if np.isnan(aupr): return 0 return aupr def get_prec_at_recall(precisions, recalls, recall_thresh): """Assumes recalls are sorted in increasing order""" processed_precisions = process_precisions(precisions) prec_at_recall = 0 for prec, recall in zip(processed_precisions, recalls): if recall >= recall_thresh: prec_at_recall = prec break return prec_at_recall def exact_match_score(prediction, ground_truth): return normalize_answer(prediction) == normalize_answer(ground_truth) def metric_max_over_ground_truths(metric_fn, predictions, ground_truths): score = 0 for pred in predictions: for ground_truth in ground_truths: score = metric_fn(pred, ground_truth) if score == 1: # break the loop when one prediction matches the ground truth break if score == 1: break return score def compute_score(dataset, predictions): f1 = exact_match = total = 0 precisions = [] recalls = [] for article in dataset: for paragraph in article["paragraphs"]: for qa in paragraph["qas"]: total += 1 if qa["id"] not in predictions: message = "Unanswered question " + qa["id"] + " will receive score 0." print(message, file=sys.stderr) continue ground_truths = list(map(lambda x: x["text"], qa["answers"])) prediction = predictions[qa["id"]] precision, recall = compute_precision_recall(prediction, ground_truths, qa["id"]) precisions.append(precision) recalls.append(recall) if precision == 0 and recall == 0: f1 += 0 else: f1 += 2 * (precision * recall) / (precision + recall) exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) precisions = [x for _, x in sorted(zip(recalls, precisions))] recalls.sort() f1 = 100.0 * f1 / total exact_match = 100.0 * exact_match / total aupr = get_aupr(precisions, recalls) prec_at_90_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.9) prec_at_80_recall = get_prec_at_recall(precisions, recalls, recall_thresh=0.8) return { "exact_match": exact_match, "f1": f1, "aupr": aupr, "prec_at_80_recall": prec_at_80_recall, "prec_at_90_recall": prec_at_90_recall, } if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluation for CUAD") parser.add_argument("dataset_file", help="Dataset file") parser.add_argument("prediction_file", help="Prediction File") args = parser.parse_args() with open(args.dataset_file) as dataset_file: dataset_json = json.load(dataset_file) dataset = dataset_json["data"] with open(args.prediction_file) as prediction_file: predictions = json.load(prediction_file) print(json.dumps(compute_score(dataset, predictions)))