File size: 2,912 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import warnings
from typing import Any, Dict, List, Set

from langchain_core.memory import BaseMemory
from langchain_core.pydantic_v1 import validator

from langchain.memory.chat_memory import BaseChatMemory


class CombinedMemory(BaseMemory):
    """Combining multiple memories' data together."""

    memories: List[BaseMemory]
    """For tracking all the memories that should be accessed."""

    @validator("memories")
    def check_repeated_memory_variable(
        cls, value: List[BaseMemory]
    ) -> List[BaseMemory]:
        all_variables: Set[str] = set()
        for val in value:
            overlap = all_variables.intersection(val.memory_variables)
            if overlap:
                raise ValueError(
                    f"The same variables {overlap} are found in multiple"
                    "memory object, which is not allowed by CombinedMemory."
                )
            all_variables |= set(val.memory_variables)

        return value

    @validator("memories")
    def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]:
        """Check that if memories are of type BaseChatMemory that input keys exist."""
        for val in value:
            if isinstance(val, BaseChatMemory):
                if val.input_key is None:
                    warnings.warn(
                        "When using CombinedMemory, "
                        "input keys should be so the input is known. "
                        f" Was not set on {val}"
                    )
        return value

    @property
    def memory_variables(self) -> List[str]:
        """All the memory variables that this instance provides."""
        """Collected from the all the linked memories."""

        memory_variables = []

        for memory in self.memories:
            memory_variables.extend(memory.memory_variables)

        return memory_variables

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
        """Load all vars from sub-memories."""
        memory_data: Dict[str, Any] = {}

        # Collect vars from all sub-memories
        for memory in self.memories:
            data = memory.load_memory_variables(inputs)
            for key, value in data.items():
                if key in memory_data:
                    raise ValueError(
                        f"The variable {key} is repeated in the CombinedMemory."
                    )
                memory_data[key] = value

        return memory_data

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
        """Save context from this session for every memory."""
        # Save context for all sub-memories
        for memory in self.memories:
            memory.save_context(inputs, outputs)

    def clear(self) -> None:
        """Clear context from this session for every memory."""
        for memory in self.memories:
            memory.clear()