ltu-chat / rag_pipeline.py
Stepan
Init
4717959
raw
history blame
9.4 kB
import os
import logging
from typing import List, Dict, Any, Optional
from pathlib import Path
# import torch
from dotenv import load_dotenv
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever, QdrantSparseEmbeddingRetriever
from haystack.components.embedders import OpenAIDocumentEmbedder, OpenAITextEmbedder
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.joiners.document_joiner import DocumentJoiner
from haystack.components.preprocessors.document_cleaner import DocumentCleaner
# from haystack.components.rankers.transformers import TransformersRanker
from haystack.components.writers import DocumentWriter
from haystack.components.generators.openai import OpenAIGenerator
from haystack import Pipeline
from haystack.utils import Secret
from haystack import tracing
from haystack.tracing.logging_tracer import LoggingTracer
# Load environment variables
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.DEBUG)
tracing.tracer.is_content_tracing_enabled = True # to enable tracing/logging content (inputs/outputs)
tracing.enable_tracing(LoggingTracer(tags_color_strings={"haystack.component.input": "\x1b[1;31m", "haystack.component.name": "\x1b[1;34m"}))
class RAGPipeline:
def __init__(
self,
embedding_model_name: str = "BAAI/bge-en-icl",
llm_model_name: str = "meta-llama/Llama-3.3-70B-Instruct",
qdrant_path: str = None
):
self.embedding_model_name = embedding_model_name
self.llm_model_name = llm_model_name
self.qdrant_path = qdrant_path
self.nebius_api_key = Secret.from_token(os.getenv("NEBIUS_API_KEY"))
if not self.nebius_api_key:
logger.warning("NEBIUS_API_KEY not found in environment variables")
# Initialize document stores and components
self.init_document_store()
self.init_components()
self.build_indexing_pipeline()
self.build_query_pipeline()
def init_document_store(self):
"""Initialize Qdrant document store for both vector and BM25 search"""
# Qdrant store for both vector and BM25 search
self.document_store = QdrantDocumentStore(
path=self.qdrant_path,
embedding_dim=4096, # Dimension for BGE model
recreate_index=False,
on_disk=True,
on_disk_payload=True,
index="ltu_documents",
force_disable_check_same_thread=True,
use_sparse_embeddings=True # Enable BM25 support
)
def init_components(self):
"""Initialize all components needed for the pipelines"""
# Document processing
self.document_cleaner = DocumentCleaner()
# Embedding components
self.document_embedder = OpenAIDocumentEmbedder(
api_base_url="https://api.studio.nebius.com/v1/",
model=self.embedding_model_name,
api_key=self.nebius_api_key,
)
self.text_embedder = OpenAITextEmbedder(
api_base_url="https://api.studio.nebius.com/v1/",
model=self.embedding_model_name,
api_key=self.nebius_api_key,
)
# Retrievers
self.bm25_retriever = QdrantSparseEmbeddingRetriever(
document_store=self.document_store,
top_k=5
)
self.embedding_retriever = QdrantEmbeddingRetriever(
document_store=self.document_store,
top_k=5
)
# Document joiner for combining results
self.document_joiner = DocumentJoiner()
# Ranker for re-ranking combined results
# self.ranker = TransformersRanker(
# model="cross-encoder/ms-marco-MiniLM-L-6-v2",
# top_k=5,
# device="cuda" if self.use_gpu else "cpu"
# )
# LLM components
self.llm = OpenAIGenerator(
api_base_url="https://api.studio.nebius.com/v1/",
model=self.llm_model_name,
api_key=self.nebius_api_key,
generation_kwargs={
"max_tokens": 1024,
"temperature": 0.1,
"top_p": 0.95,
}
)
# Prompt builder
self.prompt_builder = PromptBuilder(
template="""
<s>[INST] You are a helpful assistant that answers questions based on the provided context.
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Question: {{ question }}
Answer the question based only on the provided context. If the context doesn't contain the answer, say "I don't have enough information to answer this question."
Answer: [/INST]
"""
)
def build_indexing_pipeline(self):
"""Build the pipeline for indexing documents"""
self.indexing_pipeline = Pipeline()
self.indexing_pipeline.add_component("document_cleaner", self.document_cleaner)
self.indexing_pipeline.add_component("document_embedder", self.document_embedder)
self.indexing_pipeline.add_component("document_writer", DocumentWriter(document_store=self.document_store))
# Connect components
self.indexing_pipeline.connect("document_cleaner", "document_embedder")
self.indexing_pipeline.connect("document_embedder", "document_writer")
def build_query_pipeline(self):
"""Build the pipeline for querying"""
self.query_pipeline = Pipeline()
# Add components
self.query_pipeline.add_component("text_embedder", self.text_embedder)
# self.query_pipeline.add_component("bm25_retriever", self.bm25_retriever)
self.query_pipeline.add_component("embedding_retriever", self.embedding_retriever)
# self.query_pipeline.add_component("document_joiner", self.document_joiner)
# self.query_pipeline.add_component("ranker", self.ranker)
self.query_pipeline.add_component("prompt_builder", self.prompt_builder)
self.query_pipeline.add_component("llm", self.llm)
# Connect components
self.query_pipeline.connect("text_embedder.embedding", "embedding_retriever.query_embedding")
# self.query_pipeline.connect("bm25_retriever", "document_joiner.documents_1")
# self.query_pipeline.connect("embedding_retriever", "document_joiner.documents_2")
# self.query_pipeline.connect("document_joiner", "ranker")
# self.query_pipeline.connect("ranker", "prompt_builder.documents")
self.query_pipeline.connect("embedding_retriever.documents", "prompt_builder.documents")
self.query_pipeline.connect("prompt_builder.prompt", "llm")
def index_documents(self, documents: List[Dict[str, Any]]):
"""
Index documents in the document store.
Args:
documents: List of documents to index
"""
logger.info(f"Indexing {len(documents)} documents")
try:
self.indexing_pipeline.run(
{"document_cleaner": {"documents": documents}}
)
logger.info("Indexing completed successfully")
except Exception as e:
logger.error(f"Error during indexing: {e}")
def query(self, question: str, top_k: int = 5) -> Dict[str, Any]:
"""
Query the RAG pipeline with a question.
Args:
question: The question to ask
top_k: Number of documents to retrieve
Returns:
Dictionary containing the answer and retrieved documents
"""
logger.info(f"Querying with question: {question}")
try:
# Update top_k for retrievers
self.bm25_retriever.top_k = top_k
self.embedding_retriever.top_k = top_k
# Run the query pipeline
result = self.query_pipeline.run({
"text_embedder": {"text": question},
# "bm25_retriever": {"query": question},
"prompt_builder": {"question": question}
})
# Extract answer and documents
answer = result["llm"]["replies"][0]
# documents = result["embedding_retriever"]["documents"]
return {
"answer": answer,
"documents": [], #documents,
"question": question
}
except Exception as e:
logger.error(f"Error during query: {e}")
return {
"answer": f"An error occurred: {str(e)}",
"documents": [],
"question": question
}
def get_document_count(self) -> int:
"""
Get the number of documents in the document store.
Returns:
Document count
"""
return self.document_store.count_documents()