import numpy as np import gradio as gr import cv2 from copy import deepcopy import torch from torchvision import transforms from PIL import Image, ImageDraw, ImageFont from sam.efficient_sam.build_efficient_sam import build_efficient_sam_vits from src.utils.utils import resize_numpy_image sam = build_efficient_sam_vits() def show_point_or_box(image, global_points): # for point if len(global_points) == 1: image = cv2.circle(image, global_points[0], 10, (0, 0, 255), -1) # for box if len(global_points) == 2: p1 = global_points[0] p2 = global_points[1] image = cv2.rectangle(image,(int(p1[0]),int(p1[1])),(int(p2[0]),int(p2[1])),(0,0,255),2) return image def segment_with_points( image, original_image, global_points, global_point_label, evt: gr.SelectData, img_direction, save_dir = "./tmp" ): if original_image is None: original_image = image else: image = original_image if img_direction is None: img_direction = original_image x, y = evt.index[0], evt.index[1] image_path = None mask_path = None if len(global_points) == 0: global_points.append([x, y]) global_point_label.append(2) image_with_point= show_point_or_box(image.copy(), global_points) return image_with_point, original_image, None, global_points, global_point_label elif len(global_points) == 1: global_points.append([x, y]) global_point_label.append(3) x1, y1 = global_points[0] x2, y2 = global_points[1] if x1 < x2 and y1 >= y2: global_points[0][0] = x1 global_points[0][1] = y2 global_points[1][0] = x2 global_points[1][1] = y1 elif x1 >= x2 and y1 < y2: global_points[0][0] = x2 global_points[0][1] = y1 global_points[1][0] = x1 global_points[1][1] = y2 elif x1 >= x2 and y1 >= y2: global_points[0][0] = x2 global_points[0][1] = y2 global_points[1][0] = x1 global_points[1][1] = y1 image_with_point = show_point_or_box(image.copy(), global_points) # data process input_point = np.array(global_points) input_label = np.array(global_point_label) pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2]) pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1]) img_tensor = transforms.ToTensor()(image) # sam predicted_logits, predicted_iou = sam( img_tensor[None, ...], pts_sampled, pts_labels, ) mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy() mask_image = (mask*255.).astype(np.uint8) return image_with_point, original_image, mask_image, global_points, global_point_label else: global_points=[[x, y]] global_point_label=[2] image_with_point= show_point_or_box(image.copy(), global_points) return image_with_point, original_image, None, global_points, global_point_label def segment_with_points_paste( image, original_image, global_points, global_point_label, image_b, evt: gr.SelectData, dx, dy, resize_scale ): if original_image is None: original_image = image else: image = original_image x, y = evt.index[0], evt.index[1] if len(global_points) == 0: global_points.append([x, y]) global_point_label.append(2) image_with_point= show_point_or_box(image.copy(), global_points) return image_with_point, original_image, None, global_points, global_point_label, None elif len(global_points) == 1: global_points.append([x, y]) global_point_label.append(3) x1, y1 = global_points[0] x2, y2 = global_points[1] if x1 < x2 and y1 >= y2: global_points[0][0] = x1 global_points[0][1] = y2 global_points[1][0] = x2 global_points[1][1] = y1 elif x1 >= x2 and y1 < y2: global_points[0][0] = x2 global_points[0][1] = y1 global_points[1][0] = x1 global_points[1][1] = y2 elif x1 >= x2 and y1 >= y2: global_points[0][0] = x2 global_points[0][1] = y2 global_points[1][0] = x1 global_points[1][1] = y1 image_with_point = show_point_or_box(image.copy(), global_points) # data process input_point = np.array(global_points) input_label = np.array(global_point_label) pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2]) pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1]) img_tensor = transforms.ToTensor()(image) # sam predicted_logits, predicted_iou = sam( img_tensor[None, ...], pts_sampled, pts_labels, ) mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy() mask_uint8 = (mask*255.).astype(np.uint8) return image_with_point, original_image, paste_with_mask_and_offset(image, image_b, mask_uint8, dx, dy, resize_scale), global_points, global_point_label, mask_uint8 else: global_points=[[x, y]] global_point_label=[2] image_with_point= show_point_or_box(image.copy(), global_points) return image_with_point, original_image, None, global_points, global_point_label, None def paste_with_mask_and_offset(image_a, image_b, mask, x_offset=0, y_offset=0, delta=1): try: numpy_mask = np.array(mask) y_coords, x_coords = np.nonzero(numpy_mask) x_min = x_coords.min() x_max = x_coords.max() y_min = y_coords.min() y_max = y_coords.max() target_center_x = int((x_min + x_max) / 2) target_center_y = int((y_min + y_max) / 2) image_a = Image.fromarray(image_a) image_b = Image.fromarray(image_b) mask = Image.fromarray(mask) if image_a.size != mask.size: mask = mask.resize(image_a.size) cropped_image = Image.composite(image_a, Image.new('RGBA', image_a.size, (0, 0, 0, 0)), mask) x_b = int(target_center_x * (image_b.width / cropped_image.width)) y_b = int(target_center_y * (image_b.height / cropped_image.height)) x_offset = x_offset - int((delta - 1) * x_b) y_offset = y_offset - int((delta - 1) * y_b) cropped_image = cropped_image.resize(image_b.size) new_size = (int(cropped_image.width * delta), int(cropped_image.height * delta)) cropped_image = cropped_image.resize(new_size) image_b.putalpha(128) result_image = Image.new('RGBA', image_b.size, (0, 0, 0, 0)) result_image.paste(image_b, (0, 0)) result_image.paste(cropped_image, (x_offset, y_offset), mask=cropped_image) return result_image except: return None def upload_image_move(img, original_image): if original_image is not None: return original_image else: return img def fun_clear(*args): result = [] for arg in args: if isinstance(arg, list): result.append([]) else: result.append(None) return tuple(result) def clear_points(img): image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() return [], masked_img def get_point(img, sel_pix, evt: gr.SelectData): sel_pix.append(evt.index) points = [] for idx, point in enumerate(sel_pix): if idx % 2 == 0: cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) else: cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) points.append(tuple(point)) if len(points) == 2: cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) points = [] return img if isinstance(img, np.ndarray) else np.array(img) def calculate_translation_percentage(ori_shape, selected_points): dx = selected_points[1][0] - selected_points[0][0] dy = selected_points[1][1] - selected_points[0][1] dx_percentage = dx / ori_shape[1] dy_percentage = dy / ori_shape[0] return dx_percentage, dy_percentage def get_point_move(original_image, img, sel_pix, evt: gr.SelectData): if original_image is not None: img = original_image.copy() else: original_image = img.copy() if len(sel_pix)<2: sel_pix.append(evt.index) else: sel_pix = [evt.index] points = [] dx, dy = 0, 0 for idx, point in enumerate(sel_pix): if idx % 2 == 0: cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) else: cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) points.append(tuple(point)) if len(points) == 2: cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) ori_shape = original_image.shape dx, dy = calculate_translation_percentage(original_image.shape, sel_pix) points = [] img = np.array(img) return img, original_image, sel_pix, dx, dy def store_img(img): image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() return image, masked_img, mask def store_img_move(img, mask=None): if mask is not None: image = img["image"] return image, None, mask image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() return image, masked_img, (mask*255.).astype(np.uint8) def mask_image(image, mask, color=[255,0,0], alpha=0.5, max_resolution=None): """ Overlay mask on image for visualization purpose. Args: image (H, W, 3) or (H, W): input image mask (H, W): mask to be overlaid color: the color of overlaid mask alpha: the transparency of the mask """ if max_resolution is not None: image, _ = resize_numpy_image(image, max_resolution*max_resolution) mask = cv2.resize(mask, (image.shape[1], image.shape[0]),interpolation=cv2.INTER_NEAREST) out = deepcopy(image) img = deepcopy(image) img[mask == 1] = color out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) contours = cv2.findContours(np.uint8(deepcopy(mask)), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] return out