import os import functools import torch from torch import nn import torch.nn.functional as F from huggingface_hub import hf_hub_download from typing import Optional, List, Union, Tuple, Type from segment_anything import build_sam from segment_anything.mobile_encoder.tiny_vit_sam import TinyViT from segment_anything.modeling import PromptEncoder, MaskDecoder, TwoWayTransformer from segment_anything.modeling.image_encoder import ImageEncoderViT, LayerNorm2d, PatchEmbed, Block, Attention from segment_anything.mobile_encoder.setup_mobile_sam import load_mobile_sam from segment_anything.modeling.sam import Sam from sam_extension.distillation_models.fastertinyvit import FasterTinyViT from sam_extension.distillation_models.dino import DINO # from sam_extension.distillation_models.flashvision_transformer import FlashVisionTransformer SAM_REPO_ID = 'YouLiXiya/YL-SAM' hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir_use_symlinks=True) class SAMImageEncoder(nn.Module): def __init__(self, sam_checkpoint_path, device='cuda'): super(SAMImageEncoder, self).__init__() sam = build_sam(sam_checkpoint_path).to(device) self.image_encoder = sam.image_encoder del sam torch.cuda.empty_cache() def forward(self, x): return self.image_encoder(x) class MobileSAMImageEncoder(nn.Module): def __init__(self, sam_checkpoint_path, device='cuda'): super(MobileSAMImageEncoder, self).__init__() sam = load_mobile_sam(sam_checkpoint_path, device) self.image_encoder = sam.image_encoder del sam torch.cuda.empty_cache() def forward(self, x): return self.image_encoder(x) class SAMEncoderViT(nn.Module): def __init__( self, img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), multi_scale: bool = False, output_shape: Union[Tuple, List] = None ) -> None: """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ super().__init__() self.img_size = img_size self.multi_scale = multi_scale self.output_shape = tuple(output_shape) if output_shape else None self.patch_embed = PatchEmbed( kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), in_chans=in_chans, embed_dim=embed_dim, ) self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) ) self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=(img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_dim*depth if self.multi_scale and self.output_shape else embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed if self.multi_scale and self.output_shape: output_list = [] for blk in self.blocks: x = blk(x) output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear')) x = self.neck(torch.cat(output_list, dim=1)) else: for blk in self.blocks: x = blk(x) x = self.neck(x.permute(0, 3, 1, 2)) return x class SAMEncoderAdaptor(nn.Module): def __init__(self, img_size: int, input_size: Optional[Tuple[int, int]], embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), multi_scale: bool = False, output_shape: Union[Tuple, List] = None): super(SAMEncoderAdaptor, self).__init__() self.img_size = img_size self.multi_scale = multi_scale self.output_shape = tuple(output_shape) if output_shape else None self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros(1, input_size[0], input_size[1], embed_dim) ) self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=input_size, ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_dim * depth if self.multi_scale and self.output_shape else embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) def forward(self, x: torch.Tensor, original_size: Union[Tuple, List] = None) -> torch.Tensor: if original_size: original_size = torch.LongTensor(original_size) output_shape = x.shape[-2:] if original_size.ndim == 1: original_size = original_size[None, ...] adaptor_inputs = [] for i in range(original_size.shape[0]): h, w = original_size[i] if h > w: new_h = output_shape[0] new_w = int(w * new_h / h) else: new_w = output_shape[1] new_h = int(h * new_w / w) encoder_output = x[0].unsqueeze(0) encoder_output = F.interpolate(encoder_output, size=(new_h, new_w), mode='bilinear') pad_h = output_shape[0] - new_h pad_w = output_shape[1] - new_w encoder_output = F.pad(encoder_output, (0, pad_w, 0, pad_h)) adaptor_inputs.append(encoder_output) adaptor_inputs = torch.cat(adaptor_inputs, dim=0) x = adaptor_inputs.permute(0, 2, 3, 1) if self.pos_embed is not None: x = x + self.pos_embed if self.multi_scale and self.output_shape: output_list = [] for blk in self.blocks: x = blk(x) output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear')) x = self.neck(torch.cat(output_list, dim=1)) else: for blk in self.blocks: x = blk(x) x = self.neck(x.permute(0, 3, 1, 2)) return x class DINOSAMViT(nn.Module): def __init__(self, dino_model_type, device='cuda', pca_dim=None, **kwargs ): super(DINOSAMViT, self).__init__() self.img_size = kwargs['img_size'] if not pca_dim: pca_dim = None self.dino = DINO(dino_model_type, device, self.img_size, pca_dim) self.input_size = tuple(kwargs['output_shape']) # input_size = self.dino.model.patch_embed.img_size // self.dino.model.patch_embed.img_size # self.input_size = (input_size, input_size) embed_dim = pca_dim if pca_dim is not None else self.dino.model.embed_dim kwargs.update({'input_size': self.input_size, 'embed_dim': embed_dim}) self.adaptor = SAMEncoderAdaptor(**kwargs).to(device) def extract_dino_features(self, x, transform=False, size = None): return self.dino.extract_features(x, transform, size) def forward(self, x, transform=False, size = None): dino_feature = F.normalize(self.extract_dino_features(x, transform, size), dim=3) adaptor_input = F.interpolate(dino_feature.permute(0, 3, 1, 2), size=self.input_size, mode='bilinear').permute(0, 2, 3, 1) return self.adaptor(adaptor_input) def setup_model(model_config): prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size model = eval(model_config.pop('type'))(**model_config) if model.__class__.__name__ == 'SAMEncoderAdaptor': adaptor = model image_encoder = load_sam('weights/sam/mobile_sam.pt', 'mobile_sam', 'cpu').image_encoder else: adaptor = None image_encoder = model sam = Sam( image_encoder=image_encoder, prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, image_embedding_size=(image_embedding_size, image_embedding_size), input_image_size=(image_size, image_size), mask_in_chans=16, ), mask_decoder=MaskDecoder( num_multimask_outputs=3, transformer=TwoWayTransformer( depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8, ), transformer_dim=prompt_embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, ), adaptor=adaptor, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) return sam def load_distillation_sam(distillation_sam_ckpt_path, device='cuda'): ckpt = torch.load(distillation_sam_ckpt_path) sam = setup_model(ckpt['model_config']) sam.load_state_dict(ckpt['model']) return sam.to(device) def load_sam(sam_ckpt_path, sam_version, device): if not os.path.exists(sam_ckpt_path): parent_dir = os.path.dirname(sam_ckpt_path) os.makedirs(parent_dir, exist_ok=True) hf_sam_download(filename=os.path.basename(sam_ckpt_path), local_dir=parent_dir) if sam_version == 'sam': sam = build_sam(sam_ckpt_path).to(device) elif sam_version == 'mobile_sam': sam = load_mobile_sam(sam_ckpt_path, device) elif sam_version == 'distillation_sam': sam = load_distillation_sam(sam_ckpt_path, device) else: raise ValueError('sam version error, please give sam version in [sam, mobile_sam, distillation_sam]') return sam if __name__ == '__main__': from distillation.utils import get_parameter_number vit = SAMEncoderViT(depth=3, embed_dim=256, img_size=512, mlp_ratio=4, num_heads=16, patch_size=8, qkv_bias=True, use_rel_pos=True, global_attn_indexes=[1], window_size=16, out_chans=256, multi_scale=False, output_shape='').cuda() x = torch.randn((1, 3, 512, 512)).cuda() print(vit(x).shape) print(get_parameter_number(vit))