rohit
remove logs
8946f02
raw
history blame
5.49 kB
from haystack import Document, Pipeline
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.builders import PromptBuilder
from datasets import load_dataset
from haystack.dataclasses import ChatMessage
from typing import Optional, List, Union, Dict
from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG
class RAGPipeline:
def __init__(
self,
dataset_config: Union[str, DatasetConfig],
documents: Optional[List[Union[str, Document]]] = None,
embedding_model: Optional[str] = None
):
"""
Initialize the RAG Pipeline.
Args:
dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object
documents: Optional list of documents to use instead of loading from a dataset
embedding_model: Optional override for embedding model
"""
# Load configuration
if isinstance(dataset_config, str):
if dataset_config not in DATASET_CONFIGS:
raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}")
self.config = DATASET_CONFIGS[dataset_config]
else:
self.config = dataset_config
# Load documents either from provided list or dataset
if documents is not None:
self.documents = documents
else:
dataset = load_dataset(self.config.name, split=self.config.split)
# Create documents with metadata based on configuration
self.documents = []
for doc in dataset:
# Create metadata dictionary from configured fields
meta = {}
if self.config.fields:
for meta_key, dataset_field in self.config.fields.items():
if dataset_field in doc:
meta[meta_key] = doc[dataset_field]
# Create document with content and metadata
document = Document(
content=doc[self.config.content_field],
meta=meta
)
self.documents.append(document)
# Documents loaded silently - remove verbose output
# Initialize components
self.document_store = InMemoryDocumentStore()
self.doc_embedder = SentenceTransformersDocumentEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"],
progress_bar=False
)
self.text_embedder = SentenceTransformersTextEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"],
progress_bar=False
)
self.text_embedder = SentenceTransformersTextEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"],
progress_bar=False
)
self.retriever = InMemoryEmbeddingRetriever(self.document_store)
# Warm up the embedders
self.doc_embedder.warm_up()
self.text_embedder.warm_up()
# Initialize prompt template
self.prompt_builder = PromptBuilder(template=self.config.prompt_template or """
Given the following context, please answer the question.
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Question: {{question}}
Answer:
""")
# Index documents
self._index_documents(self.documents)
# Build pipeline
self.pipeline = self._build_pipeline()
@classmethod
def from_preset(cls, preset_name: str):
"""
Create a pipeline from a preset configuration.
Args:
preset_name: Name of the preset configuration to use
"""
return cls(dataset_config=preset_name)
def _index_documents(self, documents):
# Embed and index documents
docs_with_embeddings = self.doc_embedder.run(documents)
self.document_store.write_documents(docs_with_embeddings["documents"])
def _build_pipeline(self):
pipeline = Pipeline()
pipeline.add_component("text_embedder", self.text_embedder)
pipeline.add_component("retriever", self.retriever)
pipeline.add_component("prompt_builder", self.prompt_builder)
# Connect components
pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
pipeline.connect("retriever", "prompt_builder")
return pipeline
def answer_question(self, question: str) -> str:
"""Run the RAG pipeline to answer a question"""
# First, embed the question and retrieve relevant documents
embedded_question = self.text_embedder.run(text=question)
retrieved_docs = self.retriever.run(query_embedding=embedded_question["embedding"])
# Then, build the prompt with retrieved documents
prompt_result = self.prompt_builder.run(
question=question,
documents=retrieved_docs["documents"]
)
# Return the formatted prompt (this will be processed by the main AI)
return prompt_result["prompt"]