import os import pickle from typing import Any from dotenv import load_dotenv from haystack.nodes import ( # type: ignore AnswerParser, EmbeddingRetriever, PromptNode, PromptTemplate, ) from haystack.pipelines import Pipeline from src.document_store.document_store import get_document_store load_dotenv() OPENAI_API_KEY = os.environ.get("OPEN_API_KEY") class RAGPipeline: def __init__( self, embedding_model: str, prompt_template: str, ): self.load_document_store() self.embedding_model = embedding_model self.prompt_template = prompt_template self.retriever_node = self.generate_retriever_node() self.prompt_node = self.generate_prompt_node() self.update_embeddings() self.pipe = self.build_pipeline() def run(self, prompt: str, filters: dict) -> Any: try: result = self.pipe.run(query=prompt, params={"filters": filters}) return result except Exception as e: print(e) return None def build_pipeline(self): pipe = Pipeline() pipe.add_node(component=self.retriever_node, name="retriever", inputs=["Query"]) pipe.add_node( component=self.prompt_node, name="prompt_node", inputs=["retriever"], ) return pipe def load_document_store(self): if os.path.exists(os.path.join("database", "document_store.pkl")): with open( file=os.path.join("database", "document_store.pkl"), mode="rb" ) as f: self.document_store = pickle.load(f) else: self.document_store = get_document_store() def generate_retriever_node(self): retriever_node = EmbeddingRetriever( document_store=self.document_store, embedding_model=self.embedding_model, top_k=7, ) return retriever_node def update_embeddings(self): if not os.path.exists(os.path.join("database", "document_store.pkl")): self.document_store.update_embeddings( self.retriever_node, update_existing_embeddings=True ) with open( file=os.path.join("database", "document_store.pkl"), mode="wb" ) as f: pickle.dump(self.document_store, f) def generate_prompt_node(self): rag_prompt = PromptTemplate( prompt=self.prompt_template, output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"), ) prompt_node = PromptNode( model_name_or_path="gpt-4", default_prompt_template=rag_prompt, api_key=OPENAI_API_KEY, max_length=4000, model_kwargs={"temperature": 0.2, "max_tokens": 4096}, ) return prompt_node