Oritsemisan commited on
Commit
7c664a7
1 Parent(s): 53fadf1

Update model_pipelineV2.py

Browse files
Files changed (1) hide show
  1. model_pipelineV2.py +39 -13
model_pipelineV2.py CHANGED
@@ -13,24 +13,53 @@ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
13
  from operator import itemgetter
14
  from langchain_text_splitters import RecursiveCharacterTextSplitter
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  class ModelPipeLine:
18
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
19
  def __init__(self):
20
  self.curr_dir = os.path.dirname(__file__)
21
- self.knowledge_dir = "knowledge"
22
- print("Knowledge Directory:", self.knowledge_dir)
23
  self.prompt_dir = 'prompts'
24
  self.child_splitter = RecursiveCharacterTextSplitter(chunk_size=200)
25
  self.parent_splitter = RecursiveCharacterTextSplitter(chunk_size=500)
26
- self.documents = process_pdf_document([os.path.join(self.knowledge_dir, 'depression_1.pdf'), os.path.join(self.knowledge_dir, 'depression_2.pdf')])
27
- self.vectorstore, self.store = create_vectorstore()
28
- self.retriever = rag_retriever(self.vectorstore, self.store, self.documents, self.parent_splitter, self.child_splitter) # Create the retriever
29
- self.llm = load_llm() # Load the LLM model
30
- self.memory = ConversationBufferMemory(return_messages=True,
31
- output_key="answer",
32
- input_key="question") # Instantiate ConversationBufferMemory
33
-
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def get_prompts(self, system_file_path='system_prompt_template.txt',
36
  condense_file_path='condense_question_prompt_template.txt'):
@@ -95,15 +124,12 @@ class ModelPipeLine:
95
  def call_conversational_rag(self,question, chain):
96
  """
97
  Calls a conversational RAG (Retrieval-Augmented Generation) model to generate an answer to a given question.
98
-
99
  This function sends a question to the RAG model, retrieves the answer, and stores the question-answer pair in memory
100
  for context in future interactions.
101
-
102
  Parameters:
103
  question (str): The question to be answered by the RAG model.
104
  chain (LangChain object): An instance of LangChain which encapsulates the RAG model and its functionality.
105
  memory (Memory object): An object used for storing the context of the conversation.
106
-
107
  Returns:
108
  dict: A dictionary containing the generated answer from the RAG model.
109
  """
 
13
  from operator import itemgetter
14
  from langchain_text_splitters import RecursiveCharacterTextSplitter
15
 
16
+ class VectorStoreSingleton:
17
+ _instance = None
18
+
19
+ @classmethod
20
+ def get_instance(cls):
21
+ if cls._instance is None:
22
+ cls._instance = create_vectorstore() # Your existing function to create the vectorstore
23
+ return cls._instance
24
+
25
+ class LanguageModelSingleton:
26
+ _instance = None
27
+
28
+ @classmethod
29
+ def get_instance(cls):
30
+ if cls._instance is None:
31
+ cls._instance = load_llm() # Your existing function to load the LLM
32
+ return cls._instance
33
+
34
 
35
  class ModelPipeLine:
36
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
37
  def __init__(self):
38
  self.curr_dir = os.path.dirname(__file__)
39
+ self.knowledge_dir = 'knowledge'
 
40
  self.prompt_dir = 'prompts'
41
  self.child_splitter = RecursiveCharacterTextSplitter(chunk_size=200)
42
  self.parent_splitter = RecursiveCharacterTextSplitter(chunk_size=500)
43
+ self._documents = None # Initialize as None for lazy loading
44
+ self.vectorstore, self.store = VectorStoreSingleton.get_instance()
45
+ self._retriever = None # Corrected: Initialize _retriever as None for lazy loading
46
+ self.llm = LanguageModelSingleton.get_instance()
47
+ self.memory = ConversationBufferMemory(return_messages=True, output_key="answer", input_key="question")
48
+
49
+ @property
50
+ def documents(self):
51
+ if self._documents is None:
52
+ self._documents = process_pdf_document([
53
+ os.path.join(self.knowledge_dir, 'depression_1.pdf'),
54
+ os.path.join(self.knowledge_dir, 'depression_2.pdf')
55
+ ])
56
+ return self._documents
57
+
58
+ @property
59
+ def retriever(self):
60
+ if self._retriever is None:
61
+ self._retriever = rag_retriever(self.vectorstore, self.store, self.documents, self.parent_splitter, self.child_splitter)
62
+ return self._retriever
63
 
64
  def get_prompts(self, system_file_path='system_prompt_template.txt',
65
  condense_file_path='condense_question_prompt_template.txt'):
 
124
  def call_conversational_rag(self,question, chain):
125
  """
126
  Calls a conversational RAG (Retrieval-Augmented Generation) model to generate an answer to a given question.
 
127
  This function sends a question to the RAG model, retrieves the answer, and stores the question-answer pair in memory
128
  for context in future interactions.
 
129
  Parameters:
130
  question (str): The question to be answered by the RAG model.
131
  chain (LangChain object): An instance of LangChain which encapsulates the RAG model and its functionality.
132
  memory (Memory object): An object used for storing the context of the conversation.
 
133
  Returns:
134
  dict: A dictionary containing the generated answer from the RAG model.
135
  """