Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
class DecoderLayer(nn.Module): | |
def __init__(self, d_model=64, n_heads=4, ff_dim=128): | |
super().__init__() | |
self.d_model = d_model | |
self.n_heads = n_heads | |
self.head_dim = d_model // n_heads | |
assert d_model % n_heads == 0, "d_model must be divisible by number of heads" | |
# Self-attention: Q, K, V from decoder input | |
self.self_attn_proj = nn.Linear(d_model, 3 * d_model) | |
# Cross-attention: Q from decoder input, K/V from encoder output | |
self.cross_attn_q = nn.Linear(d_model, d_model) | |
self.cross_attn_kv = nn.Linear(d_model, 2 * d_model) | |
# Output projections | |
self.self_out = nn.Linear(d_model, d_model) | |
self.cross_out = nn.Linear(d_model, d_model) | |
# Feedforward MLP | |
self.mlp = nn.Sequential( | |
nn.Linear(d_model, ff_dim), | |
nn.GELU(), | |
nn.Linear(ff_dim, d_model) | |
) | |
# LayerNorms | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.norm3 = nn.LayerNorm(d_model) | |
def forward(self, x, enc_out): | |
""" | |
x: (B, T, D) - decoder input embeddings | |
enc_out: (B, N, D) - encoder outputs (image patch representations) | |
Returns: (B, T, D) | |
""" | |
B, T, D = x.shape | |
_, N, _ = enc_out.shape | |
# Masked Self-Attention | |
x_norm = self.norm1(x) | |
qkv = self.self_attn_proj(x_norm).reshape(B, T, 3, self.n_heads, self.head_dim) | |
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_heads, T, head_dim) | |
q, k, v = qkv[0], qkv[1], qkv[2] | |
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, T) | |
# Causal mask: prevent attention to future positions | |
mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0) # (1, 1, T, T) | |
attn_scores = attn_scores.masked_fill(mask == 0, float("-inf")) | |
attn_weights = F.softmax(attn_scores, dim=-1) | |
attn_out = attn_weights @ v # (B, n_heads, T, head_dim) | |
attn_out = attn_out.transpose(1, 2).reshape(B, T, D) | |
attn_out = self.self_out(attn_out) | |
x = x + attn_out # Residual | |
# Cross-Attention | |
x_norm = self.norm2(x) | |
q = self.cross_attn_q(x_norm).reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim) | |
kv = self.cross_attn_kv(enc_out).reshape(B, N, 2, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
k, v = kv[0], kv[1] # (B, n_heads, N, head_dim) | |
cross_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, N) | |
cross_weights = F.softmax(cross_scores, dim=-1) | |
cross_out = cross_weights @ v # (B, n_heads, T, head_dim) | |
cross_out = cross_out.transpose(1, 2).reshape(B, T, D) | |
cross_out = self.cross_out(cross_out) | |
x = x + cross_out # Residual | |
# Feedforward | |
x_norm = self.norm3(x) | |
x = x + self.mlp(x_norm) # Residual | |
return x | |
# implement the entire decoder | |
class TransformerDecoder(nn.Module): | |
def __init__(self, vocab_size=13, max_len=5, d_model=64, n_heads=4, ff_dim=128, depth=2): | |
super().__init__() | |
self.token_embedding = nn.Embedding(vocab_size, d_model) | |
self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model)) # (1, 5, 64) | |
self.layers = nn.ModuleList([ | |
DecoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim) | |
for _ in range(depth) | |
]) | |
self.output_proj = nn.Linear(d_model, vocab_size) # Final projection to logits | |
def forward(self, decoder_input_ids, encoder_output): | |
""" | |
decoder_input_ids: (B, T) token IDs | |
encoder_output: (B, N, d_model) from image encoder | |
returns: logits over vocab, shape (B, T, vocab_size) | |
""" | |
x = self.token_embedding(decoder_input_ids) # (B, T, d_model) | |
x = x + self.pos_embedding[:, :x.size(1), :] # Add positional embedding | |
for layer in self.layers: | |
x = layer(x, encoder_output) # (B, T, d_model) | |
logits = self.output_proj(x) # (B, T, vocab_size) | |
return logits | |
# quick test | |
if __name__ == "__main__": | |
decoder = TransformerDecoder() | |
decoder_input = torch.randint(0, 13, (4, 5)) # (B=4, T=5) | |
encoder_out = torch.randn(4, 16, 64) # (B=4, N=16, D=64) | |
logits = decoder(decoder_input, encoder_out) | |
print("Logits shape:", logits.shape) # (4, 5, 13) |