tiny-gpt-shakespeare / src /attention.py
bmeyer2025's picture
Upload src/attention.py with huggingface_hub
1e86f73 verified
"""
Milestone 2: Single-head causal self-attention.
Implements scaled dot-product attention with:
- Separate Q, K, V linear projections
- Causal mask (lower-triangular) so each position can only attend to past tokens
- Dropout on the attention weights
Key formula:
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class Head(nn.Module):
"""Single head of causal self-attention."""
def __init__(self, head_size: int, n_embd: int, block_size: int, dropout: float = 0.1):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.dropout = nn.Dropout(dropout)
# Causal mask: lower triangle of 1s, upper triangle of 0s.
# Registered as a buffer so it moves with the model (to/from device)
# but is NOT a learnable parameter.
self.register_buffer(
"tril",
torch.tril(torch.ones(block_size, block_size))
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape # batch, time (seq len), channels (n_embd)
k = self.key(x) # (B, T, head_size)
q = self.query(x) # (B, T, head_size)
v = self.value(x) # (B, T, head_size)
head_size = k.shape[-1]
# Scaled dot-product attention scores
# (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
scores = q @ k.transpose(-2, -1) * (head_size ** -0.5)
# Apply causal mask: positions that shouldn't be attended to get -inf,
# which softmax turns into 0 probability.
scores = scores.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
weights = F.softmax(scores, dim=-1) # (B, T, T)
weights = self.dropout(weights)
# Weighted sum of values
out = weights @ v # (B, T, head_size)
return out
# ── Quick sanity check ────────────────────────────────────────────────────────
if __name__ == "__main__":
from tokenizer import DEVICE, BLOCK_SIZE, get_batch
n_embd = 32
head_size = 16
batch_size = 4
head = Head(head_size=head_size, n_embd=n_embd, block_size=BLOCK_SIZE).to(DEVICE)
# Use random embeddings (we don't have the full model yet)
x = torch.randn(batch_size, BLOCK_SIZE, n_embd, device=DEVICE)
out = head(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape} (expected [4, {BLOCK_SIZE}, {head_size}])")
# Verify causality: output at position t should NOT depend on positions > t.
# We do this by checking that the attention mask is lower-triangular.
tril = head.tril[:8, :8]
print(f"\nCausal mask (8x8 top-left corner):")
print(tril.int())
print("\nMilestone 2 OK: single-head causal self-attention works.")