File size: 3,761 Bytes
966108f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_community.retrievers import ElasticSearchBM25Retriever
from langchain_community.vectorstores.elastic_vector_search import ElasticVectorSearch
from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
import elasticsearch


from typing import Optional, List


class HybridRetriever(BaseRetriever):
    dense_db: ElasticVectorSearch
    dense_retriever: VectorStoreRetriever
    sparse_retriever: ElasticSearchBM25Retriever
    index_dense: str
    index_sparse: str
    top_k_dense: int
    top_k_sparse: int

    is_training: bool = False
    
    @classmethod
    def create(
        cls, dense_db, dense_retriever, sparse_retriever, index_dense, index_sparse, top_k_dense, top_k_sparse
        ):

        return cls(
                dense_db=dense_db,
                dense_retriever=dense_retriever,
                sparse_retriever=sparse_retriever,
                index_dense=index_dense,
                index_sparse=index_sparse,
                top_k_dense=top_k_dense,
                top_k_sparse=top_k_sparse,
            )

    def reset_indices(self):
        result = self.dense_db.client.indices.delete(
            index=self.index_dense,
            ignore_unavailable=True,
            allow_no_indices=True,
        )

        
        logging.info('dense_db delete:', result)

        result = self.sparse_retriever.client.indices.delete(
            index=self.index_sparse,
            ignore_unavailable=True,
            allow_no_indices=True,
        )

        logging.info('sparse_retriever delete:', result)

    def add_documents(self, documents, batch_size=25):
        for i in range(0, len(documents), batch_size):
            print('batch', i) 
            dense_batch = documents[i:i + batch_size]
            sparse_batch = [doc.page_content for doc in dense_batch]
            self.dense_retriever.add_documents(dense_batch)
            self.sparse_retriever.add_texts(sparse_batch)
        
    def _get_relevant_documents(self, query: str, **kwargs):
        dense_results = self.dense_retriever.get_relevant_documents(query)[:self.top_k_dense]
        sparse_results = self.sparse_retriever.get_relevant_documents(query)[:self.top_k_sparse]

        # Combine results (you'll need a strategy here) 
        combined_results = dense_results + sparse_results
        # result_text = [doc.page_content for doc in combined_results]

        # reranked_result = rerank.rerank(query, documents=result_text, model="rerank-lite-1", top_k=self.top_k_dense+self.top_k_sparse)
        # reranked_result = sorted(reranked_result.results, key=lambda result: result.index)

        # Create LangChain Documents
        documents = [Document(page_content=doc.page_content, metadata=doc.metadata) for doc in combined_results]
        # documents = [Document(page_content=doc.page_content, metadata=doc.metadata, relevance_score=result.relevance_score) for result, doc in zip(reranked_result, combined_results)]
        return documents
    
    async def aget_relevant_documents(self, query: str): 
        raise NotImplementedError 

def get_dense_db(elasticsearch_url, index_dense, embeddings):
    dense_db = ElasticVectorSearch(
        elasticsearch_url=elasticsearch_url,
        index_name=index_dense,
        embedding=embeddings,
    )
    return dense_db
    
def get_sparse_retriever(elasticsearch_url, index_sparse):
    sparse_retriever = ElasticSearchBM25Retriever(client=elasticsearch.Elasticsearch(elasticsearch_url), 
                                                  index_name=index_sparse)
    return sparse_retriever