File size: 5,282 Bytes
cb92d2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles

import logging
import traceback
from config import Args
from user_queue import UserQueueDict
import uuid
import asyncio
import time
from PIL import Image
import io


def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline):
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    print("Init app", app)

    @app.websocket("/ws")
    async def websocket_endpoint(websocket: WebSocket):
        await websocket.accept()
        if args.max_queue_size > 0 and len(user_queue_map) >= args.max_queue_size:
            print("Server is full")
            await websocket.send_json({"status": "error", "message": "Server is full"})
            await websocket.close()
            return

        try:
            uid = uuid.uuid4()
            print(f"New user connected: {uid}")
            await websocket.send_json(
                {"status": "success", "message": "Connected", "userId": uid}
            )
            user_queue_map[uid] = {"queue": asyncio.Queue()}
            await websocket.send_json(
                {"status": "start", "message": "Start Streaming", "userId": uid}
            )
            await handle_websocket_data(websocket, uid)
        except WebSocketDisconnect as e:
            logging.error(f"WebSocket Error: {e}, {uid}")
            traceback.print_exc()
        finally:
            print(f"User disconnected: {uid}")
            queue_value = user_queue_map.pop(uid, None)
            queue = queue_value.get("queue", None)
            if queue:
                while not queue.empty():
                    try:
                        queue.get_nowait()
                    except asyncio.QueueEmpty:
                        continue

    @app.get("/queue_size")
    async def get_queue_size():
        queue_size = len(user_queue_map)
        return JSONResponse({"queue_size": queue_size})

    @app.get("/stream/{user_id}")
    async def stream(user_id: uuid.UUID):
        uid = user_id
        try:
            user_queue = user_queue_map[uid]
            queue = user_queue["queue"]

            async def generate():
                last_prompt: str = None
                while True:
                    data = await queue.get()
                    input_image = data["image"]
                    params = data["params"]
                    if input_image is None:
                        continue

                    image = pipeline.predict(
                        input_image,
                        params,
                    )
                    if image is None:
                        continue
                    frame_data = io.BytesIO()
                    image.save(frame_data, format="JPEG")
                    frame_data = frame_data.getvalue()
                    if frame_data is not None and len(frame_data) > 0:
                        yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"

                    await asyncio.sleep(1.0 / 120.0)

            return StreamingResponse(
                generate(), media_type="multipart/x-mixed-replace;boundary=frame"
            )
        except Exception as e:
            logging.error(f"Streaming Error: {e}, {user_queue_map}")
            traceback.print_exc()
            return HTTPException(status_code=404, detail="User not found")

    async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
        uid = user_id
        user_queue = user_queue_map[uid]
        queue = user_queue["queue"]
        if not queue:
            return HTTPException(status_code=404, detail="User not found")
        last_time = time.time()
        try:
            while True:
                data = await websocket.receive_bytes()
                params = await websocket.receive_json()
                params = pipeline.InputParams(**params)
                pil_image = Image.open(io.BytesIO(data))

                while not queue.empty():
                    try:
                        queue.get_nowait()
                    except asyncio.QueueEmpty:
                        continue
                await queue.put({"image": pil_image, "params": params})
                if args.timeout > 0 and time.time() - last_time > args.timeout:
                    await websocket.send_json(
                        {
                            "status": "timeout",
                            "message": "Your session has ended",
                            "userId": uid,
                        }
                    )
                    await websocket.close()
                    return

        except Exception as e:
            logging.error(f"Error: {e}")
            traceback.print_exc()

    # route to setup frontend
    @app.get("/settings")
    async def settings():
        params = pipeline.InputParams()
        return JSONResponse({"settings": params.dict()})

    app.mount("/", StaticFiles(directory="public", html=True), name="public")