| """Tree-Based Message Queue Manager - Refactored. |
| |
| Coordinates data access, async processing, and error handling. |
| Uses TreeRepository for data, TreeQueueProcessor for async logic. |
| """ |
|
|
| import asyncio |
| from collections.abc import Awaitable, Callable |
|
|
| from loguru import logger |
|
|
| from ..models import IncomingMessage |
| from .data import MessageNode, MessageState, MessageTree |
| from .processor import TreeQueueProcessor |
| from .repository import TreeRepository |
|
|
| |
| __all__ = [ |
| "MessageNode", |
| "MessageState", |
| "MessageTree", |
| "TreeQueueManager", |
| ] |
|
|
|
|
| class TreeQueueManager: |
| """ |
| Manages multiple message trees. Facade that coordinates components. |
| |
| Each new conversation creates a new tree. |
| Replies to existing messages add nodes to existing trees. |
| |
| Components: |
| - TreeRepository: Data access layer |
| - TreeQueueProcessor: Async queue processing |
| """ |
|
|
| def __init__( |
| self, |
| queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None, |
| node_started_callback: Callable[[MessageTree, str], Awaitable[None]] |
| | None = None, |
| _repository: TreeRepository | None = None, |
| ): |
| self._repository = _repository or TreeRepository() |
| self._processor = TreeQueueProcessor( |
| queue_update_callback=queue_update_callback, |
| node_started_callback=node_started_callback, |
| ) |
| self._lock = asyncio.Lock() |
|
|
| logger.info("TreeQueueManager initialized") |
|
|
| async def create_tree( |
| self, |
| node_id: str, |
| incoming: IncomingMessage, |
| status_message_id: str, |
| ) -> MessageTree: |
| """ |
| Create a new tree with a root node. |
| |
| Args: |
| node_id: ID for the root node |
| incoming: The incoming message |
| status_message_id: Bot's status message ID |
| |
| Returns: |
| The created MessageTree |
| """ |
| async with self._lock: |
| root_node = MessageNode( |
| node_id=node_id, |
| incoming=incoming, |
| status_message_id=status_message_id, |
| state=MessageState.PENDING, |
| ) |
|
|
| tree = MessageTree(root_node) |
| self._repository.add_tree(node_id, tree) |
|
|
| logger.info(f"Created new tree with root {node_id}") |
| return tree |
|
|
| async def add_to_tree( |
| self, |
| parent_node_id: str, |
| node_id: str, |
| incoming: IncomingMessage, |
| status_message_id: str, |
| ) -> tuple[MessageTree, MessageNode]: |
| """ |
| Add a reply as a child node to an existing tree. |
| |
| Args: |
| parent_node_id: ID of the parent message |
| node_id: ID for the new node |
| incoming: The incoming reply message |
| status_message_id: Bot's status message ID |
| |
| Returns: |
| Tuple of (tree, new_node) |
| """ |
| async with self._lock: |
| if not self._repository.has_node(parent_node_id): |
| raise ValueError(f"Parent node {parent_node_id} not found in any tree") |
|
|
| tree = self._repository.get_tree_for_node(parent_node_id) |
| if not tree: |
| raise ValueError(f"Parent node {parent_node_id} not found in any tree") |
|
|
| |
| node = await tree.add_node( |
| node_id=node_id, |
| incoming=incoming, |
| status_message_id=status_message_id, |
| parent_id=parent_node_id, |
| ) |
|
|
| async with self._lock: |
| self._repository.register_node(node_id, tree.root_id) |
|
|
| logger.info(f"Added node {node_id} to tree {tree.root_id}") |
| return tree, node |
|
|
| def get_tree(self, root_id: str) -> MessageTree | None: |
| """Get a tree by its root ID.""" |
| return self._repository.get_tree(root_id) |
|
|
| def get_tree_for_node(self, node_id: str) -> MessageTree | None: |
| """Get the tree containing a given node.""" |
| return self._repository.get_tree_for_node(node_id) |
|
|
| def get_node(self, node_id: str) -> MessageNode | None: |
| """Get a node from any tree.""" |
| return self._repository.get_node(node_id) |
|
|
| def resolve_parent_node_id(self, msg_id: str) -> str | None: |
| """Resolve a message ID to the actual parent node ID.""" |
| return self._repository.resolve_parent_node_id(msg_id) |
|
|
| def is_tree_busy(self, root_id: str) -> bool: |
| """Check if a tree is currently processing.""" |
| return self._repository.is_tree_busy(root_id) |
|
|
| def is_node_tree_busy(self, node_id: str) -> bool: |
| """Check if the tree containing a node is busy.""" |
| return self._repository.is_node_tree_busy(node_id) |
|
|
| async def enqueue( |
| self, |
| node_id: str, |
| processor: Callable[[str, MessageNode], Awaitable[None]], |
| ) -> bool: |
| """ |
| Enqueue a node for processing. |
| |
| If the tree is not busy, processing starts immediately. |
| If busy, the message is queued. |
| |
| Args: |
| node_id: Node to process |
| processor: Async function to process the node |
| |
| Returns: |
| True if queued, False if processing immediately |
| """ |
| tree = self._repository.get_tree_for_node(node_id) |
| if not tree: |
| logger.error(f"No tree found for node {node_id}") |
| return False |
|
|
| return await self._processor.enqueue_and_start(tree, node_id, processor) |
|
|
| def get_queue_size(self, node_id: str) -> int: |
| """Get queue size for the tree containing a node.""" |
| return self._repository.get_queue_size(node_id) |
|
|
| def get_pending_children(self, node_id: str) -> list[MessageNode]: |
| """Get all pending child nodes (recursively) of a given node.""" |
| return self._repository.get_pending_children(node_id) |
|
|
| async def mark_node_error( |
| self, |
| node_id: str, |
| error_message: str, |
| propagate_to_children: bool = True, |
| ) -> list[MessageNode]: |
| """ |
| Mark a node as ERROR and optionally propagate to pending children. |
| |
| Args: |
| node_id: The node to mark as error |
| error_message: Error description |
| propagate_to_children: If True, also mark pending children as error |
| |
| Returns: |
| List of all nodes marked as error (including children) |
| """ |
| tree = self._repository.get_tree_for_node(node_id) |
| if not tree: |
| return [] |
|
|
| affected = [] |
| node = tree.get_node(node_id) |
| if node: |
| await tree.update_state( |
| node_id, MessageState.ERROR, error_message=error_message |
| ) |
| affected.append(node) |
|
|
| if propagate_to_children: |
| pending_children = self._repository.get_pending_children(node_id) |
| for child in pending_children: |
| await tree.update_state( |
| child.node_id, |
| MessageState.ERROR, |
| error_message=f"Parent failed: {error_message}", |
| ) |
| affected.append(child) |
|
|
| return affected |
|
|
| async def cancel_tree(self, root_id: str) -> list[MessageNode]: |
| """ |
| Cancel all queued and in-progress messages in a tree. |
| |
| Updates node states to ERROR and returns list of affected nodes |
| that were actually active or in the current processing queue. |
| """ |
| tree = self._repository.get_tree(root_id) |
| if not tree: |
| return [] |
|
|
| cancelled_nodes = [] |
|
|
| cleanup_count = 0 |
| async with tree.with_lock(): |
| |
| if tree.cancel_current_task(): |
| current_id = tree.current_node_id |
| if current_id: |
| node = tree.get_node(current_id) |
| if node and node.state not in ( |
| MessageState.COMPLETED, |
| MessageState.ERROR, |
| ): |
| tree.set_node_error_sync(node, "Cancelled by user") |
| cancelled_nodes.append(node) |
|
|
| |
| queue_nodes = tree.drain_queue_and_mark_cancelled() |
| cancelled_nodes.extend(queue_nodes) |
| cancelled_ids = {n.node_id for n in cancelled_nodes} |
|
|
| |
| for node in tree.all_nodes(): |
| if ( |
| node.state in (MessageState.PENDING, MessageState.IN_PROGRESS) |
| and node.node_id not in cancelled_ids |
| ): |
| tree.set_node_error_sync(node, "Stale task cleaned up") |
| cleanup_count += 1 |
|
|
| tree.reset_processing_state() |
|
|
| if cancelled_nodes: |
| logger.info( |
| f"Cancelled {len(cancelled_nodes)} active nodes in tree {root_id}" |
| ) |
| if cleanup_count: |
| logger.info(f"Cleaned up {cleanup_count} stale nodes in tree {root_id}") |
|
|
| return cancelled_nodes |
|
|
| async def cancel_node(self, node_id: str) -> list[MessageNode]: |
| """ |
| Cancel a single node (queued or in-progress) without affecting other nodes. |
| |
| - If the node is currently running, cancels the current asyncio task. |
| - If the node is queued, removes it from the queue. |
| - Marks the node as ERROR with "Cancelled by user". |
| |
| Returns: |
| List containing the cancelled node if it was cancellable, else empty list. |
| """ |
| tree = self._repository.get_tree_for_node(node_id) |
| if not tree: |
| return [] |
|
|
| async with tree.with_lock(): |
| node = tree.get_node(node_id) |
| if not node: |
| return [] |
|
|
| if node.state in (MessageState.COMPLETED, MessageState.ERROR): |
| return [] |
|
|
| if tree.is_current_node(node_id): |
| self._processor.cancel_current(tree) |
|
|
| try: |
| tree.remove_from_queue(node_id) |
| except Exception: |
| logger.debug( |
| "Failed to remove node from queue; will rely on state=ERROR" |
| ) |
|
|
| tree.set_node_error_sync(node, "Cancelled by user") |
|
|
| return [node] |
|
|
| async def cancel_all(self) -> list[MessageNode]: |
| """Cancel all messages in all trees.""" |
| async with self._lock: |
| root_ids = list(self._repository.tree_ids()) |
| all_cancelled: list[MessageNode] = [] |
| for root_id in root_ids: |
| all_cancelled.extend(await self.cancel_tree(root_id)) |
| return all_cancelled |
|
|
| def cleanup_stale_nodes(self) -> int: |
| """ |
| Mark any PENDING or IN_PROGRESS nodes in all trees as ERROR. |
| Used on startup to reconcile restored state. |
| """ |
| count = 0 |
| for tree in self._repository.all_trees(): |
| for node in tree.all_nodes(): |
| if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS): |
| tree.set_node_error_sync(node, "Lost during server restart") |
| count += 1 |
| if count: |
| logger.info(f"Cleaned up {count} stale nodes during startup") |
| return count |
|
|
| def get_tree_count(self) -> int: |
| """Get the number of active message trees.""" |
| return self._repository.tree_count() |
|
|
| def set_queue_update_callback( |
| self, |
| queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None, |
| ) -> None: |
| """Set callback for queue position updates.""" |
| self._processor.set_queue_update_callback(queue_update_callback) |
|
|
| def set_node_started_callback( |
| self, |
| node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None, |
| ) -> None: |
| """Set callback for when a queued node starts processing.""" |
| self._processor.set_node_started_callback(node_started_callback) |
|
|
| def register_node(self, node_id: str, root_id: str) -> None: |
| """Register a node ID to a tree (for external mapping).""" |
| self._repository.register_node(node_id, root_id) |
|
|
| async def cancel_branch(self, branch_root_id: str) -> list[MessageNode]: |
| """ |
| Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants). |
| |
| Does not call cli_manager.stop_all(). Returns list of cancelled nodes. |
| """ |
| tree = self._repository.get_tree_for_node(branch_root_id) |
| if not tree: |
| return [] |
|
|
| branch_ids = set(tree.get_descendants(branch_root_id)) |
| cancelled: list[MessageNode] = [] |
|
|
| async with tree.with_lock(): |
| for nid in branch_ids: |
| node = tree.get_node(nid) |
| if not node or node.state in ( |
| MessageState.COMPLETED, |
| MessageState.ERROR, |
| ): |
| continue |
|
|
| if tree.is_current_node(nid): |
| self._processor.cancel_current(tree) |
| tree.set_node_error_sync(node, "Cancelled by user") |
| cancelled.append(node) |
| else: |
| tree.remove_from_queue(nid) |
| tree.set_node_error_sync(node, "Cancelled by user") |
| cancelled.append(node) |
|
|
| if cancelled: |
| logger.info(f"Cancelled {len(cancelled)} nodes in branch {branch_root_id}") |
| return cancelled |
|
|
| async def remove_branch( |
| self, branch_root_id: str |
| ) -> tuple[list[MessageNode], str, bool]: |
| """ |
| Remove a branch (subtree) from the tree. |
| |
| If branch_root is the tree root, removes the entire tree. |
| |
| Returns: |
| (removed_nodes, root_id, removed_entire_tree) |
| """ |
| tree = self._repository.get_tree_for_node(branch_root_id) |
| if not tree: |
| return ([], "", False) |
|
|
| root_id = tree.root_id |
|
|
| if branch_root_id == root_id: |
| cancelled = await self.cancel_tree(root_id) |
| removed_tree = self._repository.remove_tree(root_id) |
| if removed_tree: |
| return (removed_tree.all_nodes(), root_id, True) |
| return (cancelled, root_id, True) |
|
|
| async with tree.with_lock(): |
| removed = tree.remove_branch(branch_root_id) |
|
|
| self._repository.unregister_nodes([n.node_id for n in removed]) |
| return (removed, root_id, False) |
|
|
| def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]: |
| """Get all message IDs for a given platform/chat.""" |
| return self._repository.get_message_ids_for_chat(platform, chat_id) |
|
|
| def to_dict(self) -> dict: |
| """Serialize all trees.""" |
| return self._repository.to_dict() |
|
|
| @classmethod |
| def from_dict( |
| cls, |
| data: dict, |
| queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None, |
| node_started_callback: Callable[[MessageTree, str], Awaitable[None]] |
| | None = None, |
| ) -> TreeQueueManager: |
| """Deserialize from dictionary.""" |
| return cls( |
| queue_update_callback=queue_update_callback, |
| node_started_callback=node_started_callback, |
| _repository=TreeRepository.from_dict(data), |
| ) |
|
|