curt-park commited on
Commit
c4dfe50
1 Parent(s): 2708ecd

Add quantization script

Browse files
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ sam_decoder.onnx filter=lfs diff=lfs merge=lfs -text
36
+ sam_decoder_uint8.onnx filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
convert_onnx.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import urllib
9
+ import warnings
10
+ from typing import Tuple
11
+
12
+ import onnx
13
+ import torch
14
+ import torch.nn as nn
15
+ from onnxruntime.quantization import QuantType
16
+ from onnxruntime.quantization.quantize import quantize_dynamic
17
+ from segment_anything import sam_model_registry
18
+ from segment_anything.modeling import Sam
19
+ from segment_anything.utils.amg import calculate_stability_score
20
+ from torch.nn import functional as F
21
+
22
+ CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
23
+ CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
24
+ CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
25
+ MODEL_TYPE = "default"
26
+
27
+
28
+ class SamOnnxModel(nn.Module):
29
+ """
30
+ This model should not be called directly, but is used in ONNX export.
31
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
32
+ with some functions modified to enable model tracing. Also supports extra
33
+ options controlling what information. See the ONNX export script for details.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model: Sam,
39
+ return_single_mask: bool,
40
+ use_stability_score: bool = False,
41
+ return_extra_metrics: bool = False,
42
+ ) -> None:
43
+ super().__init__()
44
+ self.mask_decoder = model.mask_decoder
45
+ self.model = model
46
+ self.img_size = model.image_encoder.img_size
47
+ self.return_single_mask = return_single_mask
48
+ self.use_stability_score = use_stability_score
49
+ self.stability_score_offset = 1.0
50
+ self.return_extra_metrics = return_extra_metrics
51
+
52
+ @staticmethod
53
+ def resize_longest_image_size(
54
+ input_image_size: torch.Tensor, longest_side: int
55
+ ) -> torch.Tensor:
56
+ input_image_size = input_image_size.to(torch.float32)
57
+ scale = longest_side / torch.max(input_image_size)
58
+ transformed_size = scale * input_image_size
59
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
60
+ return transformed_size
61
+
62
+ def _embed_points(
63
+ self, point_coords: torch.Tensor, point_labels: torch.Tensor
64
+ ) -> torch.Tensor:
65
+ point_coords = point_coords + 0.5
66
+ point_coords = point_coords / self.img_size
67
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
68
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
69
+
70
+ point_embedding = point_embedding * (point_labels != -1)
71
+ point_embedding = (
72
+ point_embedding
73
+ + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
74
+ )
75
+
76
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
77
+ point_embedding = (
78
+ point_embedding
79
+ + self.model.prompt_encoder.point_embeddings[i].weight
80
+ * (point_labels == i)
81
+ )
82
+
83
+ return point_embedding
84
+
85
+ def _embed_masks(
86
+ self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
87
+ ) -> torch.Tensor:
88
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
89
+ input_mask
90
+ )
91
+ mask_embedding = mask_embedding + (
92
+ 1 - has_mask_input
93
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
94
+ return mask_embedding
95
+
96
+ def mask_postprocessing(
97
+ self, masks: torch.Tensor, orig_im_size: torch.Tensor
98
+ ) -> torch.Tensor:
99
+ masks = F.interpolate(
100
+ masks,
101
+ size=(self.img_size, self.img_size),
102
+ mode="bilinear",
103
+ align_corners=False,
104
+ )
105
+
106
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(
107
+ torch.int64
108
+ )
109
+ masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
110
+
111
+ orig_im_size = orig_im_size.to(torch.int64)
112
+ h, w = orig_im_size[0], orig_im_size[1]
113
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
114
+ return masks
115
+
116
+ def select_masks(
117
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
118
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
119
+ # Determine if we should return the multiclick mask
120
+ # or not from the number of points.
121
+ # The reweighting is used to avoid control flow.
122
+ score_reweight = torch.tensor(
123
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
124
+ ).to(iou_preds.device)
125
+ score = iou_preds + (num_points - 2.5) * score_reweight
126
+ best_idx = torch.argmax(score, dim=1)
127
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
128
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
129
+
130
+ return masks, iou_preds
131
+
132
+ @torch.no_grad()
133
+ def forward(
134
+ self,
135
+ image_embeddings: torch.Tensor,
136
+ point_coords: torch.Tensor,
137
+ point_labels: torch.Tensor,
138
+ mask_input: torch.Tensor,
139
+ has_mask_input: torch.Tensor,
140
+ orig_im_size: torch.Tensor,
141
+ ):
142
+ sparse_embedding = self._embed_points(point_coords, point_labels)
143
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
144
+
145
+ masks, scores = self.model.mask_decoder.predict_masks(
146
+ image_embeddings=image_embeddings,
147
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
148
+ sparse_prompt_embeddings=sparse_embedding,
149
+ dense_prompt_embeddings=dense_embedding,
150
+ )
151
+
152
+ if self.use_stability_score:
153
+ scores = calculate_stability_score(
154
+ masks, self.model.mask_threshold, self.stability_score_offset
155
+ )
156
+
157
+ if self.return_single_mask:
158
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
159
+
160
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
161
+
162
+ if self.return_extra_metrics:
163
+ stability_scores = calculate_stability_score(
164
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
165
+ )
166
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
167
+ return upscaled_masks, scores, stability_scores, areas, masks
168
+
169
+ return upscaled_masks, scores, masks
170
+
171
+
172
+ def load_model(
173
+ checkpoint_path: str = CHECKPOINT_PATH,
174
+ checkpoint_name: str = CHECKPOINT_NAME,
175
+ checkpoint_url: str = CHECKPOINT_URL,
176
+ model_type: str = MODEL_TYPE,
177
+ ) -> Sam:
178
+ if not os.path.exists(checkpoint_path):
179
+ os.makedirs(checkpoint_path)
180
+ checkpoint = os.path.join(checkpoint_path, checkpoint_name)
181
+ if not os.path.exists(checkpoint):
182
+ print("Downloading the model weights...")
183
+ urllib.request.urlretrieve(checkpoint_url, checkpoint)
184
+ print(f"The model weights saved as {checkpoint}")
185
+ print(f"Load the model weights from {checkpoint}")
186
+ return sam_model_registry[model_type](checkpoint=checkpoint)
187
+
188
+
189
+ if __name__ == "__main__":
190
+ sam = load_model()
191
+ onnx_model = SamOnnxModel(sam, return_single_mask=True)
192
+
193
+ dynamic_axes = {
194
+ "point_coords": {1: "num_points"},
195
+ "point_labels": {1: "num_points"},
196
+ }
197
+
198
+ embed_dim = sam.prompt_encoder.embed_dim
199
+ embed_size = sam.prompt_encoder.image_embedding_size
200
+ mask_input_size = [4 * x for x in embed_size]
201
+ dummy_inputs = {
202
+ "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
203
+ "point_coords": torch.randint(
204
+ low=0, high=1024, size=(1, 5, 2), dtype=torch.float
205
+ ),
206
+ "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
207
+ "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
208
+ "has_mask_input": torch.tensor([1], dtype=torch.float),
209
+ "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
210
+ }
211
+ output_names = ["masks", "iou_predictions", "low_res_masks"]
212
+
213
+ with warnings.catch_warnings():
214
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
215
+ warnings.filterwarnings("ignore", category=UserWarning)
216
+ torch.onnx.export(
217
+ onnx_model,
218
+ tuple(dummy_inputs.values()),
219
+ "sam_decoder.onnx",
220
+ export_params=True,
221
+ opset_version=17,
222
+ do_constant_folding=True,
223
+ input_names=list(dummy_inputs.keys()),
224
+ output_names=output_names,
225
+ dynamic_axes=dynamic_axes,
226
+ )
227
+
228
+ quantize_dynamic(
229
+ model_input="sam_decoder.onnx",
230
+ model_output="sam_decoder_uint8.onnx",
231
+ optimize_model=True,
232
+ per_channel=False,
233
+ reduce_range=False,
234
+ weight_type=QuantType.QUInt8,
235
+ )
236
+
237
+ # Validate
238
+ onnx.checker.check_model("sam_decoder_uint8.onnx")
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git+https://github.com/facebookresearch/segment-anything.git
2
+ torch == 2.0.0
3
+ onnx == 1.13.1
4
+ onnxruntime == 1.14.1
sam_decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00711579c4577ac65da750e04fcea58786963eb2229b7ce431c5806a4eaacecc
3
+ size 16501359
sam_quantized_uint8.onnx → sam_decoder_uint8.onnx RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a236d7d85ae3da39569cd1d46689a13741709b757b54390cd773b1dcaddfdc61
3
- size 8743733
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22440f13a3c4ebcc1d7c0f018c1ad913c1851502bd61a52b37acd55a0569d25f
3
+ size 8743672