| """
|
| Shared building blocks for Circuit Transformer architectures.
|
|
|
| Components:
|
| - RMSNorm: Root Mean Square Layer Normalization
|
| - RotaryEmbedding: Rotary Position Embedding (RoPE)
|
| - CausalAttention: Multi-head causal attention with RoPE + KV cache
|
| - SwiGLU: Gated feed-forward network
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
| from functools import lru_cache
|
|
|
|
|
| class RMSNorm(nn.Module):
|
| """Root Mean Square Layer Normalization."""
|
|
|
| def __init__(self, dim: int, eps: float = 1e-6):
|
| super().__init__()
|
| self.eps = eps
|
| self.weight = nn.Parameter(torch.ones(dim))
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
|
| return (x.float() * norm).type_as(x) * self.weight
|
|
|
|
|
| def build_word_start_table(tokenizer, vocab_size: int) -> torch.BoolTensor:
|
| """Build a boolean table marking which token IDs start a new word.
|
|
|
| Detects word boundaries from tokenizer's token representations:
|
| - Ġ prefix (GPT-2/BPE style)
|
| - ▁ prefix (SentencePiece style)
|
| - Special tokens (starting with <)
|
| """
|
| table = torch.zeros(vocab_size, dtype=torch.bool)
|
|
|
|
|
| if hasattr(tokenizer, 'convert_ids_to_tokens'):
|
| tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
|
| elif hasattr(tokenizer, 'sp'):
|
| tokens = [tokenizer.sp.IdToPiece(i) for i in range(vocab_size)]
|
| else:
|
| tokens = [tokenizer.decode([i]) for i in range(vocab_size)]
|
|
|
| for idx, tok in enumerate(tokens):
|
| if tok is None:
|
| continue
|
| if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
|
| table[idx] = True
|
|
|
| elif len(tok) > 0 and tok[0] in '\n\r\t':
|
| table[idx] = True
|
|
|
|
|
| table[0] = True
|
|
|
| return table
|
|
|
|
|
| def compute_word_positions(input_ids: torch.Tensor, word_start_table: torch.Tensor) -> torch.Tensor:
|
| """Compute position-within-word for each token. Vectorized, no loops.
|
|
|
| Args:
|
| input_ids: [B, L] token IDs
|
| word_start_table: [vocab_size] bool tensor from build_word_start_table
|
|
|
| Returns:
|
| [B, L] float tensor: 0, 1, 2, 0, 1, 0, ... (resets at each word boundary)
|
| """
|
| is_word_start = word_start_table[input_ids]
|
| is_word_start[:, 0] = True
|
|
|
| B, L = input_ids.shape
|
| positions = torch.arange(L, device=input_ids.device, dtype=torch.float32).unsqueeze(0).expand(B, -1)
|
|
|
|
|
| fill = torch.where(is_word_start, positions, torch.tensor(-1.0, device=input_ids.device))
|
|
|
|
|
| running_start, _ = fill.cummax(dim=1)
|
|
|
|
|
| word_pos = positions - running_start
|
|
|
| return word_pos
|
|
|
|
|
| class WordPositionRoPE(nn.Module):
|
| """RoPE encoding for position-within-word.
|
|
|
| Dedicates a small subspace of head dimensions to word-internal position,
|
| using separate (lower) frequency bases. Overrides the last `word_dims`
|
| of the standard RoPE cos/sin tensors.
|
| """
|
|
|
| def __init__(self, word_dims: int, word_base: float = 10.0):
|
| super().__init__()
|
| self.word_dims = word_dims
|
| word_inv_freq = 1.0 / (word_base ** (torch.arange(0, word_dims, 2).float() / word_dims))
|
| self.register_buffer("word_inv_freq", word_inv_freq)
|
|
|
| def forward(
|
| self, cos: torch.Tensor, sin: torch.Tensor, word_positions: torch.Tensor
|
| ) -> tuple[torch.Tensor, torch.Tensor]:
|
| """Override last word_dims of cos/sin with word-position-derived values.
|
|
|
| Args:
|
| cos, sin: [L, head_dim] from standard RotaryEmbedding
|
| word_positions: [B, L] float tensor (position within word)
|
|
|
| Returns:
|
| cos, sin: [B, L, head_dim] with word dims overridden
|
| """
|
| B, L = word_positions.shape
|
|
|
|
|
| angles = word_positions.unsqueeze(-1) * self.word_inv_freq
|
|
|
| word_emb = torch.cat([angles, angles], dim=-1)
|
| word_cos = word_emb.cos()
|
| word_sin = word_emb.sin()
|
|
|
|
|
| cos = cos.unsqueeze(0).expand(B, -1, -1).clone()
|
| sin = sin.unsqueeze(0).expand(B, -1, -1).clone()
|
|
|
|
|
| cos[:, :, -self.word_dims:] = word_cos
|
| sin[:, :, -self.word_dims:] = word_sin
|
|
|
| return cos, sin
|
|
|
|
|
| class RotaryEmbedding(nn.Module):
|
| """Rotary Position Embedding (RoPE)."""
|
|
|
| def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
|
| super().__init__()
|
| self.dim = dim
|
| self.max_seq_len = max_seq_len
|
|
|
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| self.register_buffer("inv_freq", inv_freq)
|
| self._build_cache(max_seq_len)
|
|
|
| def _build_cache(self, seq_len: int):
|
| t = torch.arange(seq_len, device=self.inv_freq.device)
|
| freqs = torch.outer(t, self.inv_freq)
|
| emb = torch.cat((freqs, freqs), dim=-1)
|
| self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
| self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
|
|
| def forward(self, x: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| if seq_len > self.cos_cached.size(0):
|
| self._build_cache(seq_len)
|
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| """Rotate half the hidden dims."""
|
| x1, x2 = x.chunk(2, dim=-1)
|
| return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
| def apply_rotary_pos_emb(
|
| q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| ) -> tuple[torch.Tensor, torch.Tensor]:
|
| """Apply rotary position embedding to queries and keys.
|
|
|
| Handles both standard [L, D] and batched [B, L, D] cos/sin.
|
| Q, K shape: [B, H, L, D]. For batched cos/sin, unsqueeze dim 1 for head broadcast.
|
| """
|
| if cos.dim() == 3:
|
| cos = cos.unsqueeze(1)
|
| sin = sin.unsqueeze(1)
|
| q_embed = (q * cos) + (rotate_half(q) * sin)
|
| k_embed = (k * cos) + (rotate_half(k) * sin)
|
| return q_embed, k_embed
|
|
|
|
|
| class CausalAttention(nn.Module):
|
| """Multi-head attention with causal mask, RoPE, and optional GQA.
|
|
|
| Supports Grouped Query Attention (GQA) where num_kv_heads < num_heads.
|
| Each KV head serves (num_heads // num_kv_heads) query heads.
|
| KV cache stored at kv_heads granularity for memory efficiency.
|
| """
|
|
|
| def __init__(
|
| self,
|
| hidden_size: int,
|
| num_heads: int,
|
| num_kv_heads: int | None = None,
|
| max_seq_len: int = 2048,
|
| dropout: float = 0.0,
|
| window_size: int | None = None,
|
| word_rope_dims: int = 0,
|
| word_rope_base: float = 10.0,
|
| ):
|
| super().__init__()
|
| self.hidden_size = hidden_size
|
| self.num_heads = num_heads
|
| self.num_kv_heads = num_kv_heads or num_heads
|
| self.head_dim = hidden_size // num_heads
|
| self.num_kv_groups = self.num_heads // self.num_kv_heads
|
| self.dropout = dropout
|
| self.window_size = window_size
|
|
|
| assert self.num_heads % self.num_kv_heads == 0, \
|
| f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
| if word_rope_dims > 0:
|
| assert word_rope_dims <= self.head_dim, \
|
| f"word_rope_dims ({word_rope_dims}) must be <= head_dim ({self.head_dim})"
|
| assert word_rope_dims % 2 == 0, \
|
| f"word_rope_dims ({word_rope_dims}) must be even"
|
|
|
| self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
|
| self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
|
| self.rotary = RotaryEmbedding(self.head_dim, max_seq_len)
|
|
|
|
|
| self.word_rope = WordPositionRoPE(word_rope_dims, word_rope_base) if word_rope_dims > 0 else None
|
|
|
|
|
| mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
| if window_size is not None:
|
|
|
| band = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=-(window_size - 1))
|
| mask = mask * band
|
| self.register_buffer(
|
| "causal_mask",
|
| mask.view(1, 1, max_seq_len, max_seq_len),
|
| persistent=False,
|
| )
|
|
|
| def _expand_kv(self, kv: torch.Tensor) -> torch.Tensor:
|
| """Expand KV heads to match Q heads for GQA. No-op if num_kv_heads == num_heads."""
|
| if self.num_kv_groups == 1:
|
| return kv
|
| B, H_kv, L, D = kv.shape
|
| return kv.unsqueeze(2).expand(B, H_kv, self.num_kv_groups, L, D).reshape(B, self.num_heads, L, D)
|
|
|
| def forward(
|
| self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
|
| word_positions: torch.Tensor | None = None,
|
| ) -> tuple[torch.Tensor, tuple | None]:
|
| B, L, _ = x.shape
|
|
|
| q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
| offset = past_kv[0].size(2) if past_kv is not None else 0
|
| cos, sin = self.rotary(x, offset + L)
|
| cos = cos[offset:offset + L]
|
| sin = sin[offset:offset + L]
|
|
|
|
|
| if self.word_rope is not None and word_positions is not None:
|
| cos, sin = self.word_rope(cos, sin, word_positions)
|
|
|
| q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
|
|
|
| if past_kv is not None:
|
| past_k, past_v = past_kv
|
| k = torch.cat([past_k, k], dim=2)
|
| v = torch.cat([past_v, v], dim=2)
|
|
|
| new_kv = (k, v) if use_cache else None
|
|
|
| dropout_p = self.dropout if self.training else 0.0
|
| use_gqa = self.num_kv_groups > 1
|
|
|
| if self.window_size is not None:
|
|
|
| k_expanded = self._expand_kv(k)
|
| v_expanded = self._expand_kv(v)
|
| seq_len = k.size(2)
|
| attn = torch.matmul(q, k_expanded.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| if seq_len <= self.causal_mask.size(-1):
|
| mask = self.causal_mask[:, :, offset:offset + L, :seq_len]
|
| attn = attn.masked_fill(mask == 0, float("-inf"))
|
| attn = F.softmax(attn, dim=-1)
|
| if dropout_p > 0:
|
| attn = F.dropout(attn, p=dropout_p)
|
| out = torch.matmul(attn, v_expanded)
|
| else:
|
|
|
|
|
| is_causal = past_kv is None and L > 1
|
| out = F.scaled_dot_product_attention(
|
| q, k, v,
|
| dropout_p=dropout_p,
|
| is_causal=is_causal,
|
| enable_gqa=use_gqa,
|
| )
|
|
|
| out = out.transpose(1, 2).contiguous().view(B, L, self.hidden_size)
|
|
|
| return self.o_proj(out), new_kv
|
|
|
|
|
| class SwiGLU(nn.Module):
|
| """SwiGLU feed-forward network."""
|
|
|
| def __init__(self, hidden_size: int, intermediate_size: int | None = None):
|
| super().__init__()
|
| intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
|
| intermediate_size = ((intermediate_size + 63) // 64) * 64
|
|
|
| self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|