SwatGarg commited on
Commit
111d456
1 Parent(s): 1346cad

Create model_pipelineV2.py

Browse files
Files changed (1) hide show
  1. model_pipelineV2.py +126 -0
model_pipelineV2.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implement Classification
2
+
3
+ import os
4
+ from langchain.prompts.chat import ChatPromptTemplate
5
+ from langchain.memory import ConversationBufferMemory
6
+ from generator import load_llm
7
+ from langchain.prompts import PromptTemplate
8
+ from retrieverV2 import process_pdf_document, create_vectorstore, rag_retriever
9
+ from langchain.schema import format_document
10
+ from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
11
+ from langchain_core.runnables import RunnableParallel
12
+ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
13
+ from operator import itemgetter
14
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
15
+
16
+ class ModelPipeLine:
17
+
18
+ DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
19
+ def __init__(self):
20
+ self.curr_dir = os.path.dirname(__file__)
21
+ self.knowledge_dir = os.path.join(os.path.dirname(os.path.dirname(self.curr_dir)), 'data','Knowledge_base')
22
+ self.prompt_dir = os.path.join(os.path.dirname(os.path.dirname(self.curr_dir)), 'src','prompts')
23
+ self.child_splitter = RecursiveCharacterTextSplitter(chunk_size=200)
24
+ self.parent_splitter = RecursiveCharacterTextSplitter(chunk_size=500)
25
+ self.documents = process_pdf_document([os.path.join(self.knowledge_dir, 'depression_1.pdf'), os.path.join(self.knowledge_dir, 'depression_2.pdf')])
26
+ self.vectorstore, self.store = create_vectorstore()
27
+ self.retriever = rag_retriever(self.vectorstore, self.store, self.documents, self.parent_splitter, self.child_splitter) # Create the retriever
28
+ self.llm = load_llm() # Load the LLM model
29
+ self.memory = ConversationBufferMemory(return_messages=True,
30
+ output_key="answer",
31
+ input_key="question") # Instantiate ConversationBufferMemory
32
+
33
+ def get_prompts(self, system_file_path='system_prompt_template.txt',
34
+ condense_file_path='condense_question_prompt_template.txt'):
35
+
36
+ with open(os.path.join(self.prompt_dir, system_file_path), 'r') as f:
37
+ system_prompt_template = f.read()
38
+
39
+ with open(os.path.join(self.prompt_dir, condense_file_path), 'r') as f:
40
+ condense_question_prompt = f.read()
41
+
42
+ # create message templates
43
+ ANSWER_PROMPT = ChatPromptTemplate.from_template(system_prompt_template)
44
+
45
+ # create message templates
46
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt)
47
+
48
+ return ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT
49
+
50
+
51
+ def _combine_documents(self,docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
52
+
53
+ doc_strings = [format_document(doc, document_prompt) for doc in docs]
54
+ return document_separator.join(doc_strings)
55
+
56
+ def create_final_chain(self):
57
+
58
+ answer_prompt, condense_question_prompt = self.get_prompts()
59
+ # This adds a "memory" key to the input object
60
+ loaded_memory = RunnablePassthrough.assign(
61
+ chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"),
62
+ )
63
+ # Now we calculate the standalone question
64
+ standalone_question = {
65
+ "standalone_question": {
66
+ "question": lambda x: x["question"],
67
+ "chat_history": lambda x: get_buffer_string(x["chat_history"]),
68
+ }
69
+ | condense_question_prompt
70
+ | self.llm,
71
+ }
72
+ # Now we retrieve the documents
73
+ retrieved_documents = {
74
+ "docs": itemgetter("standalone_question") | self.retriever,
75
+ "question": lambda x: x["standalone_question"],
76
+ }
77
+ # Now we construct the inputs for the final prompt
78
+ final_inputs = {
79
+ "context": lambda x: self._combine_documents(x["docs"]),
80
+ "question": itemgetter("question"),
81
+ }
82
+ # And finally, we do the part that returns the answers
83
+ answer = {
84
+ "answer": final_inputs | answer_prompt | self.llm,
85
+ "docs": itemgetter("docs"),
86
+ }
87
+ # And now we put it all together!
88
+ final_chain = loaded_memory | standalone_question | retrieved_documents | answer
89
+
90
+ return final_chain
91
+
92
+
93
+ def call_conversational_rag(self,question, chain):
94
+ """
95
+ Calls a conversational RAG (Retrieval-Augmented Generation) model to generate an answer to a given question.
96
+
97
+ This function sends a question to the RAG model, retrieves the answer, and stores the question-answer pair in memory
98
+ for context in future interactions.
99
+
100
+ Parameters:
101
+ question (str): The question to be answered by the RAG model.
102
+ chain (LangChain object): An instance of LangChain which encapsulates the RAG model and its functionality.
103
+ memory (Memory object): An object used for storing the context of the conversation.
104
+
105
+ Returns:
106
+ dict: A dictionary containing the generated answer from the RAG model.
107
+ """
108
+
109
+ # Prepare the input for the RAG model
110
+ inputs = {"question": question}
111
+
112
+ # Invoke the RAG model to get an answer
113
+ result = chain.invoke(inputs)
114
+
115
+ # Save the current question and its answer to memory for future context
116
+ self.memory.save_context(inputs, {"answer": result["answer"]})
117
+
118
+ # Return the result
119
+ return result
120
+
121
+
122
+ ml_pipeline = ModelPipeLine()
123
+ final_chain = ml_pipeline.create_final_chain()
124
+ question = "i am feeling sad"
125
+ res = ml_pipeline.call_conversational_rag(question,final_chain)
126
+ print(res['answer'])