Spaces:
Runtime error
Runtime error
from typing import List, Optional, Union | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.messages import BaseMessage, messages_from_dict | |
from langchain.utils import get_from_env | |
class Neo4jChatMessageHistory(BaseChatMessageHistory): | |
"""Chat message history stored in a Neo4j database.""" | |
def __init__( | |
self, | |
session_id: Union[str, int], | |
url: Optional[str] = None, | |
username: Optional[str] = None, | |
password: Optional[str] = None, | |
database: str = "neo4j", | |
node_label: str = "Session", | |
window: int = 3, | |
): | |
try: | |
import neo4j | |
except ImportError: | |
raise ValueError( | |
"Could not import neo4j python package. " | |
"Please install it with `pip install neo4j`." | |
) | |
# Make sure session id is not null | |
if not session_id: | |
raise ValueError("Please ensure that the session_id parameter is provided") | |
url = get_from_env("url", "NEO4J_URI", url) | |
username = get_from_env("username", "NEO4J_USERNAME", username) | |
password = get_from_env("password", "NEO4J_PASSWORD", password) | |
database = get_from_env("database", "NEO4J_DATABASE", database) | |
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) | |
self._database = database | |
self._session_id = session_id | |
self._node_label = node_label | |
self._window = window | |
# Verify connection | |
try: | |
self._driver.verify_connectivity() | |
except neo4j.exceptions.ServiceUnavailable: | |
raise ValueError( | |
"Could not connect to Neo4j database. " | |
"Please ensure that the url is correct" | |
) | |
except neo4j.exceptions.AuthError: | |
raise ValueError( | |
"Could not connect to Neo4j database. " | |
"Please ensure that the username and password are correct" | |
) | |
# Create session node | |
self._driver.execute_query( | |
f"MERGE (s:`{self._node_label}` {{id:$session_id}})", | |
{"session_id": self._session_id}, | |
).summary | |
def messages(self) -> List[BaseMessage]: # type: ignore | |
"""Retrieve the messages from Neo4j""" | |
query = ( | |
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) " | |
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.." | |
f"{self._window*2}]-() WITH p, length(p) AS length " | |
"ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node " | |
"RETURN {data:{content: node.content}, type:node.type} AS result" | |
) | |
records, _, _ = self._driver.execute_query( | |
query, {"session_id": self._session_id} | |
) | |
messages = messages_from_dict([el["result"] for el in records]) | |
return messages | |
def add_message(self, message: BaseMessage) -> None: | |
"""Append the message to the record in Neo4j""" | |
query = ( | |
f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id " | |
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) " | |
"CREATE (s)-[:LAST_MESSAGE]->(new:Message) " | |
"SET new += {type:$type, content:$content} " | |
"WITH new, lm, last_message WHERE last_message IS NOT NULL " | |
"CREATE (last_message)-[:NEXT]->(new) " | |
"DELETE lm" | |
) | |
self._driver.execute_query( | |
query, | |
{ | |
"type": message.type, | |
"content": message.content, | |
"session_id": self._session_id, | |
}, | |
).summary | |
def clear(self) -> None: | |
"""Clear session memory from Neo4j""" | |
query = ( | |
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) " | |
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() " | |
"WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 " | |
"UNWIND nodes(p) as node DETACH DELETE node;" | |
) | |
self._driver.execute_query(query, {"session_id": self._session_id}).summary | |
def __del__(self) -> None: | |
if self._driver: | |
self._driver.close() | |