Robert commited on
Commit
b7158e7
1 Parent(s): 9889a50

- Remove useless paragraphs that only contain formulas

Browse files

- Added some code to run the script over all questions to calculate overall performance

main.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import random
3
  from typing import cast
 
4
 
5
  import torch
6
  import transformers
@@ -32,8 +33,8 @@ if __name__ == '__main__':
32
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
33
 
34
  # Initialize retriever
35
- # retriever = FaissRetriever(dataset_paragraphs)
36
- retriever = ESRetriever(dataset_paragraphs)
37
 
38
  # Retrieve example
39
  # random.seed(111)
@@ -84,3 +85,36 @@ if __name__ == '__main__':
84
  f"Predicted answer: {answers[highest_index].text}\n"
85
  f"Exact match: {exact_match:.02f}\n"
86
  f"F1-score: {f1_score:.02f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import random
3
  from typing import cast
4
+ import time
5
 
6
  import torch
7
  import transformers
 
33
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
34
 
35
  # Initialize retriever
36
+ retriever = FaissRetriever(dataset_paragraphs)
37
+ #retriever = ESRetriever(dataset_paragraphs)
38
 
39
  # Retrieve example
40
  # random.seed(111)
 
85
  f"Predicted answer: {answers[highest_index].text}\n"
86
  f"Exact match: {exact_match:.02f}\n"
87
  f"F1-score: {f1_score:.02f}")
88
+
89
+ # Calculate overall performance
90
+ # total_f1 = 0
91
+ # total_exact = 0
92
+ # total_len = len(questions_test["question"])
93
+ # start_time = time.time()
94
+ # for i, question in enumerate(questions_test["question"]):
95
+ # print(question)
96
+ # answer = questions_test["answer"][i]
97
+ # print(answer)
98
+ #
99
+ # scores, result = retriever.retrieve(question)
100
+ # reader_input = result_to_reader_input(result)
101
+ # answers = reader.read(question, reader_input)
102
+ #
103
+ # document_scores = sm(torch.Tensor(
104
+ # [pred.relevance_score for pred in answers]))
105
+ # span_scores = sm(torch.Tensor(
106
+ # [pred.span_score for pred in answers]))
107
+ #
108
+ # highest, highest_index = 0, 0
109
+ # for j, value in enumerate(span_scores):
110
+ # if value + document_scores[j] > highest:
111
+ # highest = value + document_scores[j]
112
+ # highest_index = j
113
+ # print(answers[highest_index])
114
+ # exact_match, f1_score = evaluate(answer, answers[highest_index].text)
115
+ # total_f1 += f1_score
116
+ # total_exact += exact_match
117
+ # print(f"Total time:", round(time.time() - start_time, 2), "seconds.")
118
+ # print(total_f1)
119
+ # print(total_exact)
120
+ # print(total_f1/total_len)
src/retrievers/faiss_retriever.py CHANGED
@@ -12,6 +12,7 @@ from transformers import (
12
 
13
  from src.retrievers.base_retriever import Retriever
14
  from src.utils.log import get_logger
 
15
 
16
  # Hacky fix for FAISS error on macOS
17
  # See https://stackoverflow.com/a/63374568/4545692
@@ -55,6 +56,8 @@ class FaissRetriever(Retriever):
55
  force_new_embedding: bool = False):
56
 
57
  ds = self.dataset["train"]
 
 
58
 
59
  if not force_new_embedding and os.path.exists(self.embedding_path):
60
  ds.load_faiss_index(
 
12
 
13
  from src.retrievers.base_retriever import Retriever
14
  from src.utils.log import get_logger
15
+ from src.utils.preprocessing import remove_formulas
16
 
17
  # Hacky fix for FAISS error on macOS
18
  # See https://stackoverflow.com/a/63374568/4545692
 
56
  force_new_embedding: bool = False):
57
 
58
  ds = self.dataset["train"]
59
+ ds = ds.map(remove_formulas)
60
+
61
 
62
  if not force_new_embedding and os.path.exists(self.embedding_path):
63
  ds.load_faiss_index(
src/utils/preprocessing.py CHANGED
@@ -33,3 +33,22 @@ def result_to_reader_input(result: Dict[str, List[str]]) \
33
  reader_result['texts'].append(result['text'][n])
34
 
35
  return reader_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  reader_result['texts'].append(result['text'][n])
34
 
35
  return reader_result
36
+
37
+
38
+ def remove_formulas(ds):
39
+ """Replaces text in the 'text' column of the ds which has an average
40
+ word length of <= 3.5 with blanks. This essentially means that most
41
+ of the formulas are removed.
42
+ To-do:
43
+ - more-preprocessing
44
+ - a summarization model perhaps
45
+ Args:
46
+ ds: HuggingFace dataset that contains the information for the retriever
47
+ Returns:
48
+ ds: preprocessed HuggingFace dataset
49
+ """
50
+ words = ds['text'].split()
51
+ average = sum(len(word) for word in words) / len(words)
52
+ if average <= 3.5:
53
+ ds['text'] = ''
54
+ return ds