radames commited on
Commit
5aa8efd
·
1 Parent(s): 3187be1

lock via async

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -14,6 +14,7 @@ import uuid
14
  import logging
15
  from fastapi import FastAPI, Request, HTTPException
16
  from fastapi.middleware.cors import CORSMiddleware
 
17
 
18
  logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
19
 
@@ -26,6 +27,7 @@ DB_PATH.mkdir(exist_ok=True, parents=True)
26
  IMGS_PATH.mkdir(exist_ok=True, parents=True)
27
 
28
  database = Database(DB_PATH)
 
29
 
30
 
31
  model_id = "segmind/Segmind-Vega"
@@ -98,14 +100,15 @@ async def generate_image(
98
  return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg")
99
 
100
  logging.info(f"Image not found in cache, generating new image")
101
- pil_image = generate(prompt, negative_prompt, seed)
102
- img_id = str(uuid.uuid4())
103
- img_path = IMGS_PATH / f"{img_id}.jpg"
104
- pil_image.save(img_path)
105
- img_io = io.BytesIO()
106
- pil_image.save(img_io, "JPEG")
107
- img_io.seek(0)
108
- database.insert(prompt, negative_prompt, str(img_path), seed)
 
109
 
110
  return StreamingResponse(img_io, media_type="image/jpeg")
111
 
 
14
  import logging
15
  from fastapi import FastAPI, Request, HTTPException
16
  from fastapi.middleware.cors import CORSMiddleware
17
+ from asyncio import Lock
18
 
19
  logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
20
 
 
27
  IMGS_PATH.mkdir(exist_ok=True, parents=True)
28
 
29
  database = Database(DB_PATH)
30
+ generate_lock = Lock()
31
 
32
 
33
  model_id = "segmind/Segmind-Vega"
 
100
  return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg")
101
 
102
  logging.info(f"Image not found in cache, generating new image")
103
+ async with generate_lock:
104
+ pil_image = generate(prompt, negative_prompt, seed)
105
+ img_id = str(uuid.uuid4())
106
+ img_path = IMGS_PATH / f"{img_id}.jpg"
107
+ pil_image.save(img_path)
108
+ img_io = io.BytesIO()
109
+ pil_image.save(img_io, "JPEG")
110
+ img_io.seek(0)
111
+ database.insert(prompt, negative_prompt, str(img_path), seed)
112
 
113
  return StreamingResponse(img_io, media_type="image/jpeg")
114