Add quantization script
Browse files- .gitattributes +2 -0
- .gitignore +1 -0
- convert_onnx.py +238 -0
- requirements.txt +4 -0
- sam_decoder.onnx +3 -0
- sam_quantized_uint8.onnx → sam_decoder_uint8.onnx +2 -2
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
sam_decoder.onnx filter=lfs diff=lfs merge=lfs -text
|
36 |
+
sam_decoder_uint8.onnx filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
convert_onnx.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import urllib
|
9 |
+
import warnings
|
10 |
+
from typing import Tuple
|
11 |
+
|
12 |
+
import onnx
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from onnxruntime.quantization import QuantType
|
16 |
+
from onnxruntime.quantization.quantize import quantize_dynamic
|
17 |
+
from segment_anything import sam_model_registry
|
18 |
+
from segment_anything.modeling import Sam
|
19 |
+
from segment_anything.utils.amg import calculate_stability_score
|
20 |
+
from torch.nn import functional as F
|
21 |
+
|
22 |
+
CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
|
23 |
+
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
|
24 |
+
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
25 |
+
MODEL_TYPE = "default"
|
26 |
+
|
27 |
+
|
28 |
+
class SamOnnxModel(nn.Module):
|
29 |
+
"""
|
30 |
+
This model should not be called directly, but is used in ONNX export.
|
31 |
+
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
|
32 |
+
with some functions modified to enable model tracing. Also supports extra
|
33 |
+
options controlling what information. See the ONNX export script for details.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
model: Sam,
|
39 |
+
return_single_mask: bool,
|
40 |
+
use_stability_score: bool = False,
|
41 |
+
return_extra_metrics: bool = False,
|
42 |
+
) -> None:
|
43 |
+
super().__init__()
|
44 |
+
self.mask_decoder = model.mask_decoder
|
45 |
+
self.model = model
|
46 |
+
self.img_size = model.image_encoder.img_size
|
47 |
+
self.return_single_mask = return_single_mask
|
48 |
+
self.use_stability_score = use_stability_score
|
49 |
+
self.stability_score_offset = 1.0
|
50 |
+
self.return_extra_metrics = return_extra_metrics
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def resize_longest_image_size(
|
54 |
+
input_image_size: torch.Tensor, longest_side: int
|
55 |
+
) -> torch.Tensor:
|
56 |
+
input_image_size = input_image_size.to(torch.float32)
|
57 |
+
scale = longest_side / torch.max(input_image_size)
|
58 |
+
transformed_size = scale * input_image_size
|
59 |
+
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
|
60 |
+
return transformed_size
|
61 |
+
|
62 |
+
def _embed_points(
|
63 |
+
self, point_coords: torch.Tensor, point_labels: torch.Tensor
|
64 |
+
) -> torch.Tensor:
|
65 |
+
point_coords = point_coords + 0.5
|
66 |
+
point_coords = point_coords / self.img_size
|
67 |
+
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
68 |
+
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
69 |
+
|
70 |
+
point_embedding = point_embedding * (point_labels != -1)
|
71 |
+
point_embedding = (
|
72 |
+
point_embedding
|
73 |
+
+ self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
|
74 |
+
)
|
75 |
+
|
76 |
+
for i in range(self.model.prompt_encoder.num_point_embeddings):
|
77 |
+
point_embedding = (
|
78 |
+
point_embedding
|
79 |
+
+ self.model.prompt_encoder.point_embeddings[i].weight
|
80 |
+
* (point_labels == i)
|
81 |
+
)
|
82 |
+
|
83 |
+
return point_embedding
|
84 |
+
|
85 |
+
def _embed_masks(
|
86 |
+
self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
|
87 |
+
) -> torch.Tensor:
|
88 |
+
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
|
89 |
+
input_mask
|
90 |
+
)
|
91 |
+
mask_embedding = mask_embedding + (
|
92 |
+
1 - has_mask_input
|
93 |
+
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
94 |
+
return mask_embedding
|
95 |
+
|
96 |
+
def mask_postprocessing(
|
97 |
+
self, masks: torch.Tensor, orig_im_size: torch.Tensor
|
98 |
+
) -> torch.Tensor:
|
99 |
+
masks = F.interpolate(
|
100 |
+
masks,
|
101 |
+
size=(self.img_size, self.img_size),
|
102 |
+
mode="bilinear",
|
103 |
+
align_corners=False,
|
104 |
+
)
|
105 |
+
|
106 |
+
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(
|
107 |
+
torch.int64
|
108 |
+
)
|
109 |
+
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
|
110 |
+
|
111 |
+
orig_im_size = orig_im_size.to(torch.int64)
|
112 |
+
h, w = orig_im_size[0], orig_im_size[1]
|
113 |
+
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
114 |
+
return masks
|
115 |
+
|
116 |
+
def select_masks(
|
117 |
+
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
118 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
119 |
+
# Determine if we should return the multiclick mask
|
120 |
+
# or not from the number of points.
|
121 |
+
# The reweighting is used to avoid control flow.
|
122 |
+
score_reweight = torch.tensor(
|
123 |
+
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
124 |
+
).to(iou_preds.device)
|
125 |
+
score = iou_preds + (num_points - 2.5) * score_reweight
|
126 |
+
best_idx = torch.argmax(score, dim=1)
|
127 |
+
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
128 |
+
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
129 |
+
|
130 |
+
return masks, iou_preds
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
image_embeddings: torch.Tensor,
|
136 |
+
point_coords: torch.Tensor,
|
137 |
+
point_labels: torch.Tensor,
|
138 |
+
mask_input: torch.Tensor,
|
139 |
+
has_mask_input: torch.Tensor,
|
140 |
+
orig_im_size: torch.Tensor,
|
141 |
+
):
|
142 |
+
sparse_embedding = self._embed_points(point_coords, point_labels)
|
143 |
+
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
144 |
+
|
145 |
+
masks, scores = self.model.mask_decoder.predict_masks(
|
146 |
+
image_embeddings=image_embeddings,
|
147 |
+
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
148 |
+
sparse_prompt_embeddings=sparse_embedding,
|
149 |
+
dense_prompt_embeddings=dense_embedding,
|
150 |
+
)
|
151 |
+
|
152 |
+
if self.use_stability_score:
|
153 |
+
scores = calculate_stability_score(
|
154 |
+
masks, self.model.mask_threshold, self.stability_score_offset
|
155 |
+
)
|
156 |
+
|
157 |
+
if self.return_single_mask:
|
158 |
+
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
159 |
+
|
160 |
+
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
161 |
+
|
162 |
+
if self.return_extra_metrics:
|
163 |
+
stability_scores = calculate_stability_score(
|
164 |
+
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
|
165 |
+
)
|
166 |
+
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
|
167 |
+
return upscaled_masks, scores, stability_scores, areas, masks
|
168 |
+
|
169 |
+
return upscaled_masks, scores, masks
|
170 |
+
|
171 |
+
|
172 |
+
def load_model(
|
173 |
+
checkpoint_path: str = CHECKPOINT_PATH,
|
174 |
+
checkpoint_name: str = CHECKPOINT_NAME,
|
175 |
+
checkpoint_url: str = CHECKPOINT_URL,
|
176 |
+
model_type: str = MODEL_TYPE,
|
177 |
+
) -> Sam:
|
178 |
+
if not os.path.exists(checkpoint_path):
|
179 |
+
os.makedirs(checkpoint_path)
|
180 |
+
checkpoint = os.path.join(checkpoint_path, checkpoint_name)
|
181 |
+
if not os.path.exists(checkpoint):
|
182 |
+
print("Downloading the model weights...")
|
183 |
+
urllib.request.urlretrieve(checkpoint_url, checkpoint)
|
184 |
+
print(f"The model weights saved as {checkpoint}")
|
185 |
+
print(f"Load the model weights from {checkpoint}")
|
186 |
+
return sam_model_registry[model_type](checkpoint=checkpoint)
|
187 |
+
|
188 |
+
|
189 |
+
if __name__ == "__main__":
|
190 |
+
sam = load_model()
|
191 |
+
onnx_model = SamOnnxModel(sam, return_single_mask=True)
|
192 |
+
|
193 |
+
dynamic_axes = {
|
194 |
+
"point_coords": {1: "num_points"},
|
195 |
+
"point_labels": {1: "num_points"},
|
196 |
+
}
|
197 |
+
|
198 |
+
embed_dim = sam.prompt_encoder.embed_dim
|
199 |
+
embed_size = sam.prompt_encoder.image_embedding_size
|
200 |
+
mask_input_size = [4 * x for x in embed_size]
|
201 |
+
dummy_inputs = {
|
202 |
+
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
|
203 |
+
"point_coords": torch.randint(
|
204 |
+
low=0, high=1024, size=(1, 5, 2), dtype=torch.float
|
205 |
+
),
|
206 |
+
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
|
207 |
+
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
|
208 |
+
"has_mask_input": torch.tensor([1], dtype=torch.float),
|
209 |
+
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
|
210 |
+
}
|
211 |
+
output_names = ["masks", "iou_predictions", "low_res_masks"]
|
212 |
+
|
213 |
+
with warnings.catch_warnings():
|
214 |
+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
215 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
216 |
+
torch.onnx.export(
|
217 |
+
onnx_model,
|
218 |
+
tuple(dummy_inputs.values()),
|
219 |
+
"sam_decoder.onnx",
|
220 |
+
export_params=True,
|
221 |
+
opset_version=17,
|
222 |
+
do_constant_folding=True,
|
223 |
+
input_names=list(dummy_inputs.keys()),
|
224 |
+
output_names=output_names,
|
225 |
+
dynamic_axes=dynamic_axes,
|
226 |
+
)
|
227 |
+
|
228 |
+
quantize_dynamic(
|
229 |
+
model_input="sam_decoder.onnx",
|
230 |
+
model_output="sam_decoder_uint8.onnx",
|
231 |
+
optimize_model=True,
|
232 |
+
per_channel=False,
|
233 |
+
reduce_range=False,
|
234 |
+
weight_type=QuantType.QUInt8,
|
235 |
+
)
|
236 |
+
|
237 |
+
# Validate
|
238 |
+
onnx.checker.check_model("sam_decoder_uint8.onnx")
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
2 |
+
torch == 2.0.0
|
3 |
+
onnx == 1.13.1
|
4 |
+
onnxruntime == 1.14.1
|
sam_decoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:00711579c4577ac65da750e04fcea58786963eb2229b7ce431c5806a4eaacecc
|
3 |
+
size 16501359
|
sam_quantized_uint8.onnx → sam_decoder_uint8.onnx
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22440f13a3c4ebcc1d7c0f018c1ad913c1851502bd61a52b37acd55a0569d25f
|
3 |
+
size 8743672
|