File size: 4,304 Bytes
af9251e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from langchain.base_language import BaseLanguageModel
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
from langchain.chains import LLMChain, RetrievalQA
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma

from loader import DialogueLoader
from chains.dialogue_answering.prompts import (
    DIALOGUE_PREFIX,
    DIALOGUE_SUFFIX,
    SUMMARY_PROMPT
)


class DialogueWithSharedMemoryChains:
    zero_shot_react_llm: BaseLanguageModel = None
    ask_llm: BaseLanguageModel = None
    embeddings: HuggingFaceEmbeddings = None
    embedding_model: str = None
    vector_search_top_k: int = 6
    dialogue_path: str = None
    dialogue_loader: DialogueLoader = None
    device: str = None

    def __init__(self, zero_shot_react_llm: BaseLanguageModel = None, ask_llm: BaseLanguageModel = None,
                 params: dict = None):
        self.zero_shot_react_llm = zero_shot_react_llm
        self.ask_llm = ask_llm
        params = params or {}
        self.embedding_model = params.get('embedding_model', 'GanymedeNil/text2vec-large-chinese')
        self.vector_search_top_k = params.get('vector_search_top_k', 6)
        self.dialogue_path = params.get('dialogue_path', '')
        self.device = 'cuda' if params.get('use_cuda', False) else 'cpu'

        self.dialogue_loader = DialogueLoader(self.dialogue_path)
        self._init_cfg()
        self._init_state_of_history()
        self.memory_chain, self.memory = self._agents_answer()
        self.agent_chain = self._create_agent_chain()

    def _init_cfg(self):
        model_kwargs = {
            'device': self.device
        }
        self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model, model_kwargs=model_kwargs)

    def _init_state_of_history(self):
        documents = self.dialogue_loader.load()
        text_splitter = CharacterTextSplitter(chunk_size=3, chunk_overlap=1)
        texts = text_splitter.split_documents(documents)
        docsearch = Chroma.from_documents(texts, self.embeddings, collection_name="state-of-history")
        self.state_of_history = RetrievalQA.from_chain_type(llm=self.ask_llm, chain_type="stuff",
                                                            retriever=docsearch.as_retriever())

    def _agents_answer(self):

        memory = ConversationBufferMemory(memory_key="chat_history")
        readonly_memory = ReadOnlySharedMemory(memory=memory)
        memory_chain = LLMChain(
            llm=self.ask_llm,
            prompt=SUMMARY_PROMPT,
            verbose=True,
            memory=readonly_memory,  # use the read-only memory to prevent the tool from modifying the memory
        )
        return memory_chain, memory

    def _create_agent_chain(self):
        dialogue_participants = self.dialogue_loader.dialogue.participants_to_export()
        tools = [
            Tool(
                name="State of Dialogue History System",
                func=self.state_of_history.run,
                description=f"Dialogue with {dialogue_participants} - The answers in this section are very useful "
                            f"when searching for chat content between {dialogue_participants}. Input should be a "
                            f"complete question. "
            ),
            Tool(
                name="Summary",
                func=self.memory_chain.run,
                description="useful for when you summarize a conversation. The input to this tool should be a string, "
                            "representing who will read this summary. "
            )
        ]

        prompt = ZeroShotAgent.create_prompt(
            tools,
            prefix=DIALOGUE_PREFIX,
            suffix=DIALOGUE_SUFFIX,
            input_variables=["input", "chat_history", "agent_scratchpad"]
        )

        llm_chain = LLMChain(llm=self.zero_shot_react_llm, prompt=prompt)
        agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
        agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=self.memory)

        return agent_chain