Spaces:
Runtime error
Runtime error
import warnings | |
from typing import Any, Dict, List, Set | |
from langchain_core.memory import BaseMemory | |
from langchain_core.pydantic_v1 import validator | |
from langchain.memory.chat_memory import BaseChatMemory | |
class CombinedMemory(BaseMemory): | |
"""Combining multiple memories' data together.""" | |
memories: List[BaseMemory] | |
"""For tracking all the memories that should be accessed.""" | |
def check_repeated_memory_variable( | |
cls, value: List[BaseMemory] | |
) -> List[BaseMemory]: | |
all_variables: Set[str] = set() | |
for val in value: | |
overlap = all_variables.intersection(val.memory_variables) | |
if overlap: | |
raise ValueError( | |
f"The same variables {overlap} are found in multiple" | |
"memory object, which is not allowed by CombinedMemory." | |
) | |
all_variables |= set(val.memory_variables) | |
return value | |
def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]: | |
"""Check that if memories are of type BaseChatMemory that input keys exist.""" | |
for val in value: | |
if isinstance(val, BaseChatMemory): | |
if val.input_key is None: | |
warnings.warn( | |
"When using CombinedMemory, " | |
"input keys should be so the input is known. " | |
f" Was not set on {val}" | |
) | |
return value | |
def memory_variables(self) -> List[str]: | |
"""All the memory variables that this instance provides.""" | |
"""Collected from the all the linked memories.""" | |
memory_variables = [] | |
for memory in self.memories: | |
memory_variables.extend(memory.memory_variables) | |
return memory_variables | |
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |
"""Load all vars from sub-memories.""" | |
memory_data: Dict[str, Any] = {} | |
# Collect vars from all sub-memories | |
for memory in self.memories: | |
data = memory.load_memory_variables(inputs) | |
for key, value in data.items(): | |
if key in memory_data: | |
raise ValueError( | |
f"The variable {key} is repeated in the CombinedMemory." | |
) | |
memory_data[key] = value | |
return memory_data | |
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |
"""Save context from this session for every memory.""" | |
# Save context for all sub-memories | |
for memory in self.memories: | |
memory.save_context(inputs, outputs) | |
def clear(self) -> None: | |
"""Clear context from this session for every memory.""" | |
for memory in self.memories: | |
memory.clear() | |