refactor: Update retrievers.py to improve retriever configuration and naming conventions
Browse files- libs/retrievers.py +5 -7
libs/retrievers.py
CHANGED
@@ -12,7 +12,7 @@ from .config import FAISS_DB_INDEX, BM25_INDEX
|
|
12 |
def load_bm25_retriever():
|
13 |
with open(BM25_INDEX, "rb") as f:
|
14 |
bm25_retriever = pickle.load(f)
|
15 |
-
return bm25_retriever
|
16 |
|
17 |
|
18 |
def load_faiss_retriever(embeddings):
|
@@ -20,21 +20,19 @@ def load_faiss_retriever(embeddings):
|
|
20 |
FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True
|
21 |
)
|
22 |
faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
|
23 |
-
return faiss_retriever
|
24 |
|
25 |
|
26 |
def load_retrievers(embeddings):
|
27 |
-
faiss_retriever = load_faiss_retriever(embeddings)
|
28 |
-
run_name="FaissRetriever"
|
29 |
-
)
|
30 |
|
31 |
-
bm25_retriever = load_bm25_retriever()
|
32 |
|
33 |
ensemble_retriever = EnsembleRetriever(
|
34 |
retrievers=[bm25_retriever, faiss_retriever],
|
35 |
weights=[0.7, 0.3],
|
36 |
search_type="mmr",
|
37 |
-
)
|
38 |
|
39 |
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
|
40 |
compression_retriever = ContextualCompressionRetriever(
|
|
|
12 |
def load_bm25_retriever():
|
13 |
with open(BM25_INDEX, "rb") as f:
|
14 |
bm25_retriever = pickle.load(f)
|
15 |
+
return bm25_retriever.with_config(run_name="BM25Retriever")
|
16 |
|
17 |
|
18 |
def load_faiss_retriever(embeddings):
|
|
|
20 |
FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True
|
21 |
)
|
22 |
faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
|
23 |
+
return faiss_retriever.with_config(run_name="FaissRetriever")
|
24 |
|
25 |
|
26 |
def load_retrievers(embeddings):
|
27 |
+
faiss_retriever = load_faiss_retriever(embeddings)
|
|
|
|
|
28 |
|
29 |
+
bm25_retriever = load_bm25_retriever()
|
30 |
|
31 |
ensemble_retriever = EnsembleRetriever(
|
32 |
retrievers=[bm25_retriever, faiss_retriever],
|
33 |
weights=[0.7, 0.3],
|
34 |
search_type="mmr",
|
35 |
+
)
|
36 |
|
37 |
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
|
38 |
compression_retriever = ContextualCompressionRetriever(
|