File size: 1,389 Bytes
de850e8
 
60d7b42
 
de850e8
 
 
 
 
 
 
 
 
 
5a66b9f
de850e8
 
 
 
 
 
 
5a66b9f
de850e8
 
 
5a66b9f
de850e8
5a66b9f
de850e8
60d7b42
de850e8
 
 
5a66b9f
60d7b42
 
 
 
 
8eff983
 
60d7b42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# retrievers.py
import pickle
from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_community.vectorstores import FAISS
from langchain.retrievers import EnsembleRetriever


from .config import FAISS_DB_INDEX, BM25_INDEX


def load_bm25_retriever():
    with open(BM25_INDEX, "rb") as f:
        bm25_retriever = pickle.load(f)
    return bm25_retriever.with_config(run_name="BM25Retriever")


def load_faiss_retriever(embeddings):
    faiss_db = FAISS.load_local(
        FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True
    )
    faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
    return faiss_retriever.with_config(run_name="FaissRetriever")


def load_retrievers(embeddings):
    faiss_retriever = load_faiss_retriever(embeddings)

    bm25_retriever = load_bm25_retriever()

    ensemble_retriever = EnsembleRetriever(
        retrievers=[bm25_retriever, faiss_retriever],
        weights=[0.7, 0.3],
        search_type="mmr",
    )

    compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
    compression_retriever = ContextualCompressionRetriever(
        base_compressor=compressor,
        base_retriever=ensemble_retriever,
    ).with_config(run_name="ContextualCompressionRetriever")

    return compression_retriever