""" Source url: https://github.com/OPHoperHPO/image-background-remove-tool Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. License: Apache License 2.0 """ from pathlib import Path from typing import Union, List, Optional from PIL import Image from carvekit.ml.wrap.basnet import BASNET from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.u2net import U2NET from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.pipelines.preprocessing import PreprocessingStub from carvekit.pipelines.postprocessing import MattingMethod from carvekit.utils.image_utils import load_image from carvekit.utils.mask_utils import apply_mask from carvekit.utils.pool_utils import thread_pool_processing class Interface: def __init__( self, seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7], pre_pipe: Optional[Union[PreprocessingStub]] = None, post_pipe: Optional[Union[MattingMethod]] = None, device="cpu", ): """ Initializes an object for interacting with pipelines and other components of the CarveKit framework. Args: pre_pipe: Initialized pre-processing pipeline object seg_pipe: Initialized segmentation network object post_pipe: Initialized postprocessing pipeline object device: The processing device that will be used to apply the masks to the images. """ self.device = device self.preprocessing_pipeline = pre_pipe self.segmentation_pipeline = seg_pipe self.postprocessing_pipeline = post_pipe def __call__( self, images: List[Union[str, Path, Image.Image]] ) -> List[Image.Image]: """ Removes the background from the specified images. Args: images: list of input images Returns: List of images without background as PIL.Image.Image instances """ images = thread_pool_processing(load_image, images) if self.preprocessing_pipeline is not None: masks: List[Image.Image] = self.preprocessing_pipeline( interface=self, images=images ) else: masks: List[Image.Image] = self.segmentation_pipeline(images=images) if self.postprocessing_pipeline is not None: images: List[Image.Image] = self.postprocessing_pipeline( images=images, masks=masks ) else: images = list( map( lambda x: apply_mask( image=images[x], mask=masks[x], device=self.device ), range(len(images)), ) ) return images