|
import gc |
|
from typing import List, Optional, Dict, Any |
|
|
|
import torch |
|
from pydantic import BaseModel |
|
from PIL import Image |
|
from diffusers.schedulers import * |
|
from controlnet_aux.processor import Processor |
|
|
|
|
|
class ControlNetReq(BaseModel): |
|
controlnets: List[str] |
|
control_images: List[Image.Image] |
|
controlnet_conditioning_scale: List[float] |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
class BaseReq(BaseModel): |
|
model: str = "" |
|
prompt: str = "" |
|
negative_prompt: Optional[str] = "" |
|
fast_generation: Optional[bool] = True |
|
loras: Optional[list] = [] |
|
embeddings: Optional[list] = None |
|
resize_mode: Optional[str] = "resize_and_fill" |
|
scheduler: Optional[str] = "euler_fl" |
|
height: int = 1024 |
|
width: int = 1024 |
|
num_images_per_prompt: int = 1 |
|
num_inference_steps: int = 8 |
|
clip_skip: Optional[int] = None |
|
guidance_scale: float = 3.5 |
|
seed: Optional[int] = 0 |
|
refiner: bool = False |
|
vae: bool = True |
|
controlnet_config: Optional[ControlNetReq] = None |
|
custom_addons: Optional[Dict[Any, Any]] = None |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
class BaseImg2ImgReq(BaseReq): |
|
image: Image.Image |
|
strength: float = 1.0 |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
class BaseInpaintReq(BaseImg2ImgReq): |
|
mask_image: Image.Image |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str): |
|
for image in images: |
|
if resize_mode == "resize_only": |
|
image = image.resize((width, height)) |
|
elif resize_mode == "crop_and_resize": |
|
image = image.crop((0, 0, width, height)) |
|
elif resize_mode == "resize_and_fill": |
|
image = image.resize((width, height), Image.Resampling.LANCZOS) |
|
|
|
return images |
|
|
|
|
|
def get_controlnet_images(controlnet_config: ControlNetReq, height: int, width: int, resize_mode: str): |
|
response_images = [] |
|
control_images = resize_images(controlnet_config.control_images, height, width, resize_mode) |
|
for controlnet, image in zip(controlnet_config.controlnets, control_images): |
|
if controlnet == "canny": |
|
processor = Processor('canny') |
|
elif controlnet == "depth": |
|
processor = Processor('depth_midas') |
|
elif controlnet == "pose": |
|
processor = Processor('openpose_full') |
|
else: |
|
raise ValueError(f"Invalid Controlnet: {controlnet}") |
|
|
|
response_images.append(processor(image, to_pil=True)) |
|
|
|
return response_images |
|
|
|
|
|
def cleanup(pipeline, loras = None): |
|
if loras: |
|
pipeline.unload_lora_weights() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|