|
from functools import partial |
|
from typing import Optional, Tuple, Type |
|
|
|
import torch |
|
import torch.nn as nn |
|
from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer |
|
from segment_anything.modeling.common import LayerNorm2d |
|
from segment_anything.modeling.image_encoder import ( |
|
Block, |
|
PatchEmbed, |
|
window_partition, |
|
window_unpartition, |
|
) |
|
|
|
|
|
class CustomBlock(Block): |
|
def __init__(self, **kargs) -> None: |
|
super().__init__(**kargs) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
shortcut = x |
|
x = self.norm1(x) |
|
|
|
if self.window_size > 0: |
|
H, W = x.shape[1], x.shape[2] |
|
x, pad_hw = window_partition(x, self.window_size) |
|
x = self.attn(x) |
|
|
|
x = window_unpartition(x, self.window_size, pad_hw, (H, W)) |
|
else: |
|
x = self.attn(x) |
|
|
|
x = shortcut + x |
|
x = x + self.mlp(self.norm2(x)) |
|
|
|
return x |
|
|
|
|
|
class CustomImageEncoderViT(nn.Module): |
|
def __init__( |
|
self, |
|
img_size: int = 1024, |
|
patch_size: int = 16, |
|
in_chans: int = 3, |
|
embed_dim: int = 768, |
|
depth: int = 12, |
|
num_heads: int = 12, |
|
mlp_ratio: float = 4.0, |
|
out_chans: int = 256, |
|
qkv_bias: bool = True, |
|
norm_layer: Type[nn.Module] = nn.LayerNorm, |
|
act_layer: Type[nn.Module] = nn.GELU, |
|
use_abs_pos: bool = True, |
|
use_rel_pos: bool = False, |
|
rel_pos_zero_init: bool = True, |
|
window_size: int = 0, |
|
global_attn_indexes: Tuple[int, ...] = (), |
|
) -> None: |
|
super().__init__() |
|
self.img_size = img_size |
|
|
|
self.patch_embed = PatchEmbed( |
|
kernel_size=(patch_size, patch_size), |
|
stride=(patch_size, patch_size), |
|
in_chans=in_chans, |
|
embed_dim=embed_dim, |
|
) |
|
|
|
self.pos_embed: Optional[nn.Parameter] = None |
|
if use_abs_pos: |
|
|
|
self.pos_embed = nn.Parameter( |
|
torch.zeros( |
|
1, img_size // patch_size, img_size // patch_size, embed_dim |
|
) |
|
) |
|
|
|
self.blocks = nn.ModuleList() |
|
for i in range(depth): |
|
block = CustomBlock( |
|
dim=embed_dim, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
norm_layer=norm_layer, |
|
act_layer=act_layer, |
|
use_rel_pos=use_rel_pos, |
|
rel_pos_zero_init=rel_pos_zero_init, |
|
window_size=window_size if i not in global_attn_indexes else 0, |
|
input_size=(img_size // patch_size, img_size // patch_size), |
|
) |
|
self.blocks.append(block) |
|
|
|
self.neck = nn.Sequential( |
|
nn.Conv2d( |
|
embed_dim, |
|
out_chans, |
|
kernel_size=1, |
|
bias=False, |
|
), |
|
LayerNorm2d(out_chans), |
|
nn.Conv2d( |
|
out_chans, |
|
out_chans, |
|
kernel_size=3, |
|
padding=1, |
|
bias=False, |
|
), |
|
LayerNorm2d(out_chans), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.patch_embed(x) |
|
if self.pos_embed is not None: |
|
x = x + self.pos_embed |
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
|
|
x = self.neck(x.permute(0, 3, 1, 2)) |
|
|
|
return x |
|
|
|
|
|
def _build_sam_torchscript( |
|
encoder_embed_dim, |
|
encoder_depth, |
|
encoder_num_heads, |
|
encoder_global_attn_indexes, |
|
checkpoint=None, |
|
): |
|
prompt_embed_dim = 256 |
|
image_size = 1024 |
|
vit_patch_size = 16 |
|
image_embedding_size = image_size // vit_patch_size |
|
sam = Sam( |
|
image_encoder=CustomImageEncoderViT( |
|
depth=encoder_depth, |
|
embed_dim=encoder_embed_dim, |
|
img_size=image_size, |
|
mlp_ratio=4, |
|
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), |
|
num_heads=encoder_num_heads, |
|
patch_size=vit_patch_size, |
|
qkv_bias=True, |
|
use_rel_pos=True, |
|
global_attn_indexes=encoder_global_attn_indexes, |
|
window_size=14, |
|
out_chans=prompt_embed_dim, |
|
), |
|
prompt_encoder=PromptEncoder( |
|
embed_dim=prompt_embed_dim, |
|
image_embedding_size=(image_embedding_size, image_embedding_size), |
|
input_image_size=(image_size, image_size), |
|
mask_in_chans=16, |
|
), |
|
mask_decoder=MaskDecoder( |
|
num_multimask_outputs=3, |
|
transformer=TwoWayTransformer( |
|
depth=2, |
|
embedding_dim=prompt_embed_dim, |
|
mlp_dim=2048, |
|
num_heads=8, |
|
), |
|
transformer_dim=prompt_embed_dim, |
|
iou_head_depth=3, |
|
iou_head_hidden_dim=256, |
|
), |
|
pixel_mean=[123.675, 116.28, 103.53], |
|
pixel_std=[58.395, 57.12, 57.375], |
|
) |
|
sam.eval() |
|
if checkpoint is not None: |
|
with open(checkpoint, "rb") as f: |
|
state_dict = torch.load(f) |
|
sam.load_state_dict(state_dict) |
|
return sam |
|
|
|
|
|
def build_sam_vit_h_torchscript(checkpoint=None): |
|
return _build_sam_torchscript( |
|
encoder_embed_dim=1280, |
|
encoder_depth=32, |
|
encoder_num_heads=16, |
|
encoder_global_attn_indexes=[7, 15, 23, 31], |
|
checkpoint=checkpoint, |
|
) |
|
|