GGroenendaal commited on
Commit
1fb8ae3
1 Parent(s): 2a1de95

decouple ds loading from retriever

Browse files
main.py CHANGED
@@ -27,11 +27,14 @@ if __name__ == '__main__':
27
 
28
  # logger.info(questions)
29
 
 
 
 
30
  # Initialize retriever
31
- retriever = FaissRetriever()
32
 
33
  # Retrieve example
34
- #random.seed(111)
35
  random_index = random.randint(0, len(questions_test["question"])-1)
36
  example_q = questions_test["question"][random_index]
37
  example_a = questions_test["answer"][random_index]
 
27
 
28
  # logger.info(questions)
29
 
30
+ dataset_paragraphs = cast(DatasetDict, load_dataset(
31
+ "GroNLP/ik-nlp-22_slp", "paragraphs"))
32
+
33
  # Initialize retriever
34
+ retriever = FaissRetriever(dataset_paragraphs)
35
 
36
  # Retrieve example
37
+ # random.seed(111)
38
  random_index = random.randint(0, len(questions_test["question"])-1)
39
  example_q = questions_test["question"][random_index]
40
  example_a = questions_test["answer"][random_index]
src/retrievers/es_retriever.py CHANGED
@@ -1,10 +1,14 @@
 
1
  from src.utils.log import get_logger
 
 
2
 
3
  logger = get_logger()
4
 
5
 
6
  class ESRetriever(Retriever):
7
- def __init__(self, data_set):
 
8
  pass
9
 
10
  def retrieve(self, query: str, k: int):
 
1
+ from datasets import load_dataset
2
  from src.utils.log import get_logger
3
+ from src.retrievers.base_retriever import Retriever
4
+
5
 
6
  logger = get_logger()
7
 
8
 
9
  class ESRetriever(Retriever):
10
+ def __init__(self, data_set: ) -> None:
11
+
12
  pass
13
 
14
  def retrieve(self, query: str, k: int):
src/retrievers/faiss_retriever.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import os.path
3
 
4
  import torch
5
- from datasets import load_dataset
6
  from transformers import (
7
  DPRContextEncoder,
8
  DPRContextEncoderTokenizer,
@@ -26,14 +26,7 @@ class FaissRetriever(Retriever):
26
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
27
  """
28
 
29
- def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp") -> None:
30
- """Initialize the retriever
31
-
32
- Args:
33
- dataset (str, optional): The dataset to train on. Assumes the
34
- information is stored in a column named 'text'. Defaults to
35
- "GroNLP/ik-nlp-22_slp".
36
- """
37
  torch.set_grad_enabled(False)
38
 
39
  # Context encoding and tokenization
@@ -52,36 +45,22 @@ class FaissRetriever(Retriever):
52
  "facebook/dpr-question_encoder-single-nq-base"
53
  )
54
 
55
- # Dataset building
56
- self.dataset_name = dataset_name
57
- self.dataset = self._init_dataset(dataset_name)
 
58
 
59
- def _init_dataset(
60
  self,
61
- dataset_name: str,
62
- embedding_path: str = "./src/models/paragraphs_embedding.faiss",
63
  force_new_embedding: bool = False):
64
- """Loads the dataset and adds FAISS embeddings.
65
-
66
- Args:
67
- dataset (str): A HuggingFace dataset name.
68
- fname (str): The name to use to save the embeddings to disk for
69
- faster loading after the first run.
70
-
71
- Returns:
72
- Dataset: A dataset with a new column 'embeddings' containing FAISS
73
- embeddings.
74
- """
75
- # Load dataset
76
- ds = load_dataset(dataset_name, name="paragraphs")[
77
- "train"] # type: ignore
78
-
79
- if not force_new_embedding and os.path.exists(embedding_path):
80
- # If we already have FAISS embeddings, load them from disk
81
- ds.load_faiss_index('embeddings', embedding_path) # type: ignore
82
  return ds
83
  else:
84
- # If there are no FAISS embeddings, generate them
85
  def embed(row):
86
  # Inline helper function to perform embedding
87
  p = row["text"]
@@ -91,35 +70,25 @@ class FaissRetriever(Retriever):
91
  return {"embeddings": enc}
92
 
93
  # Add FAISS embeddings
94
- ds_with_embeddings = ds.map(embed) # type: ignore
95
 
96
- ds_with_embeddings.add_faiss_index(column="embeddings")
97
 
98
  # save dataset w/ embeddings
99
  os.makedirs("./src/models/", exist_ok=True)
100
- ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
 
101
 
102
- return ds_with_embeddings
103
 
104
  def retrieve(self, query: str, k: int = 5):
105
- """Retrieve the top k matches for a search query.
106
-
107
- Args:
108
- query (str): A search query
109
- k (int, optional): The number of documents to retrieve. Defaults to
110
- 5.
111
-
112
- Returns:
113
- tuple: A tuple of lists of scores and results.
114
- """
115
-
116
  def embed(q):
117
  # Inline helper function to perform embedding
118
  tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
119
  return self.q_encoder(**tok)[0][0].numpy()
120
 
121
  question_embedding = embed(query)
122
- scores, results = self.dataset.get_nearest_examples(
123
  "embeddings", question_embedding, k=k
124
  )
125
 
 
2
  import os.path
3
 
4
  import torch
5
+ from datasets import DatasetDict, load_dataset
6
  from transformers import (
7
  DPRContextEncoder,
8
  DPRContextEncoderTokenizer,
 
26
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
27
  """
28
 
29
+ def __init__(self, dataset: DatasetDict, embedding_path: str = "./src/models/paragraphs_embedding.faiss") -> None:
 
 
 
 
 
 
 
30
  torch.set_grad_enabled(False)
31
 
32
  # Context encoding and tokenization
 
45
  "facebook/dpr-question_encoder-single-nq-base"
46
  )
47
 
48
+ self.dataset = dataset
49
+ self.embedding_path = embedding_path
50
+
51
+ self.index = self._init_index()
52
 
53
+ def _init_index(
54
  self,
 
 
55
  force_new_embedding: bool = False):
56
+
57
+ ds = self.dataset["train"]
58
+
59
+ if not force_new_embedding and os.path.exists(self.embedding_path):
60
+ ds.load_faiss_index(
61
+ 'embeddings', self.embedding_path) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
62
  return ds
63
  else:
 
64
  def embed(row):
65
  # Inline helper function to perform embedding
66
  p = row["text"]
 
70
  return {"embeddings": enc}
71
 
72
  # Add FAISS embeddings
73
+ index = ds.map(embed) # type: ignore
74
 
75
+ index.add_faiss_index(column="embeddings")
76
 
77
  # save dataset w/ embeddings
78
  os.makedirs("./src/models/", exist_ok=True)
79
+ index.save_faiss_index(
80
+ "embeddings", self.embedding_path)
81
 
82
+ return index
83
 
84
  def retrieve(self, query: str, k: int = 5):
 
 
 
 
 
 
 
 
 
 
 
85
  def embed(q):
86
  # Inline helper function to perform embedding
87
  tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
88
  return self.q_encoder(**tok)[0][0].numpy()
89
 
90
  question_embedding = embed(query)
91
+ scores, results = self.index.get_nearest_examples(
92
  "embeddings", question_embedding, k=k
93
  )
94