Spaces:
Running
Running
import functools | |
from dataclasses import dataclass | |
import PIL | |
from PIL.Image import Image | |
import numpy as np | |
from typing import Union, Tuple, List, Optional, Callable | |
from sklearn.decomposition import PCA | |
import supervision as sv | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as T | |
from segment_anything.utils.transforms import ResizeLongestSide | |
from segment_anything.predictor import preprocess, postprocess_masks | |
from segment_anything import build_sam, load_mobile_sam | |
from sam_extension.utils import add_prompts_tag, get_empty_detections, transform_coords | |
from sam_extension.pipeline.base import Pipeline, Output | |
from sam_extension.pipeline.groundingdino import GroundingDinoPipeline | |
from sam_extension.distillation_models.sam import load_distillation_sam, load_sam | |
from sam_extension.distillation_models import * | |
ORIGINAL_SAM_IMG_SIZE: int = 1024 | |
PIXEL_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) | |
PIXEL_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) | |
PREPROCESS = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN, PIXEL_STD) | |
POSTPROCESS_MASKS = functools.partial(postprocess_masks, ORIGINAL_SAM_IMG_SIZE) | |
class SAMEncoderOutput(Output): | |
features: torch.Tensor | |
interm_features: List[torch.Tensor] | |
original_size: Tuple | |
input_size: Tuple | |
class SAMEncoderProcesImgOutput(Output): | |
input_image: torch.Tensor | |
original_size: Tuple | |
input_size: Tuple | |
class SAMDecoderPredictOutput(Output): | |
masks_np: np.ndarray | |
iou_predictions_np: np.ndarray | |
low_res_masks_np: np.ndarray | |
class SAMDecoderPredictTorchOutput(Output): | |
masks: torch.Tensor | |
iou_predictions: torch.Tensor | |
low_res_masks: torch.Tensor | |
class SAMEncoderPipeline(Pipeline): | |
def __init__(self, | |
encoder: nn.Module, | |
input_img_size: Tuple, | |
multi_output: bool, | |
preprocess: Callable, | |
transform: ResizeLongestSide, | |
device: str, | |
*args, | |
**kwargs): | |
super(SAMEncoderPipeline, self).__init__(*args, **kwargs) | |
self.encoder = encoder | |
self.input_img_size = input_img_size | |
self.multi_output = multi_output | |
self.preprocess = preprocess | |
self.transform = transform | |
self.device = device | |
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs): | |
if 'sam_version' not in kwargs.keys(): | |
sam_version = 'sam' | |
else: | |
sam_version = kwargs['sam_version'] | |
sam = load_sam(ckpt_path, sam_version, device) | |
encoder = sam.image_encoder | |
encoder_type = encoder.__class__.__name__ | |
if encoder_type in ['TinyViT', 'FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']: | |
multi_output = False | |
if encoder_type in ['FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']: | |
input_img_size = (encoder.img_size, encoder.img_size) | |
if encoder_type == 'DINOSAMViT': | |
encoder = encoder.dino | |
else: | |
input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE) | |
else: | |
multi_output = True | |
input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE) | |
if sam.adaptor is None: | |
transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE) | |
preprocess_ = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN.to(device), PIXEL_STD.to(device)) | |
else: | |
transform = T.Compose([ | |
T.Resize(input_img_size), | |
T.ToTensor(), | |
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
]) | |
preprocess_ = None | |
pipeline = cls(encoder=encoder, | |
input_img_size=input_img_size, | |
multi_output=multi_output, | |
preprocess=preprocess_, | |
transform=transform, | |
device=device) | |
del sam, encoder | |
torch.cuda.empty_cache() | |
return pipeline | |
def process_img(self, img: Union[Image, np.ndarray]) -> SAMEncoderProcesImgOutput: | |
if self.preprocess is not None: | |
if isinstance(img, Image): | |
img = np.uint8(img) | |
input_image = self.transform.apply_image(img) | |
input_image_torch = torch.as_tensor(input_image, device=self.device) | |
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] | |
original_size = tuple(img.shape[:2]) | |
input_size = tuple(input_image_torch.shape[-2:]) | |
input_image = F.interpolate(self.preprocess(input_image_torch), size=self.input_img_size, mode='bilinear') | |
else: | |
if isinstance(img, np.ndarray): | |
img = PIL.Image.fromarray(img) | |
original_size = (img.size[1], img.size[0]) | |
if original_size[0] > original_size[1]: | |
input_h = 1024 | |
input_w = int((1024 / original_size[0]) * original_size[1]) | |
else: | |
input_w = 1024 | |
input_h = int((1024 / original_size[1]) * original_size[0]) | |
input_size = (input_h, input_w) | |
input_image = self.transform(img)[None, ...].to(self.device) | |
return SAMEncoderProcesImgOutput(input_image, original_size, input_size) | |
def get_visual_feature(self, x: Union[torch.Tensor, Image, np.ndarray]=None, **kwargs): | |
pca_rgb = PCA(n_components=3) | |
if 'sam_feature' in kwargs.keys() and 'original_size' in kwargs.keys(): | |
sam_feature = kwargs['sam_feature'] | |
original_size = kwargs['original_size'] | |
else: | |
assert x is not None, 'please give x type Union[torch.Tensor, Image, np.ndarray] !' | |
sam_encoder_output = self.forward(x, **kwargs) | |
sam_feature = sam_encoder_output.features | |
original_size = sam_encoder_output.original_size | |
assert original_size is not None, 'please give original_size!' | |
sam_feature = F.interpolate(sam_feature, size=original_size, mode='bilinear').permute(0, 2, 3, 1) | |
b, h, w, c = sam_feature.shape | |
sam_feature = sam_feature.view(-1, c).cpu().numpy() | |
sam_feature = pca_rgb.fit_transform(sam_feature) | |
sam_feature = torch.Tensor(sam_feature.reshape(h, w, 3)) | |
min_f, _ = sam_feature.min(-1) | |
max_f, _ = sam_feature.max(-1) | |
sam_feature = (sam_feature - min_f[..., None]) / (max_f[..., None] - min_f[..., None]) | |
sam_feature = sam_feature.cpu().numpy() | |
sam_feature_image = PIL.Image.fromarray((sam_feature * 255).astype(np.uint8)) | |
return sam_feature_image | |
def forward(self, x: Union[torch.Tensor, Image, np.ndarray], **kwargs) -> SAMEncoderOutput: | |
if isinstance(x, (Image, np.ndarray)): | |
process_img_output = self.process_img(x) | |
x = process_img_output.input_image | |
original_size = process_img_output.original_size | |
input_size = process_img_output.input_size | |
else: | |
original_size = kwargs.pop('original_size') if 'original_size' in kwargs.keys() else None | |
input_size = x.shape[-2:] | |
with torch.no_grad(): | |
if self.multi_output: | |
features, interm_features = self.encoder(x, **kwargs) | |
else: | |
features = self.encoder(x, **kwargs) | |
if self.encoder.__class__.__name__ == 'DINO': | |
features = features.permute(0, 3, 1, 2) | |
interm_features = None | |
return SAMEncoderOutput(features, interm_features, original_size, input_size) | |
class SAMDecoderPipeline(Pipeline): | |
def __init__(self, | |
prompt_encoder: nn.Module, | |
mask_decoder: nn.Module, | |
adaptor: nn.Module, | |
mask_threshold: float, | |
transform: ResizeLongestSide, | |
postprocess_masks: Callable, | |
img_size: int, | |
device: str, | |
*args, | |
**kwargs): | |
super(SAMDecoderPipeline, self).__init__(*args, **kwargs) | |
self.prompt_encoder = prompt_encoder | |
self.mask_decoder = mask_decoder | |
self.adaptor = adaptor | |
self.mask_threshold = mask_threshold | |
self.transform = transform | |
self.postprocess_masks = postprocess_masks | |
self.img_size = img_size | |
self.device = device | |
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs): | |
if 'sam_version' not in kwargs.keys(): | |
sam_version = 'sam' | |
else: | |
sam_version = kwargs['sam_version'] | |
sam = load_sam(ckpt_path, sam_version, device) | |
if sam.image_encoder.__class__.__name__ == 'DINOSAMViT': | |
adaptor = sam.image_encoder.adaptor | |
elif sam.adaptor is not None: | |
adaptor = sam.adaptor | |
else: | |
adaptor = None | |
img_size = sam.image_encoder.img_size | |
prompt_encoder = sam.prompt_encoder | |
mask_decoder = sam.mask_decoder | |
transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE) | |
pipeline = cls(prompt_encoder=prompt_encoder, | |
mask_decoder=mask_decoder, | |
adaptor=adaptor, | |
mask_threshold=sam.mask_threshold, | |
transform=transform, | |
postprocess_masks=POSTPROCESS_MASKS, | |
img_size=img_size, | |
device=device) | |
del sam, prompt_encoder, mask_decoder | |
torch.cuda.empty_cache() | |
return pipeline | |
def visualize_prompt(self, | |
img: Union[Image, np.ndarray], | |
des_img: Union[Image, np.ndarray] = None, | |
point_labels: Union[List[int], np.ndarray] = None, | |
point_coords: Union[List[List[int]], np.ndarray] = None, | |
boxes: Union[List[List[int]], np.ndarray] = None, | |
pil: bool = False | |
) -> Union[Image, np.ndarray]: | |
if des_img is not None: | |
if isinstance(des_img, np.ndarray): | |
des_shape = tuple(des_img.shape[:2]) | |
else: | |
des_shape = (des_img.size[1], des_img.size[0]) | |
src_shape = (img.size[1], img.size[0]) | |
point_coords, boxes = transform_coords(src_shape, des_shape, point_coords, boxes) | |
return add_prompts_tag(des_img, point_labels, point_coords, boxes, pil) | |
else: | |
return add_prompts_tag(img, point_labels, point_coords, boxes, pil) | |
def visualize_results(self, | |
img: Union[Image, np.ndarray], | |
des_img: Union[Image, np.ndarray] = None, | |
sam_encoder_output: Optional[SAMEncoderOutput] = None, | |
features: Optional[torch.Tensor] = None, | |
interm_features: Optional[List[torch.Tensor]] = None, | |
original_size: Optional[Tuple] = None, | |
input_size: Optional[Tuple] = None, | |
point_coords: Optional[np.ndarray] = None, | |
point_labels: Optional[np.ndarray] = None, | |
boxes: Optional[np.ndarray] = None, | |
texts: Optional[List] = None, | |
grounding_dino_pipeline: GroundingDinoPipeline = None, | |
box_threshold: float = 0.25, | |
text_threshold: float = 0.25, | |
nms_threshold: float = 0.8, | |
detections: Optional[sv.Detections] = None, | |
multimask_output: bool = True, | |
visualize_promts: bool = True, | |
pil: bool = False): | |
if isinstance(img, Image): | |
img = np.uint8(img) | |
if des_img is not None: | |
if isinstance(des_img, np.ndarray): | |
des_shape = tuple(des_img.shape[:2]) | |
else: | |
des_shape = (des_img.size[1], des_img.size[0]) | |
src_shape = img.shape[:2] | |
if point_coords is not None or boxes is not None: | |
des_point_coords, des_boxes = transform_coords(src_shape, des_shape, point_coords, boxes) | |
else: | |
des_point_coords = None | |
des_boxes = None | |
else: | |
des_point_coords = None | |
des_boxes = None | |
src_shape = None | |
des_shape = None | |
detections = get_empty_detections() if detections is None else detections | |
mask_annotator = sv.MaskAnnotator() | |
result_list = [] | |
mask_result_list = [] | |
mask_list = [] | |
if boxes is None and point_coords is None and point_labels is None and texts is None or \ | |
(point_coords is not None and point_labels is not None and point_coords.shape[0] != point_labels.shape[0]): | |
print('no prompt given!') | |
result_list.append(img) | |
return result_list | |
# if boxes is not None and point_coords is not None and point_labels is not None: | |
# multimask_output = False | |
def get_annotated_image(mask_annotator, | |
detections, | |
img, | |
point_labels=None, | |
point_coords=None, | |
boxes=None, | |
visualize_promts=True, | |
pil=False): | |
annotated_image = mask_annotator.annotate(scene=img.copy(), detections=detections) | |
if visualize_promts: | |
annotated_image = add_prompts_tag(annotated_image, point_labels, point_coords, boxes=boxes, pil=pil) | |
else: | |
if pil: | |
annotated_image = PIL.Image.fromarray(annotated_image) | |
return annotated_image | |
def get_masked_image(img, | |
masks, | |
pil=True): | |
masked_image_list = [] | |
for i in range(masks.shape[0]): | |
object_rgb = img * (masks[i].reshape(img.shape[0], img.shape[1], 1)) | |
object_rgb = object_rgb.astype(np.uint8) | |
bkgd_mask = np.where(object_rgb == 0, 1, 0) | |
bkgd_mask *= 255 | |
bkgd_mask = bkgd_mask.astype(np.uint8) | |
object_rgb += bkgd_mask | |
if pil: | |
masked_image_list.append(PIL.Image.fromarray(object_rgb)) | |
else: | |
masked_image_list.append(object_rgb) | |
return masked_image_list | |
def interpolate_mask(mask_np, des_shape): | |
mask_tensor = torch.tensor(mask_np, dtype=torch.float32).unsqueeze(0) | |
mask_interpolate = F.interpolate(mask_tensor, size=des_shape, mode='bilinear') | |
mask_interpolate = (mask_interpolate+0.5).long() | |
mask_np = mask_interpolate.squeeze(0).numpy().astype(bool) | |
return mask_np | |
if point_coords is not None and point_labels is not None: | |
if src_shape is not None: | |
point_result = self.forward(sam_encoder_output, | |
features, | |
interm_features, | |
original_size, | |
input_size, | |
des_point_coords, | |
point_labels) | |
masks_np = interpolate_mask(point_result.masks_np, src_shape) | |
else: | |
point_result = self.forward(sam_encoder_output, | |
features, | |
interm_features, | |
original_size, | |
input_size, | |
point_coords, | |
point_labels) | |
masks_np = point_result.masks_np | |
if multimask_output: | |
for i in range(masks_np.shape[0]): | |
detections.mask = masks_np[i][None, ...] | |
mask_list.append(masks_np[i]) | |
result_list.append(get_annotated_image(mask_annotator, | |
detections, | |
img, | |
point_labels=point_labels, | |
point_coords=point_coords, | |
visualize_promts=visualize_promts, | |
pil=pil)) | |
mask_result_list += get_masked_image(img, | |
detections.mask, | |
pil=pil) | |
else: | |
index = np.argmax(point_result.iou_predictions_np) | |
detections.mask = masks_np[index][None, ...] | |
mask_list.append(masks_np[index]) | |
result_list.append(get_annotated_image(mask_annotator, | |
detections, | |
img, | |
point_labels=point_labels, | |
point_coords=point_coords, | |
visualize_promts=visualize_promts, | |
pil=pil)) | |
mask_result_list += get_masked_image(img, | |
detections.mask, | |
pil=pil) | |
if boxes is not None: | |
result_masks = [] | |
if src_shape is not None: | |
boxes_ = des_boxes | |
else: | |
boxes_ = boxes | |
if boxes_.shape[0] > 1: | |
for i in range(len(boxes)): | |
box_result = self.forward(sam_encoder_output, | |
features, | |
interm_features, | |
original_size, | |
input_size, | |
box=boxes_[i]) | |
index = np.argmax(box_result.iou_predictions_np) | |
result_masks.append(box_result.masks_np[index]) | |
mask = np.array(result_masks) | |
if src_shape is not None: | |
masks_np = interpolate_mask(mask, src_shape) | |
else: | |
masks_np = mask | |
mask_list.append(masks_np) | |
detections.mask = masks_np | |
result_list.append(get_annotated_image(mask_annotator, | |
detections, | |
img, | |
boxes=boxes, | |
visualize_promts=visualize_promts, | |
pil=pil)) | |
mask_result_list += get_masked_image(img, | |
detections.mask, | |
pil=pil) | |
else: | |
box_result = self.forward(sam_encoder_output, | |
features, | |
interm_features, | |
original_size, | |
input_size, | |
box=boxes_) | |
if src_shape is not None: | |
masks_np = interpolate_mask(box_result.masks_np, src_shape) | |
else: | |
masks_np = box_result.masks_np | |
if multimask_output: | |
for i in range(masks_np.shape[0]): | |
detections.mask = masks_np[i][None, ...] | |
mask_list.append(masks_np[i]) | |
result_list.append(get_annotated_image(mask_annotator, | |
detections, | |
img, | |
boxes=boxes, | |
visualize_promts=visualize_promts, | |
pil=pil)) | |
mask_result_list += get_masked_image(img, | |
detections.mask, | |
pil=pil) | |
else: | |
index = np.argmax(box_result.iou_predictions_np) | |
detections.mask = masks_np[index][None, ...] | |
mask_list.append(masks_np[index]) | |
result_list.append(get_annotated_image(mask_annotator, detections, img, boxes=boxes, pil=pil)) | |
mask_result_list += get_masked_image(img, | |
detections.mask, | |
pil=pil) | |
if texts is not None and grounding_dino_pipeline is not None: | |
detections = grounding_dino_pipeline(img[:, :, ::-1], texts, box_threshold, text_threshold) | |
box_annotator = sv.BoxAnnotator() | |
nms_idx = torchvision.ops.nms( | |
torch.from_numpy(detections.xyxy), | |
torch.from_numpy(detections.confidence), | |
nms_threshold | |
).numpy().tolist() | |
detections.xyxy = detections.xyxy[nms_idx] | |
detections.confidence = detections.confidence[nms_idx] | |
detections.class_id = detections.class_id[nms_idx] | |
labels = [ | |
f"{texts[class_id]} {confidence:0.2f}" | |
for _, _, confidence, class_id, _ | |
in detections] | |
result_masks = [] | |
if src_shape is not None: | |
_, boxes_ = transform_coords(src_shape, des_shape, boxes=detections.xyxy) | |
else: | |
boxes_ = detections.xyxy | |
for box in boxes_: | |
box_result = self.forward(sam_encoder_output, | |
features, | |
interm_features, | |
original_size, | |
input_size, | |
box=box) | |
index = np.argmax(box_result.iou_predictions_np) | |
result_masks.append(box_result.masks_np[index]) | |
mask = np.array(result_masks) | |
if src_shape is not None: | |
detections.mask = interpolate_mask(mask, src_shape) | |
else: | |
detections.mask = mask | |
for i in range(detections.mask.shape[0]): | |
mask_list.append(detections.mask[i, ...]) | |
if visualize_promts: | |
annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections) | |
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) | |
else: | |
annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections) | |
if pil: | |
result_list.append(PIL.Image.fromarray(annotated_image[:, :, ::-1])) | |
else: | |
result_list.append(annotated_image[:, :, ::-1]) | |
mask_result_list += get_masked_image(img, | |
detections.mask, | |
pil=pil) | |
return result_list, mask_result_list, mask_list | |
def predict( | |
self, | |
features: torch.Tensor, | |
interm_features: List[torch.Tensor], | |
original_size: Tuple, | |
input_size: Tuple, | |
point_coords: Optional[np.ndarray] = None, | |
point_labels: Optional[np.ndarray] = None, | |
box: Optional[np.ndarray] = None, | |
mask_input: Optional[np.ndarray] = None, | |
multimask_output: bool = True, | |
return_logits: bool = False, | |
hq_token_only: bool = False, | |
) -> SAMDecoderPredictOutput: | |
""" | |
Predict masks for the given input prompts, using the currently set image. | |
Arguments: | |
point_coords (np.ndarray or None): A Nx2 array of point prompts to the | |
model. Each point is in (X,Y) in pixels. | |
point_labels (np.ndarray or None): A length N array of labels for the | |
point prompts. 1 indicates a foreground point and 0 indicates a | |
background point. | |
box (np.ndarray or None): A length 4 array given a box prompt to the | |
model, in XYXY format. | |
mask_input (np.ndarray): A low resolution mask input to the model, typically | |
coming from a previous prediction iteration. Has form 1xHxW, where | |
for SAM, H=W=256. | |
multimask_output (bool): If true, the model will return three masks. | |
For ambiguous input prompts (such as a single click), this will often | |
produce better masks than a single prediction. If only a single | |
mask is needed, the model's predicted quality score can be used | |
to select the best mask. For non-ambiguous prompts, such as multiple | |
input prompts, multimask_output=False can give better results. | |
return_logits (bool): If true, returns un-thresholded masks logits | |
instead of a binary mask. | |
Returns: | |
(np.ndarray): The output masks in CxHxW format, where C is the | |
number of masks, and (H, W) is the original image size. | |
(np.ndarray): An array of length C containing the model's | |
predictions for the quality of each mask. | |
(np.ndarray): An array of shape CxHxW, where C is the number | |
of masks and H=W=256. These low resolution logits can be passed to | |
a subsequent iteration as mask input. | |
""" | |
# Transform input prompts | |
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None | |
if point_coords is not None: | |
assert ( | |
point_labels is not None | |
), "point_labels must be supplied if point_coords is supplied." | |
point_coords = self.transform.apply_coords(point_coords, original_size) | |
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) | |
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) | |
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] | |
if box is not None: | |
box = self.transform.apply_boxes(box, original_size) | |
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) | |
box_torch = box_torch[None, :] | |
if mask_input is not None: | |
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) | |
mask_input_torch = mask_input_torch[None, :, :, :] | |
sam_decoder_predict_torch_output = self.predict_torch( | |
features, | |
interm_features, | |
original_size, | |
input_size, | |
coords_torch, | |
labels_torch, | |
box_torch, | |
mask_input_torch, | |
multimask_output, | |
return_logits=return_logits, | |
hq_token_only=hq_token_only, | |
) | |
masks_np = sam_decoder_predict_torch_output.masks[0].detach().cpu().numpy() | |
iou_predictions_np = sam_decoder_predict_torch_output.iou_predictions[0].detach().cpu().numpy() | |
low_res_masks_np = sam_decoder_predict_torch_output.low_res_masks[0].detach().cpu().numpy() | |
return SAMDecoderPredictOutput(masks_np, iou_predictions_np, low_res_masks_np) | |
def predict_torch( | |
self, | |
features: torch.Tensor, | |
interm_features: List[torch.Tensor], | |
original_size: Tuple, | |
input_size: Tuple, | |
point_coords: Optional[torch.Tensor], | |
point_labels: Optional[torch.Tensor], | |
boxes: Optional[torch.Tensor] = None, | |
mask_input: Optional[torch.Tensor] = None, | |
multimask_output: bool = True, | |
return_logits: bool = False, | |
hq_token_only: bool = False, | |
) -> SAMDecoderPredictTorchOutput: | |
""" | |
Predict masks for the given input prompts, using the currently set image. | |
Input prompts are batched torch tensors and are expected to already be | |
transformed to the input frame using ResizeLongestSide. | |
Arguments: | |
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the | |
model. Each point is in (X,Y) in pixels. | |
point_labels (torch.Tensor or None): A BxN array of labels for the | |
point prompts. 1 indicates a foreground point and 0 indicates a | |
background point. | |
boxes (np.ndarray or None): A Bx4 array given a box prompt to the | |
model, in XYXY format. | |
mask_input (np.ndarray): A low resolution mask input to the model, typically | |
coming from a previous prediction iteration. Has form Bx1xHxW, where | |
for SAM, H=W=256. Masks returned by a previous iteration of the | |
predict method do not need further transformation. | |
multimask_output (bool): If true, the model will return three masks. | |
For ambiguous input prompts (such as a single click), this will often | |
produce better masks than a single prediction. If only a single | |
mask is needed, the model's predicted quality score can be used | |
to select the best mask. For non-ambiguous prompts, such as multiple | |
input prompts, multimask_output=False can give better results. | |
return_logits (bool): If true, returns un-thresholded masks logits | |
instead of a binary mask. | |
Returns: | |
(torch.Tensor): The output masks in BxCxHxW format, where C is the | |
number of masks, and (H, W) is the original image size. | |
(torch.Tensor): An array of shape BxC containing the model's | |
predictions for the quality of each mask. | |
(torch.Tensor): An array of shape BxCxHxW, where C is the number | |
of masks and H=W=256. These low res logits can be passed to | |
a subsequent iteration as mask input. | |
""" | |
if point_coords is not None: | |
points = (point_coords, point_labels) | |
else: | |
points = None | |
# Embed prompts | |
sparse_embeddings, dense_embeddings = self.prompt_encoder( | |
points=points, | |
boxes=boxes, | |
masks=mask_input, | |
) | |
# Predict masks | |
low_res_masks, iou_predictions = self.mask_decoder( | |
image_embeddings=features, | |
image_pe=self.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=multimask_output, | |
hq_token_only=hq_token_only, | |
interm_embeddings=interm_features, | |
) | |
# Upscale the masks to the original image resolution | |
# masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) | |
masks = self.postprocess_masks(low_res_masks, input_size, original_size) | |
if not return_logits: | |
masks = masks > self.mask_threshold | |
return SAMDecoderPredictTorchOutput(masks, iou_predictions, low_res_masks) | |
def forward(self, | |
sam_encoder_output: Optional[SAMEncoderOutput]=None, | |
features: Optional[torch.Tensor]=None, | |
interm_features: Optional[List[torch.Tensor]]=None, | |
original_size: Optional[Tuple]=None, | |
input_size: Optional[Tuple]=None, | |
point_coords: Optional[np.ndarray] = None, | |
point_labels: Optional[np.ndarray] = None, | |
box: Optional[np.ndarray] = None, | |
mask_input: Optional[np.ndarray] = None, | |
multimask_output: bool = True, | |
return_logits: bool = False, | |
hq_token_only: bool = False, | |
dino: bool = False | |
) -> SAMDecoderPredictOutput: | |
assert sam_encoder_output or (features is not None and original_size is not None and input_size is not None), 'one of sam_encoder_output and four necessary inputs must be given!' | |
if sam_encoder_output: | |
features = sam_encoder_output.features | |
interm_features = sam_encoder_output.interm_features | |
original_size = sam_encoder_output.original_size | |
input_size = sam_encoder_output.input_size | |
if self.adaptor is not None: | |
if dino: | |
features = F.interpolate(F.normalize(features, dim=1), size=(64, 64), mode='bilinear').permute(0, 2, 3, 1) | |
features = self.adaptor(features) | |
# | |
# else: | |
# features = self.adaptor(features, original_size) | |
return self.predict(features, | |
interm_features, | |
original_size, | |
input_size, | |
point_coords, | |
point_labels, | |
box, | |
mask_input, | |
multimask_output, | |
return_logits, | |
hq_token_only) | |
''' | |
class SAMPipeline(Pipeline): | |
@classmethod | |
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs): | |
sam_encoder_pipeline = SAMEncoderPipeline(ckpt_path, device, *args, **kwargs) | |
sam_decoder_pipeline = SAMDecoderPipeline(ckpt_path, device, *args, **kwargs) | |
pipeline = cls(**dict(sam_encoder_pipeline=sam_encoder_pipeline, | |
sam_decoder_pipeline=sam_decoder_pipeline, | |
device=device)) | |
return pipeline | |
''' | |