import torch import numpy as np from mobile_sam import sam_model_registry from .onnx_image_encoder import ImageEncoderOnnxModel import os import argparse import warnings try: import onnxruntime # type: ignore onnxruntime_exists = True except ImportError: onnxruntime_exists = False parser = argparse.ArgumentParser( description="Export the SAM image encoder to an ONNX model." ) parser.add_argument( "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint.", ) parser.add_argument( "--output", type=str, required=True, help="The filename to save the ONNX model to." ) parser.add_argument( "--model-type", type=str, required=True, help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", ) parser.add_argument( "--use-preprocess", action="store_true", help="Whether to preprocess the image by resizing, standardizing, etc.", ) parser.add_argument( "--opset", type=int, default=17, help="The ONNX opset version to use. Must be >=11", ) parser.add_argument( "--quantize-out", type=str, default=None, help=( "If set, will quantize the model and save it with this name. " "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." ), ) parser.add_argument( "--gelu-approximate", action="store_true", help=( "Replace GELU operations with approximations using tanh. Useful " "for some runtimes that have slow or unimplemented erf ops, used in GELU." ), ) def run_export( model_type: str, checkpoint: str, output: str, use_preprocess: bool, opset: int, gelu_approximate: bool = False, ): print("Loading model...") sam = sam_model_registry[model_type](checkpoint=checkpoint) onnx_model = ImageEncoderOnnxModel( model=sam, use_preprocess=use_preprocess, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) if gelu_approximate: for n, m in onnx_model.named_modules(): if isinstance(m, torch.nn.GELU): m.approximate = "tanh" image_size = sam.image_encoder.img_size if use_preprocess: dummy_input = { "input_image": torch.randn((image_size, image_size, 3), dtype=torch.float) } dynamic_axes = { "input_image": {0: "image_height", 1: "image_width"}, } else: dummy_input = { "input_image": torch.randn( (1, 3, image_size, image_size), dtype=torch.float ) } dynamic_axes = None _ = onnx_model(**dummy_input) output_names = ["image_embeddings"] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) print(f"Exporting onnx model to {output}...") if model_type == "vit_h": output_dir, output_file = os.path.split(output) os.makedirs(output_dir, mode=0o777, exist_ok=True) torch.onnx.export( onnx_model, tuple(dummy_input.values()), output, export_params=True, verbose=False, opset_version=opset, do_constant_folding=True, input_names=list(dummy_input.keys()), output_names=output_names, dynamic_axes=dynamic_axes, ) else: with open(output, "wb") as f: torch.onnx.export( onnx_model, tuple(dummy_input.values()), f, export_params=True, verbose=False, opset_version=opset, do_constant_folding=True, input_names=list(dummy_input.keys()), output_names=output_names, dynamic_axes=dynamic_axes, ) if onnxruntime_exists: ort_inputs = {k: to_numpy(v) for k, v in dummy_input.items()} providers = ["CPUExecutionProvider"] if model_type == "vit_h": session_option = onnxruntime.SessionOptions() ort_session = onnxruntime.InferenceSession(output, providers=providers) param_file = os.listdir(output_dir) param_file.remove(output_file) for i, layer in enumerate(param_file): with open(os.path.join(output_dir, layer), "rb") as fp: weights = np.frombuffer(fp.read(), dtype=np.float32) weights = onnxruntime.OrtValue.ortvalue_from_numpy(weights) session_option.add_initializer(layer, weights) else: ort_session = onnxruntime.InferenceSession(output, providers=providers) _ = ort_session.run(None, ort_inputs) print("Model has successfully been run with ONNXRuntime.") def to_numpy(tensor): return tensor.cpu().numpy() if __name__ == "__main__": args = parser.parse_args() run_export( model_type=args.model_type, checkpoint=args.checkpoint, output=args.output, use_preprocess=args.use_preprocess, opset=args.opset, gelu_approximate=args.gelu_approximate, ) if args.quantize_out is not None: assert onnxruntime_exists, "onnxruntime is required to quantize the model." from onnxruntime.quantization import QuantType # type: ignore from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore print(f"Quantizing model and writing to {args.quantize_out}...") quantize_dynamic( model_input=args.output, model_output=args.quantize_out, optimize_model=True, per_channel=False, reduce_range=False, weight_type=QuantType.QUInt8, ) print("Done!")