nico-x's picture
codebase withouth model
b54146b
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class EncoderLayer(nn.Module):
def __init__(self, d_model=64, n_heads=4, ff_dim=128):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
#attention projections
self.qkv_proj = nn.Linear(d_model, d_model * 3) #efficient way of projecting to q, k, v
self.out_proj = nn.Linear(d_model, d_model)
#FF MLP
self.mlp = nn.Sequential(
nn.Linear(d_model, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, d_model)
)
#layernorms
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
B, N, D = x.shape
#multi-head attention
x_norm = self.norm1(x)
qkv = self.qkv_proj(x_norm)
qkv = qkv.reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) # qkv: (3, B, n_heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, N, N)
attn_weights = F.softmax(attn_scores, dim=-1)
attn_output = attn_weights @ v # (B, n_heads, N, head_dim)
attn_output = attn_output.transpose(1, 2).reshape(B, N, D) # (B, N, D)
attn_output = self.out_proj(attn_output)
x = x + attn_output # Residual connection
# === Feedforward ===
x_norm = self.norm2(x)
x = x + self.mlp(x_norm) # Residual
return x
class TransformerEncoder(nn.Module):
def __init__(self, depth=4, d_model=64, n_heads=4, ff_dim=128, num_patches=16):
super().__init__()
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, d_model)) # (1, 16, 64)
self.layers = nn.ModuleList([
EncoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim)
for _ in range(depth)
])
def forward(self, x):
"""
x: Tensor of shape (B, num_patches, d_model)
returns: Tensor of same shape (B, num_patches, d_model)
"""
x = x + self.pos_embedding
for layer in self.layers:
x = layer(x)
return x
# simple testing of dimensions
if __name__ == "__main__":
import torch
B = 4 # batch size
N = 16 # number of patches
D = 64 # embedding dim
dummy_input = torch.randn(B, N, D)
print("Testing EncoderLayer...")
layer = EncoderLayer(d_model=D, n_heads=4, ff_dim=128)
out = layer(dummy_input)
print("EncoderLayer output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64])
print("Testing TransformerEncoder...")
encoder = TransformerEncoder(depth=3, d_model=D, n_heads=4, ff_dim=128, num_patches=N)
out = encoder(dummy_input)
print("TransformerEncoder output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64])