# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l from segment_anything.utils.onnx import SamOnnxModel import argparse import warnings try: import onnxruntime # type: ignore onnxruntime_exists = True except ImportError: onnxruntime_exists = False parser = argparse.ArgumentParser( description="Export the SAM prompt encoder and mask decoder to an ONNX model." ) parser.add_argument( "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." ) parser.add_argument( "--output", type=str, required=True, help="The filename to save the ONNX model to." ) parser.add_argument( "--model-type", type=str, default="default", help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.", ) parser.add_argument( "--return-single-mask", action="store_true", help=( "If true, the exported ONNX model will only return the best mask, " "instead of returning multiple masks. For high resolution images " "this can improve runtime when upscaling masks is expensive." ), ) parser.add_argument( "--opset", type=int, default=17, help="The ONNX opset version to use. Must be >=11", ) parser.add_argument( "--quantize-out", type=str, default=None, help=( "If set, will quantize the model and save it with this name. " "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." ), ) parser.add_argument( "--gelu-approximate", action="store_true", help=( "Replace GELU operations with approximations using tanh. Useful " "for some runtimes that have slow or unimplemented erf ops, used in GELU." ), ) parser.add_argument( "--use-stability-score", action="store_true", help=( "Replaces the model's predicted mask quality score with the stability " "score calculated on the low resolution masks using an offset of 1.0. " ), ) parser.add_argument( "--return-extra-metrics", action="store_true", help=( "The model will return five results: (masks, scores, stability_scores, " "areas, low_res_logits) instead of the usual three. This can be " "significantly slower for high resolution outputs." ), ) def run_export( model_type: str, checkpoint: str, output: str, opset: int, return_single_mask: bool, gelu_approximate: bool = False, use_stability_score: bool = False, return_extra_metrics=False, ): print("Loading model...") if model_type == "vit_b": sam = build_sam_vit_b(checkpoint) elif model_type == "vit_l": sam = build_sam_vit_l(checkpoint) else: sam = build_sam(checkpoint) onnx_model = SamOnnxModel( model=sam, return_single_mask=return_single_mask, use_stability_score=use_stability_score, return_extra_metrics=return_extra_metrics, ) if gelu_approximate: for n, m in onnx_model.named_modules(): if isinstance(m, torch.nn.GELU): m.approximate = "tanh" dynamic_axes = { "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}, } embed_dim = sam.prompt_encoder.embed_dim embed_size = sam.prompt_encoder.image_embedding_size mask_input_size = [4 * x for x in embed_size] dummy_inputs = { "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), "has_mask_input": torch.tensor([1], dtype=torch.float), "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), } _ = onnx_model(**dummy_inputs) output_names = ["masks", "iou_predictions", "low_res_masks"] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) with open(output, "wb") as f: print(f"Exporing onnx model to {output}...") torch.onnx.export( onnx_model, tuple(dummy_inputs.values()), f, export_params=True, verbose=False, opset_version=opset, do_constant_folding=True, input_names=list(dummy_inputs.keys()), output_names=output_names, dynamic_axes=dynamic_axes, ) if onnxruntime_exists: ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} ort_session = onnxruntime.InferenceSession(output) _ = ort_session.run(None, ort_inputs) print("Model has successfully been run with ONNXRuntime.") def to_numpy(tensor): return tensor.cpu().numpy() if __name__ == "__main__": args = parser.parse_args() run_export( model_type=args.model_type, checkpoint=args.checkpoint, output=args.output, opset=args.opset, return_single_mask=args.return_single_mask, gelu_approximate=args.gelu_approximate, use_stability_score=args.use_stability_score, return_extra_metrics=args.return_extra_metrics, ) if args.quantize_out is not None: assert onnxruntime_exists, "onnxruntime is required to quantize the model." from onnxruntime.quantization import QuantType # type: ignore from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore print(f"Quantizing model and writing to {args.quantize_out}...") quantize_dynamic( model_input=args.output, model_output=args.quantize_out, optimize_model=True, per_channel=False, reduce_range=False, weight_type=QuantType.QUInt8, ) print("Done!")