VikramSingh178 commited on
Commit
b9da07f
1 Parent(s): 3044016

chore: Refactor code to use shared BaseModel for Painting and InputFormat classes

Browse files

Former-commit-id: 407eeff325cb35937072ae4d08ae116b0c938b52 [formerly aca93c4c8e33ae484649dafdb75471c4cd35c4b7]
Former-commit-id: 62ead1acffc613e477d909cb2c553fd5da68ba9f

api/models/__init__.py ADDED
File without changes
api/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (161 Bytes). View file
 
api/models/__pycache__/sdxl_input.cpython-311.pyc ADDED
Binary file (691 Bytes). View file
 
api/models/painting.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class Painting(BaseModel):
5
+ prompt: str
6
+ num_inference_steps: int
7
+ guidance_scale: float
8
+ negative_prompt: str
9
+ num_images: int
10
+ mode: str
11
+
api/models/sdxl_input.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pydantic import BaseModel
3
+
4
+
5
+
6
+ class InputFormat(BaseModel):
7
+ prompt: str
8
+ num_inference_steps: int
9
+ guidance_scale: float
10
+ negative_prompt: str
11
+ num_images: int
12
+ mode: str
api/routers/__pycache__/painting.cpython-311.pyc CHANGED
Binary files a/api/routers/__pycache__/painting.cpython-311.pyc and b/api/routers/__pycache__/painting.cpython-311.pyc differ
 
api/routers/__pycache__/sdxl_text_to_image.cpython-311.pyc CHANGED
Binary files a/api/routers/__pycache__/sdxl_text_to_image.cpython-311.pyc and b/api/routers/__pycache__/sdxl_text_to_image.cpython-311.pyc differ
 
api/routers/sdxl_text_to_image.py CHANGED
@@ -14,6 +14,8 @@ from s3_manager import S3ManagerService
14
  from PIL import Image
15
  import io
16
  from utils import accelerator
 
 
17
 
18
  device = accelerator()
19
  torch._inductor.config.conv_1x1_as_mm = True
@@ -101,7 +103,7 @@ loaded_pipeline = load_pipeline(config.MODEL_NAME, config.ADAPTER_NAME, config.E
101
 
102
 
103
  # SDXLLoraInference class for running inference
104
- class SDXLLoraInference:
105
  """
106
  Class for performing SDXL Lora inference.
107
 
@@ -169,21 +171,37 @@ class SDXLLoraInference:
169
  return pil_to_b64_json(image)
170
  else:
171
  raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
 
 
 
 
 
 
172
 
173
-
174
- # Input format for single request
175
- class InputFormat(BaseModel):
176
- prompt: str
177
- num_inference_steps: int
178
- guidance_scale: float
179
- negative_prompt: str
180
- num_images: int
181
- mode: str
182
-
183
-
184
- # Input format for batch requests
185
- class BatchInputFormat(BaseModel):
186
- batch_input: List[InputFormat]
 
 
 
 
 
 
 
 
 
 
187
 
188
 
189
  # Endpoint for single request
@@ -195,47 +213,24 @@ async def sdxl_v0_lora_inference(data: InputFormat):
195
  data.num_images,
196
  data.num_inference_steps,
197
  data.guidance_scale,
198
- data.mode
 
 
199
  )
200
  output_json = inference.run_inference()
201
  return output_json
202
 
203
 
 
204
 
205
  @router.post("/sdxl_v0_lora_inference/batch")
206
- async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
207
- """
208
- Perform batch inference for SDXL V0 LoRa model.
209
-
210
- Args:
211
- data (BatchInputFormat): The input data containing a batch of requests.
212
-
213
- Returns:
214
- dict: A dictionary containing the message and processed requests data.
215
-
216
- Raises:
217
- HTTPException: If the number of requests exceeds the maximum queue size.
218
- """
219
- MAX_QUEUE_SIZE = 64
220
-
221
- if len(data.batch_input) > MAX_QUEUE_SIZE:
222
- raise HTTPException(
223
- status_code=400,
224
- detail=f"Number of requests exceeds maximum queue size ({MAX_QUEUE_SIZE})",
225
- )
226
-
227
- processed_requests = []
228
- for item in data.batch_input:
229
- inference = SDXLLoraInference(
230
- item.prompt,
231
- item.negative_prompt,
232
- item.num_images,
233
- item.num_inference_steps,
234
- item.guidance_scale,
235
- item.mode,
236
- )
237
- output_json = inference.run_inference()
238
- processed_requests.append(output_json)
239
-
240
- return {"message": "Requests processed successfully", "data": processed_requests}
241
 
 
14
  from PIL import Image
15
  import io
16
  from utils import accelerator
17
+ from models.sdxl_input import InputFormat
18
+ from async_batcher.batcher import AsyncBatcher
19
 
20
  device = accelerator()
21
  torch._inductor.config.conv_1x1_as_mm = True
 
103
 
104
 
105
  # SDXLLoraInference class for running inference
106
+ class SDXLLoraInference(AsyncBatcher):
107
  """
108
  Class for performing SDXL Lora inference.
109
 
 
171
  return pil_to_b64_json(image)
172
  else:
173
  raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
174
+
175
+
176
+ class SDXLLoraBatcher(AsyncBatcher[InputFormat, dict]):
177
+ def __init__(self, *args, **kwargs):
178
+ super().__init__(*args, **kwargs)
179
+ self.pipe = loaded_pipeline
180
 
181
+ def process_batch(self, batch: List[InputFormat]) -> List[dict]:
182
+ results = []
183
+ for data in batch:
184
+ try:
185
+ images = self.pipe(
186
+ prompt=data.prompt,
187
+ num_inference_steps=data.num_inference_steps,
188
+ guidance_scale=data.guidance_scale,
189
+ negative_prompt=data.negative_prompt,
190
+ num_images_per_prompt=data.num_images,
191
+ ).images
192
+
193
+ for image in images:
194
+ if data.mode == "s3_json":
195
+ result = pil_to_s3_json(image, 'sdxl_image')
196
+ elif data.mode == "b64_json":
197
+ result = pil_to_b64_json(image)
198
+ else:
199
+ raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
200
+ results.append(result)
201
+ except Exception as e:
202
+ print(f"Error in process_batch: {e}")
203
+ raise HTTPException(status_code=500, detail="Batch inference failed")
204
+ return results
205
 
206
 
207
  # Endpoint for single request
 
213
  data.num_images,
214
  data.num_inference_steps,
215
  data.guidance_scale,
216
+ data.mode,
217
+
218
+
219
  )
220
  output_json = inference.run_inference()
221
  return output_json
222
 
223
 
224
+ # Endpoint for batch requests
225
 
226
  @router.post("/sdxl_v0_lora_inference/batch")
227
+ async def sdxl_v0_lora_inference_batch(data: List[InputFormat]):
228
+ batcher = SDXLLoraBatcher(max_batch_size=64, max_queue_time=0.001)
229
+ try:
230
+ predictions = await batcher.process(batch=data)
231
+ return predictions
232
+ except Exception as e:
233
+ print(f"Error in /sdxl_v0_lora_inference/batch: {e}")
234
+ raise HTTPException(status_code=500, detail="Batch inference endpoint failed")
235
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
scripts/__pycache__/config.cpython-311.pyc CHANGED
Binary files a/scripts/__pycache__/config.cpython-311.pyc and b/scripts/__pycache__/config.cpython-311.pyc differ
 
scripts/__pycache__/inpainting_pipeline.cpython-311.pyc CHANGED
Binary files a/scripts/__pycache__/inpainting_pipeline.cpython-311.pyc and b/scripts/__pycache__/inpainting_pipeline.cpython-311.pyc differ