|
import os |
|
from functools import lru_cache |
|
from typing import List |
|
|
|
import cv2 |
|
import numpy as np |
|
from diffusers.utils import load_image |
|
from PIL import Image, ImageChops, ImageFilter |
|
from ultralytics import YOLO |
|
from .utils import * |
|
|
|
|
|
def dilate_mask(mask, dilate_factor=6, blur_radius=2, erosion_factor=2): |
|
if not mask: |
|
return None |
|
|
|
if isinstance(mask, Image.Image): |
|
mask = np.array(mask) |
|
|
|
|
|
mask = mask.astype(np.uint8) |
|
|
|
|
|
kernel = np.ones((dilate_factor, dilate_factor), np.uint8) |
|
dilated_mask = cv2.dilate(mask, kernel, iterations=1) |
|
|
|
|
|
kernel = np.ones((erosion_factor, erosion_factor), np.uint8) |
|
eroded_mask = cv2.erode(dilated_mask, kernel, iterations=1) |
|
|
|
|
|
blurred_mask = cv2.GaussianBlur( |
|
eroded_mask, (2 * blur_radius + 1, 2 * blur_radius + 1), 0 |
|
) |
|
|
|
|
|
smoothed_mask = Image.fromarray(blurred_mask).convert("L") |
|
|
|
|
|
smoothed_mask = smoothed_mask.filter(ImageFilter.GaussianBlur(radius=blur_radius)) |
|
|
|
return smoothed_mask |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def get_model(model_id): |
|
model = YOLO(model=model_id) |
|
return model |
|
|
|
|
|
def combine_masks(masks: List[dict], labels: List[str], is_label=True) -> Image.Image: |
|
""" |
|
Combine masks with the specified labels into a single mask, optimized for speed and non-overlapping of excluded masks. |
|
|
|
Parameters: |
|
- masks (List[dict]): A list of dictionaries, each containing the mask under a 'mask' key and its label under a 'label' key. |
|
- labels (List[str]): A list of labels to include in the combination. |
|
|
|
Returns: |
|
- Image.Image: The combined mask as a PIL Image object, or None if no masks are combined. |
|
""" |
|
labels_set = set(labels) |
|
|
|
|
|
mask_images = [ |
|
mask["mask"].convert("L") |
|
for mask in masks |
|
if (mask["label"] in labels_set) == is_label |
|
] |
|
|
|
|
|
if not mask_images: |
|
return None |
|
|
|
|
|
combined_mask = mask_images[0] |
|
|
|
|
|
for mask in mask_images[1:]: |
|
combined_mask = ImageChops.lighter(combined_mask, mask) |
|
|
|
return combined_mask |
|
|
|
|
|
body_labels = ["hair", "face", "arm", "hand", "leg", "foot", "outfit"] |
|
|
|
|
|
class BodyMask: |
|
|
|
def __init__( |
|
self, |
|
image_path, |
|
model_id, |
|
labels=body_labels, |
|
overlay="mask", |
|
widen_box=0, |
|
elongate_box=0, |
|
resize_to=640, |
|
dilate_factor=0, |
|
is_label=False, |
|
resize_to_nearest_eight=False, |
|
verbose=True, |
|
remove_overlap=True, |
|
): |
|
self.image_path = image_path |
|
self.image = self.get_image( |
|
resize_to=resize_to, resize_to_nearest_eight=resize_to_nearest_eight |
|
) |
|
self.labels = labels |
|
self.is_label = is_label |
|
self.model_id = model_id |
|
self.model = get_model(self.model_id) |
|
self.model_labels = self.model.names |
|
self.verbose = verbose |
|
self.results = self.get_results() |
|
self.dilate_factor = dilate_factor |
|
self.body_mask = self.get_body_mask() |
|
self.box = get_bounding_box(self.body_mask) |
|
self.body_box = self.get_body_box( |
|
remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box |
|
) |
|
if overlay == "box": |
|
self.overlay = overlay_mask( |
|
self.image, self.body_box, opacity=0.9, color="red" |
|
) |
|
else: |
|
self.overlay = overlay_mask( |
|
self.image, self.body_mask, opacity=0.9, color="red" |
|
) |
|
|
|
def get_image(self, resize_to, resize_to_nearest_eight): |
|
image = load_image(self.image_path) |
|
if resize_to: |
|
image = resize_preserve_aspect_ratio(image, resize_to) |
|
if resize_to_nearest_eight: |
|
image = resize_image_to_nearest_eight(image) |
|
else: |
|
image = image |
|
return image |
|
|
|
def get_body_mask(self): |
|
body_mask = combine_masks(self.results, self.labels, self.is_label) |
|
return dilate_mask(body_mask, self.dilate_factor) |
|
|
|
def get_results(self): |
|
imgsz = max(self.image.size) |
|
results = self.model( |
|
self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose |
|
)[0] |
|
self.masks, self.boxes, self.scores, self.phrases = unload( |
|
results, self.model_labels |
|
) |
|
results = format_results( |
|
self.masks, |
|
self.boxes, |
|
self.scores, |
|
self.phrases, |
|
self.model_labels, |
|
person_masks_only=False, |
|
) |
|
|
|
|
|
masks_to_filter = ["hair"] |
|
results = filter_highest_score(results, ["hair", "face", "phone"]) |
|
return results |
|
|
|
def display_results(self): |
|
if len(self.masks) < 4: |
|
cols = len(self.masks) |
|
else: |
|
cols = 4 |
|
display_image_with_masks(self.image, self.results, cols=cols) |
|
|
|
def get_mask(self, mask_label): |
|
assert mask_label in self.phrases, "Mask label not found in results" |
|
return [f for f in self.results if f.get("label") == mask_label] |
|
|
|
def combine_masks(self, mask_labels: List, no_labels=None, is_label=True): |
|
""" |
|
Combine the masks included in the labels list or all of the masks not in the list |
|
""" |
|
if not is_label: |
|
mask_labels = [ |
|
phrase for phrase in self.phrases if phrase not in mask_labels |
|
] |
|
masks = [ |
|
row.get("mask") for row in self.results if row.get("label") in mask_labels |
|
] |
|
if len(masks) == 0: |
|
return None |
|
combined_mask = masks[0] |
|
for mask in masks[1:]: |
|
combined_mask = ImageChops.lighter(combined_mask, mask) |
|
return combined_mask |
|
|
|
def get_body_box(self, remove_overlap=True, widen=0, elongate=0): |
|
body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate) |
|
if remove_overlap: |
|
body_box = self.remove_overlap(body_box) |
|
return body_box |
|
|
|
def remove_overlap(self, body_box): |
|
""" |
|
Remove mask regions that overlap with unwanted labels |
|
""" |
|
|
|
box_array = np.array(body_box) |
|
|
|
|
|
mask = self.combine_masks(mask_labels=self.labels, is_label=True) |
|
|
|
|
|
mask_array = np.array(mask) |
|
|
|
|
|
box_array[mask_array == 255] = 0 |
|
|
|
|
|
mask_image = Image.fromarray(box_array) |
|
return mask_image |
|
|
|
|
|
if __name__ == "__main__": |
|
url = "https://sjc1.vultrobjects.com/photo-storage/images/525d1f68-314c-455b-a8b6-f5dc3fa044e4.jpeg" |
|
image_name = url.split("/")[-1] |
|
labels = ["face", "hair", "phone", "hand"] |
|
image = load_image(url) |
|
image_size = image.size |
|
|
|
original_size = image.size |
|
|
|
|
|
body_mask = BodyMask( |
|
image, |
|
overlay="box", |
|
labels=labels, |
|
widen_box=50, |
|
elongate_box=10, |
|
dilate_factor=0, |
|
resize_to=640, |
|
is_label=False, |
|
remove_overlap=True, |
|
verbose=False, |
|
) |
|
|
|
|
|
image = body_mask.image.resize(original_size) |
|
body_mask.body_box.save(image_name) |
|
|