import torch import torch.nn as nn import torch.nn.functional as F class PositionEmbs(nn.Module): def __init__(self, num_patches, emb_dim, dropout_rate=0.1): super(PositionEmbs, self).__init__() self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim)) if dropout_rate > 0: self.dropout = nn.Dropout(dropout_rate) else: self.dropout = None def forward(self, x): out = x + self.pos_embedding if self.dropout: out = self.dropout(out) return out class MlpBlock(nn.Module): """ Transformer Feed-Forward Block """ def __init__(self, in_dim, mlp_dim, out_dim, dropout_rate=0.1): super(MlpBlock, self).__init__() # init layers self.fc1 = nn.Linear(in_dim, mlp_dim) self.fc2 = nn.Linear(mlp_dim, out_dim) self.act = nn.GELU() if dropout_rate > 0.0: self.dropout1 = nn.Dropout(dropout_rate) self.dropout2 = nn.Dropout(dropout_rate) else: self.dropout1 = None self.dropout2 = None def forward(self, x): out = self.fc1(x) out = self.act(out) if self.dropout1: out = self.dropout1(out) out = self.fc2(out) if self.dropout2: out = self.dropout2(out) return out class LinearGeneral(nn.Module): def __init__(self, in_dim=(768,), feat_dim=(12, 64)): super(LinearGeneral, self).__init__() #self.weight = nn.Parameter(torch.randn(*in_dim, *feat_dim)) self.weight = torch.randn(*in_dim, *feat_dim) self.weight.normal_(0, 0.02) self.weight = nn.Parameter(self.weight) self.bias = nn.Parameter(torch.zeros(*feat_dim)) def forward(self, x, dims): a = torch.tensordot(x, self.weight, dims=dims) + self.bias return a class SelfAttention(nn.Module): def __init__(self, in_dim, heads=8, dropout_rate=0.1): super(SelfAttention, self).__init__() self.heads = heads self.head_dim = in_dim // heads self.scale = self.head_dim ** 0.5 self.query = LinearGeneral((in_dim,), (self.heads, self.head_dim)) self.key = LinearGeneral((in_dim,), (self.heads, self.head_dim)) self.value = LinearGeneral((in_dim,), (self.heads, self.head_dim)) self.out = LinearGeneral((self.heads, self.head_dim), (in_dim,)) if dropout_rate > 0: self.dropout = nn.Dropout(dropout_rate) else: self.dropout = None def forward(self, x, vis_attn = False): b, n, _ = x.shape q = self.query(x, dims=([2], [0])) k = self.key(x, dims=([2], [0])) v = self.value(x, dims=([2], [0])) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale attn_weights = F.softmax(attn_weights, dim=-1) out = torch.matmul(attn_weights, v) out = out.permute(0, 2, 1, 3) out = self.out(out, dims=([2, 3], [0, 1])) if not vis_attn: return out else: return out, attn_weights class EncoderBlock(nn.Module): def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1, normalize = 'layer_norm'): super(EncoderBlock, self).__init__() if normalize == 'layer_norm': self.norm1 = nn.LayerNorm(in_dim) self.norm2 = nn.LayerNorm(in_dim) elif normalize == 'group_norm': self.norm1 = Normalize(in_dim) self.norm2 = Normalize(in_dim) self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate) if dropout_rate > 0: self.dropout = nn.Dropout(dropout_rate) else: self.dropout = None self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate) def forward(self, x, vis_attn = False): residual = x out = self.norm1(x) if vis_attn: out, attn_weights = self.attn(out, vis_attn) else: out = self.attn(out, vis_attn) if self.dropout: out = self.dropout(out) out += residual residual = out out = self.norm2(out) out = self.mlp(out) out += residual if vis_attn: return out, attn_weights else: return out class Encoder(nn.Module): def __init__(self, num_patches, emb_dim, mlp_dim, num_layers=12, num_heads=12, dropout_rate=0.1, attn_dropout_rate=0.0): super(Encoder, self).__init__() # positional embedding self.pos_embedding = PositionEmbs(num_patches, emb_dim, dropout_rate) # encoder blocks in_dim = emb_dim self.encoder_layers = nn.ModuleList() for i in range(num_layers): layer = EncoderBlock(in_dim, mlp_dim, num_heads, dropout_rate, attn_dropout_rate) self.encoder_layers.append(layer) self.norm = nn.LayerNorm(in_dim) def forward(self, x): out = self.pos_embedding(x) for layer in self.encoder_layers: out = layer(out) out = self.norm(out) return out class VisionTransformer(nn.Module): """ Vision Transformer """ def __init__(self, image_size=(256, 256), patch_size=(16, 16), emb_dim=768, mlp_dim=3072, num_heads=12, num_layers=12, num_classes=1000, attn_dropout_rate=0.0, dropout_rate=0.1, feat_dim=None): super(VisionTransformer, self).__init__() h, w = image_size # embedding layer fh, fw = patch_size gh, gw = h // fh, w // fw num_patches = gh * gw self.embedding = nn.Conv2d(3, emb_dim, kernel_size=(fh, fw), stride=(fh, fw)) # class token self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim)) # transformer self.transformer = Encoder( num_patches=num_patches, emb_dim=emb_dim, mlp_dim=mlp_dim, num_layers=num_layers, num_heads=num_heads, dropout_rate=dropout_rate, attn_dropout_rate=attn_dropout_rate) # classfier self.classifier = nn.Linear(emb_dim, num_classes) def forward(self, x): emb = self.embedding(x) # (n, c, gh, gw) emb = emb.permute(0, 2, 3, 1) # (n, gh, hw, c) b, h, w, c = emb.shape emb = emb.reshape(b, h * w, c) # prepend class token cls_token = self.cls_token.repeat(b, 1, 1) emb = torch.cat([cls_token, emb], dim=1) # transformer feat = self.transformer(emb) # classifier logits = self.classifier(feat[:, 0]) return logits def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b,c,h,w = q.shape q = q.reshape(b,c,h*w) q = q.permute(0,2,1) # b,hw,c k = k.reshape(b,c,h*w) # b,c,hw w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b,c,h*w) w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b,c,h,w) h_ = self.proj_out(h_) return x+h_ if __name__ == '__main__': model = VisionTransformer(num_layers=2) import pdb; pdb.set_trace() x = torch.randn((2, 3, 256, 256)) out = model(x)