Bingsu's picture
Upload files: v0.2.1
cd267d9
raw
history blame contribute delete
No virus
5.72 kB
from __future__ import annotations
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List, Mapping, Optional
from diffusers.utils import logging
from PIL import Image
from asdff.utils import (
ADOutput,
bbox_padding,
composite,
mask_dilate,
mask_gaussian_blur,
)
from asdff.yolo import yolo_detector
logger = logging.get_logger("diffusers")
DetectorType = Callable[[Image.Image], Optional[List[Image.Image]]]
def ordinal(n: int) -> str:
d = {1: "st", 2: "nd", 3: "rd"}
return str(n) + ("th" if 11 <= n % 100 <= 13 else d.get(n % 10, "th"))
class AdPipelineBase(ABC):
@property
@abstractmethod
def inpaint_pipeline(self) -> Callable:
raise NotImplementedError
@property
@abstractmethod
def txt2img_class(self) -> type:
raise NotImplementedError
def __call__( # noqa: C901
self,
common: Mapping[str, Any] | None = None,
txt2img_only: Mapping[str, Any] | None = None,
inpaint_only: Mapping[str, Any] | None = None,
images: Image.Image | Iterable[Image.Image] | None = None,
detectors: DetectorType | Iterable[DetectorType] | None = None,
mask_dilation: int = 4,
mask_blur: int = 4,
mask_padding: int = 32,
):
if common is None:
common = {}
if txt2img_only is None:
txt2img_only = {}
if inpaint_only is None:
inpaint_only = {}
if "strength" not in inpaint_only:
inpaint_only = {**inpaint_only, "strength": 0.4}
if detectors is None:
detectors = [self.default_detector]
elif not isinstance(detectors, Iterable):
detectors = [detectors]
if images is None:
txt2img_output = self.process_txt2img(common, txt2img_only)
txt2img_images = txt2img_output[0]
else:
if txt2img_only:
msg = "Both `images` and `txt2img_only` are specified. if `images` is specified, `txt2img_only` is ignored."
logger.warning(msg)
txt2img_images = [images] if not isinstance(images, Iterable) else images
init_images = []
final_images = []
for i, init_image in enumerate(txt2img_images):
init_images.append(init_image.copy())
final_image = None
for j, detector in enumerate(detectors):
masks = detector(init_image)
if masks is None:
logger.info(
f"No object detected on {ordinal(i + 1)} image with {ordinal(j + 1)} detector."
)
continue
for k, mask in enumerate(masks):
mask = mask.convert("L")
mask = mask_dilate(mask, mask_dilation)
bbox = mask.getbbox()
if bbox is None:
logger.info(f"No object in {ordinal(k + 1)} mask.")
continue
mask = mask_gaussian_blur(mask, mask_blur)
bbox_padded = bbox_padding(bbox, init_image.size, mask_padding)
inpaint_output = self.process_inpainting(
common,
inpaint_only,
init_image,
mask,
bbox_padded,
)
inpaint_image = inpaint_output[0][0]
final_image = composite(
init_image,
mask,
inpaint_image,
bbox_padded,
)
init_image = final_image
if final_image is not None:
final_images.append(final_image)
return ADOutput(images=final_images, init_images=init_images)
@property
def default_detector(self) -> Callable[..., list[Image.Image] | None]:
return yolo_detector
def _get_txt2img_args(
self, common: Mapping[str, Any], txt2img_only: Mapping[str, Any]
):
return {**common, **txt2img_only, "output_type": "pil"}
def _get_inpaint_args(
self, common: Mapping[str, Any], inpaint_only: Mapping[str, Any]
):
common = dict(common)
sig = inspect.signature(self.inpaint_pipeline)
if (
"control_image" in sig.parameters
and "control_image" not in common
and "image" in common
):
common["control_image"] = common.pop("image")
return {
**common,
**inpaint_only,
"num_images_per_prompt": 1,
"output_type": "pil",
}
def process_txt2img(
self, common: Mapping[str, Any], txt2img_only: Mapping[str, Any]
):
txt2img_args = self._get_txt2img_args(common, txt2img_only)
return self.txt2img_class.__call__(self, **txt2img_args)
def process_inpainting(
self,
common: Mapping[str, Any],
inpaint_only: Mapping[str, Any],
init_image: Image.Image,
mask: Image.Image,
bbox_padded: tuple[int, int, int, int],
):
crop_image = init_image.crop(bbox_padded)
crop_mask = mask.crop(bbox_padded)
inpaint_args = self._get_inpaint_args(common, inpaint_only)
inpaint_args["image"] = crop_image
inpaint_args["mask_image"] = crop_mask
if "control_image" in inpaint_args:
inpaint_args["control_image"] = inpaint_args["control_image"].resize(
crop_image.size
)
return self.inpaint_pipeline(**inpaint_args)