klue_mrc / compute_score.py
ingyu's picture
Add KLUE-MRC metric (F1 and EM)
0aed8fc
"""Unofficial evaluation script for KLUE-MRC.
Please note that as KLUE-MRC has the same task format as SQuAD 2.0,
this evaluation script follows almost the same format as the official evaluation script for SQuAD 2.0.
"""
import argparse
import collections
import json
import os
import string
import sys
import numpy as np
OPTS = None
def parse_args():
parser = argparse.ArgumentParser("Unofficial evaluation script for KLUE-MRC.")
parser.add_argument("data_file", metavar="data.json", help="Input data JSON file.")
parser.add_argument("pred_file", metavar="pred.json", help="Model predictions.")
parser.add_argument(
"--out-file", "-o", metavar="eval.json", help="Write accuracy metrics to file (default is stdout)."
)
parser.add_argument(
"--na-prob-file", "-n", metavar="na_prob.json", help="Model estimates of probability of no answer."
)
parser.add_argument(
"--na-prob-thresh",
"-t",
type=float,
default=1.0,
help='Predict "" if no-answer probability exceeds this (default = 1.0).',
)
parser.add_argument(
"--out-image-dir", "-p", metavar="out_images", default=None, help="Save precision-recall curves to directory."
)
parser.add_argument("--verbose", "-v", action="store_true")
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid_to_has_ans[qa["id"]] = not bool(qa["unanswerable"])
return qid_to_has_ans
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_punc(lower(s)))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def get_raw_scores(dataset, preds):
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article["paragraphs"]:
for qa in p["qas"]:
qid = qa["id"]
gold_answers = [t for t in qa["answers"]["text"] if normalize_answer(t)]
if qa["unanswerable"]:
# For unanswerable questions, only correct answer is empty string
gold_answers = [""]
if qid not in preds:
print(f"Missing prediction for {qid}")
continue
a_pred = preds[qid]
# Take max over all gold answers
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores.values()) / total),
("f1", 100.0 * sum(f1_scores.values()) / total),
("total", total),
]
)
else:
total = len(qid_list)
return collections.OrderedDict(
[
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
("total", total),
]
)
def merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval[f"{prefix}_{k}"] = new_eval[k]
def plot_pr_curve(precisions, recalls, out_image, title):
plt.step(recalls, precisions, color="b", alpha=0.2, where="post")
plt.fill_between(recalls, precisions, step="post", alpha=0.2, color="b")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.xlim([0.0, 1.05])
plt.ylim([0.0, 1.05])
plt.title(title)
plt.savefig(out_image)
plt.clf()
def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, out_image=None, title=None):
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
true_pos = 0.0
cur_p = 1.0
cur_r = 0.0
precisions = [1.0]
recalls = [0.0]
avg_prec = 0.0
for i, qid in enumerate(qid_list):
if qid_to_has_ans[qid]:
true_pos += scores[qid]
cur_p = true_pos / float(i + 1)
cur_r = true_pos / float(num_true_pos)
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i + 1]]:
# i.e., if we can put a threshold after this point
avg_prec += cur_p * (cur_r - recalls[-1])
precisions.append(cur_p)
recalls.append(cur_r)
if out_image:
plot_pr_curve(precisions, recalls, out_image, title)
return {"ap": 100.0 * avg_prec}
def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, out_image_dir):
if out_image_dir and not os.path.exists(out_image_dir):
os.makedirs(out_image_dir)
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
if num_true_pos == 0:
return
pr_exact = make_precision_recall_eval(
exact_raw,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_exact.png"),
title="Precision-Recall curve for Exact Match score",
)
pr_f1 = make_precision_recall_eval(
f1_raw,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_f1.png"),
title="Precision-Recall curve for F1 score",
)
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
pr_oracle = make_precision_recall_eval(
oracle_scores,
na_probs,
num_true_pos,
qid_to_has_ans,
out_image=os.path.join(out_image_dir, "pr_oracle.png"),
title="Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)",
)
merge_eval(main_eval, pr_exact, "pr_exact")
merge_eval(main_eval, pr_f1, "pr_f1")
merge_eval(main_eval, pr_oracle, "pr_oracle")
def histogram_na_prob(na_probs, qid_list, image_dir, name):
if not qid_list:
return
x = [na_probs[k] for k in qid_list]
weights = np.ones_like(x) / float(len(x))
plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
plt.xlabel("Model probability of no-answer")
plt.ylabel("Proportion of dataset")
plt.title(f"Histogram of no-answer probability: {name}")
plt.savefig(os.path.join(image_dir, f"na_prob_hist_{name}.png"))
plt.clf()
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for i, qid in enumerate(qid_list):
if qid not in scores:
continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if preds[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
main_eval["best_exact"] = best_exact
main_eval["best_exact_thresh"] = exact_thresh
main_eval["best_f1"] = best_f1
main_eval["best_f1_thresh"] = f1_thresh
def main():
with open(OPTS.data_file) as f:
dataset_json = json.load(f)
dataset = dataset_json["data"]
with open(OPTS.pred_file) as f:
preds = json.load(f)
if OPTS.na_prob_file:
with open(OPTS.na_prob_file) as f:
na_probs = json.load(f)
else:
na_probs = {k: 0.0 for k in preds}
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(dataset, preds)
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, OPTS.na_prob_thresh)
out_eval = make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
merge_eval(out_eval, has_ans_eval, "HasAns")
if no_ans_qids:
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
merge_eval(out_eval, no_ans_eval, "NoAns")
if OPTS.na_prob_file:
find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
if OPTS.na_prob_file and OPTS.out_image_dir:
run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans, OPTS.out_image_dir)
histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, "hasAns")
histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, "noAns")
if OPTS.out_file:
with open(OPTS.out_file, "w") as f:
json.dump(out_eval, f)
else:
print(json.dumps(out_eval, indent=2))
if __name__ == "__main__":
OPTS = parse_args()
if OPTS.out_image_dir:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
main()