SwatGarg commited on
Commit
eb5e96e
1 Parent(s): c2514dd

Create model_pipeline.py

Browse files
Files changed (1) hide show
  1. model_pipeline.py +115 -0
model_pipeline.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain.prompts.chat import ChatPromptTemplate
3
+ from langchain.memory import ConversationBufferMemory
4
+ from generator import load_llm
5
+ from langchain.prompts import PromptTemplate
6
+ from retriever import process_pdf_document, create_vectorstore, rag_retriever
7
+ from langchain.schema import format_document
8
+ from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
9
+ from langchain_core.runnables import RunnableParallel
10
+ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
11
+ from operator import itemgetter
12
+
13
+ class ModelPipeLine:
14
+
15
+ DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
16
+ def __init__(self):
17
+ self.curr_dir = os.path.dirname(__file__)
18
+ self.prompt_dir = os.path.join(os.path.dirname(os.path.dirname(self.curr_dir)), 'src','prompts')
19
+ self.vectorstore, self.store = create_vectorstore()
20
+ self.retriever = rag_retriever(self.vectorstore) # Create the retriever
21
+ self.llm = load_llm() # Load the LLM model
22
+ self.memory = ConversationBufferMemory(return_messages=True,
23
+ output_key="answer",
24
+ input_key="question") # Instantiate ConversationBufferMemory
25
+
26
+ def get_prompts(self, system_file_path='system_prompt_template.txt',
27
+ condense_file_path='condense_question_prompt_template.txt'):
28
+
29
+ with open(os.path.join(self.prompt_dir, system_file_path), 'r') as f:
30
+ system_prompt_template = f.read()
31
+
32
+ with open(os.path.join(self.prompt_dir, condense_file_path), 'r') as f:
33
+ condense_question_prompt = f.read()
34
+
35
+ # create message templates
36
+ ANSWER_PROMPT = ChatPromptTemplate.from_template(system_prompt_template)
37
+
38
+ # create message templates
39
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt)
40
+
41
+ return ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT
42
+
43
+
44
+ def _combine_documents(self,docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
45
+
46
+ doc_strings = [format_document(doc, document_prompt) for doc in docs]
47
+ return document_separator.join(doc_strings)
48
+
49
+ def create_final_chain(self):
50
+
51
+ answer_prompt, condense_question_prompt = self.get_prompts()
52
+ # This adds a "memory" key to the input object
53
+ loaded_memory = RunnablePassthrough.assign(
54
+ chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"),
55
+ )
56
+ # Now we calculate the standalone question
57
+ standalone_question = {
58
+ "standalone_question": {
59
+ "question": lambda x: x["question"],
60
+ "chat_history": lambda x: get_buffer_string(x["chat_history"]),
61
+ }
62
+ | condense_question_prompt
63
+ | self.llm,
64
+ }
65
+ # Now we retrieve the documents
66
+ retrieved_documents = {
67
+ "docs": itemgetter("standalone_question") | self.retriever,
68
+ "question": lambda x: x["standalone_question"],
69
+ }
70
+ # Now we construct the inputs for the final prompt
71
+ final_inputs = {
72
+ "context": lambda x: self._combine_documents(x["docs"]),
73
+ "question": itemgetter("question"),
74
+ }
75
+ # And finally, we do the part that returns the answers
76
+ answer = {
77
+ "answer": final_inputs | answer_prompt | self.llm,
78
+ "docs": itemgetter("docs"),
79
+ }
80
+ # And now we put it all together!
81
+ final_chain = loaded_memory | standalone_question | retrieved_documents | answer
82
+ return final_chain
83
+ def call_conversational_rag(self,question, chain):
84
+ """
85
+ Calls a conversational RAG (Retrieval-Augmented Generation) model to generate an answer to a given question.
86
+
87
+ This function sends a question to the RAG model, retrieves the answer, and stores the question-answer pair in memory
88
+ for context in future interactions.
89
+
90
+ Parameters:
91
+ question (str): The question to be answered by the RAG model.
92
+ chain (LangChain object): An instance of LangChain which encapsulates the RAG model and its functionality.
93
+ memory (Memory object): An object used for storing the context of the conversation.
94
+
95
+ Returns:
96
+ dict: A dictionary containing the generated answer from the RAG model.
97
+ """
98
+
99
+ # Prepare the input for the RAG model
100
+ inputs = {"question": question}
101
+
102
+ # Invoke the RAG model to get an answer
103
+ result = chain.invoke(inputs)
104
+
105
+ # Save the current question and its answer to memory for future context
106
+ self.memory.save_context(inputs, {"answer": result["answer"]})
107
+
108
+ # Return the result
109
+ return result
110
+
111
+ ml_pipeline = ModelPipeLine()
112
+ final_chain = ml_pipeline.create_final_chain()
113
+ question = "i am feeling sad"
114
+ res = ml_pipeline.call_conversational_rag(question,final_chain)
115
+ print(res['answer'])