File size: 2,368 Bytes
a3d6c18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
import PIL.Image
import cv2
import numpy as np
class CV2TrimapGenerator:
def __init__(self, kernel_size: int = 30, erosion_iters: int = 1):
"""
Initialize a new CV2TrimapGenerator instance
Args:
kernel_size: The size of the offset from the object mask
in pixels when an unknown area is detected in the trimap
erosion_iters: The number of iterations of erosion that
the object's mask will be subjected to before forming an unknown area
"""
self.kernel_size = kernel_size
self.erosion_iters = erosion_iters
def __call__(
self, original_image: PIL.Image.Image, mask: PIL.Image.Image
) -> PIL.Image.Image:
"""
Generates trimap based on predicted object mask to refine object mask borders.
Based on cv2 erosion algorithm.
Args:
original_image: Original image
mask: Predicted object mask
Returns:
Generated trimap for image.
"""
if mask.mode != "L":
raise ValueError("Input mask has wrong color mode.")
if mask.size != original_image.size:
raise ValueError("Sizes of input image and predicted mask doesn't equal")
# noinspection PyTypeChecker
mask_array = np.array(mask)
pixels = 2 * self.kernel_size + 1
kernel = np.ones((pixels, pixels), np.uint8)
if self.erosion_iters > 0:
erosion_kernel = np.ones((3, 3), np.uint8)
erode = cv2.erode(mask_array, erosion_kernel, iterations=self.erosion_iters)
erode = np.where(erode == 0, 0, mask_array)
else:
erode = mask_array.copy()
dilation = cv2.dilate(erode, kernel, iterations=1)
dilation = np.where(dilation == 255, 127, dilation) # WHITE to GRAY
trimap = np.where(erode > 127, 200, dilation) # mark the tumor inside GRAY
trimap = np.where(trimap < 127, 0, trimap) # Embelishment
trimap = np.where(trimap > 200, 0, trimap) # Embelishment
trimap = np.where(trimap == 200, 255, trimap) # GRAY to WHITE
return PIL.Image.fromarray(trimap).convert("L")
|