LeafDoc / model.py
TexR6's picture
initial commit
36fc972
raw
history blame
5.18 kB
import torch
import torchvision
from torch import nn
from torchvision import datasets, transforms
PATCH_SIZE = 16
class PatchEmbeddings(nn.Module):
def __init__(self, in_channels: int=3,
patch_size: int=16,
embedding_dim: int=768):
super().__init__()
self.generate_patches = nn.Conv2d(in_channels=in_channels,
out_channels=embedding_dim,
kernel_size=patch_size,
stride=patch_size, padding=0)
self.flatten = nn.Flatten(start_dim=2, end_dim=3)
def forward(self, x: torch.Tensor):
image_resolution = x.shape[-1]
assert image_resolution % PATCH_SIZE == 0, f"Image size must be divisible by patch size!"
return self.flatten(self.generate_patches(x)).permute(0, 2, 1)
class MultiheadSelfAttention(nn.Module):
def __init__(self, embedding_dim: int=768,
num_heads: int=12, attn_dropout: int=0):
super().__init__()
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads,
dropout=attn_dropout, batch_first=True)
def forward(self, x: torch.Tensor):
x = self.layer_norm(x)
attn_output, _ = self.multihead_attn(query=x, key=x, value=x,
need_weights=False)
return attn_output
class MLPBlock(nn.Module):
def __init__(self, embedding_dim: int=768,
mlp_size: int=3072, dropout: int=0.1):
super().__init__()
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
self.mlp = nn.Sequential(
nn.Linear(in_features=embedding_dim,
out_features=mlp_size),
nn.GELU(),
nn.Dropout(p=dropout),
nn.Linear(in_features=mlp_size,
out_features=embedding_dim),
nn.Dropout(p=dropout)
)
def forward(self, x: torch.Tensor):
return self.mlp(self.layer_norm(x))
class TransformerEncoderBlock(nn.Module):
def __init__(self, embedding_dim: int=768,
mlp_size: int=3072, num_heads: int=12,
mlp_dropout: int=0.1, attn_dropout: int=0):
super().__init__()
self.msa_block = MultiheadSelfAttention(embedding_dim=embedding_dim,
num_heads=num_heads, attn_dropout=attn_dropout)
self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
mlp_size=mlp_size, dropout=mlp_dropout)
def forward(self, x: torch.Tensor):
x = self.msa_block(x) + x
x = self.mlp_block(x) + x
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size: int=IMG_SIZE,
in_channels: int=3, patch_size: int=16,
num_transformer_layers: int=12, embedding_dim: int=768,
mlp_size: int=3072, num_heads: int=12,
attn_dropout: int=0, mlp_dropout: int=0.1,
embedding_dropout: int=0.1, num_classes: int=38):
super().__init__()
assert img_size % patch_size == 0, f"Image size must be divisible by patch size!"
self.num_patches = (img_size * img_size) // patch_size**2
self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
requires_grad=True)
self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
requires_grad=True)
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
self.patch_embeddings = PatchEmbeddings(in_channels=in_channels,
patch_size=patch_size, embedding_dim=embedding_dim)
self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
num_heads=num_heads, mlp_size=mlp_size,
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
self.classifier = nn.Sequential(
nn.LayerNorm(normalized_shape=embedding_dim),
nn.Linear(in_features=embedding_dim,
out_features=num_classes)
)
def forward(self, x: torch.Tensor):
batch_size = x.shape[0]
class_token = self.class_embedding.expand(batch_size, -1, -1)
x = self.patch_embeddings(x)
x = torch.cat((class_token, x), dim=1)
x = self.position_embedding + x
x = self.embedding_dropout(x)
x = self.transformer_encoder(x)
x = self.classifier(x[:, 0])
return x
with open("class_names.ob", "rb") as fp:
class_names = pickle.load(fp)
vision_transformer = VisionTransformer(num_classes=len(class_names))