from typing import Tuple, List, Union import torch from torch import nn from torch.utils.checkpoint import checkpoint import torch.nn.functional as F from timm.models.layers import trunc_normal_ from sam_extension.distillation_models.fastervit import FasterViTLayer from segment_anything.mobile_encoder.tiny_vit_sam import PatchEmbed, Conv2d_BN, LayerNorm2d, MBConv class PatchMerging(nn.Module): def __init__(self, input_resolution, dim, out_dim, activation): super().__init__() self.input_resolution = input_resolution self.dim = dim self.out_dim = out_dim self.act = activation() self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) stride_c=2 if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576 stride_c=1 self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) def forward(self, x): if x.ndim == 3: H, W = self.input_resolution B = len(x) # (B, C, H, W) x = x.view(B, H, W, -1).permute(0, 3, 1, 2) x = self.conv1(x) x = self.act(x) x = self.conv2(x) x = self.act(x) x = self.conv3(x) return x class ConvLayer(nn.Module): def __init__(self, dim, input_resolution, depth, activation, drop_path=0., downsample=None, use_checkpoint=False, out_dim=None, conv_expand_ratio=4., ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ MBConv(dim, dim, conv_expand_ratio, activation, drop_path[i] if isinstance(drop_path, list) else drop_path, ) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( input_resolution, dim=dim, out_dim=out_dim, activation=activation) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x class FasterTinyViT(nn.Module): def __init__(self, img_size=224, in_chans=3, out_chans=256, embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_sizes=[7, 7, 14, 7], mlp_ratio=4., drop_rate=0., drop_path_rate=0.1, use_checkpoint=False, mbconv_expand_ratio=4.0, ct_size=2, conv=False, multi_scale=False, output_shape=None, ): super().__init__() self.img_size = img_size self.depths = depths self.num_layers = len(depths) self.mlp_ratio = mlp_ratio self.multi_scale = multi_scale self.output_shape = tuple(output_shape) if output_shape else None activation = nn.GELU self.patch_embed = PatchEmbed(in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation) patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): kwargs_0 = dict(dim=embed_dims[i_layer], input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))), # input_resolution=(patches_resolution[0] // (2 ** i_layer), # patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], downsample=PatchMerging if ( i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, out_dim=embed_dims[min( i_layer + 1, len(embed_dims) - 1)], activation=activation, ) kwargs_1 = dict(dim=embed_dims[i_layer], out_dim=embed_dims[i_layer+1] if ( i_layer < self.num_layers - 1) else embed_dims[i_layer], input_resolution=patches_resolution[0] // (2 ** i_layer), depth=depths[i_layer], drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], downsample=True if (i_layer < self.num_layers - 1) else False, ct_size=ct_size, conv=conv, ) if i_layer == 0: layer = ConvLayer( conv_expand_ratio=mbconv_expand_ratio, **kwargs_0, ) else: layer = FasterViTLayer( num_heads=num_heads[i_layer], window_size=window_sizes[i_layer], mlp_ratio=self.mlp_ratio, drop=drop_rate, **kwargs_1) self.layers.append(layer) # init weights self.apply(self._init_weights) self.neck = nn.Sequential( nn.Conv2d( sum(embed_dims)+embed_dims[-1] if self.multi_scale and self.output_shape else embed_dims[-1], out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay_keywords(self): return {'attention_biases'} def forward_features(self, x): if self.multi_scale and self.output_shape: output_list = [] # x: (N, C, H, W) x = self.patch_embed(x) output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear')) for layer in self.layers: x = layer(x) output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear')) x = self.neck(torch.cat(output_list, dim=1)) else: x = self.patch_embed(x) for layer in self.layers: x = layer(x) x = self.neck(x) return x def forward(self, x): x = self.forward_features(x) return x if __name__ == '__main__': from distillation.utils import get_parameter_number x = torch.randn(1, 3, 1024, 1024).cuda() fastertinyvit = FasterTinyViT(img_size=1024, in_chans=3, embed_dims=[64, 128, 256], depths=[1, 2, 1], num_heads=[2, 4, 8], window_sizes=[8, 8, 8], mlp_ratio=4., drop_rate=0., drop_path_rate=0.0, use_checkpoint=False, mbconv_expand_ratio=4.0, multi_scale=False, output_shape='').cuda() print(fastertinyvit(x).shape) print(get_parameter_number(fastertinyvit)) # torch.save(fastertinyvit, 'fastertinyvit.pt')