VikramSingh178 commited on
Commit
6dec4df
1 Parent(s): fac71dd

Former-commit-id: 6e67c162ab28997e4b3581e2482f192da255a4be [formerly 522ceaeca59f5a0db3c0e1e300b6ac98b69656af]

Former-commit-id: 62721f17c8bc87e55be13bfc8363e09454779126

api/routers/painting.py CHANGED
@@ -1,49 +1,64 @@
1
  import os
2
  import uuid
3
- import json
4
- from typing import List
5
  from fastapi import APIRouter, File, UploadFile, HTTPException, Form
6
- from pydantic import BaseModel, Field, ValidationError
 
7
  import lightning.pytorch as pl
8
- from scripts.api_utils import pil_to_b64_json
9
- from scripts.outpainting import ControlNetZoeDepthOutpainting
 
10
  from async_batcher.batcher import AsyncBatcher
 
11
  from functools import lru_cache
12
-
13
  pl.seed_everything(42)
14
-
15
  router = APIRouter()
16
 
17
- @lru_cache(maxsize=1)
18
- def load_pipeline():
19
- outpainting_pipeline = ControlNetZoeDepthOutpainting(target_size=(1024, 1024))
20
- return outpainting_pipeline
21
 
22
- class OutpaintingRequest(BaseModel):
 
 
 
 
 
23
  """
24
- Model representing a request for outpainting inference.
 
 
 
25
  """
26
- controlnet_prompt: str = Field(...)
27
- controlnet_negative_prompt: str = Field(...)
28
- controlnet_conditioning_scale: float = Field(...)
29
- controlnet_guidance_scale: float = Field(...)
30
- controlnet_num_inference_steps: int = Field(...)
31
- controlnet_guidance_end: float = Field(...)
32
- inpainting_prompt: str = Field(...)
33
- inpainting_negative_prompt: str = Field(...)
34
- inpainting_guidance_scale: float = Field(...)
35
- inpainting_strength: float = Field(...)
36
- inpainting_num_inference_steps: int = Field(...)
37
 
38
- class OutpaintingBatchRequestModel(BaseModel):
 
 
39
  """
40
- Model representing a batch request for outpainting inference.
 
 
 
 
 
 
 
 
 
41
  """
42
- requests: List[OutpaintingRequest]
 
 
43
 
44
  async def save_image(image: UploadFile) -> str:
45
  """
46
  Save an uploaded image to a temporary file and return the file path.
 
 
 
 
 
 
47
  """
48
  file_name = f"{uuid.uuid4()}.png"
49
  file_path = os.path.join("/tmp", file_name)
@@ -51,75 +66,149 @@ async def save_image(image: UploadFile) -> str:
51
  f.write(await image.read())
52
  return file_path
53
 
54
- def run_inference(image_path: str, request: OutpaintingRequest):
55
- pipeline = load_pipeline()
56
- result = pipeline.run_pipeline(
57
- image_path,
58
- controlnet_prompt=request.controlnet_prompt,
59
- controlnet_negative_prompt=request.controlnet_negative_prompt,
60
- controlnet_conditioning_scale=request.controlnet_conditioning_scale,
61
- controlnet_guidance_scale=request.controlnet_guidance_scale,
62
- controlnet_num_inference_steps=request.controlnet_num_inference_steps,
63
- controlnet_guidance_end=request.controlnet_guidance_end,
64
- inpainting_prompt=request.inpainting_prompt,
65
- inpainting_negative_prompt=request.inpainting_negative_prompt,
66
- inpainting_guidance_scale=request.inpainting_guidance_scale,
67
- inpainting_strength=request.inpainting_strength,
68
- inpainting_num_inference_steps=request.inpainting_num_inference_steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
70
- return result
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- @router.post("/outpaint")
73
- async def outpaint(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  image: UploadFile = File(...),
75
- request: str = Form(...)
76
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  try:
78
- request_dict = json.loads(request)
79
- outpainting_request = OutpaintingRequest(**request_dict)
80
-
81
  image_path = await save_image(image)
82
- result = run_inference(image_path, outpainting_request)
83
-
84
- result_json = pil_to_b64_json(result)
85
-
86
- os.remove(image_path)
87
-
88
- return {"result": result_json}
89
- except json.JSONDecodeError:
90
- raise HTTPException(status_code=400, detail="Invalid JSON in request data")
91
- except ValidationError as e:
92
- raise HTTPException(status_code=422, detail=str(e))
93
  except Exception as e:
94
  raise HTTPException(status_code=500, detail=str(e))
95
 
96
- class OutpaintingBatcher(AsyncBatcher):
97
- async def process_batch(self, batch):
98
- results = []
99
- for image, request in batch:
100
- image_path = await save_image(image)
101
- try:
102
- result = run_inference(image_path, request)
103
- results.append(result)
104
- finally:
105
- os.remove(image_path)
106
- return results
 
 
 
107
 
108
- @router.post("/batch_outpaint")
109
- async def batch_outpaint(images: List[UploadFile] = File(...), batch_request: str = Form(...)):
 
110
  try:
111
- batch_request_dict = json.loads(batch_request)
112
- batch_outpainting_request = OutpaintingBatchRequestModel(**batch_request_dict)
113
-
114
- batcher = OutpaintingBatcher(max_queue_size=64)
115
- results = await batcher.process_batch(list(zip(images, batch_outpainting_request.requests)))
116
-
117
- result_jsons = [pil_to_b64_json(result) for result in results]
118
-
119
- return {"results": result_jsons}
120
- except json.JSONDecodeError:
121
- raise HTTPException(status_code=400, detail="Invalid JSON in batch request data")
122
- except ValidationError as e:
123
- raise HTTPException(status_code=422, detail=str(e))
124
  except Exception as e:
125
  raise HTTPException(status_code=500, detail=str(e))
 
1
  import os
2
  import uuid
3
+ from typing import List, Tuple, Any, Dict
 
4
  from fastapi import APIRouter, File, UploadFile, HTTPException, Form
5
+ from pydantic import BaseModel, Field
6
+ from PIL import Image
7
  import lightning.pytorch as pl
8
+ from scripts.api_utils import pil_to_s3_json, pil_to_b64_json, ImageAugmentation, accelerator
9
+ from scripts.inpainting_pipeline import AutoPaintingPipeline, load_pipeline
10
+ from hydra import compose, initialize
11
  from async_batcher.batcher import AsyncBatcher
12
+ import json
13
  from functools import lru_cache
 
14
  pl.seed_everything(42)
 
15
  router = APIRouter()
16
 
 
 
 
 
17
 
18
+ with initialize(version_base=None, config_path="../../configs"):
19
+ cfg = compose(config_name="inpainting")
20
+
21
+ # Load the inpainting pipeline
22
+ @lru_cache(maxsize=1)
23
+ def load_pipeline_wrapper():
24
  """
25
+ Load the inpainting pipeline with the specified configuration.
26
+
27
+ Returns:
28
+ pipeline: The loaded inpainting pipeline.
29
  """
30
+ pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
31
+ return pipeline
32
+ inpainting_pipeline = load_pipeline_wrapper()
 
 
 
 
 
 
 
 
33
 
34
+ class InpaintingRequest(BaseModel):
35
+ """
36
+ Model representing a request for inpainting inference.
37
  """
38
+ prompt: str = Field(..., description="Prompt text for inference")
39
+ negative_prompt: str = Field(..., description="Negative prompt text for inference")
40
+ num_inference_steps: int = Field(..., description="Number of inference steps")
41
+ strength: float = Field(..., description="Strength of the inference")
42
+ guidance_scale: float = Field(..., description="Guidance scale for inference")
43
+ mode: str = Field(..., description="Mode for output ('b64_json' or 's3_json')")
44
+ num_images: int = Field(..., description="Number of images to generate")
45
+ use_augmentation: bool = Field(True, description="Whether to use image augmentation")
46
+
47
+ class InpaintingBatchRequestModel(BaseModel):
48
  """
49
+ Model representing a batch request for inpainting inference.
50
+ """
51
+ requests: List[InpaintingRequest]
52
 
53
  async def save_image(image: UploadFile) -> str:
54
  """
55
  Save an uploaded image to a temporary file and return the file path.
56
+
57
+ Args:
58
+ image (UploadFile): The uploaded image file.
59
+
60
+ Returns:
61
+ str: File path where the image is saved.
62
  """
63
  file_name = f"{uuid.uuid4()}.png"
64
  file_path = os.path.join("/tmp", file_name)
 
66
  f.write(await image.read())
67
  return file_path
68
 
69
+ def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
70
+ """
71
+ Augment an image by extending its dimensions and generating masks.
72
+
73
+ Args:
74
+ image_path (str): Path to the image file.
75
+ target_width (int): Target width for augmentation.
76
+ target_height (int): Target height for augmentation.
77
+ roi_scale (float): Scale factor for region of interest.
78
+ segmentation_model_name (str): Name of the segmentation model.
79
+ detection_model_name (str): Name of the detection model.
80
+
81
+ Returns:
82
+ Tuple[Image.Image, Image.Image]: Augmented image and inverted mask.
83
+ """
84
+ image = Image.open(image_path)
85
+ image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
86
+ image = image_augmentation.extend_image(image)
87
+ mask = image_augmentation.generate_mask_from_bbox(image, segmentation_model_name, detection_model_name)
88
+ inverted_mask = image_augmentation.invert_mask(mask)
89
+ return image, inverted_mask
90
+
91
+ def run_inference(cfg, image_path: str, request: InpaintingRequest):
92
+ """
93
+ Run inference using an inpainting pipeline on an image.
94
+
95
+ Args:
96
+ cfg (dict): Configuration dictionary.
97
+ image_path (str): Path to the image file.
98
+ request (InpaintingRequest): Pydantic model containing inference parameters.
99
+
100
+ Returns:
101
+ dict: Resulting image in the specified mode ('b64_json' or 's3_json').
102
+
103
+ Raises:
104
+ ValueError: If an invalid mode is provided.
105
+ """
106
+ if request.use_augmentation:
107
+ image, mask_image = augment_image(image_path,
108
+ cfg['target_width'],
109
+ cfg['target_height'],
110
+ cfg['roi_scale'],
111
+ cfg['segmentation_model'],
112
+ cfg['detection_model'])
113
+ else:
114
+ image = Image.open(image_path)
115
+ mask_image = None
116
+
117
+ painting_pipeline = AutoPaintingPipeline(
118
+ pipeline=inpainting_pipeline,
119
+ image=image,
120
+ mask_image=mask_image,
121
+ target_height=cfg['target_height'],
122
+ target_width=cfg['target_width']
123
  )
124
+ output = painting_pipeline.run_inference(prompt=request.prompt,
125
+ negative_prompt=request.negative_prompt,
126
+ num_inference_steps=request.num_inference_steps,
127
+ strength=request.strength,
128
+ guidance_scale=request.guidance_scale,
129
+ num_images=request.num_images)
130
+ if request.mode == "s3_json":
131
+ return pil_to_s3_json(output, file_name="output.png")
132
+ elif request.mode == "b64_json":
133
+ return pil_to_b64_json(output)
134
+ else:
135
+ raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
136
 
137
+ class InpaintingBatcher(AsyncBatcher):
138
+ async def process_batch(self, batch: Tuple[List[str], List[InpaintingRequest]]) -> List[Dict[str, Any]]:
139
+ """
140
+ Process a batch of images and requests for inpainting inference.
141
+
142
+ Args:
143
+ batch (Tuple[List[str], List[InpaintingRequest]]): Tuple of image paths and corresponding requests.
144
+
145
+ Returns:
146
+ List[Dict[str, Any]]: List of resulting images in the specified mode ('b64_json' or 's3_json').
147
+ """
148
+ image_paths, requests = batch
149
+ results = []
150
+ for image_path, request in zip(image_paths, requests):
151
+ result = run_inference(cfg, image_path, request)
152
+ results.append(result)
153
+ return results
154
+
155
+ @router.post("/inpainting")
156
+ async def inpainting_inference(
157
  image: UploadFile = File(...),
158
+ request_data: str = Form(...),
159
  ):
160
+ """
161
+ Handle POST request for inpainting inference.
162
+
163
+ Args:
164
+ image (UploadFile): Uploaded image file.
165
+ request_data (str): JSON string of the request parameters.
166
+
167
+ Returns:
168
+ dict: Resulting image in the specified mode ('b64_json' or 's3_json').
169
+
170
+ Raises:
171
+ HTTPException: If there is an error during image processing.
172
+ """
173
  try:
 
 
 
174
  image_path = await save_image(image)
175
+ request_dict = json.loads(request_data)
176
+ request = InpaintingRequest(**request_dict)
177
+ result = run_inference(cfg, image_path, request)
178
+ return result
 
 
 
 
 
 
 
179
  except Exception as e:
180
  raise HTTPException(status_code=500, detail=str(e))
181
 
182
+ @router.post("/inpainting/batch")
183
+ async def inpainting_batch_inference(
184
+ images: List[UploadFile] = File(...),
185
+ request_data: str = Form(...),
186
+ ):
187
+ """
188
+ Handle POST request for batch inpainting inference.
189
+
190
+ Args:
191
+ images (List[UploadFile]): List of uploaded image files.
192
+ request_data (str): JSON string of the request parameters.
193
+
194
+ Returns:
195
+ List[dict]: List of resulting images in the specified mode ('b64_json' or 's3_json').
196
 
197
+ Raises:
198
+ HTTPException: If there is an error during image processing.
199
+ """
200
  try:
201
+ request_dict = json.loads(request_data)
202
+ batch_request = InpaintingBatchRequestModel(**request_dict)
203
+ requests = batch_request.requests
204
+
205
+ if len(images) != len(requests):
206
+ raise HTTPException(status_code=400, detail="The number of images and requests must match.")
207
+
208
+ batcher = InpaintingBatcher(max_batch_size=64)
209
+ image_paths = [await save_image(image) for image in images]
210
+ results = batcher.process_batch((image_paths, requests))
211
+
212
+ return results
 
213
  except Exception as e:
214
  raise HTTPException(status_code=500, detail=str(e))
sample_data/mask_image.png ADDED
scripts/__pycache__/inpainting_pipeline.cpython-310.pyc ADDED
Binary file (2.64 kB). View file
 
scripts/kandinsky3_inpainting.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from Kandinsky.kandinsky3 import get_inpainting_pipeline
3
+ from scripts.api_utils import ImageAugmentation,accelerator
4
+ from diffusers.utils import load_image
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ device_map = torch.device(accelerator())
9
+ dtype_map = {
10
+ 'unet': torch.float16,
11
+ 'text_encoder': torch.float16,
12
+ 'movq': torch.float32,
13
+ }
14
+
15
+
16
+ pipe = get_inpainting_pipeline(
17
+ device_map, dtype_map,
18
+ )
19
+
20
+ image = Image.open('/home/PicPilot/sample_data/image.png')
21
+ mask_image = Image.open('/home/PicPilot/sample_data/mask_image.png')
22
+ image = load_image(image=image)
23
+ mask_image = np.array(mask_image)
24
+
25
+
26
+
27
+
28
+ image = pipe( "Product on the Kitchen used for cooking", image, mask_image)
29
+ image.save('output.jpg')
ui/ui.py CHANGED
@@ -5,11 +5,18 @@ import json
5
  from PIL import Image
6
  from diffusers.utils import load_image
7
  from io import BytesIO
 
8
  from vars import base_url
9
 
10
  # API endpoints
11
  sdxl_inference_endpoint = f'{base_url}/api/v1/product-diffusion/sdxl_v0_lora_inference'
12
  kandinsky_inpainting_inference = f'{base_url}/api/v1/product-diffusion/inpainting'
 
 
 
 
 
 
13
 
14
  def generate_sdxl_lora_image(prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, mode):
15
  payload = {
 
5
  from PIL import Image
6
  from diffusers.utils import load_image
7
  from io import BytesIO
8
+ <<<<<<< HEAD
9
  from vars import base_url
10
 
11
  # API endpoints
12
  sdxl_inference_endpoint = f'{base_url}/api/v1/product-diffusion/sdxl_v0_lora_inference'
13
  kandinsky_inpainting_inference = f'{base_url}/api/v1/product-diffusion/inpainting'
14
+ =======
15
+
16
+ # API endpoints
17
+ sdxl_inference_endpoint = 'http://localhost:7860/api/v1/product-diffusion/sdxl_v0_lora_inference'
18
+ kandinsky_inpainting_inference = 'http://localhost:7860/api/v1/product-diffusion/inpainting'
19
+ >>>>>>> 41b06fc (commit)
20
 
21
  def generate_sdxl_lora_image(prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, mode):
22
  payload = {