Spaces:
Running
Running
| # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: BSD 2-Clause License | |
| """Message broker implementation. | |
| This module provides message broker implementations for Redis and local queue-based | |
| communication, enabling publish/subscribe patterns and key-value storage operations. | |
| """ | |
| import asyncio | |
| import itertools | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from datetime import timedelta | |
| import redis.asyncio as redis | |
| from loguru import logger | |
| from redis.asyncio.client import PubSub | |
| class MessageBrokerConfig: | |
| """Configuration for message broker initialization. | |
| Attributes: | |
| name: The type of message broker to use ('redis' or 'local_queue'). | |
| url: Connection URL for the message broker (required for Redis). | |
| """ | |
| name: str | |
| url: str = "" | |
| class MessageBroker(ABC): | |
| """Abstract interface for all message brokers. Defines interface to receive and send messages.""" | |
| async def receive_messages(self, timeout: timedelta | None = timedelta(seconds=0.5)) -> list[tuple[str, str]]: | |
| """Receive incoming messages. Returns when it received one or more messages. | |
| Args: | |
| timeout: Maximum time to wait for messages. None means wait indefinitely. | |
| Returns: | |
| list[tuple[str, str]]: List of (message_id, message_data) tuples. | |
| """ | |
| raise NotImplementedError | |
| async def send_message(self, channel_id: str, message_data: str) -> None: | |
| """Publish a message to a channel. | |
| Args: | |
| channel_id: The channel to publish to. | |
| message_data: The message content to publish. | |
| """ | |
| raise NotImplementedError | |
| async def get(self, key: str) -> str | None: | |
| """Get value for a key. | |
| Args: | |
| key: The key to retrieve. | |
| Returns: | |
| str | None: The value if found, None otherwise. | |
| """ | |
| raise NotImplementedError | |
| async def set(self, key: str, value: str) -> None: | |
| """Set value for a key. | |
| Args: | |
| key: The key to set. | |
| value: The value to store. | |
| """ | |
| raise NotImplementedError | |
| async def wait_for_connection(self) -> None: | |
| """Wait for the connection to the message broker to be established.""" | |
| raise NotImplementedError | |
| async def delete(self, key: str) -> None: | |
| """Delete a key from storage. | |
| Args: | |
| key: The key to delete. | |
| """ | |
| raise NotImplementedError | |
| async def get_latest_message(self, channel_id: str) -> str | None: | |
| """Return the most recent message in the channel. | |
| Args: | |
| channel_id: The channel to check. | |
| Returns: | |
| str | None: The latest message if available, None otherwise. | |
| """ | |
| raise NotImplementedError | |
| async def pubsub_receive_message(self, channels: list[str], timeout: timedelta | None = None) -> str | None: | |
| """Receive a message from specified channels using pub/sub pattern. | |
| Args: | |
| channels: List of channels to subscribe to. | |
| timeout: Maximum time to wait for a message. | |
| Returns: | |
| str | None: The received message if available, None otherwise. | |
| """ | |
| raise NotImplementedError | |
| class RedisMessageBroker(MessageBroker): | |
| """Message broker implementation using Redis. | |
| Provides interface to receive and send messages using Redis streams and pub/sub. | |
| """ | |
| def __init__(self, redis_url: str, channels: list[str]): | |
| """Initialize Redis message broker. | |
| Args: | |
| redis_url: URL for Redis connection. | |
| channels: List of channels to subscribe to. | |
| """ | |
| super().__init__() | |
| self.redis: redis.Redis = redis.from_url(redis_url) | |
| self._pubsub: PubSub | None = None | |
| self._channel_state: dict[str, str] = dict(map(lambda c: (c, "0"), channels)) | |
| self.is_connected = asyncio.Event() | |
| # Add connection check | |
| asyncio.create_task(self._check_connection()) | |
| async def wait_for_connection(self) -> None: | |
| """Wait for the connection to the message broker to be established.""" | |
| await self.is_connected.wait() | |
| async def _check_connection(self) -> None: | |
| """Verify Redis connection is working.""" | |
| try: | |
| await self.redis.ping() | |
| logger.info("Successfully connected to Redis") | |
| self.is_connected.set() | |
| except redis.ConnectionError as e: | |
| logger.error(f"Failed to connect to Redis: {e}") | |
| self.is_connected.clear() | |
| async def receive_messages(self, timeout: timedelta | None = timedelta(seconds=0.5)) -> list[tuple[str, str]]: | |
| """Receive incoming messages. Returns when it received one or more messages. | |
| Args: | |
| timeout: Maximum time to wait for messages. | |
| Returns: | |
| list[tuple[str, str]]: List of (message_id, message_data) tuples. | |
| """ | |
| timeout_ms: int | None = None if timeout is None else int(timeout.total_seconds() * 1000) | |
| if timeout_ms is not None and timeout_ms < 100: | |
| logger.warning(f"Redis timeout resolution is about 100ms, but a timeout of {timeout_ms} ms was given.") | |
| result = await self.redis.xread(streams=self._channel_state, block=timeout_ms) # type: ignore | |
| message_list: list[tuple[str, str]] = [] | |
| for channel in result: | |
| channel_id = str(channel[0].decode()) | |
| for message_id, value in channel[1]: | |
| message_id = message_id.decode() | |
| for key in value: | |
| message_data = value[key].decode() | |
| message_list.append((message_id, message_data)) | |
| self._channel_state[channel_id] = message_id | |
| return message_list | |
| async def pubsub_receive_message(self, channels: list[str], timeout: timedelta | None = None) -> str | None: | |
| """Receive a message from specified channels using pub/sub pattern. | |
| Args: | |
| channels: List of channels to subscribe to. | |
| timeout: Maximum time to wait for a message. | |
| Returns: | |
| str | None: The received message if available, None otherwise. | |
| """ | |
| if not self._pubsub: | |
| self._pubsub = self.redis.pubsub() | |
| for channel in channels: | |
| if channel.encode("utf-8") not in self._pubsub.channels: | |
| await self._pubsub.subscribe(channel) | |
| message = await self._pubsub.get_message(timeout=None) | |
| if message["type"] == "message": | |
| return message["data"].decode("utf-8") | |
| return None | |
| async def send_message(self, channel_id: str, message_data: str) -> None: | |
| """Publish a message to a channel. | |
| Args: | |
| channel_id: The channel to publish to. | |
| message_data: The message content to publish. | |
| """ | |
| await self.redis.xadd(channel_id, {"event": message_data.encode()}) | |
| async def get(self, key: str) -> str | None: | |
| """Get value for a key. | |
| Args: | |
| key: The key to retrieve. | |
| Returns: | |
| str | None: The value if found, None otherwise. | |
| """ | |
| return await self.redis.get(key) | |
| async def set(self, key: str, value: str) -> None: | |
| """Set value for a key. | |
| Args: | |
| key: The key to set. | |
| value: The value to store. | |
| """ | |
| await self.redis.set(name=key, value=value) | |
| async def delete(self, key: str) -> None: | |
| """Delete a key from storage. | |
| Args: | |
| key: The key to delete. | |
| """ | |
| await self.redis.delete(key) | |
| async def get_latest_message(self, channel_id: str) -> str | None: | |
| """Return the most recent message in the channel. | |
| Args: | |
| channel_id: The channel to check. | |
| Returns: | |
| str | None: The latest message if available, None otherwise. | |
| """ | |
| result = await self.redis.xrevrange(channel_id, count=1) | |
| if not result: | |
| return None | |
| _, message_data = result[0] | |
| try: | |
| for key in message_data: | |
| return message_data[key].decode() | |
| return None | |
| except Exception: | |
| # Ignore parsing errors, but log an info about it. | |
| logger.info(f"Latest message {message_data} on channel {channel_id} is no valid message") | |
| return None | |
| _MESSAGE_ID = itertools.count() | |
| class LocalQueueMessageBroker(MessageBroker): | |
| """Message broker using a local queue for testing purposes. | |
| Implements the MessageBroker interface using in-memory queues, suitable for testing | |
| and local development. | |
| """ | |
| _storage: dict[str, str] = {} | |
| _queues: dict[str, asyncio.Queue] = {} | |
| def __init__(self, channels: list[str]): | |
| """Initialize local queue message broker. | |
| Args: | |
| channels: List of channels to create queues for. | |
| """ | |
| super().__init__() | |
| self._update_channels(channels) | |
| def _update_channels(self, channels: list[str]) -> None: | |
| for channel in channels: | |
| if channel not in self._queues: | |
| self._queues[channel] = asyncio.Queue() | |
| async def receive_messages( | |
| self, timeout: timedelta | None = timedelta(seconds=0.5), channels: list[str] | None = None | |
| ) -> list[tuple[str, str]]: | |
| """Receive incoming messages. Returns when it received one or more messages. | |
| Args: | |
| timeout: Maximum time to wait for messages. | |
| channels: Optional list of specific channels to receive from. | |
| Returns: | |
| list[tuple[str, str]]: List of (message_id, message_data) tuples. | |
| """ | |
| timeout_in_seconds = timeout.total_seconds() if timeout else None | |
| queue_selection = self._queues | |
| if channels is not None: | |
| self._update_channels(channels) | |
| queue_selection = {channel: queue for channel, queue in self._queues.items() if channel in channels} | |
| aws = [ | |
| asyncio.create_task(asyncio.wait_for(q.get(), timeout=timeout_in_seconds)) for q in queue_selection.values() | |
| ] | |
| results = await asyncio.gather(*aws, return_exceptions=True) | |
| message_list: list[tuple[str, str]] = [] | |
| for result in results: | |
| if isinstance(result, tuple): | |
| message_list.append(result) | |
| return message_list | |
| async def pubsub_receive_message(self, channels: list[str], timeout: timedelta | None = None) -> str | None: | |
| """Receive a message from specified channels using pub/sub pattern. | |
| Args: | |
| channels: List of channels to subscribe to. | |
| timeout: Maximum time to wait for a message. | |
| Returns: | |
| str | None: The received message if available, None otherwise. | |
| """ | |
| messages = await self.receive_messages(timeout=timeout, channels=channels) | |
| return messages[0][1] | |
| async def send_message(self, channel_id: str, message_data: str) -> None: | |
| """Publish a message to a channel. | |
| Args: | |
| channel_id: The channel to publish to. | |
| message_data: The message content to publish. | |
| """ | |
| message_id = str(next(_MESSAGE_ID)) | |
| await self._queues.setdefault(channel_id, asyncio.Queue()).put((message_id, message_data)) | |
| async def wait_for_connection(self) -> None: | |
| """Wait for the connection to the message broker to be established. | |
| For local queue, this is a no-op as no connection is needed. | |
| """ | |
| async def get(self, key: str) -> str | None: | |
| """Get value for a key. | |
| Args: | |
| key: The key to retrieve. | |
| Returns: | |
| str | None: The value if found, None otherwise. | |
| """ | |
| return self._storage.get(key, None) | |
| async def set(self, key: str, value: str) -> None: | |
| """Set value for a key. | |
| Args: | |
| key: The key to set. | |
| value: The value to store. | |
| """ | |
| self._storage[key] = value | |
| async def delete(self, key: str) -> None: | |
| """Delete a key from storage. | |
| Args: | |
| key: The key to delete. | |
| """ | |
| if key in self._queues: | |
| del self._queues[key] | |
| if key in self._storage: | |
| del self._storage[key] | |
| async def get_latest_message(self, channel_id: str) -> str | None: | |
| """Return the most recent message in the channel. | |
| Args: | |
| channel_id: The channel to check. | |
| Returns: | |
| str | None: The latest message if available, None otherwise. | |
| """ | |
| return None | |
| def message_broker_factory(config: MessageBrokerConfig, channels: list[str]) -> MessageBroker: | |
| """Create a new message broker based on the configuration. | |
| Currently supports Redis and local queue implementations. | |
| Args: | |
| config: Configuration specifying the broker type and connection details. | |
| channels: List of channels to initialize. | |
| Returns: | |
| MessageBroker: Configured message broker instance. | |
| Raises: | |
| Exception: If the specified broker type is not supported. | |
| """ | |
| brokers = ["redis", "local_queue"] | |
| if config.name == "redis": | |
| return RedisMessageBroker(config.url, channels) | |
| elif config.name == "local_queue": | |
| return LocalQueueMessageBroker(channels) | |
| else: | |
| raise Exception(f"message broker {config.name} does not exist. Available providers {','.join(brokers)}") | |