NegotiateAI / src /rag /pipeline.py
TeresaK's picture
Update src/rag/pipeline.py
1357bc1 verified
raw
history blame
2.88 kB
import os
import pickle
from typing import Any
from dotenv import load_dotenv
from haystack.nodes import ( # type: ignore
AnswerParser,
EmbeddingRetriever,
PromptNode,
PromptTemplate,
)
from haystack.pipelines import Pipeline
from src.document_store.document_store import get_document_store
load_dotenv()
OPENAI_API_KEY = os.environ.get("OPEN_API_KEY")
class RAGPipeline:
def __init__(
self,
embedding_model: str,
prompt_template: str,
):
self.load_document_store()
self.embedding_model = embedding_model
self.prompt_template = prompt_template
self.retriever_node = self.generate_retriever_node()
self.prompt_node = self.generate_prompt_node()
self.update_embeddings()
self.pipe = self.build_pipeline()
def run(self, prompt: str, filters: dict) -> Any:
try:
result = self.pipe.run(query=prompt, params={"filters": filters})
return result
except Exception as e:
print(e)
return None
def build_pipeline(self):
pipe = Pipeline()
pipe.add_node(component=self.retriever_node, name="retriever", inputs=["Query"])
pipe.add_node(
component=self.prompt_node,
name="prompt_node",
inputs=["retriever"],
)
return pipe
def load_document_store(self):
if os.path.exists(os.path.join("database", "document_store.pkl")):
with open(
file=os.path.join("database", "document_store.pkl"), mode="rb"
) as f:
self.document_store = pickle.load(f)
else:
self.document_store = get_document_store()
def generate_retriever_node(self):
retriever_node = EmbeddingRetriever(
document_store=self.document_store,
embedding_model=self.embedding_model,
top_k=7,
)
return retriever_node
def update_embeddings(self):
if not os.path.exists(os.path.join("database", "document_store.pkl")):
self.document_store.update_embeddings(
self.retriever_node, update_existing_embeddings=True
)
with open(
file=os.path.join("database", "document_store.pkl"), mode="wb"
) as f:
pickle.dump(self.document_store, f)
def generate_prompt_node(self):
rag_prompt = PromptTemplate(
prompt=self.prompt_template,
output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
)
prompt_node = PromptNode(
model_name_or_path="gpt-4",
default_prompt_template=rag_prompt,
api_key=OPENAI_API_KEY,
max_length=4000,
model_kwargs={"temperature": 0.2, "max_tokens": 4096},
)
return prompt_node