Spaces:
Runtime error
Runtime error
import os | |
import time | |
import sys | |
import cv2 | |
import hashlib | |
import requests | |
import numpy as np | |
from typing import Union | |
from PIL import Image | |
from tqdm import tqdm | |
def load_image(image: Union[np.ndarray, Image.Image, str], return_type='numpy'): | |
""" | |
Load image from path or PIL.Image or numpy.ndarray to required format. | |
""" | |
# Check if image is already in return_type | |
if isinstance(image, Image.Image) and return_type == 'pil' or \ | |
isinstance(image, np.ndarray) and return_type == 'numpy': | |
return image | |
# PIL.Image as intermediate format | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
if image.mode == "RGBA": | |
image = image.convert("RGB") | |
if return_type == 'pil': | |
return image | |
elif return_type == 'numpy': | |
return np.asarray(image) | |
else: | |
raise NotImplementedError() | |
def image_resize(image: Image.Image, res=1024): | |
width, height = org_size = image.size | |
ratio = min(1.0 * res / max(width, height), 1.0) | |
if ratio < 1.0: | |
image = image.resize((int(width * ratio), int(height * ratio))) | |
print('Scaling image from {} to {}'.format(org_size, image.size)) | |
return image | |
def xywh_to_x1y1x2y2(bbox): | |
x, y, w, h = bbox | |
return x,y,x+w,y+h | |
def x1y1x2y2_to_xywh(bbox): | |
x1, y1, x2, y2 = bbox | |
return x1,y1,x2-x1,y2-y1 | |
def get_image_shape(image): | |
if isinstance(image, str): | |
return Image.open(image).size | |
elif isinstance(image, np.ndarray): | |
return image.shape | |
elif isinstance(image, Image.Image): | |
return image.size | |
else: | |
raise NotImplementedError | |
def is_platform_win(): | |
return sys.platform == "win32" | |
def colormap(rgb=True): | |
color_list = np.array( | |
[ | |
0.000, 0.000, 0.000, | |
1.000, 1.000, 1.000, | |
1.000, 0.498, 0.313, | |
0.392, 0.581, 0.929, | |
0.000, 0.447, 0.741, | |
0.850, 0.325, 0.098, | |
0.929, 0.694, 0.125, | |
0.494, 0.184, 0.556, | |
0.466, 0.674, 0.188, | |
0.301, 0.745, 0.933, | |
0.635, 0.078, 0.184, | |
0.300, 0.300, 0.300, | |
0.600, 0.600, 0.600, | |
1.000, 0.000, 0.000, | |
1.000, 0.500, 0.000, | |
0.749, 0.749, 0.000, | |
0.000, 1.000, 0.000, | |
0.000, 0.000, 1.000, | |
0.667, 0.000, 1.000, | |
0.333, 0.333, 0.000, | |
0.333, 0.667, 0.000, | |
0.333, 1.000, 0.000, | |
0.667, 0.333, 0.000, | |
0.667, 0.667, 0.000, | |
0.667, 1.000, 0.000, | |
1.000, 0.333, 0.000, | |
1.000, 0.667, 0.000, | |
1.000, 1.000, 0.000, | |
0.000, 0.333, 0.500, | |
0.000, 0.667, 0.500, | |
0.000, 1.000, 0.500, | |
0.333, 0.000, 0.500, | |
0.333, 0.333, 0.500, | |
0.333, 0.667, 0.500, | |
0.333, 1.000, 0.500, | |
0.667, 0.000, 0.500, | |
0.667, 0.333, 0.500, | |
0.667, 0.667, 0.500, | |
0.667, 1.000, 0.500, | |
1.000, 0.000, 0.500, | |
1.000, 0.333, 0.500, | |
1.000, 0.667, 0.500, | |
1.000, 1.000, 0.500, | |
0.000, 0.333, 1.000, | |
0.000, 0.667, 1.000, | |
0.000, 1.000, 1.000, | |
0.333, 0.000, 1.000, | |
0.333, 0.333, 1.000, | |
0.333, 0.667, 1.000, | |
0.333, 1.000, 1.000, | |
0.667, 0.000, 1.000, | |
0.667, 0.333, 1.000, | |
0.667, 0.667, 1.000, | |
0.667, 1.000, 1.000, | |
1.000, 0.000, 1.000, | |
1.000, 0.333, 1.000, | |
1.000, 0.667, 1.000, | |
0.167, 0.000, 0.000, | |
0.333, 0.000, 0.000, | |
0.500, 0.000, 0.000, | |
0.667, 0.000, 0.000, | |
0.833, 0.000, 0.000, | |
1.000, 0.000, 0.000, | |
0.000, 0.167, 0.000, | |
0.000, 0.333, 0.000, | |
0.000, 0.500, 0.000, | |
0.000, 0.667, 0.000, | |
0.000, 0.833, 0.000, | |
0.000, 1.000, 0.000, | |
0.000, 0.000, 0.167, | |
0.000, 0.000, 0.333, | |
0.000, 0.000, 0.500, | |
0.000, 0.000, 0.667, | |
0.000, 0.000, 0.833, | |
0.000, 0.000, 1.000, | |
0.143, 0.143, 0.143, | |
0.286, 0.286, 0.286, | |
0.429, 0.429, 0.429, | |
0.571, 0.571, 0.571, | |
0.714, 0.714, 0.714, | |
0.857, 0.857, 0.857 | |
] | |
).astype(np.float32) | |
color_list = color_list.reshape((-1, 3)) * 255 | |
if not rgb: | |
color_list = color_list[:, ::-1] | |
return color_list | |
color_list = colormap() | |
color_list = color_list.astype('uint8').tolist() | |
def vis_add_mask(image, mask, color, alpha, kernel_size): | |
color = np.array(color) | |
mask = mask.astype('float').copy() | |
mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha) | |
for i in range(3): | |
image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask) | |
return image | |
def vis_add_mask_wo_blur(image, mask, color, alpha): | |
color = np.array(color) | |
mask = mask.astype('float').copy() | |
for i in range(3): | |
image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask) | |
return image | |
def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha, | |
contour_alpha): | |
background_color = np.array(background_color) | |
contour_color = np.array(contour_color) | |
# background_mask = 1 - background_mask | |
# contour_mask = 1 - contour_mask | |
for i in range(3): | |
image[:, :, i] = image[:, :, i] * (1 - background_alpha + background_mask * background_alpha) \ | |
+ background_color[i] * (background_alpha - background_mask * background_alpha) | |
image[:, :, i] = image[:, :, i] * (1 - contour_alpha + contour_mask * contour_alpha) \ | |
+ contour_color[i] * (contour_alpha - contour_mask * contour_alpha) | |
return image.astype('uint8') | |
def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, | |
contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False): | |
""" | |
add color mask to the background/foreground area | |
input_image: numpy array (w, h, C) | |
input_mask: numpy array (w, h) | |
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing | |
background_blur_radius: radius of background blur, must be odd number | |
contour_width: width of mask contour, must be odd number | |
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others | |
background_color: color index of the background (area with input_mask == False) | |
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted | |
paint_foreground: True for paint on foreground, False for background. Default: Flase | |
Output: | |
painted_image: numpy array | |
""" | |
assert input_image.shape[:2] == input_mask.shape, 'different shape' | |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD' | |
# 0: background, 1: foreground | |
input_mask[input_mask > 0] = 255 | |
if paint_foreground: | |
painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha, | |
background_blur_radius) # black for background | |
else: | |
# mask background | |
painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha, | |
background_blur_radius) # black for background | |
# mask contour | |
contour_mask = input_mask.copy() | |
contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction | |
# widden contour | |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width)) | |
contour_mask = cv2.dilate(contour_mask, kernel) | |
painted_image = vis_add_mask(painted_image, 255 - contour_mask, color_list[contour_color], contour_alpha, | |
contour_width) | |
return painted_image | |
def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7, | |
contour_width=3, contour_color=3, contour_alpha=1): | |
""" | |
paint color mask on the all foreground area | |
input_image: numpy array with shape (w, h, C) | |
input_mask: list of masks, each mask is a numpy array with shape (w,h) | |
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing | |
background_blur_radius: radius of background blur, must be odd number | |
contour_width: width of mask contour, must be odd number | |
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others | |
background_color: color index of the background (area with input_mask == False) | |
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted | |
Output: | |
painted_image: numpy array | |
""" | |
for i, input_mask in enumerate(input_masks): | |
input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, | |
contour_color, contour_alpha, background_color=i + 2, paint_foreground=True) | |
return input_image | |
def mask_generator_00(mask, background_radius, contour_radius): | |
# no background width when '00' | |
# distance map | |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) | |
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3) | |
dist_map = dist_transform_fore - dist_transform_back | |
# ...:::!!!:::... | |
contour_radius += 2 | |
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) | |
contour_mask = contour_mask / np.max(contour_mask) | |
contour_mask[contour_mask > 0.5] = 1. | |
return mask, contour_mask | |
def mask_generator_01(mask, background_radius, contour_radius): | |
# no background width when '00' | |
# distance map | |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) | |
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3) | |
dist_map = dist_transform_fore - dist_transform_back | |
# ...:::!!!:::... | |
contour_radius += 2 | |
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) | |
contour_mask = contour_mask / np.max(contour_mask) | |
return mask, contour_mask | |
def mask_generator_10(mask, background_radius, contour_radius): | |
# distance map | |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) | |
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3) | |
dist_map = dist_transform_fore - dist_transform_back | |
# .....:::::!!!!! | |
background_mask = np.clip(dist_map, -background_radius, background_radius) | |
background_mask = (background_mask - np.min(background_mask)) | |
background_mask = background_mask / np.max(background_mask) | |
# ...:::!!!:::... | |
contour_radius += 2 | |
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) | |
contour_mask = contour_mask / np.max(contour_mask) | |
contour_mask[contour_mask > 0.5] = 1. | |
return background_mask, contour_mask | |
def mask_generator_11(mask, background_radius, contour_radius): | |
# distance map | |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) | |
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3) | |
dist_map = dist_transform_fore - dist_transform_back | |
# .....:::::!!!!! | |
background_mask = np.clip(dist_map, -background_radius, background_radius) | |
background_mask = (background_mask - np.min(background_mask)) | |
background_mask = background_mask / np.max(background_mask) | |
# ...:::!!!:::... | |
contour_radius += 2 | |
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) | |
contour_mask = contour_mask / np.max(contour_mask) | |
return background_mask, contour_mask | |
def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, | |
contour_color=3, contour_alpha=1, mode='11'): | |
""" | |
Input: | |
input_image: numpy array | |
input_mask: numpy array | |
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing | |
background_blur_radius: radius of background blur, must be odd number | |
contour_width: width of mask contour, must be odd number | |
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others | |
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted | |
mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both | |
Output: | |
painted_image: numpy array | |
""" | |
assert input_image.shape[:2] == input_mask.shape, 'different shape' | |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD' | |
assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11' | |
# downsample input image and mask | |
width, height = input_image.shape[0], input_image.shape[1] | |
res = 1024 | |
ratio = min(1.0 * res / max(width, height), 1.0) | |
input_image = cv2.resize(input_image, (int(height * ratio), int(width * ratio))) | |
input_mask = cv2.resize(input_mask, (int(height * ratio), int(width * ratio))) | |
# 0: background, 1: foreground | |
msk = np.clip(input_mask, 0, 1) | |
# generate masks for background and contour pixels | |
background_radius = (background_blur_radius - 1) // 2 | |
contour_radius = (contour_width - 1) // 2 | |
generator_dict = {'00': mask_generator_00, '01': mask_generator_01, '10': mask_generator_10, | |
'11': mask_generator_11} | |
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius) | |
# paint | |
painted_image = vis_add_mask_wo_gaussian \ | |
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, | |
contour_alpha) # black for background | |
return painted_image | |
seg_model_map = { | |
'base': 'vit_b', | |
'large': 'vit_l', | |
'huge': 'vit_h' | |
} | |
ckpt_url_map = { | |
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', | |
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', | |
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth' | |
} | |
expected_sha256_map = { | |
'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912', | |
'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622', | |
'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e' | |
} | |
def prepare_segmenter(segmenter="huge", download_root: str = None): | |
""" | |
Prepare segmenter model and download checkpoint if necessary. | |
Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'. | |
""" | |
os.makedirs('result', exist_ok=True) | |
seg_model_name = seg_model_map[segmenter] | |
checkpoint_url = ckpt_url_map[seg_model_name] | |
folder = download_root or os.path.expanduser("~/.cache/SAM") | |
filename = os.path.basename(checkpoint_url) | |
segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name]) | |
return seg_model_name, segmenter_checkpoint | |
def download_checkpoint(url, folder, filename, expected_sha256): | |
os.makedirs(folder, exist_ok=True) | |
download_target = os.path.join(folder, filename) | |
if os.path.isfile(download_target): | |
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: | |
return download_target | |
print(f'Download SAM checkpoint {url}, saving to {download_target} ...') | |
with requests.get(url, stream=True) as response, open(download_target, "wb") as output: | |
progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True) | |
for data in response.iter_content(chunk_size=1024): | |
size = output.write(data) | |
progress.update(size) | |
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: | |
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") | |
return download_target | |
if __name__ == '__main__': | |
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing | |
background_blur_radius = 31 # radius of background blur, must be odd number | |
contour_width = 11 # contour width, must be odd number | |
contour_color = 3 # id in color map, 0: black, 1: white, >1: others | |
contour_alpha = 1 # transparency of background, 0: no contour highlighted | |
# load input image and mask | |
input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB')) | |
input_mask = np.array(Image.open('./test_images/painter_input_mask.jpg').convert('P')) | |
# paint | |
overall_time_1 = 0 | |
overall_time_2 = 0 | |
overall_time_3 = 0 | |
overall_time_4 = 0 | |
overall_time_5 = 0 | |
for i in range(50): | |
t2 = time.time() | |
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, | |
contour_width, contour_color, contour_alpha, mode='00') | |
e2 = time.time() | |
t3 = time.time() | |
painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, | |
contour_width, contour_color, contour_alpha, mode='10') | |
e3 = time.time() | |
t1 = time.time() | |
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, | |
contour_color, contour_alpha) | |
e1 = time.time() | |
t4 = time.time() | |
painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, | |
contour_width, contour_color, contour_alpha, mode='01') | |
e4 = time.time() | |
t5 = time.time() | |
painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, | |
contour_width, contour_color, contour_alpha, mode='11') | |
e5 = time.time() | |
overall_time_1 += (e1 - t1) | |
overall_time_2 += (e2 - t2) | |
overall_time_3 += (e3 - t3) | |
overall_time_4 += (e4 - t4) | |
overall_time_5 += (e5 - t5) | |
print(f'average time w gaussian: {overall_time_1 / 50}') | |
print(f'average time w/o gaussian00: {overall_time_2 / 50}') | |
print(f'average time w/o gaussian10: {overall_time_3 / 50}') | |
print(f'average time w/o gaussian01: {overall_time_4 / 50}') | |
print(f'average time w/o gaussian11: {overall_time_5 / 50}') | |
# save | |
painted_image_00 = Image.fromarray(painted_image_00) | |
painted_image_00.save('./test_images/painter_output_image_00.png') | |
painted_image_10 = Image.fromarray(painted_image_10) | |
painted_image_10.save('./test_images/painter_output_image_10.png') | |
painted_image_01 = Image.fromarray(painted_image_01) | |
painted_image_01.save('./test_images/painter_output_image_01.png') | |
painted_image_11 = Image.fromarray(painted_image_11) | |
painted_image_11.save('./test_images/painter_output_image_11.png') | |