File size: 4,939 Bytes
1f08ed2
 
 
 
be1f224
1f08ed2
 
 
be1f224
 
1f08ed2
 
 
be1f224
 
 
 
325e3c6
0157dfd
 
 
 
 
 
1f08ed2
 
be1f224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f08ed2
 
 
 
 
 
 
 
e9df5ab
1f08ed2
 
 
 
 
 
 
325e3c6
 
 
 
e9df5ab
325e3c6
1f08ed2
 
 
 
 
 
 
 
 
 
 
 
be1f224
 
 
 
 
325e3c6
 
 
 
 
 
 
 
 
e9df5ab
 
be1f224
325e3c6
 
 
 
 
1f08ed2
 
e9df5ab
 
1f08ed2
 
be1f224
 
325e3c6
be1f224
1f08ed2
 
be1f224
 
 
 
 
 
 
 
1f08ed2
 
 
 
 
 
 
be1f224
 
 
 
 
 
 
 
 
 
 
 
1f08ed2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import argparse
import torch
import transformers

from typing import Dict, List, Literal, Tuple, cast
from datasets import load_dataset, DatasetDict
from dotenv import load_dotenv

from src.readers.base_reader import Reader
from src.readers.longformer_reader import LongformerReader
from src.readers.dpr_reader import DprReader
from src.retrievers.base_retriever import Retriever
from src.retrievers.es_retriever import ESRetriever
from src.retrievers.faiss_retriever import (
    FaissRetriever,
    FaissRetrieverOptions
)
from src.utils.preprocessing import context_to_reader_input
from src.utils.log import logger


# Setup environment
load_dotenv()
transformers.logging.set_verbosity_error()


def get_retriever(paragraphs: DatasetDict,
                  r: Literal["es", "faiss"],
                  lm: Literal["dpr", "longformer"]) -> Retriever:
    match (r, lm):
        case "es", _:
            return ESRetriever()
        case "faiss", "dpr":
            options = FaissRetrieverOptions.dpr("./src/models/dpr.faiss")
            return FaissRetriever(paragraphs, options)
        case "faiss", "longformer":
            options = FaissRetrieverOptions.longformer(
                "./src/models/longformer.faiss")
            return FaissRetriever(paragraphs, options)
        case _:
            raise ValueError("Retriever options not recognized")


def get_reader(lm: Literal["dpr", "longformer"]) -> Reader:
    match lm:
        case "dpr":
            return DprReader()
        case "longformer":
            return LongformerReader()
        case _:
            raise ValueError("Language model not recognized")


def print_name(contexts: dict, section: str, id: int):
    name = contexts[section][id]
    if name != 'nan':
        print(f"      {section}: {name}")


def get_retrieval_span_scores(answers: List[tuple]):
    # calculate answer scores
    sm = torch.nn.Softmax(dim=0)
    d_scores = sm(torch.Tensor(
        [pred.relevance_score for pred in answers]))
    s_scores = sm(torch.Tensor(
        [pred.span_score for pred in answers]))

    return d_scores, s_scores


def print_answers(answers: List[tuple], scores: List[float], contexts: dict):
    d_scores, s_scores = get_retrieval_span_scores(answers)

    for pos, answer in enumerate(answers):
        print(f"{pos + 1:>4}. {answer.text}")
        print(f"      {'-' * len(answer.text)}")
        print_name(contexts, 'chapter', answer.doc_id)
        print_name(contexts, 'section', answer.doc_id)
        print_name(contexts, 'subsection', answer.doc_id)
        print(f"      retrieval score: {scores[answer.doc_id]:6.02f}%")
        print(f"      document score:  {d_scores[pos] * 100:6.02f}%")
        print(f"      span score:      {s_scores[pos] * 100:6.02f}%")
        print()


def probe(query: str,
          retriever: Retriever,
          reader: Reader,
          num_answers: int = 5) \
          -> Tuple[List[tuple], List[float], Dict[str, List[str]]]:
    scores, contexts = retriever.retrieve(query)
    reader_input = context_to_reader_input(contexts)
    answers = reader.read(query, reader_input, num_answers)

    return answers, scores, contexts


def default_probe(query: str):
    # default probe is a probe that prints 5 answers with faiss
    paragraphs = cast(DatasetDict, load_dataset(
        "GroNLP/ik-nlp-22_slp", "paragraphs"))
    retriever = get_retriever(paragraphs, "faiss", "dpr")
    reader = DprReader()

    return probe(query, retriever, reader)


def main(args: argparse.Namespace):
    # Initialize dataset
    paragraphs = cast(DatasetDict, load_dataset(
        "GroNLP/ik-nlp-22_slp", "paragraphs"))

    # Retrieve
    retriever = get_retriever(paragraphs, args.retriever, args.lm)
    reader = get_reader(args.lm)
    answers, scores, contexts = probe(
        args.query, retriever, reader, args.top)

    # Print output
    print("Question: " + args.query)
    print("Answer(s):")
    if args.lm == "dpr":
        print_answers(answers, scores, contexts)
    else:
        answers = filter(lambda a: len(a[0].strip()) > 0, answers)
        for pos, answer in enumerate(answers, start=1):
            print(f"    - {answer[0].strip()}")


if __name__ == "__main__":
    # Set up CLI arguments
    parser = argparse.ArgumentParser(
        formatter_class=argparse.MetavarTypeHelpFormatter
    )
    parser.add_argument(
        "query", type=str, help="The question to feed to the QA system")
    parser.add_argument(
        "--top", "-t", type=int, default=1,
        help="The number of answers to retrieve")
    parser.add_argument(
        "--retriever", "-r", type=str.lower, choices=["faiss", "es"],
        default="faiss", help="The retrieval method to use")
    parser.add_argument(
        "--lm", "-l", type=str.lower,
        choices=["dpr", "longformer"], default="dpr",
        help="The language model to use for the FAISS retriever")

    args = parser.parse_args()
    main(args)