Spaces:
Sleeping
Sleeping
# # from langchain_chroma import Chroma | |
# # from langchain_openai import OpenAIEmbeddings | |
# from langchain_community.embeddings import HuggingFaceEmbeddings | |
# from utils.vector_store import get_vector_store | |
# import os | |
# # Define different embedding model options | |
# # EMBEDDING_CONFIGS = { | |
# # "Accuracy (OpenAI text-embedding-3-large)": OpenAIEmbeddings(model="text-embedding-3-large"), | |
# # "Performance (OpenAI text-embedding-3-small)": OpenAIEmbeddings(model="text-embedding-3-small"), | |
# # "Instruction-based (HuggingFace bge-large-en)": HuggingFaceEmbeddings(model_name="BAAI/bge-large-en"), | |
# # "QA Optimized (HuggingFace e5-large-v2)": HuggingFaceEmbeddings(model_name="intfloat/e5-large-v2"), | |
# # } | |
# EMBEDDING_CONFIGS = { | |
# "General-purpose (bge-large-en)": HuggingFaceEmbeddings( | |
# model_name="BAAI/bge-large-en", | |
# model_kwargs={"device": "cpu"}, # or "cuda" if you have GPU | |
# encode_kwargs={"normalize_embeddings": True} | |
# ), | |
# "Fast & lightweight (bge-small-en)": HuggingFaceEmbeddings( | |
# model_name="BAAI/bge-small-en", | |
# model_kwargs={"device": "cpu"}, | |
# encode_kwargs={"normalize_embeddings": True} | |
# ), | |
# "QA optimized (e5-large-v2)": HuggingFaceEmbeddings( | |
# model_name="intfloat/e5-large-v2", | |
# model_kwargs={"device": "cpu"}, | |
# encode_kwargs={"normalize_embeddings": True} | |
# ), | |
# "Instruction-tuned (instructor-large)": HuggingFaceEmbeddings( | |
# model_name="hkunlp/instructor-large", | |
# model_kwargs={"device": "cpu"}, | |
# encode_kwargs={"normalize_embeddings": True} | |
# ), | |
# } | |
# # Default vector store path | |
# VECTOR_STORE_PATH = "./vector_stores/mes_db" | |
# def test_retriever_with_embeddings(query: str, embedding_model, k: int = 3): | |
# """Retrieve documents using a specific embedding model.""" | |
# vector_store = get_vector_store( | |
# persist_directory=VECTOR_STORE_PATH, | |
# embedding=embedding_model | |
# ) | |
# retriever = vector_store.as_retriever(search_kwargs={"k": k}) | |
# docs = retriever.get_relevant_documents(query) | |
# # Deduplicate based on page_content | |
# seen = set() | |
# unique_docs = [] | |
# for doc in docs: | |
# if doc.page_content not in seen: | |
# seen.add(doc.page_content) | |
# unique_docs.append(doc) | |
# return unique_docs | |
# def compare_embeddings(query: str, k: int = 3): | |
# print(f"\n=== Comparing embeddings for: '{query}' ===\n") | |
# for name, embedding_model in EMBEDDING_CONFIGS.items(): | |
# try: | |
# print(f"π {name}:") | |
# print("-" * 50) | |
# docs = test_retriever_with_embeddings(query, embedding_model, k) | |
# for i, doc in enumerate(docs, 1): | |
# source = doc.metadata.get("source", "unknown") | |
# page = doc.metadata.get("page", "N/A") | |
# preview = doc.page_content[:300] | |
# if len(doc.page_content) > 300: | |
# preview += "..." | |
# print(f"--- Chunk #{i} ---") | |
# print(f"Source: {source} | Page: {page}") | |
# print(preview) | |
# print() | |
# print("\n" + "=" * 60 + "\n") | |
# except Exception as e: | |
# print(f"β Error with {name}: {e}\n") | |
# if __name__ == "__main__": | |
# print("Embedding Model Benchmark Tool") | |
# print("\nType 'compare: <question>' to compare all embeddings") | |
# print("Type 'exit' to quit\n") | |
# while True: | |
# user_input = input("\nEnter your question: ").strip() | |
# if user_input.lower() == "exit": | |
# break | |
# elif user_input.lower().startswith("compare: "): | |
# query = user_input[9:] | |
# compare_embeddings(query) | |
# else: | |
# print("Please use the format: compare: <question>") | |
from utils.vector_store import get_vector_store | |
def test_retriever_with_embeddings(query: str, embedding_model, k: int = 3, vector_store_path="./chroma_db"): | |
"""Test retriever with a specific embedding model and vector store""" | |
vector_store = get_vector_store( | |
persist_directory=vector_store_path, embedding=embedding_model) | |
retriever = vector_store.as_retriever(search_kwargs={"k": k}) | |
docs = retriever.get_relevant_documents(query) | |
# Deduplicate based on page_content | |
seen = set() | |
unique_docs = [] | |
for doc in docs: | |
if doc.page_content not in seen: | |
seen.add(doc.page_content) | |
unique_docs.append(doc) | |
print(f"\nUsing vector store: {vector_store_path}") | |
print(f"Top {len(unique_docs)} unique chunks retrieved for: '{query}'\n") | |
for i, doc in enumerate(unique_docs, 1): | |
source = doc.metadata.get("source", "unknown") | |
page = doc.metadata.get("page", "N/A") | |
print(f"--- Chunk #{i} ---") | |
print(f"Source: {source} | Page: {page}") | |
preview = doc.page_content[:300] | |
if len(doc.page_content) > 300: | |
preview += "..." | |
print(preview) | |
print() | |
def compare_retrievers_with_embeddings(query: str, embedding_model, k: int = 3): | |
"""Compare results from different vector stores using the same embedding model""" | |
stores = { | |
"MES Manual": "./vector_stores/mes_db", | |
"Technical Docs": "./vector_stores/tech_db", | |
"General Docs": "./vector_stores/general_db" | |
} | |
print(f"\n=== Comparing retrievers for: '{query}' ===\n") | |
for store_name, store_path in stores.items(): | |
try: | |
print(f"π {store_name}:") | |
print("-" * 50) | |
test_retriever_with_embeddings( | |
query, embedding_model, k=k, vector_store_path=store_path) | |
print("\n" + "="*60 + "\n") | |
except Exception as e: | |
print(f"β Could not access {store_name}: {e}\n") | |
if __name__ == "__main__": | |
from embedding_config import EMBEDDING_CONFIGS | |
print("Multi-Vector Store RAG Tester (with Embeddings)") | |
print("\nAvailable commands:") | |
print(" - Enter a question to test default store") | |
print(" - Type 'mes: <question>' for MES manual") | |
print(" - Type 'tech: <question>' for technical docs") | |
print(" - Type 'general: <question>' for general docs") | |
print(" - Type 'compare: <question>' to compare all stores") | |
print(" - Type 'exit' to quit") | |
# Choose embedding model at start | |
print("\nAvailable Embedding Models:") | |
for i, name in enumerate(EMBEDDING_CONFIGS.keys(), 1): | |
print(f" {i}. {name}") | |
choice = int(input("Select embedding model number: ").strip()) | |
embedding_model = list(EMBEDDING_CONFIGS.values())[choice - 1] | |
while True: | |
user_input = input("\nEnter your question: ").strip() | |
if user_input.lower() == "exit": | |
break | |
elif user_input.lower().startswith("mes: "): | |
query = user_input[5:] | |
test_retriever_with_embeddings( | |
query, embedding_model, vector_store_path="./vector_stores/mes_db") | |
elif user_input.lower().startswith("tech: "): | |
query = user_input[6:] | |
test_retriever_with_embeddings( | |
query, embedding_model, vector_store_path="./vector_stores/tech_db") | |
elif user_input.lower().startswith("general: "): | |
query = user_input[9:] | |
test_retriever_with_embeddings( | |
query, embedding_model, vector_store_path="./vector_stores/general_db") | |
elif user_input.lower().startswith("compare: "): | |
query = user_input[9:] | |
compare_retrievers_with_embeddings(query, embedding_model) | |
else: | |
test_retriever_with_embeddings( | |
user_input, embedding_model) # Default store | |