pseudotheos commited on
Commit
dfc480a
1 Parent(s): d362f04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -5
app.py CHANGED
@@ -26,6 +26,7 @@ from transformers import AutoFeatureExtractor, CLIPFeatureExtractor
26
  import random
27
  import time
28
  import tempfile
 
29
 
30
  logger = logging.getLogger(__name__)
31
  # Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
@@ -43,16 +44,32 @@ logger.addHandler(stream_handler)
43
  class ImageGenerationQueue:
44
  def __init__(self):
45
  self.queue = asyncio.Queue()
 
 
46
 
47
- async def add_task(self, task):
 
 
 
48
  await self.queue.put(task)
 
 
 
49
 
50
  async def process_queue(self):
51
  while True:
52
  task = await self.queue.get()
53
  await task()
 
 
 
54
  self.queue.task_done()
55
 
 
 
 
 
 
56
  app = FastAPI()
57
  queue_manager = ImageGenerationQueue()
58
 
@@ -230,13 +247,13 @@ def generate_image_from_parameters(prompt, guidance_scale, controlnet_scale, con
230
  output_image_binary = output_image_io.read()
231
 
232
  # Return the generated image binary data
233
- logger.debug("Output Values: generated_image=<binary data>")
234
  return output_image_binary
235
 
236
  except Exception as e:
237
  # Handle exceptions and return an error message if something goes wrong
238
  return str(e)
239
 
 
240
  app.add_middleware(
241
  CORSMiddleware,
242
  allow_origins=["*"], # You can replace ["*"] with specific origins if needed
@@ -256,7 +273,7 @@ async def generate_image(
256
  sampler_type: str = Form(...),
257
  image: UploadFile = File(...)
258
  ):
259
- async def generate_image_task():
260
  try:
261
  # Save the uploaded image to a temporary file
262
  temp_image_path = f"/tmp/{int(time.time())}_{image.filename}"
@@ -284,9 +301,11 @@ async def generate_image(
284
  return "Failed to generate image"
285
 
286
  try:
287
- await queue_manager.add_task(generate_image_task)
288
  position_in_queue = queue_manager.queue.qsize()
289
- total_queue_size = queue_manager.queue.qsize() # Implement this function
 
 
290
  return {"position_in_queue": position_in_queue, "total_queue_size": total_queue_size}
291
 
292
  except Exception as e:
 
26
  import random
27
  import time
28
  import tempfile
29
+ import threading
30
 
31
  logger = logging.getLogger(__name__)
32
  # Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
 
44
  class ImageGenerationQueue:
45
  def __init__(self):
46
  self.queue = asyncio.Queue()
47
+ self.queue_size = 0
48
+ self.queue_lock = threading.Lock()
49
 
50
+ def add_task(self, task):
51
+ asyncio.run_coroutine_threadsafe(self._add_task(task), loop=asyncio.get_event_loop())
52
+
53
+ async def _add_task(self, task):
54
  await self.queue.put(task)
55
+ # Update queue size in a thread-safe manner
56
+ with self.queue_lock:
57
+ self.queue_size = self.queue.qsize()
58
 
59
  async def process_queue(self):
60
  while True:
61
  task = await self.queue.get()
62
  await task()
63
+ # Update queue size in a thread-safe manner
64
+ with self.queue_lock:
65
+ self.queue_size = self.queue.qsize()
66
  self.queue.task_done()
67
 
68
+ async def get_total_queue_size(self):
69
+ # Return the queue size in a thread-safe manner
70
+ with self.queue_lock:
71
+ return self.queue_size
72
+
73
  app = FastAPI()
74
  queue_manager = ImageGenerationQueue()
75
 
 
247
  output_image_binary = output_image_io.read()
248
 
249
  # Return the generated image binary data
 
250
  return output_image_binary
251
 
252
  except Exception as e:
253
  # Handle exceptions and return an error message if something goes wrong
254
  return str(e)
255
 
256
+
257
  app.add_middleware(
258
  CORSMiddleware,
259
  allow_origins=["*"], # You can replace ["*"] with specific origins if needed
 
273
  sampler_type: str = Form(...),
274
  image: UploadFile = File(...)
275
  ):
276
+ def generate_image_task():
277
  try:
278
  # Save the uploaded image to a temporary file
279
  temp_image_path = f"/tmp/{int(time.time())}_{image.filename}"
 
301
  return "Failed to generate image"
302
 
303
  try:
304
+ queue_manager.add_task(generate_image_task)
305
  position_in_queue = queue_manager.queue.qsize()
306
+ # Total queue size is still async
307
+ total_queue_size = await get_total_queue_size(queue_manager.queue) # Implement this function
308
+
309
  return {"position_in_queue": position_in_queue, "total_queue_size": total_queue_size}
310
 
311
  except Exception as e: