File size: 4,990 Bytes
eb5e96e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8b4f65
eb5e96e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from langchain.prompts.chat import ChatPromptTemplate
from langchain.memory import ConversationBufferMemory
from generator import load_llm
from langchain.prompts import PromptTemplate
from retriever import process_pdf_document, create_vectorstore, rag_retriever
from langchain.schema import format_document
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
from langchain_core.runnables import RunnableParallel
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from operator import itemgetter

class ModelPipeLine:

    DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
    def __init__(self):
        self.curr_dir = os.path.dirname(__file__)
        self.prompt_dir = 'prompts'
        self.vectorstore, self.store = create_vectorstore()
        self.retriever = rag_retriever(self.vectorstore) # Create the retriever
        self.llm = load_llm() # Load the LLM model
        self.memory = ConversationBufferMemory(return_messages=True, 
                                               output_key="answer", 
                                               input_key="question") # Instantiate ConversationBufferMemory
    
    def get_prompts(self, system_file_path='system_prompt_template.txt', 
                    condense_file_path='condense_question_prompt_template.txt'):
        
        with open(os.path.join(self.prompt_dir, system_file_path), 'r') as f:
            system_prompt_template = f.read()

        with open(os.path.join(self.prompt_dir, condense_file_path), 'r') as f:
            condense_question_prompt = f.read()  

        # create message templates
        ANSWER_PROMPT = ChatPromptTemplate.from_template(system_prompt_template)

        # create message templates
        CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt)

        return ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT
    

    def _combine_documents(self,docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
        
        doc_strings = [format_document(doc, document_prompt) for doc in docs]
        return document_separator.join(doc_strings)
       
    def create_final_chain(self):

        answer_prompt, condense_question_prompt = self.get_prompts()
        # This adds a "memory" key to the input object
        loaded_memory = RunnablePassthrough.assign(
            chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"),
        )
        # Now we calculate the standalone question
        standalone_question = {
            "standalone_question": {
                "question": lambda x: x["question"],
                "chat_history": lambda x: get_buffer_string(x["chat_history"]),
            }
            | condense_question_prompt
            | self.llm,
        }
        # Now we retrieve the documents
        retrieved_documents = {
            "docs": itemgetter("standalone_question") | self.retriever,
            "question": lambda x: x["standalone_question"],
        }
        # Now we construct the inputs for the final prompt
        final_inputs = {
            "context": lambda x: self._combine_documents(x["docs"]),
            "question": itemgetter("question"),
        }
        # And finally, we do the part that returns the answers
        answer = {
            "answer": final_inputs | answer_prompt | self.llm,
            "docs": itemgetter("docs"),
        }
        # And now we put it all together!
        final_chain = loaded_memory | standalone_question | retrieved_documents | answer
        return final_chain
    def call_conversational_rag(self,question, chain):
        """
        Calls a conversational RAG (Retrieval-Augmented Generation) model to generate an answer to a given question.

        This function sends a question to the RAG model, retrieves the answer, and stores the question-answer pair in memory 
        for context in future interactions.

        Parameters:
        question (str): The question to be answered by the RAG model.
        chain (LangChain object): An instance of LangChain which encapsulates the RAG model and its functionality.
        memory (Memory object): An object used for storing the context of the conversation.

        Returns:
        dict: A dictionary containing the generated answer from the RAG model.
        """
        
        # Prepare the input for the RAG model
        inputs = {"question": question}

        # Invoke the RAG model to get an answer
        result = chain.invoke(inputs)
        
        # Save the current question and its answer to memory for future context
        self.memory.save_context(inputs, {"answer": result["answer"]})
        
        # Return the result
        return result

ml_pipeline = ModelPipeLine()
final_chain = ml_pipeline.create_final_chain()
question = "i am feeling sad"
res = ml_pipeline.call_conversational_rag(question,final_chain)
print(res['answer'])