Spaces:
Running
on
A100
Running
on
A100
extra msgs
Browse files- app.py +38 -32
- public/index.html +7 -5
app.py
CHANGED
@@ -55,7 +55,7 @@ def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232)
|
|
55 |
strength=strength,
|
56 |
num_inference_steps=num_inference_steps,
|
57 |
guidance_scale=guidance_scale,
|
58 |
-
lcm_origin_steps=
|
59 |
output_type="pil",
|
60 |
)
|
61 |
nsfw_content_detected = (
|
@@ -106,9 +106,12 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
106 |
"queue": asyncio.Queue(),
|
107 |
"params": params,
|
108 |
}
|
|
|
|
|
|
|
109 |
await handle_websocket_data(websocket, uid)
|
110 |
except WebSocketDisconnect as e:
|
111 |
-
logging.error(f"Error: {e}")
|
112 |
traceback.print_exc()
|
113 |
finally:
|
114 |
print(f"User disconnected: {uid}")
|
@@ -131,36 +134,39 @@ async def get_queue_size():
|
|
131 |
@app.get("/stream/{user_id}")
|
132 |
async def stream(user_id: uuid.UUID):
|
133 |
uid = str(user_id)
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
generate(), media_type="multipart/x-mixed-replace;boundary=frame"
|
163 |
-
)
|
164 |
|
165 |
|
166 |
async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
|
@@ -182,7 +188,7 @@ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
|
|
182 |
continue
|
183 |
await queue.put(pil_image)
|
184 |
if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
|
185 |
-
await websocket.send_json(
|
186 |
{
|
187 |
"status": "timeout",
|
188 |
"message": "Your session has ended",
|
|
|
55 |
strength=strength,
|
56 |
num_inference_steps=num_inference_steps,
|
57 |
guidance_scale=guidance_scale,
|
58 |
+
lcm_origin_steps=30,
|
59 |
output_type="pil",
|
60 |
)
|
61 |
nsfw_content_detected = (
|
|
|
106 |
"queue": asyncio.Queue(),
|
107 |
"params": params,
|
108 |
}
|
109 |
+
await websocket.send_json(
|
110 |
+
{"status": "start", "message": "Start Streaming", "userId": uid}
|
111 |
+
)
|
112 |
await handle_websocket_data(websocket, uid)
|
113 |
except WebSocketDisconnect as e:
|
114 |
+
logging.error(f"WebSocket Error: {e}, {uid}")
|
115 |
traceback.print_exc()
|
116 |
finally:
|
117 |
print(f"User disconnected: {uid}")
|
|
|
134 |
@app.get("/stream/{user_id}")
|
135 |
async def stream(user_id: uuid.UUID):
|
136 |
uid = str(user_id)
|
137 |
+
try:
|
138 |
+
user_queue = user_queue_map[uid]
|
139 |
+
queue = user_queue["queue"]
|
140 |
+
params = user_queue["params"]
|
141 |
+
seed = params.seed
|
142 |
+
prompt = params.prompt
|
143 |
+
strength = params.strength
|
144 |
+
guidance_scale = params.guidance_scale
|
145 |
+
|
146 |
+
async def generate():
|
147 |
+
while True:
|
148 |
+
input_image = await queue.get()
|
149 |
+
if input_image is None:
|
150 |
+
continue
|
151 |
|
152 |
+
image = predict(input_image, prompt, guidance_scale, strength, seed)
|
153 |
+
if image is None:
|
154 |
+
continue
|
155 |
+
frame_data = io.BytesIO()
|
156 |
+
image.save(frame_data, format="JPEG")
|
157 |
+
frame_data = frame_data.getvalue()
|
158 |
+
if frame_data is not None and len(frame_data) > 0:
|
159 |
+
yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
|
160 |
+
|
161 |
+
await asyncio.sleep(1.0 / 120.0)
|
162 |
+
|
163 |
+
return StreamingResponse(
|
164 |
+
generate(), media_type="multipart/x-mixed-replace;boundary=frame"
|
165 |
+
)
|
166 |
+
except Exception as e:
|
167 |
+
logging.error(f"Streaming Error: {e}, {user_queue_map}")
|
168 |
+
traceback.print_exc()
|
169 |
+
return HTTPException(status_code=404, detail="User not found")
|
|
|
|
|
170 |
|
171 |
|
172 |
async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
|
|
|
188 |
continue
|
189 |
await queue.put(pil_image)
|
190 |
if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
|
191 |
+
await websocket.send_json(
|
192 |
{
|
193 |
"status": "timeout",
|
194 |
"message": "Your session has ended",
|
public/index.html
CHANGED
@@ -47,9 +47,10 @@
|
|
47 |
switch (data.status) {
|
48 |
case "success":
|
49 |
socket.send(JSON.stringify(params));
|
|
|
|
|
50 |
const userId = data.userId;
|
51 |
-
|
52 |
-
initVideoStream();
|
53 |
break;
|
54 |
case "timeout":
|
55 |
stop();
|
@@ -72,7 +73,8 @@
|
|
72 |
websocket.send(blob);
|
73 |
}
|
74 |
|
75 |
-
function initVideoStream() {
|
|
|
76 |
const constraints = {
|
77 |
audio: false,
|
78 |
video: { width: 512, height: 512 },
|
@@ -118,7 +120,7 @@
|
|
118 |
}
|
119 |
setTimeout(() => {
|
120 |
errorEl.hidden = true;
|
121 |
-
},
|
122 |
}
|
123 |
|
124 |
|
@@ -154,7 +156,7 @@
|
|
154 |
.catch((err) => {
|
155 |
console.log(err);
|
156 |
})
|
157 |
-
,
|
158 |
</script>
|
159 |
</head>
|
160 |
|
|
|
47 |
switch (data.status) {
|
48 |
case "success":
|
49 |
socket.send(JSON.stringify(params));
|
50 |
+
break;
|
51 |
+
case "start":
|
52 |
const userId = data.userId;
|
53 |
+
initVideoStream(userId);
|
|
|
54 |
break;
|
55 |
case "timeout":
|
56 |
stop();
|
|
|
73 |
websocket.send(blob);
|
74 |
}
|
75 |
|
76 |
+
function initVideoStream(userId) {
|
77 |
+
liveImage.src = `/stream/${userId}`;
|
78 |
const constraints = {
|
79 |
audio: false,
|
80 |
video: { width: 512, height: 512 },
|
|
|
120 |
}
|
121 |
setTimeout(() => {
|
122 |
errorEl.hidden = true;
|
123 |
+
}, 2000);
|
124 |
}
|
125 |
|
126 |
|
|
|
156 |
.catch((err) => {
|
157 |
console.log(err);
|
158 |
})
|
159 |
+
, 5000);
|
160 |
</script>
|
161 |
</head>
|
162 |
|