Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class EncoderLayer(nn.Module): | |
| def __init__(self, d_model=64, n_heads=4, ff_dim=128): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| #attention projections | |
| self.qkv_proj = nn.Linear(d_model, d_model * 3) #efficient way of projecting to q, k, v | |
| self.out_proj = nn.Linear(d_model, d_model) | |
| #FF MLP | |
| self.mlp = nn.Sequential( | |
| nn.Linear(d_model, ff_dim), | |
| nn.GELU(), | |
| nn.Linear(ff_dim, d_model) | |
| ) | |
| #layernorms | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| def forward(self, x): | |
| B, N, D = x.shape | |
| #multi-head attention | |
| x_norm = self.norm1(x) | |
| qkv = self.qkv_proj(x_norm) | |
| qkv = qkv.reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) # qkv: (3, B, n_heads, N, head_dim) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, N, N) | |
| attn_weights = F.softmax(attn_scores, dim=-1) | |
| attn_output = attn_weights @ v # (B, n_heads, N, head_dim) | |
| attn_output = attn_output.transpose(1, 2).reshape(B, N, D) # (B, N, D) | |
| attn_output = self.out_proj(attn_output) | |
| x = x + attn_output # Residual connection | |
| # === Feedforward === | |
| x_norm = self.norm2(x) | |
| x = x + self.mlp(x_norm) # Residual | |
| return x | |
| class TransformerEncoder(nn.Module): | |
| def __init__(self, depth=4, d_model=64, n_heads=4, ff_dim=128, num_patches=16): | |
| super().__init__() | |
| self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, d_model)) # (1, 16, 64) | |
| self.layers = nn.ModuleList([ | |
| EncoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim) | |
| for _ in range(depth) | |
| ]) | |
| def forward(self, x): | |
| """ | |
| x: Tensor of shape (B, num_patches, d_model) | |
| returns: Tensor of same shape (B, num_patches, d_model) | |
| """ | |
| x = x + self.pos_embedding | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| # simple testing of dimensions | |
| if __name__ == "__main__": | |
| import torch | |
| B = 4 # batch size | |
| N = 16 # number of patches | |
| D = 64 # embedding dim | |
| dummy_input = torch.randn(B, N, D) | |
| print("Testing EncoderLayer...") | |
| layer = EncoderLayer(d_model=D, n_heads=4, ff_dim=128) | |
| out = layer(dummy_input) | |
| print("EncoderLayer output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64]) | |
| print("Testing TransformerEncoder...") | |
| encoder = TransformerEncoder(depth=3, d_model=D, n_heads=4, ff_dim=128, num_patches=N) | |
| out = encoder(dummy_input) | |
| print("TransformerEncoder output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64]) |