Spaces:
Runtime error
Runtime error
"""Cassandra-based chat message history, based on cassIO.""" | |
from __future__ import annotations | |
import json | |
import typing | |
from typing import List | |
if typing.TYPE_CHECKING: | |
from cassandra.cluster import Session | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.messages import ( | |
BaseMessage, | |
message_to_dict, | |
messages_from_dict, | |
) | |
DEFAULT_TABLE_NAME = "message_store" | |
DEFAULT_TTL_SECONDS = None | |
class CassandraChatMessageHistory(BaseChatMessageHistory): | |
"""Chat message history that stores history in Cassandra. | |
Args: | |
session_id: arbitrary key that is used to store the messages | |
of a single chat session. | |
session: a Cassandra `Session` object (an open DB connection) | |
keyspace: name of the keyspace to use. | |
table_name: name of the table to use. | |
ttl_seconds: time-to-live (seconds) for automatic expiration | |
of stored entries. None (default) for no expiration. | |
""" | |
def __init__( | |
self, | |
session_id: str, | |
session: Session, | |
keyspace: str, | |
table_name: str = DEFAULT_TABLE_NAME, | |
ttl_seconds: typing.Optional[int] = DEFAULT_TTL_SECONDS, | |
) -> None: | |
try: | |
from cassio.history import StoredBlobHistory | |
except (ImportError, ModuleNotFoundError): | |
raise ImportError( | |
"Could not import cassio python package. " | |
"Please install it with `pip install cassio`." | |
) | |
self.session_id = session_id | |
self.ttl_seconds = ttl_seconds | |
self.blob_history = StoredBlobHistory(session, keyspace, table_name) | |
def messages(self) -> List[BaseMessage]: # type: ignore | |
"""Retrieve all session messages from DB""" | |
message_blobs = self.blob_history.retrieve( | |
self.session_id, | |
) | |
items = [json.loads(message_blob) for message_blob in message_blobs] | |
messages = messages_from_dict(items) | |
return messages | |
def add_message(self, message: BaseMessage) -> None: | |
"""Write a message to the table""" | |
self.blob_history.store( | |
self.session_id, json.dumps(message_to_dict(message)), self.ttl_seconds | |
) | |
def clear(self) -> None: | |
"""Clear session memory from DB""" | |
self.blob_history.clear_session_id(self.session_id) | |