VikramSingh178 commited on
Commit
080c0ac
1 Parent(s): c9c9773

Refactor code to use shared BaseModel for Painting and InpaintingRequest classes

Browse files

Former-commit-id: 4484440577c279f069f9e6c8c11d9ab9685f26e8 [formerly b52b26e6289548705fa59307a387beb19edec4f6]
Former-commit-id: 60b102b660610226910f419923c877a820891aab

api/__pycache__/endpoints.cpython-310.pyc CHANGED
Binary files a/api/__pycache__/endpoints.cpython-310.pyc and b/api/__pycache__/endpoints.cpython-310.pyc differ
 
api/endpoints.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from routers import (sdxl_text_to_image,painting,batch_painting)
4
  import logfire
5
  import uvicorn
6
 
@@ -26,7 +26,6 @@ app.add_middleware(
26
 
27
  app.include_router(sdxl_text_to_image.router, prefix='/api/v1/product-diffusion')
28
  app.include_router(painting.router,prefix='/api/v1/product-diffusion')
29
- app.include_router(batch_painting.router,prefix='/api/v1/product-diffusion')
30
  logfire.instrument_fastapi(app)
31
 
32
 
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from routers import (sdxl_text_to_image,painting)
4
  import logfire
5
  import uvicorn
6
 
 
26
 
27
  app.include_router(sdxl_text_to_image.router, prefix='/api/v1/product-diffusion')
28
  app.include_router(painting.router,prefix='/api/v1/product-diffusion')
 
29
  logfire.instrument_fastapi(app)
30
 
31
 
api/routers/__pycache__/batch_painting.cpython-310.pyc CHANGED
Binary files a/api/routers/__pycache__/batch_painting.cpython-310.pyc and b/api/routers/__pycache__/batch_painting.cpython-310.pyc differ
 
api/routers/__pycache__/painting.cpython-310.pyc CHANGED
Binary files a/api/routers/__pycache__/painting.cpython-310.pyc and b/api/routers/__pycache__/painting.cpython-310.pyc differ
 
api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc CHANGED
Binary files a/api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc and b/api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc differ
 
api/routers/batch_painting.py DELETED
File without changes
api/routers/painting.py CHANGED
@@ -2,25 +2,57 @@ import sys
2
  sys.path.append('../scripts')
3
  import os
4
  import uuid
5
- from typing import List, Tuple, Any
6
- from fastapi import APIRouter, File, UploadFile, HTTPException, Form
 
7
  from PIL import Image
8
  import lightning.pytorch as pl
9
  from utils import pil_to_s3_json, pil_to_b64_json, ImageAugmentation, accelerator
10
  from inpainting_pipeline import AutoPaintingPipeline, load_pipeline
11
  from hydra import compose, initialize
12
  from async_batcher.batcher import AsyncBatcher
 
 
 
 
13
 
14
  pl.seed_everything(42)
15
  router = APIRouter()
16
 
17
- # Initialize the configuration and pipeline
18
  with initialize(version_base=None, config_path="../../configs"):
19
  cfg = compose(config_name="inpainting")
20
- inpainting_pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  async def save_image(image: UploadFile) -> str:
23
- """Save an uploaded image to a temporary file and return the file path."""
 
 
 
 
 
 
 
 
24
  file_name = f"{uuid.uuid4()}.png"
25
  file_path = os.path.join("/tmp", file_name)
26
  with open(file_path, "wb") as f:
@@ -28,6 +60,20 @@ async def save_image(image: UploadFile) -> str:
28
  return file_path
29
 
30
  def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  image = Image.open(image_path)
32
  image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
33
  image = image_augmentation.extend_image(image)
@@ -35,17 +81,22 @@ def augment_image(image_path, target_width, target_height, roi_scale, segmentati
35
  inverted_mask = image_augmentation.invert_mask(mask)
36
  return image, inverted_mask
37
 
38
- def run_inference(cfg,
39
- image_path: str,
40
- prompt: str,
41
- negative_prompt: str,
42
- num_inference_steps: int,
43
- strength: float,
44
- guidance_scale: float,
45
- mode: str,
46
- num_images: int,
47
- use_augmentation: bool):
48
- if use_augmentation:
 
 
 
 
 
49
  image, mask_image = augment_image(image_path,
50
  cfg['target_width'],
51
  cfg['target_height'],
@@ -63,37 +114,84 @@ def run_inference(cfg,
63
  target_height=cfg['target_height'],
64
  target_width=cfg['target_width']
65
  )
66
- output = painting_pipeline.run_inference(prompt=prompt,
67
- negative_prompt=negative_prompt,
68
- num_inference_steps=num_inference_steps,
69
- strength=strength,
70
- guidance_scale=guidance_scale)
71
- if mode == "s3_json":
72
  return pil_to_s3_json(output, file_name="output.png")
73
- elif mode == "b64_json":
74
  return pil_to_b64_json(output)
75
  else:
76
  raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
77
-
78
-
 
 
 
 
 
 
 
79
 
80
  @router.post("/inpainting")
81
  async def inpainting_inference(
82
  image: UploadFile = File(...),
83
- prompt: str = Form(...),
84
- negative_prompt: str = Form(...),
85
- num_inference_steps: int = Form(...),
86
- strength: float = Form(...),
87
- guidance_scale: float = Form(...),
88
- mode: str = Form(...),
89
- num_images: int = Form(1),
90
-
91
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  try:
93
  image_path = await save_image(image)
94
- result = run_inference(cfg, image_path, prompt, negative_prompt, num_inference_steps, strength, guidance_scale, mode, num_images, use_augmentation=True)
 
 
95
  return result
96
  except Exception as e:
97
  raise HTTPException(status_code=500, detail=str(e))
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  sys.path.append('../scripts')
3
  import os
4
  import uuid
5
+ from typing import List, Tuple, Any,Dict
6
+ from fastapi import APIRouter, File, UploadFile, HTTPException, Form, Depends, Body
7
+ from pydantic import BaseModel, Field
8
  from PIL import Image
9
  import lightning.pytorch as pl
10
  from utils import pil_to_s3_json, pil_to_b64_json, ImageAugmentation, accelerator
11
  from inpainting_pipeline import AutoPaintingPipeline, load_pipeline
12
  from hydra import compose, initialize
13
  from async_batcher.batcher import AsyncBatcher
14
+ from concurrent.futures import Executor
15
+ import json
16
+ import asyncio
17
+ from functools import lru_cache
18
 
19
  pl.seed_everything(42)
20
  router = APIRouter()
21
 
22
+ # Initialize Hydra configuration
23
  with initialize(version_base=None, config_path="../../configs"):
24
  cfg = compose(config_name="inpainting")
25
+
26
+ # Load the inpainting pipeline
27
+ @lru_cache(maxsize=1)
28
+ def load_pipeline_wrapper():
29
+ pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
30
+ return pipeline
31
+ inpainting_pipeline = load_pipeline_wrapper()
32
+
33
+ class InpaintingRequest(BaseModel):
34
+ prompt: str = Field(..., description="Prompt text for inference")
35
+ negative_prompt: str = Field(..., description="Negative prompt text for inference")
36
+ num_inference_steps: int = Field(..., description="Number of inference steps")
37
+ strength: float = Field(..., description="Strength of the inference")
38
+ guidance_scale: float = Field(..., description="Guidance scale for inference")
39
+ mode: str = Field(..., description="Mode for output ('b64_json' or 's3_json')")
40
+ num_images: int = Field(..., description="Number of images to generate")
41
+ use_augmentation: bool = Field(True, description="Whether to use image augmentation")
42
+
43
+ class InpaintingBatchRequestModel(BaseModel):
44
+ requests: List[InpaintingRequest]
45
 
46
  async def save_image(image: UploadFile) -> str:
47
+ """
48
+ Save an uploaded image to a temporary file and return the file path.
49
+
50
+ Args:
51
+ image (UploadFile): The uploaded image file.
52
+
53
+ Returns:
54
+ str: File path where the image is saved.
55
+ """
56
  file_name = f"{uuid.uuid4()}.png"
57
  file_path = os.path.join("/tmp", file_name)
58
  with open(file_path, "wb") as f:
 
60
  return file_path
61
 
62
  def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
63
+ """
64
+ Augment an image by extending its dimensions and generating masks.
65
+
66
+ Args:
67
+ image_path (str): Path to the image file.
68
+ target_width (int): Target width for augmentation.
69
+ target_height (int): Target height for augmentation.
70
+ roi_scale (float): Scale factor for region of interest.
71
+ segmentation_model_name (str): Name of the segmentation model.
72
+ detection_model_name (str): Name of the detection model.
73
+
74
+ Returns:
75
+ Tuple[Image.Image, Image.Image]: Augmented image and inverted mask.
76
+ """
77
  image = Image.open(image_path)
78
  image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
79
  image = image_augmentation.extend_image(image)
 
81
  inverted_mask = image_augmentation.invert_mask(mask)
82
  return image, inverted_mask
83
 
84
+ def run_inference(cfg, image_path: str, request: InpaintingRequest):
85
+ """
86
+ Run inference using an inpainting pipeline on an image.
87
+
88
+ Args:
89
+ cfg (dict): Configuration dictionary.
90
+ image_path (str): Path to the image file.
91
+ request (InpaintingRequest): Pydantic model containing inference parameters.
92
+
93
+ Returns:
94
+ dict: Resulting image in the specified mode ('b64_json' or 's3_json').
95
+
96
+ Raises:
97
+ ValueError: If an invalid mode is provided.
98
+ """
99
+ if request.use_augmentation:
100
  image, mask_image = augment_image(image_path,
101
  cfg['target_width'],
102
  cfg['target_height'],
 
114
  target_height=cfg['target_height'],
115
  target_width=cfg['target_width']
116
  )
117
+ output = painting_pipeline.run_inference(prompt=request.prompt,
118
+ negative_prompt=request.negative_prompt,
119
+ num_inference_steps=request.num_inference_steps,
120
+ strength=request.strength,
121
+ guidance_scale=request.guidance_scale)
122
+ if request.mode == "s3_json":
123
  return pil_to_s3_json(output, file_name="output.png")
124
+ elif request.mode == "b64_json":
125
  return pil_to_b64_json(output)
126
  else:
127
  raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
128
+
129
+ class InpaintingBatcher(AsyncBatcher):
130
+ async def process_batch(self, batch: Tuple[List[str], List[InpaintingRequest]]) -> List[Dict[str, Any]]:
131
+ image_paths, requests = batch
132
+ results = []
133
+ for image_path, request in zip(image_paths, requests):
134
+ result = run_inference(cfg, image_path, request)
135
+ results.append(result)
136
+ return results
137
 
138
  @router.post("/inpainting")
139
  async def inpainting_inference(
140
  image: UploadFile = File(...),
141
+ request_data: str = Form(...),
 
 
 
 
 
 
 
142
  ):
143
+ """
144
+ Handle POST request for inpainting inference.
145
+
146
+ Args:
147
+ image (UploadFile): Uploaded image file.
148
+ request_data (str): JSON string of the request parameters.
149
+
150
+ Returns:
151
+ dict: Resulting image in the specified mode ('b64_json' or 's3_json').
152
+
153
+ Raises:
154
+ HTTPException: If there is an error during image processing.
155
+ """
156
  try:
157
  image_path = await save_image(image)
158
+ request_dict = json.loads(request_data)
159
+ request = InpaintingRequest(**request_dict)
160
+ result = run_inference(cfg, image_path, request)
161
  return result
162
  except Exception as e:
163
  raise HTTPException(status_code=500, detail=str(e))
164
 
165
+ @router.post("/inpainting/batch")
166
+ async def inpainting_batch_inference(
167
+ images: List[UploadFile] = File(...),
168
+ request_data: str = Form(...),
169
+ ):
170
+ """
171
+ Handle POST request for batch inpainting inference.
172
+
173
+ Args:
174
+ images (List[UploadFile]): List of uploaded image files.
175
+ request_data (str): JSON string of the request parameters.
176
+
177
+ Returns:
178
+ List[dict]: List of resulting images in the specified mode ('b64_json' or 's3_json').
179
 
180
+ Raises:
181
+ HTTPException: If there is an error during image processing.
182
+ """
183
+ try:
184
+ request_dict = json.loads(request_data)
185
+ batch_request = InpaintingBatchRequestModel(**request_dict)
186
+ requests = batch_request.requests
187
+
188
+ if len(images) != len(requests):
189
+ raise HTTPException(status_code=400, detail="The number of images and requests must match.")
190
+
191
+ batcher = InpaintingBatcher(max_batch_size=64)
192
+ image_paths = [await save_image(image) for image in images]
193
+ results = await batcher.process_batch((image_paths, requests))
194
+
195
+ return results
196
+ except Exception as e:
197
+ raise HTTPException(status_code=500, detail=str(e))
api/routers/sdxl_text_to_image.py CHANGED
@@ -32,7 +32,7 @@ router = APIRouter()
32
 
33
 
34
  # Load the diffusion pipeline
35
- @lru_cache(maxsize=1)
36
  def load_pipeline(model_name, adapter_name,enable_compile:bool):
37
  """
38
  Load the diffusion pipeline with the specified model and adapter names.
@@ -182,7 +182,7 @@ async def sdxl_v0_lora_inference(data: InputFormat):
182
 
183
  @router.post("/sdxl_v0_lora_inference/batch")
184
  async def sdxl_v0_lora_inference_batch(data: List[InputFormat]):
185
- batcher = SDXLLoraBatcher(max_batch_size=64)
186
  try:
187
  predictions = await batcher.process_batch(data)
188
  return predictions
 
32
 
33
 
34
  # Load the diffusion pipeline
35
+
36
  def load_pipeline(model_name, adapter_name,enable_compile:bool):
37
  """
38
  Load the diffusion pipeline with the specified model and adapter names.
 
182
 
183
  @router.post("/sdxl_v0_lora_inference/batch")
184
  async def sdxl_v0_lora_inference_batch(data: List[InputFormat]):
185
+ batcher = SDXLLoraBatcher(max_batch_size=-1)
186
  try:
187
  predictions = await batcher.process_batch(data)
188
  return predictions
scripts/__pycache__/inpainting_pipeline.cpython-310.pyc CHANGED
Binary files a/scripts/__pycache__/inpainting_pipeline.cpython-310.pyc and b/scripts/__pycache__/inpainting_pipeline.cpython-310.pyc differ
 
scripts/inpainting_pipeline.py CHANGED
@@ -7,8 +7,8 @@ from omegaconf import DictConfig
7
  from PIL import Image
8
  from functools import lru_cache
9
 
10
- @lru_cache(maxsize=1)
11
- def load_pipeline(model_name: str, device, enable_compile: bool = False):
12
  pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16)
13
  if enable_compile:
14
  pipeline.unet.to(memory_format=torch.channels_last)
 
7
  from PIL import Image
8
  from functools import lru_cache
9
 
10
+
11
+ def load_pipeline(model_name: str, device, enable_compile: bool = True):
12
  pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16)
13
  if enable_compile:
14
  pipeline.unet.to(memory_format=torch.channels_last)