""" Source url: https://github.com/OPHoperHPO/image-background-remove-tool Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. License: Apache License 2.0 """ from carvekit.ml.wrap.fba_matting import FBAMatting from typing import Union, List from PIL import Image from pathlib import Path from carvekit.trimap.cv_gen import CV2TrimapGenerator from carvekit.trimap.generator import TrimapGenerator from carvekit.utils.mask_utils import apply_mask from carvekit.utils.pool_utils import thread_pool_processing from carvekit.utils.image_utils import load_image, convert_image __all__ = ["MattingMethod"] class MattingMethod: """ Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap. Neural network for matting performs accurate object edge detection by using a special map called trimap, with unknown area that we scan for boundary, already known general object area and the background.""" def __init__( self, matting_module: Union[FBAMatting], trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator], device="cpu", ): """ Initializes Matting Method class. Args: matting_module: Initialized matting neural network class trimap_generator: Initialized trimap generator class device: Processing device used for applying mask to image """ self.device = device self.matting_module = matting_module self.trimap_generator = trimap_generator def __call__( self, images: List[Union[str, Path, Image.Image]], masks: List[Union[str, Path, Image.Image]], ): """ Passes data through apply_mask function Args: images: list of images masks: list pf masks Returns: list of images """ if len(images) != len(masks): raise ValueError("Images and Masks lists should have same length!") images = thread_pool_processing(lambda x: convert_image(load_image(x)), images) masks = thread_pool_processing( lambda x: convert_image(load_image(x), mode="L"), masks ) trimaps = thread_pool_processing( lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]), range(len(images)), ) alpha = self.matting_module(images=images, trimaps=trimaps) return list( map( lambda x: apply_mask( image=images[x], mask=alpha[x], device=self.device ), range(len(images)), ) )