pseudotheos commited on
Commit
5d45828
·
1 Parent(s): 50ea979

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -11
app.py CHANGED
@@ -46,6 +46,7 @@ class ImageGenerationQueue:
46
  self.queue = asyncio.Queue()
47
  self.queue_size = 0
48
  self.queue_lock = threading.Lock()
 
49
 
50
  def add_task(self, task):
51
  asyncio.run_coroutine_threadsafe(self._add_task(task), loop=asyncio.get_event_loop())
@@ -262,6 +263,23 @@ app.add_middleware(
262
  allow_headers=["*"], # Allow all headers
263
  )
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  @app.post("/generate_image")
266
  async def generate_image(
267
  prompt: str = Form(...),
@@ -274,10 +292,10 @@ async def generate_image(
274
  image: UploadFile = File(...),
275
  background_tasks: BackgroundTasks = BackgroundTasks()
276
  ):
277
- async def generate_image_task(position_in_queue):
278
  try:
279
  # Save the uploaded image to a temporary file
280
- temp_image_path = f"/tmp/{position_in_queue}_{image.filename}"
281
  with open(temp_image_path, "wb") as temp_image:
282
  temp_image.write(image.file.read())
283
 
@@ -289,25 +307,30 @@ async def generate_image(
289
  if generated_image is None:
290
  return "Failed to generate image"
291
 
292
- # Save the generated image as binary data
293
- output_image_io = io.BytesIO()
294
- generated_image.save(output_image_io, format="PNG")
295
- output_image_io.seek(0)
 
 
 
 
296
 
297
- # Return the image as a streaming response
298
- return StreamingResponse(content=output_image_io, media_type="image/png")
299
 
300
  except Exception as e:
301
  logger.error("Error occurred during image generation: %s", str(e))
302
  return "Failed to generate image"
303
 
304
  try:
305
- position_in_queue = queue_manager.queue.qsize() + 1
306
- queue_manager.add_task(lambda _: generate_image_task(position_in_queue))
 
307
  # Total queue size is still async
308
  total_queue_size = await queue_manager.get_total_queue_size() # Implement this function
309
 
310
- return {"position_in_queue": position_in_queue, "total_queue_size": total_queue_size}
311
 
312
  except Exception as e:
313
  logger.error("Error occurred during image generation: %s", str(e))
 
46
  self.queue = asyncio.Queue()
47
  self.queue_size = 0
48
  self.queue_lock = threading.Lock()
49
+ self.next_id = 0
50
 
51
  def add_task(self, task):
52
  asyncio.run_coroutine_threadsafe(self._add_task(task), loop=asyncio.get_event_loop())
 
263
  allow_headers=["*"], # Allow all headers
264
  )
265
 
266
+ @app.post("/get_image")
267
+ async def get_image(
268
+ job_id: int = Form(...),
269
+ ):
270
+ image_path = f"/tmp/{job_id}_output_{image.filename}"
271
+ with open(image_path, "rb") as file:
272
+ generated_image = file.read()
273
+
274
+ # Save the generated image as binary data
275
+ output_image_io = io.BytesIO()
276
+ generated_image.save(output_image_io, format="PNG")
277
+ output_image_io.seek(0)
278
+
279
+ # Return the image as a streaming response
280
+ return StreamingResponse(content=output_image_io, media_type="image/png")
281
+
282
+
283
  @app.post("/generate_image")
284
  async def generate_image(
285
  prompt: str = Form(...),
 
292
  image: UploadFile = File(...),
293
  background_tasks: BackgroundTasks = BackgroundTasks()
294
  ):
295
+ async def generate_image_task(job_id):
296
  try:
297
  # Save the uploaded image to a temporary file
298
+ temp_image_path = f"/tmp/{job_id}_{image.filename}"
299
  with open(temp_image_path, "wb") as temp_image:
300
  temp_image.write(image.file.read())
301
 
 
307
  if generated_image is None:
308
  return "Failed to generate image"
309
 
310
+ output_image_path = f"/tmp/{job_id}_output_{image.filename}"
311
+ with open(output_image_path, "wb") as output_image:
312
+ output_image.write(generated_image)
313
+
314
+ # # Save the generated image as binary data
315
+ # output_image_io = io.BytesIO()
316
+ # generated_image.save(output_image_io, format="PNG")
317
+ # output_image_io.seek(0)
318
 
319
+ # # Return the image as a streaming response
320
+ # return StreamingResponse(content=output_image_io, media_type="image/png")
321
 
322
  except Exception as e:
323
  logger.error("Error occurred during image generation: %s", str(e))
324
  return "Failed to generate image"
325
 
326
  try:
327
+ id = queue_manager.next_id++
328
+ queue_manager.add_task(lambda _: generate_image_task(id))
329
+ position_in_queue = queue_manager.queue.qsize()
330
  # Total queue size is still async
331
  total_queue_size = await queue_manager.get_total_queue_size() # Implement this function
332
 
333
+ return {"job_id": id, "position_in_queue": position_in_queue, "total_queue_size": total_queue_size}
334
 
335
  except Exception as e:
336
  logger.error("Error occurred during image generation: %s", str(e))