Spaces:
Runtime error
Runtime error
File size: 3,317 Bytes
1e5b124 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from abc import ABC, abstractmethod
from haystack.nodes import BM25Retriever, FARMReader
from haystack.document_stores import ElasticsearchDocumentStore
from haystack.pipelines import ExtractiveQAPipeline
from haystack.document_stores import PineconeDocumentStore
from haystack.nodes import EmbeddingRetriever
import certifi
import datetime
import requests
from base64 import b64encode
ca_certs=certifi.where()
class DocumentQueries(ABC):
@abstractmethod
def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str):
pass
class PinecodeProposalQueries(DocumentQueries):
def __init__(self, es_host: str, es_index: str, es_user, es_password, reader_name_or_path: str, use_gpu = True) -> None:
reader = FARMReader(model_name_or_path = reader_name_or_path, use_gpu = use_gpu, num_processes=1, context_window_size=200)
self._initialize_pipeline(es_host, es_index, es_user, es_password, reader = reader)
#self.log = Log(es_host= es_host, es_index="log", es_user = es_user, es_password= es_password)
def _initialize_pipeline(self, es_host, es_index, es_user, es_password, reader = None):
if reader is not None:
self.reader = reader
self.es_host = es_host
self.es_user = es_user
self.es_password = es_password
self.document_store = PineconeDocumentStore(
api_key=es_password,
environment = "us-east1-gcp",
index=es_index,
similarity="cosine",
embedding_dim=384
)
#self.retriever = BM25Retriever(document_store = self.document_store)
self.retriever = EmbeddingRetriever(
document_store=self.document_store,
embedding_model="multi-qa-MiniLM-L6-cos-v1",
model_format="sentence_transformers"
)
self.document_store.update_embeddings(self.retriever, batch_size=16
)
#self.pipe = ExtractiveQAPipeline(self.reader, self.retriever)
def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str = None) :
#self.log.write_log(query, "hfspace-informecomision")
#if es_index is not None:
#self._initialize_pipeline(self.es_host, es_index, self.es_user, self.es_password)
params = {"Retriever": {"top_k": retriever_top_k}, "Reader": {"top_k": reader_top_k}}
prediction = self.pipe.run( query = query, params = params)
return prediction["answers"]
class Log():
def __init__(self, es_host: str, es_index: str, es_user, es_password) -> None:
self.elastic_endpoint = f"https://{es_host}:443/{es_index}/_doc"
self.credentials = b64encode(b"3pvrzh9tl:4yl4vk9ijr").decode("ascii")
self.auth_header = { 'Authorization' : 'Basic %s' % self.credentials }
def write_log(self, message: str, source: str) -> None:
created_date = datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%SZ')
post_data = {
"message" : message,
"createdDate": {
"date" : created_date
},
"source": source
}
r = requests.post(self.elastic_endpoint, json = post_data, headers = self.auth_header)
print(r.text) |