File size: 1,787 Bytes
9856e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from grounded_sam_demo import grounded_sam_demo
import numpy as np
from PIL import Image
from scipy.ndimage import convolve
from scipy.ndimage import binary_dilation
def get_sd_mask(color_mask_pil, target=(72, 4, 84), tolerance=50):
image_array = np.array(color_mask_pil)
# Update target based on the number of color channels in the image array
target = np.array(list(target) + [255] *
(image_array.shape[-1] - len(target)))
mask = np.abs(image_array - target) <= tolerance
mask = np.all(mask, axis=-1)
new_image_array = np.ones_like(image_array) * 255 # Start with white
# Apply black where condition met
new_image_array[mask] = [0] * image_array.shape[-1]
return Image.fromarray(new_image_array)
def expand_white_pixels(input_pil, expand_by=1):
img_array = np.array(input_pil)
is_white = np.all(img_array == 255, axis=-1)
kernel = np.ones((2*expand_by+1, 2*expand_by+1), bool)
expanded_white = binary_dilation(is_white, structure=kernel)
expanded_array = np.where(expanded_white[..., None], 255, img_array)
expanded_pil = Image.fromarray(expanded_array.astype('uint8'))
return expanded_pil
config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
grounded_checkpoint = "groundingdino_swint_ogc.pth"
sam_checkpoint = "sam_hq_vit_h.pth"
def just_get_sd_mask(input_pil, text_prompt, padding):
print("Doing sam")
colored_mask_pil = grounded_sam_demo(
input_pil, config_file, grounded_checkpoint, sam_checkpoint, text_prompt)
print("doing to white")
sd_mask_pil = get_sd_mask(colored_mask_pil)
print("expanding white pixels")
sd_mask_withpadding_pil = expand_white_pixels(sd_mask_pil, padding)
return sd_mask_withpadding_pil
|