| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| def conv_1x1_bn(inp, oup): | |
| return nn.Sequential( | |
| nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(oup), | |
| nn.SiLU() | |
| ) | |
| def conv_nxn_bn(inp, oup, kernel_size=3, stride=1): | |
| return nn.Sequential( | |
| nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False), | |
| nn.BatchNorm2d(oup), | |
| nn.SiLU() | |
| ) | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(self.norm(x), **kwargs) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, hidden_dim), | |
| nn.SiLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Attention(nn.Module): | |
| def __init__(self, dim, heads=8, dim_head=64, dropout=0.): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| project_out = not (heads == 1 and dim_head == dim) | |
| self.heads = heads | |
| self.scale = dim_head ** -0.5 | |
| self.attend = nn.Softmax(dim = -1) | |
| self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | |
| self.to_out = nn.Sequential( | |
| nn.Linear(inner_dim, dim), | |
| nn.Dropout(dropout) | |
| ) if project_out else nn.Identity() | |
| def forward(self, x): | |
| qkv = self.to_qkv(x).chunk(3, dim=-1) | |
| q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv) | |
| dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | |
| attn = self.attend(dots) | |
| out = torch.matmul(attn, v) | |
| out = rearrange(out, 'b p h n d -> b p n (h d)') | |
| return self.to_out(out) | |
| class Transformer(nn.Module): | |
| def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): | |
| super().__init__() | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append(nn.ModuleList([ | |
| PreNorm(dim, Attention(dim, heads, dim_head, dropout)), | |
| PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) | |
| ])) | |
| def forward(self, x): | |
| for attn, ff in self.layers: | |
| x = attn(x) + x | |
| x = ff(x) + x | |
| return x | |
| class MV2Block(nn.Module): | |
| def __init__(self, inp, oup, stride=1, expansion=4): | |
| super().__init__() | |
| self.stride = stride | |
| hidden_dim = int(inp * expansion) | |
| self.use_res_connect = self.stride == 1 and inp == oup | |
| if expansion == 1: | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), | |
| nn.BatchNorm2d(hidden_dim), | |
| nn.SiLU(), | |
| nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(oup), | |
| ) | |
| else: | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(hidden_dim), | |
| nn.SiLU(), | |
| nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), | |
| nn.BatchNorm2d(hidden_dim), | |
| nn.SiLU(), | |
| nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | |
| nn.BatchNorm2d(oup), | |
| ) | |
| def forward(self, x): | |
| if self.use_res_connect: | |
| return x + self.conv(x) | |
| else: | |
| return self.conv(x) | |
| class MobileViTBlock(nn.Module): | |
| def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.): | |
| super().__init__() | |
| self.ph, self.pw = patch_size | |
| self.conv1 = conv_nxn_bn(channel, channel, kernel_size) | |
| self.conv2 = conv_1x1_bn(channel, dim) | |
| self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout) | |
| self.conv3 = conv_1x1_bn(dim, channel) | |
| self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size) | |
| def forward(self, x): | |
| y = x.clone() | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| _, _, h, w = x.shape | |
| x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw) | |
| x = self.transformer(x) | |
| x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw) | |
| x = self.conv3(x) | |
| x = torch.cat((x, y), 1) | |
| x = self.conv4(x) | |
| return x | |
| class MobileViTv3_Small(nn.Module): | |
| def __init__(self, image_size=(224, 224), num_classes=10): | |
| super().__init__() | |
| ih, iw = image_size | |
| ph, pw = 2, 2 | |
| dims = [144, 192, 240] | |
| channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640] | |
| self.conv1 = conv_nxn_bn(3, channels[0], stride=2) | |
| self.mv2 = nn.ModuleList([]) | |
| self.mv2.append(MV2Block(channels[0], channels[1], 1, 4)) | |
| self.mv2.append(MV2Block(channels[1], channels[2], 2, 4)) | |
| self.mv2.append(MV2Block(channels[2], channels[3], 1, 4)) | |
| self.mv2.append(MV2Block(channels[3], channels[4], 2, 4)) | |
| self.mvit = nn.ModuleList([]) | |
| self.mvit.append(MobileViTBlock(dims[0], 2, channels[5], 3, (ph, pw), int(dims[0]*2))) | |
| self.mv2_2 = nn.ModuleList([]) | |
| self.mv2_2.append(MV2Block(channels[5], channels[6], 2, 4)) | |
| self.mvit_2 = nn.ModuleList([]) | |
| self.mvit_2.append(MobileViTBlock(dims[1], 4, channels[7], 3, (ph, pw), int(dims[1]*2))) | |
| self.mv2_3 = nn.ModuleList([]) | |
| self.mv2_3.append(MV2Block(channels[7], channels[8], 2, 4)) | |
| self.mvit_3 = nn.ModuleList([]) | |
| self.mvit_3.append(MobileViTBlock(dims[2], 3, channels[9], 3, (ph, pw), int(dims[2]*2))) | |
| self.conv2 = conv_1x1_bn(channels[9], channels[10]) | |
| self.pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.fc = nn.Linear(channels[10], num_classes) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| for conv in self.mv2: x = conv(x) | |
| for m in self.mvit: x = m(x) | |
| for conv in self.mv2_2: x = conv(x) | |
| for m in self.mvit_2: x = m(x) | |
| for conv in self.mv2_3: x = conv(x) | |
| for m in self.mvit_3: x = m(x) | |
| x = self.conv2(x) | |
| x = self.pool(x).view(-1, x.shape[1]) | |
| return self.fc(x) |