Spaces:
Runtime error
Runtime error
import io | |
from enum import Enum | |
from typing import List, Optional, Union | |
import numpy as np | |
from cv2 import ( | |
BORDER_DEFAULT, | |
MORPH_ELLIPSE, | |
MORPH_OPEN, | |
GaussianBlur, | |
getStructuringElement, | |
morphologyEx, | |
) | |
from PIL import Image | |
from PIL.Image import Image as PILImage | |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf | |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml | |
from pymatting.util.util import stack_images | |
from scipy.ndimage.morphology import binary_erosion | |
from .session_base import BaseSession | |
from .session_factory import new_session | |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) | |
class ReturnType(Enum): | |
BYTES = 0 | |
PILLOW = 1 | |
NDARRAY = 2 | |
def alpha_matting_cutout( | |
img: PILImage, | |
mask: PILImage, | |
foreground_threshold: int, | |
background_threshold: int, | |
erode_structure_size: int, | |
) -> PILImage: | |
if img.mode == "RGBA" or img.mode == "CMYK": | |
img = img.convert("RGB") | |
img = np.asarray(img) | |
mask = np.asarray(mask) | |
is_foreground = mask > foreground_threshold | |
is_background = mask < background_threshold | |
structure = None | |
if erode_structure_size > 0: | |
structure = np.ones( | |
(erode_structure_size, erode_structure_size), dtype=np.uint8 | |
) | |
is_foreground = binary_erosion(is_foreground, structure=structure) | |
is_background = binary_erosion(is_background, structure=structure, border_value=1) | |
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128) | |
trimap[is_foreground] = 255 | |
trimap[is_background] = 0 | |
img_normalized = img / 255.0 | |
trimap_normalized = trimap / 255.0 | |
alpha = estimate_alpha_cf(img_normalized, trimap_normalized) | |
foreground = estimate_foreground_ml(img_normalized, alpha) | |
cutout = stack_images(foreground, alpha) | |
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8) | |
cutout = Image.fromarray(cutout) | |
return cutout | |
def naive_cutout(img: PILImage, mask: PILImage) -> PILImage: | |
empty = Image.new("RGBA", (img.size), 0) | |
cutout = Image.composite(img, empty, mask) | |
return cutout | |
def get_concat_v_multi(imgs: List[PILImage]) -> PILImage: | |
pivot = imgs.pop(0) | |
for im in imgs: | |
pivot = get_concat_v(pivot, im) | |
return pivot | |
def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage: | |
dst = Image.new("RGBA", (img1.width, img1.height + img2.height)) | |
dst.paste(img1, (0, 0)) | |
dst.paste(img2, (0, img1.height)) | |
return dst | |
def post_process(mask: np.ndarray) -> np.ndarray: | |
""" | |
Post Process the mask for a smooth boundary by applying Morphological Operations | |
Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757 | |
args: | |
mask: Binary Numpy Mask | |
""" | |
mask = morphologyEx(mask, MORPH_OPEN, kernel) | |
mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT) | |
mask = np.where(mask < 127, 0, 255).astype(np.uint8) # convert again to binary | |
return mask | |
def remove( | |
data: Union[bytes, PILImage, np.ndarray], | |
alpha_matting: bool = False, | |
alpha_matting_foreground_threshold: int = 240, | |
alpha_matting_background_threshold: int = 10, | |
alpha_matting_erode_size: int = 10, | |
session: Optional[BaseSession] = None, | |
only_mask: bool = False, | |
post_process_mask: bool = False, | |
) -> Union[bytes, PILImage, np.ndarray]: | |
if isinstance(data, PILImage): | |
return_type = ReturnType.PILLOW | |
img = data | |
elif isinstance(data, bytes): | |
return_type = ReturnType.BYTES | |
img = Image.open(io.BytesIO(data)) | |
elif isinstance(data, np.ndarray): | |
return_type = ReturnType.NDARRAY | |
img = Image.fromarray(data) | |
else: | |
raise ValueError("Input type {} is not supported.".format(type(data))) | |
if session is None: | |
session = new_session("u2net") | |
masks = session.predict(img) | |
cutouts = [] | |
for mask in masks: | |
if post_process_mask: | |
mask = Image.fromarray(post_process(np.array(mask))) | |
if only_mask: | |
cutout = mask | |
elif alpha_matting: | |
try: | |
cutout = alpha_matting_cutout( | |
img, | |
mask, | |
alpha_matting_foreground_threshold, | |
alpha_matting_background_threshold, | |
alpha_matting_erode_size, | |
) | |
except ValueError: | |
cutout = naive_cutout(img, mask) | |
else: | |
cutout = naive_cutout(img, mask) | |
cutouts.append(cutout) | |
cutout = img | |
if len(cutouts) > 0: | |
cutout = get_concat_v_multi(cutouts) | |
if ReturnType.PILLOW == return_type: | |
return cutout | |
if ReturnType.NDARRAY == return_type: | |
return np.asarray(cutout) | |
bio = io.BytesIO() | |
cutout.save(bio, "PNG") | |
bio.seek(0) | |
return bio.read() | |