File size: 6,057 Bytes
110a69d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import torch
import numpy as np
from mobile_sam import sam_model_registry
from .onnx_image_encoder import ImageEncoderOnnxModel
import os
import argparse
import warnings
try:
import onnxruntime # type: ignore
onnxruntime_exists = True
except ImportError:
onnxruntime_exists = False
parser = argparse.ArgumentParser(
description="Export the SAM image encoder to an ONNX model."
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="The path to the SAM model checkpoint.",
)
parser.add_argument(
"--output", type=str, required=True, help="The filename to save the ONNX model to."
)
parser.add_argument(
"--model-type",
type=str,
required=True,
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
)
parser.add_argument(
"--use-preprocess",
action="store_true",
help="Whether to preprocess the image by resizing, standardizing, etc.",
)
parser.add_argument(
"--opset",
type=int,
default=17,
help="The ONNX opset version to use. Must be >=11",
)
parser.add_argument(
"--quantize-out",
type=str,
default=None,
help=(
"If set, will quantize the model and save it with this name. "
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
),
)
parser.add_argument(
"--gelu-approximate",
action="store_true",
help=(
"Replace GELU operations with approximations using tanh. Useful "
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
),
)
def run_export(
model_type: str,
checkpoint: str,
output: str,
use_preprocess: bool,
opset: int,
gelu_approximate: bool = False,
):
print("Loading model...")
sam = sam_model_registry[model_type](checkpoint=checkpoint)
onnx_model = ImageEncoderOnnxModel(
model=sam,
use_preprocess=use_preprocess,
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
if gelu_approximate:
for n, m in onnx_model.named_modules():
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"
image_size = sam.image_encoder.img_size
if use_preprocess:
dummy_input = {
"input_image": torch.randn((image_size, image_size, 3), dtype=torch.float)
}
dynamic_axes = {
"input_image": {0: "image_height", 1: "image_width"},
}
else:
dummy_input = {
"input_image": torch.randn(
(1, 3, image_size, image_size), dtype=torch.float
)
}
dynamic_axes = None
_ = onnx_model(**dummy_input)
output_names = ["image_embeddings"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
print(f"Exporting onnx model to {output}...")
if model_type == "vit_h":
output_dir, output_file = os.path.split(output)
os.makedirs(output_dir, mode=0o777, exist_ok=True)
torch.onnx.export(
onnx_model,
tuple(dummy_input.values()),
output,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_input.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
else:
with open(output, "wb") as f:
torch.onnx.export(
onnx_model,
tuple(dummy_input.values()),
f,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_input.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
if onnxruntime_exists:
ort_inputs = {k: to_numpy(v) for k, v in dummy_input.items()}
providers = ["CPUExecutionProvider"]
if model_type == "vit_h":
session_option = onnxruntime.SessionOptions()
ort_session = onnxruntime.InferenceSession(output, providers=providers)
param_file = os.listdir(output_dir)
param_file.remove(output_file)
for i, layer in enumerate(param_file):
with open(os.path.join(output_dir, layer), "rb") as fp:
weights = np.frombuffer(fp.read(), dtype=np.float32)
weights = onnxruntime.OrtValue.ortvalue_from_numpy(weights)
session_option.add_initializer(layer, weights)
else:
ort_session = onnxruntime.InferenceSession(output, providers=providers)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")
def to_numpy(tensor):
return tensor.cpu().numpy()
if __name__ == "__main__":
args = parser.parse_args()
run_export(
model_type=args.model_type,
checkpoint=args.checkpoint,
output=args.output,
use_preprocess=args.use_preprocess,
opset=args.opset,
gelu_approximate=args.gelu_approximate,
)
if args.quantize_out is not None:
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
from onnxruntime.quantization import QuantType # type: ignore
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")
|