File size: 5,631 Bytes
111d456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb3b6fe
c795b61
111d456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Implement Classification

import os
from langchain.prompts.chat import ChatPromptTemplate
from langchain.memory import ConversationBufferMemory
from generator import load_llm
from langchain.prompts import PromptTemplate
from retrieverV2 import process_pdf_document, create_vectorstore, rag_retriever
from langchain.schema import format_document
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
from langchain_core.runnables import RunnableParallel
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from operator import itemgetter
from langchain_text_splitters import RecursiveCharacterTextSplitter

class ModelPipeLine:

    DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
    def __init__(self):
        self.curr_dir = os.path.dirname(__file__)
        self.knowledge_dir = os.path.join(os.path.dirname(os.path.dirname(self.curr_dir)))
        self.prompt_dir = os.path.join(os.path.dirname(os.path.dirname(self.curr_dir)),'prompts')
        self.child_splitter = RecursiveCharacterTextSplitter(chunk_size=200)
        self.parent_splitter = RecursiveCharacterTextSplitter(chunk_size=500)
        self.documents = process_pdf_document([os.path.join(self.knowledge_dir, 'depression_1.pdf'), os.path.join(self.knowledge_dir, 'depression_2.pdf')])
        self.vectorstore, self.store = create_vectorstore()
        self.retriever = rag_retriever(self.vectorstore, self.store, self.documents, self.parent_splitter, self.child_splitter) # Create the retriever
        self.llm = load_llm() # Load the LLM model
        self.memory = ConversationBufferMemory(return_messages=True, 
                                               output_key="answer", 
                                               input_key="question") # Instantiate ConversationBufferMemory
    
    def get_prompts(self, system_file_path='system_prompt_template.txt', 
                    condense_file_path='condense_question_prompt_template.txt'):
        
        with open(os.path.join(self.prompt_dir, system_file_path), 'r') as f:
            system_prompt_template = f.read()

        with open(os.path.join(self.prompt_dir, condense_file_path), 'r') as f:
            condense_question_prompt = f.read()  

        # create message templates
        ANSWER_PROMPT = ChatPromptTemplate.from_template(system_prompt_template)

        # create message templates
        CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt)

        return ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT
    

    def _combine_documents(self,docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
        
        doc_strings = [format_document(doc, document_prompt) for doc in docs]
        return document_separator.join(doc_strings)
       
    def create_final_chain(self):

        answer_prompt, condense_question_prompt = self.get_prompts()
        # This adds a "memory" key to the input object
        loaded_memory = RunnablePassthrough.assign(
            chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"),
        )
        # Now we calculate the standalone question
        standalone_question = {
            "standalone_question": {
                "question": lambda x: x["question"],
                "chat_history": lambda x: get_buffer_string(x["chat_history"]),
            }
            | condense_question_prompt
            | self.llm,
        }
        # Now we retrieve the documents
        retrieved_documents = {
            "docs": itemgetter("standalone_question") | self.retriever,
            "question": lambda x: x["standalone_question"],
        }
        # Now we construct the inputs for the final prompt
        final_inputs = {
            "context": lambda x: self._combine_documents(x["docs"]),
            "question": itemgetter("question"),
        }
        # And finally, we do the part that returns the answers
        answer = {
            "answer": final_inputs | answer_prompt | self.llm,
            "docs": itemgetter("docs"),
        }
        # And now we put it all together!
        final_chain = loaded_memory | standalone_question | retrieved_documents | answer

        return final_chain
    

    def call_conversational_rag(self,question, chain):
        """
        Calls a conversational RAG (Retrieval-Augmented Generation) model to generate an answer to a given question.

        This function sends a question to the RAG model, retrieves the answer, and stores the question-answer pair in memory 
        for context in future interactions.

        Parameters:
        question (str): The question to be answered by the RAG model.
        chain (LangChain object): An instance of LangChain which encapsulates the RAG model and its functionality.
        memory (Memory object): An object used for storing the context of the conversation.

        Returns:
        dict: A dictionary containing the generated answer from the RAG model.
        """
        
        # Prepare the input for the RAG model
        inputs = {"question": question}

        # Invoke the RAG model to get an answer
        result = chain.invoke(inputs)
        
        # Save the current question and its answer to memory for future context
        self.memory.save_context(inputs, {"answer": result["answer"]})
        
        # Return the result
        return result


ml_pipeline = ModelPipeLine()
final_chain = ml_pipeline.create_final_chain()
question = "i am feeling sad"
res = ml_pipeline.call_conversational_rag(question,final_chain)
print(res['answer'])