VikramSingh178 commited on
Commit
ffaa8aa
1 Parent(s): f596b65

Former-commit-id: b753d5f0ba3f3f032ac18f11985cbdfccbd7afe9

api/endpoints.py CHANGED
@@ -7,13 +7,6 @@ import uvicorn
7
 
8
 
9
  logfire.configure(pydantic_plugin=logfire.PydanticPlugin(record='all'))
10
-
11
-
12
-
13
-
14
-
15
-
16
-
17
  app = FastAPI(openapi_url='/api/v1/product-diffusion/openapi.json',docs_url='/api/v1/product-diffusion/docs')
18
  app.add_middleware(
19
  CORSMiddleware,
 
7
 
8
 
9
  logfire.configure(pydantic_plugin=logfire.PydanticPlugin(record='all'))
 
 
 
 
 
 
 
10
  app = FastAPI(openapi_url='/api/v1/product-diffusion/openapi.json',docs_url='/api/v1/product-diffusion/docs')
11
  app.add_middleware(
12
  CORSMiddleware,
api/routers/painting.py CHANGED
@@ -1,64 +1,49 @@
1
  import os
2
  import uuid
3
- from typing import List, Tuple, Any, Dict
 
4
  from fastapi import APIRouter, File, UploadFile, HTTPException, Form
5
- from pydantic import BaseModel, Field
6
- from PIL import Image
7
  import lightning.pytorch as pl
8
- from scripts.api_utils import pil_to_s3_json, pil_to_b64_json, ImageAugmentation, accelerator
9
- from scripts.inpainting_pipeline import AutoPaintingPipeline, load_pipeline
10
- from hydra import compose, initialize
11
  from async_batcher.batcher import AsyncBatcher
12
- import json
13
  from functools import lru_cache
14
- pl.seed_everything(42)
15
- router = APIRouter()
16
 
 
17
 
18
- with initialize(version_base=None, config_path="../../configs"):
19
- cfg = compose(config_name="inpainting")
20
 
21
- # Load the inpainting pipeline
22
  @lru_cache(maxsize=1)
23
- def load_pipeline_wrapper():
24
- """
25
- Load the inpainting pipeline with the specified configuration.
26
-
27
- Returns:
28
- pipeline: The loaded inpainting pipeline.
29
- """
30
- pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
31
- return pipeline
32
- inpainting_pipeline = load_pipeline_wrapper()
33
 
34
- class InpaintingRequest(BaseModel):
35
  """
36
- Model representing a request for inpainting inference.
37
  """
38
- prompt: str = Field(..., description="Prompt text for inference")
39
- negative_prompt: str = Field(..., description="Negative prompt text for inference")
40
- num_inference_steps: int = Field(..., description="Number of inference steps")
41
- strength: float = Field(..., description="Strength of the inference")
42
- guidance_scale: float = Field(..., description="Guidance scale for inference")
43
- mode: str = Field(..., description="Mode for output ('b64_json' or 's3_json')")
44
- num_images: int = Field(..., description="Number of images to generate")
45
- use_augmentation: bool = Field(True, description="Whether to use image augmentation")
46
-
47
- class InpaintingBatchRequestModel(BaseModel):
 
 
 
48
  """
49
- Model representing a batch request for inpainting inference.
50
  """
51
- requests: List[InpaintingRequest]
52
 
53
  async def save_image(image: UploadFile) -> str:
54
  """
55
  Save an uploaded image to a temporary file and return the file path.
56
-
57
- Args:
58
- image (UploadFile): The uploaded image file.
59
-
60
- Returns:
61
- str: File path where the image is saved.
62
  """
63
  file_name = f"{uuid.uuid4()}.png"
64
  file_path = os.path.join("/tmp", file_name)
@@ -66,149 +51,75 @@ async def save_image(image: UploadFile) -> str:
66
  f.write(await image.read())
67
  return file_path
68
 
69
- def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
70
- """
71
- Augment an image by extending its dimensions and generating masks.
72
-
73
- Args:
74
- image_path (str): Path to the image file.
75
- target_width (int): Target width for augmentation.
76
- target_height (int): Target height for augmentation.
77
- roi_scale (float): Scale factor for region of interest.
78
- segmentation_model_name (str): Name of the segmentation model.
79
- detection_model_name (str): Name of the detection model.
80
-
81
- Returns:
82
- Tuple[Image.Image, Image.Image]: Augmented image and inverted mask.
83
- """
84
- image = Image.open(image_path)
85
- image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
86
- image = image_augmentation.extend_image(image)
87
- mask = image_augmentation.generate_mask_from_bbox(image, segmentation_model_name, detection_model_name)
88
- inverted_mask = image_augmentation.invert_mask(mask)
89
- return image, inverted_mask
90
-
91
- def run_inference(cfg, image_path: str, request: InpaintingRequest):
92
- """
93
- Run inference using an inpainting pipeline on an image.
94
-
95
- Args:
96
- cfg (dict): Configuration dictionary.
97
- image_path (str): Path to the image file.
98
- request (InpaintingRequest): Pydantic model containing inference parameters.
99
-
100
- Returns:
101
- dict: Resulting image in the specified mode ('b64_json' or 's3_json').
102
-
103
- Raises:
104
- ValueError: If an invalid mode is provided.
105
- """
106
- if request.use_augmentation:
107
- image, mask_image = augment_image(image_path,
108
- cfg['target_width'],
109
- cfg['target_height'],
110
- cfg['roi_scale'],
111
- cfg['segmentation_model'],
112
- cfg['detection_model'])
113
- else:
114
- image = Image.open(image_path)
115
- mask_image = None
116
-
117
- painting_pipeline = AutoPaintingPipeline(
118
- pipeline=inpainting_pipeline,
119
- image=image,
120
- mask_image=mask_image,
121
- target_height=cfg['target_height'],
122
- target_width=cfg['target_width']
123
  )
124
- output = painting_pipeline.run_inference(prompt=request.prompt,
125
- negative_prompt=request.negative_prompt,
126
- num_inference_steps=request.num_inference_steps,
127
- strength=request.strength,
128
- guidance_scale=request.guidance_scale,
129
- num_images=request.num_images)
130
- if request.mode == "s3_json":
131
- return pil_to_s3_json(output, file_name="output.png")
132
- elif request.mode == "b64_json":
133
- return pil_to_b64_json(output)
134
- else:
135
- raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
136
 
137
- class InpaintingBatcher(AsyncBatcher):
138
- async def process_batch(self, batch: Tuple[List[str], List[InpaintingRequest]]) -> List[Dict[str, Any]]:
139
- """
140
- Process a batch of images and requests for inpainting inference.
141
-
142
- Args:
143
- batch (Tuple[List[str], List[InpaintingRequest]]): Tuple of image paths and corresponding requests.
144
-
145
- Returns:
146
- List[Dict[str, Any]]: List of resulting images in the specified mode ('b64_json' or 's3_json').
147
- """
148
- image_paths, requests = batch
149
- results = []
150
- for image_path, request in zip(image_paths, requests):
151
- result = run_inference(cfg, image_path, request)
152
- results.append(result)
153
- return results
154
-
155
- @router.post("/inpainting")
156
- async def inpainting_inference(
157
  image: UploadFile = File(...),
158
- request_data: str = Form(...),
159
  ):
160
- """
161
- Handle POST request for inpainting inference.
162
-
163
- Args:
164
- image (UploadFile): Uploaded image file.
165
- request_data (str): JSON string of the request parameters.
166
-
167
- Returns:
168
- dict: Resulting image in the specified mode ('b64_json' or 's3_json').
169
-
170
- Raises:
171
- HTTPException: If there is an error during image processing.
172
- """
173
  try:
 
 
 
174
  image_path = await save_image(image)
175
- request_dict = json.loads(request_data)
176
- request = InpaintingRequest(**request_dict)
177
- result = run_inference(cfg, image_path, request)
178
- return result
 
 
 
 
 
 
 
179
  except Exception as e:
180
  raise HTTPException(status_code=500, detail=str(e))
181
 
182
- @router.post("/inpainting/batch")
183
- async def inpainting_batch_inference(
184
- images: List[UploadFile] = File(...),
185
- request_data: str = Form(...),
186
- ):
187
- """
188
- Handle POST request for batch inpainting inference.
189
-
190
- Args:
191
- images (List[UploadFile]): List of uploaded image files.
192
- request_data (str): JSON string of the request parameters.
193
-
194
- Returns:
195
- List[dict]: List of resulting images in the specified mode ('b64_json' or 's3_json').
196
 
197
- Raises:
198
- HTTPException: If there is an error during image processing.
199
- """
200
  try:
201
- request_dict = json.loads(request_data)
202
- batch_request = InpaintingBatchRequestModel(**request_dict)
203
- requests = batch_request.requests
204
-
205
- if len(images) != len(requests):
206
- raise HTTPException(status_code=400, detail="The number of images and requests must match.")
207
-
208
- batcher = InpaintingBatcher(max_batch_size=64)
209
- image_paths = [await save_image(image) for image in images]
210
- results = batcher.process_batch((image_paths, requests))
211
-
212
- return results
 
213
  except Exception as e:
214
  raise HTTPException(status_code=500, detail=str(e))
 
1
  import os
2
  import uuid
3
+ import json
4
+ from typing import List
5
  from fastapi import APIRouter, File, UploadFile, HTTPException, Form
6
+ from pydantic import BaseModel, Field, ValidationError
 
7
  import lightning.pytorch as pl
8
+ from scripts.api_utils import pil_to_b64_json
9
+ from scripts.outpainting import ControlNetZoeDepthOutpainting
 
10
  from async_batcher.batcher import AsyncBatcher
 
11
  from functools import lru_cache
 
 
12
 
13
+ pl.seed_everything(42)
14
 
15
+ router = APIRouter()
 
16
 
 
17
  @lru_cache(maxsize=1)
18
+ def load_pipeline():
19
+ outpainting_pipeline = ControlNetZoeDepthOutpainting(target_size=(1024, 1024))
20
+ return outpainting_pipeline
 
 
 
 
 
 
 
21
 
22
+ class OutpaintingRequest(BaseModel):
23
  """
24
+ Model representing a request for outpainting inference.
25
  """
26
+ controlnet_prompt: str = Field(...)
27
+ controlnet_negative_prompt: str = Field(...)
28
+ controlnet_conditioning_scale: float = Field(...)
29
+ controlnet_guidance_scale: float = Field(...)
30
+ controlnet_num_inference_steps: int = Field(...)
31
+ controlnet_guidance_end: float = Field(...)
32
+ inpainting_prompt: str = Field(...)
33
+ inpainting_negative_prompt: str = Field(...)
34
+ inpainting_guidance_scale: float = Field(...)
35
+ inpainting_strength: float = Field(...)
36
+ inpainting_num_inference_steps: int = Field(...)
37
+
38
+ class OutpaintingBatchRequestModel(BaseModel):
39
  """
40
+ Model representing a batch request for outpainting inference.
41
  """
42
+ requests: List[OutpaintingRequest]
43
 
44
  async def save_image(image: UploadFile) -> str:
45
  """
46
  Save an uploaded image to a temporary file and return the file path.
 
 
 
 
 
 
47
  """
48
  file_name = f"{uuid.uuid4()}.png"
49
  file_path = os.path.join("/tmp", file_name)
 
51
  f.write(await image.read())
52
  return file_path
53
 
54
+ def run_inference(image_path: str, request: OutpaintingRequest):
55
+ pipeline = load_pipeline()
56
+ result = pipeline.run_pipeline(
57
+ image_path,
58
+ controlnet_prompt=request.controlnet_prompt,
59
+ controlnet_negative_prompt=request.controlnet_negative_prompt,
60
+ controlnet_conditioning_scale=request.controlnet_conditioning_scale,
61
+ controlnet_guidance_scale=request.controlnet_guidance_scale,
62
+ controlnet_num_inference_steps=request.controlnet_num_inference_steps,
63
+ controlnet_guidance_end=request.controlnet_guidance_end,
64
+ inpainting_prompt=request.inpainting_prompt,
65
+ inpainting_negative_prompt=request.inpainting_negative_prompt,
66
+ inpainting_guidance_scale=request.inpainting_guidance_scale,
67
+ inpainting_strength=request.inpainting_strength,
68
+ inpainting_num_inference_steps=request.inpainting_num_inference_steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
70
+ return result
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ @router.post("/outpaint")
73
+ async def outpaint(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  image: UploadFile = File(...),
75
+ request: str = Form(...)
76
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  try:
78
+ request_dict = json.loads(request)
79
+ outpainting_request = OutpaintingRequest(**request_dict)
80
+
81
  image_path = await save_image(image)
82
+ result = run_inference(image_path, outpainting_request)
83
+
84
+ result_json = pil_to_b64_json(result)
85
+
86
+ os.remove(image_path)
87
+
88
+ return {"result": result_json}
89
+ except json.JSONDecodeError:
90
+ raise HTTPException(status_code=400, detail="Invalid JSON in request data")
91
+ except ValidationError as e:
92
+ raise HTTPException(status_code=422, detail=str(e))
93
  except Exception as e:
94
  raise HTTPException(status_code=500, detail=str(e))
95
 
96
+ class OutpaintingBatcher(AsyncBatcher):
97
+ async def process_batch(self, batch):
98
+ results = []
99
+ for image, request in batch:
100
+ image_path = await save_image(image)
101
+ try:
102
+ result = run_inference(image_path, request)
103
+ results.append(result)
104
+ finally:
105
+ os.remove(image_path)
106
+ return results
 
 
 
107
 
108
+ @router.post("/batch_outpaint")
109
+ async def batch_outpaint(images: List[UploadFile] = File(...), batch_request: str = Form(...)):
 
110
  try:
111
+ batch_request_dict = json.loads(batch_request)
112
+ batch_outpainting_request = OutpaintingBatchRequestModel(**batch_request_dict)
113
+
114
+ batcher = OutpaintingBatcher(max_queue_size=64)
115
+ results = await batcher.process_batch(list(zip(images, batch_outpainting_request.requests)))
116
+
117
+ result_jsons = [pil_to_b64_json(result) for result in results]
118
+
119
+ return {"results": result_jsons}
120
+ except json.JSONDecodeError:
121
+ raise HTTPException(status_code=400, detail="Invalid JSON in batch request data")
122
+ except ValidationError as e:
123
+ raise HTTPException(status_code=422, detail=str(e))
124
  except Exception as e:
125
  raise HTTPException(status_code=500, detail=str(e))
scripts/controlnet_outpainting.py CHANGED
@@ -1,35 +1,96 @@
1
- from diffusers import ControlNetModel,StableDiffusionXLControlNetPipeline
2
  import torch
3
- import requests
4
  from PIL import Image
5
- from io import BytesIO
 
 
 
 
6
 
 
 
 
7
 
8
- controlnet = ControlNetModel.from_pretrained(
9
- "destitech/controlnet-inpaint-dreamer-sdxl", torch_dtype=torch.float16, variant="fp16"
10
- )
 
11
 
12
- response = requests.get("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/outpainting/313891870-adb6dc80-2e9e-420c-bac3-f93e6de8d06b.png?download=true")
13
- control_image = Image.open('/home/PicPilot/sample_data/example2.jpg')
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
17
- "RunDiffusion/Juggernaut-XL-v9",
18
- torch_dtype=torch.float16,
19
- variant="fp16",
20
- controlnet=controlnet,
21
- ).to("cuda")
22
 
23
- image = pipeline(
24
- prompt='Showcase 4k',
25
- negative_prompt='low Resolution , Bad Resolution',
26
- height=1024,
27
- width=1024,
28
- guidance_scale=7.5,
29
- num_inference_steps=100,
30
- image=control_image,
31
- controlnet_conditioning_scale=0.9,
32
- control_guidance_end=0.9,
33
- ).images[0]
34
 
35
- image.save('output.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
2
  import torch
 
3
  from PIL import Image
4
+ import lightning.pytorch as pl
5
+ from scripts.api_utils import accelerator
6
+ from typing import Optional
7
+ import matplotlib.pyplot as plt
8
+ pl.seed_everything(42)
9
 
10
+ class ImageGenerator:
11
+ """
12
+ A class to generate images using ControlNet and Stable Diffusion XL pipelines.
13
 
14
+ Attributes:
15
+ controlnet (ControlNetModel): The ControlNet model.
16
+ pipeline (StableDiffusionXLControlNetPipeline): The Stable Diffusion XL pipeline with ControlNet.
17
+ """
18
 
19
+ def __init__(self, controlnet_model_name, sd_pipeline_model_name):
20
+ """
21
+ Initializes the ImageGenerator with the specified models.
22
 
23
+ Args:
24
+ controlnet_model_name (str): The name of the ControlNet model.
25
+ sd_pipeline_model_name (str): The name of the Stable Diffusion XL pipeline model.
26
+ image (str): The path to the image to be used.
27
+ """
28
+ self.controlnet = ControlNetModel.from_pretrained(
29
+ controlnet_model_name, torch_dtype=torch.float16, variant="fp16"
30
+ )
31
+ self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
32
+ sd_pipeline_model_name,
33
+ torch_dtype=torch.float16,
34
+ variant="fp16",
35
+ controlnet=self.controlnet,
36
+ ).to(accelerator())
37
 
38
+ def inference(self, prompt, negative_prompt, height, width, guidance_scale, num_images_per_prompt, num_inference_steps, image_path, controlnet_conditioning_scale, control_guidance_end,output_path:Optional[str]):
39
+ """
40
+ Generates images based on the provided parameters.
 
 
 
41
 
42
+ Args:
43
+ prompt (str): The prompt for image generation.
44
+ negative_prompt (str): The negative prompt for image generation.
45
+ height (int): The height of the generated images.
46
+ width (int): The width of the generated images.
47
+ guidance_scale (float): The guidance scale for image generation.
48
+ num_images_per_prompt (int): The number of images to generate per prompt.
49
+ num_inference_steps (int): The number of inference steps.
50
+ image_path (str): The path to the image to be used.
51
+ controlnet_conditioning_scale (float): The conditioning scale for ControlNet.
52
+ control_guidance_end (float): The end guidance for ControlNet.
53
 
54
+ Returns:
55
+ list: A list of generated images.
56
+ """
57
+ images_list = self.pipeline(
58
+ prompt=prompt,
59
+ negative_prompt=negative_prompt,
60
+ height=height,
61
+ width=width,
62
+ guidance_scale=guidance_scale,
63
+ num_images_per_prompt=num_images_per_prompt,
64
+ num_inference_steps=num_inference_steps,
65
+ image=Image.open(image_path),
66
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
67
+ control_guidance_end=control_guidance_end,
68
+ ).images
69
+ if output_path:
70
+ for i,image in enumerate(images_list):
71
+ image.save(f'{output_path}/output_{i}.png')
72
+ else:
73
+ return images_list
74
+
75
+ if __name__ == "__main__":
76
+ generator = ImageGenerator(
77
+ controlnet_model_name="destitech/controlnet-inpaint-dreamer-sdxl",
78
+ sd_pipeline_model_name="RunDiffusion/Juggernaut-XL-v9"
79
+ )
80
+ generator.inference(
81
+ prompt='Park',
82
+ negative_prompt='low Resolution , Bad Resolution',
83
+ height=1080,
84
+ width=1920,
85
+ guidance_scale=7.5,
86
+ num_images_per_prompt=4,
87
+ num_inference_steps=100,
88
+ image_path='/home/PicPilot/sample_data/example1.jpg',
89
+ controlnet_conditioning_scale=0.9,
90
+ control_guidance_end=0.9,
91
+ output_path='/home/PicPilot/output'
92
+ )
93
+
94
+
95
+
96
+
scripts/outpainting.py CHANGED
@@ -1,24 +1,15 @@
1
- import requests
2
  import torch
3
  from controlnet_aux import ZoeDetector
4
  from PIL import Image
5
- from diffusers import (
6
- AutoencoderKL,
7
- ControlNetModel,
8
- StableDiffusionXLControlNetPipeline,
9
- StableDiffusionXLInpaintPipeline
10
- )
11
- from typing import Optional
12
- from api_utils import ImageAugmentation
13
- import lightning.pytorch as pl
14
- pl.seed_everything(42)
15
-
16
-
17
-
18
 
 
19
 
20
 
21
- class OutpaintingProcessor:
22
  """
23
  A class for processing and outpainting images using Stable Diffusion XL.
24
 
@@ -27,65 +18,56 @@ class OutpaintingProcessor:
27
  the final outpainting.
28
  """
29
 
30
- def __init__(self, target_size=(1024, 1024)):
31
  """
32
- Initialize the OutpaintingProcessor with necessary models and pipelines.
33
 
34
  Args:
35
- target_size (tuple): The target size for the output image (width, height).
36
  """
37
  self.target_size = target_size
38
  print("Initializing models and pipelines...")
39
- self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(self.device)
40
  self.zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
41
  self.controlnets = [
42
  ControlNetModel.from_pretrained("destitech/controlnet-inpaint-dreamer-sdxl", torch_dtype=torch.float16, variant="fp16"),
43
- ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16,variant='fp16')
44
- ]
45
-
46
- print("Setting up initial pipeline...")
47
- self.controlnet_pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
48
- "SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, variant="fp16",
49
- controlnet=self.controlnets, vae=self.vae
50
- ).to(self.device)
51
-
52
  print("Setting up inpaint pipeline...")
53
- self.inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained("OzzyGT/RealVisXL_V4.0_inpainting",torch_dtype=torch.float16,
54
- variant="fp16",
55
- vae=self.vae,
56
- ).to(self.device)
57
 
58
- print("Initialization complete.")
59
-
60
- def load_and_preprocess_image(self, image_url):
61
  """
62
- Load an image from a URL and preprocess it for outpainting.
63
 
64
  Args:
65
- image_url (str): URL of the image to process.
66
 
67
  Returns:
68
- tuple: A tuple containing the resized original image and the background image.
69
  """
70
- original_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGBA")
 
71
  return self.scale_and_paste(original_image, self.target_size)
72
 
73
- def scale_and_paste(self, original_image, target_size, scale_factor=0.95):
74
  """
75
  Scale the original image and paste it onto a background of the target size.
76
 
77
  Args:
78
- original_image (PIL.Image): The original image to process.
79
- target_size (tuple): The target size (width, height) for the output image.
80
  scale_factor (float): Factor to scale down the image to leave some padding (default: 0.95).
81
 
82
  Returns:
83
- tuple: A tuple containing the resized original image and the background image.
84
  """
85
  target_width, target_height = target_size
86
  aspect_ratio = original_image.width / original_image.height
87
 
88
- if (target_width / target_height) < aspect_ratio:
89
  new_width = int(target_width * scale_factor)
90
  new_height = int(new_width / aspect_ratio)
91
  else:
@@ -97,128 +79,148 @@ class OutpaintingProcessor:
97
  x = (target_width - new_width) // 2
98
  y = (target_height - new_height) // 2
99
  background.paste(resized_original, (x, y), resized_original)
100
-
101
  return resized_original, background
102
 
103
- def generate_depth_map(self, image):
104
  """
105
  Generate a depth map for the given image using the Zoe model.
106
 
107
  Args:
108
- image (PIL.Image): The image to generate a depth map for.
109
 
110
  Returns:
111
- PIL.Image: The generated depth map.
112
  """
113
  return self.zoe(image, detect_resolution=512, image_resolution=self.target_size[0])
114
 
115
- def generate_image(self, prompt, negative_prompt, inpaint_image, zoe_image,guidance_scale,num_inference_steps):
116
  """
117
- Generate an image using the initial pipeline.
118
 
119
  Args:
120
  prompt (str): The prompt for image generation.
121
  negative_prompt (str): The negative prompt for image generation.
122
- inpaint_image (PIL.Image): The image to inpaint.
123
- zoe_image (PIL.Image): The depth map image.
124
- seed (int, optional): Seed for random number generation.
 
 
 
125
 
126
  Returns:
127
- PIL.Image: The generated image.
128
  """
129
-
130
- return self.initial_pipeline(
131
  prompt,
132
  negative_prompt=negative_prompt,
133
  image=[inpaint_image, zoe_image],
134
  guidance_scale=guidance_scale,
135
- num_inference_steps=25,
136
- controlnet_conditioning_scale=[0.5, 0.8],
137
- control_guidance_end=[0.9, 0.6],
138
  ).images[0]
139
 
140
- def create_mask(self, image, segmentation_model, detection_model):
141
  """
142
  Create a mask for the final outpainting process.
143
 
144
  Args:
145
- image (PIL.Image): The original image.
146
  segmentation_model (str): The segmentation model identifier.
147
  detection_model (str): The detection model identifier.
148
 
149
  Returns:
150
- PIL.Image: The created mask.
151
  """
152
- image_augmenter = ImageAugmentation(self.target_size[0], self.target_size[1])
153
- mask_image = image_augmenter.generate_mask_from_bbox(image, segmentation_model,detection_model)
154
  inverted_mask = image_augmenter.invert_mask(mask_image)
155
  return inverted_mask
156
 
157
- def generate_outpainting(self, prompt, negative_prompt, image, mask, seed:Optional[int]=42):
158
  """
159
  Generate the final outpainted image.
160
 
161
  Args:
162
  prompt (str): The prompt for image generation.
163
  negative_prompt (str): The negative prompt for image generation.
164
- image (PIL.Image): The image to outpaint.
165
- mask (PIL.Image): The mask for outpainting.
166
- seed (int, optional): Seed for random number generation.
 
 
167
 
168
  Returns:
169
- PIL.Image: The final outpainted image.
170
  """
171
-
172
  return self.inpaint_pipeline(
173
  prompt,
174
  negative_prompt=negative_prompt,
175
  image=image,
176
  mask_image=mask,
177
- guidance_scale=10.0,
178
- strength=0.8,
179
- num_inference_steps=30,
180
  ).images[0]
181
 
182
- def process(self, image_url, initial_prompt, final_prompt, negative_prompt=""):
183
  """
184
  Process an image through the entire outpainting pipeline.
185
 
186
  Args:
187
- image_url (str): URL of the image to process.
188
- initial_prompt (str): Prompt for the initial image generation.
189
- final_prompt (str): Prompt for the final outpainting.
190
- negative_prompt (str, optional): Negative prompt for both stages.
 
 
 
 
 
 
 
 
191
 
192
  Returns:
193
- PIL.Image: The final outpainted image.
194
  """
195
- print("Loading and preprocessing image...")
196
- resized_img, background_image = self.load_and_preprocess_image(image_url)
197
-
198
- print("Generating depth map...")
199
  image_zoe = self.generate_depth_map(background_image)
200
-
201
- print("Generating initial image...")
202
- temp_image = self.generate_image(initial_prompt, negative_prompt, background_image, image_zoe)
203
  x = (self.target_size[0] - resized_img.width) // 2
204
  y = (self.target_size[1] - resized_img.height) // 2
205
  temp_image.paste(resized_img, (x, y), resized_img)
206
- print("Creating mask for outpainting...")
207
  final_mask = self.create_mask(temp_image, "facebook/sam-vit-large", "yolov8l")
208
  mask_blurred = self.inpaint_pipeline.mask_processor.blur(final_mask, blur_factor=20)
209
- print("Generating final outpainted image...")
210
- final_image = self.generate_outpainting(final_prompt, negative_prompt, temp_image, mask_blurred)
 
211
  final_image.paste(resized_img, (x, y), resized_img)
212
  return final_image
213
 
 
214
  def main():
215
- processor = OutpaintingProcessor(target_size=(1024, 1024)) # Set to 720p resolution
216
- result = processor.process(
217
- "https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/outpainting/BMW_i8_Safety_Car_Front.png?download=true",
218
- "a car on the highway",
219
- "high quality photo of a car on the highway, shadows, highly detailed")
 
 
 
 
 
 
 
 
220
  result.save("outpainted_result.png")
221
  print("Outpainting complete. Result saved as 'outpainted_result.png'")
222
 
 
223
  if __name__ == "__main__":
224
  main()
 
 
1
  import torch
2
  from controlnet_aux import ZoeDetector
3
  from PIL import Image
4
+ from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetPipeline, StableDiffusionXLInpaintPipeline
5
+ from scripts.api_utils import ImageAugmentation, accelerator
6
+ import lightning.pytorch as pl
7
+ from rembg import remove
 
 
 
 
 
 
 
 
 
8
 
9
+ pl.seed_everything(42)
10
 
11
 
12
+ class ControlNetZoeDepthOutpainting:
13
  """
14
  A class for processing and outpainting images using Stable Diffusion XL.
15
 
 
18
  the final outpainting.
19
  """
20
 
21
+ def __init__(self, target_size: tuple[int, int] = (1024, 1024)):
22
  """
23
+ Initialize the ImageOutpaintingProcessor with necessary models and pipelines.
24
 
25
  Args:
26
+ target_size (tuple[int, int]): The target size for the output image (width, height).
27
  """
28
  self.target_size = target_size
29
  print("Initializing models and pipelines...")
30
+ self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(accelerator())
31
  self.zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
32
  self.controlnets = [
33
  ControlNetModel.from_pretrained("destitech/controlnet-inpaint-dreamer-sdxl", torch_dtype=torch.float16, variant="fp16"),
34
+ ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16)
35
+ ]
36
+ print("Setting up sdxl pipeline...")
37
+ self.controlnet_pipeline = StableDiffusionXLControlNetPipeline.from_pretrained("SG161222/RealVisXL_V4.0", torch_dtype=torch.float16, variant="fp16", controlnet=self.controlnets, vae=self.vae).to(accelerator())
 
 
 
 
 
38
  print("Setting up inpaint pipeline...")
39
+ self.inpaint_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained("OzzyGT/RealVisXL_V4.0_inpainting", torch_dtype=torch.float16, variant="fp16", vae=self.vae).to(accelerator())
 
 
 
40
 
41
+ def load_and_preprocess_image(self, image_path: str) -> tuple[Image.Image, Image.Image]:
 
 
42
  """
43
+ Load an image from a file path and preprocess it for outpainting.
44
 
45
  Args:
46
+ image_path (str): Path of the image to process.
47
 
48
  Returns:
49
+ tuple[Image.Image, Image.Image]: A tuple containing the resized original image and the background image.
50
  """
51
+ original_image = Image.open(image_path).convert("RGBA")
52
+ original_image = remove(original_image)
53
  return self.scale_and_paste(original_image, self.target_size)
54
 
55
+ def scale_and_paste(self, original_image: Image.Image, target_size: tuple[int, int], scale_factor: float = 0.95) -> tuple[Image.Image, Image.Image]:
56
  """
57
  Scale the original image and paste it onto a background of the target size.
58
 
59
  Args:
60
+ original_image (Image.Image): The original image to process.
61
+ target_size (tuple[int, int]): The target size (width, height) for the output image.
62
  scale_factor (float): Factor to scale down the image to leave some padding (default: 0.95).
63
 
64
  Returns:
65
+ tuple[Image.Image, Image.Image]: A tuple containing the resized original image and the background image.
66
  """
67
  target_width, target_height = target_size
68
  aspect_ratio = original_image.width / original_image.height
69
 
70
+ if (target_width / target_height) < aspect_ratio:
71
  new_width = int(target_width * scale_factor)
72
  new_height = int(new_width / aspect_ratio)
73
  else:
 
79
  x = (target_width - new_width) // 2
80
  y = (target_height - new_height) // 2
81
  background.paste(resized_original, (x, y), resized_original)
 
82
  return resized_original, background
83
 
84
+ def generate_depth_map(self, image: Image.Image) -> Image.Image:
85
  """
86
  Generate a depth map for the given image using the Zoe model.
87
 
88
  Args:
89
+ image (Image.Image): The image to generate a depth map for.
90
 
91
  Returns:
92
+ Image.Image: The generated depth map.
93
  """
94
  return self.zoe(image, detect_resolution=512, image_resolution=self.target_size[0])
95
 
96
+ def generate_base_image(self, prompt: str, negative_prompt: str, inpaint_image: Image.Image, zoe_image: Image.Image, guidance_scale: float, controlnet_num_inference_steps: int, controlnet_conditioning_scale: float, control_guidance_end: float) -> Image.Image:
97
  """
98
+ Generate an image using the controlnet pipeline.
99
 
100
  Args:
101
  prompt (str): The prompt for image generation.
102
  negative_prompt (str): The negative prompt for image generation.
103
+ inpaint_image (Image.Image): The image to inpaint.
104
+ zoe_image (Image.Image): The depth map image.
105
+ guidance_scale (float): Guidance scale for controlnet.
106
+ controlnet_num_inference_steps (int): Number of inference steps for controlnet.
107
+ controlnet_conditioning_scale (float): Conditioning scale for controlnet.
108
+ control_guidance_end (float): Guidance end for controlnet.
109
 
110
  Returns:
111
+ Image.Image: The generated image.
112
  """
113
+ return self.controlnet_pipeline(
 
114
  prompt,
115
  negative_prompt=negative_prompt,
116
  image=[inpaint_image, zoe_image],
117
  guidance_scale=guidance_scale,
118
+ num_inference_steps=controlnet_num_inference_steps,
119
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
120
+ control_guidance_end=control_guidance_end,
121
  ).images[0]
122
 
123
+ def create_mask(self, image: Image.Image, segmentation_model: str, detection_model: str) -> Image.Image:
124
  """
125
  Create a mask for the final outpainting process.
126
 
127
  Args:
128
+ image (Image.Image): The original image.
129
  segmentation_model (str): The segmentation model identifier.
130
  detection_model (str): The detection model identifier.
131
 
132
  Returns:
133
+ Image.Image: The created mask.
134
  """
135
+ image_augmenter = ImageAugmentation(self.target_size[0], self.target_size[1], roi_scale=0.4)
136
+ mask_image = image_augmenter.generate_mask_from_bbox(image, segmentation_model, detection_model)
137
  inverted_mask = image_augmenter.invert_mask(mask_image)
138
  return inverted_mask
139
 
140
+ def generate_outpainting(self, prompt: str, negative_prompt: str, image: Image.Image, mask: Image.Image, guidance_scale: float, strength: float, num_inference_steps: int) -> Image.Image:
141
  """
142
  Generate the final outpainted image.
143
 
144
  Args:
145
  prompt (str): The prompt for image generation.
146
  negative_prompt (str): The negative prompt for image generation.
147
+ image (Image.Image): The image to outpaint.
148
+ mask (Image.Image): The mask for outpainting.
149
+ guidance_scale (float): Guidance scale for inpainting.
150
+ strength (float): Strength for inpainting.
151
+ num_inference_steps (int): Number of inference steps for inpainting.
152
 
153
  Returns:
154
+ Image.Image: The final outpainted image.
155
  """
 
156
  return self.inpaint_pipeline(
157
  prompt,
158
  negative_prompt=negative_prompt,
159
  image=image,
160
  mask_image=mask,
161
+ guidance_scale=guidance_scale,
162
+ strength=strength,
163
+ num_inference_steps=num_inference_steps,
164
  ).images[0]
165
 
166
+ def run_pipeline(self, image_path: str, controlnet_prompt: str, controlnet_negative_prompt: str, controlnet_conditioning_scale: float, controlnet_guidance_scale: float, controlnet_num_inference_steps: int, controlnet_guidance_end: float, inpainting_prompt: str, inpainting_negative_prompt: str, inpainting_guidance_scale: float, inpainting_strength: float, inpainting_num_inference_steps: int) -> Image.Image:
167
  """
168
  Process an image through the entire outpainting pipeline.
169
 
170
  Args:
171
+ image_path (str): Path of the image to process.
172
+ controlnet_prompt (str): Prompt for the controlnet image generation.
173
+ controlnet_negative_prompt (str): Negative prompt for controlnet image generation.
174
+ controlnet_conditioning_scale (float): Conditioning scale for controlnet.
175
+ controlnet_guidance_scale (float): Guidance scale for controlnet.
176
+ controlnet_num_inference_steps (int): Number of inference steps for controlnet.
177
+ controlnet_guidance_end (float): Guidance end for controlnet.
178
+ inpainting_prompt (str): Prompt for the inpainting image generation.
179
+ inpainting_negative_prompt (str): Negative prompt for inpainting image generation.
180
+ inpainting_guidance_scale (float): Guidance scale for inpainting.
181
+ inpainting_strength (float): Strength for inpainting.
182
+ inpainting_num_inference_steps (int): Number of inference steps for inpainting.
183
 
184
  Returns:
185
+ Image.Image: The final outpainted image.
186
  """
187
+ print("Loading and preprocessing image")
188
+ resized_img, background_image = self.load_and_preprocess_image(image_path)
189
+ print("Generating depth map")
 
190
  image_zoe = self.generate_depth_map(background_image)
191
+ print("Generating initial image")
192
+ temp_image = self.generate_base_image(controlnet_prompt, controlnet_negative_prompt, background_image, image_zoe,
193
+ controlnet_guidance_scale, controlnet_num_inference_steps, controlnet_conditioning_scale, controlnet_guidance_end)
194
  x = (self.target_size[0] - resized_img.width) // 2
195
  y = (self.target_size[1] - resized_img.height) // 2
196
  temp_image.paste(resized_img, (x, y), resized_img)
197
+ print("Creating mask for outpainting")
198
  final_mask = self.create_mask(temp_image, "facebook/sam-vit-large", "yolov8l")
199
  mask_blurred = self.inpaint_pipeline.mask_processor.blur(final_mask, blur_factor=20)
200
+ print("Generating final outpainted image")
201
+ final_image = self.generate_outpainting(inpainting_prompt, inpainting_negative_prompt, temp_image, mask_blurred,
202
+ inpainting_guidance_scale, inpainting_strength, inpainting_num_inference_steps)
203
  final_image.paste(resized_img, (x, y), resized_img)
204
  return final_image
205
 
206
+
207
  def main():
208
+ processor = ControlNetZoeDepthOutpainting(target_size=(1024, 1024))
209
+ result = processor.run_pipeline("/home/PicPilot/sample_data/example1.jpg",
210
+ "product in the kitchen",
211
+ "low resolution, Bad Resolution",
212
+ 0.9,
213
+ 7.5,
214
+ 50,
215
+ 0.6,
216
+ "Editorial Photography of the Pot in the kitchen",
217
+ "low Resolution, Bad Resolution",
218
+ 8,
219
+ 0.7,
220
+ 30)
221
  result.save("outpainted_result.png")
222
  print("Outpainting complete. Result saved as 'outpainted_result.png'")
223
 
224
+
225
  if __name__ == "__main__":
226
  main()