|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch |
|
from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
|
|
mongo_username = os.getenv('MONGO_USERNAME') |
|
mongo_password = os.getenv('MONGO_PASSWORD') |
|
mongo_database = os.getenv('MONGO_DATABASE') |
|
mongo_connection_str = os.getenv('MONGO_CONNECTION_STRING') |
|
mongo_collection_name = os.getenv('MONGO_COLLECTION') |
|
|
|
|
|
MODEL_KWARGS = {"device": "cpu"} |
|
ENCODE_KWARGS = {"normalize_embeddings": True, |
|
"batch_size": 32} |
|
EMBEDDING_DIMENSIONS = 1024 |
|
MODEL_NAME = "BAAI/bge-m3" |
|
FINAL_TOP_K = 10 |
|
HYBRID_FULLTEXT_PENALTY = 60 |
|
HYBRID_VECTOR_PENALTY = 60 |
|
|
|
|
|
embed_model = HuggingFaceEmbeddings( |
|
model_name=MODEL_NAME, |
|
model_kwargs=MODEL_KWARGS, |
|
encode_kwargs=ENCODE_KWARGS |
|
) |
|
|
|
|
|
num_vector_candidates = max(20, 2 * FINAL_TOP_K) |
|
num_text_candidates = max(20, 2 * FINAL_TOP_K) |
|
vector_k = num_vector_candidates |
|
vector_num_candidates_for_operator = vector_k * 10 |
|
|
|
|
|
vector_store = MongoDBAtlasVectorSearch.from_connection_string( |
|
connection_string=mongo_connection_str, |
|
namespace=f"{mongo_database}.{mongo_collection_name}", |
|
embedding=embed_model, |
|
index_name="search_index_v1", |
|
) |
|
|
|
|
|
def get_retriever(**kwargs): |
|
retriever = MongoDBAtlasHybridSearchRetriever( |
|
vectorstore=vector_store, |
|
search_index_name='search_index_v1', |
|
embedding=embed_model, |
|
text_key= 'text', |
|
embedding_key='embedding', |
|
top_k=FINAL_TOP_K, |
|
vector_penalty=HYBRID_VECTOR_PENALTY, |
|
fulltext_penalty=HYBRID_FULLTEXT_PENALTY, |
|
vector_search_params={ |
|
"k": vector_k, |
|
"numCandidates": vector_num_candidates_for_operator |
|
}, |
|
text_search_params={ |
|
"limit": num_text_candidates |
|
}, |
|
pre_filter=kwargs |
|
) |
|
return retriever |
|
|
|
|