VikramSingh178
commited on
Commit
•
b9da07f
1
Parent(s):
3044016
chore: Refactor code to use shared BaseModel for Painting and InputFormat classes
Browse filesFormer-commit-id: 407eeff325cb35937072ae4d08ae116b0c938b52 [formerly aca93c4c8e33ae484649dafdb75471c4cd35c4b7]
Former-commit-id: 62ead1acffc613e477d909cb2c553fd5da68ba9f
- api/models/__init__.py +0 -0
- api/models/__pycache__/__init__.cpython-311.pyc +0 -0
- api/models/__pycache__/sdxl_input.cpython-311.pyc +0 -0
- api/models/painting.py +11 -0
- api/models/sdxl_input.py +12 -0
- api/routers/__pycache__/painting.cpython-311.pyc +0 -0
- api/routers/__pycache__/sdxl_text_to_image.cpython-311.pyc +0 -0
- api/routers/sdxl_text_to_image.py +46 -51
- scripts/__pycache__/config.cpython-311.pyc +0 -0
- scripts/__pycache__/inpainting_pipeline.cpython-311.pyc +0 -0
api/models/__init__.py
ADDED
File without changes
|
api/models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (161 Bytes). View file
|
|
api/models/__pycache__/sdxl_input.cpython-311.pyc
ADDED
Binary file (691 Bytes). View file
|
|
api/models/painting.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class Painting(BaseModel):
|
5 |
+
prompt: str
|
6 |
+
num_inference_steps: int
|
7 |
+
guidance_scale: float
|
8 |
+
negative_prompt: str
|
9 |
+
num_images: int
|
10 |
+
mode: str
|
11 |
+
|
api/models/sdxl_input.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from pydantic import BaseModel
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class InputFormat(BaseModel):
|
7 |
+
prompt: str
|
8 |
+
num_inference_steps: int
|
9 |
+
guidance_scale: float
|
10 |
+
negative_prompt: str
|
11 |
+
num_images: int
|
12 |
+
mode: str
|
api/routers/__pycache__/painting.cpython-311.pyc
CHANGED
Binary files a/api/routers/__pycache__/painting.cpython-311.pyc and b/api/routers/__pycache__/painting.cpython-311.pyc differ
|
|
api/routers/__pycache__/sdxl_text_to_image.cpython-311.pyc
CHANGED
Binary files a/api/routers/__pycache__/sdxl_text_to_image.cpython-311.pyc and b/api/routers/__pycache__/sdxl_text_to_image.cpython-311.pyc differ
|
|
api/routers/sdxl_text_to_image.py
CHANGED
@@ -14,6 +14,8 @@ from s3_manager import S3ManagerService
|
|
14 |
from PIL import Image
|
15 |
import io
|
16 |
from utils import accelerator
|
|
|
|
|
17 |
|
18 |
device = accelerator()
|
19 |
torch._inductor.config.conv_1x1_as_mm = True
|
@@ -101,7 +103,7 @@ loaded_pipeline = load_pipeline(config.MODEL_NAME, config.ADAPTER_NAME, config.E
|
|
101 |
|
102 |
|
103 |
# SDXLLoraInference class for running inference
|
104 |
-
class SDXLLoraInference:
|
105 |
"""
|
106 |
Class for performing SDXL Lora inference.
|
107 |
|
@@ -169,21 +171,37 @@ class SDXLLoraInference:
|
|
169 |
return pil_to_b64_json(image)
|
170 |
else:
|
171 |
raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
|
189 |
# Endpoint for single request
|
@@ -195,47 +213,24 @@ async def sdxl_v0_lora_inference(data: InputFormat):
|
|
195 |
data.num_images,
|
196 |
data.num_inference_steps,
|
197 |
data.guidance_scale,
|
198 |
-
data.mode
|
|
|
|
|
199 |
)
|
200 |
output_json = inference.run_inference()
|
201 |
return output_json
|
202 |
|
203 |
|
|
|
204 |
|
205 |
@router.post("/sdxl_v0_lora_inference/batch")
|
206 |
-
async def sdxl_v0_lora_inference_batch(data:
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
Raises:
|
217 |
-
HTTPException: If the number of requests exceeds the maximum queue size.
|
218 |
-
"""
|
219 |
-
MAX_QUEUE_SIZE = 64
|
220 |
-
|
221 |
-
if len(data.batch_input) > MAX_QUEUE_SIZE:
|
222 |
-
raise HTTPException(
|
223 |
-
status_code=400,
|
224 |
-
detail=f"Number of requests exceeds maximum queue size ({MAX_QUEUE_SIZE})",
|
225 |
-
)
|
226 |
-
|
227 |
-
processed_requests = []
|
228 |
-
for item in data.batch_input:
|
229 |
-
inference = SDXLLoraInference(
|
230 |
-
item.prompt,
|
231 |
-
item.negative_prompt,
|
232 |
-
item.num_images,
|
233 |
-
item.num_inference_steps,
|
234 |
-
item.guidance_scale,
|
235 |
-
item.mode,
|
236 |
-
)
|
237 |
-
output_json = inference.run_inference()
|
238 |
-
processed_requests.append(output_json)
|
239 |
-
|
240 |
-
return {"message": "Requests processed successfully", "data": processed_requests}
|
241 |
|
|
|
14 |
from PIL import Image
|
15 |
import io
|
16 |
from utils import accelerator
|
17 |
+
from models.sdxl_input import InputFormat
|
18 |
+
from async_batcher.batcher import AsyncBatcher
|
19 |
|
20 |
device = accelerator()
|
21 |
torch._inductor.config.conv_1x1_as_mm = True
|
|
|
103 |
|
104 |
|
105 |
# SDXLLoraInference class for running inference
|
106 |
+
class SDXLLoraInference(AsyncBatcher):
|
107 |
"""
|
108 |
Class for performing SDXL Lora inference.
|
109 |
|
|
|
171 |
return pil_to_b64_json(image)
|
172 |
else:
|
173 |
raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
|
174 |
+
|
175 |
+
|
176 |
+
class SDXLLoraBatcher(AsyncBatcher[InputFormat, dict]):
|
177 |
+
def __init__(self, *args, **kwargs):
|
178 |
+
super().__init__(*args, **kwargs)
|
179 |
+
self.pipe = loaded_pipeline
|
180 |
|
181 |
+
def process_batch(self, batch: List[InputFormat]) -> List[dict]:
|
182 |
+
results = []
|
183 |
+
for data in batch:
|
184 |
+
try:
|
185 |
+
images = self.pipe(
|
186 |
+
prompt=data.prompt,
|
187 |
+
num_inference_steps=data.num_inference_steps,
|
188 |
+
guidance_scale=data.guidance_scale,
|
189 |
+
negative_prompt=data.negative_prompt,
|
190 |
+
num_images_per_prompt=data.num_images,
|
191 |
+
).images
|
192 |
+
|
193 |
+
for image in images:
|
194 |
+
if data.mode == "s3_json":
|
195 |
+
result = pil_to_s3_json(image, 'sdxl_image')
|
196 |
+
elif data.mode == "b64_json":
|
197 |
+
result = pil_to_b64_json(image)
|
198 |
+
else:
|
199 |
+
raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
|
200 |
+
results.append(result)
|
201 |
+
except Exception as e:
|
202 |
+
print(f"Error in process_batch: {e}")
|
203 |
+
raise HTTPException(status_code=500, detail="Batch inference failed")
|
204 |
+
return results
|
205 |
|
206 |
|
207 |
# Endpoint for single request
|
|
|
213 |
data.num_images,
|
214 |
data.num_inference_steps,
|
215 |
data.guidance_scale,
|
216 |
+
data.mode,
|
217 |
+
|
218 |
+
|
219 |
)
|
220 |
output_json = inference.run_inference()
|
221 |
return output_json
|
222 |
|
223 |
|
224 |
+
# Endpoint for batch requests
|
225 |
|
226 |
@router.post("/sdxl_v0_lora_inference/batch")
|
227 |
+
async def sdxl_v0_lora_inference_batch(data: List[InputFormat]):
|
228 |
+
batcher = SDXLLoraBatcher(max_batch_size=64, max_queue_time=0.001)
|
229 |
+
try:
|
230 |
+
predictions = await batcher.process(batch=data)
|
231 |
+
return predictions
|
232 |
+
except Exception as e:
|
233 |
+
print(f"Error in /sdxl_v0_lora_inference/batch: {e}")
|
234 |
+
raise HTTPException(status_code=500, detail="Batch inference endpoint failed")
|
235 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
scripts/__pycache__/config.cpython-311.pyc
CHANGED
Binary files a/scripts/__pycache__/config.cpython-311.pyc and b/scripts/__pycache__/config.cpython-311.pyc differ
|
|
scripts/__pycache__/inpainting_pipeline.cpython-311.pyc
CHANGED
Binary files a/scripts/__pycache__/inpainting_pipeline.cpython-311.pyc and b/scripts/__pycache__/inpainting_pipeline.cpython-311.pyc differ
|
|