import torch import torch.nn as nn import torch.nn.functional as F class PositionEmbs(nn.Module): def __init__(self, num_patches, emb_dim, dropout_rate=0.1): super(PositionEmbs, self).__init__() self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim)) if dropout_rate > 0: self.dropout = nn.Dropout(dropout_rate) else: self.dropout = None def forward(self, x): out = x + self.pos_embedding if self.dropout: out = self.dropout(out) return out class MlpBlock(nn.Module): """ Transformer Feed-Forward Block """ def __init__(self, in_dim, mlp_dim, out_dim, dropout_rate=0.1): super(MlpBlock, self).__init__() # init layers self.fc1 = nn.Linear(in_dim, mlp_dim) self.fc2 = nn.Linear(mlp_dim, out_dim) self.act = nn.GELU() if dropout_rate > 0.0: self.dropout1 = nn.Dropout(dropout_rate) self.dropout2 = nn.Dropout(dropout_rate) else: self.dropout1 = None self.dropout2 = None def forward(self, x): out = self.fc1(x) out = self.act(out) if self.dropout1: out = self.dropout1(out) out = self.fc2(out) if self.dropout2: out = self.dropout2(out) return out class LinearGeneral(nn.Module): def __init__(self, in_dim=(768,), feat_dim=(12, 64)): super(LinearGeneral, self).__init__() self.weight = nn.Parameter(torch.randn(*in_dim, *feat_dim)) self.bias = nn.Parameter(torch.zeros(*feat_dim)) def forward(self, x, dims): a = torch.tensordot(x, self.weight, dims=dims) + self.bias return a class SelfAttention(nn.Module): def __init__(self, in_dim, heads=8, dropout_rate=0.1): super(SelfAttention, self).__init__() self.heads = heads self.head_dim = in_dim // heads self.scale = self.head_dim ** 0.5 self.query = LinearGeneral((in_dim,), (self.heads, self.head_dim)) self.key = LinearGeneral((in_dim,), (self.heads, self.head_dim)) self.value = LinearGeneral((in_dim,), (self.heads, self.head_dim)) self.out = LinearGeneral((self.heads, self.head_dim), (in_dim,)) if dropout_rate > 0: self.dropout = nn.Dropout(dropout_rate) else: self.dropout = None self.cluster_mlp = nn.Sequential(nn.Linear(256 * 100, 64 * 100), nn.LeakyReLU(0.2), nn.Linear(64 * 100, 8 * 100)) def forward(self, x): b, n, _ = x.shape q = self.query(x, dims=([2], [0])) q = self.cluster_mlp(q.view(b, -1)).view(b, 8, 1, 100) k = self.key(x, dims=([2], [0])) v = self.value(x, dims=([2], [0])) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale attn_weights = F.softmax(attn_weights, dim=-1) out = torch.matmul(attn_weights, v) out = out.permute(0, 2, 1, 3) out = self.out(out, dims=([2, 3], [0, 1])) return out class EncoderBlock(nn.Module): def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1): super(EncoderBlock, self).__init__() self.norm1 = nn.LayerNorm(in_dim) self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate) if dropout_rate > 0: self.dropout = nn.Dropout(dropout_rate) else: self.dropout = None self.norm2 = nn.LayerNorm(in_dim) self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate) def forward(self, x): residual = x out = self.norm1(x) out = self.attn(out) if self.dropout: out = self.dropout(out) #out += residual residual = out out = self.norm2(out) out = self.mlp(out) out += residual return out class Encoder(nn.Module): def __init__(self, num_patches, emb_dim, mlp_dim, num_layers=12, num_heads=12, dropout_rate=0.1, attn_dropout_rate=0.0): super(Encoder, self).__init__() # positional embedding self.pos_embedding = PositionEmbs(num_patches, emb_dim, dropout_rate) # encoder blocks in_dim = emb_dim self.encoder_layers = nn.ModuleList() for i in range(num_layers): layer = EncoderBlock(in_dim, mlp_dim, num_heads, dropout_rate, attn_dropout_rate) self.encoder_layers.append(layer) self.norm = nn.LayerNorm(in_dim) def forward(self, x): out = self.pos_embedding(x) for layer in self.encoder_layers: out = layer(out) out = self.norm(out) return out class VisionTransformer(nn.Module): """ Vision Transformer """ def __init__(self, image_size=(256, 256), patch_size=(16, 16), emb_dim=768, mlp_dim=3072, num_heads=12, num_layers=12, num_classes=1000, attn_dropout_rate=0.0, dropout_rate=0.1, feat_dim=None): super(VisionTransformer, self).__init__() h, w = image_size # embedding layer fh, fw = patch_size gh, gw = h // fh, w // fw num_patches = gh * gw self.embedding = nn.Conv2d(3, emb_dim, kernel_size=(fh, fw), stride=(fh, fw)) # class token self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim)) # transformer self.transformer = Encoder( num_patches=num_patches, emb_dim=emb_dim, mlp_dim=mlp_dim, num_layers=num_layers, num_heads=num_heads, dropout_rate=dropout_rate, attn_dropout_rate=attn_dropout_rate) # classfier self.classifier = nn.Linear(emb_dim, num_classes) def forward(self, x): emb = self.embedding(x) # (n, c, gh, gw) emb = emb.permute(0, 2, 3, 1) # (n, gh, hw, c) b, h, w, c = emb.shape emb = emb.reshape(b, h * w, c) # prepend class token cls_token = self.cls_token.repeat(b, 1, 1) emb = torch.cat([cls_token, emb], dim=1) # transformer feat = self.transformer(emb) # classifier logits = self.classifier(feat[:, 0]) return logits if __name__ == '__main__': model = VisionTransformer(num_layers=2) import pdb; pdb.set_trace() x = torch.randn((2, 3, 256, 256)) out = model(x)