File size: 5,627 Bytes
af461f3
b06298d
af461f3
 
 
51a31d4
f2e3e47
b06298d
51dabd6
af461f3
ab5dfc2
b06298d
af461f3
ab5dfc2
51a31d4
325e3c6
b06298d
ab5dfc2
 
f2e3e47
 
ab5dfc2
51dabd6
 
51a31d4
e9df5ab
 
51a31d4
 
b06298d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7158e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import random
from typing import Dict, cast

import torch
import transformers
from datasets import DatasetDict, load_dataset
from dotenv import load_dotenv
from query import print_answers

from src.evaluation import evaluate
from src.readers.dpr_reader import DprReader
from src.retrievers.base_retriever import Retriever
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import FaissRetriever
from src.utils.log import get_logger
from src.utils.preprocessing import context_to_reader_input
from src.utils.timing import get_times, timeit

logger = get_logger()

load_dotenv()
transformers.logging.set_verbosity_error()

if __name__ == '__main__':
    dataset_name = "GroNLP/ik-nlp-22_slp"
    paragraphs = cast(DatasetDict, load_dataset(
        "GroNLP/ik-nlp-22_slp", "paragraphs"))
    questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))

    # Only doing a few questions for speed
    subset_idx = 3
    questions_test = questions["test"][:subset_idx]

    experiments: Dict[str, Retriever] = {
        "faiss": FaissRetriever(paragraphs),
        # "es": ESRetriever(paragraphs),
    }

    for experiment_name, retriever in experiments.items():
        reader = DprReader()

        for idx in range(subset_idx):
            question = questions_test["question"][idx]
            answer = questions_test["answer"][idx]

            scores, context = retriever.retrieve(question, 5)
            reader_input = context_to_reader_input(context)

            # workaround so we can use the decorator with a dynamic name for time recording
            time_wrapper = timeit(f"{experiment_name}.read")
            answers = time_wrapper(reader.read)(question, reader_input, 5)

            # Calculate softmaxed scores for readable output
            sm = torch.nn.Softmax(dim=0)
            document_scores = sm(torch.Tensor(
                [pred.relevance_score for pred in answers]))
            span_scores = sm(torch.Tensor(
                [pred.span_score for pred in answers]))

            print_answers(answers, scores, context)

            # TODO evaluation and storing of results

    times = get_times()
    print(times)
    # TODO evaluation and storing of results

    # # Initialize retriever
    # retriever = FaissRetriever(paragraphs)
    # # retriever = ESRetriever(paragraphs)

    # # Retrieve example
    # # random.seed(111)
    # random_index = random.randint(0, len(questions_test["question"])-1)
    # example_q = questions_test["question"][random_index]
    # example_a = questions_test["answer"][random_index]

    # scores, result = retriever.retrieve(example_q)
    # reader_input = context_to_reader_input(result)

    # # TODO: use new code from query.py to clean this up
    # # Initialize reader
    # answers = reader.read(example_q, reader_input)

    # # Calculate softmaxed scores for readable output
    # sm = torch.nn.Softmax(dim=0)
    # document_scores = sm(torch.Tensor(
    #     [pred.relevance_score for pred in answers]))
    # span_scores = sm(torch.Tensor(
    #     [pred.span_score for pred in answers]))

    # print(example_q)
    # for answer_i, answer in enumerate(answers):
    #     print(f"[{answer_i + 1}]: {answer.text}")
    #     print(f"\tDocument {answer.doc_id}", end='')
    #     print(f"\t(score {document_scores[answer_i] * 100:.02f})")
    #     print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
    #     print(f"\t(score {span_scores[answer_i] * 100:.02f})")
    #     print()  # Newline

    # # print(f"Example q: {example_q} answer: {result['text'][0]}")

    # # for i, score in enumerate(scores):
    # #     print(f"Result {i+1} (score: {score:.02f}):")
    # #     print(result['text'][i])

    # # Determine best answer we want to evaluate
    # highest, highest_index = 0, 0
    # for i, value in enumerate(span_scores):
    #     if value + document_scores[i] > highest:
    #         highest = value + document_scores[i]
    #         highest_index = i

    # # Retrieve exact match and F1-score
    # exact_match, f1_score = evaluate(
    #     example_a, answers[highest_index].text)
    # print(f"Gold answer: {example_a}\n"
    #       f"Predicted answer: {answers[highest_index].text}\n"
    #       f"Exact match: {exact_match:.02f}\n"
    #       f"F1-score: {f1_score:.02f}")

    # Calculate overall performance
    # total_f1 = 0
    # total_exact = 0
    # total_len = len(questions_test["question"])
    # start_time = time.time()
    # for i, question in enumerate(questions_test["question"]):
    #     print(question)
    #     answer = questions_test["answer"][i]
    #     print(answer)
    #
    #     scores, result = retriever.retrieve(question)
    #     reader_input = result_to_reader_input(result)
    #     answers = reader.read(question, reader_input)
    #
    #     document_scores = sm(torch.Tensor(
    #         [pred.relevance_score for pred in answers]))
    #     span_scores = sm(torch.Tensor(
    #         [pred.span_score for pred in answers]))
    #
    #     highest, highest_index = 0, 0
    #     for j, value in enumerate(span_scores):
    #         if value + document_scores[j] > highest:
    #             highest = value + document_scores[j]
    #             highest_index = j
    #     print(answers[highest_index])
    #     exact_match, f1_score = evaluate(answer, answers[highest_index].text)
    #     total_f1 += f1_score
    #     total_exact += exact_match
    # print(f"Total time:", round(time.time() - start_time, 2), "seconds.")
    # print(total_f1)
    # print(total_exact)
    # print(total_f1/total_len)