File size: 1,007 Bytes
4962437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import Any, Dict, List

from swarms.memory.base_memory import BaseChatMemory, get_prompt_input_key
from swarms.memory.base import VectorStoreRetriever

class AgentMemory(BaseChatMemory):
    retriever: VectorStoreRetriever
    """VectorStoreRetriever object to connect to."""

    @property
    def memory_variables(self) -> List[str]:
        return ["chat_history", "relevant_context"]

    def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
        """Get the input key for the prompt."""
        if self.input_key is None:
            return get_prompt_input_key(inputs, self.memory_variables)
        return self.input_key

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        input_key = self._get_prompt_input_key(inputs)
        query = inputs[input_key]
        docs = self.retriever.get_relevant_documents(query)
        return {
            "chat_history": self.chat_memory.messages[-10:],
            "relevant_context": docs,
        }