GGroenendaal
commited on
Commit
·
b06298d
1
Parent(s):
615dee0
add experiment code
Browse files- README.old.md +3 -3
- main.py +95 -59
- src/retrievers/base_retriever.py +11 -2
- src/retrievers/es_retriever.py +8 -4
- src/retrievers/faiss_retriever.py +4 -2
- test.py +20 -0
README.old.md
CHANGED
@@ -6,12 +6,12 @@
|
|
6 |
- [ ] Formules enzo eruit filteren
|
7 |
- [ ] Splitsen op zinnen...?
|
8 |
- [ ] Meer language models proberen
|
9 |
-
- [
|
10 |
-
- [
|
11 |
|
12 |
### Extra dingen
|
13 |
|
14 |
-
- [
|
15 |
- [ ] Question generation voor finetuning
|
16 |
- [ ] Language model finetunen
|
17 |
|
|
|
6 |
- [ ] Formules enzo eruit filteren
|
7 |
- [ ] Splitsen op zinnen...?
|
8 |
- [ ] Meer language models proberen
|
9 |
+
- [X] Elasticsearch
|
10 |
+
- [X] CLI voor vragen beantwoorden
|
11 |
|
12 |
### Extra dingen
|
13 |
|
14 |
+
- [X] Huggingface spaces demo
|
15 |
- [ ] Question generation voor finetuning
|
16 |
- [ ] Language model finetunen
|
17 |
|
main.py
CHANGED
@@ -1,19 +1,20 @@
|
|
1 |
-
import os
|
2 |
import random
|
3 |
-
from typing import cast
|
4 |
-
import time
|
5 |
|
6 |
import torch
|
7 |
import transformers
|
8 |
from datasets import DatasetDict, load_dataset
|
9 |
from dotenv import load_dotenv
|
|
|
10 |
|
11 |
from src.evaluation import evaluate
|
12 |
from src.readers.dpr_reader import DprReader
|
|
|
13 |
from src.retrievers.es_retriever import ESRetriever
|
14 |
from src.retrievers.faiss_retriever import FaissRetriever
|
15 |
from src.utils.log import get_logger
|
16 |
from src.utils.preprocessing import context_to_reader_input
|
|
|
17 |
|
18 |
logger = get_logger()
|
19 |
|
@@ -26,62 +27,97 @@ if __name__ == '__main__':
|
|
26 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
27 |
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
#
|
66 |
-
|
67 |
-
#
|
68 |
-
#
|
69 |
-
#
|
70 |
-
|
71 |
-
#
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
#
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# Calculate overall performance
|
87 |
# total_f1 = 0
|
|
|
|
|
1 |
import random
|
2 |
+
from typing import Dict, cast
|
|
|
3 |
|
4 |
import torch
|
5 |
import transformers
|
6 |
from datasets import DatasetDict, load_dataset
|
7 |
from dotenv import load_dotenv
|
8 |
+
from query import print_answers
|
9 |
|
10 |
from src.evaluation import evaluate
|
11 |
from src.readers.dpr_reader import DprReader
|
12 |
+
from src.retrievers.base_retriever import Retriever
|
13 |
from src.retrievers.es_retriever import ESRetriever
|
14 |
from src.retrievers.faiss_retriever import FaissRetriever
|
15 |
from src.utils.log import get_logger
|
16 |
from src.utils.preprocessing import context_to_reader_input
|
17 |
+
from src.utils.timing import get_times, timeit
|
18 |
|
19 |
logger = get_logger()
|
20 |
|
|
|
27 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
28 |
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
|
29 |
|
30 |
+
# Only doing a few questions for speed
|
31 |
+
subset_idx = 3
|
32 |
+
questions_test = questions["test"][:subset_idx]
|
33 |
+
|
34 |
+
experiments: Dict[str, Retriever] = {
|
35 |
+
"faiss": FaissRetriever(paragraphs),
|
36 |
+
# "es": ESRetriever(paragraphs),
|
37 |
+
}
|
38 |
+
|
39 |
+
for experiment_name, retriever in experiments.items():
|
40 |
+
reader = DprReader()
|
41 |
+
|
42 |
+
for idx in range(subset_idx):
|
43 |
+
question = questions_test["question"][idx]
|
44 |
+
answer = questions_test["answer"][idx]
|
45 |
+
|
46 |
+
scores, context = retriever.retrieve(question, 5)
|
47 |
+
reader_input = context_to_reader_input(context)
|
48 |
+
|
49 |
+
# workaround so we can use the decorator with a dynamic name for time recording
|
50 |
+
time_wrapper = timeit(f"{experiment_name}.read")
|
51 |
+
answers = time_wrapper(reader.read)(question, reader_input, 5)
|
52 |
+
|
53 |
+
# Calculate softmaxed scores for readable output
|
54 |
+
sm = torch.nn.Softmax(dim=0)
|
55 |
+
document_scores = sm(torch.Tensor(
|
56 |
+
[pred.relevance_score for pred in answers]))
|
57 |
+
span_scores = sm(torch.Tensor(
|
58 |
+
[pred.span_score for pred in answers]))
|
59 |
+
|
60 |
+
print_answers(answers, scores, context)
|
61 |
+
|
62 |
+
# TODO evaluation and storing of results
|
63 |
+
|
64 |
+
times = get_times()
|
65 |
+
print(times)
|
66 |
+
# TODO evaluation and storing of results
|
67 |
+
|
68 |
+
# # Initialize retriever
|
69 |
+
# retriever = FaissRetriever(paragraphs)
|
70 |
+
# # retriever = ESRetriever(paragraphs)
|
71 |
+
|
72 |
+
# # Retrieve example
|
73 |
+
# # random.seed(111)
|
74 |
+
# random_index = random.randint(0, len(questions_test["question"])-1)
|
75 |
+
# example_q = questions_test["question"][random_index]
|
76 |
+
# example_a = questions_test["answer"][random_index]
|
77 |
+
|
78 |
+
# scores, result = retriever.retrieve(example_q)
|
79 |
+
# reader_input = context_to_reader_input(result)
|
80 |
+
|
81 |
+
# # TODO: use new code from query.py to clean this up
|
82 |
+
# # Initialize reader
|
83 |
+
# answers = reader.read(example_q, reader_input)
|
84 |
+
|
85 |
+
# # Calculate softmaxed scores for readable output
|
86 |
+
# sm = torch.nn.Softmax(dim=0)
|
87 |
+
# document_scores = sm(torch.Tensor(
|
88 |
+
# [pred.relevance_score for pred in answers]))
|
89 |
+
# span_scores = sm(torch.Tensor(
|
90 |
+
# [pred.span_score for pred in answers]))
|
91 |
+
|
92 |
+
# print(example_q)
|
93 |
+
# for answer_i, answer in enumerate(answers):
|
94 |
+
# print(f"[{answer_i + 1}]: {answer.text}")
|
95 |
+
# print(f"\tDocument {answer.doc_id}", end='')
|
96 |
+
# print(f"\t(score {document_scores[answer_i] * 100:.02f})")
|
97 |
+
# print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
|
98 |
+
# print(f"\t(score {span_scores[answer_i] * 100:.02f})")
|
99 |
+
# print() # Newline
|
100 |
+
|
101 |
+
# # print(f"Example q: {example_q} answer: {result['text'][0]}")
|
102 |
+
|
103 |
+
# # for i, score in enumerate(scores):
|
104 |
+
# # print(f"Result {i+1} (score: {score:.02f}):")
|
105 |
+
# # print(result['text'][i])
|
106 |
+
|
107 |
+
# # Determine best answer we want to evaluate
|
108 |
+
# highest, highest_index = 0, 0
|
109 |
+
# for i, value in enumerate(span_scores):
|
110 |
+
# if value + document_scores[i] > highest:
|
111 |
+
# highest = value + document_scores[i]
|
112 |
+
# highest_index = i
|
113 |
+
|
114 |
+
# # Retrieve exact match and F1-score
|
115 |
+
# exact_match, f1_score = evaluate(
|
116 |
+
# example_a, answers[highest_index].text)
|
117 |
+
# print(f"Gold answer: {example_a}\n"
|
118 |
+
# f"Predicted answer: {answers[highest_index].text}\n"
|
119 |
+
# f"Exact match: {exact_match:.02f}\n"
|
120 |
+
# f"F1-score: {f1_score:.02f}")
|
121 |
|
122 |
# Calculate overall performance
|
123 |
# total_f1 = 0
|
src/retrievers/base_retriever.py
CHANGED
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
class Retriever():
|
2 |
-
def retrieve(self, query: str, k: int):
|
3 |
-
|
|
|
1 |
+
from typing import Dict, List, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
RetrieveTypeResult = Dict[str, List[str]]
|
6 |
+
RetrieveTypeScores = np.ndarray
|
7 |
+
RetrieveType = Tuple[RetrieveTypeScores, RetrieveTypeResult]
|
8 |
+
|
9 |
+
|
10 |
class Retriever():
|
11 |
+
def retrieve(self, query: str, k: int) -> RetrieveType:
|
12 |
+
raise NotImplementedError()
|
src/retrievers/es_retriever.py
CHANGED
@@ -1,8 +1,11 @@
|
|
|
|
|
|
1 |
from datasets import DatasetDict
|
2 |
-
from src.utils.log import get_logger
|
3 |
-
from src.retrievers.base_retriever import Retriever
|
4 |
from elasticsearch import Elasticsearch
|
5 |
-
|
|
|
|
|
|
|
6 |
|
7 |
logger = get_logger()
|
8 |
|
@@ -31,5 +34,6 @@ class ESRetriever(Retriever):
|
|
31 |
es_index_name="paragraphs",
|
32 |
es_client=self.client)
|
33 |
|
34 |
-
|
|
|
35 |
return self.paragraphs.get_nearest_examples("paragraphs", query, k)
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
from datasets import DatasetDict
|
|
|
|
|
4 |
from elasticsearch import Elasticsearch
|
5 |
+
|
6 |
+
from src.retrievers.base_retriever import RetrieveType, Retriever
|
7 |
+
from src.utils.log import get_logger
|
8 |
+
from src.utils.timing import timeit
|
9 |
|
10 |
logger = get_logger()
|
11 |
|
|
|
34 |
es_index_name="paragraphs",
|
35 |
es_client=self.client)
|
36 |
|
37 |
+
@timeit("esretriever.retrieve")
|
38 |
+
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
39 |
return self.paragraphs.get_nearest_examples("paragraphs", query, k)
|
src/retrievers/faiss_retriever.py
CHANGED
@@ -10,9 +10,10 @@ from transformers import (
|
|
10 |
DPRQuestionEncoderTokenizer,
|
11 |
)
|
12 |
|
13 |
-
from src.retrievers.base_retriever import Retriever
|
14 |
from src.utils.log import get_logger
|
15 |
from src.utils.preprocessing import remove_formulas
|
|
|
16 |
|
17 |
# Hacky fix for FAISS error on macOS
|
18 |
# See https://stackoverflow.com/a/63374568/4545692
|
@@ -83,7 +84,8 @@ class FaissRetriever(Retriever):
|
|
83 |
|
84 |
return index
|
85 |
|
86 |
-
|
|
|
87 |
def embed(q):
|
88 |
# Inline helper function to perform embedding
|
89 |
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
|
|
10 |
DPRQuestionEncoderTokenizer,
|
11 |
)
|
12 |
|
13 |
+
from src.retrievers.base_retriever import RetrieveType, Retriever
|
14 |
from src.utils.log import get_logger
|
15 |
from src.utils.preprocessing import remove_formulas
|
16 |
+
from src.utils.timing import timeit
|
17 |
|
18 |
# Hacky fix for FAISS error on macOS
|
19 |
# See https://stackoverflow.com/a/63374568/4545692
|
|
|
84 |
|
85 |
return index
|
86 |
|
87 |
+
@timeit("faissretriever.retrieve")
|
88 |
+
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
89 |
def embed(q):
|
90 |
# Inline helper function to perform embedding
|
91 |
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
test.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
from datasets import load_dataset
|
3 |
+
from src.retrievers.faiss_retriever import FaissRetriever
|
4 |
+
|
5 |
+
|
6 |
+
data = load_dataset("GroNLP/ik-nlp-22_slp", "paragraphs")
|
7 |
+
|
8 |
+
# # %%
|
9 |
+
# x = data["test"][:3]
|
10 |
+
|
11 |
+
# # %%
|
12 |
+
# for y in x:
|
13 |
+
|
14 |
+
# print(y)
|
15 |
+
# # %%
|
16 |
+
# x.num_rows
|
17 |
+
|
18 |
+
# # %%
|
19 |
+
retriever = FaissRetriever(data)
|
20 |
+
scores, result = retriever.retrieve("hello world")
|