import torch import torch.nn as nn import torch.nn.functional as F import math class EncoderLayer(nn.Module): def __init__(self, d_model=64, n_heads=4, ff_dim=128): super().__init__() self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads #attention projections self.qkv_proj = nn.Linear(d_model, d_model * 3) #efficient way of projecting to q, k, v self.out_proj = nn.Linear(d_model, d_model) #FF MLP self.mlp = nn.Sequential( nn.Linear(d_model, ff_dim), nn.GELU(), nn.Linear(ff_dim, d_model) ) #layernorms self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x): B, N, D = x.shape #multi-head attention x_norm = self.norm1(x) qkv = self.qkv_proj(x_norm) qkv = qkv.reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) # qkv: (3, B, n_heads, N, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, N, N) attn_weights = F.softmax(attn_scores, dim=-1) attn_output = attn_weights @ v # (B, n_heads, N, head_dim) attn_output = attn_output.transpose(1, 2).reshape(B, N, D) # (B, N, D) attn_output = self.out_proj(attn_output) x = x + attn_output # Residual connection # === Feedforward === x_norm = self.norm2(x) x = x + self.mlp(x_norm) # Residual return x class TransformerEncoder(nn.Module): def __init__(self, depth=4, d_model=64, n_heads=4, ff_dim=128, num_patches=16): super().__init__() self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, d_model)) # (1, 16, 64) self.layers = nn.ModuleList([ EncoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim) for _ in range(depth) ]) def forward(self, x): """ x: Tensor of shape (B, num_patches, d_model) returns: Tensor of same shape (B, num_patches, d_model) """ x = x + self.pos_embedding for layer in self.layers: x = layer(x) return x # simple testing of dimensions if __name__ == "__main__": import torch B = 4 # batch size N = 16 # number of patches D = 64 # embedding dim dummy_input = torch.randn(B, N, D) print("Testing EncoderLayer...") layer = EncoderLayer(d_model=D, n_heads=4, ff_dim=128) out = layer(dummy_input) print("EncoderLayer output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64]) print("Testing TransformerEncoder...") encoder = TransformerEncoder(depth=3, d_model=D, n_heads=4, ff_dim=128, num_patches=N) out = encoder(dummy_input) print("TransformerEncoder output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64])