radames commited on
Commit
1383dae
1 Parent(s): 8c2b71b

extra msgs

Browse files
Files changed (2) hide show
  1. app.py +38 -32
  2. public/index.html +7 -5
app.py CHANGED
@@ -55,7 +55,7 @@ def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232)
55
  strength=strength,
56
  num_inference_steps=num_inference_steps,
57
  guidance_scale=guidance_scale,
58
- lcm_origin_steps=20,
59
  output_type="pil",
60
  )
61
  nsfw_content_detected = (
@@ -106,9 +106,12 @@ async def websocket_endpoint(websocket: WebSocket):
106
  "queue": asyncio.Queue(),
107
  "params": params,
108
  }
 
 
 
109
  await handle_websocket_data(websocket, uid)
110
  except WebSocketDisconnect as e:
111
- logging.error(f"Error: {e}")
112
  traceback.print_exc()
113
  finally:
114
  print(f"User disconnected: {uid}")
@@ -131,36 +134,39 @@ async def get_queue_size():
131
  @app.get("/stream/{user_id}")
132
  async def stream(user_id: uuid.UUID):
133
  uid = str(user_id)
134
- user_queue = user_queue_map[uid]
135
- queue = user_queue["queue"]
136
- params = user_queue["params"]
137
- seed = params.seed
138
- prompt = params.prompt
139
- strength = params.strength
140
- guidance_scale = params.guidance_scale
141
- if not queue:
142
- return HTTPException(status_code=404, detail="User not found")
 
 
 
 
 
143
 
144
- async def generate():
145
- while True:
146
- input_image = await queue.get()
147
- if input_image is None:
148
- continue
149
-
150
- image = predict(input_image, prompt, guidance_scale, strength, seed)
151
- if image is None:
152
- continue
153
- frame_data = io.BytesIO()
154
- image.save(frame_data, format="JPEG")
155
- frame_data = frame_data.getvalue()
156
- if frame_data is not None and len(frame_data) > 0:
157
- yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
158
-
159
- await asyncio.sleep(1.0 / 120.0)
160
-
161
- return StreamingResponse(
162
- generate(), media_type="multipart/x-mixed-replace;boundary=frame"
163
- )
164
 
165
 
166
  async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
@@ -182,7 +188,7 @@ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
182
  continue
183
  await queue.put(pil_image)
184
  if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
185
- await websocket.send_json(
186
  {
187
  "status": "timeout",
188
  "message": "Your session has ended",
 
55
  strength=strength,
56
  num_inference_steps=num_inference_steps,
57
  guidance_scale=guidance_scale,
58
+ lcm_origin_steps=30,
59
  output_type="pil",
60
  )
61
  nsfw_content_detected = (
 
106
  "queue": asyncio.Queue(),
107
  "params": params,
108
  }
109
+ await websocket.send_json(
110
+ {"status": "start", "message": "Start Streaming", "userId": uid}
111
+ )
112
  await handle_websocket_data(websocket, uid)
113
  except WebSocketDisconnect as e:
114
+ logging.error(f"WebSocket Error: {e}, {uid}")
115
  traceback.print_exc()
116
  finally:
117
  print(f"User disconnected: {uid}")
 
134
  @app.get("/stream/{user_id}")
135
  async def stream(user_id: uuid.UUID):
136
  uid = str(user_id)
137
+ try:
138
+ user_queue = user_queue_map[uid]
139
+ queue = user_queue["queue"]
140
+ params = user_queue["params"]
141
+ seed = params.seed
142
+ prompt = params.prompt
143
+ strength = params.strength
144
+ guidance_scale = params.guidance_scale
145
+
146
+ async def generate():
147
+ while True:
148
+ input_image = await queue.get()
149
+ if input_image is None:
150
+ continue
151
 
152
+ image = predict(input_image, prompt, guidance_scale, strength, seed)
153
+ if image is None:
154
+ continue
155
+ frame_data = io.BytesIO()
156
+ image.save(frame_data, format="JPEG")
157
+ frame_data = frame_data.getvalue()
158
+ if frame_data is not None and len(frame_data) > 0:
159
+ yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
160
+
161
+ await asyncio.sleep(1.0 / 120.0)
162
+
163
+ return StreamingResponse(
164
+ generate(), media_type="multipart/x-mixed-replace;boundary=frame"
165
+ )
166
+ except Exception as e:
167
+ logging.error(f"Streaming Error: {e}, {user_queue_map}")
168
+ traceback.print_exc()
169
+ return HTTPException(status_code=404, detail="User not found")
 
 
170
 
171
 
172
  async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
 
188
  continue
189
  await queue.put(pil_image)
190
  if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
191
+ await websocket.send_json(
192
  {
193
  "status": "timeout",
194
  "message": "Your session has ended",
public/index.html CHANGED
@@ -47,9 +47,10 @@
47
  switch (data.status) {
48
  case "success":
49
  socket.send(JSON.stringify(params));
 
 
50
  const userId = data.userId;
51
- liveImage.src = `/stream/${userId}`;
52
- initVideoStream();
53
  break;
54
  case "timeout":
55
  stop();
@@ -72,7 +73,8 @@
72
  websocket.send(blob);
73
  }
74
 
75
- function initVideoStream() {
 
76
  const constraints = {
77
  audio: false,
78
  video: { width: 512, height: 512 },
@@ -118,7 +120,7 @@
118
  }
119
  setTimeout(() => {
120
  errorEl.hidden = true;
121
- }, 5000);
122
  }
123
 
124
 
@@ -154,7 +156,7 @@
154
  .catch((err) => {
155
  console.log(err);
156
  })
157
- , 1000);
158
  </script>
159
  </head>
160
 
 
47
  switch (data.status) {
48
  case "success":
49
  socket.send(JSON.stringify(params));
50
+ break;
51
+ case "start":
52
  const userId = data.userId;
53
+ initVideoStream(userId);
 
54
  break;
55
  case "timeout":
56
  stop();
 
73
  websocket.send(blob);
74
  }
75
 
76
+ function initVideoStream(userId) {
77
+ liveImage.src = `/stream/${userId}`;
78
  const constraints = {
79
  audio: false,
80
  video: { width: 512, height: 512 },
 
120
  }
121
  setTimeout(() => {
122
  errorEl.hidden = true;
123
+ }, 2000);
124
  }
125
 
126
 
 
156
  .catch((err) => {
157
  console.log(err);
158
  })
159
+ , 5000);
160
  </script>
161
  </head>
162