GradLLM / openai_server.py
johnbridges's picture
trying a openai over rabbitmq test
15d27ef
raw
history blame
8.24 kB
# openai_server.py
from __future__ import annotations
import asyncio, json, time, uuid, math, logging
from typing import Any, AsyncIterable, Dict, List, Optional
import aio_pika
logger = logging.getLogger(__name__)
# --------------------------- Helpers ---------------------------
def _now() -> int:
return int(time.time())
def _chunk_text(s: str, sz: int = 120) -> List[str]:
if not s:
return []
return [s[i:i+sz] for i in range(0, len(s), sz)]
def _last_user_text(messages: List[Dict[str, Any]]) -> str:
# Accept either string or multimodal parts [{type:"text"/"image_url"/...}]
for m in reversed(messages or []):
if (m or {}).get("role") == "user":
c = m.get("content", "")
if isinstance(c, str):
return c
if isinstance(c, list):
texts = [p.get("text","") for p in c if p.get("type") == "text"]
return " ".join([t for t in texts if t])
return ""
# --------------------------- Backends ---------------------------
# You can replace DummyChatBackend with a real LLM (OpenAI/HF/local).
class ChatBackend:
async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
raise NotImplementedError
class DummyChatBackend(ChatBackend):
async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
"""
Emits OpenAI-shaped *streaming* chunks.
- No tool_calls for now (keeps server simple)
- Mimics delta frames + final finish_reason
"""
rid = f"chatcmpl-{uuid.uuid4().hex[:12]}"
model = request.get("model", "gpt-4o-mini")
text = _last_user_text(request.get("messages", [])) or "(empty)"
answer = f"Echo (RabbitMQ): {text}"
now = _now()
# First delta sets the role per OpenAI stream shape
yield {
"id": rid, "object": "chat.completion.chunk", "created": now, "model": model,
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}]
}
# Stream content in small pieces
for piece in _chunk_text(answer, 140):
yield {
"id": rid, "object": "chat.completion.chunk", "created": now, "model": model,
"choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}]
}
# Final delta with finish_reason
yield {
"id": rid, "object": "chat.completion.chunk", "created": now, "model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
}
class ImagesBackend:
async def generate_b64(self, request: Dict[str, Any]) -> str:
"""
Return base64 image string. This is a stub.
Replace with your image generator (e.g., SDXL, OpenAI gpt-image-1, etc.).
"""
# For now, return a 1x1 transparent PNG
return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
# --------------------------- Servers ---------------------------
class ChatCompletionsServer:
"""
Consumes OpenAI Chat Completions requests from exchange `oa.chat.create`,
routing-key `default`, and streams OpenAI-shaped chunks back to `reply_to`.
"""
def __init__(self, amqp_url: str, *, exchange_name: str = "oa.chat.create", routing_key: str = "default", backend: Optional[ChatBackend] = None):
self._amqp_url = amqp_url
self._exchange_name = exchange_name
self._routing_key = routing_key
self._backend = backend or DummyChatBackend()
self._conn: Optional[aio_pika.RobustConnection] = None
self._ch: Optional[aio_pika.RobustChannel] = None
self._ex: Optional[aio_pika.Exchange] = None
self._queue_name = f"{exchange_name}.{routing_key}"
async def start(self):
self._conn = await aio_pika.connect_robust(self._amqp_url)
self._ch = await self._conn.channel()
self._ex = await self._ch.declare_exchange(self._exchange_name, aio_pika.ExchangeType.DIRECT, durable=True)
q = await self._ch.declare_queue(self._queue_name, durable=True)
await q.bind(self._ex, routing_key=self._routing_key)
await q.consume(self._on_message)
logger.info("ChatCompletionsServer listening on %s/%s β†’ %s", self._exchange_name, self._routing_key, self._queue_name)
async def _on_message(self, msg: aio_pika.IncomingMessage):
async with msg.process(ignore_processed=True):
try:
req = json.loads(msg.body.decode("utf-8", errors="replace"))
reply_to = msg.reply_to
corr_id = msg.correlation_id
if not reply_to or not corr_id:
logger.warning("Missing reply_to/correlation_id; dropping.")
return
async for chunk in self._backend.stream(req):
await self._ch.default_exchange.publish(
aio_pika.Message(
body=json.dumps(chunk).encode("utf-8"),
correlation_id=corr_id,
content_type="application/json",
delivery_mode=aio_pika.DeliveryMode.NOT_PERSISTENT,
),
routing_key=reply_to,
)
# Optional end sentinel
await self._ch.default_exchange.publish(
aio_pika.Message(
body=b'{"object":"stream.end"}',
correlation_id=corr_id,
content_type="application/json",
),
routing_key=reply_to,
)
except Exception:
logger.exception("ChatCompletionsServer: failed to process message")
class ImagesServer:
"""
Consumes OpenAI Images API requests from exchange `oa.images.generate`,
routing-key `default`, and replies once with {data:[{b64_json:...}], created:...}.
"""
def __init__(self, amqp_url: str, *, exchange_name: str = "oa.images.generate", routing_key: str = "default", backend: Optional[ImagesBackend] = None):
self._amqp_url = amqp_url
self._exchange_name = exchange_name
self._routing_key = routing_key
self._backend = backend or ImagesBackend()
self._conn: Optional[aio_pika.RobustConnection] = None
self._ch: Optional[aio_pika.RobustChannel] = None
self._ex: Optional[aio_pika.Exchange] = None
self._queue_name = f"{exchange_name}.{routing_key}"
async def start(self):
self._conn = await aio_pika.connect_robust(self._amqp_url)
self._ch = await self._conn.channel()
self._ex = await self._ch.declare_exchange(self._exchange_name, aio_pika.ExchangeType.DIRECT, durable=True)
q = await self._ch.declare_queue(self._queue_name, durable=True)
await q.bind(self._ex, routing_key=self._routing_key)
await q.consume(self._on_message)
logger.info("ImagesServer listening on %s/%s β†’ %s", self._exchange_name, self._routing_key, self._queue_name)
async def _on_message(self, msg: aio_pika.IncomingMessage):
async with msg.process(ignore_processed=True):
try:
req = json.loads(msg.body.decode("utf-8", errors="replace"))
reply_to = msg.reply_to
corr_id = msg.correlation_id
if not reply_to or not corr_id:
logger.warning("Missing reply_to/correlation_id; dropping.")
return
b64_img = await self._backend.generate_b64(req)
resp = {"created": _now(), "data": [{"b64_json": b64_img}]}
await self._ch.default_exchange.publish(
aio_pika.Message(
body=json.dumps(resp).encode("utf-8"),
correlation_id=corr_id,
content_type="application/json",
),
routing_key=reply_to,
)
except Exception:
logger.exception("ImagesServer: failed to process message")