|
import asyncio |
|
import functools |
|
import json |
|
import logging |
|
import os |
|
import ssl |
|
import time |
|
from enum import Enum |
|
from typing import Dict, Optional, Literal, Union |
|
|
|
import msgpack |
|
import pika |
|
from pika.exceptions import AMQPConnectionError |
|
from pydantic import BaseModel, field_validator |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def sync(f): |
|
@functools.wraps(f) |
|
def wrapper(*args, **kwargs): |
|
return asyncio.get_event_loop().run_until_complete(f(*args, **kwargs)) |
|
|
|
return wrapper |
|
|
|
|
|
class Priority(Enum): |
|
LOW = 1 |
|
NORMAL = 5 |
|
HIGH = 10 |
|
|
|
|
|
class Headers(BaseModel): |
|
job_id: str |
|
priority: Priority |
|
task_type: Optional[str] = None |
|
|
|
@field_validator('priority', mode='before') |
|
@classmethod |
|
def _convert_priority(cls, value): |
|
if isinstance(value, Priority): |
|
return value |
|
return Priority[value] |
|
|
|
|
|
class RabbitMQConfig(BaseModel): |
|
host: str |
|
port: int |
|
username: str |
|
password: str |
|
protocol: str |
|
|
|
|
|
class BasicPikaClient: |
|
def __init__(self): |
|
self.username = os.environ.get('RABBITMQ_USER', '') |
|
self.password = os.environ.get('RABBITMQ_PASSWD', '') |
|
self.host = os.environ.get('RABBITMQ_HOST', 'localhost') |
|
self.port = os.environ.get('RABBITMQ_PORT', 5672) |
|
self.protocol = "amqp" |
|
|
|
self._init_connection_parameters() |
|
self._connect() |
|
|
|
def _connect(self): |
|
tries = 0 |
|
while True: |
|
try: |
|
self.connection = pika.BlockingConnection(self.parameters) |
|
self.channel = self.connection.channel() |
|
if self.connection.is_open: |
|
break |
|
except (AMQPConnectionError, Exception) as e: |
|
time.sleep(5) |
|
tries += 1 |
|
if tries == 20: |
|
raise AMQPConnectionError(e) |
|
|
|
def _init_connection_parameters(self): |
|
if any([self.username, self.password]): |
|
self.credentials = pika.PlainCredentials(self.username, self.password) |
|
self.parameters = pika.ConnectionParameters( |
|
host=self.host, |
|
port=int(self.port), |
|
virtual_host="/", |
|
credentials=self.credentials, |
|
) |
|
else: |
|
self.parameters = pika.ConnectionParameters( |
|
self.host, |
|
int(self.port), |
|
"/", |
|
) |
|
|
|
if self.protocol == "amqps": |
|
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) |
|
ssl_context.set_ciphers("ECDHE+AESGCM:!ECDSA") |
|
self.parameters.ssl_options = pika.SSLOptions(context=ssl_context) |
|
|
|
def check_connection(self): |
|
if not self.connection or self.connection.is_closed: |
|
self._connect() |
|
|
|
def close(self): |
|
self.channel.close() |
|
self.connection.close() |
|
|
|
def declare_queue( |
|
self, queue_name, exclusive: bool = False, max_priority: int = 10 |
|
): |
|
self.check_connection() |
|
logger.debug(f"Trying to declare queue({queue_name})...") |
|
|
|
self.channel.queue_declare( |
|
queue=queue_name, |
|
exclusive=exclusive, |
|
durable=True, |
|
arguments={"x-max-priority": max_priority} if max_priority > 0 else None |
|
) |
|
|
|
def declare_exchange(self, exchange_name: str, exchange_type: str = "direct"): |
|
self.check_connection() |
|
self.channel.exchange_declare( |
|
exchange=exchange_name, exchange_type=exchange_type |
|
) |
|
|
|
def bind_queue(self, exchange_name: str, queue_name: str, routing_key: str): |
|
self.check_connection() |
|
self.channel.queue_bind( |
|
exchange=exchange_name, queue=queue_name, routing_key=routing_key |
|
) |
|
|
|
def unbind_queue(self, exchange_name: str, queue_name: str, routing_key: str): |
|
self.channel.queue_unbind( |
|
queue=queue_name, exchange=exchange_name, routing_key=routing_key |
|
) |
|
|
|
|
|
class BasicMessageSender(BasicPikaClient): |
|
message_encoding_type: Literal['bytes', 'json'] = 'json' |
|
|
|
def encode_message(self, body: Union[Dict, str], encoding_type: str = "bytes"): |
|
if encoding_type == "bytes": |
|
return msgpack.packb(body) |
|
elif encoding_type == "json": |
|
if isinstance(body, dict): |
|
return json.dumps(body) |
|
return body |
|
else: |
|
raise NotImplementedError |
|
|
|
def send_message( |
|
self, |
|
exchange_name: str, |
|
routing_key: str, |
|
body: Union[Dict, str], |
|
headers: Optional[Headers], |
|
): |
|
body = self.encode_message(body=body, encoding_type=self.message_encoding_type) |
|
|
|
properties = pika.BasicProperties(delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE, |
|
priority=headers.priority.value if headers else None, |
|
headers=headers.model_dump() if headers else None) |
|
self.check_connection() |
|
self.channel.basic_publish( |
|
exchange=exchange_name, |
|
routing_key=routing_key, |
|
body=body, |
|
properties=properties, |
|
) |
|
logger.debug( |
|
f"Sent message. Exchange: {exchange_name}, Routing Key: {routing_key}, Body: {body[:128]}" |
|
) |
|
|
|
|
|
class BasicMessageReceiver(BasicPikaClient): |
|
def __init__(self): |
|
super().__init__() |
|
self.channel_tag = None |
|
|
|
def decode_message(self, body): |
|
if type(body) == bytes: |
|
return json.loads(body) |
|
else: |
|
raise NotImplementedError |
|
|
|
def get_message(self, queue_name: str, auto_ack: bool = False): |
|
method_frame, header_frame, body = self.channel.basic_get( |
|
queue=queue_name, auto_ack=auto_ack |
|
) |
|
if method_frame: |
|
logger.debug(f"{method_frame}, {header_frame}, {body}") |
|
return method_frame, header_frame, body |
|
else: |
|
logger.debug("No message returned") |
|
return None |
|
|
|
def consume_messages(self, queue, callback): |
|
self.check_connection() |
|
self.channel_tag = self.channel.basic_consume( |
|
queue=queue, on_message_callback=callback, auto_ack=True |
|
) |
|
logger.debug(" [*] Waiting for messages. To exit press CTRL+C") |
|
self.channel.start_consuming() |
|
|
|
def cancel_consumer(self): |
|
if self.channel_tag is not None: |
|
self.channel.basic_cancel(self.channel_tag) |
|
self.channel_tag = None |
|
else: |
|
logger.error("Do not cancel a non-existing job") |
|
|