GGroenendaal commited on
Commit
af461f3
1 Parent(s): 1fb8ae3

add es retriever

Browse files
Files changed (3) hide show
  1. .env.example +1 -0
  2. main.py +10 -9
  3. src/retrievers/es_retriever.py +25 -5
.env.example CHANGED
@@ -1,4 +1,5 @@
1
  ELASTIC_USERNAME=elastic
2
  ELASTIC_PASSWORD=<password>
 
3
 
4
  LOG_LEVEL=INFO
 
1
  ELASTIC_USERNAME=elastic
2
  ELASTIC_PASSWORD=<password>
3
+ ELASTIC_HOST=<localhost>
4
 
5
  LOG_LEVEL=INFO
main.py CHANGED
@@ -1,18 +1,18 @@
 
 
 
 
 
 
1
  from datasets import DatasetDict, load_dataset
2
 
 
3
  from src.readers.dpr_reader import DprReader
 
4
  from src.retrievers.faiss_retriever import FaissRetriever
5
  from src.utils.log import get_logger
6
- from src.evaluation import evaluate
7
- from typing import cast
8
-
9
  from src.utils.preprocessing import result_to_reader_input
10
 
11
- import torch
12
- import transformers
13
- import os
14
- import random
15
-
16
  os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
17
 
18
  logger = get_logger()
@@ -31,7 +31,8 @@ if __name__ == '__main__':
31
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
32
 
33
  # Initialize retriever
34
- retriever = FaissRetriever(dataset_paragraphs)
 
35
 
36
  # Retrieve example
37
  # random.seed(111)
 
1
+ import os
2
+ import random
3
+ from typing import cast
4
+
5
+ import torch
6
+ import transformers
7
  from datasets import DatasetDict, load_dataset
8
 
9
+ from src.evaluation import evaluate
10
  from src.readers.dpr_reader import DprReader
11
+ from src.retrievers.es_retriever import ESRetriever
12
  from src.retrievers.faiss_retriever import FaissRetriever
13
  from src.utils.log import get_logger
 
 
 
14
  from src.utils.preprocessing import result_to_reader_input
15
 
 
 
 
 
 
16
  os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
17
 
18
  logger = get_logger()
 
31
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
32
 
33
  # Initialize retriever
34
+ # retriever = FaissRetriever(dataset_paragraphs)
35
+ retriever = ESRetriever(dataset_paragraphs)
36
 
37
  # Retrieve example
38
  # random.seed(111)
src/retrievers/es_retriever.py CHANGED
@@ -1,15 +1,35 @@
1
- from datasets import load_dataset
2
  from src.utils.log import get_logger
3
  from src.retrievers.base_retriever import Retriever
 
 
 
4
 
 
5
 
6
  logger = get_logger()
7
 
8
 
9
  class ESRetriever(Retriever):
10
- def __init__(self, data_set: ) -> None:
 
11
 
12
- pass
 
 
13
 
14
- def retrieve(self, query: str, k: int):
15
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import DatasetDict, load_dataset
2
  from src.utils.log import get_logger
3
  from src.retrievers.base_retriever import Retriever
4
+ from elasticsearch import Elasticsearch
5
+ from dotenv import load_dotenv
6
+ import os
7
 
8
+ load_dotenv()
9
 
10
  logger = get_logger()
11
 
12
 
13
  class ESRetriever(Retriever):
14
+ def __init__(self, dataset: DatasetDict) -> None:
15
+ self.dataset = dataset["train"]
16
 
17
+ es_host = os.getenv("ELASTIC_HOST", "localhost")
18
+ es_password = os.getenv("ELASTIC_PASSWORD")
19
+ es_username = os.getenv("ELASTIC_USERNAME")
20
 
21
+ self.client = Elasticsearch(
22
+ hosts=[es_host], http_auth=(es_username, es_password))
23
+
24
+ if self.client.indices.exists(index="paragraphs"):
25
+ self.dataset.load_elasticsearch_index(
26
+ "paragraphs", es_index_name="paragraphs", es_client=self.client)
27
+ else:
28
+ logger.info(f"Creating index 'paragraphs' on {es_host}")
29
+ self.dataset.add_elasticsearch_index(column="text",
30
+ index_name="paragraphs",
31
+ es_index_name="paragraphs",
32
+ es_client=self.client)
33
+
34
+ def retrieve(self, query: str, k: int = 5):
35
+ return self.dataset.get_nearest_examples("paragraphs", query, k)