import json from collections.abc import Sequence from uuid import UUID from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage from loguru import logger from sqlalchemy import delete from sqlmodel import Session, col, select from sqlmodel.ext.asyncio.session import AsyncSession from langflow.schema.message import Message from langflow.services.database.models.message.model import MessageRead, MessageTable from langflow.services.deps import async_session_scope, session_scope from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER def _get_variable_query( sender: str | None = None, sender_name: str | None = None, session_id: str | None = None, order_by: str | None = "timestamp", order: str | None = "DESC", flow_id: UUID | None = None, limit: int | None = None, ): stmt = select(MessageTable).where(MessageTable.error == False) # noqa: E712 if sender: stmt = stmt.where(MessageTable.sender == sender) if sender_name: stmt = stmt.where(MessageTable.sender_name == sender_name) if session_id: stmt = stmt.where(MessageTable.session_id == session_id) if flow_id: stmt = stmt.where(MessageTable.flow_id == flow_id) if order_by: col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc() stmt = stmt.order_by(col) if limit: stmt = stmt.limit(limit) return stmt def get_messages( sender: str | None = None, sender_name: str | None = None, session_id: str | None = None, order_by: str | None = "timestamp", order: str | None = "DESC", flow_id: UUID | None = None, limit: int | None = None, ) -> list[Message]: """Retrieves messages from the monitor service based on the provided filters. Args: sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User") sender_name (Optional[str]): The name of the sender. session_id (Optional[str]): The session ID associated with the messages. order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp". order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC". flow_id (Optional[UUID]): The flow ID associated with the messages. limit (Optional[int]): The maximum number of messages to retrieve. Returns: List[Data]: A list of Data objects representing the retrieved messages. """ with session_scope() as session: stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit) messages = session.exec(stmt) return [Message(**d.model_dump()) for d in messages] async def aget_messages( sender: str | None = None, sender_name: str | None = None, session_id: str | None = None, order_by: str | None = "timestamp", order: str | None = "DESC", flow_id: UUID | None = None, limit: int | None = None, ) -> list[Message]: """Retrieves messages from the monitor service based on the provided filters. Args: sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User") sender_name (Optional[str]): The name of the sender. session_id (Optional[str]): The session ID associated with the messages. order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp". order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC". flow_id (Optional[UUID]): The flow ID associated with the messages. limit (Optional[int]): The maximum number of messages to retrieve. Returns: List[Data]: A list of Data objects representing the retrieved messages. """ async with async_session_scope() as session: stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit) messages = await session.exec(stmt) return [await Message.create(**d.model_dump()) for d in messages] def add_messages(messages: Message | list[Message], flow_id: str | None = None): """Add a message to the monitor service.""" if not isinstance(messages, list): messages = [messages] if not all(isinstance(message, Message) for message in messages): types = ", ".join([str(type(message)) for message in messages]) msg = f"The messages must be instances of Message. Found: {types}" raise ValueError(msg) try: messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages] with session_scope() as session: messages_models = add_messagetables(messages_models, session) return [Message(**message.model_dump()) for message in messages_models] except Exception as e: logger.exception(e) raise async def aadd_messages(messages: Message | list[Message], flow_id: str | None = None): """Add a message to the monitor service.""" if not isinstance(messages, list): messages = [messages] if not all(isinstance(message, Message) for message in messages): types = ", ".join([str(type(message)) for message in messages]) msg = f"The messages must be instances of Message. Found: {types}" raise ValueError(msg) try: messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages] async with async_session_scope() as session: messages_models = await aadd_messagetables(messages_models, session) return [await Message.create(**message.model_dump()) for message in messages_models] except Exception as e: logger.exception(e) raise def update_messages(messages: Message | list[Message]) -> list[Message]: if not isinstance(messages, list): messages = [messages] with session_scope() as session: updated_messages: list[MessageTable] = [] for message in messages: msg = session.get(MessageTable, message.id) if msg: msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True)) session.add(msg) session.commit() session.refresh(msg) updated_messages.append(msg) else: logger.warning(f"Message with id {message.id} not found") return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages] async def aupdate_messages(messages: Message | list[Message]) -> list[Message]: if not isinstance(messages, list): messages = [messages] async with async_session_scope() as session: updated_messages: list[MessageTable] = [] for message in messages: msg = await session.get(MessageTable, message.id) if msg: msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True)) session.add(msg) await session.commit() await session.refresh(msg) updated_messages.append(msg) else: logger.warning(f"Message with id {message.id} not found") return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages] def add_messagetables(messages: list[MessageTable], session: Session): for message in messages: try: session.add(message) session.commit() session.refresh(message) except Exception as e: logger.exception(e) raise new_messages = [] for msg in messages: msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type] msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type] msg.category = msg.category or "" new_messages.append(msg) return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages] async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession): try: for message in messages: session.add(message) await session.commit() for message in messages: await session.refresh(message) except Exception as e: logger.exception(e) raise new_messages = [] for msg in messages: msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties # type: ignore[arg-type] msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] # type: ignore[arg-type] msg.category = msg.category or "" new_messages.append(msg) return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages] def delete_messages(session_id: str) -> None: """Delete messages from the monitor service based on the provided session ID. Args: session_id (str): The session ID associated with the messages to delete. """ with session_scope() as session: session.exec( delete(MessageTable) .where(col(MessageTable.session_id) == session_id) .execution_options(synchronize_session="fetch") ) async def adelete_messages(session_id: str) -> None: """Delete messages from the monitor service based on the provided session ID. Args: session_id (str): The session ID associated with the messages to delete. """ async with async_session_scope() as session: stmt = ( delete(MessageTable) .where(col(MessageTable.session_id) == session_id) .execution_options(synchronize_session="fetch") ) await session.exec(stmt) async def delete_message(id_: str) -> None: """Delete a message from the monitor service based on the provided ID. Args: id_ (str): The ID of the message to delete. """ async with async_session_scope() as session: message = await session.get(MessageTable, id_) if message: await session.delete(message) await session.commit() def store_message( message: Message, flow_id: str | None = None, ) -> list[Message]: """Stores a message in the memory. Args: message (Message): The message to store. flow_id (Optional[str]): The flow ID associated with the message. When running from the CustomComponent you can access this using `self.graph.flow_id`. Returns: List[Message]: A list of data containing the stored message. Raises: ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided. """ if not message: logger.warning("No message provided.") return [] required_fields = ["session_id", "sender", "sender_name"] missing_fields = [field for field in required_fields if not getattr(message, field)] if missing_fields: missing_descriptions = { "session_id": "session_id (unique conversation identifier)", "sender": f"sender (e.g., '{MESSAGE_SENDER_USER}' or '{MESSAGE_SENDER_AI}')", "sender_name": "sender_name (display name, e.g., 'User' or 'Assistant')", } missing = ", ".join(missing_descriptions[field] for field in missing_fields) msg = ( f"It looks like we're missing some important information: {missing}. " "Please ensure that your message includes all the required fields." ) raise ValueError(msg) if hasattr(message, "id") and message.id: return update_messages([message]) return add_messages([message], flow_id=flow_id) async def astore_message( message: Message, flow_id: str | None = None, ) -> list[Message]: """Stores a message in the memory. Args: message (Message): The message to store. flow_id (Optional[str]): The flow ID associated with the message. When running from the CustomComponent you can access this using `self.graph.flow_id`. Returns: List[Message]: A list of data containing the stored message. Raises: ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided. """ if not message: logger.warning("No message provided.") return [] if not message.session_id or not message.sender or not message.sender_name: msg = "All of session_id, sender, and sender_name must be provided." raise ValueError(msg) if hasattr(message, "id") and message.id: return await aupdate_messages([message]) return await aadd_messages([message], flow_id=flow_id) class LCBuiltinChatMemory(BaseChatMessageHistory): def __init__( self, flow_id: str, session_id: str, ) -> None: self.flow_id = flow_id self.session_id = session_id @property def messages(self) -> list[BaseMessage]: messages = get_messages( session_id=self.session_id, ) return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages async def aget_messages(self) -> list[BaseMessage]: messages = await aget_messages( session_id=self.session_id, ) return [m.to_lc_message() for m in messages if not m.error] # Exclude error messages def add_messages(self, messages: Sequence[BaseMessage]) -> None: for lc_message in messages: message = Message.from_lc_message(lc_message) message.session_id = self.session_id store_message(message, flow_id=self.flow_id) async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: for lc_message in messages: message = Message.from_lc_message(lc_message) message.session_id = self.session_id await astore_message(message, flow_id=self.flow_id) def clear(self) -> None: delete_messages(self.session_id) async def aclear(self) -> None: await adelete_messages(self.session_id)