File size: 3,376 Bytes
1e5b124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98f2c45
1e5b124
 
 
 
b9d8bb3
1e5b124
 
0891da7
 
1e5b124
 
 
 
 
fd70a0d
 
1e5b124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from abc import ABC, abstractmethod
from haystack.nodes import BM25Retriever, FARMReader
from haystack.document_stores import ElasticsearchDocumentStore
from haystack.pipelines import ExtractiveQAPipeline
from haystack.document_stores import PineconeDocumentStore
from haystack.nodes import EmbeddingRetriever

import certifi
import datetime
import requests
from base64 import b64encode

ca_certs=certifi.where()        

class DocumentQueries(ABC):
    
    @abstractmethod
    def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str):
        pass

class PinecodeProposalQueries(DocumentQueries):
    
    def __init__(self, es_host: str, es_index: str, es_user, es_password, reader_name_or_path: str, use_gpu = True) -> None:
        reader = FARMReader(model_name_or_path = reader_name_or_path, use_gpu = use_gpu, num_processes=1, context_window_size=200)
        self._initialize_pipeline(es_host, es_index, es_user, es_password, reader = reader)
        #self.log = Log(es_host= es_host, es_index="log", es_user = es_user, es_password= es_password)

    def _initialize_pipeline(self, es_host, es_index, es_user, es_password, reader = None):
        if reader is not None:
            self.reader = reader
        self.es_host = es_host
        self.es_user = es_user
        self.es_password = es_password
        self.document_store = PineconeDocumentStore(
            api_key=es_password,
            environment = "us-east1-gcp",
            index=es_index,
            similarity="cosine",
            embedding_dim=768
        )        
        #self.retriever = BM25Retriever(document_store = self.document_store)
        self.retriever = EmbeddingRetriever(
            document_store=self.document_store,
            embedding_model="multi-qa-distilbert-dot-v1",
            model_format="sentence_transformers"
        )
        self.document_store.update_embeddings(self.retriever, batch_size=16)
        self.pipe = ExtractiveQAPipeline(self.reader, self.retriever)

    def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str = None) :
        #self.log.write_log(query, "hfspace-informecomision")
        #if es_index is not None:
            #self._initialize_pipeline(self.es_host, es_index, self.es_user, self.es_password)
        #params = {"Retriever": {"top_k": retriever_top_k}, "Reader": {"top_k": reader_top_k}}
        params = {"Retriever": {"top_k": retriever_top_k}}
        prediction = self.pipe.run( query = query, params = params)      
        return prediction["answers"]

class Log():
    
    def __init__(self, es_host: str, es_index: str, es_user, es_password) -> None:
        self.elastic_endpoint = f"https://{es_host}:443/{es_index}/_doc"
        self.credentials = b64encode(b"3pvrzh9tl:4yl4vk9ijr").decode("ascii")
        self.auth_header = { 'Authorization' : 'Basic %s' %  self.credentials }
                      
    def write_log(self, message: str, source: str) -> None:
        created_date = datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%SZ')
        post_data  = {
            "message" : message,
            "createdDate": {
                "date" : created_date
            },
            "source": source
        }
        r = requests.post(self.elastic_endpoint, json = post_data, headers = self.auth_header)
        print(r.text)