radames HF staff commited on
Commit
87914c1
1 Parent(s): ef88349

better Websocket handling

Browse files
Files changed (3) hide show
  1. connection_manager.py +0 -5
  2. main.py +22 -31
  3. pipelines/txt2img.py +1 -1
connection_manager.py CHANGED
@@ -48,11 +48,6 @@ class ConnectionManager:
48
  user_session = self.active_connections.get(user_id)
49
  if user_session:
50
  queue = user_session["queue"]
51
- while not queue.empty():
52
- try:
53
- queue.get_nowait()
54
- except asyncio.QueueEmpty:
55
- continue
56
  await queue.put(new_data)
57
 
58
  async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
 
48
  user_session = self.active_connections.get(user_id)
49
  if user_session:
50
  queue = user_session["queue"]
 
 
 
 
 
51
  await queue.put(new_data)
52
 
53
  async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
main.py CHANGED
@@ -7,7 +7,7 @@ import markdown2
7
 
8
  import logging
9
  from config import config, Args
10
- from connection_manager import ConnectionManager
11
  import uuid
12
  import time
13
  from types import SimpleNamespace
@@ -72,25 +72,22 @@ class App:
72
  await self.conn_manager.disconnect(user_id)
73
  return
74
  data = await self.conn_manager.receive_json(user_id)
75
- if data["status"] != "next_frame":
76
- asyncio.sleep(THROTTLE)
77
- continue
78
-
79
- params = await self.conn_manager.receive_json(user_id)
80
- params = pipeline.InputParams(**params)
81
- info = pipeline.Info()
82
- params = SimpleNamespace(**params.dict())
83
- if info.input_mode == "image":
84
- image_data = await self.conn_manager.receive_bytes(user_id)
85
- if len(image_data) == 0:
86
- await self.conn_manager.send_json(
87
- user_id, {"status": "send_frame"}
88
- )
89
- await asyncio.sleep(THROTTLE)
90
- continue
91
- params.image = bytes_to_pil(image_data)
92
- await self.conn_manager.update_data(user_id, params)
93
- await self.conn_manager.send_json(user_id, {"status": "wait"})
94
 
95
  except Exception as e:
96
  logging.error(f"Websocket Error: {e}, {user_id} ")
@@ -109,28 +106,22 @@ class App:
109
  last_params = SimpleNamespace()
110
  while True:
111
  last_time = time.time()
 
 
 
112
  params = await self.conn_manager.get_latest_data(user_id)
113
- if not vars(params) or params.__dict__ == last_params.__dict__:
114
- await self.conn_manager.send_json(
115
- user_id, {"status": "send_frame"}
116
- )
117
  continue
118
-
119
  last_params = params
120
  image = pipeline.predict(params)
121
  if image is None:
122
- await self.conn_manager.send_json(
123
- user_id, {"status": "send_frame"}
124
- )
125
  continue
126
  frame = pil_to_frame(image)
127
  yield frame
128
  # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
129
  if not is_firefox(request.headers["user-agent"]):
130
  yield frame
131
- await self.conn_manager.send_json(
132
- user_id, {"status": "send_frame"}
133
- )
134
  if self.args.debug:
135
  print(f"Time taken: {time.time() - last_time}")
136
 
 
7
 
8
  import logging
9
  from config import config, Args
10
+ from connection_manager import ConnectionManager, ServerFullException
11
  import uuid
12
  import time
13
  from types import SimpleNamespace
 
72
  await self.conn_manager.disconnect(user_id)
73
  return
74
  data = await self.conn_manager.receive_json(user_id)
75
+ if data["status"] == "next_frame":
76
+ info = pipeline.Info()
77
+ params = await self.conn_manager.receive_json(user_id)
78
+ params = pipeline.InputParams(**params)
79
+ params = SimpleNamespace(**params.dict())
80
+ if info.input_mode == "image":
81
+ image_data = await self.conn_manager.receive_bytes(user_id)
82
+ if len(image_data) == 0:
83
+ await self.conn_manager.send_json(
84
+ user_id, {"status": "send_frame"}
85
+ )
86
+ continue
87
+ params.image = bytes_to_pil(image_data)
88
+
89
+ await self.conn_manager.update_data(user_id, params)
90
+ await self.conn_manager.send_json(user_id, {"status": "wait"})
 
 
 
91
 
92
  except Exception as e:
93
  logging.error(f"Websocket Error: {e}, {user_id} ")
 
106
  last_params = SimpleNamespace()
107
  while True:
108
  last_time = time.time()
109
+ await self.conn_manager.send_json(
110
+ user_id, {"status": "send_frame"}
111
+ )
112
  params = await self.conn_manager.get_latest_data(user_id)
113
+ if params.__dict__ == last_params.__dict__ or params is None:
114
+ await asyncio.sleep(THROTTLE)
 
 
115
  continue
 
116
  last_params = params
117
  image = pipeline.predict(params)
118
  if image is None:
 
 
 
119
  continue
120
  frame = pil_to_frame(image)
121
  yield frame
122
  # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
123
  if not is_firefox(request.headers["user-agent"]):
124
  yield frame
 
 
 
125
  if self.args.debug:
126
  print(f"Time taken: {time.time() - last_time}")
127
 
pipelines/txt2img.py CHANGED
@@ -7,10 +7,10 @@ try:
7
  except:
8
  pass
9
 
10
- import psutil
11
  from config import Args
12
  from pydantic import BaseModel, Field
13
  from PIL import Image
 
14
 
15
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
16
  taesd_model = "madebyollin/taesd"
 
7
  except:
8
  pass
9
 
 
10
  from config import Args
11
  from pydantic import BaseModel, Field
12
  from PIL import Image
13
+ from typing import List
14
 
15
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
16
  taesd_model = "madebyollin/taesd"