|
""" |
|
Source url: https://github.com/OPHoperHPO/image-background-remove-tool |
|
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. |
|
License: Apache License 2.0 |
|
""" |
|
from pathlib import Path |
|
from typing import Union, List, Optional |
|
|
|
from PIL import Image |
|
|
|
from carvekit.ml.wrap.basnet import BASNET |
|
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 |
|
from carvekit.ml.wrap.u2net import U2NET |
|
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 |
|
from carvekit.pipelines.preprocessing import PreprocessingStub |
|
from carvekit.pipelines.postprocessing import MattingMethod |
|
from carvekit.utils.image_utils import load_image |
|
from carvekit.utils.mask_utils import apply_mask |
|
from carvekit.utils.pool_utils import thread_pool_processing |
|
|
|
|
|
class Interface: |
|
def __init__( |
|
self, |
|
seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7], |
|
pre_pipe: Optional[Union[PreprocessingStub]] = None, |
|
post_pipe: Optional[Union[MattingMethod]] = None, |
|
device="cpu", |
|
): |
|
""" |
|
Initializes an object for interacting with pipelines and other components of the CarveKit framework. |
|
|
|
Args: |
|
pre_pipe: Initialized pre-processing pipeline object |
|
seg_pipe: Initialized segmentation network object |
|
post_pipe: Initialized postprocessing pipeline object |
|
device: The processing device that will be used to apply the masks to the images. |
|
""" |
|
self.device = device |
|
self.preprocessing_pipeline = pre_pipe |
|
self.segmentation_pipeline = seg_pipe |
|
self.postprocessing_pipeline = post_pipe |
|
|
|
def __call__( |
|
self, images: List[Union[str, Path, Image.Image]] |
|
) -> List[Image.Image]: |
|
""" |
|
Removes the background from the specified images. |
|
|
|
Args: |
|
images: list of input images |
|
|
|
Returns: |
|
List of images without background as PIL.Image.Image instances |
|
""" |
|
images = thread_pool_processing(load_image, images) |
|
if self.preprocessing_pipeline is not None: |
|
masks: List[Image.Image] = self.preprocessing_pipeline( |
|
interface=self, images=images |
|
) |
|
else: |
|
masks: List[Image.Image] = self.segmentation_pipeline(images=images) |
|
|
|
if self.postprocessing_pipeline is not None: |
|
images: List[Image.Image] = self.postprocessing_pipeline( |
|
images=images, masks=masks |
|
) |
|
else: |
|
images = list( |
|
map( |
|
lambda x: apply_mask( |
|
image=images[x], mask=masks[x], device=self.device |
|
), |
|
range(len(images)), |
|
) |
|
) |
|
return images |
|
|