Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| from time import time | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_core.messages import ( | |
| BaseMessage, | |
| message_to_dict, | |
| messages_from_dict, | |
| ) | |
| if TYPE_CHECKING: | |
| from elasticsearch import Elasticsearch | |
| logger = logging.getLogger(__name__) | |
| class ElasticsearchChatMessageHistory(BaseChatMessageHistory): | |
| """Chat message history that stores history in Elasticsearch. | |
| Args: | |
| es_url: URL of the Elasticsearch instance to connect to. | |
| es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. | |
| es_user: Username to use when connecting to Elasticsearch. | |
| es_password: Password to use when connecting to Elasticsearch. | |
| es_api_key: API key to use when connecting to Elasticsearch. | |
| es_connection: Optional pre-existing Elasticsearch connection. | |
| index: Name of the index to use. | |
| session_id: Arbitrary key that is used to store the messages | |
| of a single chat session. | |
| """ | |
| def __init__( | |
| self, | |
| index: str, | |
| session_id: str, | |
| *, | |
| es_connection: Optional["Elasticsearch"] = None, | |
| es_url: Optional[str] = None, | |
| es_cloud_id: Optional[str] = None, | |
| es_user: Optional[str] = None, | |
| es_api_key: Optional[str] = None, | |
| es_password: Optional[str] = None, | |
| ): | |
| self.index: str = index | |
| self.session_id: str = session_id | |
| # Initialize Elasticsearch client from passed client arg or connection info | |
| if es_connection is not None: | |
| self.client = es_connection.options( | |
| headers={"user-agent": self.get_user_agent()} | |
| ) | |
| elif es_url is not None or es_cloud_id is not None: | |
| self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch( | |
| es_url=es_url, | |
| username=es_user, | |
| password=es_password, | |
| cloud_id=es_cloud_id, | |
| api_key=es_api_key, | |
| ) | |
| else: | |
| raise ValueError( | |
| """Either provide a pre-existing Elasticsearch connection, \ | |
| or valid credentials for creating a new connection.""" | |
| ) | |
| if self.client.indices.exists(index=index): | |
| logger.debug( | |
| f"Chat history index {index} already exists, skipping creation." | |
| ) | |
| else: | |
| logger.debug(f"Creating index {index} for storing chat history.") | |
| self.client.indices.create( | |
| index=index, | |
| mappings={ | |
| "properties": { | |
| "session_id": {"type": "keyword"}, | |
| "created_at": {"type": "date"}, | |
| "history": {"type": "text"}, | |
| } | |
| }, | |
| ) | |
| def get_user_agent() -> str: | |
| from langchain import __version__ | |
| return f"langchain-py-ms/{__version__}" | |
| def connect_to_elasticsearch( | |
| *, | |
| es_url: Optional[str] = None, | |
| cloud_id: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| username: Optional[str] = None, | |
| password: Optional[str] = None, | |
| ) -> "Elasticsearch": | |
| try: | |
| import elasticsearch | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import elasticsearch python package. " | |
| "Please install it with `pip install elasticsearch`." | |
| ) | |
| if es_url and cloud_id: | |
| raise ValueError( | |
| "Both es_url and cloud_id are defined. Please provide only one." | |
| ) | |
| connection_params: Dict[str, Any] = {} | |
| if es_url: | |
| connection_params["hosts"] = [es_url] | |
| elif cloud_id: | |
| connection_params["cloud_id"] = cloud_id | |
| else: | |
| raise ValueError("Please provide either elasticsearch_url or cloud_id.") | |
| if api_key: | |
| connection_params["api_key"] = api_key | |
| elif username and password: | |
| connection_params["basic_auth"] = (username, password) | |
| es_client = elasticsearch.Elasticsearch( | |
| **connection_params, | |
| headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()}, | |
| ) | |
| try: | |
| es_client.info() | |
| except Exception as err: | |
| logger.error(f"Error connecting to Elasticsearch: {err}") | |
| raise err | |
| return es_client | |
| def messages(self) -> List[BaseMessage]: # type: ignore[override] | |
| """Retrieve the messages from Elasticsearch""" | |
| try: | |
| from elasticsearch import ApiError | |
| result = self.client.search( | |
| index=self.index, | |
| query={"term": {"session_id": self.session_id}}, | |
| sort="created_at:asc", | |
| ) | |
| except ApiError as err: | |
| logger.error(f"Could not retrieve messages from Elasticsearch: {err}") | |
| raise err | |
| if result and len(result["hits"]["hits"]) > 0: | |
| items = [ | |
| json.loads(document["_source"]["history"]) | |
| for document in result["hits"]["hits"] | |
| ] | |
| else: | |
| items = [] | |
| return messages_from_dict(items) | |
| def add_message(self, message: BaseMessage) -> None: | |
| """Add a message to the chat session in Elasticsearch""" | |
| try: | |
| from elasticsearch import ApiError | |
| self.client.index( | |
| index=self.index, | |
| document={ | |
| "session_id": self.session_id, | |
| "created_at": round(time() * 1000), | |
| "history": json.dumps(message_to_dict(message)), | |
| }, | |
| refresh=True, | |
| ) | |
| except ApiError as err: | |
| logger.error(f"Could not add message to Elasticsearch: {err}") | |
| raise err | |
| def clear(self) -> None: | |
| """Clear session memory in Elasticsearch""" | |
| try: | |
| from elasticsearch import ApiError | |
| self.client.delete_by_query( | |
| index=self.index, | |
| query={"term": {"session_id": self.session_id}}, | |
| refresh=True, | |
| ) | |
| except ApiError as err: | |
| logger.error(f"Could not clear session memory in Elasticsearch: {err}") | |
| raise err | |