File size: 2,254 Bytes
3060e5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import datasets

from llm.qa_agent import QnAAgent

validation_dataset = datasets.load_dataset(
    "trivia_qa", "rc", split="test"
)  # remove [:5%] to run on full validation set

PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"']))

qna_agent = QnAAgent()


def get_sub_answers(answers, begin=0, end=None):
    return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]


def expand_to_aliases(given_answers, make_sub_answers=False):
    if make_sub_answers:
        # if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
        # *e.g.* if the correct answer contains a prefix such as "the", or "a"
        given_answers = (
            given_answers
            + get_sub_answers(given_answers, begin=1)
            + get_sub_answers(given_answers, end=-1)
        )
    answers = []
    for answer in given_answers:
        alias = answer.replace("_", " ").lower()
        alias = "".join(
            c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias
        )
        answers.append(" ".join(alias.split()).strip())
    return set(answers)


def evaluate(example):
    # get answer from QnA agent
    answer_without_context = qna_agent.get_answer(example["question"], use_context=False)
    answer_with_context = qna_agent.get_answer(example["question"], use_context=True)

    example["output"] = answer_without_context
    example["output_context"] = answer_with_context

    example["targets"] = example["answer"]["aliases"]
    answers = expand_to_aliases(example["targets"], make_sub_answers=True)

    predictions = expand_to_aliases([example["output"]])
    preditions_with_context = expand_to_aliases([example["output_context"]])

    # if there is a common element, it's a match
    example["match"] = len(list(answers & predictions)) > 0
    example["match_context"] = len(list(answers & preditions_with_context)) > 0

    return example


results = validation_dataset.map(evaluate)

print("Exact Match (EM) without context: {:.2f}".format(100 * sum(results['match'])/len(results)))
print("Exact Match (EM) with context: {:.2f}".format(100 * sum(results['match_context'])/len(results)))