yolo-human-parse / yolo /BodyMask.py
MnLgt's picture
"hi"
6706230
raw
history blame
8.09 kB
import os
from functools import lru_cache
from typing import List
import cv2
import numpy as np
from diffusers.utils import load_image
from PIL import Image, ImageChops, ImageFilter
from ultralytics import YOLO
from .utils import *
def dilate_mask(mask, dilate_factor=6, blur_radius=2, erosion_factor=2):
if not mask:
return None
# Convert PIL image to NumPy array if necessary
if isinstance(mask, Image.Image):
mask = np.array(mask)
# Ensure mask is in uint8 format
mask = mask.astype(np.uint8)
# Apply dilation
kernel = np.ones((dilate_factor, dilate_factor), np.uint8)
dilated_mask = cv2.dilate(mask, kernel, iterations=1)
# Apply erosion for refinement
kernel = np.ones((erosion_factor, erosion_factor), np.uint8)
eroded_mask = cv2.erode(dilated_mask, kernel, iterations=1)
# Apply Gaussian blur to smooth the edges
blurred_mask = cv2.GaussianBlur(
eroded_mask, (2 * blur_radius + 1, 2 * blur_radius + 1), 0
)
# Convert back to PIL image
smoothed_mask = Image.fromarray(blurred_mask).convert("L")
# Optionally, apply an additional blur for extra smoothness using PIL
smoothed_mask = smoothed_mask.filter(ImageFilter.GaussianBlur(radius=blur_radius))
return smoothed_mask
@lru_cache(maxsize=1)
def get_model(model_id):
model = YOLO(model=model_id)
return model
def combine_masks(masks: List[dict], labels: List[str], is_label=True) -> Image.Image:
"""
Combine masks with the specified labels into a single mask, optimized for speed and non-overlapping of excluded masks.
Parameters:
- masks (List[dict]): A list of dictionaries, each containing the mask under a 'mask' key and its label under a 'label' key.
- labels (List[str]): A list of labels to include in the combination.
Returns:
- Image.Image: The combined mask as a PIL Image object, or None if no masks are combined.
"""
labels_set = set(labels) # Convert labels list to a set for O(1) lookups
# Filter and convert mask images based on the specified labels
mask_images = [
mask["mask"].convert("L")
for mask in masks
if (mask["label"] in labels_set) == is_label
]
# Ensure there is at least one mask to combine
if not mask_images:
return None # Or raise an appropriate error, e.g., ValueError("No masks found for the specified labels.")
# Initialize the combined mask with the first mask
combined_mask = mask_images[0]
# Combine the remaining masks with the existing combined_mask using a bitwise OR operation to ensure non-overlap
for mask in mask_images[1:]:
combined_mask = ImageChops.lighter(combined_mask, mask)
return combined_mask
body_labels = ["hair", "face", "arm", "hand", "leg", "foot", "outfit"]
class BodyMask:
def __init__(
self,
image_path,
model_id,
labels=body_labels,
overlay="mask",
widen_box=0,
elongate_box=0,
resize_to=640,
dilate_factor=0,
is_label=False,
resize_to_nearest_eight=False,
verbose=True,
remove_overlap=True,
):
self.image_path = image_path
self.image = self.get_image(
resize_to=resize_to, resize_to_nearest_eight=resize_to_nearest_eight
)
self.labels = labels
self.is_label = is_label
self.model_id = model_id
self.model = get_model(self.model_id)
self.model_labels = self.model.names
self.verbose = verbose
self.results = self.get_results()
self.dilate_factor = dilate_factor
self.body_mask = self.get_body_mask()
self.box = get_bounding_box(self.body_mask)
self.body_box = self.get_body_box(
remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
)
if overlay == "box":
self.overlay = overlay_mask(
self.image, self.body_box, opacity=0.9, color="red"
)
else:
self.overlay = overlay_mask(
self.image, self.body_mask, opacity=0.9, color="red"
)
def get_image(self, resize_to, resize_to_nearest_eight):
image = load_image(self.image_path)
if resize_to:
image = resize_preserve_aspect_ratio(image, resize_to)
if resize_to_nearest_eight:
image = resize_image_to_nearest_eight(image)
else:
image = image
return image
def get_body_mask(self):
body_mask = combine_masks(self.results, self.labels, self.is_label)
return dilate_mask(body_mask, self.dilate_factor)
def get_results(self):
imgsz = max(self.image.size)
results = self.model(
self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose
)[0]
self.masks, self.boxes, self.scores, self.phrases = unload(
results, self.model_labels
)
results = format_results(
self.masks,
self.boxes,
self.scores,
self.phrases,
self.model_labels,
person_masks_only=False,
)
# filter out lower score results
masks_to_filter = ["hair"]
results = filter_highest_score(results, ["hair", "face", "phone"])
return results
def display_results(self):
if len(self.masks) < 4:
cols = len(self.masks)
else:
cols = 4
display_image_with_masks(self.image, self.results, cols=cols)
def get_mask(self, mask_label):
assert mask_label in self.phrases, "Mask label not found in results"
return [f for f in self.results if f.get("label") == mask_label]
def combine_masks(self, mask_labels: List, no_labels=None, is_label=True):
"""
Combine the masks included in the labels list or all of the masks not in the list
"""
if not is_label:
mask_labels = [
phrase for phrase in self.phrases if phrase not in mask_labels
]
masks = [
row.get("mask") for row in self.results if row.get("label") in mask_labels
]
if len(masks) == 0:
return None
combined_mask = masks[0]
for mask in masks[1:]:
combined_mask = ImageChops.lighter(combined_mask, mask)
return combined_mask
def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
if remove_overlap:
body_box = self.remove_overlap(body_box)
return body_box
def remove_overlap(self, body_box):
"""
Remove mask regions that overlap with unwanted labels
"""
# convert mask to numpy array
box_array = np.array(body_box)
# combine the masks for those labels
mask = self.combine_masks(mask_labels=self.labels, is_label=True)
# convert mask to numpy array
mask_array = np.array(mask)
# where the mask array is white set the box array to black
box_array[mask_array == 255] = 0
# convert the box array to an image
mask_image = Image.fromarray(box_array)
return mask_image
if __name__ == "__main__":
url = "https://sjc1.vultrobjects.com/photo-storage/images/525d1f68-314c-455b-a8b6-f5dc3fa044e4.jpeg"
image_name = url.split("/")[-1]
labels = ["face", "hair", "phone", "hand"]
image = load_image(url)
image_size = image.size
# Get the original size of the image
original_size = image.size
# Create body mask
body_mask = BodyMask(
image,
overlay="box",
labels=labels,
widen_box=50,
elongate_box=10,
dilate_factor=0,
resize_to=640,
is_label=False,
remove_overlap=True,
verbose=False,
)
# Resize the image back to the original size
image = body_mask.image.resize(original_size)
body_mask.body_box.save(image_name)