Spaces:
Runtime error
Runtime error
Create model_pipeline.py
Browse files- model_pipeline.py +115 -0
model_pipeline.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from langchain.prompts.chat import ChatPromptTemplate
|
3 |
+
from langchain.memory import ConversationBufferMemory
|
4 |
+
from generator import load_llm
|
5 |
+
from langchain.prompts import PromptTemplate
|
6 |
+
from retriever import process_pdf_document, create_vectorstore, rag_retriever
|
7 |
+
from langchain.schema import format_document
|
8 |
+
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
|
9 |
+
from langchain_core.runnables import RunnableParallel
|
10 |
+
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
11 |
+
from operator import itemgetter
|
12 |
+
|
13 |
+
class ModelPipeLine:
|
14 |
+
|
15 |
+
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
16 |
+
def __init__(self):
|
17 |
+
self.curr_dir = os.path.dirname(__file__)
|
18 |
+
self.prompt_dir = os.path.join(os.path.dirname(os.path.dirname(self.curr_dir)), 'src','prompts')
|
19 |
+
self.vectorstore, self.store = create_vectorstore()
|
20 |
+
self.retriever = rag_retriever(self.vectorstore) # Create the retriever
|
21 |
+
self.llm = load_llm() # Load the LLM model
|
22 |
+
self.memory = ConversationBufferMemory(return_messages=True,
|
23 |
+
output_key="answer",
|
24 |
+
input_key="question") # Instantiate ConversationBufferMemory
|
25 |
+
|
26 |
+
def get_prompts(self, system_file_path='system_prompt_template.txt',
|
27 |
+
condense_file_path='condense_question_prompt_template.txt'):
|
28 |
+
|
29 |
+
with open(os.path.join(self.prompt_dir, system_file_path), 'r') as f:
|
30 |
+
system_prompt_template = f.read()
|
31 |
+
|
32 |
+
with open(os.path.join(self.prompt_dir, condense_file_path), 'r') as f:
|
33 |
+
condense_question_prompt = f.read()
|
34 |
+
|
35 |
+
# create message templates
|
36 |
+
ANSWER_PROMPT = ChatPromptTemplate.from_template(system_prompt_template)
|
37 |
+
|
38 |
+
# create message templates
|
39 |
+
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt)
|
40 |
+
|
41 |
+
return ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT
|
42 |
+
|
43 |
+
|
44 |
+
def _combine_documents(self,docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
|
45 |
+
|
46 |
+
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
47 |
+
return document_separator.join(doc_strings)
|
48 |
+
|
49 |
+
def create_final_chain(self):
|
50 |
+
|
51 |
+
answer_prompt, condense_question_prompt = self.get_prompts()
|
52 |
+
# This adds a "memory" key to the input object
|
53 |
+
loaded_memory = RunnablePassthrough.assign(
|
54 |
+
chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"),
|
55 |
+
)
|
56 |
+
# Now we calculate the standalone question
|
57 |
+
standalone_question = {
|
58 |
+
"standalone_question": {
|
59 |
+
"question": lambda x: x["question"],
|
60 |
+
"chat_history": lambda x: get_buffer_string(x["chat_history"]),
|
61 |
+
}
|
62 |
+
| condense_question_prompt
|
63 |
+
| self.llm,
|
64 |
+
}
|
65 |
+
# Now we retrieve the documents
|
66 |
+
retrieved_documents = {
|
67 |
+
"docs": itemgetter("standalone_question") | self.retriever,
|
68 |
+
"question": lambda x: x["standalone_question"],
|
69 |
+
}
|
70 |
+
# Now we construct the inputs for the final prompt
|
71 |
+
final_inputs = {
|
72 |
+
"context": lambda x: self._combine_documents(x["docs"]),
|
73 |
+
"question": itemgetter("question"),
|
74 |
+
}
|
75 |
+
# And finally, we do the part that returns the answers
|
76 |
+
answer = {
|
77 |
+
"answer": final_inputs | answer_prompt | self.llm,
|
78 |
+
"docs": itemgetter("docs"),
|
79 |
+
}
|
80 |
+
# And now we put it all together!
|
81 |
+
final_chain = loaded_memory | standalone_question | retrieved_documents | answer
|
82 |
+
return final_chain
|
83 |
+
def call_conversational_rag(self,question, chain):
|
84 |
+
"""
|
85 |
+
Calls a conversational RAG (Retrieval-Augmented Generation) model to generate an answer to a given question.
|
86 |
+
|
87 |
+
This function sends a question to the RAG model, retrieves the answer, and stores the question-answer pair in memory
|
88 |
+
for context in future interactions.
|
89 |
+
|
90 |
+
Parameters:
|
91 |
+
question (str): The question to be answered by the RAG model.
|
92 |
+
chain (LangChain object): An instance of LangChain which encapsulates the RAG model and its functionality.
|
93 |
+
memory (Memory object): An object used for storing the context of the conversation.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
dict: A dictionary containing the generated answer from the RAG model.
|
97 |
+
"""
|
98 |
+
|
99 |
+
# Prepare the input for the RAG model
|
100 |
+
inputs = {"question": question}
|
101 |
+
|
102 |
+
# Invoke the RAG model to get an answer
|
103 |
+
result = chain.invoke(inputs)
|
104 |
+
|
105 |
+
# Save the current question and its answer to memory for future context
|
106 |
+
self.memory.save_context(inputs, {"answer": result["answer"]})
|
107 |
+
|
108 |
+
# Return the result
|
109 |
+
return result
|
110 |
+
|
111 |
+
ml_pipeline = ModelPipeLine()
|
112 |
+
final_chain = ml_pipeline.create_final_chain()
|
113 |
+
question = "i am feeling sad"
|
114 |
+
res = ml_pipeline.call_conversational_rag(question,final_chain)
|
115 |
+
print(res['answer'])
|