File size: 5,185 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from typing import Any, Dict, List, Type, Union

from pydantic import BaseModel, Field

from langchain.chains.llm import LLMChain
from langchain.graphs import NetworkxEntityGraph
from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
    ENTITY_EXTRACTION_PROMPT,
    KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import (
    BaseLanguageModel,
    BaseMessage,
    SystemMessage,
    get_buffer_string,
)


class ConversationKGMemory(BaseChatMemory, BaseModel):
    """Knowledge graph memory for storing conversation memory.

    Integrates with external knowledge graph to store and retrieve
    information about knowledge triples in the conversation.
    """

    k: int = 2
    human_prefix: str = "Human"
    ai_prefix: str = "AI"
    kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
    knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
    entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
    llm: BaseLanguageModel
    summary_message_cls: Type[BaseMessage] = SystemMessage
    """Number of previous utterances to include in the context."""
    memory_key: str = "history"  #: :meta private:

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Return history buffer."""
        entities = self._get_current_entities(inputs)
        summaries = {}
        for entity in entities:
            knowledge = self.kg.get_entity_knowledge(entity)
            if knowledge:
                summaries[entity] = ". ".join(knowledge) + "."
        context: Union[str, List]
        if summaries:
            summary_strings = [
                f"On {entity}: {summary}" for entity, summary in summaries.items()
            ]
            if self.return_messages:
                context = [
                    self.summary_message_cls(content=text) for text in summary_strings
                ]
            else:
                context = "\n".join(summary_strings)
        else:
            if self.return_messages:
                context = []
            else:
                context = ""
        return {self.memory_key: context}

    @property
    def memory_variables(self) -> List[str]:
        """Will always return list of memory variables.

        :meta private:
        """
        return [self.memory_key]

    def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
        """Get the input key for the prompt."""
        if self.input_key is None:
            return get_prompt_input_key(inputs, self.memory_variables)
        return self.input_key

    def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
        """Get the output key for the prompt."""
        if self.output_key is None:
            if len(outputs) != 1:
                raise ValueError(f"One output key expected, got {outputs.keys()}")
            return list(outputs.keys())[0]
        return self.output_key

    def get_current_entities(self, input_string: str) -> List[str]:
        chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
        buffer_string = get_buffer_string(
            self.chat_memory.messages[-self.k * 2 :],
            human_prefix=self.human_prefix,
            ai_prefix=self.ai_prefix,
        )
        output = chain.predict(
            history=buffer_string,
            input=input_string,
        )
        return get_entities(output)

    def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
        """Get the current entities in the conversation."""
        prompt_input_key = self._get_prompt_input_key(inputs)
        return self.get_current_entities(inputs[prompt_input_key])

    def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
        chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
        buffer_string = get_buffer_string(
            self.chat_memory.messages[-self.k * 2 :],
            human_prefix=self.human_prefix,
            ai_prefix=self.ai_prefix,
        )
        output = chain.predict(
            history=buffer_string,
            input=input_string,
            verbose=True,
        )
        knowledge = parse_triples(output)
        return knowledge

    def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
        """Get and update knowledge graph from the conversation history."""
        prompt_input_key = self._get_prompt_input_key(inputs)
        knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
        for triple in knowledge:
            self.kg.add_triple(triple)

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
        """Save context from this conversation to buffer."""
        super().save_context(inputs, outputs)
        self._get_and_update_kg(inputs)

    def clear(self) -> None:
        """Clear memory contents."""
        super().clear()
        self.kg.clear()