# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from typing import Any, List, Tuple, Type import torch import torch.nn.functional as F from torch import nn, Tensor from .efficient_sam_decoder import MaskDecoder, PromptEncoder from .efficient_sam_encoder import ImageEncoderViT from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer class EfficientSam(nn.Module): mask_threshold: float = 0.0 image_format: str = "RGB" def __init__( self, image_encoder: ImageEncoderViT, prompt_encoder: PromptEncoder, decoder_max_num_input_points: int, mask_decoder: MaskDecoder, pixel_mean: List[float] = [0.485, 0.456, 0.406], pixel_std: List[float] = [0.229, 0.224, 0.225], ) -> None: """ SAM predicts object masks from an image and input prompts. Arguments: image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for efficient mask prediction. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image. """ super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.decoder_max_num_input_points = decoder_max_num_input_points self.mask_decoder = mask_decoder self.register_buffer( "pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False ) self.register_buffer( "pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False ) @torch.jit.export def predict_masks( self, image_embeddings: torch.Tensor, batched_points: torch.Tensor, batched_point_labels: torch.Tensor, multimask_output: bool, input_h: int, input_w: int, output_h: int = -1, output_w: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predicts masks given image embeddings and prompts. This only runs the decoder. Arguments: image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W] batched_points: A tensor of shape [B, max_num_queries, num_pts, 2] batched_point_labels: A tensor of shape [B, max_num_queries, num_pts] Returns: A tuple of two tensors: low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores """ batch_size, max_num_queries, num_pts, _ = batched_points.shape num_pts = batched_points.shape[2] rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w) if num_pts > self.decoder_max_num_input_points: rescaled_batched_points = rescaled_batched_points[ :, :, : self.decoder_max_num_input_points, : ] batched_point_labels = batched_point_labels[ :, :, : self.decoder_max_num_input_points ] elif num_pts < self.decoder_max_num_input_points: rescaled_batched_points = F.pad( rescaled_batched_points, (0, 0, 0, self.decoder_max_num_input_points - num_pts), value=-1.0, ) batched_point_labels = F.pad( batched_point_labels, (0, self.decoder_max_num_input_points - num_pts), value=-1.0, ) sparse_embeddings = self.prompt_encoder( rescaled_batched_points.reshape( batch_size * max_num_queries, self.decoder_max_num_input_points, 2 ), batched_point_labels.reshape( batch_size * max_num_queries, self.decoder_max_num_input_points ), ) sparse_embeddings = sparse_embeddings.view( batch_size, max_num_queries, sparse_embeddings.shape[1], sparse_embeddings.shape[2], ) low_res_masks, iou_predictions = self.mask_decoder( image_embeddings, self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, multimask_output=multimask_output, ) _, num_predictions, low_res_size, _ = low_res_masks.shape if output_w > 0 and output_h > 0: output_masks = F.interpolate( low_res_masks, (output_h, output_w), mode="bicubic" ) output_masks = torch.reshape( output_masks, (batch_size, max_num_queries, num_predictions, output_h, output_w), ) else: output_masks = torch.reshape( low_res_masks, ( batch_size, max_num_queries, num_predictions, low_res_size, low_res_size, ), ) iou_predictions = torch.reshape( iou_predictions, (batch_size, max_num_queries, num_predictions) ) sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True) iou_predictions = torch.take_along_dim(iou_predictions, sorted_ids, dim=2) output_masks = torch.take_along_dim( output_masks, sorted_ids[..., None, None], dim=2 ) return output_masks, iou_predictions def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int): return torch.stack( [ torch.where( batched_points[..., 0] >= 0, batched_points[..., 0] * self.image_encoder.img_size / input_w, -1.0, ), torch.where( batched_points[..., 1] >= 0, batched_points[..., 1] * self.image_encoder.img_size / input_h, -1.0, ), ], dim=-1, ) @torch.jit.export def get_image_embeddings(self, batched_images) -> torch.Tensor: """ Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using SamPredictor is recommended over calling the model directly. Arguments: batched_images: A tensor of shape [B, 3, H, W] Returns: List of image embeddings each of of shape [B, C(i), H(i), W(i)]. The last embedding corresponds to the final layer. """ batched_images = self.preprocess(batched_images) return self.image_encoder(batched_images) def forward( self, batched_images: torch.Tensor, batched_points: torch.Tensor, batched_point_labels: torch.Tensor, scale_to_original_image_size: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using SamPredictor is recommended over calling the model directly. Arguments: batched_images: A tensor of shape [B, 3, H, W] batched_points: A tensor of shape [B, num_queries, max_num_pts, 2] batched_point_labels: A tensor of shape [B, num_queries, max_num_pts] Returns: A list tuples of two tensors where the ith element is by considering the first i+1 points. low_res_mask: A tensor of shape [B, 256, 256] of predicted masks iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores """ batch_size, _, input_h, input_w = batched_images.shape image_embeddings = self.get_image_embeddings(batched_images) return self.predict_masks( image_embeddings, batched_points, batched_point_labels, multimask_output=True, input_h=input_h, input_w=input_w, output_h=input_h if scale_to_original_image_size else -1, output_w=input_w if scale_to_original_image_size else -1, ) def preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" if ( x.shape[2] != self.image_encoder.img_size or x.shape[3] != self.image_encoder.img_size ): x = F.interpolate( x, (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", ) return (x - self.pixel_mean) / self.pixel_std def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None): img_size = 1024 encoder_patch_size = 16 encoder_depth = 12 encoder_mlp_ratio = 4.0 encoder_neck_dims = [256, 256] decoder_max_num_input_points = 6 decoder_transformer_depth = 2 decoder_transformer_mlp_dim = 2048 decoder_num_heads = 8 decoder_upscaling_layer_dims = [64, 32] num_multimask_outputs = 3 iou_head_depth = 3 iou_head_hidden_dim = 256 activation = "gelu" normalization_type = "layer_norm" normalize_before_activation = False assert activation == "relu" or activation == "gelu" if activation == "relu": activation_fn = nn.ReLU else: activation_fn = nn.GELU image_encoder = ImageEncoderViT( img_size=img_size, patch_size=encoder_patch_size, in_chans=3, patch_embed_dim=encoder_patch_embed_dim, normalization_type=normalization_type, depth=encoder_depth, num_heads=encoder_num_heads, mlp_ratio=encoder_mlp_ratio, neck_dims=encoder_neck_dims, act_layer=activation_fn, ) image_embedding_size = image_encoder.image_embedding_size encoder_transformer_output_dim = image_encoder.transformer_output_dim sam = EfficientSam( image_encoder=image_encoder, prompt_encoder=PromptEncoder( embed_dim=encoder_transformer_output_dim, image_embedding_size=(image_embedding_size, image_embedding_size), input_image_size=(img_size, img_size), ), decoder_max_num_input_points=decoder_max_num_input_points, mask_decoder=MaskDecoder( transformer_dim=encoder_transformer_output_dim, transformer=TwoWayTransformer( depth=decoder_transformer_depth, embedding_dim=encoder_transformer_output_dim, num_heads=decoder_num_heads, mlp_dim=decoder_transformer_mlp_dim, activation=activation_fn, normalize_before_activation=normalize_before_activation, ), num_multimask_outputs=num_multimask_outputs, activation=activation_fn, normalization_type=normalization_type, normalize_before_activation=normalize_before_activation, iou_head_depth=iou_head_depth - 1, iou_head_hidden_dim=iou_head_hidden_dim, upscaling_layer_dims=decoder_upscaling_layer_dims, ), pixel_mean=[0.485, 0.456, 0.406], pixel_std=[0.229, 0.224, 0.225], ) if checkpoint is not None: with open(checkpoint, "rb") as f: state_dict = torch.load(f, map_location="cpu") sam.load_state_dict(state_dict["model"]) return sam