YouLiXiya's picture
Upload 22 files
7dbe662
raw
history blame
No virus
8.71 kB
from typing import Tuple, List, Union
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from sam_extension.distillation_models.fastervit import FasterViTLayer
from segment_anything.mobile_encoder.tiny_vit_sam import PatchEmbed, Conv2d_BN, LayerNorm2d, MBConv
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, out_dim, activation):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim
self.act = activation()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
stride_c=2
if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576
stride_c=1
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
# (B, C, H, W)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
return x
class ConvLayer(nn.Module):
def __init__(self, dim, input_resolution, depth,
activation,
drop_path=0., downsample=None, use_checkpoint=False,
out_dim=None,
conv_expand_ratio=4.,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
MBConv(dim, dim, conv_expand_ratio, activation,
drop_path[i] if isinstance(drop_path, list) else drop_path,
)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class FasterTinyViT(nn.Module):
def __init__(self, img_size=224,
in_chans=3,
out_chans=256,
embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
ct_size=2,
conv=False,
multi_scale=False,
output_shape=None,
):
super().__init__()
self.img_size = img_size
self.depths = depths
self.num_layers = len(depths)
self.mlp_ratio = mlp_ratio
self.multi_scale = multi_scale
self.output_shape = tuple(output_shape) if output_shape else None
activation = nn.GELU
self.patch_embed = PatchEmbed(in_chans=in_chans,
embed_dim=embed_dims[0],
resolution=img_size,
activation=activation)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
kwargs_0 = dict(dim=embed_dims[i_layer],
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=PatchMerging if (
i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(
i_layer + 1, len(embed_dims) - 1)],
activation=activation,
)
kwargs_1 = dict(dim=embed_dims[i_layer],
out_dim=embed_dims[i_layer+1] if (
i_layer < self.num_layers - 1) else embed_dims[i_layer],
input_resolution=patches_resolution[0] // (2 ** i_layer),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=True if (i_layer < self.num_layers - 1) else False,
ct_size=ct_size,
conv=conv,
)
if i_layer == 0:
layer = ConvLayer(
conv_expand_ratio=mbconv_expand_ratio,
**kwargs_0,
)
else:
layer = FasterViTLayer(
num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
**kwargs_1)
self.layers.append(layer)
# init weights
self.apply(self._init_weights)
self.neck = nn.Sequential(
nn.Conv2d(
sum(embed_dims)+embed_dims[-1] if self.multi_scale and self.output_shape else embed_dims[-1],
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'attention_biases'}
def forward_features(self, x):
if self.multi_scale and self.output_shape:
output_list = []
# x: (N, C, H, W)
x = self.patch_embed(x)
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
for layer in self.layers:
x = layer(x)
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
x = self.neck(torch.cat(output_list, dim=1))
else:
x = self.patch_embed(x)
for layer in self.layers:
x = layer(x)
x = self.neck(x)
return x
def forward(self, x):
x = self.forward_features(x)
return x
if __name__ == '__main__':
from distillation.utils import get_parameter_number
x = torch.randn(1, 3, 1024, 1024).cuda()
fastertinyvit = FasterTinyViT(img_size=1024, in_chans=3,
embed_dims=[64, 128, 256],
depths=[1, 2, 1],
num_heads=[2, 4, 8],
window_sizes=[8, 8, 8],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
multi_scale=False,
output_shape='').cuda()
print(fastertinyvit(x).shape)
print(get_parameter_number(fastertinyvit))
# torch.save(fastertinyvit, 'fastertinyvit.pt')