File size: 5,177 Bytes
36fc972 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torchvision
from torch import nn
from torchvision import datasets, transforms
PATCH_SIZE = 16
class PatchEmbeddings(nn.Module):
def __init__(self, in_channels: int=3,
patch_size: int=16,
embedding_dim: int=768):
super().__init__()
self.generate_patches = nn.Conv2d(in_channels=in_channels,
out_channels=embedding_dim,
kernel_size=patch_size,
stride=patch_size, padding=0)
self.flatten = nn.Flatten(start_dim=2, end_dim=3)
def forward(self, x: torch.Tensor):
image_resolution = x.shape[-1]
assert image_resolution % PATCH_SIZE == 0, f"Image size must be divisible by patch size!"
return self.flatten(self.generate_patches(x)).permute(0, 2, 1)
class MultiheadSelfAttention(nn.Module):
def __init__(self, embedding_dim: int=768,
num_heads: int=12, attn_dropout: int=0):
super().__init__()
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads,
dropout=attn_dropout, batch_first=True)
def forward(self, x: torch.Tensor):
x = self.layer_norm(x)
attn_output, _ = self.multihead_attn(query=x, key=x, value=x,
need_weights=False)
return attn_output
class MLPBlock(nn.Module):
def __init__(self, embedding_dim: int=768,
mlp_size: int=3072, dropout: int=0.1):
super().__init__()
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
self.mlp = nn.Sequential(
nn.Linear(in_features=embedding_dim,
out_features=mlp_size),
nn.GELU(),
nn.Dropout(p=dropout),
nn.Linear(in_features=mlp_size,
out_features=embedding_dim),
nn.Dropout(p=dropout)
)
def forward(self, x: torch.Tensor):
return self.mlp(self.layer_norm(x))
class TransformerEncoderBlock(nn.Module):
def __init__(self, embedding_dim: int=768,
mlp_size: int=3072, num_heads: int=12,
mlp_dropout: int=0.1, attn_dropout: int=0):
super().__init__()
self.msa_block = MultiheadSelfAttention(embedding_dim=embedding_dim,
num_heads=num_heads, attn_dropout=attn_dropout)
self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
mlp_size=mlp_size, dropout=mlp_dropout)
def forward(self, x: torch.Tensor):
x = self.msa_block(x) + x
x = self.mlp_block(x) + x
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size: int=IMG_SIZE,
in_channels: int=3, patch_size: int=16,
num_transformer_layers: int=12, embedding_dim: int=768,
mlp_size: int=3072, num_heads: int=12,
attn_dropout: int=0, mlp_dropout: int=0.1,
embedding_dropout: int=0.1, num_classes: int=38):
super().__init__()
assert img_size % patch_size == 0, f"Image size must be divisible by patch size!"
self.num_patches = (img_size * img_size) // patch_size**2
self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
requires_grad=True)
self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
requires_grad=True)
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
self.patch_embeddings = PatchEmbeddings(in_channels=in_channels,
patch_size=patch_size, embedding_dim=embedding_dim)
self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
num_heads=num_heads, mlp_size=mlp_size,
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
self.classifier = nn.Sequential(
nn.LayerNorm(normalized_shape=embedding_dim),
nn.Linear(in_features=embedding_dim,
out_features=num_classes)
)
def forward(self, x: torch.Tensor):
batch_size = x.shape[0]
class_token = self.class_embedding.expand(batch_size, -1, -1)
x = self.patch_embeddings(x)
x = torch.cat((class_token, x), dim=1)
x = self.position_embedding + x
x = self.embedding_dropout(x)
x = self.transformer_encoder(x)
x = self.classifier(x[:, 0])
return x
with open("class_names.ob", "rb") as fp:
class_names = pickle.load(fp)
vision_transformer = VisionTransformer(num_classes=len(class_names))
|