anpigon commited on
Commit
5a66b9f
1 Parent(s): 1c4aaba

refactor: Update retrievers.py to improve retriever configuration and naming conventions

Browse files
Files changed (1) hide show
  1. libs/retrievers.py +5 -7
libs/retrievers.py CHANGED
@@ -12,7 +12,7 @@ from .config import FAISS_DB_INDEX, BM25_INDEX
12
  def load_bm25_retriever():
13
  with open(BM25_INDEX, "rb") as f:
14
  bm25_retriever = pickle.load(f)
15
- return bm25_retriever
16
 
17
 
18
  def load_faiss_retriever(embeddings):
@@ -20,21 +20,19 @@ def load_faiss_retriever(embeddings):
20
  FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True
21
  )
22
  faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
23
- return faiss_retriever
24
 
25
 
26
  def load_retrievers(embeddings):
27
- faiss_retriever = load_faiss_retriever(embeddings).with_config(
28
- run_name="FaissRetriever"
29
- )
30
 
31
- bm25_retriever = load_bm25_retriever().with_config(run_name="BM25Retriever")
32
 
33
  ensemble_retriever = EnsembleRetriever(
34
  retrievers=[bm25_retriever, faiss_retriever],
35
  weights=[0.7, 0.3],
36
  search_type="mmr",
37
- ).with_config(run_name="EnsembleRetriever")
38
 
39
  compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
40
  compression_retriever = ContextualCompressionRetriever(
 
12
  def load_bm25_retriever():
13
  with open(BM25_INDEX, "rb") as f:
14
  bm25_retriever = pickle.load(f)
15
+ return bm25_retriever.with_config(run_name="BM25Retriever")
16
 
17
 
18
  def load_faiss_retriever(embeddings):
 
20
  FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True
21
  )
22
  faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
23
+ return faiss_retriever.with_config(run_name="FaissRetriever")
24
 
25
 
26
  def load_retrievers(embeddings):
27
+ faiss_retriever = load_faiss_retriever(embeddings)
 
 
28
 
29
+ bm25_retriever = load_bm25_retriever()
30
 
31
  ensemble_retriever = EnsembleRetriever(
32
  retrievers=[bm25_retriever, faiss_retriever],
33
  weights=[0.7, 0.3],
34
  search_type="mmr",
35
+ )
36
 
37
  compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
38
  compression_retriever = ContextualCompressionRetriever(