Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import warnings | |
| from typing import Any, Dict, List, Set | |
| from langchain_core.memory import BaseMemory | |
| from langchain_core.pydantic_v1 import validator | |
| from langchain.memory.chat_memory import BaseChatMemory | |
| class CombinedMemory(BaseMemory): | |
| """Combining multiple memories' data together.""" | |
| memories: List[BaseMemory] | |
| """For tracking all the memories that should be accessed.""" | |
| def check_repeated_memory_variable( | |
| cls, value: List[BaseMemory] | |
| ) -> List[BaseMemory]: | |
| all_variables: Set[str] = set() | |
| for val in value: | |
| overlap = all_variables.intersection(val.memory_variables) | |
| if overlap: | |
| raise ValueError( | |
| f"The same variables {overlap} are found in multiple" | |
| "memory object, which is not allowed by CombinedMemory." | |
| ) | |
| all_variables |= set(val.memory_variables) | |
| return value | |
| def check_input_key(cls, value: List[BaseMemory]) -> List[BaseMemory]: | |
| """Check that if memories are of type BaseChatMemory that input keys exist.""" | |
| for val in value: | |
| if isinstance(val, BaseChatMemory): | |
| if val.input_key is None: | |
| warnings.warn( | |
| "When using CombinedMemory, " | |
| "input keys should be so the input is known. " | |
| f" Was not set on {val}" | |
| ) | |
| return value | |
| def memory_variables(self) -> List[str]: | |
| """All the memory variables that this instance provides.""" | |
| """Collected from the all the linked memories.""" | |
| memory_variables = [] | |
| for memory in self.memories: | |
| memory_variables.extend(memory.memory_variables) | |
| return memory_variables | |
| def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |
| """Load all vars from sub-memories.""" | |
| memory_data: Dict[str, Any] = {} | |
| # Collect vars from all sub-memories | |
| for memory in self.memories: | |
| data = memory.load_memory_variables(inputs) | |
| for key, value in data.items(): | |
| if key in memory_data: | |
| raise ValueError( | |
| f"The variable {key} is repeated in the CombinedMemory." | |
| ) | |
| memory_data[key] = value | |
| return memory_data | |
| def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |
| """Save context from this session for every memory.""" | |
| # Save context for all sub-memories | |
| for memory in self.memories: | |
| memory.save_context(inputs, outputs) | |
| def clear(self) -> None: | |
| """Clear context from this session for every memory.""" | |
| for memory in self.memories: | |
| memory.clear() | |