Spaces:
Sleeping
Sleeping
| import io | |
| import base64 | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| def postprocess_mask(mask): | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| cleaned = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_CLOSE, kernel) | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(cleaned, connectivity=8) | |
| if num_labels <= 1: | |
| return cleaned | |
| largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA]) | |
| result = np.zeros_like(cleaned) | |
| result[labels == largest_label] = 255 | |
| return result | |
| def parse_mask_data(mask_data, h, w): | |
| if isinstance(mask_data, str): | |
| mask_bytes = base64.b64decode(mask_data) | |
| mask_img = Image.open(io.BytesIO(mask_bytes)) | |
| mask_np = np.array(mask_img.convert('L')) | |
| elif isinstance(mask_data, dict) and 'data' in mask_data: | |
| mask_bytes = base64.b64decode(mask_data['data']) | |
| mask_img = Image.open(io.BytesIO(mask_bytes)) | |
| mask_np = np.array(mask_img.convert('L')) | |
| else: | |
| mask_np = np.array(mask_data, dtype=np.uint8) | |
| if mask_np.shape != (h, w): | |
| resample = getattr(Image, 'LANCZOS', getattr(Image, 'NEAREST', None)) | |
| mask_pil = Image.fromarray(mask_np) | |
| mask_pil = mask_pil.resize((w, h), resample) | |
| mask_np = np.array(mask_pil) | |
| binary_mask = np.zeros_like(mask_np, dtype=np.uint8) | |
| binary_mask[mask_np > 127] = 255 | |
| return postprocess_mask(binary_mask) | |
| def render_overlay(image, mask): | |
| overlay = image.copy() | |
| mask_bool = mask == 255 | |
| cyan_tint = np.zeros_like(image) | |
| cyan_tint[:, :] = (255, 255, 0) | |
| overlay[mask_bool] = cv2.addWeighted(image, 0.75, cyan_tint, 0.25, 0)[mask_bool] | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(overlay, contours, -1, (0, 255, 255), 3) | |
| _, buffer = cv2.imencode('.png', overlay) | |
| return base64.b64encode(buffer).decode('utf-8') | |