law_poc / utils /retriever.py
SUMANA SUMANAKUL (ING)
commit
8e5a9dd
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever
import os
from dotenv import load_dotenv
load_dotenv()
# ---- MongoDB credentials ----
mongo_username = os.getenv('MONGO_USERNAME')
mongo_password = os.getenv('MONGO_PASSWORD')
mongo_database = os.getenv('MONGO_DATABASE')
mongo_connection_str = os.getenv('MONGO_CONNECTION_STRING')
mongo_collection_name = os.getenv('MONGO_COLLECTION')
# ---- Common Configurations & Hybrid Retrieval Configuration ----
MODEL_KWARGS = {"device": "cpu"}
ENCODE_KWARGS = {"normalize_embeddings": True,
"batch_size": 32}
EMBEDDING_DIMENSIONS = 1024
MODEL_NAME = "BAAI/bge-m3"
FINAL_TOP_K = 10
HYBRID_FULLTEXT_PENALTY = 60
HYBRID_VECTOR_PENALTY = 60
# ---- Embedding model ----
embed_model = HuggingFaceEmbeddings(
model_name=MODEL_NAME,
model_kwargs=MODEL_KWARGS,
encode_kwargs=ENCODE_KWARGS
)
# ---- Vectore Search ----
num_vector_candidates = max(20, 2 * FINAL_TOP_K)
num_text_candidates = max(20, 2 * FINAL_TOP_K)
vector_k = num_vector_candidates
vector_num_candidates_for_operator = vector_k * 10
# ---- Vectore Store ----
vector_store = MongoDBAtlasVectorSearch.from_connection_string(
connection_string=mongo_connection_str,
namespace=f"{mongo_database}.{mongo_collection_name}",
embedding=embed_model,
index_name="search_index_v1",
)
# ---- Retriever (Hybrid) ----
def get_retriever(**kwargs):
retriever = MongoDBAtlasHybridSearchRetriever(
vectorstore=vector_store,
search_index_name='search_index_v1',
embedding=embed_model,
text_key= 'text', #'token',
embedding_key='embedding',
top_k=FINAL_TOP_K,
vector_penalty=HYBRID_VECTOR_PENALTY,
fulltext_penalty=HYBRID_FULLTEXT_PENALTY,
vector_search_params={
"k": vector_k,
"numCandidates": vector_num_candidates_for_operator
},
text_search_params={
"limit": num_text_candidates
},
pre_filter=kwargs
)
return retriever