anls / compute_score.py
Viona's picture
adding main description; skeleton of code
6a4fac9
raw
history blame
1.51 kB
import sys
from collections import Counter
from Levenshtein import ratio
def anls_compute(prediction, ground_truth):
prediction_tokens = prediction.split()
ground_truth_tokens = ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def compute_score(dataset, predictions):
anls_score = total = 0
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
print(message, file=sys.stderr)
continue
ground_truths = list(map(lambda x: x["text"], qa["answers"]))
prediction = predictions[qa["id"]]
score = anls_compute(prediction=prediction, ground_truth=ground_truths)
# exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
# f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
#
# exact_match = 100.0 * exact_match / total
# f1 = 100.0 * f1 / total
return {"anls_score": anls_score}