Towsif7's picture
firrst commit
59e40e1
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
from pathlib import Path
from typing import Union, List, Optional
from PIL import Image
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.u2net import U2NET
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.utils.image_utils import load_image
from carvekit.utils.mask_utils import apply_mask
from carvekit.utils.pool_utils import thread_pool_processing
class Interface:
def __init__(
self,
seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7],
pre_pipe: Optional[Union[PreprocessingStub]] = None,
post_pipe: Optional[Union[MattingMethod]] = None,
device="cpu",
):
"""
Initializes an object for interacting with pipelines and other components of the CarveKit framework.
Args:
pre_pipe: Initialized pre-processing pipeline object
seg_pipe: Initialized segmentation network object
post_pipe: Initialized postprocessing pipeline object
device: The processing device that will be used to apply the masks to the images.
"""
self.device = device
self.preprocessing_pipeline = pre_pipe
self.segmentation_pipeline = seg_pipe
self.postprocessing_pipeline = post_pipe
def __call__(
self, images: List[Union[str, Path, Image.Image]]
) -> List[Image.Image]:
"""
Removes the background from the specified images.
Args:
images: list of input images
Returns:
List of images without background as PIL.Image.Image instances
"""
images = thread_pool_processing(load_image, images)
if self.preprocessing_pipeline is not None:
masks: List[Image.Image] = self.preprocessing_pipeline(
interface=self, images=images
)
else:
masks: List[Image.Image] = self.segmentation_pipeline(images=images)
if self.postprocessing_pipeline is not None:
images: List[Image.Image] = self.postprocessing_pipeline(
images=images, masks=masks
)
else:
images = list(
map(
lambda x: apply_mask(
image=images[x], mask=masks[x], device=self.device
),
range(len(images)),
)
)
return images