ttengwang
support "segment everything in a paragraph"
ccb14a3
import os
import time
import sys
import cv2
import hashlib
import requests
import numpy as np
from typing import Union
from PIL import Image
from tqdm import tqdm
def load_image(image: Union[np.ndarray, Image.Image, str], return_type='numpy'):
"""
Load image from path or PIL.Image or numpy.ndarray to required format.
"""
# Check if image is already in return_type
if isinstance(image, Image.Image) and return_type == 'pil' or \
isinstance(image, np.ndarray) and return_type == 'numpy':
return image
# PIL.Image as intermediate format
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode == "RGBA":
image = image.convert("RGB")
if return_type == 'pil':
return image
elif return_type == 'numpy':
return np.asarray(image)
else:
raise NotImplementedError()
def image_resize(image: Image.Image, res=1024):
width, height = org_size = image.size
ratio = min(1.0 * res / max(width, height), 1.0)
if ratio < 1.0:
image = image.resize((int(width * ratio), int(height * ratio)))
print('Scaling image from {} to {}'.format(org_size, image.size))
return image
def xywh_to_x1y1x2y2(bbox):
x, y, w, h = bbox
return x,y,x+w,y+h
def x1y1x2y2_to_xywh(bbox):
x1, y1, x2, y2 = bbox
return x1,y1,x2-x1,y2-y1
def get_image_shape(image):
if isinstance(image, str):
return Image.open(image).size
elif isinstance(image, np.ndarray):
return image.shape
elif isinstance(image, Image.Image):
return image.size
else:
raise NotImplementedError
def is_platform_win():
return sys.platform == "win32"
def colormap(rgb=True):
color_list = np.array(
[
0.000, 0.000, 0.000,
1.000, 1.000, 1.000,
1.000, 0.498, 0.313,
0.392, 0.581, 0.929,
0.000, 0.447, 0.741,
0.850, 0.325, 0.098,
0.929, 0.694, 0.125,
0.494, 0.184, 0.556,
0.466, 0.674, 0.188,
0.301, 0.745, 0.933,
0.635, 0.078, 0.184,
0.300, 0.300, 0.300,
0.600, 0.600, 0.600,
1.000, 0.000, 0.000,
1.000, 0.500, 0.000,
0.749, 0.749, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 1.000,
0.667, 0.000, 1.000,
0.333, 0.333, 0.000,
0.333, 0.667, 0.000,
0.333, 1.000, 0.000,
0.667, 0.333, 0.000,
0.667, 0.667, 0.000,
0.667, 1.000, 0.000,
1.000, 0.333, 0.000,
1.000, 0.667, 0.000,
1.000, 1.000, 0.000,
0.000, 0.333, 0.500,
0.000, 0.667, 0.500,
0.000, 1.000, 0.500,
0.333, 0.000, 0.500,
0.333, 0.333, 0.500,
0.333, 0.667, 0.500,
0.333, 1.000, 0.500,
0.667, 0.000, 0.500,
0.667, 0.333, 0.500,
0.667, 0.667, 0.500,
0.667, 1.000, 0.500,
1.000, 0.000, 0.500,
1.000, 0.333, 0.500,
1.000, 0.667, 0.500,
1.000, 1.000, 0.500,
0.000, 0.333, 1.000,
0.000, 0.667, 1.000,
0.000, 1.000, 1.000,
0.333, 0.000, 1.000,
0.333, 0.333, 1.000,
0.333, 0.667, 1.000,
0.333, 1.000, 1.000,
0.667, 0.000, 1.000,
0.667, 0.333, 1.000,
0.667, 0.667, 1.000,
0.667, 1.000, 1.000,
1.000, 0.000, 1.000,
1.000, 0.333, 1.000,
1.000, 0.667, 1.000,
0.167, 0.000, 0.000,
0.333, 0.000, 0.000,
0.500, 0.000, 0.000,
0.667, 0.000, 0.000,
0.833, 0.000, 0.000,
1.000, 0.000, 0.000,
0.000, 0.167, 0.000,
0.000, 0.333, 0.000,
0.000, 0.500, 0.000,
0.000, 0.667, 0.000,
0.000, 0.833, 0.000,
0.000, 1.000, 0.000,
0.000, 0.000, 0.167,
0.000, 0.000, 0.333,
0.000, 0.000, 0.500,
0.000, 0.000, 0.667,
0.000, 0.000, 0.833,
0.000, 0.000, 1.000,
0.143, 0.143, 0.143,
0.286, 0.286, 0.286,
0.429, 0.429, 0.429,
0.571, 0.571, 0.571,
0.714, 0.714, 0.714,
0.857, 0.857, 0.857
]
).astype(np.float32)
color_list = color_list.reshape((-1, 3)) * 255
if not rgb:
color_list = color_list[:, ::-1]
return color_list
color_list = colormap()
color_list = color_list.astype('uint8').tolist()
def vis_add_mask(image, mask, color, alpha, kernel_size):
color = np.array(color)
mask = mask.astype('float').copy()
mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
for i in range(3):
image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
return image
def vis_add_mask_wo_blur(image, mask, color, alpha):
color = np.array(color)
mask = mask.astype('float').copy()
for i in range(3):
image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
return image
def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha,
contour_alpha):
background_color = np.array(background_color)
contour_color = np.array(contour_color)
# background_mask = 1 - background_mask
# contour_mask = 1 - contour_mask
for i in range(3):
image[:, :, i] = image[:, :, i] * (1 - background_alpha + background_mask * background_alpha) \
+ background_color[i] * (background_alpha - background_mask * background_alpha)
image[:, :, i] = image[:, :, i] * (1 - contour_alpha + contour_mask * contour_alpha) \
+ contour_color[i] * (contour_alpha - contour_mask * contour_alpha)
return image.astype('uint8')
def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3,
contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
"""
add color mask to the background/foreground area
input_image: numpy array (w, h, C)
input_mask: numpy array (w, h)
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
background_blur_radius: radius of background blur, must be odd number
contour_width: width of mask contour, must be odd number
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
background_color: color index of the background (area with input_mask == False)
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
paint_foreground: True for paint on foreground, False for background. Default: Flase
Output:
painted_image: numpy array
"""
assert input_image.shape[:2] == input_mask.shape, 'different shape'
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
# 0: background, 1: foreground
input_mask[input_mask > 0] = 255
if paint_foreground:
painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha,
background_blur_radius) # black for background
else:
# mask background
painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha,
background_blur_radius) # black for background
# mask contour
contour_mask = input_mask.copy()
contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
# widden contour
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
contour_mask = cv2.dilate(contour_mask, kernel)
painted_image = vis_add_mask(painted_image, 255 - contour_mask, color_list[contour_color], contour_alpha,
contour_width)
return painted_image
def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7,
contour_width=3, contour_color=3, contour_alpha=1):
"""
paint color mask on the all foreground area
input_image: numpy array with shape (w, h, C)
input_mask: list of masks, each mask is a numpy array with shape (w,h)
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
background_blur_radius: radius of background blur, must be odd number
contour_width: width of mask contour, must be odd number
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
background_color: color index of the background (area with input_mask == False)
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
Output:
painted_image: numpy array
"""
for i, input_mask in enumerate(input_masks):
input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
return input_image
def mask_generator_00(mask, background_radius, contour_radius):
# no background width when '00'
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
contour_mask[contour_mask > 0.5] = 1.
return mask, contour_mask
def mask_generator_01(mask, background_radius, contour_radius):
# no background width when '00'
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
return mask, contour_mask
def mask_generator_10(mask, background_radius, contour_radius):
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# .....:::::!!!!!
background_mask = np.clip(dist_map, -background_radius, background_radius)
background_mask = (background_mask - np.min(background_mask))
background_mask = background_mask / np.max(background_mask)
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
contour_mask[contour_mask > 0.5] = 1.
return background_mask, contour_mask
def mask_generator_11(mask, background_radius, contour_radius):
# distance map
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
dist_map = dist_transform_fore - dist_transform_back
# .....:::::!!!!!
background_mask = np.clip(dist_map, -background_radius, background_radius)
background_mask = (background_mask - np.min(background_mask))
background_mask = background_mask / np.max(background_mask)
# ...:::!!!:::...
contour_radius += 2
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
contour_mask = contour_mask / np.max(contour_mask)
return background_mask, contour_mask
def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3,
contour_color=3, contour_alpha=1, mode='11'):
"""
Input:
input_image: numpy array
input_mask: numpy array
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
background_blur_radius: radius of background blur, must be odd number
contour_width: width of mask contour, must be odd number
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
Output:
painted_image: numpy array
"""
assert input_image.shape[:2] == input_mask.shape, 'different shape'
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
# downsample input image and mask
width, height = input_image.shape[0], input_image.shape[1]
res = 1024
ratio = min(1.0 * res / max(width, height), 1.0)
input_image = cv2.resize(input_image, (int(height * ratio), int(width * ratio)))
input_mask = cv2.resize(input_mask, (int(height * ratio), int(width * ratio)))
# 0: background, 1: foreground
msk = np.clip(input_mask, 0, 1)
# generate masks for background and contour pixels
background_radius = (background_blur_radius - 1) // 2
contour_radius = (contour_width - 1) // 2
generator_dict = {'00': mask_generator_00, '01': mask_generator_01, '10': mask_generator_10,
'11': mask_generator_11}
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
# paint
painted_image = vis_add_mask_wo_gaussian \
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha,
contour_alpha) # black for background
return painted_image
seg_model_map = {
'base': 'vit_b',
'large': 'vit_l',
'huge': 'vit_h'
}
ckpt_url_map = {
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
}
expected_sha256_map = {
'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
}
def prepare_segmenter(segmenter="huge", download_root: str = None):
"""
Prepare segmenter model and download checkpoint if necessary.
Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
"""
os.makedirs('result', exist_ok=True)
seg_model_name = seg_model_map[segmenter]
checkpoint_url = ckpt_url_map[seg_model_name]
folder = download_root or os.path.expanduser("~/.cache/SAM")
filename = os.path.basename(checkpoint_url)
segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
return seg_model_name, segmenter_checkpoint
def download_checkpoint(url, folder, filename, expected_sha256):
os.makedirs(folder, exist_ok=True)
download_target = os.path.join(folder, filename)
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
for data in response.iter_content(chunk_size=1024):
size = output.write(data)
progress.update(size)
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target
if __name__ == '__main__':
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
background_blur_radius = 31 # radius of background blur, must be odd number
contour_width = 11 # contour width, must be odd number
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
contour_alpha = 1 # transparency of background, 0: no contour highlighted
# load input image and mask
input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
input_mask = np.array(Image.open('./test_images/painter_input_mask.jpg').convert('P'))
# paint
overall_time_1 = 0
overall_time_2 = 0
overall_time_3 = 0
overall_time_4 = 0
overall_time_5 = 0
for i in range(50):
t2 = time.time()
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
contour_width, contour_color, contour_alpha, mode='00')
e2 = time.time()
t3 = time.time()
painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
contour_width, contour_color, contour_alpha, mode='10')
e3 = time.time()
t1 = time.time()
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
contour_color, contour_alpha)
e1 = time.time()
t4 = time.time()
painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
contour_width, contour_color, contour_alpha, mode='01')
e4 = time.time()
t5 = time.time()
painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
contour_width, contour_color, contour_alpha, mode='11')
e5 = time.time()
overall_time_1 += (e1 - t1)
overall_time_2 += (e2 - t2)
overall_time_3 += (e3 - t3)
overall_time_4 += (e4 - t4)
overall_time_5 += (e5 - t5)
print(f'average time w gaussian: {overall_time_1 / 50}')
print(f'average time w/o gaussian00: {overall_time_2 / 50}')
print(f'average time w/o gaussian10: {overall_time_3 / 50}')
print(f'average time w/o gaussian01: {overall_time_4 / 50}')
print(f'average time w/o gaussian11: {overall_time_5 / 50}')
# save
painted_image_00 = Image.fromarray(painted_image_00)
painted_image_00.save('./test_images/painter_output_image_00.png')
painted_image_10 = Image.fromarray(painted_image_10)
painted_image_10.save('./test_images/painter_output_image_10.png')
painted_image_01 = Image.fromarray(painted_image_01)
painted_image_01.save('./test_images/painter_output_image_01.png')
painted_image_11 = Image.fromarray(painted_image_11)
painted_image_11.save('./test_images/painter_output_image_11.png')