| """ |
| MonoidForCausalLM โ Causal Monoid Language Model (HuggingFace Compatible) |
| MonoidForCausalLM โ ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ (ๅ
ผๅฎน HuggingFace) |
| |
| Architecture / ๆถๆๆฆ่ฆ: |
| Replace softmax attention with a monoid parallel-scan recurrence. |
| ็จๅนบๅ็พคๅนถ่กๆซๆ้ๆจๆฟไปฃ softmax ๆณจๆๅใ |
| |
| Core idea / ๆ ธๅฟๆๆณ: |
| Softmax attention computes o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i |
| โ requires O(T) KV-cache per layer at inference. |
| Softmax ๆณจๆๅ่ฎก็ฎ o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i |
| โ ๆจ็ๆถๆฏๅฑ้่ฆ O(T) ็ KV ็ผๅญใ |
| |
| Monoid attention compresses the entire causal history into a |
| fixed-size state matrix S_t โ โ^{dรd} per head: |
| S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t (explicit causal recurrence) |
| o_t = q_t ยท S_t (state readout) |
| ๅนบๅ็พคๆณจๆๅๅฐๅฎๆดๅ ๆๅๅฒๅ็ผฉๅฐๆฏไธชๅคดไธไธชๅบๅฎๅคงๅฐ็็ถๆ็ฉ้ต S_t: |
| S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t (ๆพๅผๅ ๆ้ๆจ) |
| o_t = q_t ยท S_t (็ถๆ่ฏปๅบ) |
| |
| This is a monoid because the binary operator: |
| (log_ฮฑ, S) โ (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X) |
| is associative โ enables parallel prefix scan for training, |
| and O(1) sequential update for inference. |
| ่ฟๆฏไธไธชๅนบๅ็พค๏ผๅ ไธบไบๅ
็ฎๅญ: |
| (log_ฮฑ, S) โ (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X) |
| ๆปก่ถณ็ปๅๅพ โ ่ฎญ็ปๆถๅฏ็จๅนถ่กๅ็ผๆซๆ๏ผๆจ็ๆถ O(1) ้ๆญฅ้ๆจใ |
| |
| Key properties / ๅ
ณ้ฎ็นๆง: |
| โ Explicit causal modeling โ ฮฑ_t gate explicitly controls how fast |
| past information decays, making causality a first-class citizen. |
| ๆพๅผๅ ๆๅปบๆจก โ ฮฑ_t ่กฐๅ้จๆพๅผๆงๅถๅๅฒไฟกๆฏ็้ๅฟ้็๏ผ |
| ๅ ๆๆงๆฏไธ็ญๅ
ฌๆฐ่้้ mask ๆฝๅ ็็บฆๆใ |
| |
| โ Monoid state compression โ the full causal prefix x_{1:t} is |
| lossily compressed into a fixed-size (dรd) state matrix per head. |
| No O(T) KV-cache needed; inference is O(1) per token per layer. |
| ๅนบๅ็พค็ถๆๅ็ผฉ โ ๅฎๆดๅ ๆๅ็ผ x_{1:t} ่ขซๆๆๅ็ผฉๅฐๆฏไธชๅคด |
| ๅบๅฎๅคงๅฐ็ (dรd) ็ถๆ็ฉ้ตไธญใๆ ้ O(T) KV ็ผๅญ๏ผ |
| ๆจ็ๆถๆฏๅฑๆฏ token O(1)ใ |
| |
| โ Parallel training โ associativity of โ enables O(T) parallel |
| prefix scan (vs O(Tยฒ) for softmax attention). |
| ๅนถ่ก่ฎญ็ป โ โ ็็ปๅๅพไฝฟ O(T) ๅนถ่กๅ็ผๆซๆๆไธบๅฏ่ฝ |
| (ๅฏนๆฏ softmax ๆณจๆๅ็ O(Tยฒ))ใ |
| |
| Reuses LlamaMLP + LlamaRMSNorm from HuggingFace Transformers. |
| ๅค็จ HuggingFace Transformers ็ LlamaMLP + LlamaRMSNormใ |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
|
|
| from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin, AutoConfig, AutoModelForCausalLM |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm |
|
|
| try: |
| from monoid_scan_cuda import parallel_scan, parallel_scan_with_state |
| except ImportError: |
| |
| |
|
|
| def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor: |
| """Sequential prefix scan fallback: S_t = exp(log_ฮฑ_t)ยทS_{t-1} + kv_t.""" |
| B, H, T, d1, d2 = kv.shape |
| states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) |
| S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) |
| for t in range(T): |
| decay = torch.exp(log_alpha[:, :, t]) |
| while decay.dim() < S.dim(): |
| decay = decay.unsqueeze(-1) |
| S = S * decay + kv[:, :, t] |
| states[:, :, t] = S |
| return states |
|
|
| def parallel_scan_with_state(log_alpha: Tensor, kv: Tensor): |
| """Sequential prefix scan that also returns the final (log_decay, S) state.""" |
| B, H, T, d1, d2 = kv.shape |
| states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) |
| S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) |
| log_acc = torch.zeros(B, H, 1, device=log_alpha.device, dtype=log_alpha.dtype) |
| for t in range(T): |
| decay = torch.exp(log_alpha[:, :, t]) |
| while decay.dim() < S.dim(): |
| decay = decay.unsqueeze(-1) |
| S = S * decay + kv[:, :, t] |
| states[:, :, t] = S |
| log_acc = log_acc + log_alpha[:, :, t] |
| return states, (log_acc, S) |
|
|
|
|
|
|
| |
| |
| |
|
|
| class MonoidConfig(PretrainedConfig): |
| """ |
| Configuration for the Monoid causal language model. |
| ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ็้
็ฝฎใ |
| |
| Mirrors LlamaConfig for the shared components (MLP, RMSNorm, embedding) |
| so that weights can be directly transferred from Llama checkpoints. |
| ไธ LlamaConfig ็ๅ
ฑไบซ็ปไปถ (MLP, RMSNorm, embedding) ไฟๆไธ่ด, |
| ไปฅไพฟไป Llama ๆฃๆฅ็น็ดๆฅ่ฟ็งปๆ้ใ |
| """ |
| model_type = "monoid" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 32000, |
| hidden_size: int = 576, |
| intermediate_size: int = 1536, |
| num_hidden_layers: int = 30, |
| num_attention_heads: int = 9, |
| head_dim: int = 64, |
| max_position_embeddings: int = 2048, |
| rms_norm_eps: float = 1e-5, |
| hidden_act: str = "silu", |
| mlp_bias: bool = False, |
| attention_bias: bool = False, |
| tie_word_embeddings: bool = True, |
| initializer_range: float = 0.041666666666666664, |
| pad_token_id: int = None, |
| bos_token_id: int = 1, |
| eos_token_id: int = 2, |
| **kwargs, |
| ): |
| super().__init__( |
| pad_token_id=pad_token_id, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| tie_word_embeddings=tie_word_embeddings, |
| **kwargs, |
| ) |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.head_dim = head_dim |
| self.max_position_embeddings = max_position_embeddings |
| self.rms_norm_eps = rms_norm_eps |
| self.hidden_act = hidden_act |
| self.mlp_bias = mlp_bias |
| self.attention_bias = attention_bias |
| self.initializer_range = initializer_range |
|
|
|
|
| |
| |
| |
| |
|
|
| class MonoidCache: |
| """ |
| Per-layer monoid state cache for autoregressive inference. |
| ่ชๅๅฝๆจ็็้ๅฑๅนบๅ็พค็ถๆ็ผๅญใ |
| |
| Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory), |
| each layer here stores exactly ONE state tuple: |
| (log_decay_acc, S) where S โ โ^{B, H, d, d} |
| This is the monoid "sum" of all past (log_ฮฑ_i, k_iโv_i) via โ. |
| Memory is O(1) per layer regardless of sequence length. |
| |
| ไธๅไบ Transformer ็ KV-Cache (ๅญๅจๆๆ่ฟๅป็ key ๅ value, O(T) ๅ
ๅญ), |
| ่ฟ้ๆฏๅฑไป
ๅญๅจไธไธช็ถๆๅ
็ป: |
| (log_decay_acc, S) ๅ
ถไธญ S โ โ^{B, H, d, d} |
| ่ฟๆฏๆๆ่ฟๅป็ (log_ฮฑ_i, k_iโv_i) ้่ฟ โ ็ดฏ็งฏ็ๅนบๅ็พค "ๅ"ใ |
| ๆ ่ฎบๅบๅๅค้ฟ๏ผๆฏๅฑๅ
ๅญ O(1)ใ |
| """ |
|
|
| def __init__(self): |
| self.states: list[tuple[Tensor, Tensor] | None] = [] |
| self.seen_tokens: int = 0 |
|
|
| def get_seq_length(self, layer_idx: int = 0) -> int: |
| return self.seen_tokens |
|
|
| def update(self, layer_idx: int, state: tuple[Tensor, Tensor]): |
| """Store the accumulated monoid state for a given layer. |
| ๅญๅจๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ""" |
| while len(self.states) <= layer_idx: |
| self.states.append(None) |
| self.states[layer_idx] = state |
|
|
| def get_state(self, layer_idx: int) -> tuple[Tensor, Tensor] | None: |
| """Retrieve the accumulated monoid state for a given layer. |
| ่ทๅๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ""" |
| if layer_idx < len(self.states): |
| return self.states[layer_idx] |
| return None |
|
|
| def reorder_cache(self, beam_idx: torch.LongTensor): |
| """Reorder cache for beam search. ไธบ beam search ้ๆ็ผๅญใ""" |
| for i, state in enumerate(self.states): |
| if state is not None: |
| log_d, kv = state |
| self.states[i] = (log_d[beam_idx], kv[beam_idx]) |
|
|
|
|
| |
| |
| |
| |
|
|
| def monoid_op( |
| a: tuple[Tensor, Tensor], |
| b: tuple[Tensor, Tensor], |
| ) -> tuple[Tensor, Tensor]: |
| """ |
| The monoid binary operator โ on (log-space decay, state matrix) pairs. |
| ๅนบๅ็พคไบๅ
็ฎๅญ โ๏ผไฝ็จไบ (ๅฏนๆฐ่กฐๅ, ็ถๆ็ฉ้ต) ๅฏนใ |
| |
| Definition / ๅฎไน: |
| (log_ฮฑ, S) โ (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X) |
| |
| Why this is a monoid / ไธบไปไน่ฟๆฏๅนบๅ็พค: |
| โข Associativity / ็ปๅๅพ: |
| (a โ b) โ c = a โ (b โ c) โ |
| This enables parallel prefix scan for training (reduce tree) |
| and O(1) left-fold for inference (sequential append). |
| ็ปๅๅพไฝฟ่ฎญ็ปๆถๅฏไปฅ็จๅนถ่กๅ็ผๆซๆ (ๅฝ็บฆๆ ), |
| ๆจ็ๆถๅฏไปฅ O(1) ๅทฆๆๅ (้ๆญฅ่ฟฝๅ )ใ |
| |
| โข Identity / ๅไฝๅ
: |
| e = (0, 0) โ e โ a = a โ e = a โ |
| |
| Why log-space / ไธบไปไน็จๅฏนๆฐ็ฉบ้ด: |
| Working in log-space for the decay factor avoids numerical |
| underflow when ฮฑ^T โ 0 for long sequences. |
| ่กฐๅๅ ๅญๅจๅฏนๆฐ็ฉบ้ดไธญ่ฟ็ฎ๏ผ้ฟๅ
้ฟๅบๅไธ ฮฑ^T โ 0 ็ๆฐๅผไธๆบขใ |
| |
| Causal semantics / ๅ ๆ่ฏญไน: |
| S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t |
| The decay ฮฑ_t โ (0,1) explicitly controls how much of the past |
| the model retains. This is *explicit causal modeling* โ the model |
| must learn to balance retention vs novelty at every timestep. |
| ่กฐๅ ฮฑ_t โ (0,1) ๆพๅผๆงๅถๆจกๅไฟ็ๅคๅฐ่ฟๅปไฟกๆฏใ |
| ่ฟๅฐฑๆฏ *ๆพๅผๅ ๆๅปบๆจก* โ ๆจกๅๅฟ
้กปๅจๆฏไธชๆถ้ดๆญฅๅญฆไน ๅฆไฝ |
| ๅนณ่กกไฟ็ๆงไฟกๆฏไธๅธๆถๆฐไฟกๆฏใ |
| """ |
| log_a, kv_a = a |
| log_b, kv_b = b |
|
|
| new_log = log_a + log_b |
| decay_b = torch.exp(log_b) |
| while decay_b.dim() < kv_a.dim(): |
| decay_b = decay_b.unsqueeze(-1) |
|
|
| return new_log, kv_a * decay_b + kv_b |
|
|
|
|
| |
| |
| |
| |
|
|
| class MonoidAttention(nn.Module): |
| """ |
| Monoid Causal Attention โ replaces softmax attention entirely. |
| ๅนบๅ็พคๅ ๆๆณจๆๅ โ ๅฎๅ
จๆฟไปฃ softmax ๆณจๆๅใ |
| |
| Key differences from standard attention / ไธๆ ๅๆณจๆๅ็ๅ
ณ้ฎๅบๅซ: |
| โ No RoPE / positional encoding โ position is implicitly encoded |
| by the causal decay gate ฮฑ_t. The model learns *when* to forget |
| rather than encoding *where* tokens are. |
| ไธไฝฟ็จ RoPE / ไฝ็ฝฎ็ผ็ โ ไฝ็ฝฎไฟกๆฏ็ฑๅ ๆ่กฐๅ้จ ฮฑ_t ้ๅผ็ผ็ ใ |
| ๆจกๅๅญฆไน *ไฝๆถ้ๅฟ* ่้็ผ็ token *ๅจๅช้*ใ |
| |
| โ No KV-Cache โ replaced by MonoidCache with O(1) state per layer. |
| Each state S โ โ^{Hรdรd} is a compressed summary of ALL past tokens. |
| ไธไฝฟ็จ KV ็ผๅญ โ ็ฑ O(1) ็ MonoidCache ็ถๆๆฟไปฃใ |
| ๆฏไธช็ถๆ S โ โ^{Hรdรd} ๆฏๆๆ่ฟๅป token ็ๅ็ผฉๆ่ฆใ |
| |
| โ No attention mask โ causality is built into the recurrence itself. |
| S_t only depends on S_{t-1} and the current token by construction. |
| ไธไฝฟ็จๆณจๆๅๆฉ็ โ ๅ ๆๆงๅ
ๅปบไบ้ๆจ็ปๆๆฌ่บซใ |
| S_t ไป
ไพ่ต S_{t-1} ๅๅฝๅ token๏ผ็ปๆไธไฟ่ฏๅ ๆๆงใ |
| |
| Computation / ่ฎก็ฎ: |
| Training (parallel scan, O(T)): |
| k_t = SiLU(k_proj(x_t)) # non-negative keys for PSD state |
| S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # monoid recurrence via prefix scan |
| o_t = q_t ยท S_t # linear readout from state |
| |
| Inference (RNN mode, O(1) per token): |
| Same recurrence, but applied one token at a time. |
| |
| ่ฎญ็ป (ๅนถ่กๆซๆ, O(T)): |
| k_t = SiLU(k_proj(x_t)) # ้่ด key ไฟ่ฏ็ถๆ็ฉ้ตๅๆญฃๅฎ |
| S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # ้่ฟๅ็ผๆซๆๅฎ็ฐๅนบๅ็พค้ๆจ |
| o_t = q_t ยท S_t # ไป็ถๆไธญ็บฟๆง่ฏปๅบ |
| |
| ๆจ็ (RNN ๆจกๅผ, ๆฏ token O(1)): |
| ๅไธ้ๆจๅ
ฌๅผ, ไฝ้ token ้กบๅบๅบ็จใ |
| """ |
|
|
| def __init__(self, config: MonoidConfig, layer_idx: int): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = config.head_dim |
| self.scaling = self.head_dim ** -0.5 |
| |
|
|
| |
| |
| |
| |
| |
| |
| self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True) |
|
|
| |
| |
| |
| |
| |
| |
| self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| self.h0 = nn.Parameter(torch.zeros(1, self.num_heads, self.head_dim, self.head_dim)) |
|
|
| def forward( |
| self, |
| hidden_states: Tensor, |
| monoid_cache: MonoidCache | None = None, |
| use_cache: bool = False, |
| ) -> tuple[Tensor, tuple[Tensor, Tensor] | None]: |
| """ |
| Args: |
| hidden_states: [B, T, hidden_size] |
| monoid_cache: O(1) state cache for inference |
| ๆจ็็จ O(1) ็ถๆ็ผๅญ |
| use_cache: whether to use/update the cache |
| ๆฏๅฆไฝฟ็จ/ๆดๆฐ็ผๅญ |
| |
| Returns: |
| output: [B, T, hidden_size] |
| final_state: (log_decay_acc, S) or None |
| """ |
| B, T, _ = hidden_states.shape |
| H, d = self.num_heads, self.head_dim |
|
|
| |
| |
| q = self.q_proj(hidden_states).view(B, T, H, d).transpose(1, 2) |
| k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2) |
| v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2) |
|
|
| |
| |
| q = self.q_norm(q) * self.scaling |
| k = self.k_norm(k) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| k = torch.nn.functional.silu(k) |
|
|
| |
| |
| |
| |
| alpha = torch.sigmoid(self.decay_proj(hidden_states)) |
| alpha = alpha.transpose(1, 2).unsqueeze(-1) |
| log_alpha = torch.log(alpha.clamp(min=1e-6)) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if use_cache and T == 1: |
| |
| |
| kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0]) |
| log_t = log_alpha[:, :, 0] |
|
|
| prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None |
| if prev is None: |
| |
| |
| decay_t = torch.exp(log_t) |
| while decay_t.dim() < self.h0.dim(): |
| decay_t = decay_t.unsqueeze(-1) |
| new_state = (log_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t) |
| else: |
| |
| |
| new_state = monoid_op(prev, (log_t, kv_t)) |
|
|
| if monoid_cache is not None: |
| monoid_cache.update(self.layer_idx, new_state) |
|
|
| |
| |
| o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1]) |
| |
| |
| o = o.contiguous().view(B, 1, -1) |
| return self.o_proj(o), new_state |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| if use_cache: |
| S = self.h0.expand(B, -1, -1, -1).clone() |
| log_acc = torch.zeros(B, H, 1, device=hidden_states.device, dtype=q.dtype) |
| o_parts = [] |
| for t in range(T): |
| kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t]) |
| decay = torch.exp(log_alpha[:, :, t]) |
| while decay.dim() < S.dim(): |
| decay = decay.unsqueeze(-1) |
| S = S * decay + kv_t |
| o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S)) |
| log_acc = log_acc + log_alpha[:, :, t] |
|
|
| final_state = (log_acc, S) |
| if monoid_cache is not None: |
| monoid_cache.update(self.layer_idx, final_state) |
|
|
| o = torch.stack(o_parts, dim=2) |
| o = o.transpose(1, 2).contiguous().view(B, T, -1) |
| return self.o_proj(o), final_state |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| kv = torch.einsum('bhtd, bhte -> bhtde', k, v) |
|
|
| |
| |
| |
| |
| states = parallel_scan(log_alpha, kv) |
|
|
| |
| |
| cum_log_alpha = torch.cumsum(log_alpha, dim=2) |
| h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) |
| states = states + h0_decay * self.h0.unsqueeze(2) |
|
|
| |
| |
| o = torch.einsum('bhtd, bhtde -> bhte', q, states) |
|
|
| o = o.transpose(1, 2).contiguous().view(B, T, -1) |
| return self.o_proj(o), None |
|
|
|
|
| |
| |
| |
| |
|
|
| class MonoidDecoderLayer(nn.Module): |
| """ |
| Pre-Norm Transformer block with Monoid attention. |
| ไฝฟ็จๅนบๅ็พคๆณจๆๅ็ Pre-Norm Transformer ๅใ |
| |
| Data flow / ๆฐๆฎๆต: |
| x โ RMSNorm โ MonoidAttn โ +residual โ RMSNorm โ LlamaMLP โ +residual โ out |
| |
| The MLP and RMSNorm are identical to Llama (weights transferred directly). |
| Only MonoidAttention is the novel component. |
| MLP ๅ RMSNorm ไธ Llama ๅฎๅ
จ็ธๅ (ๆ้็ดๆฅ่ฟ็งป)ใ |
| ไป
MonoidAttention ๆฏๅ
จๆฐ็ปไปถใ |
| """ |
| gradient_checkpointing = False |
|
|
| def __init__(self, config: MonoidConfig, layer_idx: int): |
| super().__init__() |
| self.self_attn = MonoidAttention(config, layer_idx) |
| self.mlp = LlamaMLP(config) |
| self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: Tensor, |
| monoid_cache: MonoidCache | None = None, |
| use_cache: bool = False, |
| ) -> Tensor: |
| |
| |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, _ = self.self_attn(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache) |
| hidden_states = residual + hidden_states |
|
|
| |
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states |
|
|
|
|
| |
| |
| |
| |
|
|
| class MonoidPreTrainedModel(PreTrainedModel): |
| config_class = MonoidConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["MonoidDecoderLayer"] |
|
|
| def _init_weights(self, module: nn.Module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| if isinstance(module, MonoidAttention): |
| nn.init.constant_(module.decay_proj.bias, 4.0) |
|
|
| class MonoidModel(MonoidPreTrainedModel): |
| """ |
| Stack of MonoidDecoderLayers with token embedding and final norm. |
| ๅนบๅ็พค่งฃ็ ๅฑๅ ๅ , ๅธฆ token ๅตๅ
ฅๅๆ็ปๅฝไธๅใ |
| |
| Forward: embed_tokens โ N ร MonoidDecoderLayer โ final_norm |
| ๅๅ: embed_tokens โ N ร MonoidDecoderLayer โ final_norm |
| """ |
|
|
| def __init__(self, config: MonoidConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [MonoidDecoderLayer(config, i) for i in range(config.num_hidden_layers)] |
| ) |
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Tensor | None = None, |
| inputs_embeds: Tensor | None = None, |
| monoid_cache: MonoidCache | None = None, |
| use_cache: bool = False, |
| ) -> BaseModelOutputWithPast: |
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| hidden_states = inputs_embeds |
| for layer in self.layers: |
| if self.gradient_checkpointing and self.training and not use_cache: |
| hidden_states = self._gradient_checkpointing_func( |
| layer.__call__, |
| hidden_states, |
| monoid_cache, |
| use_cache, |
| ) |
| else: |
| hidden_states = layer(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=monoid_cache, |
| ) |
|
|
|
|
| |
| |
| |
| |
|
|
| class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin): |
| """ |
| Monoid-based causal language model with LM head. |
| ๅบไบๅนบๅ็พค็ๅ ๆ่ฏญ่จๆจกๅ, ๅธฆ่ฏญ่จๆจกๅๅคดใ |
| |
| The architecture in one sentence: |
| "Llama body + Monoid mind" โ reuse Llama's proven MLP/embeddings, |
| replace attention with monoid state compression for O(1) inference. |
| |
| ไธๅฅ่ฏๆฆๆฌๆถๆ: |
| "Llama ็่บซไฝ + ๅนบๅ็พค็ๆ็ปด" โ ๅค็จ Llama ๆ็็ MLP/ๅตๅ
ฅๅฑ, |
| ็จๅนบๅ็พค็ถๆๅ็ผฉๆฟๆขๆณจๆๅ, ๅฎ็ฐ O(1) ๆจ็ใ |
| """ |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| |
| |
| |
| |
| _is_stateful = True |
|
|
| def __init__(self, config: MonoidConfig): |
| super().__init__(config) |
| self.model = MonoidModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: Tensor, |
| past_key_values=None, |
| attention_mask: Tensor | None = None, |
| inputs_embeds: Tensor | None = None, |
| **kwargs, |
| ) -> dict: |
| """ |
| Called by GenerationMixin at each decoding step. |
| GenerationMixin ๅจๆฏไธช่งฃ็ ๆญฅ่ฐ็จๆญคๆนๆณใ |
| |
| HuggingFace may pass a DynamicCache; we intercept and replace |
| it with MonoidCache since we don't use standard KV-cache. |
| HuggingFace ๅฏ่ฝไผ ๅ
ฅ DynamicCache; ๆไปฌๆฆๆชๅนถๆฟๆขไธบ |
| MonoidCache, ๅ ไธบๆไปฌไธไฝฟ็จๆ ๅ KV ็ผๅญใ |
| """ |
| |
| |
| if past_key_values is not None and not isinstance(past_key_values, MonoidCache): |
| past_key_values = None |
|
|
| if past_key_values is not None and past_key_values.seen_tokens > 0: |
| |
| |
| input_ids = input_ids[:, -1:] |
|
|
| model_inputs = { |
| "input_ids": input_ids, |
| "monoid_cache": past_key_values, |
| "use_cache": True, |
| } |
| return model_inputs |
|
|
| def forward( |
| self, |
| input_ids: Tensor | None = None, |
| attention_mask: Tensor | None = None, |
| |
| position_ids: Tensor | None = None, |
| |
| past_key_values: MonoidCache | None = None, |
| inputs_embeds: Tensor | None = None, |
| labels: Tensor | None = None, |
| use_cache: bool | None = None, |
| monoid_cache: MonoidCache | None = None, |
| output_attentions: bool | None = None, |
| output_hidden_states: bool | None = None, |
| logits_to_keep: int | Tensor = 0, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| |
| |
| cache = monoid_cache or past_key_values |
|
|
| |
| |
| if cache is not None and not isinstance(cache, MonoidCache): |
| cache = None |
|
|
| if use_cache and cache is None: |
| cache = MonoidCache() |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| inputs_embeds=inputs_embeds, |
| monoid_cache=cache, |
| use_cache=bool(use_cache), |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
|
|
| |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| |
| |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = nn.functional.cross_entropy( |
| shift_logits.view(-1, self.vocab_size), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
|
|
| if cache is not None: |
| cache.seen_tokens += (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=cache, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| AutoConfig.register("monoid", MonoidConfig) |
| AutoModelForCausalLM.register(MonoidConfig, MonoidForCausalLM) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') |
| print(f'Device: {device}') |
|
|
| config = MonoidConfig( |
| vocab_size=49152, |
| hidden_size=576, |
| intermediate_size=1536, |
| num_hidden_layers=30, |
| num_attention_heads=9, |
| head_dim=64, |
| rms_norm_eps=1e-5, |
| hidden_act="silu", |
| tie_word_embeddings=True, |
| ) |
| model = MonoidForCausalLM(config).to(device) |
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f'Parameters: {n_params:,}') |
|
|
| |
| B, T = 2, 64 |
| ids = torch.randint(0, config.vocab_size, (B, T), device=device) |
| out = model(ids, labels=ids) |
| print(f'Train โ logits: {out.logits.shape}, loss: {out.loss:.4f}') |
|
|
| |
| prompt = torch.randint(0, config.vocab_size, (1, 8), device=device) |
| cache = MonoidCache() |
| |
| prefill_out = model(prompt, use_cache=True, monoid_cache=cache) |
| print(f'Prefill โ logits: {prefill_out.logits.shape}, cache seen: {cache.seen_tokens}') |
| |
| next_tok = prefill_out.logits[:, -1:].argmax(dim=-1) |
| step_out = model(next_tok, use_cache=True, monoid_cache=cache) |
| print(f'Decode โ logits: {step_out.logits.shape}, cache seen: {cache.seen_tokens}') |
|
|
| |
| print('\nMonoid associativity check / ๅนบๅ็พค็ปๅๅพ้ช่ฏ:') |
| a = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) |
| b = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) |
| c = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) |
| ab_c = monoid_op(monoid_op(a, b), c) |
| a_bc = monoid_op(a, monoid_op(b, c)) |
| err = (ab_c[1] - a_bc[1]).abs().max().item() |
| print(f' |(aโb)โc - aโ(bโc)| = {err:.2e}') |
|
|
| print('\nDone.') |
|
|