Spaces:
Running
Running
import os | |
import logging | |
from typing import List, Dict, Any, Optional | |
from pathlib import Path | |
# import torch | |
from dotenv import load_dotenv | |
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore | |
from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever, QdrantSparseEmbeddingRetriever | |
from haystack.components.embedders import OpenAIDocumentEmbedder, OpenAITextEmbedder | |
from haystack.components.builders.prompt_builder import PromptBuilder | |
from haystack.components.joiners.document_joiner import DocumentJoiner | |
from haystack.components.preprocessors.document_cleaner import DocumentCleaner | |
# from haystack.components.rankers.transformers import TransformersRanker | |
from haystack.components.writers import DocumentWriter | |
from haystack.components.generators.openai import OpenAIGenerator | |
from haystack import Pipeline | |
from haystack.utils import Secret | |
from haystack import tracing | |
from haystack.tracing.logging_tracer import LoggingTracer | |
# Load environment variables | |
load_dotenv() | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) | |
logging.getLogger("haystack").setLevel(logging.DEBUG) | |
tracing.tracer.is_content_tracing_enabled = True # to enable tracing/logging content (inputs/outputs) | |
tracing.enable_tracing(LoggingTracer(tags_color_strings={"haystack.component.input": "\x1b[1;31m", "haystack.component.name": "\x1b[1;34m"})) | |
class RAGPipeline: | |
def __init__( | |
self, | |
embedding_model_name: str = "BAAI/bge-en-icl", | |
llm_model_name: str = "meta-llama/Llama-3.3-70B-Instruct", | |
qdrant_path: str = None | |
): | |
self.embedding_model_name = embedding_model_name | |
self.llm_model_name = llm_model_name | |
self.qdrant_path = qdrant_path | |
self.nebius_api_key = Secret.from_token(os.getenv("NEBIUS_API_KEY")) | |
if not self.nebius_api_key: | |
logger.warning("NEBIUS_API_KEY not found in environment variables") | |
# Initialize document stores and components | |
self.init_document_store() | |
self.init_components() | |
self.build_indexing_pipeline() | |
self.build_query_pipeline() | |
def init_document_store(self): | |
"""Initialize Qdrant document store for both vector and BM25 search""" | |
# Qdrant store for both vector and BM25 search | |
self.document_store = QdrantDocumentStore( | |
path=self.qdrant_path, | |
embedding_dim=4096, # Dimension for BGE model | |
recreate_index=False, | |
on_disk=True, | |
on_disk_payload=True, | |
index="ltu_documents", | |
force_disable_check_same_thread=True, | |
use_sparse_embeddings=True # Enable BM25 support | |
) | |
def init_components(self): | |
"""Initialize all components needed for the pipelines""" | |
# Document processing | |
self.document_cleaner = DocumentCleaner() | |
# Embedding components | |
self.document_embedder = OpenAIDocumentEmbedder( | |
api_base_url="https://api.studio.nebius.com/v1/", | |
model=self.embedding_model_name, | |
api_key=self.nebius_api_key, | |
) | |
self.text_embedder = OpenAITextEmbedder( | |
api_base_url="https://api.studio.nebius.com/v1/", | |
model=self.embedding_model_name, | |
api_key=self.nebius_api_key, | |
) | |
# Retrievers | |
self.bm25_retriever = QdrantSparseEmbeddingRetriever( | |
document_store=self.document_store, | |
top_k=5 | |
) | |
self.embedding_retriever = QdrantEmbeddingRetriever( | |
document_store=self.document_store, | |
top_k=5 | |
) | |
# Document joiner for combining results | |
self.document_joiner = DocumentJoiner() | |
# Ranker for re-ranking combined results | |
# self.ranker = TransformersRanker( | |
# model="cross-encoder/ms-marco-MiniLM-L-6-v2", | |
# top_k=5, | |
# device="cuda" if self.use_gpu else "cpu" | |
# ) | |
# LLM components | |
self.llm = OpenAIGenerator( | |
api_base_url="https://api.studio.nebius.com/v1/", | |
model=self.llm_model_name, | |
api_key=self.nebius_api_key, | |
generation_kwargs={ | |
"max_tokens": 1024, | |
"temperature": 0.1, | |
"top_p": 0.95, | |
} | |
) | |
# Prompt builder | |
self.prompt_builder = PromptBuilder( | |
template=""" | |
<s>[INST] You are a helpful assistant that answers questions based on the provided context. | |
Context: | |
{% for document in documents %} | |
{{ document.content }} | |
{% endfor %} | |
Question: {{ question }} | |
Answer the question based only on the provided context. If the context doesn't contain the answer, say "I don't have enough information to answer this question." | |
Answer: [/INST] | |
""" | |
) | |
def build_indexing_pipeline(self): | |
"""Build the pipeline for indexing documents""" | |
self.indexing_pipeline = Pipeline() | |
self.indexing_pipeline.add_component("document_cleaner", self.document_cleaner) | |
self.indexing_pipeline.add_component("document_embedder", self.document_embedder) | |
self.indexing_pipeline.add_component("document_writer", DocumentWriter(document_store=self.document_store)) | |
# Connect components | |
self.indexing_pipeline.connect("document_cleaner", "document_embedder") | |
self.indexing_pipeline.connect("document_embedder", "document_writer") | |
def build_query_pipeline(self): | |
"""Build the pipeline for querying""" | |
self.query_pipeline = Pipeline() | |
# Add components | |
self.query_pipeline.add_component("text_embedder", self.text_embedder) | |
# self.query_pipeline.add_component("bm25_retriever", self.bm25_retriever) | |
self.query_pipeline.add_component("embedding_retriever", self.embedding_retriever) | |
# self.query_pipeline.add_component("document_joiner", self.document_joiner) | |
# self.query_pipeline.add_component("ranker", self.ranker) | |
self.query_pipeline.add_component("prompt_builder", self.prompt_builder) | |
self.query_pipeline.add_component("llm", self.llm) | |
# Connect components | |
self.query_pipeline.connect("text_embedder.embedding", "embedding_retriever.query_embedding") | |
# self.query_pipeline.connect("bm25_retriever", "document_joiner.documents_1") | |
# self.query_pipeline.connect("embedding_retriever", "document_joiner.documents_2") | |
# self.query_pipeline.connect("document_joiner", "ranker") | |
# self.query_pipeline.connect("ranker", "prompt_builder.documents") | |
self.query_pipeline.connect("embedding_retriever.documents", "prompt_builder.documents") | |
self.query_pipeline.connect("prompt_builder.prompt", "llm") | |
def index_documents(self, documents: List[Dict[str, Any]]): | |
""" | |
Index documents in the document store. | |
Args: | |
documents: List of documents to index | |
""" | |
logger.info(f"Indexing {len(documents)} documents") | |
try: | |
self.indexing_pipeline.run( | |
{"document_cleaner": {"documents": documents}} | |
) | |
logger.info("Indexing completed successfully") | |
except Exception as e: | |
logger.error(f"Error during indexing: {e}") | |
def query(self, question: str, top_k: int = 5) -> Dict[str, Any]: | |
""" | |
Query the RAG pipeline with a question. | |
Args: | |
question: The question to ask | |
top_k: Number of documents to retrieve | |
Returns: | |
Dictionary containing the answer and retrieved documents | |
""" | |
logger.info(f"Querying with question: {question}") | |
try: | |
# Update top_k for retrievers | |
self.bm25_retriever.top_k = top_k | |
self.embedding_retriever.top_k = top_k | |
# Run the query pipeline | |
result = self.query_pipeline.run({ | |
"text_embedder": {"text": question}, | |
# "bm25_retriever": {"query": question}, | |
"prompt_builder": {"question": question} | |
}) | |
# Extract answer and documents | |
answer = result["llm"]["replies"][0] | |
# documents = result["embedding_retriever"]["documents"] | |
return { | |
"answer": answer, | |
"documents": [], #documents, | |
"question": question | |
} | |
except Exception as e: | |
logger.error(f"Error during query: {e}") | |
return { | |
"answer": f"An error occurred: {str(e)}", | |
"documents": [], | |
"question": question | |
} | |
def get_document_count(self) -> int: | |
""" | |
Get the number of documents in the document store. | |
Returns: | |
Document count | |
""" | |
return self.document_store.count_documents() |