GGroenendaal commited on
Commit
b06298d
1 Parent(s): 615dee0

add experiment code

Browse files
README.old.md CHANGED
@@ -6,12 +6,12 @@
6
  - [ ] Formules enzo eruit filteren
7
  - [ ] Splitsen op zinnen...?
8
  - [ ] Meer language models proberen
9
- - [ ] Elasticsearch
10
- - [ ] CLI voor vragen beantwoorden
11
 
12
  ### Extra dingen
13
 
14
- - [ ] Huggingface spaces demo
15
  - [ ] Question generation voor finetuning
16
  - [ ] Language model finetunen
17
 
 
6
  - [ ] Formules enzo eruit filteren
7
  - [ ] Splitsen op zinnen...?
8
  - [ ] Meer language models proberen
9
+ - [X] Elasticsearch
10
+ - [X] CLI voor vragen beantwoorden
11
 
12
  ### Extra dingen
13
 
14
+ - [X] Huggingface spaces demo
15
  - [ ] Question generation voor finetuning
16
  - [ ] Language model finetunen
17
 
main.py CHANGED
@@ -1,19 +1,20 @@
1
- import os
2
  import random
3
- from typing import cast
4
- import time
5
 
6
  import torch
7
  import transformers
8
  from datasets import DatasetDict, load_dataset
9
  from dotenv import load_dotenv
 
10
 
11
  from src.evaluation import evaluate
12
  from src.readers.dpr_reader import DprReader
 
13
  from src.retrievers.es_retriever import ESRetriever
14
  from src.retrievers.faiss_retriever import FaissRetriever
15
  from src.utils.log import get_logger
16
  from src.utils.preprocessing import context_to_reader_input
 
17
 
18
  logger = get_logger()
19
 
@@ -26,62 +27,97 @@ if __name__ == '__main__':
26
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
27
  questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
28
 
29
- questions_test = questions["test"]
30
-
31
- # Initialize retriever
32
- retriever = FaissRetriever(paragraphs)
33
- #retriever = ESRetriever(paragraphs)
34
-
35
- # Retrieve example
36
- # random.seed(111)
37
- random_index = random.randint(0, len(questions_test["question"])-1)
38
- example_q = questions_test["question"][random_index]
39
- example_a = questions_test["answer"][random_index]
40
-
41
- scores, result = retriever.retrieve(example_q)
42
- reader_input = context_to_reader_input(result)
43
-
44
- # TODO: use new code from query.py to clean this up
45
- # Initialize reader
46
- reader = DprReader()
47
- answers = reader.read(example_q, reader_input)
48
-
49
- # Calculate softmaxed scores for readable output
50
- sm = torch.nn.Softmax(dim=0)
51
- document_scores = sm(torch.Tensor(
52
- [pred.relevance_score for pred in answers]))
53
- span_scores = sm(torch.Tensor(
54
- [pred.span_score for pred in answers]))
55
-
56
- print(example_q)
57
- for answer_i, answer in enumerate(answers):
58
- print(f"[{answer_i + 1}]: {answer.text}")
59
- print(f"\tDocument {answer.doc_id}", end='')
60
- print(f"\t(score {document_scores[answer_i] * 100:.02f})")
61
- print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
62
- print(f"\t(score {span_scores[answer_i] * 100:.02f})")
63
- print() # Newline
64
-
65
- # print(f"Example q: {example_q} answer: {result['text'][0]}")
66
-
67
- # for i, score in enumerate(scores):
68
- # print(f"Result {i+1} (score: {score:.02f}):")
69
- # print(result['text'][i])
70
-
71
- # Determine best answer we want to evaluate
72
- highest, highest_index = 0, 0
73
- for i, value in enumerate(span_scores):
74
- if value + document_scores[i] > highest:
75
- highest = value + document_scores[i]
76
- highest_index = i
77
-
78
- # Retrieve exact match and F1-score
79
- exact_match, f1_score = evaluate(
80
- example_a, answers[highest_index].text)
81
- print(f"Gold answer: {example_a}\n"
82
- f"Predicted answer: {answers[highest_index].text}\n"
83
- f"Exact match: {exact_match:.02f}\n"
84
- f"F1-score: {f1_score:.02f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  # Calculate overall performance
87
  # total_f1 = 0
 
 
1
  import random
2
+ from typing import Dict, cast
 
3
 
4
  import torch
5
  import transformers
6
  from datasets import DatasetDict, load_dataset
7
  from dotenv import load_dotenv
8
+ from query import print_answers
9
 
10
  from src.evaluation import evaluate
11
  from src.readers.dpr_reader import DprReader
12
+ from src.retrievers.base_retriever import Retriever
13
  from src.retrievers.es_retriever import ESRetriever
14
  from src.retrievers.faiss_retriever import FaissRetriever
15
  from src.utils.log import get_logger
16
  from src.utils.preprocessing import context_to_reader_input
17
+ from src.utils.timing import get_times, timeit
18
 
19
  logger = get_logger()
20
 
 
27
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
28
  questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
29
 
30
+ # Only doing a few questions for speed
31
+ subset_idx = 3
32
+ questions_test = questions["test"][:subset_idx]
33
+
34
+ experiments: Dict[str, Retriever] = {
35
+ "faiss": FaissRetriever(paragraphs),
36
+ # "es": ESRetriever(paragraphs),
37
+ }
38
+
39
+ for experiment_name, retriever in experiments.items():
40
+ reader = DprReader()
41
+
42
+ for idx in range(subset_idx):
43
+ question = questions_test["question"][idx]
44
+ answer = questions_test["answer"][idx]
45
+
46
+ scores, context = retriever.retrieve(question, 5)
47
+ reader_input = context_to_reader_input(context)
48
+
49
+ # workaround so we can use the decorator with a dynamic name for time recording
50
+ time_wrapper = timeit(f"{experiment_name}.read")
51
+ answers = time_wrapper(reader.read)(question, reader_input, 5)
52
+
53
+ # Calculate softmaxed scores for readable output
54
+ sm = torch.nn.Softmax(dim=0)
55
+ document_scores = sm(torch.Tensor(
56
+ [pred.relevance_score for pred in answers]))
57
+ span_scores = sm(torch.Tensor(
58
+ [pred.span_score for pred in answers]))
59
+
60
+ print_answers(answers, scores, context)
61
+
62
+ # TODO evaluation and storing of results
63
+
64
+ times = get_times()
65
+ print(times)
66
+ # TODO evaluation and storing of results
67
+
68
+ # # Initialize retriever
69
+ # retriever = FaissRetriever(paragraphs)
70
+ # # retriever = ESRetriever(paragraphs)
71
+
72
+ # # Retrieve example
73
+ # # random.seed(111)
74
+ # random_index = random.randint(0, len(questions_test["question"])-1)
75
+ # example_q = questions_test["question"][random_index]
76
+ # example_a = questions_test["answer"][random_index]
77
+
78
+ # scores, result = retriever.retrieve(example_q)
79
+ # reader_input = context_to_reader_input(result)
80
+
81
+ # # TODO: use new code from query.py to clean this up
82
+ # # Initialize reader
83
+ # answers = reader.read(example_q, reader_input)
84
+
85
+ # # Calculate softmaxed scores for readable output
86
+ # sm = torch.nn.Softmax(dim=0)
87
+ # document_scores = sm(torch.Tensor(
88
+ # [pred.relevance_score for pred in answers]))
89
+ # span_scores = sm(torch.Tensor(
90
+ # [pred.span_score for pred in answers]))
91
+
92
+ # print(example_q)
93
+ # for answer_i, answer in enumerate(answers):
94
+ # print(f"[{answer_i + 1}]: {answer.text}")
95
+ # print(f"\tDocument {answer.doc_id}", end='')
96
+ # print(f"\t(score {document_scores[answer_i] * 100:.02f})")
97
+ # print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
98
+ # print(f"\t(score {span_scores[answer_i] * 100:.02f})")
99
+ # print() # Newline
100
+
101
+ # # print(f"Example q: {example_q} answer: {result['text'][0]}")
102
+
103
+ # # for i, score in enumerate(scores):
104
+ # # print(f"Result {i+1} (score: {score:.02f}):")
105
+ # # print(result['text'][i])
106
+
107
+ # # Determine best answer we want to evaluate
108
+ # highest, highest_index = 0, 0
109
+ # for i, value in enumerate(span_scores):
110
+ # if value + document_scores[i] > highest:
111
+ # highest = value + document_scores[i]
112
+ # highest_index = i
113
+
114
+ # # Retrieve exact match and F1-score
115
+ # exact_match, f1_score = evaluate(
116
+ # example_a, answers[highest_index].text)
117
+ # print(f"Gold answer: {example_a}\n"
118
+ # f"Predicted answer: {answers[highest_index].text}\n"
119
+ # f"Exact match: {exact_match:.02f}\n"
120
+ # f"F1-score: {f1_score:.02f}")
121
 
122
  # Calculate overall performance
123
  # total_f1 = 0
src/retrievers/base_retriever.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  class Retriever():
2
- def retrieve(self, query: str, k: int):
3
- pass
 
1
+ from typing import Dict, List, Tuple
2
+
3
+ import numpy as np
4
+
5
+ RetrieveTypeResult = Dict[str, List[str]]
6
+ RetrieveTypeScores = np.ndarray
7
+ RetrieveType = Tuple[RetrieveTypeScores, RetrieveTypeResult]
8
+
9
+
10
  class Retriever():
11
+ def retrieve(self, query: str, k: int) -> RetrieveType:
12
+ raise NotImplementedError()
src/retrievers/es_retriever.py CHANGED
@@ -1,8 +1,11 @@
 
 
1
  from datasets import DatasetDict
2
- from src.utils.log import get_logger
3
- from src.retrievers.base_retriever import Retriever
4
  from elasticsearch import Elasticsearch
5
- import os
 
 
 
6
 
7
  logger = get_logger()
8
 
@@ -31,5 +34,6 @@ class ESRetriever(Retriever):
31
  es_index_name="paragraphs",
32
  es_client=self.client)
33
 
34
- def retrieve(self, query: str, k: int = 5):
 
35
  return self.paragraphs.get_nearest_examples("paragraphs", query, k)
 
1
+ import os
2
+
3
  from datasets import DatasetDict
 
 
4
  from elasticsearch import Elasticsearch
5
+
6
+ from src.retrievers.base_retriever import RetrieveType, Retriever
7
+ from src.utils.log import get_logger
8
+ from src.utils.timing import timeit
9
 
10
  logger = get_logger()
11
 
 
34
  es_index_name="paragraphs",
35
  es_client=self.client)
36
 
37
+ @timeit("esretriever.retrieve")
38
+ def retrieve(self, query: str, k: int = 5) -> RetrieveType:
39
  return self.paragraphs.get_nearest_examples("paragraphs", query, k)
src/retrievers/faiss_retriever.py CHANGED
@@ -10,9 +10,10 @@ from transformers import (
10
  DPRQuestionEncoderTokenizer,
11
  )
12
 
13
- from src.retrievers.base_retriever import Retriever
14
  from src.utils.log import get_logger
15
  from src.utils.preprocessing import remove_formulas
 
16
 
17
  # Hacky fix for FAISS error on macOS
18
  # See https://stackoverflow.com/a/63374568/4545692
@@ -83,7 +84,8 @@ class FaissRetriever(Retriever):
83
 
84
  return index
85
 
86
- def retrieve(self, query: str, k: int = 50):
 
87
  def embed(q):
88
  # Inline helper function to perform embedding
89
  tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
 
10
  DPRQuestionEncoderTokenizer,
11
  )
12
 
13
+ from src.retrievers.base_retriever import RetrieveType, Retriever
14
  from src.utils.log import get_logger
15
  from src.utils.preprocessing import remove_formulas
16
+ from src.utils.timing import timeit
17
 
18
  # Hacky fix for FAISS error on macOS
19
  # See https://stackoverflow.com/a/63374568/4545692
 
84
 
85
  return index
86
 
87
+ @timeit("faissretriever.retrieve")
88
+ def retrieve(self, query: str, k: int = 5) -> RetrieveType:
89
  def embed(q):
90
  # Inline helper function to perform embedding
91
  tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
test.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from datasets import load_dataset
3
+ from src.retrievers.faiss_retriever import FaissRetriever
4
+
5
+
6
+ data = load_dataset("GroNLP/ik-nlp-22_slp", "paragraphs")
7
+
8
+ # # %%
9
+ # x = data["test"][:3]
10
+
11
+ # # %%
12
+ # for y in x:
13
+
14
+ # print(y)
15
+ # # %%
16
+ # x.num_rows
17
+
18
+ # # %%
19
+ retriever = FaissRetriever(data)
20
+ scores, result = retriever.retrieve("hello world")