Bingsu commited on
Commit
9c47f7f
1 Parent(s): 36f0b3c

Upload files: v0.1.0

Browse files
Files changed (6) hide show
  1. asdff/__init__.py +9 -0
  2. asdff/__version__.py +1 -0
  3. asdff/sd.py +122 -0
  4. asdff/utils.py +70 -0
  5. asdff/yolo.py +73 -0
  6. pipeline.py +1 -0
asdff/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .__version__ import __version__
2
+ from .sd import AdPipeline
3
+ from .yolo import yolo_detector
4
+
5
+ __all__ = [
6
+ "AdPipeline",
7
+ "yolo_detector",
8
+ "__version__",
9
+ ]
asdff/__version__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
asdff/sd.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from functools import cached_property
4
+ from typing import Any, Callable, Iterable, List, Optional
5
+
6
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
7
+ from diffusers.utils import logging
8
+ from PIL import Image
9
+
10
+ from asdff.utils import (
11
+ ADOutput,
12
+ bbox_padding,
13
+ composite,
14
+ mask_dilate,
15
+ mask_gaussian_blur,
16
+ )
17
+ from asdff.yolo import yolo_detector
18
+
19
+ logger = logging.get_logger("diffusers")
20
+
21
+
22
+ DetectorType = Callable[[Image.Image], Optional[List[Image.Image]]]
23
+
24
+
25
+ def ordinal(n: int) -> str:
26
+ d = {1: "st", 2: "nd", 3: "rd"}
27
+ return str(n) + ("th" if 11 <= n % 100 <= 13 else d.get(n % 10, "th"))
28
+
29
+
30
+ class AdPipeline(StableDiffusionPipeline):
31
+ @cached_property
32
+ def inpaine_pipeline(self):
33
+ return StableDiffusionInpaintPipeline(
34
+ vae=self.vae,
35
+ text_encoder=self.text_encoder,
36
+ tokenizer=self.tokenizer,
37
+ unet=self.unet,
38
+ scheduler=self.scheduler,
39
+ safety_checker=self.safety_checker,
40
+ feature_extractor=self.feature_extractor,
41
+ requires_safety_checker=self.config.requires_safety_checker,
42
+ )
43
+
44
+ def __call__( # noqa: C901
45
+ self,
46
+ common: dict[str, Any] | None = None,
47
+ txt2img_only: dict[str, Any] | None = None,
48
+ inpaint_only: dict[str, Any] | None = None,
49
+ detectors: DetectorType | Iterable[DetectorType] | None = None,
50
+ mask_dilation: int = 4,
51
+ mask_blur: int = 4,
52
+ mask_padding: int = 32,
53
+ ):
54
+ if common is None:
55
+ common = {}
56
+ if txt2img_only is None:
57
+ txt2img_only = {}
58
+ if inpaint_only is None:
59
+ inpaint_only = {}
60
+ inpaint_only.setdefault("strength", 0.4)
61
+
62
+ if detectors is None:
63
+ detectors = [self.default_detector]
64
+ elif callable(detectors):
65
+ detectors = [detectors]
66
+
67
+ txt2img_output = super().__call__(**common, **txt2img_only, output_type="pil")
68
+ txt2img_images: list[Image.Image] = txt2img_output[0]
69
+
70
+ init_images = []
71
+ final_images = []
72
+
73
+ for i, init_image in enumerate(txt2img_images):
74
+ init_images.append(init_image.copy())
75
+ final_image = None
76
+
77
+ for j, detector in enumerate(detectors):
78
+ masks = detector(init_image)
79
+ if masks is None:
80
+ logger.info(
81
+ f"No object detected on {ordinal(i + 1)} image with {ordinal(j + 1)} detector."
82
+ )
83
+ continue
84
+
85
+ for k, mask in enumerate(masks):
86
+ mask = mask.convert("L")
87
+ mask = mask_dilate(mask, mask_dilation)
88
+ bbox = mask.getbbox()
89
+ if bbox is None:
90
+ logger.info(f"No object in {ordinal(k + 1)} mask.")
91
+ continue
92
+ mask = mask_gaussian_blur(mask, mask_blur)
93
+ bbox_padded = bbox_padding(bbox, init_image.size, mask_padding)
94
+
95
+ crop_image = init_image.crop(bbox_padded)
96
+ crop_mask = mask.crop(bbox_padded)
97
+
98
+ inpaint_output = self.inpaine_pipeline(
99
+ **common,
100
+ **inpaint_only,
101
+ image=crop_image,
102
+ mask_image=crop_mask,
103
+ num_images_per_prompt=1,
104
+ output_type="pil",
105
+ )
106
+ inpaint_image: Image.Image = inpaint_output[0][0]
107
+ final_image = composite(
108
+ init=init_image,
109
+ mask=mask,
110
+ gen=inpaint_image,
111
+ bbox_padded=bbox_padded,
112
+ )
113
+ init_image = final_image
114
+
115
+ if final_image is not None:
116
+ final_images.append(final_image)
117
+
118
+ return ADOutput(images=final_images, init_images=init_images)
119
+
120
+ @property
121
+ def default_detector(self) -> Callable[..., list[Image.Image] | None]:
122
+ return yolo_detector
asdff/utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from diffusers.utils import BaseOutput
8
+ from PIL import Image, ImageFilter, ImageOps
9
+
10
+
11
+ @dataclass
12
+ class ADOutput(BaseOutput):
13
+ images: list[Image.Image]
14
+ init_images: list[Image.Image]
15
+
16
+
17
+ def mask_dilate(image: Image.Image, value: int = 4) -> Image.Image:
18
+ if value <= 0:
19
+ return image
20
+
21
+ arr = np.array(image)
22
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
23
+ dilated = cv2.dilate(arr, kernel, iterations=1)
24
+ return Image.fromarray(dilated)
25
+
26
+
27
+ def mask_gaussian_blur(image: Image.Image, value: int = 4) -> Image.Image:
28
+ if value <= 0:
29
+ return image
30
+
31
+ blur = ImageFilter.GaussianBlur(value)
32
+ return image.filter(blur)
33
+
34
+
35
+ def bbox_padding(
36
+ bbox: tuple[int, int, int, int], image_size: tuple[int, int], value: int = 32
37
+ ) -> tuple[int, int, int, int]:
38
+ if value <= 0:
39
+ return bbox
40
+
41
+ arr = np.array(bbox).reshape(2, 2)
42
+ arr[0] -= value
43
+ arr[1] += value
44
+ arr = np.clip(arr, (0, 0), image_size)
45
+ return tuple(arr.flatten())
46
+
47
+
48
+ def composite(
49
+ init: Image.Image,
50
+ mask: Image.Image,
51
+ gen: Image.Image,
52
+ bbox_padded: tuple[int, int, int, int],
53
+ ) -> Image.Image:
54
+ img_masked = Image.new("RGBa", init.size)
55
+ img_masked.paste(
56
+ init.convert("RGBA").convert("RGBa"),
57
+ mask=ImageOps.invert(mask),
58
+ )
59
+ img_masked = img_masked.convert("RGBA")
60
+
61
+ size = (
62
+ bbox_padded[2] - bbox_padded[0],
63
+ bbox_padded[3] - bbox_padded[1],
64
+ )
65
+ resized = gen.resize(size)
66
+
67
+ output = Image.new("RGBA", init.size)
68
+ output.paste(resized, bbox_padded)
69
+ output.alpha_composite(img_masked)
70
+ return output.convert("RGB")
asdff/yolo.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from PIL import Image, ImageDraw
7
+ from torchvision.transforms.functional import to_pil_image
8
+ from ultralytics import YOLO
9
+
10
+
11
+ def create_mask_from_bbox(
12
+ bboxes: np.ndarray, shape: tuple[int, int]
13
+ ) -> list[Image.Image]:
14
+ """
15
+ Parameters
16
+ ----------
17
+ bboxes: list[list[float]]
18
+ list of [x1, y1, x2, y2]
19
+ bounding boxes
20
+ shape: tuple[int, int]
21
+ shape of the image (width, height)
22
+
23
+ Returns
24
+ -------
25
+ masks: list[Image.Image]
26
+ A list of masks
27
+
28
+ """
29
+ masks = []
30
+ for bbox in bboxes:
31
+ mask = Image.new("L", shape, "black")
32
+ mask_draw = ImageDraw.Draw(mask)
33
+ mask_draw.rectangle(bbox, fill="white")
34
+ masks.append(mask)
35
+ return masks
36
+
37
+
38
+ def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
39
+ """
40
+ Parameters
41
+ ----------
42
+ masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
43
+ The device can be CUDA, but `to_pil_image` takes care of that.
44
+
45
+ shape: tuple[int, int]
46
+ (width, height) of the original image
47
+
48
+ Returns
49
+ -------
50
+ images: list[Image.Image]
51
+ """
52
+ n = masks.shape[0]
53
+ return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
54
+
55
+
56
+ def yolo_detector(
57
+ image: Image.Image, model_path: str | None = None, confidence: float = 0.3
58
+ ) -> list[Image.Image] | None:
59
+ if not model_path:
60
+ model_path = hf_hub_download("Bingsu/adetailer", "face_yolov8n.pt")
61
+ model = YOLO(model_path)
62
+ pred = model(image, conf=confidence)
63
+
64
+ bboxes = pred[0].boxes.xyxy.cpu().numpy()
65
+ if bboxes.size == 0:
66
+ return None
67
+
68
+ if pred[0].masks is None:
69
+ masks = create_mask_from_bbox(bboxes, image.size)
70
+ else:
71
+ masks = mask_to_pil(pred[0].masks.data, image.size)
72
+
73
+ return masks
pipeline.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from asdff import AdPipeline # noqa: F401