File size: 1,510 Bytes
6a4fac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import sys
from collections import Counter
from Levenshtein import ratio


def anls_compute(prediction, ground_truth):
    prediction_tokens = prediction.split()
    ground_truth_tokens = ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def compute_score(dataset, predictions):
    anls_score = total = 0
    for article in dataset:
        for paragraph in article["paragraphs"]:
            for qa in paragraph["qas"]:
                total += 1
                if qa["id"] not in predictions:
                    message = "Unanswered question " + qa["id"] + " will receive score 0."
                    print(message, file=sys.stderr)
                    continue
                ground_truths = list(map(lambda x: x["text"], qa["answers"]))
                prediction = predictions[qa["id"]]
                score = anls_compute(prediction=prediction, ground_truth=ground_truths)
    #             exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
    #             f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
    #
    # exact_match = 100.0 * exact_match / total
    # f1 = 100.0 * f1 / total

    return {"anls_score": anls_score}