nico-x's picture
codebase withouth model
b54146b
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class DecoderLayer(nn.Module):
def __init__(self, d_model=64, n_heads=4, ff_dim=128):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
assert d_model % n_heads == 0, "d_model must be divisible by number of heads"
# Self-attention: Q, K, V from decoder input
self.self_attn_proj = nn.Linear(d_model, 3 * d_model)
# Cross-attention: Q from decoder input, K/V from encoder output
self.cross_attn_q = nn.Linear(d_model, d_model)
self.cross_attn_kv = nn.Linear(d_model, 2 * d_model)
# Output projections
self.self_out = nn.Linear(d_model, d_model)
self.cross_out = nn.Linear(d_model, d_model)
# Feedforward MLP
self.mlp = nn.Sequential(
nn.Linear(d_model, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, d_model)
)
# LayerNorms
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, enc_out):
"""
x: (B, T, D) - decoder input embeddings
enc_out: (B, N, D) - encoder outputs (image patch representations)
Returns: (B, T, D)
"""
B, T, D = x.shape
_, N, _ = enc_out.shape
# Masked Self-Attention
x_norm = self.norm1(x)
qkv = self.self_attn_proj(x_norm).reshape(B, T, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_heads, T, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, T)
# Causal mask: prevent attention to future positions
mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_out = attn_weights @ v # (B, n_heads, T, head_dim)
attn_out = attn_out.transpose(1, 2).reshape(B, T, D)
attn_out = self.self_out(attn_out)
x = x + attn_out # Residual
# Cross-Attention
x_norm = self.norm2(x)
q = self.cross_attn_q(x_norm).reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
kv = self.cross_attn_kv(enc_out).reshape(B, N, 2, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1] # (B, n_heads, N, head_dim)
cross_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, N)
cross_weights = F.softmax(cross_scores, dim=-1)
cross_out = cross_weights @ v # (B, n_heads, T, head_dim)
cross_out = cross_out.transpose(1, 2).reshape(B, T, D)
cross_out = self.cross_out(cross_out)
x = x + cross_out # Residual
# Feedforward
x_norm = self.norm3(x)
x = x + self.mlp(x_norm) # Residual
return x
# implement the entire decoder
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size=13, max_len=5, d_model=64, n_heads=4, ff_dim=128, depth=2):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model)) # (1, 5, 64)
self.layers = nn.ModuleList([
DecoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim)
for _ in range(depth)
])
self.output_proj = nn.Linear(d_model, vocab_size) # Final projection to logits
def forward(self, decoder_input_ids, encoder_output):
"""
decoder_input_ids: (B, T) token IDs
encoder_output: (B, N, d_model) from image encoder
returns: logits over vocab, shape (B, T, vocab_size)
"""
x = self.token_embedding(decoder_input_ids) # (B, T, d_model)
x = x + self.pos_embedding[:, :x.size(1), :] # Add positional embedding
for layer in self.layers:
x = layer(x, encoder_output) # (B, T, d_model)
logits = self.output_proj(x) # (B, T, vocab_size)
return logits
# quick test
if __name__ == "__main__":
decoder = TransformerDecoder()
decoder_input = torch.randint(0, 13, (4, 5)) # (B=4, T=5)
encoder_out = torch.randn(4, 16, 64) # (B=4, N=16, D=64)
logits = decoder(decoder_input, encoder_out)
print("Logits shape:", logits.shape) # (4, 5, 13)