import gc import numpy as np import torch from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator import gradio as gr import cv2 from demo.mask_utils import * class SAM_Inference: def __init__(self, model_type='vit_b', device='cuda') -> None: models = { 'vit_b': './checkpoints/sam_vit_b_01ec64.pth', 'vit_l': './checkpoints/sam_vit_l_0b3195.pth', 'vit_h': './checkpoints/sam_vit_h_4b8939.pth' } sam = sam_model_registry[model_type](checkpoint=models[model_type]) sam = sam.to(device) self.predictor = SamPredictor(sam) self.mask_generator = SamAutomaticMaskGenerator(model=sam) def img_select_point(self, original_img: np.ndarray, evt: gr.SelectData): img = original_img.copy() sel_pix = [(evt.index, 1)] # append the foreground_point masks = self.run_inference(original_img, sel_pix) for point, label in sel_pix: cv2.circle(img, point, 5, (240, 240, 240), -1, 0) cv2.circle(img, point, 5, (30, 144, 255), 2, 0) mask = masks[0][0] colored_mask = mask_foreground(mask) res = img_add_masks(original_img, colored_mask, mask) return img, process_mask_to_show(mask), res, mask def gen_box_seg(self, inp): if inp is None: raise gr.Error("Please upload an image first!") image = inp['image'] if len(inp['boxes']) == 0: raise gr.Error("Please clear the raw boxes and draw a box first!") boxes = inp['boxes'][-1] input_box = np.array([boxes[0], boxes[1], boxes[2], boxes[3]]).astype(int) masks = self.predict_box(image, input_box) mask = masks[0][0] colored_mask = mask_foreground(mask) res = img_add_masks(image, colored_mask, mask) return process_mask_to_show(mask), res, mask def run_inference(self, input_x, selected_points): if len(selected_points) == 0: return [] self.predictor.set_image(input_x) points = torch.Tensor( [p for p, _ in selected_points] ).to(self.predictor.device).unsqueeze(0) labels = torch.Tensor( [int(l) for _, l in selected_points] ).to(self.predictor.device).unsqueeze(0) transformed_points = self.predictor.transform.apply_coords_torch( points, input_x.shape[:2]) # predict segmentation according to the boxes masks, scores, logits = self.predictor.predict_torch( point_coords=transformed_points, point_labels=labels, multimask_output=False, ) masks = masks.cpu().detach().numpy() gc.collect() torch.cuda.empty_cache() return masks def predict_box(self, input_x, input_box): self.predictor.set_image(input_x) input_boxes = torch.tensor(input_box[None, :], device=self.predictor.device) transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, input_x.shape[:2]) masks, _, _ = self.predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False ) masks = masks.cpu().detach().numpy() gc.collect() torch.cuda.empty_cache() return masks