Jorge Henao commited on
Commit
3a70faa
β€’
1 Parent(s): 9129bee

init app update

Browse files
Files changed (2) hide show
  1. config.py +0 -0
  2. document_quieries +39 -0
config.py ADDED
File without changes
document_quieries ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from haystack.nodes import BM25Retriever, FARMReader
2
+ from haystack.document_stores import ElasticsearchDocumentStore
3
+ from haystack.pipelines import ExtractiveQAPipeline
4
+
5
+ import certifi
6
+ ca_certs=certifi.where()
7
+
8
+ from abc import ABC, abstractmethod
9
+
10
+ class DocumentQueries(ABC):
11
+
12
+ @abstractmethod
13
+ def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str):
14
+ pass
15
+
16
+ class ExtractiveProposalQueries(DocumentQueries):
17
+
18
+ def __init__(self, es_host: str, es_index: str, es_user, es_password, reader_name_or_path: str, use_gpu = False) -> None:
19
+ reader = FARMReader(model_name_or_path = reader_name_or_path, use_gpu = use_gpu, num_processes=1)
20
+ self._initialize_pipeline(es_host, es_index, es_user, es_password, reader = reader)
21
+
22
+
23
+ def _initialize_pipeline(self, es_host, es_index, es_user, es_password, reader = None):
24
+ if reader is not None:
25
+ self.reader = reader
26
+ self.es_host = es_host
27
+ self.es_user = es_user
28
+ self.es_password = es_password
29
+ self.document_store = ElasticsearchDocumentStore(host = es_host, username=es_user, password=es_password, index = es_index, port = 443, scheme='https', verify_certs=True, ca_certs=ca_certs)
30
+ self.retriever = BM25Retriever(document_store = self.document_store)
31
+ self.pipe = ExtractiveQAPipeline(self.reader, self.retriever)
32
+
33
+ def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str = None) :
34
+ if es_index is not None:
35
+ self._initialize_pipeline(self.es_host, es_index, self.es_user, self.es_password)
36
+ params = {"Retriever": {"top_k": retriever_top_k}, "Reader": {"top_k": reader_top_k}}
37
+ prediction = self.pipe.run( query = query, params = params)
38
+ return prediction["answers"]
39
+