Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import pickle | |
from typing import Any | |
from dotenv import load_dotenv | |
from haystack.nodes import ( # type: ignore | |
AnswerParser, | |
EmbeddingRetriever, | |
PromptNode, | |
PromptTemplate, | |
) | |
from haystack.pipelines import Pipeline | |
from src.document_store.document_store import get_document_store | |
load_dotenv() | |
OPENAI_API_KEY = os.environ.get("OPEN_API_KEY") | |
class RAGPipeline: | |
def __init__( | |
self, | |
embedding_model: str, | |
prompt_template: str, | |
): | |
self.load_document_store() | |
self.embedding_model = embedding_model | |
self.prompt_template = prompt_template | |
self.retriever_node = self.generate_retriever_node() | |
self.prompt_node = self.generate_prompt_node() | |
self.update_embeddings() | |
self.pipe = self.build_pipeline() | |
def run(self, prompt: str, filters: dict) -> Any: | |
try: | |
result = self.pipe.run(query=prompt, params={"filters": filters}) | |
return result | |
except Exception as e: | |
print(e) | |
return None | |
def build_pipeline(self): | |
pipe = Pipeline() | |
pipe.add_node(component=self.retriever_node, name="retriever", inputs=["Query"]) | |
pipe.add_node( | |
component=self.prompt_node, | |
name="prompt_node", | |
inputs=["retriever"], | |
) | |
return pipe | |
def load_document_store(self): | |
if os.path.exists(os.path.join("database", "document_store.pkl")): | |
with open( | |
file=os.path.join("database", "document_store.pkl"), mode="rb" | |
) as f: | |
self.document_store = pickle.load(f) | |
else: | |
self.document_store = get_document_store() | |
def generate_retriever_node(self): | |
retriever_node = EmbeddingRetriever( | |
document_store=self.document_store, | |
embedding_model=self.embedding_model, | |
top_k=7, | |
) | |
return retriever_node | |
def update_embeddings(self): | |
if not os.path.exists(os.path.join("database", "document_store.pkl")): | |
self.document_store.update_embeddings( | |
self.retriever_node, update_existing_embeddings=True | |
) | |
with open( | |
file=os.path.join("database", "document_store.pkl"), mode="wb" | |
) as f: | |
pickle.dump(self.document_store, f) | |
def generate_prompt_node(self): | |
rag_prompt = PromptTemplate( | |
prompt=self.prompt_template, | |
output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"), | |
) | |
prompt_node = PromptNode( | |
model_name_or_path="gpt-4", | |
default_prompt_template=rag_prompt, | |
api_key=OPENAI_API_KEY, | |
max_length=4000, | |
model_kwargs={"temperature": 0.2, "max_tokens": 4096}, | |
) | |
return prompt_node | |