Spaces:
Running
Running
import os | |
import functools | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from huggingface_hub import hf_hub_download | |
from typing import Optional, List, Union, Tuple, Type | |
from segment_anything import build_sam | |
from segment_anything.mobile_encoder.tiny_vit_sam import TinyViT | |
from segment_anything.modeling import PromptEncoder, MaskDecoder, TwoWayTransformer | |
from segment_anything.modeling.image_encoder import ImageEncoderViT, LayerNorm2d, PatchEmbed, Block, Attention | |
from segment_anything.mobile_encoder.setup_mobile_sam import load_mobile_sam | |
from segment_anything.modeling.sam import Sam | |
from sam_extension.distillation_models.fastertinyvit import FasterTinyViT | |
from sam_extension.distillation_models.dino import DINO | |
# from sam_extension.distillation_models.flashvision_transformer import FlashVisionTransformer | |
SAM_REPO_ID = 'YouLiXiya/YL-SAM' | |
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir_use_symlinks=True) | |
class SAMImageEncoder(nn.Module): | |
def __init__(self, | |
sam_checkpoint_path, | |
device='cuda'): | |
super(SAMImageEncoder, self).__init__() | |
sam = build_sam(sam_checkpoint_path).to(device) | |
self.image_encoder = sam.image_encoder | |
del sam | |
torch.cuda.empty_cache() | |
def forward(self, x): | |
return self.image_encoder(x) | |
class MobileSAMImageEncoder(nn.Module): | |
def __init__(self, | |
sam_checkpoint_path, | |
device='cuda'): | |
super(MobileSAMImageEncoder, self).__init__() | |
sam = load_mobile_sam(sam_checkpoint_path, device) | |
self.image_encoder = sam.image_encoder | |
del sam | |
torch.cuda.empty_cache() | |
def forward(self, x): | |
return self.image_encoder(x) | |
class SAMEncoderViT(nn.Module): | |
def __init__( | |
self, | |
img_size: int = 1024, | |
patch_size: int = 16, | |
in_chans: int = 3, | |
embed_dim: int = 768, | |
depth: int = 12, | |
num_heads: int = 12, | |
mlp_ratio: float = 4.0, | |
out_chans: int = 256, | |
qkv_bias: bool = True, | |
norm_layer: Type[nn.Module] = nn.LayerNorm, | |
act_layer: Type[nn.Module] = nn.GELU, | |
use_abs_pos: bool = True, | |
use_rel_pos: bool = False, | |
rel_pos_zero_init: bool = True, | |
window_size: int = 0, | |
global_attn_indexes: Tuple[int, ...] = (), | |
multi_scale: bool = False, | |
output_shape: Union[Tuple, List] = None | |
) -> None: | |
""" | |
Args: | |
img_size (int): Input image size. | |
patch_size (int): Patch size. | |
in_chans (int): Number of input image channels. | |
embed_dim (int): Patch embedding dimension. | |
depth (int): Depth of ViT. | |
num_heads (int): Number of attention heads in each ViT block. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
norm_layer (nn.Module): Normalization layer. | |
act_layer (nn.Module): Activation layer. | |
use_abs_pos (bool): If True, use absolute positional embeddings. | |
use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | |
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | |
window_size (int): Window size for window attention blocks. | |
global_attn_indexes (list): Indexes for blocks using global attention. | |
""" | |
super().__init__() | |
self.img_size = img_size | |
self.multi_scale = multi_scale | |
self.output_shape = tuple(output_shape) if output_shape else None | |
self.patch_embed = PatchEmbed( | |
kernel_size=(patch_size, patch_size), | |
stride=(patch_size, patch_size), | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
) | |
self.pos_embed: Optional[nn.Parameter] = None | |
if use_abs_pos: | |
# Initialize absolute positional embedding with pretrain image size. | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) | |
) | |
self.blocks = nn.ModuleList() | |
for i in range(depth): | |
block = Block( | |
dim=embed_dim, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
norm_layer=norm_layer, | |
act_layer=act_layer, | |
use_rel_pos=use_rel_pos, | |
rel_pos_zero_init=rel_pos_zero_init, | |
window_size=window_size if i not in global_attn_indexes else 0, | |
input_size=(img_size // patch_size, img_size // patch_size), | |
) | |
self.blocks.append(block) | |
self.neck = nn.Sequential( | |
nn.Conv2d( | |
embed_dim*depth if self.multi_scale and self.output_shape else embed_dim, | |
out_chans, | |
kernel_size=1, | |
bias=False, | |
), | |
LayerNorm2d(out_chans), | |
nn.Conv2d( | |
out_chans, | |
out_chans, | |
kernel_size=3, | |
padding=1, | |
bias=False, | |
), | |
LayerNorm2d(out_chans), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.patch_embed(x) | |
if self.pos_embed is not None: | |
x = x + self.pos_embed | |
if self.multi_scale and self.output_shape: | |
output_list = [] | |
for blk in self.blocks: | |
x = blk(x) | |
output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear')) | |
x = self.neck(torch.cat(output_list, dim=1)) | |
else: | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.neck(x.permute(0, 3, 1, 2)) | |
return x | |
class SAMEncoderAdaptor(nn.Module): | |
def __init__(self, | |
img_size: int, | |
input_size: Optional[Tuple[int, int]], | |
embed_dim: int = 768, | |
depth: int = 12, | |
num_heads: int = 12, | |
mlp_ratio: float = 4.0, | |
out_chans: int = 256, | |
qkv_bias: bool = True, | |
norm_layer: Type[nn.Module] = nn.LayerNorm, | |
act_layer: Type[nn.Module] = nn.GELU, | |
use_abs_pos: bool = True, | |
use_rel_pos: bool = False, | |
rel_pos_zero_init: bool = True, | |
window_size: int = 0, | |
global_attn_indexes: Tuple[int, ...] = (), | |
multi_scale: bool = False, | |
output_shape: Union[Tuple, List] = None): | |
super(SAMEncoderAdaptor, self).__init__() | |
self.img_size = img_size | |
self.multi_scale = multi_scale | |
self.output_shape = tuple(output_shape) if output_shape else None | |
self.pos_embed: Optional[nn.Parameter] = None | |
if use_abs_pos: | |
# Initialize absolute positional embedding with pretrain image size. | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, input_size[0], input_size[1], embed_dim) | |
) | |
self.blocks = nn.ModuleList() | |
for i in range(depth): | |
block = Block( | |
dim=embed_dim, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
norm_layer=norm_layer, | |
act_layer=act_layer, | |
use_rel_pos=use_rel_pos, | |
rel_pos_zero_init=rel_pos_zero_init, | |
window_size=window_size if i not in global_attn_indexes else 0, | |
input_size=input_size, | |
) | |
self.blocks.append(block) | |
self.neck = nn.Sequential( | |
nn.Conv2d( | |
embed_dim * depth if self.multi_scale and self.output_shape else embed_dim, | |
out_chans, | |
kernel_size=1, | |
bias=False, | |
), | |
LayerNorm2d(out_chans), | |
nn.Conv2d( | |
out_chans, | |
out_chans, | |
kernel_size=3, | |
padding=1, | |
bias=False, | |
), | |
LayerNorm2d(out_chans), | |
) | |
def forward(self, x: torch.Tensor, original_size: Union[Tuple, List] = None) -> torch.Tensor: | |
if original_size: | |
original_size = torch.LongTensor(original_size) | |
output_shape = x.shape[-2:] | |
if original_size.ndim == 1: | |
original_size = original_size[None, ...] | |
adaptor_inputs = [] | |
for i in range(original_size.shape[0]): | |
h, w = original_size[i] | |
if h > w: | |
new_h = output_shape[0] | |
new_w = int(w * new_h / h) | |
else: | |
new_w = output_shape[1] | |
new_h = int(h * new_w / w) | |
encoder_output = x[0].unsqueeze(0) | |
encoder_output = F.interpolate(encoder_output, size=(new_h, new_w), mode='bilinear') | |
pad_h = output_shape[0] - new_h | |
pad_w = output_shape[1] - new_w | |
encoder_output = F.pad(encoder_output, (0, pad_w, 0, pad_h)) | |
adaptor_inputs.append(encoder_output) | |
adaptor_inputs = torch.cat(adaptor_inputs, dim=0) | |
x = adaptor_inputs.permute(0, 2, 3, 1) | |
if self.pos_embed is not None: | |
x = x + self.pos_embed | |
if self.multi_scale and self.output_shape: | |
output_list = [] | |
for blk in self.blocks: | |
x = blk(x) | |
output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear')) | |
x = self.neck(torch.cat(output_list, dim=1)) | |
else: | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.neck(x.permute(0, 3, 1, 2)) | |
return x | |
class DINOSAMViT(nn.Module): | |
def __init__(self, | |
dino_model_type, | |
device='cuda', | |
pca_dim=None, | |
**kwargs | |
): | |
super(DINOSAMViT, self).__init__() | |
self.img_size = kwargs['img_size'] | |
if not pca_dim: | |
pca_dim = None | |
self.dino = DINO(dino_model_type, device, self.img_size, pca_dim) | |
self.input_size = tuple(kwargs['output_shape']) | |
# input_size = self.dino.model.patch_embed.img_size // self.dino.model.patch_embed.img_size | |
# self.input_size = (input_size, input_size) | |
embed_dim = pca_dim if pca_dim is not None else self.dino.model.embed_dim | |
kwargs.update({'input_size': self.input_size, 'embed_dim': embed_dim}) | |
self.adaptor = SAMEncoderAdaptor(**kwargs).to(device) | |
def extract_dino_features(self, x, transform=False, size = None): | |
return self.dino.extract_features(x, transform, size) | |
def forward(self, x, transform=False, size = None): | |
dino_feature = F.normalize(self.extract_dino_features(x, transform, size), dim=3) | |
adaptor_input = F.interpolate(dino_feature.permute(0, 3, 1, 2), size=self.input_size, mode='bilinear').permute(0, 2, 3, 1) | |
return self.adaptor(adaptor_input) | |
def setup_model(model_config): | |
prompt_embed_dim = 256 | |
image_size = 1024 | |
vit_patch_size = 16 | |
image_embedding_size = image_size // vit_patch_size | |
model = eval(model_config.pop('type'))(**model_config) | |
if model.__class__.__name__ == 'SAMEncoderAdaptor': | |
adaptor = model | |
image_encoder = load_sam('weights/sam/mobile_sam.pt', 'mobile_sam', 'cpu').image_encoder | |
else: | |
adaptor = None | |
image_encoder = model | |
sam = Sam( | |
image_encoder=image_encoder, | |
prompt_encoder=PromptEncoder( | |
embed_dim=prompt_embed_dim, | |
image_embedding_size=(image_embedding_size, image_embedding_size), | |
input_image_size=(image_size, image_size), | |
mask_in_chans=16, | |
), | |
mask_decoder=MaskDecoder( | |
num_multimask_outputs=3, | |
transformer=TwoWayTransformer( | |
depth=2, | |
embedding_dim=prompt_embed_dim, | |
mlp_dim=2048, | |
num_heads=8, | |
), | |
transformer_dim=prompt_embed_dim, | |
iou_head_depth=3, | |
iou_head_hidden_dim=256, | |
), | |
adaptor=adaptor, | |
pixel_mean=[123.675, 116.28, 103.53], | |
pixel_std=[58.395, 57.12, 57.375], | |
) | |
return sam | |
def load_distillation_sam(distillation_sam_ckpt_path, | |
device='cuda'): | |
ckpt = torch.load(distillation_sam_ckpt_path) | |
sam = setup_model(ckpt['model_config']) | |
sam.load_state_dict(ckpt['model']) | |
return sam.to(device) | |
def load_sam(sam_ckpt_path, sam_version, device): | |
if not os.path.exists(sam_ckpt_path): | |
parent_dir = os.path.dirname(sam_ckpt_path) | |
os.makedirs(parent_dir, exist_ok=True) | |
hf_sam_download(filename=os.path.basename(sam_ckpt_path), local_dir=parent_dir) | |
if sam_version == 'sam': | |
sam = build_sam(sam_ckpt_path).to(device) | |
elif sam_version == 'mobile_sam': | |
sam = load_mobile_sam(sam_ckpt_path, device) | |
elif sam_version == 'distillation_sam': | |
sam = load_distillation_sam(sam_ckpt_path, device) | |
else: | |
raise ValueError('sam version error, please give sam version in [sam, mobile_sam, distillation_sam]') | |
return sam | |
if __name__ == '__main__': | |
from distillation.utils import get_parameter_number | |
vit = SAMEncoderViT(depth=3, | |
embed_dim=256, | |
img_size=512, | |
mlp_ratio=4, | |
num_heads=16, | |
patch_size=8, | |
qkv_bias=True, | |
use_rel_pos=True, | |
global_attn_indexes=[1], | |
window_size=16, | |
out_chans=256, | |
multi_scale=False, | |
output_shape='').cuda() | |
x = torch.randn((1, 3, 512, 512)).cuda() | |
print(vit(x).shape) | |
print(get_parameter_number(vit)) | |