Spaces:
Runtime error
Runtime error
"""Class for a VectorStore-backed memory object.""" | |
from typing import Any, Dict, List, Optional, Sequence, Union | |
from langchain_core.documents import Document | |
from langchain_core.pydantic_v1 import Field | |
from langchain_core.vectorstores import VectorStoreRetriever | |
from langchain.memory.chat_memory import BaseMemory | |
from langchain.memory.utils import get_prompt_input_key | |
class VectorStoreRetrieverMemory(BaseMemory): | |
"""VectorStoreRetriever-backed memory.""" | |
retriever: VectorStoreRetriever = Field(exclude=True) | |
"""VectorStoreRetriever object to connect to.""" | |
memory_key: str = "history" #: :meta private: | |
"""Key name to locate the memories in the result of load_memory_variables.""" | |
input_key: Optional[str] = None | |
"""Key name to index the inputs to load_memory_variables.""" | |
return_docs: bool = False | |
"""Whether or not to return the result of querying the database directly.""" | |
exclude_input_keys: Sequence[str] = Field(default_factory=tuple) | |
"""Input keys to exclude in addition to memory key when constructing the document""" | |
def memory_variables(self) -> List[str]: | |
"""The list of keys emitted from the load_memory_variables method.""" | |
return [self.memory_key] | |
def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str: | |
"""Get the input key for the prompt.""" | |
if self.input_key is None: | |
return get_prompt_input_key(inputs, self.memory_variables) | |
return self.input_key | |
def load_memory_variables( | |
self, inputs: Dict[str, Any] | |
) -> Dict[str, Union[List[Document], str]]: | |
"""Return history buffer.""" | |
input_key = self._get_prompt_input_key(inputs) | |
query = inputs[input_key] | |
docs = self.retriever.get_relevant_documents(query) | |
result: Union[List[Document], str] | |
if not self.return_docs: | |
result = "\n".join([doc.page_content for doc in docs]) | |
else: | |
result = docs | |
return {self.memory_key: result} | |
def _form_documents( | |
self, inputs: Dict[str, Any], outputs: Dict[str, str] | |
) -> List[Document]: | |
"""Format context from this conversation to buffer.""" | |
# Each document should only include the current turn, not the chat history | |
exclude = set(self.exclude_input_keys) | |
exclude.add(self.memory_key) | |
filtered_inputs = {k: v for k, v in inputs.items() if k not in exclude} | |
texts = [ | |
f"{k}: {v}" | |
for k, v in list(filtered_inputs.items()) + list(outputs.items()) | |
] | |
page_content = "\n".join(texts) | |
return [Document(page_content=page_content)] | |
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |
"""Save context from this conversation to buffer.""" | |
documents = self._form_documents(inputs, outputs) | |
self.retriever.add_documents(documents) | |
def clear(self) -> None: | |
"""Nothing to clear.""" | |