Prabha-AIMLOPS's picture
initial commit
3243379 verified
import redis
from typing import Callable, List, Optional
import time
from concurrent.futures import ThreadPoolExecutor
import structlog
from prometheus_client import Counter, Histogram
from .config import BrokerConfig
from .message import Message
logger = structlog.get_logger()
messages_processed = Counter(
"messages_processed_total",
"Total number of messages processed",
["queue", "status"]
)
processing_time = Histogram(
"message_processing_seconds",
"Time spent processing messages",
["queue"]
)
class MessageConsumer:
def __init__(self, config: BrokerConfig, queue: str, handler: Callable[[Message], None]):
self.config = config
self.queue = queue
self.handler = handler
logger.info("Creating Redis connection pool",
host=config.redis.host,
port=config.redis.port,
ssl=config.redis.ssl)
connection_params = {
"host": config.redis.host,
"port": config.redis.port,
"db": config.redis.db,
"password": config.redis.password,
"decode_responses": True,
"max_connections": config.redis.connection_pool_size
}
if config.redis.ssl:
connection_params.update({
"ssl": True,
"ssl_cert_reqs": None,
"ssl_ca_certs": None
})
connection_pool = redis.ConnectionPool(**connection_params)
self._redis = redis.Redis(connection_pool=connection_pool)
self._executor = ThreadPoolExecutor(max_workers=config.num_workers)
self._running = False
logger.info("Message consumer initialized", queue=queue, config=config.dict())
def start(self) -> None:
"""Start consuming messages."""
self._running = True
self._executor.submit(self._process_retry_queue)
logger.info("self.config.num_workers")
for _ in range(self.config.num_workers):
logger.info("-----------------")
self._executor.submit(self._consume)
logger.info("Consumer started", queue=self.queue)
def stop(self) -> None:
"""Stop consuming messages."""
self._running = False
self._executor.shutdown(wait=True)
logger.info("Consumer stopped", queue=self.queue)
def _consume(self) -> None:
"""Consume messages from the queue."""
logger.info("Consumer thread started", queue=self.queue)
while self._running:
try:
messages = self._batch_pop_messages()
if messages:
logger.info("Received messages", queue=self.queue, count=len(messages))
for message_data in messages:
self._process_message(Message.from_json(message_data))
else:
# Small sleep to prevent CPU spinning when queue is empty
time.sleep(0.1)
except Exception as e:
logger.error("Error in consumer loop", error=str(e), queue=self.queue)
time.sleep(1)
def _batch_pop_messages(self) -> List[str]:
"""Pop a batch of messages from the queue."""
messages = []
try:
# Using brpop instead of rpop for blocking operation
result = self._redis.brpop([f"queue:{self.queue}"], timeout=1)
if result:
messages.append(result[1]) # brpop returns (key, value) tuple
# Try to get more messages up to batch size
for _ in range(self.config.batch_size - 1):
msg = self._redis.rpop(f"queue:{self.queue}")
if msg:
messages.append(msg)
else:
break
logger.debug("Batch pop result",
queue=self.queue,
messages_count=len(messages))
return messages
except Exception as e:
logger.error("Error in batch pop", error=str(e), queue=self.queue)
return []
def _process_message(self, message: Message) -> None:
"""Process a single message."""
with processing_time.labels(queue=self.queue).time():
try:
self.handler(message)
messages_processed.labels(
queue=self.queue, status="success"
).inc()
logger.info(
"Message processed successfully",
message_id=message.id,
queue=self.queue
)
except Exception as e:
messages_processed.labels(
queue=self.queue, status="error"
).inc()
message.error = str(e)
self._handle_processing_error(message)
def _handle_processing_error(self, message: Message) -> None:
"""Handle a message processing error."""
if message.retry_count < message.max_retries:
self._retry_message(message)
else:
self._move_to_dead_letter(message)
def _retry_message(self, message: Message) -> None:
"""Move a message to the retry queue with exponential backoff."""
message.retry_count += 1
delay = min(
self.config.retry.initial_delay *
(self.config.retry.backoff_factor ** (message.retry_count - 1)),
self.config.retry.max_delay
)
self._redis.zadd(
f"retry:{self.queue}",
{message.to_json(): time.time() + delay}
)
logger.info(
"Message scheduled for retry",
message_id=message.id,
queue=self.queue,
retry_count=message.retry_count,
delay=delay
)
def _process_retry_queue(self) -> None:
"""Process messages in the retry queue."""
while self._running:
try:
# Get messages that are ready to be retried
messages = self._redis.zrangebyscore(
f"retry:{self.queue}",
"-inf",
time.time(),
start=0,
num=self.config.batch_size
)
if not messages:
time.sleep(self.config.polling_interval)
continue
# Remove the processed messages from the retry queue
pipeline = self._redis.pipeline()
for message_data in messages:
#message = Message.from_json(message_data)
pipeline.zrem(f"retry:{self.queue}", message_data)
pipeline.lpush(f"queue:{self.queue}", message_data)
pipeline.execute()
except Exception as e:
logger.error("Error processing retry queue", error=str(e))
time.sleep(1)
def _move_to_dead_letter(self, message: Message) -> None:
"""Move a message to the dead letter queue."""
try:
self._redis.lpush(f"dead_letter:{self.queue}", message.to_json())
logger.warning(
"Message moved to dead letter queue",
message_id=message.id,
queue=self.queue,
error=message.error
)
except redis.RedisError as e:
logger.error("Failed to move message to dead letter queue", error=str(e))