import functools from dataclasses import dataclass import PIL from PIL.Image import Image import numpy as np from typing import Union, Tuple, List, Optional, Callable from sklearn.decomposition import PCA import supervision as sv import torch from torch import nn import torch.nn.functional as F import torchvision import torchvision.transforms as T from segment_anything.utils.transforms import ResizeLongestSide from segment_anything.predictor import preprocess, postprocess_masks from segment_anything import build_sam, load_mobile_sam from sam_extension.utils import add_prompts_tag, get_empty_detections, transform_coords from sam_extension.pipeline.base import Pipeline, Output from sam_extension.pipeline.groundingdino import GroundingDinoPipeline from sam_extension.distillation_models.sam import load_distillation_sam, load_sam from sam_extension.distillation_models import * ORIGINAL_SAM_IMG_SIZE: int = 1024 PIXEL_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) PIXEL_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) PREPROCESS = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN, PIXEL_STD) POSTPROCESS_MASKS = functools.partial(postprocess_masks, ORIGINAL_SAM_IMG_SIZE) @dataclass(repr=True) class SAMEncoderOutput(Output): features: torch.Tensor interm_features: List[torch.Tensor] original_size: Tuple input_size: Tuple @dataclass(repr=True) class SAMEncoderProcesImgOutput(Output): input_image: torch.Tensor original_size: Tuple input_size: Tuple @dataclass(repr=True) class SAMDecoderPredictOutput(Output): masks_np: np.ndarray iou_predictions_np: np.ndarray low_res_masks_np: np.ndarray @dataclass(repr=True) class SAMDecoderPredictTorchOutput(Output): masks: torch.Tensor iou_predictions: torch.Tensor low_res_masks: torch.Tensor class SAMEncoderPipeline(Pipeline): def __init__(self, encoder: nn.Module, input_img_size: Tuple, multi_output: bool, preprocess: Callable, transform: ResizeLongestSide, device: str, *args, **kwargs): super(SAMEncoderPipeline, self).__init__(*args, **kwargs) self.encoder = encoder self.input_img_size = input_img_size self.multi_output = multi_output self.preprocess = preprocess self.transform = transform self.device = device @classmethod def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs): if 'sam_version' not in kwargs.keys(): sam_version = 'sam' else: sam_version = kwargs['sam_version'] sam = load_sam(ckpt_path, sam_version, device) encoder = sam.image_encoder encoder_type = encoder.__class__.__name__ if encoder_type in ['TinyViT', 'FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']: multi_output = False if encoder_type in ['FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']: input_img_size = (encoder.img_size, encoder.img_size) if encoder_type == 'DINOSAMViT': encoder = encoder.dino else: input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE) else: multi_output = True input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE) if sam.adaptor is None: transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE) preprocess_ = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN.to(device), PIXEL_STD.to(device)) else: transform = T.Compose([ T.Resize(input_img_size), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) preprocess_ = None pipeline = cls(encoder=encoder, input_img_size=input_img_size, multi_output=multi_output, preprocess=preprocess_, transform=transform, device=device) del sam, encoder torch.cuda.empty_cache() return pipeline def process_img(self, img: Union[Image, np.ndarray]) -> SAMEncoderProcesImgOutput: if self.preprocess is not None: if isinstance(img, Image): img = np.uint8(img) input_image = self.transform.apply_image(img) input_image_torch = torch.as_tensor(input_image, device=self.device) input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] original_size = tuple(img.shape[:2]) input_size = tuple(input_image_torch.shape[-2:]) input_image = F.interpolate(self.preprocess(input_image_torch), size=self.input_img_size, mode='bilinear') else: if isinstance(img, np.ndarray): img = PIL.Image.fromarray(img) original_size = (img.size[1], img.size[0]) if original_size[0] > original_size[1]: input_h = 1024 input_w = int((1024 / original_size[0]) * original_size[1]) else: input_w = 1024 input_h = int((1024 / original_size[1]) * original_size[0]) input_size = (input_h, input_w) input_image = self.transform(img)[None, ...].to(self.device) return SAMEncoderProcesImgOutput(input_image, original_size, input_size) @torch.no_grad() def get_visual_feature(self, x: Union[torch.Tensor, Image, np.ndarray]=None, **kwargs): pca_rgb = PCA(n_components=3) if 'sam_feature' in kwargs.keys() and 'original_size' in kwargs.keys(): sam_feature = kwargs['sam_feature'] original_size = kwargs['original_size'] else: assert x is not None, 'please give x type Union[torch.Tensor, Image, np.ndarray] !' sam_encoder_output = self.forward(x, **kwargs) sam_feature = sam_encoder_output.features original_size = sam_encoder_output.original_size assert original_size is not None, 'please give original_size!' sam_feature = F.interpolate(sam_feature, size=original_size, mode='bilinear').permute(0, 2, 3, 1) b, h, w, c = sam_feature.shape sam_feature = sam_feature.view(-1, c).cpu().numpy() sam_feature = pca_rgb.fit_transform(sam_feature) sam_feature = torch.Tensor(sam_feature.reshape(h, w, 3)) min_f, _ = sam_feature.min(-1) max_f, _ = sam_feature.max(-1) sam_feature = (sam_feature - min_f[..., None]) / (max_f[..., None] - min_f[..., None]) sam_feature = sam_feature.cpu().numpy() sam_feature_image = PIL.Image.fromarray((sam_feature * 255).astype(np.uint8)) return sam_feature_image def forward(self, x: Union[torch.Tensor, Image, np.ndarray], **kwargs) -> SAMEncoderOutput: if isinstance(x, (Image, np.ndarray)): process_img_output = self.process_img(x) x = process_img_output.input_image original_size = process_img_output.original_size input_size = process_img_output.input_size else: original_size = kwargs.pop('original_size') if 'original_size' in kwargs.keys() else None input_size = x.shape[-2:] with torch.no_grad(): if self.multi_output: features, interm_features = self.encoder(x, **kwargs) else: features = self.encoder(x, **kwargs) if self.encoder.__class__.__name__ == 'DINO': features = features.permute(0, 3, 1, 2) interm_features = None return SAMEncoderOutput(features, interm_features, original_size, input_size) class SAMDecoderPipeline(Pipeline): def __init__(self, prompt_encoder: nn.Module, mask_decoder: nn.Module, adaptor: nn.Module, mask_threshold: float, transform: ResizeLongestSide, postprocess_masks: Callable, img_size: int, device: str, *args, **kwargs): super(SAMDecoderPipeline, self).__init__(*args, **kwargs) self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder self.adaptor = adaptor self.mask_threshold = mask_threshold self.transform = transform self.postprocess_masks = postprocess_masks self.img_size = img_size self.device = device @classmethod def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs): if 'sam_version' not in kwargs.keys(): sam_version = 'sam' else: sam_version = kwargs['sam_version'] sam = load_sam(ckpt_path, sam_version, device) if sam.image_encoder.__class__.__name__ == 'DINOSAMViT': adaptor = sam.image_encoder.adaptor elif sam.adaptor is not None: adaptor = sam.adaptor else: adaptor = None img_size = sam.image_encoder.img_size prompt_encoder = sam.prompt_encoder mask_decoder = sam.mask_decoder transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE) pipeline = cls(prompt_encoder=prompt_encoder, mask_decoder=mask_decoder, adaptor=adaptor, mask_threshold=sam.mask_threshold, transform=transform, postprocess_masks=POSTPROCESS_MASKS, img_size=img_size, device=device) del sam, prompt_encoder, mask_decoder torch.cuda.empty_cache() return pipeline def visualize_prompt(self, img: Union[Image, np.ndarray], des_img: Union[Image, np.ndarray] = None, point_labels: Union[List[int], np.ndarray] = None, point_coords: Union[List[List[int]], np.ndarray] = None, boxes: Union[List[List[int]], np.ndarray] = None, pil: bool = False ) -> Union[Image, np.ndarray]: if des_img is not None: if isinstance(des_img, np.ndarray): des_shape = tuple(des_img.shape[:2]) else: des_shape = (des_img.size[1], des_img.size[0]) src_shape = (img.size[1], img.size[0]) point_coords, boxes = transform_coords(src_shape, des_shape, point_coords, boxes) return add_prompts_tag(des_img, point_labels, point_coords, boxes, pil) else: return add_prompts_tag(img, point_labels, point_coords, boxes, pil) def visualize_results(self, img: Union[Image, np.ndarray], des_img: Union[Image, np.ndarray] = None, sam_encoder_output: Optional[SAMEncoderOutput] = None, features: Optional[torch.Tensor] = None, interm_features: Optional[List[torch.Tensor]] = None, original_size: Optional[Tuple] = None, input_size: Optional[Tuple] = None, point_coords: Optional[np.ndarray] = None, point_labels: Optional[np.ndarray] = None, boxes: Optional[np.ndarray] = None, texts: Optional[List] = None, grounding_dino_pipeline: GroundingDinoPipeline = None, box_threshold: float = 0.25, text_threshold: float = 0.25, nms_threshold: float = 0.8, detections: Optional[sv.Detections] = None, multimask_output: bool = True, visualize_promts: bool = True, pil: bool = False): if isinstance(img, Image): img = np.uint8(img) if des_img is not None: if isinstance(des_img, np.ndarray): des_shape = tuple(des_img.shape[:2]) else: des_shape = (des_img.size[1], des_img.size[0]) src_shape = img.shape[:2] if point_coords is not None or boxes is not None: des_point_coords, des_boxes = transform_coords(src_shape, des_shape, point_coords, boxes) else: des_point_coords = None des_boxes = None else: des_point_coords = None des_boxes = None src_shape = None des_shape = None detections = get_empty_detections() if detections is None else detections mask_annotator = sv.MaskAnnotator() result_list = [] mask_result_list = [] mask_list = [] if boxes is None and point_coords is None and point_labels is None and texts is None or \ (point_coords is not None and point_labels is not None and point_coords.shape[0] != point_labels.shape[0]): print('no prompt given!') result_list.append(img) return result_list # if boxes is not None and point_coords is not None and point_labels is not None: # multimask_output = False def get_annotated_image(mask_annotator, detections, img, point_labels=None, point_coords=None, boxes=None, visualize_promts=True, pil=False): annotated_image = mask_annotator.annotate(scene=img.copy(), detections=detections) if visualize_promts: annotated_image = add_prompts_tag(annotated_image, point_labels, point_coords, boxes=boxes, pil=pil) else: if pil: annotated_image = PIL.Image.fromarray(annotated_image) return annotated_image def get_masked_image(img, masks, pil=True): masked_image_list = [] for i in range(masks.shape[0]): object_rgb = img * (masks[i].reshape(img.shape[0], img.shape[1], 1)) object_rgb = object_rgb.astype(np.uint8) bkgd_mask = np.where(object_rgb == 0, 1, 0) bkgd_mask *= 255 bkgd_mask = bkgd_mask.astype(np.uint8) object_rgb += bkgd_mask if pil: masked_image_list.append(PIL.Image.fromarray(object_rgb)) else: masked_image_list.append(object_rgb) return masked_image_list def interpolate_mask(mask_np, des_shape): mask_tensor = torch.tensor(mask_np, dtype=torch.float32).unsqueeze(0) mask_interpolate = F.interpolate(mask_tensor, size=des_shape, mode='bilinear') mask_interpolate = (mask_interpolate+0.5).long() mask_np = mask_interpolate.squeeze(0).numpy().astype(bool) return mask_np if point_coords is not None and point_labels is not None: if src_shape is not None: point_result = self.forward(sam_encoder_output, features, interm_features, original_size, input_size, des_point_coords, point_labels) masks_np = interpolate_mask(point_result.masks_np, src_shape) else: point_result = self.forward(sam_encoder_output, features, interm_features, original_size, input_size, point_coords, point_labels) masks_np = point_result.masks_np if multimask_output: for i in range(masks_np.shape[0]): detections.mask = masks_np[i][None, ...] mask_list.append(masks_np[i]) result_list.append(get_annotated_image(mask_annotator, detections, img, point_labels=point_labels, point_coords=point_coords, visualize_promts=visualize_promts, pil=pil)) mask_result_list += get_masked_image(img, detections.mask, pil=pil) else: index = np.argmax(point_result.iou_predictions_np) detections.mask = masks_np[index][None, ...] mask_list.append(masks_np[index]) result_list.append(get_annotated_image(mask_annotator, detections, img, point_labels=point_labels, point_coords=point_coords, visualize_promts=visualize_promts, pil=pil)) mask_result_list += get_masked_image(img, detections.mask, pil=pil) if boxes is not None: result_masks = [] if src_shape is not None: boxes_ = des_boxes else: boxes_ = boxes if boxes_.shape[0] > 1: for i in range(len(boxes)): box_result = self.forward(sam_encoder_output, features, interm_features, original_size, input_size, box=boxes_[i]) index = np.argmax(box_result.iou_predictions_np) result_masks.append(box_result.masks_np[index]) mask = np.array(result_masks) if src_shape is not None: masks_np = interpolate_mask(mask, src_shape) else: masks_np = mask mask_list.append(masks_np) detections.mask = masks_np result_list.append(get_annotated_image(mask_annotator, detections, img, boxes=boxes, visualize_promts=visualize_promts, pil=pil)) mask_result_list += get_masked_image(img, detections.mask, pil=pil) else: box_result = self.forward(sam_encoder_output, features, interm_features, original_size, input_size, box=boxes_) if src_shape is not None: masks_np = interpolate_mask(box_result.masks_np, src_shape) else: masks_np = box_result.masks_np if multimask_output: for i in range(masks_np.shape[0]): detections.mask = masks_np[i][None, ...] mask_list.append(masks_np[i]) result_list.append(get_annotated_image(mask_annotator, detections, img, boxes=boxes, visualize_promts=visualize_promts, pil=pil)) mask_result_list += get_masked_image(img, detections.mask, pil=pil) else: index = np.argmax(box_result.iou_predictions_np) detections.mask = masks_np[index][None, ...] mask_list.append(masks_np[index]) result_list.append(get_annotated_image(mask_annotator, detections, img, boxes=boxes, pil=pil)) mask_result_list += get_masked_image(img, detections.mask, pil=pil) if texts is not None and grounding_dino_pipeline is not None: detections = grounding_dino_pipeline(img[:, :, ::-1], texts, box_threshold, text_threshold) box_annotator = sv.BoxAnnotator() nms_idx = torchvision.ops.nms( torch.from_numpy(detections.xyxy), torch.from_numpy(detections.confidence), nms_threshold ).numpy().tolist() detections.xyxy = detections.xyxy[nms_idx] detections.confidence = detections.confidence[nms_idx] detections.class_id = detections.class_id[nms_idx] labels = [ f"{texts[class_id]} {confidence:0.2f}" for _, _, confidence, class_id, _ in detections] result_masks = [] if src_shape is not None: _, boxes_ = transform_coords(src_shape, des_shape, boxes=detections.xyxy) else: boxes_ = detections.xyxy for box in boxes_: box_result = self.forward(sam_encoder_output, features, interm_features, original_size, input_size, box=box) index = np.argmax(box_result.iou_predictions_np) result_masks.append(box_result.masks_np[index]) mask = np.array(result_masks) if src_shape is not None: detections.mask = interpolate_mask(mask, src_shape) else: detections.mask = mask for i in range(detections.mask.shape[0]): mask_list.append(detections.mask[i, ...]) if visualize_promts: annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections) annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) else: annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections) if pil: result_list.append(PIL.Image.fromarray(annotated_image[:, :, ::-1])) else: result_list.append(annotated_image[:, :, ::-1]) mask_result_list += get_masked_image(img, detections.mask, pil=pil) return result_list, mask_result_list, mask_list def predict( self, features: torch.Tensor, interm_features: List[torch.Tensor], original_size: Tuple, input_size: Tuple, point_coords: Optional[np.ndarray] = None, point_labels: Optional[np.ndarray] = None, box: Optional[np.ndarray] = None, mask_input: Optional[np.ndarray] = None, multimask_output: bool = True, return_logits: bool = False, hq_token_only: bool = False, ) -> SAMDecoderPredictOutput: """ Predict masks for the given input prompts, using the currently set image. Arguments: point_coords (np.ndarray or None): A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (np.ndarray or None): A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format. mask_input (np.ndarray): A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. Returns: (np.ndarray): The output masks in CxHxW format, where C is the number of masks, and (H, W) is the original image size. (np.ndarray): An array of length C containing the model's predictions for the quality of each mask. (np.ndarray): An array of shape CxHxW, where C is the number of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input. """ # Transform input prompts coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None if point_coords is not None: assert ( point_labels is not None ), "point_labels must be supplied if point_coords is supplied." point_coords = self.transform.apply_coords(point_coords, original_size) coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] if box is not None: box = self.transform.apply_boxes(box, original_size) box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) box_torch = box_torch[None, :] if mask_input is not None: mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) mask_input_torch = mask_input_torch[None, :, :, :] sam_decoder_predict_torch_output = self.predict_torch( features, interm_features, original_size, input_size, coords_torch, labels_torch, box_torch, mask_input_torch, multimask_output, return_logits=return_logits, hq_token_only=hq_token_only, ) masks_np = sam_decoder_predict_torch_output.masks[0].detach().cpu().numpy() iou_predictions_np = sam_decoder_predict_torch_output.iou_predictions[0].detach().cpu().numpy() low_res_masks_np = sam_decoder_predict_torch_output.low_res_masks[0].detach().cpu().numpy() return SAMDecoderPredictOutput(masks_np, iou_predictions_np, low_res_masks_np) @torch.no_grad() def predict_torch( self, features: torch.Tensor, interm_features: List[torch.Tensor], original_size: Tuple, input_size: Tuple, point_coords: Optional[torch.Tensor], point_labels: Optional[torch.Tensor], boxes: Optional[torch.Tensor] = None, mask_input: Optional[torch.Tensor] = None, multimask_output: bool = True, return_logits: bool = False, hq_token_only: bool = False, ) -> SAMDecoderPredictTorchOutput: """ Predict masks for the given input prompts, using the currently set image. Input prompts are batched torch tensors and are expected to already be transformed to the input frame using ResizeLongestSide. Arguments: point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (torch.Tensor or None): A BxN array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. boxes (np.ndarray or None): A Bx4 array given a box prompt to the model, in XYXY format. mask_input (np.ndarray): A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form Bx1xHxW, where for SAM, H=W=256. Masks returned by a previous iteration of the predict method do not need further transformation. multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. Returns: (torch.Tensor): The output masks in BxCxHxW format, where C is the number of masks, and (H, W) is the original image size. (torch.Tensor): An array of shape BxC containing the model's predictions for the quality of each mask. (torch.Tensor): An array of shape BxCxHxW, where C is the number of masks and H=W=256. These low res logits can be passed to a subsequent iteration as mask input. """ if point_coords is not None: points = (point_coords, point_labels) else: points = None # Embed prompts sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=boxes, masks=mask_input, ) # Predict masks low_res_masks, iou_predictions = self.mask_decoder( image_embeddings=features, image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, hq_token_only=hq_token_only, interm_embeddings=interm_features, ) # Upscale the masks to the original image resolution # masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) masks = self.postprocess_masks(low_res_masks, input_size, original_size) if not return_logits: masks = masks > self.mask_threshold return SAMDecoderPredictTorchOutput(masks, iou_predictions, low_res_masks) def forward(self, sam_encoder_output: Optional[SAMEncoderOutput]=None, features: Optional[torch.Tensor]=None, interm_features: Optional[List[torch.Tensor]]=None, original_size: Optional[Tuple]=None, input_size: Optional[Tuple]=None, point_coords: Optional[np.ndarray] = None, point_labels: Optional[np.ndarray] = None, box: Optional[np.ndarray] = None, mask_input: Optional[np.ndarray] = None, multimask_output: bool = True, return_logits: bool = False, hq_token_only: bool = False, dino: bool = False ) -> SAMDecoderPredictOutput: assert sam_encoder_output or (features is not None and original_size is not None and input_size is not None), 'one of sam_encoder_output and four necessary inputs must be given!' if sam_encoder_output: features = sam_encoder_output.features interm_features = sam_encoder_output.interm_features original_size = sam_encoder_output.original_size input_size = sam_encoder_output.input_size if self.adaptor is not None: if dino: features = F.interpolate(F.normalize(features, dim=1), size=(64, 64), mode='bilinear').permute(0, 2, 3, 1) features = self.adaptor(features) # # else: # features = self.adaptor(features, original_size) return self.predict(features, interm_features, original_size, input_size, point_coords, point_labels, box, mask_input, multimask_output, return_logits, hq_token_only) ''' class SAMPipeline(Pipeline): @classmethod def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs): sam_encoder_pipeline = SAMEncoderPipeline(ckpt_path, device, *args, **kwargs) sam_decoder_pipeline = SAMDecoderPipeline(ckpt_path, device, *args, **kwargs) pipeline = cls(**dict(sam_encoder_pipeline=sam_encoder_pipeline, sam_decoder_pipeline=sam_decoder_pipeline, device=device)) return pipeline '''