|
import warnings |
|
from typing import Any, Dict, List, Set |
|
|
|
from langchain_core.memory import BaseMemory |
|
from pydantic import field_validator |
|
|
|
from langchain.memory.chat_memory import BaseChatMemory |
|
|
|
|
|
class CombinedMemory(BaseMemory): |
|
"""Combining multiple memories' data together.""" |
|
|
|
memories: List[BaseMemory] |
|
"""For tracking all the memories that should be accessed.""" |
|
|
|
@field_validator("memories") |
|
@classmethod |
|
def check_repeated_memory_variable( |
|
cls, value: List[BaseMemory] |
|
) -> List[BaseMemory]: |
|
all_variables: Set[str] = set() |
|
for val in value: |
|
overlap = all_variables.intersection(val.memory_variables) |
|
if overlap: |
|
raise ValueError( |
|
f"The same variables {overlap} are found in multiple" |
|
"memory object, which is not allowed by CombinedMemory." |
|
) |
|
all_variables |= set(val.memory_variables) |
|
|
|
return value |
|
|
|
@field_validator("memories") |
|
@classmethod |
|
def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]: |
|
"""Check that if memories are of type BaseChatMemory that input keys exist.""" |
|
for val in value: |
|
if isinstance(val, BaseChatMemory): |
|
if val.input_key is None: |
|
warnings.warn( |
|
"When using CombinedMemory, " |
|
"input keys should be so the input is known. " |
|
f" Was not set on {val}" |
|
) |
|
return value |
|
|
|
@property |
|
def memory_variables(self) -> List[str]: |
|
"""All the memory variables that this instance provides.""" |
|
"""Collected from the all the linked memories.""" |
|
|
|
memory_variables = [] |
|
|
|
for memory in self.memories: |
|
memory_variables.extend(memory.memory_variables) |
|
|
|
return memory_variables |
|
|
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: |
|
"""Load all vars from sub-memories.""" |
|
memory_data: Dict[str, Any] = {} |
|
|
|
|
|
for memory in self.memories: |
|
data = memory.load_memory_variables(inputs) |
|
for key, value in data.items(): |
|
if key in memory_data: |
|
raise ValueError( |
|
f"The variable {key} is repeated in the CombinedMemory." |
|
) |
|
memory_data[key] = value |
|
|
|
return memory_data |
|
|
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: |
|
"""Save context from this session for every memory.""" |
|
|
|
for memory in self.memories: |
|
memory.save_context(inputs, outputs) |
|
|
|
def clear(self) -> None: |
|
"""Clear context from this session for every memory.""" |
|
for memory in self.memories: |
|
memory.clear() |
|
|