File size: 1,326 Bytes
6a4fac9
 
 
6be1f2a
29d0f05
 
95d4295
29d0f05
 
 
 
95d4295
 
29d0f05
 
 
 
 
6a4fac9
95d4295
29d0f05
 
 
95d4295
 
 
29d0f05
6be1f2a
29d0f05
 
6be1f2a
 
 
29d0f05
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from Levenshtein import ratio


def compute_score(predictions, ground_truths):
    theta = 0.5
    anls_score = 0
    total = 0
    for qid, prediction in predictions.items():
        max_value = 0
        if qid in ground_truths:
            for x in ground_truths[qid]:
                total += 1
                nl = ratio(prediction.lower(), x.lower())
                if nl < theta:
                    score = 1 - nl
                    if score > max_value:
                        max_value = score
            anls_score += max_value

    return anls_score/total


if __name__ == "__main__":
    predictions = [{'question_id': '10285', 'prediction_text': 'Denver R.'},
                   {'question_id': '18601', 'prediction_text': '12'},
                   {'question_id': '16734', 'prediction_text': 'dear'}]

    references = [{"answers": ["Denver Broncos", "Denver R. Broncos"], 'question_id': '10285'},
               {'answers': ['12/15/88'], 'question_id': '18601'},
               {'answers': ['Dear Dr. Lobo', 'Dr. Lobo'], 'question_id': '16734'}]
    ground_truths = {x['question_id']: x['answers'] for x in references}
    predictions = {x['question_id']: x['prediction_text'] for x in predictions}
    anls_score = compute_score(predictions=predictions, ground_truths=ground_truths)
    print(anls_score)