Spaces:
Runtime error
Runtime error
import cv2 | |
import numpy as np | |
from PIL import Image | |
from typing import Any, Dict, List | |
def load_img_to_array(img_p): | |
return np.array(Image.open(img_p)) | |
def save_array_to_img(img_arr, img_p): | |
Image.fromarray(img_arr.astype(np.uint8)).save(img_p) | |
def dilate_mask(mask, dilate_factor=15): | |
mask = mask.astype(np.uint8) | |
mask = cv2.dilate( | |
mask, | |
np.ones((dilate_factor, dilate_factor), np.uint8), | |
iterations=1 | |
) | |
return mask | |
def erode_mask(mask, dilate_factor=15): | |
mask = mask.astype(np.uint8) | |
mask = cv2.erode( | |
mask, | |
np.ones((dilate_factor, dilate_factor), np.uint8), | |
iterations=1 | |
) | |
return mask | |
def show_mask(ax, mask: np.ndarray, random_color=False): | |
mask = mask.astype(np.uint8) | |
if np.max(mask) == 255: | |
mask = mask / 255 | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_img) | |
def show_points(ax, coords: List[List[float]], labels: List[int], size=375): | |
coords = np.array(coords) | |
labels = np.array(labels) | |
color_table = {0: 'red', 1: 'green'} | |
for label_value, color in color_table.items(): | |
points = coords[labels == label_value] | |
ax.scatter(points[:, 0], points[:, 1], color=color, marker='*', | |
s=size, edgecolor='white', linewidth=1.25) |