Spaces:
Runtime error
Runtime error
VikramSingh178
commited on
Commit
β’
6c850b1
1
Parent(s):
47b9b86
chore: Update import statement for InpaintingRequest in painting.py and refactor code to use shared BaseModel for Painting and InpaintingRequest classes
Browse filesFormer-commit-id: a23f3f613d6636adcadd1b01d7f1d61993d91251 [formerly e5c83686f1238c8a9cb66b21590a5a1ee6597665]
Former-commit-id: 2252984afe4046eea2172d7ddd6d0096cf70b07b
- api/__pycache__/endpoints.cpython-310.pyc +0 -0
- api/endpoints.py +1 -1
- api/models/__pycache__/painting.cpython-310.pyc +0 -0
- api/models/painting.py +2 -0
- api/routers/__pycache__/painting.cpython-310.pyc +0 -0
- api/routers/painting.py +135 -71
- configs/inpainting.yaml +2 -2
- outputs/mask.jpg +0 -0
- outputs/output.jpg +0 -0
- scripts/__pycache__/config.cpython-310.pyc +0 -0
- scripts/__pycache__/inpainting_pipeline.cpython-310.pyc +0 -0
- scripts/config.py +1 -0
- scripts/inpainting_pipeline.py +54 -55
api/__pycache__/endpoints.cpython-310.pyc
CHANGED
Binary files a/api/__pycache__/endpoints.cpython-310.pyc and b/api/__pycache__/endpoints.cpython-310.pyc differ
|
|
api/endpoints.py
CHANGED
@@ -51,5 +51,5 @@ async def root():
|
|
51 |
def check_health():
|
52 |
return {"status": "ok"}
|
53 |
|
54 |
-
|
55 |
|
|
|
51 |
def check_health():
|
52 |
return {"status": "ok"}
|
53 |
|
54 |
+
|
55 |
|
api/models/__pycache__/painting.cpython-310.pyc
CHANGED
Binary files a/api/models/__pycache__/painting.cpython-310.pyc and b/api/models/__pycache__/painting.cpython-310.pyc differ
|
|
api/models/painting.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from pydantic import BaseModel
|
|
|
2 |
|
3 |
|
4 |
class InpaintingRequest(BaseModel):
|
@@ -7,3 +8,4 @@ class InpaintingRequest(BaseModel):
|
|
7 |
num_inference_steps: int
|
8 |
strength: float
|
9 |
guidance_scale: float
|
|
|
|
1 |
from pydantic import BaseModel
|
2 |
+
from fastapi import Form
|
3 |
|
4 |
|
5 |
class InpaintingRequest(BaseModel):
|
|
|
8 |
num_inference_steps: int
|
9 |
strength: float
|
10 |
guidance_scale: float
|
11 |
+
|
api/routers/__pycache__/painting.cpython-310.pyc
CHANGED
Binary files a/api/routers/__pycache__/painting.cpython-310.pyc and b/api/routers/__pycache__/painting.cpython-310.pyc differ
|
|
api/routers/painting.py
CHANGED
@@ -1,69 +1,34 @@
|
|
|
|
|
|
1 |
import sys
|
2 |
sys.path.append("../scripts")
|
3 |
-
from fastapi import APIRouter, File, UploadFile, HTTPException
|
4 |
-
from pydantic import BaseModel
|
5 |
-
from PIL import Image
|
6 |
-
from io import BytesIO
|
7 |
-
from models.painting import InpaintingRequest
|
8 |
import uuid
|
9 |
-
from inpainting_pipeline import AutoPaintingPipeline
|
10 |
-
from utils import pil_to_s3_json, ImageAugmentation
|
11 |
-
from hydra import compose, initialize
|
12 |
import lightning.pytorch as pl
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
18 |
|
19 |
-
#class InpaintingRequest(BaseModel):
|
20 |
-
# prompt: str
|
21 |
-
# negative_prompt: str
|
22 |
-
# num_inference_steps: int
|
23 |
-
# strength: float
|
24 |
-
# guidance_scale: float
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
Augments an image with a given prompt, model, and other parameters.
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
- target_height (int): The desired height of the augmented image.
|
34 |
-
- roi_scale (float): The scale factor for the region of interest.
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
- inverted_mask (PIL.Image.Image): The inverted mask generated from the augmented image.
|
39 |
-
"""
|
40 |
-
image = Image.open(image)
|
41 |
image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
|
42 |
image = image_augmentation.extend_image(image)
|
43 |
mask = image_augmentation.generate_mask_from_bbox(image, segmentation_model_name, detection_model_name)
|
44 |
inverted_mask = image_augmentation.invert_mask(mask)
|
45 |
return image, inverted_mask
|
46 |
|
47 |
-
def run_inference(cfg: dict, image_path: str, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float):
|
48 |
-
"""
|
49 |
-
Run inference using the provided configuration and input image.
|
50 |
-
|
51 |
-
Args:
|
52 |
-
cfg (dict): Configuration dictionary containing model parameters.
|
53 |
-
image_path (str): Path to the input image file.
|
54 |
-
prompt (str): Prompt for the inference process.
|
55 |
-
negative_prompt (str): Negative prompt for the inference process.
|
56 |
-
num_inference_steps (int): Number of inference steps to perform.
|
57 |
-
strength (float): Strength parameter for the inference.
|
58 |
-
guidance_scale (float): Guidance scale for the inference.
|
59 |
-
|
60 |
-
Returns:
|
61 |
-
dict: A JSON object containing the image ID and the signed URL.
|
62 |
-
|
63 |
-
Raises:
|
64 |
-
HTTPException: If an error occurs during the inference process.
|
65 |
-
|
66 |
-
"""
|
67 |
image, mask_image = augment_image(image_path,
|
68 |
cfg['target_width'],
|
69 |
cfg['target_height'],
|
@@ -71,25 +36,89 @@ def run_inference(cfg: dict, image_path: str, prompt: str, negative_prompt: str,
|
|
71 |
cfg['segmentation_model'],
|
72 |
cfg['detection_model'])
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
80 |
negative_prompt=negative_prompt,
|
81 |
num_inference_steps=num_inference_steps,
|
82 |
strength=strength,
|
83 |
guidance_scale=guidance_scale)
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
"""
|
94 |
Run the inpainting/outpainting inference pipeline.
|
95 |
|
@@ -100,6 +129,8 @@ async def inpainting_inference(image: UploadFile = File(...),
|
|
100 |
- num_inference_steps: int - The number of inference steps to perform during the inpainting/outpainting process.
|
101 |
- strength: float - The strength parameter for controlling the inpainting/outpainting process.
|
102 |
- guidance_scale: float - The guidance scale parameter for controlling the inpainting/outpainting process.
|
|
|
|
|
103 |
|
104 |
Returns:
|
105 |
- result: The result of the inpainting/outpainting process.
|
@@ -113,14 +144,47 @@ async def inpainting_inference(image: UploadFile = File(...),
|
|
113 |
with open(image_path, "wb") as f:
|
114 |
f.write(image_bytes)
|
115 |
|
|
|
|
|
|
|
116 |
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
-
return
|
123 |
except Exception as e:
|
124 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
125 |
|
126 |
|
|
|
1 |
+
from fastapi import APIRouter, File, UploadFile, HTTPException, Form
|
2 |
+
from PIL import Image
|
3 |
import sys
|
4 |
sys.path.append("../scripts")
|
|
|
|
|
|
|
|
|
|
|
5 |
import uuid
|
|
|
|
|
|
|
6 |
import lightning.pytorch as pl
|
7 |
+
from typing import List
|
8 |
+
from utils import pil_to_s3_json, pil_to_b64_json, ImageAugmentation, accelerator
|
9 |
+
from inpainting_pipeline import AutoPaintingPipeline, load_pipeline
|
10 |
+
from hydra import compose, initialize
|
11 |
+
from pydantic import BaseModel
|
12 |
+
from async_batcher.batcher import AsyncBatcher
|
13 |
+
from typing import Dict
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
router = APIRouter()
|
17 |
+
pl.seed_everything(42)
|
|
|
18 |
|
19 |
+
with initialize(version_base=None, config_path="../../configs"):
|
20 |
+
cfg = compose(config_name="inpainting")
|
21 |
+
inpainting_pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
|
|
|
|
|
22 |
|
23 |
+
def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
|
24 |
+
image = Image.open(image_path)
|
|
|
|
|
|
|
25 |
image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
|
26 |
image = image_augmentation.extend_image(image)
|
27 |
mask = image_augmentation.generate_mask_from_bbox(image, segmentation_model_name, detection_model_name)
|
28 |
inverted_mask = image_augmentation.invert_mask(mask)
|
29 |
return image, inverted_mask
|
30 |
|
31 |
+
def run_inference(cfg: dict, image_path: str, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float, mode: str, num_images: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
image, mask_image = augment_image(image_path,
|
33 |
cfg['target_width'],
|
34 |
cfg['target_height'],
|
|
|
36 |
cfg['segmentation_model'],
|
37 |
cfg['detection_model'])
|
38 |
|
39 |
+
painting_pipeline = AutoPaintingPipeline(
|
40 |
+
pipeline=inpainting_pipeline,
|
41 |
+
image=image,
|
42 |
+
mask_image=mask_image,
|
43 |
+
target_height=cfg['target_height'],
|
44 |
+
target_width=cfg['target_width']
|
45 |
+
)
|
46 |
+
output = painting_pipeline.run_inference(prompt=prompt,
|
47 |
negative_prompt=negative_prompt,
|
48 |
num_inference_steps=num_inference_steps,
|
49 |
strength=strength,
|
50 |
guidance_scale=guidance_scale)
|
51 |
+
if mode == "s3_json":
|
52 |
+
return pil_to_s3_json(output, file_name="output.png")
|
53 |
+
elif mode == "b64_json":
|
54 |
+
return pil_to_b64_json(output)
|
55 |
+
else:
|
56 |
+
raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
|
57 |
+
|
58 |
+
class InpaintingRequest(BaseModel):
|
59 |
+
prompt: str
|
60 |
+
negative_prompt: str
|
61 |
+
num_inference_steps: int
|
62 |
+
strength: float
|
63 |
+
guidance_scale: float
|
64 |
+
num_images: int = 1
|
65 |
+
|
66 |
+
class InpaintingBatcher(AsyncBatcher[List[Dict], dict]):
|
67 |
+
def __init__(self, pipeline, cfg):
|
68 |
+
self.pipeline = pipeline
|
69 |
+
self.cfg = cfg
|
70 |
+
|
71 |
+
def process_batch(self, batch: List[Dict], image_paths: List[str]) -> List[dict]:
|
72 |
+
results = []
|
73 |
+
for data, image_path in zip(batch, image_paths):
|
74 |
+
try:
|
75 |
+
image, mask_image = augment_image(
|
76 |
+
image_path,
|
77 |
+
self.cfg['target_width'],
|
78 |
+
self.cfg['target_height'],
|
79 |
+
self.cfg['roi_scale'],
|
80 |
+
self.cfg['segmentation_model'],
|
81 |
+
self.cfg['detection_model']
|
82 |
+
)
|
83 |
+
|
84 |
+
pipeline = AutoPaintingPipeline(
|
85 |
+
image=image,
|
86 |
+
mask_image=mask_image,
|
87 |
+
target_height=self.cfg['target_height'],
|
88 |
+
target_width=self.cfg['target_width']
|
89 |
+
)
|
90 |
+
output = pipeline.run_inference(
|
91 |
+
prompt=data['prompt'],
|
92 |
+
negative_prompt=data['negative_prompt'],
|
93 |
+
num_inference_steps=data['num_inference_steps'],
|
94 |
+
strength=data['strength'],
|
95 |
+
guidance_scale=data['guidance_scale']
|
96 |
+
)
|
97 |
+
|
98 |
+
if data['mode'] == "s3_json":
|
99 |
+
result = pil_to_s3_json(output, 'inpainting_image')
|
100 |
+
elif data['mode'] == "b64_json":
|
101 |
+
result = pil_to_b64_json(output)
|
102 |
+
else:
|
103 |
+
raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
|
104 |
+
|
105 |
+
results.append(result)
|
106 |
+
except Exception as e:
|
107 |
+
print(f"Error in process_batch: {e}")
|
108 |
+
raise HTTPException(status_code=500, detail="Batch inference failed")
|
109 |
+
return results
|
110 |
+
|
111 |
+
@router.post("/inpainting")
|
112 |
+
async def inpainting_inference(
|
113 |
+
image: UploadFile = File(...),
|
114 |
+
prompt: str = Form(...),
|
115 |
+
negative_prompt: str = Form(...),
|
116 |
+
num_inference_steps: int = Form(...),
|
117 |
+
strength: float = Form(...),
|
118 |
+
guidance_scale: float = Form(...),
|
119 |
+
mode: str = Form(...),
|
120 |
+
num_images: int = Form(1)
|
121 |
+
):
|
122 |
"""
|
123 |
Run the inpainting/outpainting inference pipeline.
|
124 |
|
|
|
129 |
- num_inference_steps: int - The number of inference steps to perform during the inpainting/outpainting process.
|
130 |
- strength: float - The strength parameter for controlling the inpainting/outpainting process.
|
131 |
- guidance_scale: float - The guidance scale parameter for controlling the inpainting/outpainting process.
|
132 |
+
- mode: str - The output mode, either "s3_json" or "b64_json".
|
133 |
+
- num_images: int - The number of images to generate.
|
134 |
|
135 |
Returns:
|
136 |
- result: The result of the inpainting/outpainting process.
|
|
|
144 |
with open(image_path, "wb") as f:
|
145 |
f.write(image_bytes)
|
146 |
|
147 |
+
result = run_inference(
|
148 |
+
cfg, image_path, prompt, negative_prompt, num_inference_steps, strength, guidance_scale, mode, num_images
|
149 |
+
)
|
150 |
|
151 |
+
return result
|
152 |
+
except Exception as e:
|
153 |
+
raise HTTPException(status_code=500, detail=str(e))
|
154 |
+
|
155 |
+
@router.post("/inpainting_batch")
|
156 |
+
async def inpainting_batch_inference(
|
157 |
+
batch: List[dict],
|
158 |
+
images: List[UploadFile] = File(...)
|
159 |
+
):
|
160 |
+
"""
|
161 |
+
Run batch inpainting/outpainting inference pipeline.
|
162 |
|
163 |
+
Parameters:
|
164 |
+
- batch: List[dict] - The batch of requests containing parameters for the inpainting/outpainting process.
|
165 |
+
- images: List[UploadFile] - The list of image files to be used for inpainting/outpainting.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
- results: The results of the inpainting/outpainting process for each request.
|
169 |
+
|
170 |
+
Raises:
|
171 |
+
- HTTPException: If an error occurs during the inpainting/outpainting process.
|
172 |
+
"""
|
173 |
+
try:
|
174 |
+
image_paths = []
|
175 |
+
for image in images:
|
176 |
+
image_bytes = await image.read()
|
177 |
+
image_path = f"/tmp/{uuid.uuid4()}.png"
|
178 |
+
with open(image_path, "wb") as f:
|
179 |
+
f.write(image_bytes)
|
180 |
+
image_paths.append(image_path)
|
181 |
+
|
182 |
+
batcher = InpaintingBatcher(pipeline, cfg)
|
183 |
+
results = batcher.process_batch(batch, image_paths)
|
184 |
|
185 |
+
return results
|
186 |
except Exception as e:
|
187 |
raise HTTPException(status_code=500, detail=str(e))
|
188 |
+
|
189 |
|
190 |
|
configs/inpainting.yaml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
segmentation_model : 'facebook/sam-vit-
|
2 |
-
detection_model : '
|
3 |
model : 'kandinsky-community/kandinsky-2-2-decoder-inpaint'
|
4 |
target_width : 2560
|
5 |
target_height : 1472
|
|
|
1 |
+
segmentation_model : 'facebook/sam-vit-base'
|
2 |
+
detection_model : 'yolov8s'
|
3 |
model : 'kandinsky-community/kandinsky-2-2-decoder-inpaint'
|
4 |
target_width : 2560
|
5 |
target_height : 1472
|
outputs/mask.jpg
CHANGED
outputs/output.jpg
CHANGED
scripts/__pycache__/config.cpython-310.pyc
CHANGED
Binary files a/scripts/__pycache__/config.cpython-310.pyc and b/scripts/__pycache__/config.cpython-310.pyc differ
|
|
scripts/__pycache__/inpainting_pipeline.cpython-310.pyc
CHANGED
Binary files a/scripts/__pycache__/inpainting_pipeline.cpython-310.pyc and b/scripts/__pycache__/inpainting_pipeline.cpython-310.pyc differ
|
|
scripts/config.py
CHANGED
@@ -9,6 +9,7 @@ CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
|
|
9 |
SEGMENTATION_MODEL_NAME = "facebook/sam-vit-large"
|
10 |
DETECTION_MODEL_NAME = "yolov8l"
|
11 |
ENABLE_COMPILE = False
|
|
|
12 |
|
13 |
|
14 |
|
|
|
9 |
SEGMENTATION_MODEL_NAME = "facebook/sam-vit-large"
|
10 |
DETECTION_MODEL_NAME = "yolov8l"
|
11 |
ENABLE_COMPILE = False
|
12 |
+
INPAINTING_MODEL_NAME = ''
|
13 |
|
14 |
|
15 |
|
scripts/inpainting_pipeline.py
CHANGED
@@ -1,81 +1,80 @@
|
|
1 |
import torch
|
2 |
-
from diffusers import AutoPipelineForInpainting
|
3 |
from diffusers.utils import load_image
|
4 |
-
from utils import
|
5 |
import hydra
|
6 |
from omegaconf import DictConfig
|
7 |
from PIL import Image
|
8 |
from functools import lru_cache
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class AutoPaintingPipeline:
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
model_name (str): The name of the pretrained inpainting model.
|
18 |
-
image (Image): The input image to be processed.
|
19 |
-
mask_image (Image): The mask image indicating the areas to be inpainted.
|
20 |
-
"""
|
21 |
-
|
22 |
-
def __init__(self, model_name: str, image: Image, mask_image: Image,target_width: int, target_height: int):
|
23 |
-
self.model_name = model_name
|
24 |
-
self.device = accelerator()
|
25 |
-
self.pipeline = AutoPipelineForInpainting.from_pretrained(self.model_name, torch_dtype=torch.float16)
|
26 |
-
self.image = load_image(image)
|
27 |
-
self.mask_image = load_image(mask_image)
|
28 |
self.target_width = target_width
|
29 |
self.target_height = target_height
|
30 |
-
|
31 |
-
self.pipeline.unet = torch.compile(self.pipeline.unet,mode='max-autotune')
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
def run_inference(self, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float):
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
47 |
return output
|
48 |
-
|
49 |
-
|
50 |
-
@hydra.main(version_base=None ,config_path="../configs", config_name="inpainting")
|
51 |
def inference(cfg: DictConfig):
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
cfg (DictConfig): The configuration file for the inpainting pipeline.
|
57 |
-
"""
|
58 |
augmenter = ImageAugmentation(target_width=cfg.target_width, target_height=cfg.target_height)
|
59 |
-
model_name = cfg.model
|
60 |
image_path = "../sample_data/example3.jpg"
|
61 |
image = Image.open(image_path)
|
62 |
extended_image = augmenter.extend_image(image)
|
63 |
mask_image = augmenter.generate_mask_from_bbox(extended_image, cfg.segmentation_model, cfg.detection_model)
|
64 |
mask_image = augmenter.invert_mask(mask_image)
|
65 |
-
prompt = cfg.prompt
|
66 |
-
negative_prompt = cfg.negative_prompt
|
67 |
-
num_inference_steps = cfg.num_inference_steps
|
68 |
-
strength = cfg.strength
|
69 |
-
guidance_scale = cfg.guidance_scale
|
70 |
-
pipeline = AutoPaintingPipeline(model_name=model_name, image = extended_image, mask_image=mask_image, target_height=cfg.target_height, target_width=cfg.target_width)
|
71 |
-
output = pipeline.run_inference(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, strength=strength, guidance_scale=guidance_scale)
|
72 |
-
output.save(f'{cfg.output_path}/output.jpg')
|
73 |
-
mask_image.save(f'{cfg.output_path}/mask.jpg')
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
|
|
|
|
|
|
|
|
76 |
if __name__ == "__main__":
|
77 |
inference()
|
78 |
|
79 |
-
|
80 |
-
|
81 |
|
|
|
1 |
import torch
|
2 |
+
from diffusers import AutoPipelineForInpainting
|
3 |
from diffusers.utils import load_image
|
4 |
+
from utils import accelerator, ImageAugmentation
|
5 |
import hydra
|
6 |
from omegaconf import DictConfig
|
7 |
from PIL import Image
|
8 |
from functools import lru_cache
|
9 |
|
10 |
+
@lru_cache(maxsize=1)
|
11 |
+
def load_pipeline(model_name: str, device, enable_compile: bool = True):
|
12 |
+
pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16)
|
13 |
+
if enable_compile:
|
14 |
+
pipeline.unet.to(memory_format=torch.channels_last)
|
15 |
+
pipeline.unet = torch.compile(pipeline.unet, mode='reduce-overhead',fullgraph=True)
|
16 |
+
pipeline.to(device)
|
17 |
+
return pipeline
|
18 |
|
19 |
class AutoPaintingPipeline:
|
20 |
+
def __init__(self, pipeline, image: Image, mask_image: Image, target_width: int, target_height: int):
|
21 |
+
self.pipeline = pipeline
|
22 |
+
self.image = image
|
23 |
+
self.mask_image = mask_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
self.target_width = target_width
|
25 |
self.target_height = target_height
|
26 |
+
|
|
|
|
|
|
|
|
|
|
|
27 |
def run_inference(self, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float):
|
28 |
+
output = self.pipeline(
|
29 |
+
prompt=prompt,
|
30 |
+
negative_prompt=negative_prompt,
|
31 |
+
image=self.image,
|
32 |
+
mask_image=self.mask_image,
|
33 |
+
num_inference_steps=num_inference_steps,
|
34 |
+
strength=strength,
|
35 |
+
guidance_scale=guidance_scale,
|
36 |
+
height=self.target_height,
|
37 |
+
width=self.target_width
|
38 |
+
|
39 |
+
).images[0]
|
40 |
return output
|
41 |
+
|
42 |
+
@hydra.main(version_base=None, config_path="../configs", config_name="inpainting")
|
|
|
43 |
def inference(cfg: DictConfig):
|
44 |
+
# Load the pipeline once and cache it
|
45 |
+
pipeline = load_pipeline(cfg.model, accelerator(), True)
|
46 |
+
|
47 |
+
# Image augmentation and preparation
|
|
|
|
|
48 |
augmenter = ImageAugmentation(target_width=cfg.target_width, target_height=cfg.target_height)
|
|
|
49 |
image_path = "../sample_data/example3.jpg"
|
50 |
image = Image.open(image_path)
|
51 |
extended_image = augmenter.extend_image(image)
|
52 |
mask_image = augmenter.generate_mask_from_bbox(extended_image, cfg.segmentation_model, cfg.detection_model)
|
53 |
mask_image = augmenter.invert_mask(mask_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
# Create AutoPaintingPipeline instance with cached pipeline
|
56 |
+
painting_pipeline = AutoPaintingPipeline(
|
57 |
+
pipeline=pipeline,
|
58 |
+
image=extended_image,
|
59 |
+
mask_image=mask_image,
|
60 |
+
target_height=cfg.target_height,
|
61 |
+
target_width=cfg.target_width
|
62 |
+
)
|
63 |
+
|
64 |
+
# Run inference
|
65 |
+
output = painting_pipeline.run_inference(
|
66 |
+
prompt=cfg.prompt,
|
67 |
+
negative_prompt=cfg.negative_prompt,
|
68 |
+
num_inference_steps=cfg.num_inference_steps,
|
69 |
+
strength=cfg.strength,
|
70 |
+
guidance_scale=cfg.guidance_scale
|
71 |
+
)
|
72 |
|
73 |
+
# Save output and mask images
|
74 |
+
output.save(f'{cfg.output_path}/output.jpg')
|
75 |
+
mask_image.save(f'{cfg.output_path}/mask.jpg')
|
76 |
+
|
77 |
if __name__ == "__main__":
|
78 |
inference()
|
79 |
|
|
|
|
|
80 |
|