radames commited on
Commit
488b360
1 Parent(s): 78b5416

better app structure

Browse files
Files changed (10) hide show
  1. app.py +0 -17
  2. app_init.py +0 -163
  3. config.py +2 -1
  4. connection_manager.py +116 -0
  5. frontend/src/lib/lcmLive.ts +9 -6
  6. frontend/src/routes/+page.svelte +3 -3
  7. main.py +184 -0
  8. run.py +0 -12
  9. user_queue.py +0 -63
  10. util.py +0 -2
app.py DELETED
@@ -1,17 +0,0 @@
1
- from fastapi import FastAPI
2
-
3
- from config import args
4
- from device import device, torch_dtype
5
- from app_init import init_app
6
- from user_queue import user_data
7
- from util import get_pipeline_class
8
-
9
- print("DEVICE:", device)
10
- print("TORCH_DTYPE:", torch_dtype)
11
- args.pretty_print()
12
-
13
- app = FastAPI()
14
-
15
- pipeline_class = get_pipeline_class(args.pipeline)
16
- pipeline = pipeline_class(args, device, torch_dtype)
17
- init_app(app, user_data, args, pipeline)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_init.py DELETED
@@ -1,163 +0,0 @@
1
- from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
2
- from fastapi.responses import StreamingResponse, JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.staticfiles import StaticFiles
5
- from fastapi import Request
6
- import markdown2
7
-
8
- import logging
9
- import traceback
10
- from config import Args
11
- from user_queue import UserData
12
- import uuid
13
- import time
14
- from types import SimpleNamespace
15
- from util import pil_to_frame, bytes_to_pil, is_firefox
16
- import asyncio
17
- import os
18
- import time
19
-
20
- THROTTLE = 1.0 / 120
21
-
22
-
23
- def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
24
- app.add_middleware(
25
- CORSMiddleware,
26
- allow_origins=["*"],
27
- allow_credentials=True,
28
- allow_methods=["*"],
29
- allow_headers=["*"],
30
- )
31
-
32
- @app.websocket("/api/ws")
33
- async def websocket_endpoint(websocket: WebSocket):
34
- await websocket.accept()
35
- user_count = user_data.get_user_count()
36
- if args.max_queue_size > 0 and user_count >= args.max_queue_size:
37
- print("Server is full")
38
- await websocket.send_json({"status": "error", "message": "Server is full"})
39
- await websocket.close()
40
- return
41
- try:
42
- user_id = uuid.uuid4()
43
- print(f"New user connected: {user_id}")
44
- await user_data.create_user(user_id, websocket)
45
- await websocket.send_json(
46
- {"status": "connected", "message": "Connected", "userId": str(user_id)}
47
- )
48
- await websocket.send_json({"status": "send_frame"})
49
- await handle_websocket_data(user_id, websocket)
50
- except WebSocketDisconnect as e:
51
- logging.error(f"WebSocket Error: {e}, {user_id}")
52
- traceback.print_exc()
53
- finally:
54
- print(f"User disconnected: {user_id}")
55
- user_data.delete_user(user_id)
56
-
57
- async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
58
- if not user_data.check_user(user_id):
59
- return HTTPException(status_code=404, detail="User not found")
60
- last_time = time.time()
61
- try:
62
- while True:
63
- if args.timeout > 0 and time.time() - last_time > args.timeout:
64
- await websocket.send_json(
65
- {
66
- "status": "timeout",
67
- "message": "Your session has ended",
68
- "userId": str(user_id),
69
- }
70
- )
71
- await websocket.close()
72
- return
73
- data = await websocket.receive_json()
74
- if data["status"] != "next_frame":
75
- asyncio.sleep(THROTTLE)
76
- continue
77
-
78
- params = await websocket.receive_json()
79
- params = pipeline.InputParams(**params)
80
- info = pipeline.Info()
81
- params = SimpleNamespace(**params.dict())
82
- if info.input_mode == "image":
83
- image_data = await websocket.receive_bytes()
84
- if len(image_data) == 0:
85
- await websocket.send_json({"status": "send_frame"})
86
- await asyncio.sleep(THROTTLE)
87
- continue
88
- params.image = bytes_to_pil(image_data)
89
- await user_data.update_data(user_id, params)
90
- await websocket.send_json({"status": "wait"})
91
-
92
- except Exception as e:
93
- logging.error(f"Error: {e}")
94
- traceback.print_exc()
95
-
96
- @app.get("/api/queue")
97
- async def get_queue_size():
98
- queue_size = user_data.get_user_count()
99
- return JSONResponse({"queue_size": queue_size})
100
-
101
- @app.get("/api/stream/{user_id}")
102
- async def stream(user_id: uuid.UUID, request: Request):
103
- try:
104
-
105
- async def generate():
106
- websocket = user_data.get_websocket(user_id)
107
- last_params = SimpleNamespace()
108
- while True:
109
- last_time = time.time()
110
- params = await user_data.get_latest_data(user_id)
111
- if not vars(params) or params.__dict__ == last_params.__dict__:
112
- await websocket.send_json({"status": "send_frame"})
113
- await asyncio.sleep(THROTTLE)
114
- continue
115
-
116
- last_params = params
117
- image = pipeline.predict(params)
118
-
119
- if image is None:
120
- await websocket.send_json({"status": "send_frame"})
121
- await asyncio.sleep(THROTTLE)
122
- continue
123
- frame = pil_to_frame(image)
124
- yield frame
125
- # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
126
- if not is_firefox(request.headers["user-agent"]):
127
- yield frame
128
- await websocket.send_json({"status": "send_frame"})
129
- if args.debug:
130
- print(f"Time taken: {time.time() - last_time}")
131
-
132
- return StreamingResponse(
133
- generate(),
134
- media_type="multipart/x-mixed-replace;boundary=frame",
135
- headers={"Cache-Control": "no-cache"},
136
- )
137
- except Exception as e:
138
- logging.error(f"Streaming Error: {e}, {user_id} ")
139
- traceback.print_exc()
140
- return HTTPException(status_code=404, detail="User not found")
141
-
142
- # route to setup frontend
143
- @app.get("/api/settings")
144
- async def settings():
145
- info_schema = pipeline.Info.schema()
146
- info = pipeline.Info()
147
- if info.page_content:
148
- page_content = markdown2.markdown(info.page_content)
149
-
150
- input_params = pipeline.InputParams.schema()
151
- return JSONResponse(
152
- {
153
- "info": info_schema,
154
- "input_params": input_params,
155
- "max_queue_size": args.max_queue_size,
156
- "page_content": page_content if info.page_content else "",
157
- }
158
- )
159
-
160
- if not os.path.exists("public"):
161
- os.makedirs("public")
162
-
163
- app.mount("/", StaticFiles(directory="public", html=True), name="public")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py CHANGED
@@ -124,4 +124,5 @@ parser.add_argument(
124
  )
125
  parser.set_defaults(taesd=USE_TAESD)
126
 
127
- args = Args(**vars(parser.parse_args()))
 
 
124
  )
125
  parser.set_defaults(taesd=USE_TAESD)
126
 
127
+ config = Args(**vars(parser.parse_args()))
128
+ config.pretty_print()
connection_manager.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Union
2
+ from uuid import UUID
3
+ import asyncio
4
+ from fastapi import WebSocket
5
+ from starlette.websockets import WebSocketState
6
+ import logging
7
+ from types import SimpleNamespace
8
+
9
+ Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
10
+
11
+
12
+ class ServerFullException(Exception):
13
+ """Exception raised when the server is full."""
14
+
15
+ pass
16
+
17
+
18
+ class ConnectionManager:
19
+ def __init__(self):
20
+ self.active_connections: Connections = {}
21
+
22
+ async def connect(
23
+ self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0
24
+ ):
25
+ await websocket.accept()
26
+ user_count = self.get_user_count()
27
+ print(f"User count: {user_count}")
28
+ if max_queue_size > 0 and user_count >= max_queue_size:
29
+ print("Server is full")
30
+ await websocket.send_json({"status": "error", "message": "Server is full"})
31
+ await websocket.close()
32
+ raise ServerFullException("Server is full")
33
+ print(f"New user connected: {user_id}")
34
+ self.active_connections[user_id] = {
35
+ "websocket": websocket,
36
+ "queue": asyncio.Queue(),
37
+ }
38
+ await websocket.send_json(
39
+ {"status": "connected", "message": "Connected"},
40
+ )
41
+ await websocket.send_json({"status": "wait"})
42
+ await websocket.send_json({"status": "send_frame"})
43
+
44
+ def check_user(self, user_id: UUID) -> bool:
45
+ return user_id in self.active_connections
46
+
47
+ async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
48
+ user_session = self.active_connections.get(user_id)
49
+ if user_session:
50
+ queue = user_session["queue"]
51
+ while not queue.empty():
52
+ try:
53
+ queue.get_nowait()
54
+ except asyncio.QueueEmpty:
55
+ continue
56
+ await queue.put(new_data)
57
+
58
+ async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
59
+ user_session = self.active_connections.get(user_id)
60
+ if user_session:
61
+ queue = user_session["queue"]
62
+ try:
63
+ return await queue.get()
64
+ except asyncio.QueueEmpty:
65
+ return None
66
+
67
+ def delete_user(self, user_id: UUID):
68
+ user_session = self.active_connections.pop(user_id, None)
69
+ if user_session:
70
+ queue = user_session["queue"]
71
+ while not queue.empty():
72
+ try:
73
+ queue.get_nowait()
74
+ except asyncio.QueueEmpty:
75
+ continue
76
+
77
+ def get_user_count(self) -> int:
78
+ return len(self.active_connections)
79
+
80
+ def get_websocket(self, user_id: UUID) -> WebSocket:
81
+ user_session = self.active_connections.get(user_id)
82
+ if user_session:
83
+ websocket = user_session["websocket"]
84
+ if websocket.client_state == WebSocketState.CONNECTED:
85
+ return user_session["websocket"]
86
+ return None
87
+
88
+ async def disconnect(self, user_id: UUID):
89
+ websocket = self.get_websocket(user_id)
90
+ if websocket:
91
+ await websocket.close()
92
+ self.delete_user(user_id)
93
+
94
+ async def send_json(self, user_id: UUID, data: Dict):
95
+ try:
96
+ websocket = self.get_websocket(user_id)
97
+ if websocket:
98
+ await websocket.send_json(data)
99
+ except Exception as e:
100
+ logging.error(f"Error: Send json: {e}")
101
+
102
+ async def receive_json(self, user_id: UUID) -> Dict:
103
+ try:
104
+ websocket = self.get_websocket(user_id)
105
+ if websocket:
106
+ return await websocket.receive_json()
107
+ except Exception as e:
108
+ logging.error(f"Error: Receive json: {e}")
109
+
110
+ async def receive_bytes(self, user_id: UUID) -> bytes:
111
+ try:
112
+ websocket = self.get_websocket(user_id)
113
+ if websocket:
114
+ return await websocket.receive_bytes()
115
+ except Exception as e:
116
+ logging.error(f"Error: Receive bytes: {e}")
frontend/src/lib/lcmLive.ts CHANGED
@@ -6,6 +6,7 @@ export enum LCMLiveStatus {
6
  DISCONNECTED = "disconnected",
7
  WAIT = "wait",
8
  SEND_FRAME = "send_frame",
 
9
  }
10
 
11
  const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
@@ -19,8 +20,9 @@ export const lcmLiveActions = {
19
  return new Promise((resolve, reject) => {
20
 
21
  try {
 
22
  const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
23
- }:${window.location.host}/api/ws`;
24
 
25
  websocket = new WebSocket(websocketURL);
26
  websocket.onopen = () => {
@@ -37,9 +39,9 @@ export const lcmLiveActions = {
37
  const data = JSON.parse(event.data);
38
  switch (data.status) {
39
  case "connected":
40
- const userId = data.userId;
41
  lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
42
  streamId.set(userId);
 
43
  break;
44
  case "send_frame":
45
  lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
@@ -54,14 +56,16 @@ export const lcmLiveActions = {
54
  break;
55
  case "timeout":
56
  console.log("timeout");
57
- lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
58
  streamId.set(null);
59
- resolve({ status: "timeout" });
 
60
  case "error":
61
  console.log(data.message);
62
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
63
  streamId.set(null);
64
  reject(new Error(data.message));
 
65
  }
66
  };
67
 
@@ -85,12 +89,11 @@ export const lcmLiveActions = {
85
  }
86
  },
87
  async stop() {
88
-
89
  if (websocket) {
90
  websocket.close();
91
  }
92
  websocket = null;
93
- lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
94
  streamId.set(null);
95
  },
96
  };
 
6
  DISCONNECTED = "disconnected",
7
  WAIT = "wait",
8
  SEND_FRAME = "send_frame",
9
+ TIMEOUT = "timeout",
10
  }
11
 
12
  const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
 
20
  return new Promise((resolve, reject) => {
21
 
22
  try {
23
+ const userId = crypto.randomUUID();
24
  const websocketURL = `${window.location.protocol === "https:" ? "wss" : "ws"
25
+ }:${window.location.host}/api/ws/${userId}`;
26
 
27
  websocket = new WebSocket(websocketURL);
28
  websocket.onopen = () => {
 
39
  const data = JSON.parse(event.data);
40
  switch (data.status) {
41
  case "connected":
 
42
  lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
43
  streamId.set(userId);
44
+ resolve({ status: "connected", userId });
45
  break;
46
  case "send_frame":
47
  lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
 
56
  break;
57
  case "timeout":
58
  console.log("timeout");
59
+ lcmLiveStatus.set(LCMLiveStatus.TIMEOUT);
60
  streamId.set(null);
61
+ reject(new Error("timeout"));
62
+ break;
63
  case "error":
64
  console.log(data.message);
65
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
66
  streamId.set(null);
67
  reject(new Error(data.message));
68
+ break;
69
  }
70
  };
71
 
 
89
  }
90
  },
91
  async stop() {
92
+ lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
93
  if (websocket) {
94
  websocket.close();
95
  }
96
  websocket = null;
 
97
  streamId.set(null);
98
  },
99
  };
frontend/src/routes/+page.svelte CHANGED
@@ -20,7 +20,6 @@
20
  let currentQueueSize: number = 0;
21
  let queueCheckerRunning: boolean = false;
22
  let warningMessage: string = '';
23
-
24
  onMount(() => {
25
  getSettings();
26
  });
@@ -59,7 +58,9 @@
59
  }
60
 
61
  $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
62
-
 
 
63
  let disabled = false;
64
  async function toggleLcmLive() {
65
  try {
@@ -70,7 +71,6 @@
70
  }
71
  disabled = true;
72
  await lcmLiveActions.start(getSreamdata);
73
- warningMessage = 'Timeout, please try again.';
74
  disabled = false;
75
  toggleQueueChecker(false);
76
  } else {
 
20
  let currentQueueSize: number = 0;
21
  let queueCheckerRunning: boolean = false;
22
  let warningMessage: string = '';
 
23
  onMount(() => {
24
  getSettings();
25
  });
 
58
  }
59
 
60
  $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
61
+ $: if ($lcmLiveStatus === LCMLiveStatus.TIMEOUT) {
62
+ warningMessage = 'Session timed out. Please try again.';
63
+ }
64
  let disabled = false;
65
  async function toggleLcmLive() {
66
  try {
 
71
  }
72
  disabled = true;
73
  await lcmLiveActions.start(getSreamdata);
 
74
  disabled = false;
75
  toggleQueueChecker(false);
76
  } else {
main.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
2
+ from fastapi.responses import StreamingResponse, JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi import Request
6
+ import markdown2
7
+
8
+ import logging
9
+ from config import config, Args
10
+ from connection_manager import ConnectionManager
11
+ import uuid
12
+ import time
13
+ from types import SimpleNamespace
14
+ from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class
15
+ from device import device, torch_dtype
16
+ import asyncio
17
+ import os
18
+ import time
19
+ import torch
20
+
21
+
22
+ THROTTLE = 1.0 / 120
23
+
24
+
25
+ class App:
26
+ def __init__(self, config: Args, pipeline):
27
+ self.args = config
28
+ self.pipeline = pipeline
29
+ self.app = FastAPI()
30
+ self.conn_manager = ConnectionManager()
31
+ self.init_app()
32
+
33
+ def init_app(self):
34
+ self.app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ @self.app.websocket("/api/ws/{user_id}")
43
+ async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket):
44
+ try:
45
+ await self.conn_manager.connect(
46
+ user_id, websocket, self.args.max_queue_size
47
+ )
48
+ await handle_websocket_data(user_id)
49
+ except ServerFullException as e:
50
+ logging.error(f"Server Full: {e}")
51
+ finally:
52
+ await self.conn_manager.disconnect(user_id)
53
+ logging.info(f"User disconnected: {user_id}")
54
+
55
+ async def handle_websocket_data(user_id: uuid.UUID):
56
+ if not self.conn_manager.check_user(user_id):
57
+ return HTTPException(status_code=404, detail="User not found")
58
+ last_time = time.time()
59
+ try:
60
+ while True:
61
+ if (
62
+ self.args.timeout > 0
63
+ and time.time() - last_time > self.args.timeout
64
+ ):
65
+ await self.conn_manager.send_json(
66
+ user_id,
67
+ {
68
+ "status": "timeout",
69
+ "message": "Your session has ended",
70
+ },
71
+ )
72
+ await self.conn_manager.disconnect(user_id)
73
+ return
74
+ data = await self.conn_manager.receive_json(user_id)
75
+ if data["status"] != "next_frame":
76
+ asyncio.sleep(THROTTLE)
77
+ continue
78
+
79
+ params = await self.conn_manager.receive_json(user_id)
80
+ params = pipeline.InputParams(**params)
81
+ info = pipeline.Info()
82
+ params = SimpleNamespace(**params.dict())
83
+ if info.input_mode == "image":
84
+ image_data = await self.conn_manager.receive_bytes(user_id)
85
+ if len(image_data) == 0:
86
+ await self.conn_manager.send_json(
87
+ user_id, {"status": "send_frame"}
88
+ )
89
+ await asyncio.sleep(THROTTLE)
90
+ continue
91
+ params.image = bytes_to_pil(image_data)
92
+ await self.conn_manager.update_data(user_id, params)
93
+ await self.conn_manager.send_json(user_id, {"status": "wait"})
94
+
95
+ except Exception as e:
96
+ logging.error(f"Websocket Error: {e}, {user_id} ")
97
+ await self.conn_manager.disconnect(user_id)
98
+
99
+ @self.app.get("/api/queue")
100
+ async def get_queue_size():
101
+ queue_size = self.conn_manager.get_user_count()
102
+ return JSONResponse({"queue_size": queue_size})
103
+
104
+ @self.app.get("/api/stream/{user_id}")
105
+ async def stream(user_id: uuid.UUID, request: Request):
106
+ try:
107
+
108
+ async def generate():
109
+ last_params = SimpleNamespace()
110
+ while True:
111
+ last_time = time.time()
112
+ params = await self.conn_manager.get_latest_data(user_id)
113
+ if not vars(params) or params.__dict__ == last_params.__dict__:
114
+ await self.conn_manager.send_json(
115
+ user_id, {"status": "send_frame"}
116
+ )
117
+ continue
118
+
119
+ last_params = params
120
+ image = pipeline.predict(params)
121
+ if image is None:
122
+ await self.conn_manager.send_json(
123
+ user_id, {"status": "send_frame"}
124
+ )
125
+ continue
126
+ frame = pil_to_frame(image)
127
+ yield frame
128
+ # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
129
+ if not is_firefox(request.headers["user-agent"]):
130
+ yield frame
131
+ await self.conn_manager.send_json(
132
+ user_id, {"status": "send_frame"}
133
+ )
134
+ if self.args.debug:
135
+ print(f"Time taken: {time.time() - last_time}")
136
+
137
+ return StreamingResponse(
138
+ generate(),
139
+ media_type="multipart/x-mixed-replace;boundary=frame",
140
+ headers={"Cache-Control": "no-cache"},
141
+ )
142
+ except Exception as e:
143
+ logging.error(f"Streaming Error: {e}, {user_id} ")
144
+ return HTTPException(status_code=404, detail="User not found")
145
+
146
+ # route to setup frontend
147
+ @self.app.get("/api/settings")
148
+ async def settings():
149
+ info_schema = pipeline.Info.schema()
150
+ info = pipeline.Info()
151
+ if info.page_content:
152
+ page_content = markdown2.markdown(info.page_content)
153
+
154
+ input_params = pipeline.InputParams.schema()
155
+ return JSONResponse(
156
+ {
157
+ "info": info_schema,
158
+ "input_params": input_params,
159
+ "max_queue_size": self.args.max_queue_size,
160
+ "page_content": page_content if info.page_content else "",
161
+ }
162
+ )
163
+
164
+ if not os.path.exists("public"):
165
+ os.makedirs("public")
166
+
167
+ self.app.mount("/", StaticFiles(directory="public", html=True), name="public")
168
+
169
+
170
+ pipeline_class = get_pipeline_class(config.pipeline)
171
+ pipeline = pipeline_class(config, device, torch_dtype)
172
+ app = App(config, pipeline).app
173
+
174
+ if __name__ == "__main__":
175
+ import uvicorn
176
+
177
+ uvicorn.run(
178
+ "main:app",
179
+ host=config.host,
180
+ port=config.port,
181
+ reload=config.reload,
182
+ ssl_certfile=config.ssl_certfile,
183
+ ssl_keyfile=config.ssl_keyfile,
184
+ )
run.py DELETED
@@ -1,12 +0,0 @@
1
- if __name__ == "__main__":
2
- import uvicorn
3
- from config import args
4
-
5
- uvicorn.run(
6
- "app:app",
7
- host=args.host,
8
- port=args.port,
9
- reload=args.reload,
10
- ssl_certfile=args.ssl_certfile,
11
- ssl_keyfile=args.ssl_keyfile,
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
user_queue.py DELETED
@@ -1,63 +0,0 @@
1
- from typing import Dict
2
- from uuid import UUID
3
- import asyncio
4
- from fastapi import WebSocket
5
- from types import SimpleNamespace
6
- from typing import Dict
7
- from typing import Union
8
-
9
- UserDataContent = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
10
-
11
-
12
- class UserData:
13
- def __init__(self):
14
- self.data_content: Dict[UUID, UserDataContent] = {}
15
-
16
- async def create_user(self, user_id: UUID, websocket: WebSocket):
17
- self.data_content[user_id] = {
18
- "websocket": websocket,
19
- "queue": asyncio.Queue(),
20
- }
21
- await asyncio.sleep(1)
22
-
23
- def check_user(self, user_id: UUID) -> bool:
24
- return user_id in self.data_content
25
-
26
- async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
27
- user_session = self.data_content[user_id]
28
- queue = user_session["queue"]
29
- while not queue.empty():
30
- try:
31
- queue.get_nowait()
32
- except asyncio.QueueEmpty:
33
- continue
34
- await queue.put(new_data)
35
-
36
- async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
37
- user_session = self.data_content[user_id]
38
- queue = user_session["queue"]
39
-
40
- try:
41
- return await queue.get()
42
- except asyncio.QueueEmpty:
43
- return None
44
-
45
- def delete_user(self, user_id: UUID):
46
- user_session = self.data_content[user_id]
47
- queue = user_session["queue"]
48
- while not queue.empty():
49
- try:
50
- queue.get_nowait()
51
- except asyncio.QueueEmpty:
52
- continue
53
- if user_id in self.data_content:
54
- del self.data_content[user_id]
55
-
56
- def get_user_count(self) -> int:
57
- return len(self.data_content)
58
-
59
- def get_websocket(self, user_id: UUID) -> WebSocket:
60
- return self.data_content[user_id]["websocket"]
61
-
62
-
63
- user_data = UserData()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util.py CHANGED
@@ -1,7 +1,5 @@
1
  from importlib import import_module
2
  from types import ModuleType
3
- from typing import Dict, Any
4
- from pydantic import BaseModel as PydanticBaseModel, Field
5
  from PIL import Image
6
  import io
7
 
 
1
  from importlib import import_module
2
  from types import ModuleType
 
 
3
  from PIL import Image
4
  import io
5