|
from typing import Any |
|
import argparse |
|
import pathlib |
|
|
|
import torch |
|
from torch import nn |
|
from sam2.build_sam import build_sam2 |
|
from sam2.modeling.sam2_base import SAM2Base |
|
|
|
|
|
class SAM2ImageEncoder(nn.Module): |
|
def __init__(self, sam_model: SAM2Base) -> None: |
|
super().__init__() |
|
self.model = sam_model |
|
self.image_encoder = sam_model.image_encoder |
|
self.no_mem_embed = sam_model.no_mem_embed |
|
|
|
def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: |
|
backbone_out = self.image_encoder(x) |
|
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0( |
|
backbone_out["backbone_fpn"][0] |
|
) |
|
backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1( |
|
backbone_out["backbone_fpn"][1] |
|
) |
|
|
|
feature_maps = backbone_out["backbone_fpn"][ |
|
-self.model.num_feature_levels : |
|
] |
|
vision_pos_embeds = backbone_out["vision_pos_enc"][ |
|
-self.model.num_feature_levels : |
|
] |
|
|
|
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] |
|
|
|
|
|
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] |
|
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed |
|
|
|
feats = [ |
|
feat.permute(1, 2, 0).reshape(1, -1, *feat_size) |
|
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1]) |
|
][::-1] |
|
|
|
return feats[0], feats[1], feats[2] |
|
|
|
|
|
class SAM2ImageDecoder(nn.Module): |
|
def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None: |
|
super().__init__() |
|
self.mask_decoder = sam_model.sam_mask_decoder |
|
self.prompt_encoder = sam_model.sam_prompt_encoder |
|
self.model = sam_model |
|
self.img_size = sam_model.image_size |
|
self.multimask_output = multimask_output |
|
|
|
@torch.no_grad() |
|
def forward( |
|
self, |
|
image_embed: torch.Tensor, |
|
high_res_feats_0: torch.Tensor, |
|
high_res_feats_1: torch.Tensor, |
|
point_coords: torch.Tensor, |
|
point_labels: torch.Tensor, |
|
orig_im_size: torch.Tensor, |
|
mask_input: torch.Tensor, |
|
has_mask_input: torch.Tensor, |
|
): |
|
sparse_embedding = self._embed_points(point_coords, point_labels) |
|
self.sparse_embedding = sparse_embedding |
|
dense_embedding = self._embed_masks(mask_input, has_mask_input) |
|
|
|
high_res_feats = [high_res_feats_0, high_res_feats_1] |
|
image_embed = image_embed |
|
|
|
masks, iou_predictions, _, _ = self.mask_decoder.predict_masks( |
|
image_embeddings=image_embed, |
|
image_pe=self.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embedding, |
|
dense_prompt_embeddings=dense_embedding, |
|
repeat_image=False, |
|
high_res_features=high_res_feats, |
|
) |
|
|
|
if self.multimask_output: |
|
masks = masks[:, 1:, :, :] |
|
iou_predictions = iou_predictions[:, 1:] |
|
else: |
|
masks, iou_predictions = ( |
|
self.mask_decoder._dynamic_multimask_via_stability( |
|
masks, iou_predictions |
|
) |
|
) |
|
|
|
masks = torch.clamp(masks, -32.0, 32.0) |
|
|
|
return masks, iou_predictions |
|
|
|
def _embed_points( |
|
self, point_coords: torch.Tensor, point_labels: torch.Tensor |
|
) -> torch.Tensor: |
|
|
|
point_coords = point_coords + 0.5 |
|
|
|
padding_point = torch.zeros( |
|
(point_coords.shape[0], 1, 2), device=point_coords.device |
|
) |
|
padding_label = -torch.ones( |
|
(point_labels.shape[0], 1), device=point_labels.device |
|
) |
|
point_coords = torch.cat([point_coords, padding_point], dim=1) |
|
point_labels = torch.cat([point_labels, padding_label], dim=1) |
|
|
|
point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size |
|
point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size |
|
|
|
point_embedding = self.prompt_encoder.pe_layer._pe_encoding( |
|
point_coords |
|
) |
|
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) |
|
|
|
point_embedding = point_embedding * (point_labels != -1) |
|
point_embedding = ( |
|
point_embedding |
|
+ self.prompt_encoder.not_a_point_embed.weight |
|
* (point_labels == -1) |
|
) |
|
|
|
for i in range(self.prompt_encoder.num_point_embeddings): |
|
point_embedding = ( |
|
point_embedding |
|
+ self.prompt_encoder.point_embeddings[i].weight |
|
* (point_labels == i) |
|
) |
|
|
|
return point_embedding |
|
|
|
def _embed_masks( |
|
self, input_mask: torch.Tensor, has_mask_input: torch.Tensor |
|
) -> torch.Tensor: |
|
mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling( |
|
input_mask |
|
) |
|
mask_embedding = mask_embedding + ( |
|
1 - has_mask_input |
|
) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) |
|
return mask_embedding |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Export the SAM2 prompt encoder and mask decoder to an ONNX model." |
|
) |
|
parser.add_argument( |
|
"--checkpoint", |
|
type=str, |
|
required=True, |
|
help="The path to the SAM model checkpoint.", |
|
) |
|
|
|
parser.add_argument( |
|
"--output_encoder", |
|
type=str, |
|
required=True, |
|
help="The filename to save the encoder ONNX model to.", |
|
) |
|
|
|
parser.add_argument( |
|
"--output_decoder", |
|
type=str, |
|
required=True, |
|
help="The filename to save the decoder ONNX model to.", |
|
) |
|
|
|
parser.add_argument( |
|
"--model_type", |
|
type=str, |
|
required=True, |
|
help="In the form of sam2_hiera_{tiny, small, base_plus, large}.", |
|
) |
|
|
|
parser.add_argument( |
|
"--opset", |
|
type=int, |
|
default=17, |
|
help="The ONNX opset version to use. Must be >=11", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
input_size = (1024, 1024) |
|
multimask_output = False |
|
model_type = args.model_type |
|
if model_type == "sam2.1_hiera_tiny": |
|
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml" |
|
elif model_type == "sam2.1_hiera_small": |
|
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml" |
|
elif model_type == "sam2.1_hiera_base_plus": |
|
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" |
|
elif model_type == "sam2.1_hiera_large": |
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" |
|
else: |
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" |
|
|
|
sam2_model = build_sam2(model_cfg, args.checkpoint, device="cpu") |
|
img = torch.randn(1, 3, input_size[0], input_size[1]).cpu() |
|
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu() |
|
high_res_feats_0, high_res_feats_1, image_embed = sam2_encoder(img) |
|
|
|
pathlib.Path(args.output_encoder).parent.mkdir(parents=True, exist_ok=True) |
|
torch.onnx.export( |
|
sam2_encoder, |
|
img, |
|
args.output_encoder, |
|
export_params=True, |
|
opset_version=args.opset, |
|
do_constant_folding=True, |
|
input_names=["image"], |
|
output_names=["high_res_feats_0", "high_res_feats_1", "image_embed"], |
|
) |
|
print("Saved encoder to", args.output_encoder) |
|
|
|
sam2_decoder = SAM2ImageDecoder( |
|
sam2_model, multimask_output=multimask_output |
|
).cpu() |
|
|
|
embed_dim = sam2_model.sam_prompt_encoder.embed_dim |
|
embed_size = ( |
|
sam2_model.image_size // sam2_model.backbone_stride, |
|
sam2_model.image_size // sam2_model.backbone_stride, |
|
) |
|
mask_input_size = [4 * x for x in embed_size] |
|
print(embed_dim, embed_size, mask_input_size) |
|
|
|
point_coords = torch.randint( |
|
low=0, high=input_size[1], size=(1, 5, 2), dtype=torch.float |
|
) |
|
point_labels = torch.randint(low=0, high=1, size=(1, 5), dtype=torch.float) |
|
mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float) |
|
has_mask_input = torch.tensor([1], dtype=torch.float) |
|
orig_im_size = torch.tensor([input_size[0], input_size[1]], dtype=torch.int) |
|
|
|
pathlib.Path(args.output_decoder).parent.mkdir(parents=True, exist_ok=True) |
|
torch.onnx.export( |
|
sam2_decoder, |
|
( |
|
image_embed, |
|
high_res_feats_0, |
|
high_res_feats_1, |
|
point_coords, |
|
point_labels, |
|
orig_im_size, |
|
mask_input, |
|
has_mask_input, |
|
), |
|
args.output_decoder, |
|
export_params=True, |
|
opset_version=args.opset, |
|
do_constant_folding=True, |
|
input_names=[ |
|
"image_embed", |
|
"high_res_feats_0", |
|
"high_res_feats_1", |
|
"point_coords", |
|
"point_labels", |
|
"orig_im_size", |
|
"mask_input", |
|
"has_mask_input", |
|
], |
|
output_names=["masks", "iou_predictions"], |
|
dynamic_axes={ |
|
"point_coords": {0: "num_labels", 1: "num_points"}, |
|
"point_labels": {0: "num_labels", 1: "num_points"}, |
|
"mask_input": {0: "num_labels"}, |
|
"has_mask_input": {0: "num_labels"}, |
|
}, |
|
) |
|
print("Saved decoder to", args.output_decoder) |