|
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.staticfiles import StaticFiles |
|
|
|
import logging |
|
import traceback |
|
from config import Args |
|
from user_queue import UserQueueDict |
|
import uuid |
|
import asyncio |
|
import time |
|
from PIL import Image |
|
import io |
|
|
|
|
|
def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline): |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
print("Init app", app) |
|
|
|
@app.websocket("/ws") |
|
async def websocket_endpoint(websocket: WebSocket): |
|
await websocket.accept() |
|
if args.max_queue_size > 0 and len(user_queue_map) >= args.max_queue_size: |
|
print("Server is full") |
|
await websocket.send_json({"status": "error", "message": "Server is full"}) |
|
await websocket.close() |
|
return |
|
|
|
try: |
|
uid = uuid.uuid4() |
|
print(f"New user connected: {uid}") |
|
await websocket.send_json( |
|
{"status": "success", "message": "Connected", "userId": uid} |
|
) |
|
user_queue_map[uid] = {"queue": asyncio.Queue()} |
|
await websocket.send_json( |
|
{"status": "start", "message": "Start Streaming", "userId": uid} |
|
) |
|
await handle_websocket_data(websocket, uid) |
|
except WebSocketDisconnect as e: |
|
logging.error(f"WebSocket Error: {e}, {uid}") |
|
traceback.print_exc() |
|
finally: |
|
print(f"User disconnected: {uid}") |
|
queue_value = user_queue_map.pop(uid, None) |
|
queue = queue_value.get("queue", None) |
|
if queue: |
|
while not queue.empty(): |
|
try: |
|
queue.get_nowait() |
|
except asyncio.QueueEmpty: |
|
continue |
|
|
|
@app.get("/queue_size") |
|
async def get_queue_size(): |
|
queue_size = len(user_queue_map) |
|
return JSONResponse({"queue_size": queue_size}) |
|
|
|
@app.get("/stream/{user_id}") |
|
async def stream(user_id: uuid.UUID): |
|
uid = user_id |
|
try: |
|
user_queue = user_queue_map[uid] |
|
queue = user_queue["queue"] |
|
|
|
async def generate(): |
|
last_prompt: str = None |
|
while True: |
|
data = await queue.get() |
|
input_image = data["image"] |
|
params = data["params"] |
|
if input_image is None: |
|
continue |
|
|
|
image = pipeline.predict( |
|
input_image, |
|
params, |
|
) |
|
if image is None: |
|
continue |
|
frame_data = io.BytesIO() |
|
image.save(frame_data, format="JPEG") |
|
frame_data = frame_data.getvalue() |
|
if frame_data is not None and len(frame_data) > 0: |
|
yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n" |
|
|
|
await asyncio.sleep(1.0 / 120.0) |
|
|
|
return StreamingResponse( |
|
generate(), media_type="multipart/x-mixed-replace;boundary=frame" |
|
) |
|
except Exception as e: |
|
logging.error(f"Streaming Error: {e}, {user_queue_map}") |
|
traceback.print_exc() |
|
return HTTPException(status_code=404, detail="User not found") |
|
|
|
async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID): |
|
uid = user_id |
|
user_queue = user_queue_map[uid] |
|
queue = user_queue["queue"] |
|
if not queue: |
|
return HTTPException(status_code=404, detail="User not found") |
|
last_time = time.time() |
|
try: |
|
while True: |
|
data = await websocket.receive_bytes() |
|
params = await websocket.receive_json() |
|
params = pipeline.InputParams(**params) |
|
pil_image = Image.open(io.BytesIO(data)) |
|
|
|
while not queue.empty(): |
|
try: |
|
queue.get_nowait() |
|
except asyncio.QueueEmpty: |
|
continue |
|
await queue.put({"image": pil_image, "params": params}) |
|
if args.timeout > 0 and time.time() - last_time > args.timeout: |
|
await websocket.send_json( |
|
{ |
|
"status": "timeout", |
|
"message": "Your session has ended", |
|
"userId": uid, |
|
} |
|
) |
|
await websocket.close() |
|
return |
|
|
|
except Exception as e: |
|
logging.error(f"Error: {e}") |
|
traceback.print_exc() |
|
|
|
|
|
@app.get("/settings") |
|
async def settings(): |
|
params = pipeline.InputParams() |
|
return JSONResponse(params.model_json_schema()) |
|
|
|
app.mount("/", StaticFiles(directory="public", html=True), name="public") |
|
|