""" Source url: https://github.com/OPHoperHPO/image-background-remove-tool Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. License: Apache License 2.0 """ from PIL import Image from carvekit.trimap.cv_gen import CV2TrimapGenerator from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion class TrimapGenerator(CV2TrimapGenerator): def __init__( self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5 ): """ Initialize a TrimapGenerator instance Args: prob_threshold: Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied kernel_size: The size of the offset from the object mask in pixels when an unknown area is detected in the trimap erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area """ super().__init__(kernel_size, erosion_iters=0) self.prob_threshold = prob_threshold self.__erosion_iters = erosion_iters def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image: """ Generates trimap based on predicted object mask to refine object mask borders. Based on cv2 erosion algorithm and additional prob. filters. Args: original_image: Original image mask: Predicted object mask Returns: Generated trimap for image. """ filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold) trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask) new_trimap = prob_as_unknown_area( trimap=trimap, mask=mask, prob_threshold=self.prob_threshold ) new_trimap = post_erosion(new_trimap, self.__erosion_iters) return new_trimap