DigitalDaimyo's picture
Upload training.py with huggingface_hub
8f95e75 verified
"""
Addressed State Attention (ASA) - Training Harness
Efficient implementation optimized for language model training.
For mechanistic analysis and interventions, use asm_analysis.py instead.
Repository: https://github.com/DigitalDaimyo/AddressedStateAttention
Paper: https://github.com/DigitalDaimyo/AddressedStateAttention/paper_drafts
"""
import math
from dataclasses import dataclass
from typing import Optional, Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
'AddressedStateAttention',
'ASMBlock',
'ASMLanguageModel',
'ASMTrainConfig',
'build_model_from_cfg',
]
# -------------------------
# RoPE helper (rotate-half)
# -------------------------
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, base: float = 10000.0):
super().__init__()
assert dim % 2 == 0, "RoPE requires even dim"
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._cos_cached = None
self._sin_cached = None
self._t_cached = None
self._device_cached = None
def get_cos_sin(self, T: int, device, dtype):
if self._t_cached == T and self._cos_cached is not None and self._device_cached == device:
return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype)
t = torch.arange(T, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("t,f->tf", t, self.inv_freq) # [T, d/2]
emb = torch.cat([freqs, freqs], dim=-1) # [T, d]
cos = emb.cos()[None, None, :, :] # [1,1,T,d]
sin = emb.sin()[None, None, :, :] # [1,1,T,d]
self._t_cached = T
self._device_cached = device
self._cos_cached = cos
self._sin_cached = sin
return cos.to(dtype=dtype), sin.to(dtype=dtype)
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
return (x * cos) + (_rotate_half(x) * sin)
# -------------------------
# ALiBi slopes helper
# -------------------------
def alibi_slopes(num_heads: int, device=None, dtype=torch.float32) -> torch.Tensor:
def get_slopes(n):
def power_of_2_slopes(n):
start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n).is_integer():
return power_of_2_slopes(n)
closest = 2 ** math.floor(math.log2(n))
return power_of_2_slopes(closest) + get_slopes(2 * closest)[0::2][: n - closest]
return torch.tensor(get_slopes(num_heads), device=device, dtype=dtype)
def _inv_softplus(y: torch.Tensor) -> torch.Tensor:
return torch.log(torch.expm1(y))
class AddressedStateAttention(nn.Module):
"""
ASA with integral slotspace refine fused into the compiled chunk kernel.
Fixes included:
(1) pad slotspace RoPE cos/sin to CH (identity on padded positions)
(2) build valid_mask_c even when attention_mask is None (padding-only)
(3) pad write logits with -inf (so padded positions contribute zero to scan)
"""
def __init__(
self,
embed_dim: int,
num_heads: int = 12,
num_slots: int = 16,
dropout: float = 0.1,
# temps / numerics
read_temperature: float = 1.0,
write_temperature: float = 1.0,
state_fp32: bool = True,
slot_dropout: float = 0.0,
normalize_k: bool = False,
# write geometry
use_rope_keys: bool = True,
rope_base: float = 10000.0,
# write bias
use_alibi_write: bool = True,
alibi_strength_init: float = 0.1,
learn_alibi_strength: bool = True,
min_strength: float = 0.0,
# content read gamma
use_content_read: bool = True,
content_read_init: float = -4.0,
content_read_max_gamma: float = 3.0,
# slotspace refine (INTEGRAL)
use_slotspace_refine: bool = True, # compat only
slotspace_dim: int = 8,
slotspace_gate_init: float = -4.0,
slotspace_dropout: float = 0.05,
slotspace_signed_weights: bool = True,
# slotspace RoPE (Q/K only)
use_rope_slotspace: bool = True,
rope_base_slotspace: float = 100000.0,
# perf
write_chunk_size: int = 1024,
enable_compiled: bool = True,
):
super().__init__()
assert embed_dim % num_heads == 0
assert (slotspace_dim % 2) == 0, "slotspace_dim must be even if RoPE enabled"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_slots = num_slots
self.head_dim = embed_dim // num_heads
self.dropout = nn.Dropout(dropout)
self.read_temperature = float(read_temperature)
self.write_temperature = float(write_temperature)
self.state_fp32 = bool(state_fp32)
self.slot_dropout = float(slot_dropout)
self.normalize_k = bool(normalize_k)
self.use_rope_keys = bool(use_rope_keys)
self.use_alibi_write = bool(use_alibi_write)
self.learn_alibi_strength = bool(learn_alibi_strength)
self.min_strength = float(min_strength)
self.use_content_read = bool(use_content_read)
self.content_read_max_gamma = float(content_read_max_gamma)
self.slotspace_dim = int(slotspace_dim)
self.slotspace_dropout = nn.Dropout(float(slotspace_dropout))
self.slotspace_signed_weights = bool(slotspace_signed_weights)
self.use_rope_slotspace = bool(use_rope_slotspace)
self.write_chunk_size = int(write_chunk_size)
H, K, d = self.num_heads, self.num_slots, self.head_dim
M = self.slotspace_dim
self.slot_keys = nn.Parameter(torch.randn(H, K, d) / math.sqrt(d))
self.Wk_write = nn.Linear(embed_dim, embed_dim, bias=False)
self.Wv_write = nn.Linear(embed_dim, embed_dim, bias=False)
self.Wq_read = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.rope = RotaryEmbedding(d, base=rope_base) if self.use_rope_keys else None
if self.use_alibi_write:
self.register_buffer("_alibi_slopes", alibi_slopes(H), persistent=False)
else:
self.register_buffer("_alibi_slopes", torch.zeros(H), persistent=False)
if self.use_alibi_write and self.learn_alibi_strength:
init = torch.tensor(float(alibi_strength_init) - self.min_strength).clamp_min(1e-8)
self._alibi_strength_param = nn.Parameter(_inv_softplus(init))
else:
self._alibi_strength_param = None
self.alibi_strength = float(alibi_strength_init)
if self.use_content_read:
self._content_read_gamma_raw = nn.Parameter(torch.tensor(float(content_read_init)))
else:
self._content_read_gamma_raw = None
self.slot_in = nn.Linear(K, M, bias=False)
self.slot_q = nn.Linear(M, M, bias=False)
self.slot_k = nn.Linear(M, M, bias=False)
self.slot_v = nn.Linear(M, M, bias=False)
self.slot_out = nn.Linear(M, K, bias=False)
self._slotspace_gate_raw = nn.Parameter(torch.tensor(float(slotspace_gate_init)))
self.rope_slotspace = RotaryEmbedding(M, base=float(rope_base_slotspace)) if self.use_rope_slotspace else None
self._compiled = None
if enable_compiled:
self.enable_compiled_kernel()
def enable_compiled_kernel(self):
if self._compiled is None:
self._compiled = torch.compile(self._asa_chunk_fused, dynamic=False, fullgraph=False)
def _alibi_strength(self, dtype, device) -> torch.Tensor:
if not (self.use_alibi_write and self.learn_alibi_strength):
return torch.tensor(getattr(self, "alibi_strength", 0.0), dtype=dtype, device=device)
return (F.softplus(self._alibi_strength_param) + self.min_strength).to(dtype=dtype, device=device)
def _content_read_gamma(self, dtype, device) -> torch.Tensor:
if not self.use_content_read:
return torch.tensor(0.0, dtype=dtype, device=device)
g = F.softplus(self._content_read_gamma_raw)
if self.content_read_max_gamma is not None and self.content_read_max_gamma > 0:
g = g.clamp(max=self.content_read_max_gamma)
return g.to(dtype=dtype, device=device)
def _slotspace_gate(self, dtype, device) -> torch.Tensor:
return F.softplus(self._slotspace_gate_raw).to(dtype=dtype, device=device)
@staticmethod
def _safe_exp_sub_max(s: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
diff = s - m
diff = diff.masked_fill(~torch.isfinite(m), float("-inf"))
return torch.exp(diff)
@staticmethod
def _phi(x: torch.Tensor) -> torch.Tensor:
return F.elu(x) + 1.0
@staticmethod
def _pad_time_slice(x: torch.Tensor, t0: int, L: int, CH: int, dim: int):
sl = x.narrow(dim, t0, L)
if L == CH:
return sl, None
pad_shape = list(sl.shape)
pad_shape[dim] = CH - L
pad = torch.zeros(pad_shape, device=sl.device, dtype=sl.dtype)
xpad = torch.cat([sl, pad], dim=dim)
mask = torch.zeros((CH,), device=sl.device, dtype=torch.bool)
mask[:L] = True
return xpad, mask
def _asa_chunk_fused(
self,
wlog_c: torch.Tensor, # [B,H,K,CH]
v_c: torch.Tensor, # [B,H,CH,d]
q_c: torch.Tensor, # [B,H,CH,d]
slot_keys_dk: torch.Tensor, # [1,H,d,K]
pos_cos_s: Optional[torch.Tensor], # [1,1,CH,M] or None
pos_sin_s: Optional[torch.Tensor], # [1,1,CH,M] or None
content_gamma: torch.Tensor,
rtemp_t: torch.Tensor,
gate_t: torch.Tensor,
m_state: torch.Tensor, # [B,H,K]
denom_state: torch.Tensor, # [B,H,K]
numer_state: torch.Tensor, # [B,H,K,d]
S_state: torch.Tensor, # [B,H,M,M]
Z_state: torch.Tensor, # [B,H,M]
valid_mask_c: Optional[torch.Tensor], # [B,1,CH,1] or None
do_dropout: bool,
dropout_p: float,
signed_slot_w: bool,
):
B, H, K, CH = wlog_c.shape
d = numer_state.shape[-1]
M = S_state.shape[-1]
inv_sqrt_d = 1.0 / math.sqrt(d)
# ----- WRITE prefix-softmax scan -----
m_c, _ = torch.cummax(wlog_c, dim=-1) # [B,H,K,CH]
m_new = torch.maximum(m_state.unsqueeze(-1), m_c) # [B,H,K,CH]
scale = torch.exp(m_state.unsqueeze(-1) - m_new) # [B,H,K,CH]
denom_c = denom_state.unsqueeze(-1) * scale # [B,H,K,CH]
numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1) # [B,H,K,CH,d]
w_new = self._safe_exp_sub_max(wlog_c, m_new) # [B,H,K,CH]
denom_c = denom_c + torch.cumsum(w_new, dim=-1) # [B,H,K,CH]
numer_c = numer_c + torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2) # [B,H,K,CH,d]
# ----- Routing logits -----
read_logits_key = torch.matmul(q_c, slot_keys_dk) * inv_sqrt_d # [B,H,CH,K]
if self.use_content_read:
numer_for_dot = numer_c.to(q_c.dtype).permute(0, 1, 3, 2, 4) # [B,H,CH,K,d]
denom_for_div = denom_c.to(q_c.dtype).permute(0, 1, 3, 2) # [B,H,CH,K]
read_logits_content = (q_c.unsqueeze(-2) * numer_for_dot).sum(dim=-1) * inv_sqrt_d
read_logits_content = read_logits_content / denom_for_div.clamp_min(1e-8)
read_logits = read_logits_key + content_gamma.to(read_logits_key.dtype) * read_logits_content
else:
read_logits = read_logits_key
read_w = torch.softmax(read_logits / rtemp_t, dim=-1) # [B,H,CH,K]
# ----- EXACT base output -----
inv_denom = (1.0 / denom_c.clamp_min(1e-8)).to(numer_c.dtype) # [B,H,K,CH]
w_scaled = read_w.to(numer_c.dtype).permute(0, 1, 3, 2) * inv_denom # [B,H,K,CH]
out_base = (w_scaled.unsqueeze(-1) * numer_c).sum(dim=2) # [B,H,CH,d]
# ----- Slotspace refine -----
u = self.slot_in(read_w.to(out_base.dtype)) # [B,H,CH,M]
q_s = self.slot_q(u)
k_s = self.slot_k(u)
v_s = self.slot_v(u)
if self.use_rope_slotspace and (pos_cos_s is not None) and (pos_sin_s is not None):
q_s = apply_rope(q_s, pos_cos_s, pos_sin_s)
k_s = apply_rope(k_s, pos_cos_s, pos_sin_s)
if valid_mask_c is not None:
q_s = q_s * valid_mask_c
k_s = k_s * valid_mask_c
v_s = v_s * valid_mask_c
qf = self._phi(q_s)
kf = self._phi(k_s)
kv = kf.unsqueeze(-1) * v_s.unsqueeze(-2) # [B,H,CH,M,M]
S_c = torch.cumsum(kv, dim=2) + S_state.unsqueeze(2) # [B,H,CH,M,M]
Z_c = torch.cumsum(kf, dim=2) + Z_state.unsqueeze(2) # [B,H,CH,M]
Z_c = Z_c.clamp_min(1e-8)
num = torch.matmul(qf.unsqueeze(-2), S_c).squeeze(-2) # [B,H,CH,M]
den = (qf * Z_c).sum(dim=-1, keepdim=True).clamp_min(1e-8) # [B,H,CH,1]
u2 = num / den # [B,H,CH,M]
S_state_new = S_c[:, :, -1, :, :]
Z_state_new = Z_c[:, :, -1, :]
if do_dropout and dropout_p > 0.0:
keep = (torch.rand_like(u2) > dropout_p).to(u2.dtype) / (1.0 - dropout_p)
u2 = u2 * keep
slot_w = self.slot_out(u2) # [B,H,CH,K]
if signed_slot_w:
slot_w = torch.tanh(slot_w)
else:
slot_w = torch.softmax(slot_w, dim=-1)
slot_w_scaled = slot_w.to(numer_c.dtype).permute(0, 1, 3, 2) * inv_denom
delta = (slot_w_scaled.unsqueeze(-1) * numer_c).sum(dim=2) # [B,H,CH,d]
out = out_base + gate_t.to(out_base.dtype) * delta
m_state_new = m_new[:, :, :, -1]
denom_state_new = denom_c[:, :, :, -1]
numer_state_new = numer_c[:, :, :, -1, :]
return out, read_w, m_state_new, denom_state_new, numer_state_new, S_state_new, Z_state_new
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_info: bool = False,
return_light_stats: bool = False,
) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
B, T, C = x.shape
H, K, d = self.num_heads, self.num_slots, self.head_dim
M = self.slotspace_dim
k_write = self.Wk_write(x).reshape(B, T, H, d).transpose(1, 2) # [B,H,T,d]
v_write = self.Wv_write(x).reshape(B, T, H, d).transpose(1, 2) # [B,H,T,d]
q_read = self.Wq_read(x).reshape(B, T, H, d).transpose(1, 2) # [B,H,T,d]
if self.normalize_k:
k_write = F.normalize(k_write, dim=-1, eps=1e-8)
if self.use_rope_keys:
cos, sin = self.rope.get_cos_sin(T, device=x.device, dtype=k_write.dtype)
k_write = apply_rope(k_write, cos, sin)
slot_keys = self.slot_keys
if self.training and self.slot_dropout > 0.0:
drop = (torch.rand((H, K), device=x.device) < self.slot_dropout)
slot_keys = slot_keys * (~drop).to(slot_keys.dtype).unsqueeze(-1)
slot_keys_dk = slot_keys.transpose(-1, -2).unsqueeze(0).to(q_read.dtype) # [1,H,d,K]
write_logits_raw = torch.matmul(k_write.to(q_read.dtype), slot_keys_dk).permute(0, 1, 3, 2) / math.sqrt(d)
state_dtype = torch.float32 if (self.state_fp32 and x.dtype != torch.float32) else x.dtype
write_logits = write_logits_raw.to(state_dtype)
wtemp = max(1e-6, self.write_temperature)
write_logits = write_logits / wtemp
if self.use_alibi_write:
strength = self._alibi_strength(dtype=state_dtype, device=x.device)
slopes = self._alibi_slopes.to(device=x.device, dtype=state_dtype) * strength
pos = torch.arange(T, device=x.device, dtype=state_dtype)
write_logits = write_logits + slopes.view(1, H, 1, 1) * pos.view(1, 1, 1, T)
valid = None
if attention_mask is not None:
valid = attention_mask.to(dtype=torch.bool)
write_logits = write_logits.masked_fill(~valid.view(B, 1, 1, T), float("-inf"))
content_gamma = self._content_read_gamma(dtype=q_read.dtype, device=x.device)
rtemp_t = torch.tensor(max(1e-6, self.read_temperature), device=x.device, dtype=q_read.dtype)
gate_t = self._slotspace_gate(dtype=state_dtype, device=x.device)
denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype)
numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype)
m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype)
S_state = torch.zeros((B, H, M, M), device=x.device, dtype=state_dtype)
Z_state = torch.zeros((B, H, M), device=x.device, dtype=state_dtype)
out_h = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype)
if self.use_rope_slotspace:
cos_s_full, sin_s_full = self.rope_slotspace.get_cos_sin(T, device=x.device, dtype=state_dtype)
else:
cos_s_full = sin_s_full = None
CH = self.write_chunk_size
kernel = self._compiled if self._compiled is not None else self._asa_chunk_fused
do_dropout = bool(self.training and self.slotspace_dropout.p > 0.0)
dropout_p = float(self.slotspace_dropout.p)
signed_slot_w = bool(self.slotspace_signed_weights)
for t0 in range(0, T, CH):
t1 = min(T, t0 + CH)
L = t1 - t0
wlog_c, mask = self._pad_time_slice(write_logits, t0, L, CH, dim=3) # [B,H,K,CH]
v_c, _ = self._pad_time_slice(v_write.to(state_dtype), t0, L, CH, dim=2) # [B,H,CH,d]
q_c, _ = self._pad_time_slice(q_read, t0, L, CH, dim=2) # [B,H,CH,d]
# (3) ensure padded write logits contribute zero mass
if mask is not None:
wlog_c = wlog_c.clone()
wlog_c[:, :, :, L:] = float("-inf")
# (2) build valid_mask_c even when attention_mask is None (padding-only)
valid_mask_c = None
if (valid is not None) or (mask is not None):
if valid is None:
vm_pad = mask.view(1, CH).expand(B, CH) # [B,CH]
else:
if mask is None:
vm_pad = valid[:, t0:t1]
else:
vm = valid[:, t0:t1]
vm_pad = torch.zeros((B, CH), device=x.device, dtype=torch.bool)
vm_pad[:, :L] = vm
valid_mask_c = vm_pad.view(B, 1, CH, 1).to(state_dtype)
# (1) slotspace RoPE slice PADDED TO CH (identity on padded positions)
if self.use_rope_slotspace:
cos_slice = cos_s_full[:, :, t0:t1, :] # [1,1,L,M]
sin_slice = sin_s_full[:, :, t0:t1, :] # [1,1,L,M]
if L == CH:
cos_s, sin_s = cos_slice, sin_slice
else:
cos_s = torch.ones((1, 1, CH, M), device=x.device, dtype=state_dtype)
sin_s = torch.zeros((1, 1, CH, M), device=x.device, dtype=state_dtype)
cos_s[:, :, :L, :] = cos_slice
sin_s[:, :, :L, :] = sin_slice
else:
cos_s = sin_s = None
out_c, read_w_c, m_state, denom_state, numer_state, S_state, Z_state = kernel(
wlog_c, v_c, q_c, slot_keys_dk,
cos_s, sin_s,
content_gamma, rtemp_t, gate_t,
m_state, denom_state, numer_state,
S_state, Z_state,
valid_mask_c,
do_dropout, dropout_p,
signed_slot_w,
)
if mask is not None:
out_c = out_c * mask.view(1, 1, CH, 1).to(out_c.dtype)
out_h[:, :, t0:t1, :] = out_c[:, :, :L, :]
out = out_h.transpose(1, 2).reshape(B, T, C)
out = self.out_proj(out)
out = self.dropout(out)
info = None
if return_info or return_light_stats:
info = {
"content_read_gamma": content_gamma.detach().to(torch.float32).cpu(),
"slotspace_gate": gate_t.detach().to(torch.float32).cpu(),
}
return out, info
# ============================================================================
# Addressed State Models (ASM): Config + Block + LM
# - Naming aligned with paper: slots, read/write, slot-space refinement
# - No compatibility layer (fresh public tooling)
# ============================================================================
# ============================================================================
# Config
# ============================================================================
@dataclass
class ASMTrainConfig:
# Data
dataset_name: str = "wikitext"
dataset_config: str = "wikitext-103-raw-v1"
tokenizer_name: str = "gpt2"
max_seq_len: int = 256
stride_frac_val: float = 0.50
seed: int = 1337
micro_batch_size: int = 2
grad_accum_steps: int = 8
# Sample budgets
train_samples_target: int = 100_000_000
val_samples_target: int = 25_000
# Training
batch_size: int = 64
learning_rate: float = 3e-4
weight_decay: float = 0.01
betas: Tuple[float, float] = (0.9, 0.95)
grad_clip: float = 1.0
warmup_steps: int = 1_000
total_steps: int = 75_000
eval_interval: int = 1_000
log_interval: int = 100
# Model
vocab_size: int = 50257
embed_dim: int = 384
num_layers: int = 23
num_heads: int = 8
num_slots: int = 32
mlp_ratio: float = 4.0
dropout: float = 0.1
tie_weights: bool = True
# Addressed State Attention (ASA) / numerics
read_temperature: float = 1.0
write_temperature: float = 1.0
slot_dropout: float = 0.05
state_fp32: bool = True
normalize_k: bool = False
# Positions
use_abs_pos: bool = False
use_rope_keys: bool = True
rope_base: float = 10000.0
use_alibi_write: bool = True
alibi_strength_init: float = 0.1
learn_alibi_strength: bool = True
min_strength: float = 0.0
# Content-conditioned read term (gamma)
use_content_read: bool = True
content_read_init: float = -4.0
content_read_max_gamma: float = 3.0
# Optional slot-space refinement (formerly "k-space")
use_slotspace_refine: bool = True
slotspace_dim: int = 64
slotspace_gate_init: float = -4.0
slotspace_dropout: float = 0.05
slotspace_signed_weights: bool = True
# RoPE inside slot-space matcher (Q/K only)
use_rope_slotspace: bool = True
rope_base_slotspace: float = 100000.0
# Perf knobs (behavior-identical)
write_chunk_size: int = 128
enable_compiled: bool = True
# Analytics
eval_max_batches: int = 150
analytics_last_k: int = 4
# IO / caches
output_dir: str = "./drive/MyDrive/asm_outputs"
tag: str = "asm_wikitext"
cache_dir: str = "./drive/MyDrive/asm_caches/fineweb/1B"
val_windows_cache: str = "./drive/MyDrive/asm_val_cache_windows_1024.pkl"
# ============================================================================
# Block
# ============================================================================
class ASMBlock(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
num_slots: int,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
# temperatures / numerics
read_temperature: float = 1.0,
write_temperature: float = 1.0,
state_fp32: bool = True,
slot_dropout: float = 0.0,
normalize_k: bool = False,
# positions
use_rope_keys: bool = True,
rope_base: float = 10000.0,
use_alibi_write: bool = True,
# ALiBi params
alibi_strength_init: float = 0.1,
learn_alibi_strength: bool = True,
min_strength: float = 0.0,
# content-conditioned read (gamma)
use_content_read: bool = True,
content_read_init: float = -4.0,
content_read_max_gamma: float = 3.0,
# optional slot-space refinement
use_slotspace_refine: bool = True,
slotspace_dim: int = 32,
slotspace_gate_init: float = -10.0,
slotspace_dropout: float = 0.0,
slotspace_signed_weights: bool = True,
# RoPE inside slot-space matcher
use_rope_slotspace: bool = True,
rope_base_slotspace: float = 100000.0,
# chunk sizes
write_chunk_size: int = 128,
enable_compiled: bool = False,
):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.asa = AddressedStateAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_slots=num_slots,
dropout=dropout,
read_temperature=read_temperature,
write_temperature=write_temperature,
state_fp32=state_fp32,
slot_dropout=slot_dropout,
normalize_k=normalize_k,
use_rope_keys=use_rope_keys,
rope_base=rope_base,
use_alibi_write=use_alibi_write,
alibi_strength_init=alibi_strength_init,
learn_alibi_strength=learn_alibi_strength,
min_strength=min_strength,
use_content_read=use_content_read,
content_read_init=content_read_init,
content_read_max_gamma=content_read_max_gamma,
use_slotspace_refine=use_slotspace_refine,
slotspace_dim=slotspace_dim,
slotspace_gate_init=slotspace_gate_init,
slotspace_dropout=slotspace_dropout,
slotspace_signed_weights=slotspace_signed_weights,
use_rope_slotspace=use_rope_slotspace,
rope_base_slotspace=rope_base_slotspace,
write_chunk_size=write_chunk_size,
enable_compiled=enable_compiled,
)
self.norm2 = nn.LayerNorm(embed_dim)
hidden = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, hidden, bias=False),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, embed_dim, bias=False),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, return_info: bool = False, return_light_stats: Optional[bool] = None):
a, info = self.asa(self.norm1(x), attention_mask=attention_mask, return_info=return_info, return_light_stats=return_light_stats)
x = x + a
x = x + self.mlp(self.norm2(x))
return x, info
# ============================================================================
# LM
# ============================================================================
class ASMLanguageModel(nn.Module):
def __init__(
self,
vocab_size: int,
embed_dim: int = 384,
num_layers: int = 6,
num_heads: int = 8,
num_slots: int = 8,
max_seq_len: int = 1024,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
# temperatures / numerics
read_temperature: float = 1.0,
write_temperature: float = 1.0,
state_fp32: bool = True,
slot_dropout: float = 0.05,
normalize_k: bool = False,
tie_weights: bool = True,
# LM-level abs pos
use_abs_pos: bool = False,
# positions
use_rope_keys: bool = True,
rope_base: float = 10000.0,
use_alibi_write: bool = True,
# ALiBi
alibi_strength_init: float = 0.1,
learn_alibi_strength: bool = True,
min_strength: float = 0.0,
# content-conditioned read (gamma)
use_content_read: bool = True,
content_read_init: float = -4.0,
content_read_max_gamma: float = 3.0,
# optional slot-space refinement
use_slotspace_refine: bool = True,
slotspace_dim: int = 32,
slotspace_gate_init: float = -10.0,
slotspace_dropout: float = 0.0,
slotspace_signed_weights: bool = True,
# RoPE inside slot-space matcher
use_rope_slotspace: bool = True,
rope_base_slotspace: float = 100000.0,
# chunk sizes
write_chunk_size: int = 128,
enable_compiled: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
self.use_abs_pos = bool(use_abs_pos)
self.tok = nn.Embedding(vocab_size, embed_dim)
self.pos = nn.Embedding(max_seq_len, embed_dim) if self.use_abs_pos else None
self.drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
ASMBlock(
embed_dim=embed_dim,
num_heads=num_heads,
num_slots=num_slots,
mlp_ratio=mlp_ratio,
dropout=dropout,
read_temperature=read_temperature,
write_temperature=write_temperature,
state_fp32=state_fp32,
slot_dropout=slot_dropout,
normalize_k=normalize_k,
use_rope_keys=use_rope_keys,
rope_base=rope_base,
use_alibi_write=use_alibi_write,
alibi_strength_init=alibi_strength_init,
learn_alibi_strength=learn_alibi_strength,
min_strength=min_strength,
use_content_read=use_content_read,
content_read_init=content_read_init,
content_read_max_gamma=content_read_max_gamma,
use_slotspace_refine=use_slotspace_refine,
slotspace_dim=slotspace_dim, slotspace_gate_init=slotspace_gate_init,
slotspace_dropout=slotspace_dropout,
slotspace_signed_weights=slotspace_signed_weights,
use_rope_slotspace=use_rope_slotspace,
rope_base_slotspace=rope_base_slotspace,
write_chunk_size=write_chunk_size,
enable_compiled=enable_compiled,
)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_dim)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
if tie_weights:
self.lm_head.weight = self.tok.weight
self.apply(self._init)
def _init(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_info: bool = False,
return_light_stats: Optional[bool] = None,
):
B, T = input_ids.shape
assert T <= self.max_seq_len, f"T={T} exceeds max_seq_len={self.max_seq_len}"
x = self.tok(input_ids)
if self.use_abs_pos:
pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1)
x = x + self.pos(pos)
x = self.drop(x)
infos = []
for blk in self.blocks:
x, info = blk(x, attention_mask=attention_mask, return_info=return_info, return_light_stats=return_light_stats)
if return_info:
infos.append(info)
x = self.norm(x)
logits = self.lm_head(x)
return (logits, infos) if return_info else logits
# ============================================================================
# Convenience: build model from config
# ============================================================================
def build_model_from_cfg(cfg: ASMTrainConfig) -> ASMLanguageModel:
return ASMLanguageModel(
vocab_size=cfg.vocab_size,
embed_dim=cfg.embed_dim,
num_layers=cfg.num_layers,
num_heads=cfg.num_heads,
num_slots=cfg.num_slots,
max_seq_len=cfg.max_seq_len,
mlp_ratio=cfg.mlp_ratio,
dropout=cfg.dropout,
read_temperature=cfg.read_temperature,
write_temperature=cfg.write_temperature,
state_fp32=cfg.state_fp32,
slot_dropout=cfg.slot_dropout,
normalize_k=cfg.normalize_k,
tie_weights=cfg.tie_weights,
use_abs_pos=cfg.use_abs_pos,
use_rope_keys=cfg.use_rope_keys,
rope_base=cfg.rope_base,
use_alibi_write=cfg.use_alibi_write,
alibi_strength_init=cfg.alibi_strength_init,
learn_alibi_strength=cfg.learn_alibi_strength,
min_strength=cfg.min_strength,
use_content_read=cfg.use_content_read,
content_read_init=cfg.content_read_init,
content_read_max_gamma=cfg.content_read_max_gamma,
use_slotspace_refine=cfg.use_slotspace_refine,
slotspace_dim=cfg.slotspace_dim,
slotspace_gate_init=cfg.slotspace_gate_init,
slotspace_dropout=cfg.slotspace_dropout,
slotspace_signed_weights=cfg.slotspace_signed_weights,
use_rope_slotspace=cfg.use_rope_slotspace,
rope_base_slotspace=cfg.rope_base_slotspace,
write_chunk_size=cfg.write_chunk_size,
enable_compiled=cfg.enable_compiled,
)