import io from enum import Enum from typing import List, Optional, Union import numpy as np from cv2 import ( BORDER_DEFAULT, MORPH_ELLIPSE, MORPH_OPEN, GaussianBlur, getStructuringElement, morphologyEx, ) from PIL import Image from PIL.Image import Image as PILImage from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml from pymatting.util.util import stack_images from scipy.ndimage import binary_erosion from .session_base import BaseSession from .session_factory import new_session kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) class ReturnType(Enum): BYTES = 0 PILLOW = 1 NDARRAY = 2 def alpha_matting_cutout( img: PILImage, mask: PILImage, foreground_threshold: int, background_threshold: int, erode_structure_size: int, ) -> PILImage: if img.mode == "RGBA" or img.mode == "CMYK": img = img.convert("RGB") img = np.asarray(img) mask = np.asarray(mask) is_foreground = mask > foreground_threshold is_background = mask < background_threshold structure = None if erode_structure_size > 0: structure = np.ones( (erode_structure_size, erode_structure_size), dtype=np.uint8 ) is_foreground = binary_erosion(is_foreground, structure=structure) is_background = binary_erosion(is_background, structure=structure, border_value=1) trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128) trimap[is_foreground] = 255 trimap[is_background] = 0 img_normalized = img / 255.0 trimap_normalized = trimap / 255.0 alpha = estimate_alpha_cf(img_normalized, trimap_normalized) foreground = estimate_foreground_ml(img_normalized, alpha) cutout = stack_images(foreground, alpha) cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8) cutout = Image.fromarray(cutout) return cutout def naive_cutout(img: PILImage, mask: PILImage) -> PILImage: empty = Image.new("RGBA", (img.size), 0) cutout = Image.composite(img, empty, mask) return cutout def get_concat_v_multi(imgs: List[PILImage]) -> PILImage: pivot = imgs.pop(0) for im in imgs: pivot = get_concat_v(pivot, im) return pivot def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage: dst = Image.new("RGBA", (img1.width, img1.height + img2.height)) dst.paste(img1, (0, 0)) dst.paste(img2, (0, img1.height)) return dst def post_process(mask: np.ndarray) -> np.ndarray: """ Post Process the mask for a smooth boundary by applying Morphological Operations Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757 args: mask: Binary Numpy Mask """ mask = morphologyEx(mask, MORPH_OPEN, kernel) mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT) mask = np.where(mask < 127, 0, 255).astype(np.uint8) # convert again to binary return mask def remove( data: Union[bytes, PILImage, np.ndarray], alpha_matting: bool = False, alpha_matting_foreground_threshold: int = 240, alpha_matting_background_threshold: int = 10, alpha_matting_erode_size: int = 10, session: Optional[BaseSession] = None, only_mask: bool = False, post_process_mask: bool = False, ) -> Union[bytes, PILImage, np.ndarray]: if isinstance(data, PILImage): return_type = ReturnType.PILLOW img = data elif isinstance(data, bytes): return_type = ReturnType.BYTES img = Image.open(io.BytesIO(data)) elif isinstance(data, np.ndarray): return_type = ReturnType.NDARRAY img = Image.fromarray(data) else: raise ValueError("Input type {} is not supported.".format(type(data))) if session is None: session = new_session("u2net") masks = session.predict(img) cutouts = [] for mask in masks: if post_process_mask: mask = Image.fromarray(post_process(np.array(mask))) if only_mask: cutout = mask elif alpha_matting: try: cutout = alpha_matting_cutout( img, mask, alpha_matting_foreground_threshold, alpha_matting_background_threshold, alpha_matting_erode_size, ) except ValueError: cutout = naive_cutout(img, mask) else: cutout = naive_cutout(img, mask) cutouts.append(cutout) cutout = img if len(cutouts) > 0: cutout = get_concat_v_multi(cutouts) if ReturnType.PILLOW == return_type: return cutout if ReturnType.NDARRAY == return_type: return np.asarray(cutout) bio = io.BytesIO() cutout.save(bio, "PNG") bio.seek(0) return bio.read()