mes-chatbot-rag-backend / test_embeddings_retriever.py
christian
Remove big files for HF
402e33f
raw
history blame
7.76 kB
# # from langchain_chroma import Chroma
# # from langchain_openai import OpenAIEmbeddings
# from langchain_community.embeddings import HuggingFaceEmbeddings
# from utils.vector_store import get_vector_store
# import os
# # Define different embedding model options
# # EMBEDDING_CONFIGS = {
# # "Accuracy (OpenAI text-embedding-3-large)": OpenAIEmbeddings(model="text-embedding-3-large"),
# # "Performance (OpenAI text-embedding-3-small)": OpenAIEmbeddings(model="text-embedding-3-small"),
# # "Instruction-based (HuggingFace bge-large-en)": HuggingFaceEmbeddings(model_name="BAAI/bge-large-en"),
# # "QA Optimized (HuggingFace e5-large-v2)": HuggingFaceEmbeddings(model_name="intfloat/e5-large-v2"),
# # }
# EMBEDDING_CONFIGS = {
# "General-purpose (bge-large-en)": HuggingFaceEmbeddings(
# model_name="BAAI/bge-large-en",
# model_kwargs={"device": "cpu"}, # or "cuda" if you have GPU
# encode_kwargs={"normalize_embeddings": True}
# ),
# "Fast & lightweight (bge-small-en)": HuggingFaceEmbeddings(
# model_name="BAAI/bge-small-en",
# model_kwargs={"device": "cpu"},
# encode_kwargs={"normalize_embeddings": True}
# ),
# "QA optimized (e5-large-v2)": HuggingFaceEmbeddings(
# model_name="intfloat/e5-large-v2",
# model_kwargs={"device": "cpu"},
# encode_kwargs={"normalize_embeddings": True}
# ),
# "Instruction-tuned (instructor-large)": HuggingFaceEmbeddings(
# model_name="hkunlp/instructor-large",
# model_kwargs={"device": "cpu"},
# encode_kwargs={"normalize_embeddings": True}
# ),
# }
# # Default vector store path
# VECTOR_STORE_PATH = "./vector_stores/mes_db"
# def test_retriever_with_embeddings(query: str, embedding_model, k: int = 3):
# """Retrieve documents using a specific embedding model."""
# vector_store = get_vector_store(
# persist_directory=VECTOR_STORE_PATH,
# embedding=embedding_model
# )
# retriever = vector_store.as_retriever(search_kwargs={"k": k})
# docs = retriever.get_relevant_documents(query)
# # Deduplicate based on page_content
# seen = set()
# unique_docs = []
# for doc in docs:
# if doc.page_content not in seen:
# seen.add(doc.page_content)
# unique_docs.append(doc)
# return unique_docs
# def compare_embeddings(query: str, k: int = 3):
# print(f"\n=== Comparing embeddings for: '{query}' ===\n")
# for name, embedding_model in EMBEDDING_CONFIGS.items():
# try:
# print(f"πŸ” {name}:")
# print("-" * 50)
# docs = test_retriever_with_embeddings(query, embedding_model, k)
# for i, doc in enumerate(docs, 1):
# source = doc.metadata.get("source", "unknown")
# page = doc.metadata.get("page", "N/A")
# preview = doc.page_content[:300]
# if len(doc.page_content) > 300:
# preview += "..."
# print(f"--- Chunk #{i} ---")
# print(f"Source: {source} | Page: {page}")
# print(preview)
# print()
# print("\n" + "=" * 60 + "\n")
# except Exception as e:
# print(f"❌ Error with {name}: {e}\n")
# if __name__ == "__main__":
# print("Embedding Model Benchmark Tool")
# print("\nType 'compare: <question>' to compare all embeddings")
# print("Type 'exit' to quit\n")
# while True:
# user_input = input("\nEnter your question: ").strip()
# if user_input.lower() == "exit":
# break
# elif user_input.lower().startswith("compare: "):
# query = user_input[9:]
# compare_embeddings(query)
# else:
# print("Please use the format: compare: <question>")
from utils.vector_store import get_vector_store
def test_retriever_with_embeddings(query: str, embedding_model, k: int = 3, vector_store_path="./chroma_db"):
"""Test retriever with a specific embedding model and vector store"""
vector_store = get_vector_store(
persist_directory=vector_store_path, embedding=embedding_model)
retriever = vector_store.as_retriever(search_kwargs={"k": k})
docs = retriever.get_relevant_documents(query)
# Deduplicate based on page_content
seen = set()
unique_docs = []
for doc in docs:
if doc.page_content not in seen:
seen.add(doc.page_content)
unique_docs.append(doc)
print(f"\nUsing vector store: {vector_store_path}")
print(f"Top {len(unique_docs)} unique chunks retrieved for: '{query}'\n")
for i, doc in enumerate(unique_docs, 1):
source = doc.metadata.get("source", "unknown")
page = doc.metadata.get("page", "N/A")
print(f"--- Chunk #{i} ---")
print(f"Source: {source} | Page: {page}")
preview = doc.page_content[:300]
if len(doc.page_content) > 300:
preview += "..."
print(preview)
print()
def compare_retrievers_with_embeddings(query: str, embedding_model, k: int = 3):
"""Compare results from different vector stores using the same embedding model"""
stores = {
"MES Manual": "./vector_stores/mes_db",
"Technical Docs": "./vector_stores/tech_db",
"General Docs": "./vector_stores/general_db"
}
print(f"\n=== Comparing retrievers for: '{query}' ===\n")
for store_name, store_path in stores.items():
try:
print(f"πŸ” {store_name}:")
print("-" * 50)
test_retriever_with_embeddings(
query, embedding_model, k=k, vector_store_path=store_path)
print("\n" + "="*60 + "\n")
except Exception as e:
print(f"❌ Could not access {store_name}: {e}\n")
if __name__ == "__main__":
from embedding_config import EMBEDDING_CONFIGS
print("Multi-Vector Store RAG Tester (with Embeddings)")
print("\nAvailable commands:")
print(" - Enter a question to test default store")
print(" - Type 'mes: <question>' for MES manual")
print(" - Type 'tech: <question>' for technical docs")
print(" - Type 'general: <question>' for general docs")
print(" - Type 'compare: <question>' to compare all stores")
print(" - Type 'exit' to quit")
# Choose embedding model at start
print("\nAvailable Embedding Models:")
for i, name in enumerate(EMBEDDING_CONFIGS.keys(), 1):
print(f" {i}. {name}")
choice = int(input("Select embedding model number: ").strip())
embedding_model = list(EMBEDDING_CONFIGS.values())[choice - 1]
while True:
user_input = input("\nEnter your question: ").strip()
if user_input.lower() == "exit":
break
elif user_input.lower().startswith("mes: "):
query = user_input[5:]
test_retriever_with_embeddings(
query, embedding_model, vector_store_path="./vector_stores/mes_db")
elif user_input.lower().startswith("tech: "):
query = user_input[6:]
test_retriever_with_embeddings(
query, embedding_model, vector_store_path="./vector_stores/tech_db")
elif user_input.lower().startswith("general: "):
query = user_input[9:]
test_retriever_with_embeddings(
query, embedding_model, vector_store_path="./vector_stores/general_db")
elif user_input.lower().startswith("compare: "):
query = user_input[9:]
compare_retrievers_with_embeddings(query, embedding_model)
else:
test_retriever_with_embeddings(
user_input, embedding_model) # Default store