# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import urllib import warnings from typing import Tuple import onnx import torch import torch.nn as nn from onnxruntime.quantization import QuantType from onnxruntime.quantization.quantize import quantize_dynamic from segment_anything import sam_model_registry from segment_anything.modeling import Sam from segment_anything.utils.amg import calculate_stability_score from torch.nn import functional as F CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM") CHECKPOINT_NAME = "sam_vit_h_4b8939.pth" CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" MODEL_TYPE = "default" class SamOnnxModel(nn.Module): """ This model should not be called directly, but is used in ONNX export. It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, with some functions modified to enable model tracing. Also supports extra options controlling what information. See the ONNX export script for details. """ def __init__( self, model: Sam, return_single_mask: bool, use_stability_score: bool = False, return_extra_metrics: bool = False, ) -> None: super().__init__() self.mask_decoder = model.mask_decoder self.model = model self.img_size = model.image_encoder.img_size self.return_single_mask = return_single_mask self.use_stability_score = use_stability_score self.stability_score_offset = 1.0 self.return_extra_metrics = return_extra_metrics @staticmethod def resize_longest_image_size( input_image_size: torch.Tensor, longest_side: int ) -> torch.Tensor: input_image_size = input_image_size.to(torch.float32) scale = longest_side / torch.max(input_image_size) transformed_size = scale * input_image_size transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) return transformed_size def _embed_points( self, point_coords: torch.Tensor, point_labels: torch.Tensor ) -> torch.Tensor: point_coords = point_coords + 0.5 point_coords = point_coords / self.img_size point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) point_embedding = point_embedding * (point_labels != -1) point_embedding = ( point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) ) for i in range(self.model.prompt_encoder.num_point_embeddings): point_embedding = ( point_embedding + self.model.prompt_encoder.point_embeddings[i].weight * (point_labels == i) ) return point_embedding def _embed_masks( self, input_mask: torch.Tensor, has_mask_input: torch.Tensor ) -> torch.Tensor: mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling( input_mask ) mask_embedding = mask_embedding + ( 1 - has_mask_input ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) return mask_embedding def mask_postprocessing( self, masks: torch.Tensor, orig_im_size: torch.Tensor ) -> torch.Tensor: masks = F.interpolate( masks, size=(self.img_size, self.img_size), mode="bilinear", align_corners=False, ) prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to( torch.int64 ) masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore orig_im_size = orig_im_size.to(torch.int64) h, w = orig_im_size[0], orig_im_size[1] masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) return masks def select_masks( self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int ) -> Tuple[torch.Tensor, torch.Tensor]: # Determine if we should return the multiclick mask # or not from the number of points. # The reweighting is used to avoid control flow. score_reweight = torch.tensor( [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] ).to(iou_preds.device) score = iou_preds + (num_points - 2.5) * score_reweight best_idx = torch.argmax(score, dim=1) masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) return masks, iou_preds @torch.no_grad() def forward( self, image_embeddings: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor, mask_input: torch.Tensor, has_mask_input: torch.Tensor, orig_im_size: torch.Tensor, ): sparse_embedding = self._embed_points(point_coords, point_labels) dense_embedding = self._embed_masks(mask_input, has_mask_input) masks, scores = self.model.mask_decoder.predict_masks( image_embeddings=image_embeddings, image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embedding, dense_prompt_embeddings=dense_embedding, ) if self.use_stability_score: scores = calculate_stability_score( masks, self.model.mask_threshold, self.stability_score_offset ) if self.return_single_mask: masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) upscaled_masks = self.mask_postprocessing(masks, orig_im_size) if self.return_extra_metrics: stability_scores = calculate_stability_score( upscaled_masks, self.model.mask_threshold, self.stability_score_offset ) areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) return upscaled_masks, scores, stability_scores, areas, masks return upscaled_masks, scores, masks def load_model( checkpoint_path: str = CHECKPOINT_PATH, checkpoint_name: str = CHECKPOINT_NAME, checkpoint_url: str = CHECKPOINT_URL, model_type: str = MODEL_TYPE, ) -> Sam: if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) checkpoint = os.path.join(checkpoint_path, checkpoint_name) if not os.path.exists(checkpoint): print("Downloading the model weights...") urllib.request.urlretrieve(checkpoint_url, checkpoint) print(f"The model weights saved as {checkpoint}") print(f"Load the model weights from {checkpoint}") return sam_model_registry[model_type](checkpoint=checkpoint) if __name__ == "__main__": sam = load_model() onnx_model = SamOnnxModel(sam, return_single_mask=True) dynamic_axes = { "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}, } embed_dim = sam.prompt_encoder.embed_dim embed_size = sam.prompt_encoder.image_embedding_size mask_input_size = [4 * x for x in embed_size] dummy_inputs = { "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), "point_coords": torch.randint( low=0, high=1024, size=(1, 5, 2), dtype=torch.float ), "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), "has_mask_input": torch.tensor([1], dtype=torch.float), "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), } output_names = ["masks", "iou_predictions", "low_res_masks"] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) torch.onnx.export( onnx_model, tuple(dummy_inputs.values()), "sam_decoder.onnx", export_params=True, opset_version=17, do_constant_folding=True, input_names=list(dummy_inputs.keys()), output_names=output_names, dynamic_axes=dynamic_axes, ) quantize_dynamic( model_input="sam_decoder.onnx", model_output="sam_decoder_uint8.onnx", optimize_model=True, per_channel=False, reduce_range=False, weight_type=QuantType.QUInt8, ) # Validate onnx.checker.check_model("sam_decoder_uint8.onnx")