|
import sys |
|
from collections import Counter |
|
from Levenshtein import ratio |
|
|
|
|
|
def anls_compute(prediction, ground_truth): |
|
prediction_tokens = prediction.split() |
|
ground_truth_tokens = ground_truth.split() |
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) |
|
num_same = sum(common.values()) |
|
if num_same == 0: |
|
return 0 |
|
precision = 1.0 * num_same / len(prediction_tokens) |
|
recall = 1.0 * num_same / len(ground_truth_tokens) |
|
f1 = (2 * precision * recall) / (precision + recall) |
|
return f1 |
|
|
|
|
|
def compute_score(dataset, predictions): |
|
anls_score = total = 0 |
|
for article in dataset: |
|
for paragraph in article["paragraphs"]: |
|
for qa in paragraph["qas"]: |
|
total += 1 |
|
if qa["id"] not in predictions: |
|
message = "Unanswered question " + qa["id"] + " will receive score 0." |
|
print(message, file=sys.stderr) |
|
continue |
|
ground_truths = list(map(lambda x: x["text"], qa["answers"])) |
|
prediction = predictions[qa["id"]] |
|
score = anls_compute(prediction=prediction, ground_truth=ground_truths) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"anls_score": anls_score} |
|
|