YouLiXiya's picture
Upload 22 files
7dbe662
raw
history blame
14.4 kB
import os
import functools
import torch
from torch import nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from typing import Optional, List, Union, Tuple, Type
from segment_anything import build_sam
from segment_anything.mobile_encoder.tiny_vit_sam import TinyViT
from segment_anything.modeling import PromptEncoder, MaskDecoder, TwoWayTransformer
from segment_anything.modeling.image_encoder import ImageEncoderViT, LayerNorm2d, PatchEmbed, Block, Attention
from segment_anything.mobile_encoder.setup_mobile_sam import load_mobile_sam
from segment_anything.modeling.sam import Sam
from sam_extension.distillation_models.fastertinyvit import FasterTinyViT
from sam_extension.distillation_models.dino import DINO
# from sam_extension.distillation_models.flashvision_transformer import FlashVisionTransformer
SAM_REPO_ID = 'YouLiXiya/YL-SAM'
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir_use_symlinks=True)
class SAMImageEncoder(nn.Module):
def __init__(self,
sam_checkpoint_path,
device='cuda'):
super(SAMImageEncoder, self).__init__()
sam = build_sam(sam_checkpoint_path).to(device)
self.image_encoder = sam.image_encoder
del sam
torch.cuda.empty_cache()
def forward(self, x):
return self.image_encoder(x)
class MobileSAMImageEncoder(nn.Module):
def __init__(self,
sam_checkpoint_path,
device='cuda'):
super(MobileSAMImageEncoder, self).__init__()
sam = load_mobile_sam(sam_checkpoint_path, device)
self.image_encoder = sam.image_encoder
del sam
torch.cuda.empty_cache()
def forward(self, x):
return self.image_encoder(x)
class SAMEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
multi_scale: bool = False,
output_shape: Union[Tuple, List] = None
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.multi_scale = multi_scale
self.output_shape = tuple(output_shape) if output_shape else None
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim*depth if self.multi_scale and self.output_shape else embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
if self.multi_scale and self.output_shape:
output_list = []
for blk in self.blocks:
x = blk(x)
output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear'))
x = self.neck(torch.cat(output_list, dim=1))
else:
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
class SAMEncoderAdaptor(nn.Module):
def __init__(self,
img_size: int,
input_size: Optional[Tuple[int, int]],
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
multi_scale: bool = False,
output_shape: Union[Tuple, List] = None):
super(SAMEncoderAdaptor, self).__init__()
self.img_size = img_size
self.multi_scale = multi_scale
self.output_shape = tuple(output_shape) if output_shape else None
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, input_size[0], input_size[1], embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=input_size,
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim * depth if self.multi_scale and self.output_shape else embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def forward(self, x: torch.Tensor, original_size: Union[Tuple, List] = None) -> torch.Tensor:
if original_size:
original_size = torch.LongTensor(original_size)
output_shape = x.shape[-2:]
if original_size.ndim == 1:
original_size = original_size[None, ...]
adaptor_inputs = []
for i in range(original_size.shape[0]):
h, w = original_size[i]
if h > w:
new_h = output_shape[0]
new_w = int(w * new_h / h)
else:
new_w = output_shape[1]
new_h = int(h * new_w / w)
encoder_output = x[0].unsqueeze(0)
encoder_output = F.interpolate(encoder_output, size=(new_h, new_w), mode='bilinear')
pad_h = output_shape[0] - new_h
pad_w = output_shape[1] - new_w
encoder_output = F.pad(encoder_output, (0, pad_w, 0, pad_h))
adaptor_inputs.append(encoder_output)
adaptor_inputs = torch.cat(adaptor_inputs, dim=0)
x = adaptor_inputs.permute(0, 2, 3, 1)
if self.pos_embed is not None:
x = x + self.pos_embed
if self.multi_scale and self.output_shape:
output_list = []
for blk in self.blocks:
x = blk(x)
output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear'))
x = self.neck(torch.cat(output_list, dim=1))
else:
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
class DINOSAMViT(nn.Module):
def __init__(self,
dino_model_type,
device='cuda',
pca_dim=None,
**kwargs
):
super(DINOSAMViT, self).__init__()
self.img_size = kwargs['img_size']
if not pca_dim:
pca_dim = None
self.dino = DINO(dino_model_type, device, self.img_size, pca_dim)
self.input_size = tuple(kwargs['output_shape'])
# input_size = self.dino.model.patch_embed.img_size // self.dino.model.patch_embed.img_size
# self.input_size = (input_size, input_size)
embed_dim = pca_dim if pca_dim is not None else self.dino.model.embed_dim
kwargs.update({'input_size': self.input_size, 'embed_dim': embed_dim})
self.adaptor = SAMEncoderAdaptor(**kwargs).to(device)
def extract_dino_features(self, x, transform=False, size = None):
return self.dino.extract_features(x, transform, size)
def forward(self, x, transform=False, size = None):
dino_feature = F.normalize(self.extract_dino_features(x, transform, size), dim=3)
adaptor_input = F.interpolate(dino_feature.permute(0, 3, 1, 2), size=self.input_size, mode='bilinear').permute(0, 2, 3, 1)
return self.adaptor(adaptor_input)
def setup_model(model_config):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
model = eval(model_config.pop('type'))(**model_config)
if model.__class__.__name__ == 'SAMEncoderAdaptor':
adaptor = model
image_encoder = load_sam('weights/sam/mobile_sam.pt', 'mobile_sam', 'cpu').image_encoder
else:
adaptor = None
image_encoder = model
sam = Sam(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
adaptor=adaptor,
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
return sam
def load_distillation_sam(distillation_sam_ckpt_path,
device='cuda'):
ckpt = torch.load(distillation_sam_ckpt_path)
sam = setup_model(ckpt['model_config'])
sam.load_state_dict(ckpt['model'])
return sam.to(device)
def load_sam(sam_ckpt_path, sam_version, device):
if not os.path.exists(sam_ckpt_path):
parent_dir = os.path.dirname(sam_ckpt_path)
os.makedirs(parent_dir, exist_ok=True)
hf_sam_download(filename=os.path.basename(sam_ckpt_path), local_dir=parent_dir)
if sam_version == 'sam':
sam = build_sam(sam_ckpt_path).to(device)
elif sam_version == 'mobile_sam':
sam = load_mobile_sam(sam_ckpt_path, device)
elif sam_version == 'distillation_sam':
sam = load_distillation_sam(sam_ckpt_path, device)
else:
raise ValueError('sam version error, please give sam version in [sam, mobile_sam, distillation_sam]')
return sam
if __name__ == '__main__':
from distillation.utils import get_parameter_number
vit = SAMEncoderViT(depth=3,
embed_dim=256,
img_size=512,
mlp_ratio=4,
num_heads=16,
patch_size=8,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=[1],
window_size=16,
out_chans=256,
multi_scale=False,
output_shape='').cuda()
x = torch.randn((1, 3, 512, 512)).cuda()
print(vit(x).shape)
print(get_parameter_number(vit))