pseudotheos
commited on
Commit
·
dfc480a
1
Parent(s):
d362f04
Update app.py
Browse files
app.py
CHANGED
@@ -26,6 +26,7 @@ from transformers import AutoFeatureExtractor, CLIPFeatureExtractor
|
|
26 |
import random
|
27 |
import time
|
28 |
import tempfile
|
|
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
# Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
@@ -43,16 +44,32 @@ logger.addHandler(stream_handler)
|
|
43 |
class ImageGenerationQueue:
|
44 |
def __init__(self):
|
45 |
self.queue = asyncio.Queue()
|
|
|
|
|
46 |
|
47 |
-
|
|
|
|
|
|
|
48 |
await self.queue.put(task)
|
|
|
|
|
|
|
49 |
|
50 |
async def process_queue(self):
|
51 |
while True:
|
52 |
task = await self.queue.get()
|
53 |
await task()
|
|
|
|
|
|
|
54 |
self.queue.task_done()
|
55 |
|
|
|
|
|
|
|
|
|
|
|
56 |
app = FastAPI()
|
57 |
queue_manager = ImageGenerationQueue()
|
58 |
|
@@ -230,13 +247,13 @@ def generate_image_from_parameters(prompt, guidance_scale, controlnet_scale, con
|
|
230 |
output_image_binary = output_image_io.read()
|
231 |
|
232 |
# Return the generated image binary data
|
233 |
-
logger.debug("Output Values: generated_image=<binary data>")
|
234 |
return output_image_binary
|
235 |
|
236 |
except Exception as e:
|
237 |
# Handle exceptions and return an error message if something goes wrong
|
238 |
return str(e)
|
239 |
|
|
|
240 |
app.add_middleware(
|
241 |
CORSMiddleware,
|
242 |
allow_origins=["*"], # You can replace ["*"] with specific origins if needed
|
@@ -256,7 +273,7 @@ async def generate_image(
|
|
256 |
sampler_type: str = Form(...),
|
257 |
image: UploadFile = File(...)
|
258 |
):
|
259 |
-
|
260 |
try:
|
261 |
# Save the uploaded image to a temporary file
|
262 |
temp_image_path = f"/tmp/{int(time.time())}_{image.filename}"
|
@@ -284,9 +301,11 @@ async def generate_image(
|
|
284 |
return "Failed to generate image"
|
285 |
|
286 |
try:
|
287 |
-
|
288 |
position_in_queue = queue_manager.queue.qsize()
|
289 |
-
|
|
|
|
|
290 |
return {"position_in_queue": position_in_queue, "total_queue_size": total_queue_size}
|
291 |
|
292 |
except Exception as e:
|
|
|
26 |
import random
|
27 |
import time
|
28 |
import tempfile
|
29 |
+
import threading
|
30 |
|
31 |
logger = logging.getLogger(__name__)
|
32 |
# Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
|
44 |
class ImageGenerationQueue:
|
45 |
def __init__(self):
|
46 |
self.queue = asyncio.Queue()
|
47 |
+
self.queue_size = 0
|
48 |
+
self.queue_lock = threading.Lock()
|
49 |
|
50 |
+
def add_task(self, task):
|
51 |
+
asyncio.run_coroutine_threadsafe(self._add_task(task), loop=asyncio.get_event_loop())
|
52 |
+
|
53 |
+
async def _add_task(self, task):
|
54 |
await self.queue.put(task)
|
55 |
+
# Update queue size in a thread-safe manner
|
56 |
+
with self.queue_lock:
|
57 |
+
self.queue_size = self.queue.qsize()
|
58 |
|
59 |
async def process_queue(self):
|
60 |
while True:
|
61 |
task = await self.queue.get()
|
62 |
await task()
|
63 |
+
# Update queue size in a thread-safe manner
|
64 |
+
with self.queue_lock:
|
65 |
+
self.queue_size = self.queue.qsize()
|
66 |
self.queue.task_done()
|
67 |
|
68 |
+
async def get_total_queue_size(self):
|
69 |
+
# Return the queue size in a thread-safe manner
|
70 |
+
with self.queue_lock:
|
71 |
+
return self.queue_size
|
72 |
+
|
73 |
app = FastAPI()
|
74 |
queue_manager = ImageGenerationQueue()
|
75 |
|
|
|
247 |
output_image_binary = output_image_io.read()
|
248 |
|
249 |
# Return the generated image binary data
|
|
|
250 |
return output_image_binary
|
251 |
|
252 |
except Exception as e:
|
253 |
# Handle exceptions and return an error message if something goes wrong
|
254 |
return str(e)
|
255 |
|
256 |
+
|
257 |
app.add_middleware(
|
258 |
CORSMiddleware,
|
259 |
allow_origins=["*"], # You can replace ["*"] with specific origins if needed
|
|
|
273 |
sampler_type: str = Form(...),
|
274 |
image: UploadFile = File(...)
|
275 |
):
|
276 |
+
def generate_image_task():
|
277 |
try:
|
278 |
# Save the uploaded image to a temporary file
|
279 |
temp_image_path = f"/tmp/{int(time.time())}_{image.filename}"
|
|
|
301 |
return "Failed to generate image"
|
302 |
|
303 |
try:
|
304 |
+
queue_manager.add_task(generate_image_task)
|
305 |
position_in_queue = queue_manager.queue.qsize()
|
306 |
+
# Total queue size is still async
|
307 |
+
total_queue_size = await get_total_queue_size(queue_manager.queue) # Implement this function
|
308 |
+
|
309 |
return {"position_in_queue": position_in_queue, "total_queue_size": total_queue_size}
|
310 |
|
311 |
except Exception as e:
|