Spaces:
Runtime error
Runtime error
File size: 6,510 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""Azure CosmosDB Memory History."""
from __future__ import annotations
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Any, List, Optional, Type
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
messages_from_dict,
messages_to_dict,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from azure.cosmos import ContainerProxy
class CosmosDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history backed by Azure CosmosDB."""
def __init__(
self,
cosmos_endpoint: str,
cosmos_database: str,
cosmos_container: str,
session_id: str,
user_id: str,
credential: Any = None,
connection_string: Optional[str] = None,
ttl: Optional[int] = None,
cosmos_client_kwargs: Optional[dict] = None,
):
"""
Initializes a new instance of the CosmosDBChatMessageHistory class.
Make sure to call prepare_cosmos or use the context manager to make
sure your database is ready.
Either a credential or a connection string must be provided.
:param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account.
:param cosmos_database: The name of the database to use.
:param cosmos_container: The name of the container to use.
:param session_id: The session ID to use, can be overwritten while loading.
:param user_id: The user ID to use, can be overwritten while loading.
:param credential: The credential to use to authenticate to Azure Cosmos DB.
:param connection_string: The connection string to use to authenticate.
:param ttl: The time to live (in seconds) to use for documents in the container.
:param cosmos_client_kwargs: Additional kwargs to pass to the CosmosClient.
"""
self.cosmos_endpoint = cosmos_endpoint
self.cosmos_database = cosmos_database
self.cosmos_container = cosmos_container
self.credential = credential
self.conn_string = connection_string
self.session_id = session_id
self.user_id = user_id
self.ttl = ttl
self.messages: List[BaseMessage] = []
try:
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
CosmosClient,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
"Please install it with `pip install azure-cosmos`."
) from exc
if self.credential:
self._client = CosmosClient(
url=self.cosmos_endpoint,
credential=self.credential,
**cosmos_client_kwargs or {},
)
elif self.conn_string:
self._client = CosmosClient.from_connection_string(
conn_str=self.conn_string,
**cosmos_client_kwargs or {},
)
else:
raise ValueError("Either a connection string or a credential must be set.")
self._container: Optional[ContainerProxy] = None
def prepare_cosmos(self) -> None:
"""Prepare the CosmosDB client.
Use this function or the context manager to make sure your database is ready.
"""
try:
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
PartitionKey,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
"Please install it with `pip install azure-cosmos`."
) from exc
database = self._client.create_database_if_not_exists(self.cosmos_database)
self._container = database.create_container_if_not_exists(
self.cosmos_container,
partition_key=PartitionKey("/user_id"),
default_ttl=self.ttl,
)
self.load_messages()
def __enter__(self) -> "CosmosDBChatMessageHistory":
"""Context manager entry point."""
self._client.__enter__()
self.prepare_cosmos()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Context manager exit"""
self.upsert_messages()
self._client.__exit__(exc_type, exc_val, traceback)
def load_messages(self) -> None:
"""Retrieve the messages from Cosmos"""
if not self._container:
raise ValueError("Container not initialized")
try:
from azure.cosmos.exceptions import ( # pylint: disable=import-outside-toplevel # noqa: E501
CosmosHttpResponseError,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
"Please install it with `pip install azure-cosmos`."
) from exc
try:
item = self._container.read_item(
item=self.session_id, partition_key=self.user_id
)
except CosmosHttpResponseError:
logger.info("no session found")
return
if "messages" in item and len(item["messages"]) > 0:
self.messages = messages_from_dict(item["messages"])
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
self.messages.append(message)
self.upsert_messages()
def upsert_messages(self) -> None:
"""Update the cosmosdb item."""
if not self._container:
raise ValueError("Container not initialized")
self._container.upsert_item(
body={
"id": self.session_id,
"user_id": self.user_id,
"messages": messages_to_dict(self.messages),
}
)
def clear(self) -> None:
"""Clear session memory from this memory and cosmos."""
self.messages = []
if self._container:
self._container.delete_item(
item=self.session_id, partition_key=self.user_id
)
|