Spaces:
Running
Running
Commit
•
080c0ac
1
Parent(s):
c9c9773
Refactor code to use shared BaseModel for Painting and InpaintingRequest classes
Browse filesFormer-commit-id: 4484440577c279f069f9e6c8c11d9ab9685f26e8 [formerly b52b26e6289548705fa59307a387beb19edec4f6]
Former-commit-id: 60b102b660610226910f419923c877a820891aab
- api/__pycache__/endpoints.cpython-310.pyc +0 -0
- api/endpoints.py +1 -2
- api/routers/__pycache__/batch_painting.cpython-310.pyc +0 -0
- api/routers/__pycache__/painting.cpython-310.pyc +0 -0
- api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc +0 -0
- api/routers/batch_painting.py +0 -0
- api/routers/painting.py +132 -34
- api/routers/sdxl_text_to_image.py +2 -2
- scripts/__pycache__/inpainting_pipeline.cpython-310.pyc +0 -0
- scripts/inpainting_pipeline.py +2 -2
api/__pycache__/endpoints.cpython-310.pyc
CHANGED
Binary files a/api/__pycache__/endpoints.cpython-310.pyc and b/api/__pycache__/endpoints.cpython-310.pyc differ
|
|
api/endpoints.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
-
from routers import (sdxl_text_to_image,painting
|
4 |
import logfire
|
5 |
import uvicorn
|
6 |
|
@@ -26,7 +26,6 @@ app.add_middleware(
|
|
26 |
|
27 |
app.include_router(sdxl_text_to_image.router, prefix='/api/v1/product-diffusion')
|
28 |
app.include_router(painting.router,prefix='/api/v1/product-diffusion')
|
29 |
-
app.include_router(batch_painting.router,prefix='/api/v1/product-diffusion')
|
30 |
logfire.instrument_fastapi(app)
|
31 |
|
32 |
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from routers import (sdxl_text_to_image,painting)
|
4 |
import logfire
|
5 |
import uvicorn
|
6 |
|
|
|
26 |
|
27 |
app.include_router(sdxl_text_to_image.router, prefix='/api/v1/product-diffusion')
|
28 |
app.include_router(painting.router,prefix='/api/v1/product-diffusion')
|
|
|
29 |
logfire.instrument_fastapi(app)
|
30 |
|
31 |
|
api/routers/__pycache__/batch_painting.cpython-310.pyc
CHANGED
Binary files a/api/routers/__pycache__/batch_painting.cpython-310.pyc and b/api/routers/__pycache__/batch_painting.cpython-310.pyc differ
|
|
api/routers/__pycache__/painting.cpython-310.pyc
CHANGED
Binary files a/api/routers/__pycache__/painting.cpython-310.pyc and b/api/routers/__pycache__/painting.cpython-310.pyc differ
|
|
api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc
CHANGED
Binary files a/api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc and b/api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc differ
|
|
api/routers/batch_painting.py
DELETED
File without changes
|
api/routers/painting.py
CHANGED
@@ -2,25 +2,57 @@ import sys
|
|
2 |
sys.path.append('../scripts')
|
3 |
import os
|
4 |
import uuid
|
5 |
-
from typing import List, Tuple, Any
|
6 |
-
from fastapi import APIRouter, File, UploadFile, HTTPException, Form
|
|
|
7 |
from PIL import Image
|
8 |
import lightning.pytorch as pl
|
9 |
from utils import pil_to_s3_json, pil_to_b64_json, ImageAugmentation, accelerator
|
10 |
from inpainting_pipeline import AutoPaintingPipeline, load_pipeline
|
11 |
from hydra import compose, initialize
|
12 |
from async_batcher.batcher import AsyncBatcher
|
|
|
|
|
|
|
|
|
13 |
|
14 |
pl.seed_everything(42)
|
15 |
router = APIRouter()
|
16 |
|
17 |
-
# Initialize
|
18 |
with initialize(version_base=None, config_path="../../configs"):
|
19 |
cfg = compose(config_name="inpainting")
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
async def save_image(image: UploadFile) -> str:
|
23 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
file_name = f"{uuid.uuid4()}.png"
|
25 |
file_path = os.path.join("/tmp", file_name)
|
26 |
with open(file_path, "wb") as f:
|
@@ -28,6 +60,20 @@ async def save_image(image: UploadFile) -> str:
|
|
28 |
return file_path
|
29 |
|
30 |
def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
image = Image.open(image_path)
|
32 |
image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
|
33 |
image = image_augmentation.extend_image(image)
|
@@ -35,17 +81,22 @@ def augment_image(image_path, target_width, target_height, roi_scale, segmentati
|
|
35 |
inverted_mask = image_augmentation.invert_mask(mask)
|
36 |
return image, inverted_mask
|
37 |
|
38 |
-
def run_inference(cfg,
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
image, mask_image = augment_image(image_path,
|
50 |
cfg['target_width'],
|
51 |
cfg['target_height'],
|
@@ -63,37 +114,84 @@ def run_inference(cfg,
|
|
63 |
target_height=cfg['target_height'],
|
64 |
target_width=cfg['target_width']
|
65 |
)
|
66 |
-
output = painting_pipeline.run_inference(prompt=prompt,
|
67 |
-
negative_prompt=negative_prompt,
|
68 |
-
num_inference_steps=num_inference_steps,
|
69 |
-
strength=strength,
|
70 |
-
guidance_scale=guidance_scale)
|
71 |
-
if mode == "s3_json":
|
72 |
return pil_to_s3_json(output, file_name="output.png")
|
73 |
-
elif mode == "b64_json":
|
74 |
return pil_to_b64_json(output)
|
75 |
else:
|
76 |
raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
@router.post("/inpainting")
|
81 |
async def inpainting_inference(
|
82 |
image: UploadFile = File(...),
|
83 |
-
|
84 |
-
negative_prompt: str = Form(...),
|
85 |
-
num_inference_steps: int = Form(...),
|
86 |
-
strength: float = Form(...),
|
87 |
-
guidance_scale: float = Form(...),
|
88 |
-
mode: str = Form(...),
|
89 |
-
num_images: int = Form(1),
|
90 |
-
|
91 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
try:
|
93 |
image_path = await save_image(image)
|
94 |
-
|
|
|
|
|
95 |
return result
|
96 |
except Exception as e:
|
97 |
raise HTTPException(status_code=500, detail=str(e))
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
sys.path.append('../scripts')
|
3 |
import os
|
4 |
import uuid
|
5 |
+
from typing import List, Tuple, Any,Dict
|
6 |
+
from fastapi import APIRouter, File, UploadFile, HTTPException, Form, Depends, Body
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
from PIL import Image
|
9 |
import lightning.pytorch as pl
|
10 |
from utils import pil_to_s3_json, pil_to_b64_json, ImageAugmentation, accelerator
|
11 |
from inpainting_pipeline import AutoPaintingPipeline, load_pipeline
|
12 |
from hydra import compose, initialize
|
13 |
from async_batcher.batcher import AsyncBatcher
|
14 |
+
from concurrent.futures import Executor
|
15 |
+
import json
|
16 |
+
import asyncio
|
17 |
+
from functools import lru_cache
|
18 |
|
19 |
pl.seed_everything(42)
|
20 |
router = APIRouter()
|
21 |
|
22 |
+
# Initialize Hydra configuration
|
23 |
with initialize(version_base=None, config_path="../../configs"):
|
24 |
cfg = compose(config_name="inpainting")
|
25 |
+
|
26 |
+
# Load the inpainting pipeline
|
27 |
+
@lru_cache(maxsize=1)
|
28 |
+
def load_pipeline_wrapper():
|
29 |
+
pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
|
30 |
+
return pipeline
|
31 |
+
inpainting_pipeline = load_pipeline_wrapper()
|
32 |
+
|
33 |
+
class InpaintingRequest(BaseModel):
|
34 |
+
prompt: str = Field(..., description="Prompt text for inference")
|
35 |
+
negative_prompt: str = Field(..., description="Negative prompt text for inference")
|
36 |
+
num_inference_steps: int = Field(..., description="Number of inference steps")
|
37 |
+
strength: float = Field(..., description="Strength of the inference")
|
38 |
+
guidance_scale: float = Field(..., description="Guidance scale for inference")
|
39 |
+
mode: str = Field(..., description="Mode for output ('b64_json' or 's3_json')")
|
40 |
+
num_images: int = Field(..., description="Number of images to generate")
|
41 |
+
use_augmentation: bool = Field(True, description="Whether to use image augmentation")
|
42 |
+
|
43 |
+
class InpaintingBatchRequestModel(BaseModel):
|
44 |
+
requests: List[InpaintingRequest]
|
45 |
|
46 |
async def save_image(image: UploadFile) -> str:
|
47 |
+
"""
|
48 |
+
Save an uploaded image to a temporary file and return the file path.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
image (UploadFile): The uploaded image file.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
str: File path where the image is saved.
|
55 |
+
"""
|
56 |
file_name = f"{uuid.uuid4()}.png"
|
57 |
file_path = os.path.join("/tmp", file_name)
|
58 |
with open(file_path, "wb") as f:
|
|
|
60 |
return file_path
|
61 |
|
62 |
def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
|
63 |
+
"""
|
64 |
+
Augment an image by extending its dimensions and generating masks.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
image_path (str): Path to the image file.
|
68 |
+
target_width (int): Target width for augmentation.
|
69 |
+
target_height (int): Target height for augmentation.
|
70 |
+
roi_scale (float): Scale factor for region of interest.
|
71 |
+
segmentation_model_name (str): Name of the segmentation model.
|
72 |
+
detection_model_name (str): Name of the detection model.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Tuple[Image.Image, Image.Image]: Augmented image and inverted mask.
|
76 |
+
"""
|
77 |
image = Image.open(image_path)
|
78 |
image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
|
79 |
image = image_augmentation.extend_image(image)
|
|
|
81 |
inverted_mask = image_augmentation.invert_mask(mask)
|
82 |
return image, inverted_mask
|
83 |
|
84 |
+
def run_inference(cfg, image_path: str, request: InpaintingRequest):
|
85 |
+
"""
|
86 |
+
Run inference using an inpainting pipeline on an image.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
cfg (dict): Configuration dictionary.
|
90 |
+
image_path (str): Path to the image file.
|
91 |
+
request (InpaintingRequest): Pydantic model containing inference parameters.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
dict: Resulting image in the specified mode ('b64_json' or 's3_json').
|
95 |
+
|
96 |
+
Raises:
|
97 |
+
ValueError: If an invalid mode is provided.
|
98 |
+
"""
|
99 |
+
if request.use_augmentation:
|
100 |
image, mask_image = augment_image(image_path,
|
101 |
cfg['target_width'],
|
102 |
cfg['target_height'],
|
|
|
114 |
target_height=cfg['target_height'],
|
115 |
target_width=cfg['target_width']
|
116 |
)
|
117 |
+
output = painting_pipeline.run_inference(prompt=request.prompt,
|
118 |
+
negative_prompt=request.negative_prompt,
|
119 |
+
num_inference_steps=request.num_inference_steps,
|
120 |
+
strength=request.strength,
|
121 |
+
guidance_scale=request.guidance_scale)
|
122 |
+
if request.mode == "s3_json":
|
123 |
return pil_to_s3_json(output, file_name="output.png")
|
124 |
+
elif request.mode == "b64_json":
|
125 |
return pil_to_b64_json(output)
|
126 |
else:
|
127 |
raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
|
128 |
+
|
129 |
+
class InpaintingBatcher(AsyncBatcher):
|
130 |
+
async def process_batch(self, batch: Tuple[List[str], List[InpaintingRequest]]) -> List[Dict[str, Any]]:
|
131 |
+
image_paths, requests = batch
|
132 |
+
results = []
|
133 |
+
for image_path, request in zip(image_paths, requests):
|
134 |
+
result = run_inference(cfg, image_path, request)
|
135 |
+
results.append(result)
|
136 |
+
return results
|
137 |
|
138 |
@router.post("/inpainting")
|
139 |
async def inpainting_inference(
|
140 |
image: UploadFile = File(...),
|
141 |
+
request_data: str = Form(...),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
):
|
143 |
+
"""
|
144 |
+
Handle POST request for inpainting inference.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
image (UploadFile): Uploaded image file.
|
148 |
+
request_data (str): JSON string of the request parameters.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
dict: Resulting image in the specified mode ('b64_json' or 's3_json').
|
152 |
+
|
153 |
+
Raises:
|
154 |
+
HTTPException: If there is an error during image processing.
|
155 |
+
"""
|
156 |
try:
|
157 |
image_path = await save_image(image)
|
158 |
+
request_dict = json.loads(request_data)
|
159 |
+
request = InpaintingRequest(**request_dict)
|
160 |
+
result = run_inference(cfg, image_path, request)
|
161 |
return result
|
162 |
except Exception as e:
|
163 |
raise HTTPException(status_code=500, detail=str(e))
|
164 |
|
165 |
+
@router.post("/inpainting/batch")
|
166 |
+
async def inpainting_batch_inference(
|
167 |
+
images: List[UploadFile] = File(...),
|
168 |
+
request_data: str = Form(...),
|
169 |
+
):
|
170 |
+
"""
|
171 |
+
Handle POST request for batch inpainting inference.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
images (List[UploadFile]): List of uploaded image files.
|
175 |
+
request_data (str): JSON string of the request parameters.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
List[dict]: List of resulting images in the specified mode ('b64_json' or 's3_json').
|
179 |
|
180 |
+
Raises:
|
181 |
+
HTTPException: If there is an error during image processing.
|
182 |
+
"""
|
183 |
+
try:
|
184 |
+
request_dict = json.loads(request_data)
|
185 |
+
batch_request = InpaintingBatchRequestModel(**request_dict)
|
186 |
+
requests = batch_request.requests
|
187 |
+
|
188 |
+
if len(images) != len(requests):
|
189 |
+
raise HTTPException(status_code=400, detail="The number of images and requests must match.")
|
190 |
+
|
191 |
+
batcher = InpaintingBatcher(max_batch_size=64)
|
192 |
+
image_paths = [await save_image(image) for image in images]
|
193 |
+
results = await batcher.process_batch((image_paths, requests))
|
194 |
+
|
195 |
+
return results
|
196 |
+
except Exception as e:
|
197 |
+
raise HTTPException(status_code=500, detail=str(e))
|
api/routers/sdxl_text_to_image.py
CHANGED
@@ -32,7 +32,7 @@ router = APIRouter()
|
|
32 |
|
33 |
|
34 |
# Load the diffusion pipeline
|
35 |
-
|
36 |
def load_pipeline(model_name, adapter_name,enable_compile:bool):
|
37 |
"""
|
38 |
Load the diffusion pipeline with the specified model and adapter names.
|
@@ -182,7 +182,7 @@ async def sdxl_v0_lora_inference(data: InputFormat):
|
|
182 |
|
183 |
@router.post("/sdxl_v0_lora_inference/batch")
|
184 |
async def sdxl_v0_lora_inference_batch(data: List[InputFormat]):
|
185 |
-
batcher = SDXLLoraBatcher(max_batch_size
|
186 |
try:
|
187 |
predictions = await batcher.process_batch(data)
|
188 |
return predictions
|
|
|
32 |
|
33 |
|
34 |
# Load the diffusion pipeline
|
35 |
+
|
36 |
def load_pipeline(model_name, adapter_name,enable_compile:bool):
|
37 |
"""
|
38 |
Load the diffusion pipeline with the specified model and adapter names.
|
|
|
182 |
|
183 |
@router.post("/sdxl_v0_lora_inference/batch")
|
184 |
async def sdxl_v0_lora_inference_batch(data: List[InputFormat]):
|
185 |
+
batcher = SDXLLoraBatcher(max_batch_size=-1)
|
186 |
try:
|
187 |
predictions = await batcher.process_batch(data)
|
188 |
return predictions
|
scripts/__pycache__/inpainting_pipeline.cpython-310.pyc
CHANGED
Binary files a/scripts/__pycache__/inpainting_pipeline.cpython-310.pyc and b/scripts/__pycache__/inpainting_pipeline.cpython-310.pyc differ
|
|
scripts/inpainting_pipeline.py
CHANGED
@@ -7,8 +7,8 @@ from omegaconf import DictConfig
|
|
7 |
from PIL import Image
|
8 |
from functools import lru_cache
|
9 |
|
10 |
-
|
11 |
-
def load_pipeline(model_name: str, device, enable_compile: bool =
|
12 |
pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16)
|
13 |
if enable_compile:
|
14 |
pipeline.unet.to(memory_format=torch.channels_last)
|
|
|
7 |
from PIL import Image
|
8 |
from functools import lru_cache
|
9 |
|
10 |
+
|
11 |
+
def load_pipeline(model_name: str, device, enable_compile: bool = True):
|
12 |
pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16)
|
13 |
if enable_compile:
|
14 |
pipeline.unet.to(memory_format=torch.channels_last)
|