File size: 6,289 Bytes
cb92d2b
 
 
 
d6fedfa
46bd9ac
cb92d2b
 
 
 
3207814
cb92d2b
 
ff9325e
3207814
 
c6b09d3
8a96a46
cb92d2b
a659304
 
cb92d2b
3207814
cb92d2b
 
 
 
 
 
 
 
995fb39
cb92d2b
 
3207814
 
cb92d2b
 
 
 
 
3207814
 
 
cb92d2b
3207814
cb92d2b
64e241d
3207814
cb92d2b
3207814
cb92d2b
 
3207814
 
cb92d2b
1d3190d
 
 
 
 
 
78b5416
 
 
 
 
 
 
 
 
 
1d3190d
 
a659304
1d3190d
 
 
 
 
 
 
 
64e241d
 
78b5416
64e241d
1d3190d
 
64e241d
1d3190d
 
 
 
 
995fb39
cb92d2b
3207814
cb92d2b
 
995fb39
d6fedfa
cb92d2b
 
 
1d3190d
 
cb92d2b
8a96a46
3207814
1d3190d
64e241d
78b5416
3207814
1d3190d
 
ff9325e
4b58964
cb92d2b
7d67dc6
78b5416
cb92d2b
d6fedfa
 
 
 
 
7d67dc6
8a96a46
 
cb92d2b
 
d6fedfa
 
 
cb92d2b
 
3207814
cb92d2b
 
 
 
995fb39
cb92d2b
46bd9ac
 
 
 
 
43148fd
d6fedfa
 
46bd9ac
d6fedfa
 
46bd9ac
d6fedfa
 
cb92d2b
c6b09d3
 
 
cb92d2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi import Request
import markdown2

import logging
import traceback
from config import Args
from user_queue import UserData
import uuid
import time
from types import SimpleNamespace
from util import pil_to_frame, bytes_to_pil, is_firefox
import asyncio
import os
import time

THROTTLE = 1.0 / 120


def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    @app.websocket("/api/ws")
    async def websocket_endpoint(websocket: WebSocket):
        await websocket.accept()
        user_count = user_data.get_user_count()
        if args.max_queue_size > 0 and user_count >= args.max_queue_size:
            print("Server is full")
            await websocket.send_json({"status": "error", "message": "Server is full"})
            await websocket.close()
            return
        try:
            user_id = uuid.uuid4()
            print(f"New user connected: {user_id}")
            await user_data.create_user(user_id, websocket)
            await websocket.send_json(
                {"status": "connected", "message": "Connected", "userId": str(user_id)}
            )
            await websocket.send_json({"status": "send_frame"})
            await handle_websocket_data(user_id, websocket)
        except WebSocketDisconnect as e:
            logging.error(f"WebSocket Error: {e}, {user_id}")
            traceback.print_exc()
        finally:
            print(f"User disconnected: {user_id}")
            user_data.delete_user(user_id)

    async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
        if not user_data.check_user(user_id):
            return HTTPException(status_code=404, detail="User not found")
        last_time = time.time()
        try:
            while True:
                if args.timeout > 0 and time.time() - last_time > args.timeout:
                    await websocket.send_json(
                        {
                            "status": "timeout",
                            "message": "Your session has ended",
                            "userId": str(user_id),
                        }
                    )
                    await websocket.close()
                    return
                data = await websocket.receive_json()
                if data["status"] != "next_frame":
                    asyncio.sleep(THROTTLE)
                    continue

                params = await websocket.receive_json()
                params = pipeline.InputParams(**params)
                info = pipeline.Info()
                params = SimpleNamespace(**params.dict())
                if info.input_mode == "image":
                    image_data = await websocket.receive_bytes()
                    if len(image_data) == 0:
                        await websocket.send_json({"status": "send_frame"})
                        await asyncio.sleep(THROTTLE)
                        continue
                    params.image = bytes_to_pil(image_data)
                await user_data.update_data(user_id, params)
                await websocket.send_json({"status": "wait"})

        except Exception as e:
            logging.error(f"Error: {e}")
            traceback.print_exc()

    @app.get("/api/queue")
    async def get_queue_size():
        queue_size = user_data.get_user_count()
        return JSONResponse({"queue_size": queue_size})

    @app.get("/api/stream/{user_id}")
    async def stream(user_id: uuid.UUID, request: Request):
        try:

            async def generate():
                websocket = user_data.get_websocket(user_id)
                last_params = SimpleNamespace()
                while True:
                    last_time = time.time()
                    params = await user_data.get_latest_data(user_id)
                    if not vars(params) or params.__dict__ == last_params.__dict__:
                        await websocket.send_json({"status": "send_frame"})
                        await asyncio.sleep(THROTTLE)
                        continue

                    last_params = params
                    image = pipeline.predict(params)

                    if image is None:
                        await websocket.send_json({"status": "send_frame"})
                        await asyncio.sleep(THROTTLE)
                        continue
                    frame = pil_to_frame(image)
                    yield frame
                    # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
                    if not is_firefox(request.headers["user-agent"]):
                        yield frame
                    await websocket.send_json({"status": "send_frame"})
                    if args.debug:
                        print(f"Time taken: {time.time() - last_time}")

            return StreamingResponse(
                generate(),
                media_type="multipart/x-mixed-replace;boundary=frame",
                headers={"Cache-Control": "no-cache"},
            )
        except Exception as e:
            logging.error(f"Streaming Error: {e}, {user_id} ")
            traceback.print_exc()
            return HTTPException(status_code=404, detail="User not found")

    # route to setup frontend
    @app.get("/api/settings")
    async def settings():
        info_schema = pipeline.Info.schema()
        info = pipeline.Info()
        if info.page_content:
            page_content = markdown2.markdown(info.page_content)

        input_params = pipeline.InputParams.schema()
        return JSONResponse(
            {
                "info": info_schema,
                "input_params": input_params,
                "max_queue_size": args.max_queue_size,
                "page_content": page_content if info.page_content else "",
            }
        )

    if not os.path.exists("public"):
        os.makedirs("public")

    app.mount("/", StaticFiles(directory="public", html=True), name="public")