|
import torch |
|
import torchvision |
|
from torch import nn |
|
from torchvision import datasets, transforms |
|
|
|
PATCH_SIZE = 16 |
|
|
|
class PatchEmbeddings(nn.Module): |
|
def __init__(self, in_channels: int=3, |
|
patch_size: int=16, |
|
embedding_dim: int=768): |
|
super().__init__() |
|
|
|
self.generate_patches = nn.Conv2d(in_channels=in_channels, |
|
out_channels=embedding_dim, |
|
kernel_size=patch_size, |
|
stride=patch_size, padding=0) |
|
|
|
self.flatten = nn.Flatten(start_dim=2, end_dim=3) |
|
|
|
def forward(self, x: torch.Tensor): |
|
image_resolution = x.shape[-1] |
|
assert image_resolution % PATCH_SIZE == 0, f"Image size must be divisible by patch size!" |
|
|
|
return self.flatten(self.generate_patches(x)).permute(0, 2, 1) |
|
|
|
class MultiheadSelfAttention(nn.Module): |
|
def __init__(self, embedding_dim: int=768, |
|
num_heads: int=12, attn_dropout: int=0): |
|
super().__init__() |
|
|
|
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim) |
|
|
|
self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, |
|
dropout=attn_dropout, batch_first=True) |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = self.layer_norm(x) |
|
|
|
attn_output, _ = self.multihead_attn(query=x, key=x, value=x, |
|
need_weights=False) |
|
return attn_output |
|
|
|
class MLPBlock(nn.Module): |
|
def __init__(self, embedding_dim: int=768, |
|
mlp_size: int=3072, dropout: int=0.1): |
|
super().__init__() |
|
|
|
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim) |
|
|
|
self.mlp = nn.Sequential( |
|
nn.Linear(in_features=embedding_dim, |
|
out_features=mlp_size), |
|
nn.GELU(), |
|
nn.Dropout(p=dropout), |
|
nn.Linear(in_features=mlp_size, |
|
out_features=embedding_dim), |
|
nn.Dropout(p=dropout) |
|
) |
|
|
|
def forward(self, x: torch.Tensor): |
|
return self.mlp(self.layer_norm(x)) |
|
|
|
class TransformerEncoderBlock(nn.Module): |
|
def __init__(self, embedding_dim: int=768, |
|
mlp_size: int=3072, num_heads: int=12, |
|
mlp_dropout: int=0.1, attn_dropout: int=0): |
|
super().__init__() |
|
|
|
self.msa_block = MultiheadSelfAttention(embedding_dim=embedding_dim, |
|
num_heads=num_heads, attn_dropout=attn_dropout) |
|
|
|
self.mlp_block = MLPBlock(embedding_dim=embedding_dim, |
|
mlp_size=mlp_size, dropout=mlp_dropout) |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = self.msa_block(x) + x |
|
x = self.mlp_block(x) + x |
|
return x |
|
|
|
class VisionTransformer(nn.Module): |
|
def __init__(self, img_size: int=IMG_SIZE, |
|
in_channels: int=3, patch_size: int=16, |
|
num_transformer_layers: int=12, embedding_dim: int=768, |
|
mlp_size: int=3072, num_heads: int=12, |
|
attn_dropout: int=0, mlp_dropout: int=0.1, |
|
embedding_dropout: int=0.1, num_classes: int=38): |
|
super().__init__() |
|
|
|
assert img_size % patch_size == 0, f"Image size must be divisible by patch size!" |
|
|
|
self.num_patches = (img_size * img_size) // patch_size**2 |
|
|
|
self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim), |
|
requires_grad=True) |
|
|
|
self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim), |
|
requires_grad=True) |
|
|
|
self.embedding_dropout = nn.Dropout(p=embedding_dropout) |
|
|
|
self.patch_embeddings = PatchEmbeddings(in_channels=in_channels, |
|
patch_size=patch_size, embedding_dim=embedding_dim) |
|
|
|
self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim, |
|
num_heads=num_heads, mlp_size=mlp_size, |
|
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)]) |
|
|
|
self.classifier = nn.Sequential( |
|
nn.LayerNorm(normalized_shape=embedding_dim), |
|
nn.Linear(in_features=embedding_dim, |
|
out_features=num_classes) |
|
) |
|
|
|
def forward(self, x: torch.Tensor): |
|
batch_size = x.shape[0] |
|
|
|
class_token = self.class_embedding.expand(batch_size, -1, -1) |
|
|
|
x = self.patch_embeddings(x) |
|
x = torch.cat((class_token, x), dim=1) |
|
|
|
x = self.position_embedding + x |
|
x = self.embedding_dropout(x) |
|
|
|
x = self.transformer_encoder(x) |
|
x = self.classifier(x[:, 0]) |
|
|
|
return x |
|
|
|
with open("class_names.ob", "rb") as fp: |
|
class_names = pickle.load(fp) |
|
|
|
vision_transformer = VisionTransformer(num_classes=len(class_names)) |
|
|