import torch import torchvision from torch import nn from torchvision import datasets, transforms PATCH_SIZE = 16 class PatchEmbeddings(nn.Module): def __init__(self, in_channels: int=3, patch_size: int=16, embedding_dim: int=768): super().__init__() self.generate_patches = nn.Conv2d(in_channels=in_channels, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size, padding=0) self.flatten = nn.Flatten(start_dim=2, end_dim=3) def forward(self, x: torch.Tensor): image_resolution = x.shape[-1] assert image_resolution % PATCH_SIZE == 0, f"Image size must be divisible by patch size!" return self.flatten(self.generate_patches(x)).permute(0, 2, 1) class MultiheadSelfAttention(nn.Module): def __init__(self, embedding_dim: int=768, num_heads: int=12, attn_dropout: int=0): super().__init__() self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim) self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, dropout=attn_dropout, batch_first=True) def forward(self, x: torch.Tensor): x = self.layer_norm(x) attn_output, _ = self.multihead_attn(query=x, key=x, value=x, need_weights=False) return attn_output class MLPBlock(nn.Module): def __init__(self, embedding_dim: int=768, mlp_size: int=3072, dropout: int=0.1): super().__init__() self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim) self.mlp = nn.Sequential( nn.Linear(in_features=embedding_dim, out_features=mlp_size), nn.GELU(), nn.Dropout(p=dropout), nn.Linear(in_features=mlp_size, out_features=embedding_dim), nn.Dropout(p=dropout) ) def forward(self, x: torch.Tensor): return self.mlp(self.layer_norm(x)) class TransformerEncoderBlock(nn.Module): def __init__(self, embedding_dim: int=768, mlp_size: int=3072, num_heads: int=12, mlp_dropout: int=0.1, attn_dropout: int=0): super().__init__() self.msa_block = MultiheadSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads, attn_dropout=attn_dropout) self.mlp_block = MLPBlock(embedding_dim=embedding_dim, mlp_size=mlp_size, dropout=mlp_dropout) def forward(self, x: torch.Tensor): x = self.msa_block(x) + x x = self.mlp_block(x) + x return x class VisionTransformer(nn.Module): def __init__(self, img_size: int=IMG_SIZE, in_channels: int=3, patch_size: int=16, num_transformer_layers: int=12, embedding_dim: int=768, mlp_size: int=3072, num_heads: int=12, attn_dropout: int=0, mlp_dropout: int=0.1, embedding_dropout: int=0.1, num_classes: int=38): super().__init__() assert img_size % patch_size == 0, f"Image size must be divisible by patch size!" self.num_patches = (img_size * img_size) // patch_size**2 self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim), requires_grad=True) self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim), requires_grad=True) self.embedding_dropout = nn.Dropout(p=embedding_dropout) self.patch_embeddings = PatchEmbeddings(in_channels=in_channels, patch_size=patch_size, embedding_dim=embedding_dim) self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim, num_heads=num_heads, mlp_size=mlp_size, mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)]) self.classifier = nn.Sequential( nn.LayerNorm(normalized_shape=embedding_dim), nn.Linear(in_features=embedding_dim, out_features=num_classes) ) def forward(self, x: torch.Tensor): batch_size = x.shape[0] class_token = self.class_embedding.expand(batch_size, -1, -1) x = self.patch_embeddings(x) x = torch.cat((class_token, x), dim=1) x = self.position_embedding + x x = self.embedding_dropout(x) x = self.transformer_encoder(x) x = self.classifier(x[:, 0]) return x with open("class_names.ob", "rb") as fp: class_names = pickle.load(fp) vision_transformer = VisionTransformer(num_classes=len(class_names))