from haystack import Pipeline from haystack.components.builders import PromptBuilder from haystack.components.embedders import SentenceTransformersTextEmbedder from haystack.components.generators import OpenAIGenerator from haystack.utils import Secret from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever from src.settings import settings class RAGPipeline: def __init__( self, document_store, template: str, top_k: int, ) -> None: self.text_embedder: SentenceTransformersTextEmbedder # type: ignore self.retriever: QdrantEmbeddingRetriever # type: ignore self.prompt_builder: PromptBuilder # type: ignore self.llm_provider: OpenAIGenerator # type: ignore self.pipeline: Pipeline | None = None self.document_store = document_store self.template = template self.top_k = top_k self.get_text_embedder() self.get_retriever() self.get_prompt_builder() self.get_llm_provider() def run(self, query: str, filter_selections: dict[str, list] | None = None) -> dict: if not self.pipeline: self.build_pipeline() if self.pipeline: filters = RAGPipeline.build_filter(filter_selections=filter_selections) result = self.pipeline.run( data={ "text_embedder": {"text": query}, "retriever": {"filters": filters}, "prompt_builder": {"query": query}, }, include_outputs_from=["retriever", "llm"], ) return result def get_text_embedder(self) -> None: self.text_embedder = SentenceTransformersTextEmbedder( model=settings.qdrant_database.model ) self.text_embedder.warm_up() def get_retriever(self) -> None: self.retriever = QdrantEmbeddingRetriever( document_store=self.document_store, top_k=self.top_k ) def get_prompt_builder(self) -> None: self.prompt_builder = PromptBuilder(template=self.template) def get_llm_provider(self) -> None: self.llm_provider = OpenAIGenerator( model=settings.llm_provider.model, api_key=Secret.from_env_var("LLM_PROVIDER__API_KEY"), max_retries=3, generation_kwargs={"max_tokens": 5000, "temperature": 0.2}, ) @staticmethod def build_filter(filter_selections: dict[str, list] | None = None) -> dict: filters: dict[str, str | list[dict]] = {"operator": "AND", "conditions": []} if filter_selections: for meta_data_name, selections in filter_selections.items(): filters["conditions"].append( # type: ignore { "field": "meta." + meta_data_name, "operator": "in", "value": selections, } ) else: filters = {} return filters def build_pipeline(self): self.pipeline = Pipeline() self.pipeline.add_component("text_embedder", self.text_embedder) self.pipeline.add_component("retriever", self.retriever) self.pipeline.add_component("prompt_builder", self.prompt_builder) self.pipeline.add_component("llm", self.llm_provider) self.pipeline.connect("text_embedder.embedding", "retriever.query_embedding") self.pipeline.connect("retriever", "prompt_builder.documents") self.pipeline.connect("prompt_builder", "llm") if __name__ == "__main__": document_store = DocumentStore(index="inc_data") with open("src/rag/prompt_templates/inc_template.txt", "r") as file: template = file.read() pipeline = RAGPipeline( document_store=document_store.document_store, template=template, top_k=5 ) filter_selections = { "author": ["Malaysia", "Australia"], } result = pipeline.run( "What is Malaysia's position on plastic waste?", filter_selections=filter_selections, ) pass