| |
| """Public-facing TMLM-Haiku interactive CLI. |
| |
| Pulls models from the CompactAI-O HuggingFace collection: |
| https://huggingface.co/collections/CompactAI-O/tmlm-haiku-series |
| """ |
| from __future__ import annotations |
|
|
|
|
| |
| from __future__ import annotations |
|
|
| import hashlib |
| import json |
| import math |
| import os |
| import string |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, Iterator, List, Optional, Sequence, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint |
|
|
|
|
| HUGGINGFACE_MODELS = { |
| "TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1", |
| "TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3", |
| "TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2", |
| "TMLM-Haiku-2.3": "CompactAI-O/TMLM-Haiku-2.3", |
| "Glint-1": "CompactAI-O/Glint-1", |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| dim: int = 128 |
| n_unique_layers: int = 8 |
| n_logical_layers: int = 16 |
| n_heads: int = 4 |
| n_kv_heads: int = 2 |
| ffn_dim: int = 224 |
| dropout: float = 0.0 |
| seq_len: int = 2048 |
| sliding_window_size: int = 512 |
| mtp_horizons: Tuple[int, ...] = (2, 3, 4) |
| rope_fraction: float = 0.5 |
| embed_scale: bool = True |
| logit_soft_cap: float = -1.0 |
| quantization: str = "nvfp4" |
|
|
| @property |
| def head_dim(self) -> int: |
| return self.dim // self.n_heads |
|
|
|
|
| model_config = ModelConfig() |
|
|
| MODEL_SERIES = { |
| "haiku": { |
| "dim": 64, |
| "n_unique_layers": 12, |
| "n_logical_layers": 24, |
| "n_heads": 4, |
| "n_kv_heads": 2, |
| "ffn_dim": 384, |
| "dropout": 0.0, |
| "seq_len": 2048, |
| "sliding_window_size": 2048, |
| "mtp_horizons": (), |
| "rope_fraction": 0.5, |
| "engram_dim": 8, |
| "engram_heads": 2, |
| "engram_table_size": 64, |
| "engram_max_ngram": 2, |
| "mhc_expansion": 2, |
| "sleep_gate_cap": 0, |
| "sleep_gate_heads": 4, |
| "latent_think_layers": 0, |
| "prelude_layers": 0, |
| "coda_layers": 0, |
| "recurrent_loops": 0, |
| "recurrent_act_threshold": 0.9, |
| "recurrent_lora_rank": 0, |
| "recurrent_loop_embed_dim": 0, |
| }, |
| "sonnet": { |
| "dim": 1024, |
| "n_unique_layers": 20, |
| "n_logical_layers": 40, |
| "n_heads": 16, |
| "n_kv_heads": 4, |
| "ffn_dim": 4096, |
| "dropout": 0.0, |
| "seq_len": 2048, |
| "mtp_horizons": (2,), |
| "engram_dim": 32, |
| "engram_heads": 8, |
| "engram_table_size": 4096, |
| "engram_max_ngram": 2, |
| "mhc_expansion": 2, |
| "sleep_gate_cap": 0, |
| "sleep_gate_heads": 8, |
| "latent_think_layers": 0, |
| "prelude_layers": 0, |
| "coda_layers": 0, |
| "recurrent_loops": 0, |
| "recurrent_act_threshold": 0.99, |
| "recurrent_lora_rank": 0, |
| "recurrent_loop_embed_dim": 0, |
| }, |
| "opus": { |
| "dim": 1536, |
| "n_unique_layers": 18, |
| "n_logical_layers": 36, |
| "n_heads": 16, |
| "n_kv_heads": 4, |
| "ffn_dim": 5888, |
| "dropout": 0.0, |
| "seq_len": 2048, |
| "mtp_horizons": (2,), |
| "engram_dim": 64, |
| "engram_heads": 8, |
| "engram_table_size": 8192, |
| "engram_max_ngram": 2, |
| "mhc_expansion": 4, |
| "sleep_gate_cap": 0, |
| "sleep_gate_heads": 8, |
| "latent_think_layers": 0, |
| "prelude_layers": 0, |
| "coda_layers": 0, |
| "recurrent_loops": 0, |
| "recurrent_act_threshold": 0.99, |
| "recurrent_lora_rank": 0, |
| "recurrent_loop_embed_dim": 0, |
| }, |
| } |
|
|
|
|
| |
| |
| |
|
|
| FORMAT_TOKENS = [ |
| "<|user|>", |
| "<|assistant|>", |
| "<|system|>", |
| "<|start_header_id|>", |
| "<|end_header_id|>", |
| "<|begin_of_thought|>", |
| "<|end_of_thought|>", |
| "<|begin_of_solution|>", |
| "<|end_of_solution|>", |
| ] |
|
|
|
|
| class WordTokenizer: |
| def __init__( |
| self, extra_chars: str = "", format_tokens: Optional[List[str]] = None |
| ) -> None: |
| base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r" |
| fallback_chars = sorted(set(base + extra_chars)) |
| self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"] |
| self.format_tokens = ( |
| list(format_tokens) if format_tokens else list(FORMAT_TOKENS) |
| ) |
| self.special = list(self.core_special) + list(self.format_tokens) |
| self.id_to_token: List[str] = ( |
| list(self.core_special) + self.format_tokens + fallback_chars |
| ) |
| self.token_to_id: Dict[str, int] = { |
| t: i for i, t in enumerate(self.id_to_token) |
| } |
| self.special_multi_tokens = sorted( |
| [t for t in self.special if len(t) > 1], key=len, reverse=True |
| ) |
| self.multi_char_tokens = self.special_multi_tokens |
| self.dynamic_additions = 0 |
|
|
| @property |
| def pad_id(self) -> int: |
| return self.token_to_id["<PAD>"] |
|
|
| @property |
| def bos_id(self) -> int: |
| return self.token_to_id["<BOS>"] |
|
|
| @property |
| def eos_id(self) -> int: |
| return self.token_to_id["<EOS>"] |
|
|
| @property |
| def unk_id(self) -> int: |
| return self.token_to_id["<UNK>"] |
|
|
| @property |
| def vocab_size(self) -> int: |
| return len(self.id_to_token) |
|
|
| def maybe_add_char(self, ch: str) -> bool: |
| if ch in self.token_to_id: |
| return False |
| self.token_to_id[ch] = len(self.id_to_token) |
| self.id_to_token.append(ch) |
| self.dynamic_additions += 1 |
| return True |
|
|
| def iter_lexical_tokens(self, text: str) -> Iterator[str]: |
| i = 0 |
| n = len(text) |
| while i < n: |
| matched_special = False |
| for token in self.special_multi_tokens: |
| if text.startswith(token, i): |
| yield token |
| i += len(token) |
| matched_special = True |
| break |
| if matched_special: |
| continue |
| yield text[i] |
| i += 1 |
|
|
| def encode( |
| self, text: str, add_bos: bool = False, add_eos: bool = False |
| ) -> List[int]: |
| out: List[int] = [] |
| if add_bos: |
| out.append(self.bos_id) |
| unk = self.unk_id |
| t2i = self.token_to_id |
| for tok in self.iter_lexical_tokens(text): |
| out.append(t2i.get(tok, unk)) |
| if add_eos: |
| out.append(self.eos_id) |
| return out |
|
|
| def decode(self, ids: Sequence[int], skip_special: bool = True) -> str: |
| pieces: List[str] = [] |
| for idx in ids: |
| if int(idx) < 0 or int(idx) >= len(self.id_to_token): |
| continue |
| tok = self.id_to_token[int(idx)] |
| if skip_special and tok in self.special: |
| continue |
| pieces.append(tok) |
| return "".join(pieces) |
|
|
| @classmethod |
| def load(cls, path: Path) -> WordTokenizer: |
| with path.open("r", encoding="utf-8") as f: |
| data = json.load(f) |
| format_tokens = data.get("format_tokens", FORMAT_TOKENS) |
| tokenizer = cls(extra_chars="", format_tokens=format_tokens) |
| tokenizer.id_to_token = data["id_to_token"] |
| tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)} |
| tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens) |
| tokenizer.special_multi_tokens = sorted( |
| [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True |
| ) |
| tokenizer.multi_char_tokens = tokenizer.special_multi_tokens |
| return tokenizer |
|
|
|
|
| LetterTokenizer = WordTokenizer |
|
|
|
|
| |
| |
| |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if hasattr(torch.nn.functional, "rms_norm"): |
| return torch.nn.functional.rms_norm( |
| x, self.weight.shape, self.weight, self.eps |
| ) |
| x_fp = x.float() |
| rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
| return (x_fp * rms).to(dtype=x.dtype) * self.weight |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, base: float = 10000.0) -> None: |
| super().__init__() |
| inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv, persistent=False) |
|
|
| def cos_sin( |
| self, seq_len: int, device: torch.device, dtype: torch.dtype |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| cos = emb.cos()[None, None, :, :].to(dtype=dtype) |
| sin = emb.sin()[None, None, :, :].to(dtype=dtype) |
| return cos, sin |
|
|
|
|
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| n_heads: int, |
| n_kv_heads: int, |
| head_dim: int, |
| dropout: float, |
| sliding_window: int, |
| rope_fraction: float, |
| ) -> None: |
| super().__init__() |
| self.dim = dim |
| self.n_heads = n_heads |
| self.n_kv_heads = n_kv_heads |
| self.head_dim = head_dim |
| self.n_rep = n_heads // n_kv_heads |
| self.dropout = dropout |
| self.sliding_window = sliding_window |
|
|
| self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) |
| self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) |
| self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) |
| self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) |
|
|
| self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2) |
| self.rope = RotaryEmbedding(self.rope_dim) |
|
|
| self.q_norm = RMSNorm(head_dim) |
| self.k_norm = RMSNorm(head_dim) |
|
|
| self.output_gate = nn.Parameter(torch.ones(n_heads)) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| is_global: bool, |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| B, T, _ = x.shape |
|
|
| q = self.wq(x).view(B, T, self.n_heads, self.head_dim) |
| k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) |
| v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) |
|
|
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| past_len = past_kv[0].shape[2] if past_kv is not None else 0 |
| cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype) |
| cos_slice = cos[:, :, past_len : past_len + T, :] |
| sin_slice = sin[:, :, past_len : past_len + T, :] |
|
|
| q_rope = q[..., : self.rope_dim] |
| q_pass = q[..., self.rope_dim :] |
| k_rope = k[..., : self.rope_dim] |
| k_pass = k[..., self.rope_dim :] |
|
|
| q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice) |
| k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice) |
|
|
| q = torch.cat([q_rope, q_pass], dim=-1) |
| k = torch.cat([k_rope, k_pass], dim=-1) |
|
|
| if past_kv is not None: |
| k = torch.cat([past_kv[0], k], dim=2) |
| v = torch.cat([past_kv[1], v], dim=2) |
|
|
| new_kv = (k, v) if use_cache else None |
|
|
| S = k.shape[2] |
| if self.n_rep > 1: |
| k = ( |
| k[:, :, None, :, :] |
| .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) |
| .reshape(B, self.n_heads, S, self.head_dim) |
| ) |
| v = ( |
| v[:, :, None, :, :] |
| .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) |
| .reshape(B, self.n_heads, S, self.head_dim) |
| ) |
|
|
| drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0 |
|
|
| if is_global: |
| if past_kv is None and T > 1: |
| out = F.scaled_dot_product_attention( |
| q, k, v, is_causal=True, dropout_p=drop_p |
| ) |
| else: |
| out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p) |
| else: |
| T_q = q.shape[2] |
| q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1) |
| k_pos = torch.arange(S, device=q.device).unsqueeze(0) |
| mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window) |
| out = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p |
| ) |
|
|
| gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1) |
| out = out * gate |
|
|
| out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) |
| out = self.wo(out) |
|
|
| return out, new_kv |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None: |
| super().__init__() |
| self.gate = nn.Linear(dim, hidden_dim, bias=False) |
| self.up = nn.Linear(dim, hidden_dim, bias=False) |
| self.down = nn.Linear(hidden_dim, dim, bias=False) |
| self.drop = nn.Dropout(dropout) |
|
|
| nn.init.normal_(self.gate.weight, std=dim**-0.5) |
| nn.init.normal_(self.up.weight, std=dim**-0.5) |
| nn.init.normal_(self.down.weight, std=hidden_dim**-0.5) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = F.silu(self.gate(x)) * self.up(x) |
| out = self.down(h) |
| if self.training and torch.is_grad_enabled(): |
| out = self.drop(out) |
| return out |
|
|
|
|
| def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor: |
| if loop_dim <= 0: |
| return h |
| loop_dim = min(loop_dim, h.shape[-1]) |
| if loop_dim % 2 == 1: |
| loop_dim -= 1 |
| if loop_dim <= 0: |
| return h |
| inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) |
| phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq |
| loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim) |
| out = h.clone() |
| out[..., :loop_dim] = out[..., :loop_dim] + loop_embed |
| return out |
|
|
|
|
| class DepthLoRAAdapter(nn.Module): |
| def __init__(self, dim: int, rank: int, max_loops: int) -> None: |
| super().__init__() |
| self.rank = max(0, rank) |
| if self.rank <= 0: |
| self.down = None |
| self.B = None |
| self.scale = None |
| return |
| self.down = nn.Linear(dim, self.rank, bias=False) |
| self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02) |
| self.scale = nn.Embedding(max(1, max_loops), self.rank) |
| nn.init.zeros_(self.scale.weight) |
|
|
| def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor: |
| if self.rank <= 0 or self.down is None or self.B is None or self.scale is None: |
| return torch.zeros_like(x) |
| t_idx = min(loop_t, self.scale.num_embeddings - 1) |
| scale = self.scale(torch.tensor(t_idx, device=x.device)) |
| return (self.down(x) * scale) @ self.B |
|
|
|
|
| class StableRecurrentInjection(nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.log_A = nn.Parameter(torch.full((dim,), -2.0)) |
| self.log_dt = nn.Parameter(torch.full((dim,), -2.0)) |
| self.input_gate = nn.Parameter(torch.zeros(dim)) |
|
|
| def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor: |
| A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1) |
| B = torch.sigmoid(self.input_gate).view(1, 1, -1) |
| return A * h + B * e + transformer_out |
|
|
|
|
| class AdaptiveHalting(nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.halt = nn.Linear(dim, 1, bias=True) |
| nn.init.zeros_(self.halt.weight) |
| nn.init.constant_(self.halt.bias, -2.0) |
|
|
| def forward(self, h: torch.Tensor) -> torch.Tensor: |
| return torch.sigmoid(self.halt(h)).squeeze(-1) |
|
|
|
|
| class EngramBlock(nn.Module): |
| """DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup. |
| |
| Stores common token-pair/triplet patterns in an embedding table and |
| retrieves them with multi-head hashing. A context-aware gate (using the |
| current hidden state as query) decides how much of the retrieved memory |
| to inject into the residual stream. |
| |
| Reference: DeepSeek-AI, "Conditional Memory via Scalable Lookup" (2025). |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| engram_dim: int, |
| n_heads: int = 4, |
| table_size: int = 8192, |
| max_ngram: int = 3, |
| ) -> None: |
| super().__init__() |
| self.dim = dim |
| self.engram_dim = engram_dim |
| self.n_heads = n_heads |
| self.table_size = table_size |
| self.max_ngram = max_ngram |
|
|
| |
| self.embeddings = nn.ParameterDict() |
| for n in range(2, max_ngram + 1): |
| for k in range(n_heads): |
| self.embeddings[f"{n}_{k}"] = nn.Parameter( |
| torch.randn(table_size, engram_dim) * (engram_dim**-0.5) |
| ) |
|
|
| |
| for n in range(2, max_ngram + 1): |
| for k in range(n_heads): |
| seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16) |
| rng = torch.Generator().manual_seed(seed) |
| a = torch.randint(1, 2**31, (1,), generator=rng).item() |
| b = torch.randint(0, 2**31, (1,), generator=rng).item() |
| self.register_buffer( |
| f"hash_a_{n}_{k}", torch.tensor(a), persistent=False |
| ) |
| self.register_buffer( |
| f"hash_b_{n}_{k}", torch.tensor(b), persistent=False |
| ) |
|
|
| |
| total_branch_dim = engram_dim * n_heads * (max_ngram - 1) |
| self.branch_conv = nn.Conv1d( |
| total_branch_dim, |
| total_branch_dim, |
| kernel_size=4, |
| dilation=max_ngram, |
| padding=0, |
| groups=total_branch_dim, |
| bias=True, |
| ) |
| nn.init.zeros_(self.branch_conv.weight) |
| nn.init.zeros_(self.branch_conv.bias) |
|
|
| |
| self.gate_query = nn.Linear(dim, engram_dim, bias=False) |
| self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False) |
| self.gate_value = nn.Linear(total_branch_dim, dim, bias=False) |
| self.gate_scale = engram_dim**-0.5 |
|
|
| def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor: |
| """Hash n-gram token sequences into table indices. |
| |
| Args: |
| token_ids: (B, T) token IDs |
| n: n-gram order (2 = bigram, 3 = trigram) |
| k: hash head index |
| Returns: |
| indices: (B, T) integer indices into embedding table |
| """ |
| a = getattr(self, f"hash_a_{n}_{k}") |
| b = getattr(self, f"hash_b_{n}_{k}") |
| B, T = token_ids.shape |
|
|
| |
| padded = F.pad(token_ids, (n - 1, 0), value=0) |
|
|
| |
| combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device) |
| for i in range(n): |
| combined = combined * 31 + padded[:, i : i + T].long() |
|
|
| indices = ((a * combined) ^ b) % self.table_size |
| return indices |
|
|
| def forward( |
| self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """Forward pass. |
| |
| Args: |
| hidden: (B, T, dim) current hidden state |
| token_ids: (B, T) input token IDs for n-gram hashing. |
| If None, uses argmax of hidden projections as proxy. |
| Returns: |
| output: (B, T, dim) memory injection for residual stream |
| """ |
| B, T, _ = hidden.shape |
|
|
| if token_ids is None: |
| |
| token_ids = hidden.mean(dim=-1).long() % self.table_size |
|
|
| |
| branch_outputs = [] |
| for n in range(2, self.max_ngram + 1): |
| for k in range(self.n_heads): |
| indices = self._hash_ngram(token_ids, n, k) |
| table = self.embeddings[f"{n}_{k}"] |
| retrieved = table[indices] |
| branch_outputs.append(retrieved) |
|
|
| |
| memory = torch.cat(branch_outputs, dim=-1) |
|
|
| |
| |
| conv_in = memory.transpose(1, 2) |
| conv_in = F.pad( |
| conv_in, |
| ((self.branch_conv.kernel_size[0] - 1) * self.branch_conv.dilation[0], 0), |
| ) |
| conv_out = self.branch_conv(conv_in) |
| memory = conv_out.transpose(1, 2) |
|
|
| |
| query = self.gate_query(hidden) |
| key = self.gate_key(memory) |
| gate = torch.sigmoid( |
| (query * key).sum(dim=-1, keepdim=True) * self.gate_scale |
| ) |
| value = self.gate_value(memory) |
|
|
| return gate * value |
|
|
|
|
| class SleepGate(nn.Module): |
| """Persistent memory + periodic consolidation gate.""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| cap: int = 128, |
| n_heads: int = 4, |
| retention_enabled: bool = True, |
| retention_hidden: int = 0, |
| ) -> None: |
| super().__init__() |
| self.dim = dim |
| self.cap = cap |
| self.n_heads = n_heads |
| self.head_dim = dim // n_heads |
| self.scale = self.head_dim ** -0.5 |
| self.retention_enabled = retention_enabled |
|
|
| self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16)) |
| self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long)) |
| self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32)) |
| self.register_buffer("mem_count", torch.zeros((), dtype=torch.long)) |
| self.register_buffer("mem_head", torch.zeros((), dtype=torch.long)) |
| self.register_buffer("global_step", torch.zeros((), dtype=torch.long)) |
|
|
| self.q_proj = nn.Linear(dim, dim, bias=False) |
| self.k_proj = nn.Linear(dim, dim, bias=False) |
| self.v_proj = nn.Linear(dim, dim, bias=False) |
| self.o_proj = nn.Linear(dim, dim, bias=False) |
| nn.init.zeros_(self.o_proj.weight) |
| self.gate_scale = nn.Parameter(torch.zeros(())) |
|
|
| if retention_enabled: |
| if retention_hidden > 0: |
| self.retention_gate: Optional[nn.Module] = nn.Sequential( |
| nn.Linear(dim, retention_hidden, bias=False), |
| nn.GELU(), |
| nn.Linear(retention_hidden, 1, bias=True), |
| ) |
| nn.init.constant_(self.retention_gate[-1].bias, 2.2) |
| else: |
| self.retention_gate = nn.Linear(dim, 1, bias=True) |
| nn.init.constant_(self.retention_gate.bias, 2.2) |
| else: |
| self.retention_gate = None |
|
|
| self._last_beta: Optional[torch.Tensor] = None |
|
|
| def write(self, hidden: torch.Tensor) -> None: |
| B, T, _ = hidden.shape |
| tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1) |
| if self.retention_gate is not None: |
| beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1)) |
| self._last_beta = beta_live if self.training else None |
| beta_store = beta_live.detach().float() |
| else: |
| self._last_beta = None |
| beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32) |
| tail = tail_full.to(self.mem_emb.dtype).detach() |
| with torch.no_grad(): |
| head = int(self.mem_head.item()) |
| count = int(self.mem_count.item()) |
| step = int(self.global_step.item()) |
| for b in range(B): |
| self.mem_emb[head] = tail[b] |
| self.mem_age[head] = step |
| self.mem_beta[head] = beta_store[b] |
| head = (head + 1) % self.cap |
| if count < self.cap: |
| count += 1 |
| self.mem_head.fill_(head) |
| self.mem_count.fill_(count) |
|
|
| def read(self, x: torch.Tensor) -> torch.Tensor: |
| count = int(self.mem_count.item()) |
| if count == 0: |
| return torch.zeros_like(x) |
| B, T, D = x.shape |
| mem = self.mem_emb[:count].clone().to(x.dtype) |
| q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) |
| v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) |
| attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale |
| attn = F.softmax(attn, dim=-1) |
| if self.retention_enabled: |
| step = int(self.global_step.item()) |
| ages = self.mem_age[:count].to(x.device) |
| delta = (step - ages).clamp(min=0).to(x.dtype) |
| betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0) |
| weights = betas.pow(delta) |
| attn = attn * weights.view(1, 1, 1, count) |
| attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9) |
| out = torch.einsum("bhtm,hmd->bhtd", attn, v) |
| out = out.transpose(1, 2).contiguous().view(B, T, D) |
| out = self.o_proj(out) |
| return torch.sigmoid(self.gate_scale) * out |
|
|
| @torch.no_grad() |
| def reset(self) -> None: |
| self.mem_emb.zero_() |
| self.mem_age.zero_() |
| self.mem_beta.fill_(1.0) |
| self.mem_count.zero_() |
| self.mem_head.zero_() |
| self.global_step.zero_() |
| self._last_beta = None |
|
|
|
|
| def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor: |
| M = torch.exp(logits.clamp(-10, 10)) |
| for _ in range(n_iters): |
| M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10) |
| M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10) |
| return M |
|
|
|
|
| class ManifoldHyperConnection(nn.Module): |
| def __init__(self, dim: int, expansion: int = 2) -> None: |
| super().__init__() |
| self.dim = dim |
| self.expansion = expansion |
| n = expansion |
|
|
| self.expand_fn = "duplicate" |
| self.collapse_fn = "mean" |
|
|
| self.bias_pre = nn.Parameter(torch.zeros(1, n)) |
| self.bias_post = nn.Parameter(torch.zeros(1, n)) |
| self.bias_res = nn.Parameter(torch.zeros(n, n)) |
|
|
| self.theta_pre = nn.Linear(n * dim, n, bias=False) |
| self.theta_post = nn.Linear(n * dim, n, bias=False) |
| self.theta_res = nn.Linear(n * dim, n * n, bias=False) |
|
|
| self.alpha_pre = nn.Parameter(torch.tensor(0.0)) |
| self.alpha_post = nn.Parameter(torch.tensor(0.0)) |
| self.alpha_res = nn.Parameter(torch.tensor(0.0)) |
|
|
| def _compute_mappings( |
| self, x_expanded: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| B, T, _ = x_expanded.shape |
| n = self.expansion |
|
|
| x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]]) |
|
|
| d_pre = torch.tanh(self.theta_pre(x_norm)) |
| d_post = torch.tanh(self.theta_post(x_norm)) |
| d_res = self.theta_res(x_norm) |
|
|
| H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre) |
| H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post) |
| H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape( |
| B, T, n, n |
| ) |
|
|
| H_res = _sinkhorn_knopp(H_res_raw) |
|
|
| return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res |
|
|
| def expand_stream(self, x: torch.Tensor) -> torch.Tensor: |
| return x.repeat(1, 1, self.expansion) |
|
|
| def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor: |
| B, T, _ = x_expanded.shape |
| n = self.expansion |
| C = self.dim |
| return x_expanded.view(B, T, n, C).mean(dim=-2) |
|
|
| def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor: |
| B, T, _ = x_expanded.shape |
| n = self.expansion |
| x_streams = x_expanded.view(B, T, n, self.dim) |
| return (H_pre @ x_streams).squeeze(-2) |
|
|
| def post_res_mix( |
| self, |
| layer_output: torch.Tensor, |
| x_expanded: torch.Tensor, |
| H_post: torch.Tensor, |
| H_res: torch.Tensor, |
| ) -> torch.Tensor: |
| B, T, _ = x_expanded.shape |
| n = self.expansion |
| C = self.dim |
|
|
| x_streams = x_expanded.view(B, T, n, C) |
| mixed = torch.matmul(H_res, x_streams) |
| post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2)) |
|
|
| result = mixed + post_out |
| return result.reshape(B, T, n * C) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| n_heads: int, |
| n_kv_heads: int, |
| head_dim: int, |
| ffn_dim: int, |
| dropout: float, |
| sliding_window: int, |
| rope_fraction: float, |
| engram_dim: int = 0, |
| engram_heads: int = 4, |
| engram_table_size: int = 8192, |
| engram_max_ngram: int = 3, |
| mhc_expansion: int = 1, |
| ) -> None: |
| super().__init__() |
| self.norm1 = RMSNorm(dim) |
| self.attn = CausalSelfAttention( |
| dim=dim, |
| n_heads=n_heads, |
| n_kv_heads=n_kv_heads, |
| head_dim=head_dim, |
| dropout=dropout, |
| sliding_window=sliding_window, |
| rope_fraction=rope_fraction, |
| ) |
| self.norm2 = RMSNorm(dim) |
| self.ffn = SwiGLU(dim, ffn_dim, dropout) |
| self.use_engram = engram_dim > 0 |
| if self.use_engram: |
| self.engram = EngramBlock( |
| dim=dim, |
| engram_dim=engram_dim, |
| n_heads=engram_heads, |
| table_size=engram_table_size, |
| max_ngram=engram_max_ngram, |
| ) |
| self.engram_norm = RMSNorm(dim) |
| self.use_mhc = mhc_expansion > 1 |
| if self.use_mhc: |
| self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion) |
| self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| is_global: bool, |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| token_ids: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| if self.use_mhc: |
| x_exp = self.mhc_attn.expand_stream(x) |
| H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp) |
| attn_in = self.mhc_attn.pre_mix(x_exp, H_pre) |
| attn_out, new_kv = self.attn( |
| self.norm1(attn_in), is_global, past_kv, use_cache |
| ) |
| x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res) |
| if self.use_engram: |
| collapsed = self.mhc_attn.collapse_stream(x_exp) |
| collapsed = collapsed + self.engram( |
| self.engram_norm(collapsed), token_ids=token_ids |
| ) |
| x_exp = self.mhc_attn.expand_stream(collapsed) |
| H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp) |
| ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2) |
| ffn_out = self.ffn(self.norm2(ffn_in)) |
| x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2) |
| x = self.mhc_attn.collapse_stream(x_exp) |
| else: |
| attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache) |
| x = x + attn_out |
| if self.use_engram: |
| x = x + self.engram(self.engram_norm(x), token_ids=token_ids) |
| x = x + self.ffn(self.norm2(x)) |
| return x, new_kv |
|
|
|
|
| class RecurrentDepthBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| n_heads: int, |
| n_kv_heads: int, |
| head_dim: int, |
| ffn_dim: int, |
| dropout: float, |
| sliding_window: int, |
| rope_fraction: float, |
| n_loops: int, |
| act_threshold: float, |
| lora_rank: int, |
| loop_embed_dim: int, |
| ) -> None: |
| super().__init__() |
| self.n_loops = max(1, n_loops) |
| self.act_threshold = act_threshold |
| self.loop_embed_dim = max(0, loop_embed_dim) |
| self.norm = RMSNorm(dim) |
| self.block = TransformerBlock( |
| dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, |
| ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, |
| rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, |
| ) |
| self.injection = StableRecurrentInjection(dim) |
| self.act = AdaptiveHalting(dim) |
| self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops) |
|
|
| def forward( |
| self, |
| h: torch.Tensor, |
| e: torch.Tensor, |
| token_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, |
| use_cache: bool = False, |
| n_loops: Optional[int] = None, |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: |
| loops = max(1, n_loops or self.n_loops) |
| B, T, _ = h.shape |
| halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) |
| cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype) |
| output = torch.zeros_like(h) |
| new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None |
| current = h |
| final_halt = None |
|
|
| for t in range(loops): |
| h_loop = loop_index_embedding(current, t, self.loop_embed_dim) |
| combined = self.norm(h_loop + e) |
| past_kv = None |
| if past_key_values is not None and t < len(past_key_values): |
| past_kv = past_key_values[t] |
| trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids) |
| trans_out = trans_out + self.lora(trans_out, t) |
| next_h = self.injection(current, e, trans_out) |
| p = self.act(next_h) |
| p = p * (~halted).to(p.dtype) |
| final_halt = p |
| should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold) |
| update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p) |
| output = output + next_h * update_weight.unsqueeze(-1) |
| cumulative_p = cumulative_p + update_weight |
| current = torch.where(halted.unsqueeze(-1), current, next_h) |
| halted = halted | should_halt |
| if new_past is not None: |
| new_past.append(layer_kv) |
| if not use_cache and bool(halted.all()): |
| break |
|
|
| remainder = (1.0 - cumulative_p).clamp(min=0.0) |
| output = output + current * remainder.unsqueeze(-1) |
| aux: Dict[str, torch.Tensor] = {} |
| if final_halt is not None: |
| aux["recurrent_halt_mean"] = final_halt.mean() |
| return output, aux, new_past |
|
|
|
|
| class TinyMemoryLM(nn.Module): |
| def __init__( |
| self, |
| vocab_size: int, |
| dim: int, |
| n_unique_layers: int, |
| n_logical_layers: int, |
| n_heads: int, |
| n_kv_heads: int, |
| ffn_dim: int, |
| dropout: float, |
| mtp_horizons: Sequence[int], |
| grad_checkpoint: bool, |
| sliding_window: int = 512, |
| rope_fraction: float = 0.5, |
| embed_scale: bool = True, |
| engram_dim: int = 0, |
| engram_heads: int = 4, |
| engram_table_size: int = 8192, |
| engram_max_ngram: int = 3, |
| mhc_expansion: int = 1, |
| sleep_gate_cap: int = 0, |
| sleep_gate_heads: int = 4, |
| sleep_retention_enabled: bool = True, |
| sleep_retention_hidden: int = 0, |
| latent_think_layers: int = 0, |
| prelude_layers: int = 0, |
| coda_layers: int = 0, |
| recurrent_loops: int = 0, |
| recurrent_act_threshold: float = 0.99, |
| recurrent_lora_rank: int = 0, |
| recurrent_loop_embed_dim: int = 0, |
| ) -> None: |
| super().__init__() |
| self.dim = dim |
| self.n_unique_layers = n_unique_layers |
| self.n_logical_layers = n_logical_layers |
| self.grad_checkpoint = grad_checkpoint |
| self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0 |
| head_dim = dim // n_heads |
|
|
| self.embed_tokens = nn.Embedding(vocab_size, dim) |
| self.head = nn.Linear(dim, vocab_size, bias=False) |
| self.head.weight = self.embed_tokens.weight |
| self.output_bias = nn.Parameter(torch.zeros(vocab_size)) |
|
|
| self.use_recurrent_depth = recurrent_loops > 0 |
| self.prelude_layers = max(0, prelude_layers) |
| self.coda_layers = max(0, coda_layers) |
| self.recurrent_loops = max(0, recurrent_loops) |
|
|
| self.blocks: Optional[nn.ModuleList] = None |
| self.prelude: Optional[nn.ModuleList] = None |
| self.recurrent: Optional[RecurrentDepthBlock] = None |
| self.coda: Optional[nn.ModuleList] = None |
|
|
| def _make_blocks(n: int) -> nn.ModuleList: |
| return nn.ModuleList([ |
| TransformerBlock( |
| dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, |
| ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, |
| rope_fraction=rope_fraction, engram_dim=engram_dim, |
| engram_heads=engram_heads, engram_table_size=engram_table_size, |
| engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion, |
| ) |
| for _ in range(n) |
| ]) |
|
|
| if self.use_recurrent_depth: |
| if self.prelude_layers > 0: |
| self.prelude = _make_blocks(self.prelude_layers) |
| self.recurrent = RecurrentDepthBlock( |
| dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, |
| ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, |
| rope_fraction=rope_fraction, n_loops=self.recurrent_loops, |
| act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank, |
| loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8), |
| ) |
| if self.coda_layers > 0: |
| self.coda = _make_blocks(self.coda_layers) |
| else: |
| self.blocks = _make_blocks(max(1, n_unique_layers)) |
|
|
| self.norm = RMSNorm(dim) |
|
|
| self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1}) |
| self.mtp_adapters = nn.ModuleDict( |
| {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons} |
| ) |
| self.mtp_norms = nn.ModuleDict( |
| {str(h): RMSNorm(dim) for h in self.mtp_horizons} |
| ) |
|
|
| res_scale = (2 * max(1, n_logical_layers)) ** -0.5 |
| for group in (self.blocks, self.prelude, self.coda): |
| if group is None: |
| continue |
| for block in group: |
| block.attn.wo.weight.data.mul_(res_scale) |
| block.ffn.down.weight.data.mul_(res_scale) |
| if self.recurrent is not None: |
| self.recurrent.block.attn.wo.weight.data.mul_(res_scale) |
| self.recurrent.block.ffn.down.weight.data.mul_(res_scale) |
|
|
| self.sleep_gate: Optional[SleepGate] = None |
| if sleep_gate_cap > 0: |
| self.sleep_gate = SleepGate( |
| dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads, |
| retention_enabled=sleep_retention_enabled, |
| retention_hidden=sleep_retention_hidden, |
| ) |
|
|
| self.think_blocks: Optional[nn.ModuleList] = None |
| self.think_norm: Optional[RMSNorm] = None |
| if latent_think_layers > 0: |
| self.think_blocks = nn.ModuleList([ |
| TransformerBlock( |
| dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, |
| ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048, |
| rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, |
| ) |
| for _ in range(latent_think_layers) |
| ]) |
| self.think_norm = RMSNorm(dim) |
|
|
| def resize_token_embeddings(self, new_vocab_size: int) -> None: |
| old_vocab_size = self.embed_tokens.num_embeddings |
| if new_vocab_size == old_vocab_size: |
| return |
| device = self.embed_tokens.weight.device |
| old_embed_weight = self.embed_tokens.weight.data.clone() |
| self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device) |
| self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device) |
| self.head.weight = self.embed_tokens.weight |
| old_bias = self.output_bias.data.clone() |
| self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device)) |
| copy_size = min(old_vocab_size, new_vocab_size) |
| self.output_bias.data[:copy_size] = old_bias[:copy_size] |
| self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size] |
|
|
| def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]: |
| if self.blocks is None: |
| return [] |
| blocks_list = list(self.blocks) |
| full_sequence = blocks_list + blocks_list |
| return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])] |
|
|
| def forward( |
| self, |
| ids: torch.Tensor, |
| use_cache: bool = False, |
| past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, |
| return_hidden: bool = False, |
| ) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: |
| B, T = ids.shape |
| x = self.embed_tokens(ids) * self.embed_scale_factor |
| new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None |
| aux: Dict[str, torch.Tensor] = {} |
|
|
| if self.use_recurrent_depth: |
| offset = 0 |
| if self.prelude is not None: |
| for block in self.prelude: |
| past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None |
| x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) |
| if new_past_key_values is not None: |
| new_past_key_values.append(layer_kv) |
| offset += 1 |
| encoded = x |
| recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None |
| x, recurrent_aux, recurrent_kv = self.recurrent( |
| x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache, |
| ) |
| aux.update(recurrent_aux) |
| if new_past_key_values is not None and recurrent_kv is not None: |
| new_past_key_values.extend(recurrent_kv) |
| offset += self.recurrent_loops |
| if self.coda is not None: |
| for block in self.coda: |
| past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None |
| x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) |
| if new_past_key_values is not None: |
| new_past_key_values.append(layer_kv) |
| offset += 1 |
| else: |
| logical_layers = self._build_logical_layers() |
| last_logical_idx = len(logical_layers) - 1 |
| for layer_idx, (block, logical_idx) in enumerate(logical_layers): |
| is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx |
| past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None |
| if self.grad_checkpoint and self.training and not use_cache: |
| x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True) |
| else: |
| x, layer_kv = block(x, is_global, past_kv, use_cache, ids) |
| if new_past_key_values is not None: |
| new_past_key_values.append(layer_kv) |
|
|
| x = self.norm(x) |
|
|
| if self.sleep_gate is not None: |
| x = x + self.sleep_gate.read(x) |
| if self.training: |
| self.sleep_gate.write(x) |
|
|
| if self.think_blocks is not None: |
| for think_block in self.think_blocks: |
| x, _ = think_block(x, is_global=True) |
| x = self.think_norm(x) |
|
|
| h_out = x if return_hidden else None |
| logits = self.head(x) |
| if self.embed_scale_factor != 1.0: |
| logits = logits / self.embed_scale_factor |
| logits = logits + self.output_bias |
|
|
| mtp: Dict[int, torch.Tensor] = {} |
| if self.mtp_horizons and self.training: |
| for horizon in self.mtp_horizons: |
| if horizon > 1 and horizon <= T - 1: |
| shifted_h = x[:, :-horizon, :] |
| adapted_h = self.mtp_adapters[str(horizon)](shifted_h) |
| adapted_h = self.mtp_norms[str(horizon)](adapted_h) |
| mtp_logits = self.head(adapted_h) |
| if self.embed_scale_factor != 1.0: |
| mtp_logits = mtp_logits / self.embed_scale_factor |
| mtp_logits = mtp_logits + self.output_bias |
| mtp[horizon] = mtp_logits |
|
|
| return logits, mtp, aux, h_out, new_past_key_values |
|
|
|
|
| |
| |
| |
|
|
|
|
| def build_stop_token_ids(tokenizer: WordTokenizer) -> set: |
| stop_tokens = {tokenizer.eos_id} |
| for tok in ("<|user|>", "<|system|>", "<|assistant|>"): |
| tid = tokenizer.token_to_id.get(tok) |
| if tid is not None: |
| stop_tokens.add(int(tid)) |
| return stop_tokens |
|
|
|
|
| def apply_no_repeat_ngram( |
| logits: torch.Tensor, |
| token_history: Sequence[int], |
| ngram_size: int, |
| ) -> torch.Tensor: |
| if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1): |
| return logits |
| prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple() |
| banned: set = set() |
| for i in range(len(token_history) - ngram_size + 1): |
| if tuple(token_history[i : i + ngram_size - 1]) == prefix: |
| banned.add(int(token_history[i + ngram_size - 1])) |
| if not banned: |
| return logits |
| out = logits.clone() |
| banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long) |
| out[banned_ids] = float("-inf") |
| return out |
|
|
|
|
| def apply_loop_penalty( |
| logits: torch.Tensor, |
| tokenizer: WordTokenizer, |
| generated_text: str, |
| penalty: float = 5.0, |
| ) -> torch.Tensor: |
| """Detect repeated substring loops and penalise continuation tokens.""" |
| if len(generated_text) < 16: |
| return logits |
| out = logits.clone() |
| for span_len in [24, 16, 12, 8]: |
| if len(generated_text) < span_len * 2: |
| continue |
| suffix = generated_text[-span_len:] |
| prev = generated_text[:-span_len].rfind(suffix) |
| if prev == -1: |
| continue |
| next_pos = prev + span_len |
| if next_pos < len(generated_text): |
| next_char = generated_text[next_pos] |
| tid = tokenizer.token_to_id.get(next_char) |
| if tid is not None: |
| out[tid] -= penalty |
| break |
| return out |
|
|
|
|
| def apply_min_p(logits: torch.Tensor, min_p: float) -> torch.Tensor: |
| """Filter tokens below min_p fraction of the top token probability.""" |
| if min_p <= 0.0: |
| return logits |
| probs = torch.softmax(logits, dim=-1) |
| threshold = probs.max() * min_p |
| out = logits.clone() |
| out[probs < threshold] = float("-inf") |
| return out |
|
|
|
|
| def generate( |
| model: TinyMemoryLM, |
| tokenizer: WordTokenizer, |
| prompt: str, |
| max_new_tokens: int = 256, |
| temperature: float = 0.8, |
| top_k: int = 16, |
| top_p: float = 0.95, |
| repetition_penalty: float = 1.0, |
| device: str = "cuda", |
| sft_mode: bool = True, |
| stream: bool = True, |
| no_repeat_ngram_size: int = 0, |
| context_window: int = 2048, |
| logit_soft_cap: float = 15.0, |
| min_p: float = 0.05, |
| loop_penalty: float = 5.0, |
| ) -> str: |
| if sft_mode: |
| full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" |
| else: |
| full_prompt = prompt |
| input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) |
| input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) |
| visible_tokens: List[str] = [] |
| stop_token_ids = build_stop_token_ids(tokenizer) |
| generated_text = "" |
|
|
| generated_ids: List[int] = [] |
| |
| full_ids_history: List[int] = list(input_ids) |
|
|
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| ctx_ids = ( |
| input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t |
| ) |
| logits, *_ = model(ctx_ids) |
| next_logits = logits[0, -1, :].clone() |
|
|
| |
| if logit_soft_cap > 0: |
| next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap) |
|
|
| raw_next_logits = next_logits.clone() |
|
|
| |
| if repetition_penalty != 1.0 and generated_ids: |
| for tok_id in set(generated_ids): |
| if next_logits[tok_id] > 0: |
| next_logits[tok_id] /= repetition_penalty |
| else: |
| next_logits[tok_id] *= repetition_penalty |
|
|
| |
| if no_repeat_ngram_size > 0 and generated_ids: |
| next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size) |
|
|
| |
| next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty) |
|
|
| |
| if temperature != 1.0: |
| next_logits = next_logits / max(temperature, 1e-6) |
|
|
| |
| if min_p > 0: |
| next_logits = apply_min_p(next_logits, min_p) |
|
|
| |
| if top_k > 0: |
| v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) |
| next_logits[next_logits < v[-1]] = float("-inf") |
|
|
| |
| if 0.0 < top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) |
| sorted_probs = torch.softmax(sorted_logits, dim=-1) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| remove_mask = cumulative_probs > top_p |
| remove_mask[0] = False |
| indices_to_remove = sorted_indices[remove_mask] |
| next_logits[indices_to_remove] = float("-inf") |
|
|
| |
| if not torch.isfinite(next_logits).any(): |
| next_logits = raw_next_logits |
| if temperature != 1.0: |
| next_logits = next_logits / max(temperature, 1e-6) |
|
|
| if temperature == 0: |
| next_id = torch.argmax(next_logits).item() |
| else: |
| probs = torch.softmax(next_logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1).item() |
| if next_id in stop_token_ids: |
| break |
| token_str = ( |
| tokenizer.id_to_token[next_id] |
| if next_id < len(tokenizer.id_to_token) |
| else "" |
| ) |
| generated_ids.append(next_id) |
| full_ids_history.append(next_id) |
| if token_str not in tokenizer.special: |
| visible_tokens.append(token_str) |
| generated_text += token_str |
| if stream: |
| print(token_str, end="", flush=True) |
| input_ids_t = torch.cat( |
| [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 |
| ) |
| if stream: |
| print() |
| return "".join(visible_tokens) |
|
|
|
|
| def generate_stream( |
| model: TinyMemoryLM, |
| tokenizer: WordTokenizer, |
| prompt: str, |
| max_new_tokens: int = 256, |
| temperature: float = 0.8, |
| top_k: int = 16, |
| top_p: float = 0.95, |
| repetition_penalty: float = 1.0, |
| device: str = "cpu", |
| sft_mode: bool = True, |
| no_repeat_ngram_size: int = 0, |
| context_window: int = 2048, |
| logit_soft_cap: float = 15.0, |
| min_p: float = 0.05, |
| loop_penalty: float = 5.0, |
| ) -> "Iterator[str]": |
| """Yield the accumulated response string after each new token (for Gradio streaming).""" |
| if sft_mode: |
| full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" |
| else: |
| full_prompt = prompt |
| input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) |
| input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) |
| stop_token_ids = build_stop_token_ids(tokenizer) |
| generated_ids: list = [] |
| generated_text = "" |
|
|
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| ctx_ids = input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t |
| logits, *_ = model(ctx_ids) |
| next_logits = logits[0, -1, :].clone() |
|
|
| if logit_soft_cap > 0: |
| next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap) |
|
|
| raw_next_logits = next_logits.clone() |
|
|
| if repetition_penalty != 1.0 and generated_ids: |
| for tok_id in set(generated_ids): |
| if next_logits[tok_id] > 0: |
| next_logits[tok_id] /= repetition_penalty |
| else: |
| next_logits[tok_id] *= repetition_penalty |
|
|
| if no_repeat_ngram_size > 0 and generated_ids: |
| next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size) |
|
|
| next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty) |
|
|
| if temperature != 1.0: |
| next_logits = next_logits / max(temperature, 1e-6) |
| if min_p > 0: |
| next_logits = apply_min_p(next_logits, min_p) |
| if top_k > 0: |
| v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) |
| next_logits[next_logits < v[-1]] = float("-inf") |
| if 0.0 < top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) |
| sorted_probs = torch.softmax(sorted_logits, dim=-1) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| remove_mask = cumulative_probs > top_p |
| remove_mask[0] = False |
| next_logits[sorted_indices[remove_mask]] = float("-inf") |
| if not torch.isfinite(next_logits).any(): |
| next_logits = raw_next_logits |
| if temperature != 1.0: |
| next_logits = next_logits / max(temperature, 1e-6) |
|
|
| if temperature == 0: |
| next_id = int(torch.argmax(next_logits).item()) |
| else: |
| probs = torch.softmax(next_logits, dim=-1) |
| next_id = int(torch.multinomial(probs, num_samples=1).item()) |
|
|
| if next_id in stop_token_ids: |
| break |
|
|
| token_str = tokenizer.id_to_token[next_id] if next_id < len(tokenizer.id_to_token) else "" |
| generated_ids.append(next_id) |
| if token_str not in tokenizer.special: |
| generated_text += token_str |
| yield generated_text |
|
|
| input_ids_t = torch.cat( |
| [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def series_from_name(name: str) -> str | None: |
| lower = (name or "").lower() |
| if "haiku" in lower: |
| return "Haiku" |
| if "sonnet" in lower: |
| return "Sonnet" |
| if "opus" in lower: |
| return "Opus" |
| return None |
|
|
|
|
| def series_config(series: str) -> dict[str, object]: |
| return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"]) |
|
|
|
|
| def discover_models(runs_dir: Path) -> List[dict]: |
| models = [] |
| if not runs_dir.is_dir(): |
| return models |
| for child in sorted(runs_dir.iterdir()): |
| if not child.is_dir(): |
| continue |
| tokenizer_path = child / "tokenizer.json" |
| if not tokenizer_path.exists(): |
| continue |
| name = child.name |
| series = None |
| for ckpt_name in ("model.pt", "pretrain.pt"): |
| ckpt_path = child / ckpt_name |
| if ckpt_path.exists(): |
| series = _fast_series_from_checkpoint(ckpt_path) |
| break |
| if series is None: |
| series = series_from_name(name) or "Sonnet" |
| found = False |
| for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"): |
| ckpt_path = child / ckpt_name |
| if ckpt_path.exists(): |
| models.append( |
| { |
| "name": name, |
| "checkpoint": ckpt_name, |
| "series": series, |
| "model_path": ckpt_path, |
| "tokenizer_path": tokenizer_path, |
| } |
| ) |
| found = True |
| if not found: |
| step_ckpts = sorted( |
| child.glob("checkpoint_step_*.pt"), |
| key=lambda p: int(p.stem.rsplit("_", 1)[-1]), |
| ) |
| if step_ckpts: |
| ckpt_path = step_ckpts[-1] |
| models.append( |
| { |
| "name": name, |
| "checkpoint": ckpt_path.name, |
| "series": series, |
| "model_path": ckpt_path, |
| "tokenizer_path": tokenizer_path, |
| } |
| ) |
| return models |
|
|
|
|
| def _detect_engram(state_dict): |
| for key in state_dict: |
| if ".engram." in key: |
| if ".embeddings." in key: |
| return state_dict[key].shape[-1] |
| return 0 |
|
|
|
|
| def _detect_mhc(state_dict): |
| for key, val in state_dict.items(): |
| if ".mhc_attn.bias_pre" in key and val.dim() == 2: |
| return val.shape[-1] |
| return 1 |
|
|
|
|
| def _detect_sleep_gate(state_dict) -> Tuple[int, int]: |
| for key, val in state_dict.items(): |
| if key == "sleep_gate.mem_emb" and val.dim() == 2: |
| cap = val.shape[0] |
| return cap, 4 |
| return 0, 4 |
|
|
|
|
| def _detect_latent_think(state_dict) -> int: |
| indices = { |
| int(k.split(".")[1]) |
| for k in state_dict |
| if k.startswith("think_blocks.") and k.split(".")[1].isdigit() |
| } |
| return max(indices) + 1 if indices else 0 |
|
|
|
|
| def _detect_prelude_layers(state_dict) -> int: |
| indices = { |
| int(k.split(".")[1]) |
| for k in state_dict |
| if k.startswith("prelude.") and k.split(".")[1].isdigit() |
| } |
| return max(indices) + 1 if indices else 0 |
|
|
|
|
| def _detect_coda_layers(state_dict) -> int: |
| indices = { |
| int(k.split(".")[1]) |
| for k in state_dict |
| if k.startswith("coda.") and k.split(".")[1].isdigit() |
| } |
| return max(indices) + 1 if indices else 0 |
|
|
|
|
| def _detect_recurrent_loops(state_dict) -> int: |
| if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict: |
| if "recurrent.lora.scale.weight" in state_dict: |
| return state_dict["recurrent.lora.scale.weight"].shape[0] |
| return 1 |
| return 0 |
|
|
|
|
| def _detect_recurrent_lora_rank(state_dict) -> int: |
| for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): |
| if key in state_dict: |
| shape = state_dict[key].shape |
| if len(shape) == 2: |
| return int(shape[0]) |
| return 0 |
|
|
|
|
| def _infer_series_from_lora_rank(rank: int) -> str | None: |
| if rank == 0: |
| return None |
| if rank <= 8: |
| return "haiku" |
| if rank <= 16: |
| return "sonnet" |
| return "opus" |
|
|
|
|
| def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None: |
| try: |
| cp = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| sd = cp.get("model_state", cp.get("state_dict", {})) |
| rank = 0 |
| for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): |
| if key in sd: |
| rank = int(sd[key].shape[0]) |
| break |
| if rank == 0: |
| return None |
| if rank <= 8: |
| return "Haiku" |
| if rank <= 16: |
| return "Sonnet" |
| return "Opus" |
| except Exception: |
| pass |
| return None |
|
|
|
|
| def _infer_arch_from_state_dict(state_dict, cfg): |
| """Infer architecture hyper-parameters directly from checkpoint weights, |
| falling back to *cfg* (series config) when a key is not found.""" |
| overrides = {} |
|
|
| has_prelude = any(k.startswith("prelude.") for k in state_dict) |
| has_blocks = any(k.startswith("blocks.") for k in state_dict) |
| has_recurrent = any(k.startswith("recurrent.") for k in state_dict) |
| uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks |
|
|
| |
| if "embed_tokens.weight" in state_dict: |
| overrides["dim"] = state_dict["embed_tokens.weight"].shape[1] |
|
|
| if uses_recurrent_arch: |
| if "prelude.0.ffn.gate.weight" in state_dict: |
| overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0] |
| overrides["n_unique_layers"] = 0 |
| src = "prelude.0" |
| else: |
| if "blocks.0.ffn.gate.weight" in state_dict: |
| overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0] |
| block_ids = { |
| int(k.split(".")[1]) |
| for k in state_dict |
| if k.startswith("blocks.") and k.split(".")[1].isdigit() |
| } |
| if block_ids: |
| overrides["n_unique_layers"] = max(block_ids) + 1 |
| src = "blocks.0" |
|
|
| dim = overrides.get("dim", int(cfg.get("dim", model_config.dim))) |
| if f"{src}.attn.wq.weight" in state_dict: |
| wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0] |
| if f"{src}.attn.q_norm.weight" in state_dict: |
| head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0] |
| overrides["n_heads"] = wq_rows // head_dim |
| if f"{src}.attn.wk.weight" in state_dict: |
| wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0] |
| if f"{src}.attn.k_norm.weight" in state_dict: |
| head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0] |
| overrides["n_kv_heads"] = wk_rows // head_dim |
|
|
| |
| for key, val in state_dict.items(): |
| if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2: |
| overrides["engram_table_size"] = val.shape[0] |
| overrides["engram_dim"] = val.shape[1] |
| break |
| engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0))) |
| engram_max_ngram = int(cfg.get("engram_max_ngram", 2)) |
| if engram_dim > 0: |
| for key, val in state_dict.items(): |
| if ".engram.branch_conv.weight" in key and val.dim() == 3: |
| total_branch_dim = val.shape[0] |
| denom = engram_dim * (engram_max_ngram - 1) |
| if denom > 0 and total_branch_dim % denom == 0: |
| overrides["engram_heads"] = total_branch_dim // denom |
| break |
|
|
| merged = dict(cfg) |
| merged.update(overrides) |
| return merged |
|
|
|
|
| def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dict: |
| tokenizer = WordTokenizer.load(tokenizer_path) |
| ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False) |
| cfg = series_config(series) |
| vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size)) |
|
|
| state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt |
|
|
| cfg = _infer_arch_from_state_dict(state_dict, cfg) |
|
|
| engram_dim = int(cfg.get("engram_dim", 0)) |
| if _detect_engram(state_dict) == 0: |
| engram_dim = 0 |
|
|
| mhc_expansion = _detect_mhc(state_dict) |
| if mhc_expansion == 1: |
| mhc_expansion = int(cfg.get("mhc_expansion", 1)) |
|
|
| ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict) |
| sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0)) |
| sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4)) |
| sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True)) |
| sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0)) |
|
|
| latent_think_layers = _detect_latent_think(state_dict) |
| if latent_think_layers == 0: |
| latent_think_layers = int(cfg.get("latent_think_layers", 0)) |
|
|
| prelude_layers = _detect_prelude_layers(state_dict) |
| coda_layers = _detect_coda_layers(state_dict) |
| recurrent_loops = _detect_recurrent_loops(state_dict) |
|
|
| ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict) |
| if ckpt_lora_rank > 0: |
| inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank) |
| if inferred_series and inferred_series != series.lower(): |
| series = inferred_series.capitalize() |
| cfg = series_config(series) |
| recurrent_lora_rank = ckpt_lora_rank |
| else: |
| recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0)) |
|
|
| recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99)) |
| recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0)) |
|
|
| n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers)) |
|
|
| model = TinyMemoryLM( |
| vocab_size=vocab_size, |
| dim=int(cfg.get("dim", model_config.dim)), |
| n_unique_layers=n_unique, |
| n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)), |
| n_heads=int(cfg.get("n_heads", model_config.n_heads)), |
| n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)), |
| ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)), |
| dropout=float(cfg.get("dropout", model_config.dropout)), |
| mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)), |
| grad_checkpoint=False, |
| sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))), |
| rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))), |
| embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))), |
| engram_dim=engram_dim, |
| engram_heads=int(cfg.get("engram_heads", 4)), |
| engram_table_size=int(cfg.get("engram_table_size", 8192)), |
| engram_max_ngram=int(cfg.get("engram_max_ngram", 3)), |
| mhc_expansion=mhc_expansion, |
| sleep_gate_cap=sleep_gate_cap, |
| sleep_gate_heads=sleep_gate_heads, |
| sleep_retention_enabled=sleep_retention_enabled, |
| sleep_retention_hidden=sleep_retention_hidden, |
| latent_think_layers=latent_think_layers, |
| prelude_layers=prelude_layers, |
| coda_layers=coda_layers, |
| recurrent_loops=recurrent_loops, |
| recurrent_act_threshold=recurrent_act_threshold, |
| recurrent_lora_rank=recurrent_lora_rank, |
| recurrent_loop_embed_dim=recurrent_loop_embed_dim, |
| ) |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
| if tokenizer.vocab_size > vocab_size: |
| model.resize_token_embeddings(tokenizer.vocab_size) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
| return { |
| "model": model, |
| "tokenizer": tokenizer, |
| "device": device, |
| "series": series, |
| "sft_mode": ckpt.get("sft_mode", None), |
| "phase": ckpt.get("phase", None), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict: |
| try: |
| from huggingface_hub import snapshot_download |
| except ImportError: |
| print("huggingface_hub not installed. Install with: pip install huggingface_hub") |
| sys.exit(1) |
|
|
| print(f"Downloading {hf_id}...") |
| try: |
| local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) |
| except Exception as e: |
| print(f"Failed to download {hf_id}: {e}") |
| return None |
|
|
| print(f"Using cached {hf_id} from {local_dir}") |
|
|
| |
| if (local_dir / "models").exists(): |
| model_dir = local_dir / "models" |
| elif (local_dir / "model").exists(): |
| model_dir = local_dir / "model" |
| else: |
| model_dir = local_dir |
| model_path = model_dir / "model.pt" |
| pretrain_path = model_dir / "pretrain.pt" |
| tokenizer_path = model_dir / "tokenizer.json" |
|
|
| ckpt_path = None |
| for p in [model_path, pretrain_path]: |
| if p.exists(): |
| ckpt_path = p |
| break |
|
|
| if ckpt_path is None or not tokenizer_path.exists(): |
| print(f"Missing model files in {model_dir}") |
| print(f" model.pt exists: {model_path.exists()}") |
| print(f" pretrain.pt exists: {pretrain_path.exists()}") |
| print(f" tokenizer.json exists: {tokenizer_path.exists()}") |
| return None |
|
|
| return { |
| "model_path": ckpt_path, |
| "tokenizer_path": tokenizer_path, |
| "model_name": ckpt_path.stem, |
| } |
|
|
|
|
| def load_huggingface_model(hf_id: str, cache_dir: Path) -> dict: |
| files = download_huggingface_model(hf_id, cache_dir) |
| if files is None: |
| return None |
|
|
| return load_local_model(files["model_path"], files["tokenizer_path"], "Haiku") |
|
|
|
|
| |
| |
| |
|
|
| _hf_model_cache: Dict[str, dict] = {} |
|
|
|
|
| def prefetch_huggingface_models() -> None: |
| root = Path(__file__).resolve().parent |
| cache_dir = root / "cache" / "huggingface" |
| cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print("Downloading/preparing HuggingFace models...") |
| for name, hf_id in HUGGINGFACE_MODELS.items(): |
| print(f" {name}...") |
| bundle = load_huggingface_model(hf_id, cache_dir) |
| if bundle: |
| _hf_model_cache[name] = bundle |
| print(f"Prepared {len(_hf_model_cache)} HuggingFace models") |
|
|
|
|
| def compare_all_models(prompt: str, cfg: dict) -> None: |
| root = Path(__file__).resolve().parent |
| runs_dir = root / "runs" |
| all_models = discover_models(runs_dir) |
|
|
| is_pretrain = not cfg.get("sft_mode", True) |
| local_models = [ |
| m for m in all_models |
| if ("pretrain" in m["checkpoint"]) == is_pretrain |
| ] |
|
|
| if not local_models: |
| print("No models found matching mode.") |
| return |
|
|
| results: List[dict] = [] |
|
|
| for m in local_models: |
| print(f"\n{'='*60}") |
| print(f"Loading local {m['name']}/{m['checkpoint']}...") |
| try: |
| bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) |
| except Exception as e: |
| print(f"Failed to load {m['name']}: {e}") |
| continue |
|
|
| model = bundle["model"] |
| tokenizer = bundle["tokenizer"] |
| device = bundle["device"] |
|
|
| print(f"Generating on '{prompt}'...") |
| output = generate( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| max_new_tokens=cfg["max_new_tokens"], |
| temperature=cfg["temperature"], |
| top_k=cfg["top_k"], |
| top_p=cfg["top_p"], |
| min_p=cfg["min_p"], |
| no_repeat_ngram_size=cfg["no_repeat_ngram_size"], |
| repetition_penalty=cfg["repetition_penalty"], |
| logit_soft_cap=cfg["logit_soft_cap"], |
| loop_penalty=cfg["loop_penalty"], |
| device=str(device), |
| sft_mode=cfg["sft_mode"], |
| stream=True, |
| context_window=cfg["context_window"], |
| ) |
|
|
| results.append({ |
| "name": f"[LOCAL] {m['name']}/{m['checkpoint']}", |
| "output": output, |
| "device": device, |
| }) |
|
|
| for name, bundle in _hf_model_cache.items(): |
| print(f"\n{'='*60}") |
| print(f"Loading {name} (cached)...") |
|
|
| model = bundle["model"] |
| tokenizer = bundle["tokenizer"] |
| device = bundle["device"] |
|
|
| print(f"Generating on '{prompt}'...") |
| output = generate( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| max_new_tokens=cfg["max_new_tokens"], |
| temperature=cfg["temperature"], |
| top_k=cfg["top_k"], |
| top_p=cfg["top_p"], |
| min_p=cfg["min_p"], |
| no_repeat_ngram_size=cfg["no_repeat_ngram_size"], |
| repetition_penalty=cfg["repetition_penalty"], |
| logit_soft_cap=cfg["logit_soft_cap"], |
| loop_penalty=cfg["loop_penalty"], |
| device=str(device), |
| sft_mode=cfg["sft_mode"], |
| stream=True, |
| context_window=cfg["context_window"], |
| ) |
|
|
| results.append({ |
| "name": name, |
| "output": output, |
| "device": device, |
| }) |
|
|
| print(f"\n{'='*60}") |
| print("=" * 60) |
| print("SIDE-BY-SIDE COMPARISON") |
| print("=" * 60) |
| for r in results: |
| print(f"\n--- {r['name']} ---") |
| print(r["output"]) |
| print(f"\n{'='*60}") |
|
|
|
|
| |
| |
| |
|
|
| BENCHMARKS = { |
| "blimp": { |
| "label": "BLiMP", |
| "desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.", |
| "hf_dataset": ("nyu-mll/blimp", None), |
| "metric": "accuracy", |
| }, |
| "wikitext2": { |
| "label": "WikiText-2", |
| "desc": "LM perplexity on Wikipedia test split. Lower is better.", |
| "hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"), |
| "metric": "perplexity", |
| }, |
| "arc_easy": { |
| "label": "ARC-Easy", |
| "desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.", |
| "hf_dataset": ("allenai/ai2_arc", "ARC-Easy"), |
| "metric": "accuracy", |
| }, |
| } |
|
|
|
|
| def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float: |
| ids = tokenizer.encode(text, add_bos=True, add_eos=False) |
| if len(ids) < 2: |
| return float("nan") |
| ids_t = torch.tensor([ids], dtype=torch.long, device=device) |
| with torch.no_grad(): |
| logits, *_ = model(ids_t) |
| log_probs = F.log_softmax(logits[0], dim=-1) |
| targets = ids_t[0, 1:] |
| nll = -log_probs[range(len(targets)), targets].mean().item() |
| return nll |
|
|
|
|
| def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float: |
| full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False) |
| ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False) |
| n_ctx = len(ctx_ids) |
| n_ref = len(full_ids) - n_ctx |
| if n_ref <= 0: |
| return float("nan") |
| ids_t = torch.tensor([full_ids], dtype=torch.long, device=device) |
| with torch.no_grad(): |
| logits, *_ = model(ids_t) |
| log_probs = F.log_softmax(logits[0], dim=-1) |
| targets = ids_t[0, 1:] |
| ref_start = n_ctx - 1 |
| ref_end = min(ref_start + n_ref, log_probs.shape[0]) |
| if ref_start >= ref_end: |
| return float("nan") |
| nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item() |
| return nll |
|
|
|
|
| BLIMP_PARADIGMS = [ |
| "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement", |
| "animate_subject_passive", "animate_subject_trans", "causative", |
| "complex_NP_island", "coordinate_structure_constraint_complex_left_branch", |
| "coordinate_structure_constraint_object_extraction", |
| "determiner_noun_agreement_1", "determiner_noun_agreement_2", |
| "determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2", |
| "determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1", |
| "determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1", |
| "distractor_agreement_relational_noun", "distractor_agreement_relative_clause", |
| "drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2", |
| "existential_there_object_raising", "existential_there_quantifiers_1", |
| "existential_there_quantifiers_2", "existential_there_subject_raising", |
| "expletive_it_object_raising", "inchoative", "intransitive", |
| "irregular_past_participle_adjectives", "irregular_past_participle_verbs", |
| "irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2", |
| "left_branch_island_echo_question", "left_branch_island_simple_question", |
| "matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2", |
| "only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2", |
| "principle_A_c_command", "principle_A_case_1", "principle_A_case_2", |
| "principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3", |
| "principle_A_reconstruction", "regular_plural_subject_verb_agreement_1", |
| "regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present", |
| "sentential_negation_npi_scope", "sentential_subject_island", |
| "superlative_quantifiers_1", "superlative_quantifiers_2", |
| "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island", |
| "wh_questions_object_gap", "wh_questions_subject_gap", |
| "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap", |
| "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap", |
| "wh_vs_that_with_gap_long_distance", |
| ] |
|
|
|
|
| def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]: |
| from datasets import load_dataset |
| accuracies: List[float] = [] |
| for paradigm in BLIMP_PARADIGMS: |
| try: |
| ds = load_dataset("nyu-mll/blimp", paradigm, split="train") |
| except Exception as e: |
| print(f" {paradigm}: skip ({e})") |
| accuracies.append(float("nan")) |
| continue |
| items = list(ds)[:n_samples] |
| correct = 0 |
| for ex in items: |
| good_nll = _score_text(model, tokenizer, ex["sentence_good"], device) |
| bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device) |
| if math.isnan(good_nll) or math.isnan(bad_nll): |
| continue |
| if good_nll < bad_nll: |
| correct += 1 |
| acc = correct / len(items) if items else float("nan") |
| accuracies.append(acc) |
| print(f" {paradigm:50s} acc={acc:.3f}") |
| return BLIMP_PARADIGMS, accuracies |
|
|
|
|
| def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]: |
| from datasets import load_dataset |
| ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") |
| full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip()) |
| chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)] |
| chunks = [c for c in chunks if len(c) > 20][:max_chunks] |
| labels: List[str] = [] |
| ppls: List[float] = [] |
| for i, chunk in enumerate(chunks): |
| nll = _score_text(model, tokenizer, chunk, device) |
| ppl = math.exp(nll) if not math.isnan(nll) else float("nan") |
| labels.append(f"chunk {i + 1}") |
| ppls.append(ppl) |
| if (i + 1) % 10 == 0: |
| valid = [v for v in ppls if not math.isnan(v)] |
| mean = sum(valid) / len(valid) if valid else float("nan") |
| print(f" chunk {i + 1}/{len(chunks)} running mean ppl={mean:.2f}") |
| return labels, ppls |
|
|
|
|
| def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]: |
| from datasets import load_dataset |
| ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test") |
| items = list(ds)[:max_samples] |
| labels: List[str] = [] |
| scores: List[float] = [] |
| for i, ex in enumerate(items): |
| question = ex["question"] |
| choices = ex["choices"]["text"] |
| choice_labels = ex["choices"]["label"] |
| answer_key = ex["answerKey"] |
| context = f"Question: {question}\nAnswer:" |
| nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices] |
| if all(math.isnan(v) for v in nlls): |
| scores.append(float("nan")) |
| else: |
| best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf")) |
| predicted = choice_labels[best_idx] |
| scores.append(1.0 if predicted == answer_key else 0.0) |
| labels.append(f"Q{i + 1}") |
| n_valid = sum(1 for s in scores if not math.isnan(s)) |
| acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan") |
| print(f" {n_valid} questions evaluated, accuracy={acc:.3f}") |
| return labels, scores |
|
|
|
|
| def run_benchmark_mode() -> None: |
| try: |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| except ImportError: |
| print("matplotlib not installed. pip install matplotlib") |
| return |
|
|
| bench_keys = list(BENCHMARKS.keys()) |
| print("\nBenchmarks:") |
| for i, k in enumerate(bench_keys): |
| b = BENCHMARKS[k] |
| print(f" [{i + 1}] {b['label']} β {b['desc']}") |
| print("Select benchmark [1]:", end=" ", flush=True) |
| try: |
| b_choice = input().strip() or "1" |
| except (EOFError, KeyboardInterrupt): |
| print() |
| return |
| if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)): |
| print("Invalid selection.") |
| return |
| bench_key = bench_keys[int(b_choice) - 1] |
| bench = BENCHMARKS[bench_key] |
| print(f"Benchmark: {bench['label']}") |
|
|
| root = Path(__file__).resolve().parent |
| runs_dir = root / "runs" |
| all_models = discover_models(runs_dir) |
|
|
| model_entries: List[dict] = [] |
| for m in all_models: |
| model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m}) |
| for hf_name, hf_id in HUGGINGFACE_MODELS.items(): |
| model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name}) |
|
|
| if not model_entries: |
| print("No models found.") |
| return |
|
|
| print("\nAvailable models:") |
| for i, e in enumerate(model_entries): |
| print(f" [{i + 1}] {e['label']}") |
| print(" [a] All models") |
| print("Select models (comma-separated or 'a'):", end=" ", flush=True) |
| try: |
| raw = input().strip() |
| except (EOFError, KeyboardInterrupt): |
| print() |
| return |
|
|
| if raw.lower() == "a": |
| selected = list(range(len(model_entries))) |
| else: |
| selected = [] |
| for tok in raw.split(","): |
| tok = tok.strip() |
| if tok.isdigit() and 1 <= int(tok) <= len(model_entries): |
| selected.append(int(tok) - 1) |
| if not selected: |
| print("No valid selection.") |
| return |
|
|
| all_results: List[dict] = [] |
| shared_x_labels: Optional[List[str]] = None |
|
|
| for idx in selected: |
| entry = model_entries[idx] |
| print(f"\n{'='*60}\nLoading {entry['label']}...") |
| try: |
| if entry["type"] == "local": |
| m = entry["meta"] |
| bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) |
| else: |
| bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache") |
| except Exception as e: |
| print(f" Failed: {e}") |
| continue |
|
|
| model = bundle["model"] |
| tokenizer = bundle["tokenizer"] |
| device = str(bundle["device"]) |
| model.eval() |
|
|
| if bench_key == "blimp": |
| x_labels, y_vals = _run_blimp(model, tokenizer, device) |
| elif bench_key == "wikitext2": |
| x_labels, y_vals = _run_wikitext2(model, tokenizer, device) |
| else: |
| x_labels, y_vals = _run_arc_easy(model, tokenizer, device) |
|
|
| if shared_x_labels is None: |
| shared_x_labels = x_labels |
|
|
| valid = [v for v in y_vals if not math.isnan(v)] |
| summary = sum(valid) / len(valid) if valid else float("nan") |
| all_results.append({"label": entry["label"], "y": y_vals, "summary": summary}) |
|
|
| if not all_results or shared_x_labels is None: |
| print("No results to plot.") |
| return |
|
|
| metric = bench["metric"] |
| paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]), |
| reverse=(metric != "perplexity")) |
| summaries, model_labels = zip(*paired) if paired else ([], []) |
| n = len(summaries) |
| colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)] |
|
|
| fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6)) |
| bars = ax.bar(range(n), summaries, color=colors, edgecolor="black") |
| for bar, val in zip(bars, summaries): |
| ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, |
| f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold") |
|
|
| ylabel = "Mean Perplexity (β better)" if metric == "perplexity" else "Mean Accuracy (β better)" |
| ax.set_ylabel(ylabel) |
| ax.set_title(f"{bench['label']} Benchmark β Model Comparison") |
| ax.set_xticks(range(n)) |
| ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9) |
| if metric == "accuracy": |
| ax.set_ylim(0, 1.05) |
| ax.grid(True, axis="y", alpha=0.3) |
| plt.tight_layout() |
|
|
| out_path = root / f"benchmark_{bench_key}.png" |
| plt.savefig(str(out_path), dpi=150) |
| print(f"\nChart saved to {out_path}") |
| try: |
| import subprocess |
| subprocess.Popen(["xdg-open", str(out_path)]) |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _pick_series(detected: str) -> str: |
| series_list = list(MODEL_SERIES.keys()) |
| detected_lower = detected.lower() |
| default_idx = next( |
| (i + 1 for i, s in enumerate(series_list) if s == detected_lower), 1 |
| ) |
|
|
| |
| if len(series_list) == 1: |
| return series_list[0].capitalize() |
|
|
| print("Series:") |
| for i, s in enumerate(series_list): |
| marker = " (detected)" if s == detected_lower else "" |
| print(f" [{i + 1}] {s.capitalize()}{marker}") |
| while True: |
| try: |
| choice = input(f"Select series [{default_idx}]: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print() |
| sys.exit(0) |
| if not choice: |
| choice = str(default_idx) |
| if choice.isdigit() and 1 <= int(choice) <= len(series_list): |
| return series_list[int(choice) - 1].capitalize() |
| print(f"Enter a number 1-{len(series_list)}") |
|
|
|
|
| def pick_model(runs_dir: Path) -> tuple[dict, str]: |
| models = discover_models(runs_dir) |
| if not models: |
| print(f"No models found in {runs_dir}") |
| print("Expected layout: runs/<name>/model.pt (or pretrain.pt) + tokenizer.json") |
| sys.exit(1) |
|
|
| if len(models) == 1: |
| m = models[0] |
| print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") |
| bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) |
| return bundle, m["checkpoint"] |
|
|
| print("Available models:") |
| for i, m in enumerate(models): |
| print(f" [{i + 1}] {m['name']}/{m['checkpoint']} ({m['series']})") |
| while True: |
| try: |
| choice = input("Select model [1]: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print() |
| sys.exit(0) |
| if not choice: |
| choice = "1" |
| if choice.isdigit() and 1 <= int(choice) <= len(models): |
| m = models[int(choice) - 1] |
| print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") |
| bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) |
| return bundle, m["checkpoint"] |
| print(f"Enter a number 1-{len(models)}") |
|
|
|
|
| |
| |
| |
|
|
| MODES = { |
| "chat-coherent": { |
| "label": "Chat β Coherent", |
| "desc": "structured, consistent, strong repetition control", |
| "sft_mode": "chat", |
| "temperature": 0.35, |
| "top_k": 20, |
| "top_p": 0.88, |
| "min_p": 0.10, |
| "no_repeat_ngram_size": 4, |
| "repetition_penalty": 1.22, |
| "logit_soft_cap": 20.0, |
| "loop_penalty": 20.0, |
| "max_new_tokens": 4096, |
| "context_window": 2048, |
| }, |
| "chat-variants": { |
| "label": "Chat β Variants", |
| "desc": "creative, diverse, more surprising outputs", |
| "sft_mode": "chat", |
| "temperature": 0.65, |
| "top_k": 60, |
| "top_p": 0.92, |
| "min_p": 0.05, |
| "no_repeat_ngram_size": 4, |
| "repetition_penalty": 1.12, |
| "logit_soft_cap": 20.0, |
| "loop_penalty": 14.0, |
| "max_new_tokens": 4096, |
| "context_window": 2048, |
| }, |
| "pretrain-coherent": { |
| "label": "Pretrain β Coherent", |
| "desc": "grounded continuation, low temperature, tight sampling", |
| "sft_mode": False, |
| "temperature": 0.3, |
| "top_k": 20, |
| "top_p": 0.85, |
| "min_p": 0.10, |
| "no_repeat_ngram_size": 4, |
| "repetition_penalty": 1.2, |
| "logit_soft_cap": 20.0, |
| "loop_penalty": 20.0, |
| "max_new_tokens": 4096, |
| "context_window": 2048, |
| }, |
| "pretrain-variants": { |
| "label": "Pretrain β Variants", |
| "desc": "free-form continuation, higher temperature, more exploration", |
| "sft_mode": False, |
| "temperature": 0.7, |
| "top_k": 60, |
| "top_p": 0.93, |
| "min_p": 0.04, |
| "no_repeat_ngram_size": 4, |
| "repetition_penalty": 1.12, |
| "logit_soft_cap": 20.0, |
| "loop_penalty": 12.0, |
| "max_new_tokens": 4096, |
| "context_window": 2048, |
| }, |
| } |
|
|
| _MODE_LIST = list(MODES.keys()) |
|
|
|
|
| def pick_mode(is_pretrain: bool) -> dict: |
| """Prompt the user to choose a generation mode. Returns a config dict.""" |
| |
| candidates = [k for k in _MODE_LIST if ("pretrain" in k) == is_pretrain] |
| print("\nGeneration mode:") |
| for i, key in enumerate(candidates): |
| cfg = MODES[key] |
| print(f" [{i + 1}] {cfg['label']} β {cfg['desc']}") |
| while True: |
| try: |
| choice = input("Select mode [1]: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print() |
| sys.exit(0) |
| if not choice: |
| choice = "1" |
| if choice.isdigit() and 1 <= int(choice) <= len(candidates): |
| key = candidates[int(choice) - 1] |
| cfg = MODES[key] |
| print(f"Mode: {cfg['label']}") |
| return cfg |
| print(f"Enter a number 1-{len(candidates)}") |
|
|
|
|
| def _run_loop(bundle: dict, cfg: dict) -> None: |
| model = bundle["model"] |
| tokenizer = bundle["tokenizer"] |
| device = bundle["device"] |
| sft = cfg["sft_mode"] |
| prompt_label = "You" if sft else "Prompt" |
| print(f"\nModel ready on {device}. Type your message, or /quit to exit.") |
| print(f" temp={cfg['temperature']} top_k={cfg['top_k']} top_p={cfg['top_p']}") |
| print(f" min_p={cfg['min_p']} ng={cfg['no_repeat_ngram_size']} rp={cfg['repetition_penalty']}") |
| print(f" cap={cfg['logit_soft_cap']} loop_penalty={cfg['loop_penalty']}\n") |
| while True: |
| try: |
| prompt = input(f"{prompt_label}: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print() |
| break |
| if not prompt: |
| continue |
| if prompt in ("/quit", "/exit", "/q"): |
| break |
| if prompt == "/help": |
| print("Commands: /quit /exit /q /help /mode") |
| if sft: |
| print("Anything else is sent as a chat prompt.") |
| else: |
| print("Anything else is sent as a raw continuation prompt.") |
| continue |
| if prompt == "/mode": |
| print(f"Current: {cfg['label']} β {cfg['desc']}") |
| continue |
| print("AI: ", end="", flush=True) |
| generate( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| max_new_tokens=cfg["max_new_tokens"], |
| temperature=cfg["temperature"], |
| top_k=cfg["top_k"], |
| top_p=cfg["top_p"], |
| min_p=cfg["min_p"], |
| no_repeat_ngram_size=cfg["no_repeat_ngram_size"], |
| repetition_penalty=cfg["repetition_penalty"], |
| logit_soft_cap=cfg["logit_soft_cap"], |
| loop_penalty=cfg["loop_penalty"], |
| device=str(device), |
| sft_mode=cfg["sft_mode"], |
| stream=True, |
| context_window=cfg["context_window"], |
| ) |
|
|
|
|
|
|
|
|
| |
| |
| |
|
|
| _COLLECTION_SLUG = "CompactAI-O/tmlm-haiku-series" |
| _AUTHOR = "CompactAI-O" |
| _SEARCH = "TMLM-Haiku" |
|
|
| _FALLBACK_COLLECTION = [ |
| {"version": "TMLM-Haiku-2.3", "hf_id": "CompactAI-O/TMLM-Haiku-2.3"}, |
| {"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"}, |
| {"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"}, |
| {"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"}, |
| {"version": "Glint-1", "hf_id": "CompactAI-O/Glint-1"}, |
| ] |
|
|
| _EXTRA_REPOS = ["CompactAI-O/Glint-1"] |
|
|
|
|
| def _probe_repo(hf_id: str) -> dict | None: |
| """Return entry dict for one repo, or None if no usable checkpoints found.""" |
| from huggingface_hub import list_repo_files |
|
|
| try: |
| files = set(list_repo_files(hf_id)) |
| except Exception: |
| return None |
|
|
| |
| subdir: str | None = None |
| for candidate in ("models", "model"): |
| if any(f.startswith(f"{candidate}/") for f in files): |
| subdir = candidate |
| break |
|
|
| prefix = f"{subdir}/" if subdir else "" |
|
|
| |
| pt_files = sorted( |
| f[len(prefix):] for f in files |
| if f.startswith(prefix) and f.endswith(".pt") |
| ) |
|
|
| _LABELS = { |
| "model.pt": ("Chat (SFT)", False), |
| "model_rep.pt": ("Chat (anti-repetition)", False), |
| "pretrain.pt": ("Pretrain (base)", True), |
| } |
|
|
| checkpoints = [] |
| for fname in pt_files: |
| label, is_pretrain = _LABELS.get(fname, (fname.removesuffix(".pt"), "pretrain" in fname)) |
| checkpoints.append((label, fname, is_pretrain)) |
|
|
| if not checkpoints: |
| return None |
|
|
| return { |
| "version": hf_id.split("/")[-1], |
| "hf_id": hf_id, |
| "subdir": subdir, |
| "checkpoints": checkpoints, |
| "desc": "", |
| } |
|
|
|
|
| def fetch_collection() -> list[dict]: |
| """Query HF for all CompactAI-O TMLM-Haiku models, newest first.""" |
| from huggingface_hub import HfApi |
|
|
| print("Checking HuggingFace collection for available models...") |
| try: |
| api = HfApi() |
| infos = list( |
| api.list_models( |
| author=_AUTHOR, |
| search=_SEARCH, |
| sort="lastModified", |
| ) |
| ) |
| infos.sort(key=lambda m: getattr(m, "lastModified", ""), reverse=True) |
| except Exception as exc: |
| print(f" Could not reach HuggingFace ({exc}); using fallback list.") |
| infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION] |
|
|
| entries = [] |
| seen_ids: set = set() |
| for info in infos: |
| repo_id = info.id |
| if _SEARCH.lower() not in repo_id.lower(): |
| continue |
| entry = _probe_repo(repo_id) |
| if entry: |
| entries.append(entry) |
| seen_ids.add(repo_id) |
|
|
| |
| for repo_id in _EXTRA_REPOS: |
| if repo_id not in seen_ids: |
| entry = _probe_repo(repo_id) |
| if entry: |
| entries.append(entry) |
| seen_ids.add(repo_id) |
|
|
| if not entries: |
| print(" No models found; using fallback list.") |
| for fb in _FALLBACK_COLLECTION: |
| e = _probe_repo(fb["hf_id"]) |
| if e: |
| entries.append(e) |
|
|
| return entries |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _download_version(entry: dict, cache_dir: Path) -> Path: |
| """Download full repo snapshot; return the directory containing model files.""" |
| try: |
| from huggingface_hub import snapshot_download |
| except ImportError: |
| print("huggingface_hub not installed. Run: pip install huggingface_hub") |
| sys.exit(1) |
|
|
| hf_id = entry["hf_id"] |
| print(f"Fetching {hf_id} ...") |
| try: |
| local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) |
| except Exception as exc: |
| print(f"Download failed: {exc}") |
| sys.exit(1) |
|
|
| subdir = entry.get("subdir") |
| model_dir = (local_dir / subdir) if subdir else local_dir |
| if not model_dir.exists(): |
| |
| model_dir = local_dir |
| return model_dir |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _prompt_int(prompt: str, lo: int, hi: int, default: int = 1) -> int: |
| while True: |
| try: |
| raw = input(f"{prompt} [{default}]: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print() |
| sys.exit(0) |
| if not raw: |
| return default |
| if raw.isdigit() and lo <= int(raw) <= hi: |
| return int(raw) |
| print(f" Enter a number {lo}β{hi}.") |
|
|
|
|
| def pick_version(collection: list[dict]) -> dict: |
| print("\nTMLM-Haiku series (CompactAI-O)\n") |
| for i, entry in enumerate(collection): |
| desc = f" β {entry['desc']}" if entry["desc"] else "" |
| print(f" [{i + 1}] {entry['version']}{desc}") |
| idx = _prompt_int("Select version", 1, len(collection)) |
| return collection[idx - 1] |
|
|
|
|
| def pick_checkpoint(entry: dict) -> tuple[str, bool]: |
| """Return (filename, is_pretrain).""" |
| ckpts = entry["checkpoints"] |
| if len(ckpts) == 1: |
| label, fname, is_pretrain = ckpts[0] |
| print(f" Using: {label} ({fname})") |
| return fname, is_pretrain |
|
|
| print(f"\nCheckpoints for {entry['version']}:") |
| for i, (label, fname, _) in enumerate(ckpts): |
| print(f" [{i + 1}] {label} ({fname})") |
| idx = _prompt_int("Select checkpoint", 1, len(ckpts)) |
| label, fname, is_pretrain = ckpts[idx - 1] |
| return fname, is_pretrain |
|
|
|
|
| |
| |
| |
|
|
|
|
|
|
| |
| |
| |
|
|
| import gradio as gr |
|
|
| _CACHE_DIR = Path(__file__).parent / ".hf_cache" |
| _CACHE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| _collection_cache: list = [] |
| _model_cache: dict = {} |
|
|
|
|
| def _get_collection() -> list: |
| global _collection_cache |
| if not _collection_cache: |
| try: |
| _collection_cache = fetch_collection() |
| except Exception as e: |
| print(f"Warning: fetch_collection failed ({e}); using fallback.") |
| _collection_cache = [ |
| _probe_repo(e["hf_id"]) or {"version": e["version"], "hf_id": e["hf_id"], |
| "subdir": None, "checkpoints": [("Chat (SFT)", "model.pt", False)], "desc": ""} |
| for e in _FALLBACK_COLLECTION |
| ] |
| return _collection_cache |
|
|
|
|
| def _collection_versions() -> list[str]: |
| return [e["version"] for e in _get_collection()] |
|
|
|
|
| def _checkpoints_for(version: str) -> list[tuple[str, str, bool]]: |
| for e in _get_collection(): |
| if e["version"] == version: |
| return e["checkpoints"] |
| return [] |
|
|
|
|
| def _ckpt_labels(version: str) -> list[str]: |
| return [label for label, _, _ in _checkpoints_for(version)] |
|
|
|
|
| def _ckpt_is_pretrain(version: str, label: str) -> bool: |
| for lbl, _, is_pt in _checkpoints_for(version): |
| if lbl == label: |
| return is_pt |
| return False |
|
|
|
|
| def _ckpt_fname(version: str, label: str) -> str: |
| for lbl, fname, _ in _checkpoints_for(version): |
| if lbl == label: |
| return fname |
| return "model.pt" |
|
|
|
|
| def _load_bundle(version: str, ckpt_label: str) -> dict: |
| key = f"{version}/{ckpt_label}" |
| if key not in _model_cache: |
| fname = _ckpt_fname(version, ckpt_label) |
| for entry in _get_collection(): |
| if entry["version"] == version: |
| model_dir = _download_version(entry, _CACHE_DIR) |
| model_path = model_dir / fname |
| tokenizer_path = model_dir / "tokenizer.json" |
| _model_cache[key] = load_local_model(model_path, tokenizer_path, "Haiku") |
| break |
| return _model_cache[key] |
|
|
|
|
| def _build_conversation_prompt(history: list[dict], new_message: str) -> str: |
| """Flatten Gradio messages-format history + new turn into a raw prompt.""" |
| parts = [] |
| |
| i = 0 |
| while i < len(history) - 1: |
| u = history[i] |
| a = history[i + 1] |
| if u["role"] == "user" and a["role"] == "assistant": |
| parts.append(f"<|user|>\n{u['content']}\n<|assistant|>\n{a['content']}") |
| i += 2 |
| parts.append(f"<|user|>\n{new_message}\n<|assistant|>\n") |
| return "".join(parts) |
|
|
|
|
| |
|
|
| def _on_version_change(version): |
| labels = _ckpt_labels(version) |
| return gr.update(choices=labels, value=labels[0] if labels else None) |
|
|
|
|
| def _chat_submit(message, history): |
| history = history or [] |
| history.append({"role": "user", "content": message}) |
| history.append({"role": "assistant", "content": ""}) |
| return "", history |
|
|
|
|
| def _chat_stream(history, version, ckpt_label, mode_key, use_custom, |
| temperature, top_k, top_p, min_p, rep_penalty, |
| ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode): |
| if not history or history[-1]["role"] != "assistant": |
| yield history |
| return |
| try: |
| bundle = _load_bundle(version, ckpt_label) |
| except Exception as e: |
| history[-1]["content"] = f"[Error loading model: {e}]" |
| yield history |
| return |
|
|
| prior_msgs = history[:-2] |
| new_msg = history[-2]["content"] |
|
|
| if use_custom: |
| cfg = { |
| "sft_mode": not raw_mode, |
| "temperature": temperature, "top_k": top_k, "top_p": top_p, |
| "min_p": min_p, "repetition_penalty": rep_penalty, |
| "no_repeat_ngram_size": ngram_size, "logit_soft_cap": soft_cap, |
| "loop_penalty": loop_pen, "max_new_tokens": max_tokens, |
| "context_window": ctx_win, |
| } |
| else: |
| cfg = dict(MODES[mode_key]) |
| |
| cfg["max_new_tokens"] = int(max_tokens) |
| cfg["context_window"] = int(ctx_win) |
|
|
| if prior_msgs: |
| prompt = _build_conversation_prompt(prior_msgs, new_msg) |
| sft = False |
| else: |
| prompt = new_msg |
| sft = cfg["sft_mode"] |
|
|
| for partial in generate_stream( |
| model=bundle["model"], tokenizer=bundle["tokenizer"], |
| prompt=prompt, device=str(bundle["device"]), |
| sft_mode=sft, |
| temperature=cfg["temperature"], top_k=cfg["top_k"], |
| top_p=cfg["top_p"], min_p=cfg["min_p"], |
| repetition_penalty=cfg["repetition_penalty"], |
| no_repeat_ngram_size=cfg["no_repeat_ngram_size"], |
| logit_soft_cap=cfg["logit_soft_cap"], |
| loop_penalty=cfg["loop_penalty"], |
| max_new_tokens=cfg["max_new_tokens"], |
| context_window=cfg["context_window"], |
| ): |
| history[-1]["content"] = partial |
| yield history |
|
|
|
|
| |
|
|
| def _compare_fn(prompt, selected_versions, mode_key, use_custom, |
| temperature, top_k, top_p, min_p, rep_penalty, |
| ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode): |
| if use_custom: |
| cfg = { |
| "sft_mode": not raw_mode, |
| "temperature": temperature, "top_k": top_k, "top_p": top_p, |
| "min_p": min_p, "repetition_penalty": rep_penalty, |
| "no_repeat_ngram_size": ngram_size, "logit_soft_cap": soft_cap, |
| "loop_penalty": loop_pen, "max_new_tokens": max_tokens, |
| "context_window": ctx_win, |
| } |
| else: |
| cfg = dict(MODES[mode_key]) |
| |
| cfg["max_new_tokens"] = int(max_tokens) |
| cfg["context_window"] = int(ctx_win) |
|
|
| |
| |
| all_versions = _sort_oldest_to_newest(_collection_versions()) |
| selected = set(selected_versions or []) |
| state = {v: ("β³ Queuedβ¦" if v in selected else "") for v in all_versions} |
|
|
| def _emit(): |
| return [state[v] for v in all_versions] |
|
|
| yield _emit() |
|
|
| for version in all_versions: |
| if version not in selected: |
| continue |
| labels = _ckpt_labels(version) |
| ckpt_label = labels[0] if labels else None |
| if not ckpt_label: |
| state[version] = "[No checkpoint found]" |
| yield _emit() |
| continue |
|
|
| state[version] = "β³ Loadingβ¦" |
| yield _emit() |
|
|
| try: |
| bundle = _load_bundle(version, ckpt_label) |
| except Exception as e: |
| state[version] = f"[Load error: {e}]" |
| yield _emit() |
| continue |
|
|
| state[version] = "" |
| yield _emit() |
|
|
| try: |
| for partial in generate_stream( |
| model=bundle["model"], tokenizer=bundle["tokenizer"], |
| prompt=prompt, device=str(bundle["device"]), |
| sft_mode=cfg["sft_mode"], |
| temperature=cfg["temperature"], top_k=cfg["top_k"], |
| top_p=cfg["top_p"], min_p=cfg["min_p"], |
| repetition_penalty=cfg["repetition_penalty"], |
| no_repeat_ngram_size=cfg["no_repeat_ngram_size"], |
| logit_soft_cap=cfg["logit_soft_cap"], |
| loop_penalty=cfg["loop_penalty"], |
| max_new_tokens=cfg["max_new_tokens"], |
| context_window=cfg["context_window"], |
| ): |
| state[version] = partial |
| yield _emit() |
| except Exception as e: |
| state[version] = f"[Generation error: {e}]" |
| yield _emit() |
|
|
|
|
| |
|
|
| def _benchmark_fn(bench_key, selected_versions, max_samples, |
| progress=gr.Progress(track_tqdm=True)): |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| if not selected_versions: |
| return "No models selected.", None |
|
|
| bench = BENCHMARKS[bench_key] |
| all_results = [] |
| log_lines = [f"Benchmark: {bench['label']}", ""] |
|
|
| for version in progress.tqdm(selected_versions, desc="Benchmarking"): |
| log_lines.append(f"--- {version} ---") |
| labels = _ckpt_labels(version) |
| ckpt_label = labels[0] if labels else None |
| if not ckpt_label: |
| log_lines.append(" (no checkpoint)") |
| continue |
| try: |
| bundle = _load_bundle(version, ckpt_label) |
| model, tokenizer, device = bundle["model"], bundle["tokenizer"], str(bundle["device"]) |
| model.eval() |
| if bench_key == "blimp": |
| _, y = _run_blimp(model, tokenizer, device, n_samples=max_samples) |
| elif bench_key == "wikitext2": |
| _, y = _run_wikitext2(model, tokenizer, device, max_chunks=max_samples) |
| else: |
| _, y = _run_arc_easy(model, tokenizer, device, max_samples=max_samples) |
| valid = [v for v in y if not math.isnan(v)] |
| summary = sum(valid) / len(valid) if valid else float("nan") |
| all_results.append({"label": version, "summary": summary}) |
| log_lines.append(f" score: {summary:.4f}") |
| except Exception as e: |
| log_lines.append(f" error: {e}") |
|
|
| if not all_results: |
| return "\n".join(log_lines), None |
|
|
| metric = bench["metric"] |
| paired = sorted( |
| zip([r["summary"] for r in all_results], [r["label"] for r in all_results]), |
| reverse=(metric != "perplexity"), |
| ) |
| summaries, labels_ = zip(*paired) |
| n = len(summaries) |
| colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)] |
| fig, ax = plt.subplots(figsize=(max(6, n * 1.6), 5)) |
| bars = ax.bar(range(n), summaries, color=colors, edgecolor="black") |
| for bar, val in zip(bars, summaries): |
| ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, |
| f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold") |
| ylabel = "Mean Perplexity (β better)" if metric == "perplexity" else "Mean Accuracy (β better)" |
| ax.set_ylabel(ylabel) |
| ax.set_title(f"{bench['label']} β Model Comparison") |
| ax.set_xticks(range(n)) |
| ax.set_xticklabels(labels_, rotation=20, ha="right", fontsize=9) |
| if metric == "accuracy": |
| ax.set_ylim(0, 1.05) |
| ax.grid(True, axis="y", alpha=0.3) |
| plt.tight_layout() |
| out_path = "/tmp/benchmark_result.png" |
| plt.savefig(out_path, dpi=150) |
| plt.close(fig) |
| log_lines += ["", "Done."] |
| return "\n".join(log_lines), out_path |
|
|
|
|
| |
|
|
| def _advanced_block(): |
| with gr.Accordion("Advanced parameters", open=False): |
| use_custom = gr.Checkbox(label="Override preset with custom values below", value=False) |
| raw_mode = gr.Checkbox(label="Raw / pretrain mode (no <|user|> wrapping)", value=False) |
| with gr.Row(): |
| temperature = gr.Slider(0.0, 2.0, value=0.5, step=0.01, label="Temperature") |
| top_k = gr.Slider(0, 200, value=20, step=1, label="Top-k") |
| with gr.Row(): |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") |
| min_p = gr.Slider(0.0, 0.5, value=0.05, step=0.01, label="Min-p") |
| with gr.Row(): |
| rep_penalty = gr.Slider(1.0, 2.0, value=1.15, step=0.01, label="Repetition penalty") |
| ngram_size = gr.Slider(0, 8, value=4, step=1, label="No-repeat n-gram size") |
| with gr.Row(): |
| soft_cap = gr.Slider(0.0, 50.0, value=20.0, step=0.5, label="Logit soft cap") |
| loop_pen = gr.Slider(0.0, 50.0, value=15.0, step=0.5, label="Loop penalty") |
| with gr.Row(): |
| max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max new tokens") |
| ctx_win = gr.Slider(128, 4096, value=2048, step=128, label="Context window") |
| return use_custom, temperature, top_k, top_p, min_p, rep_penalty, ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode |
|
|
|
|
| |
|
|
| def _sort_oldest_to_newest(versions: list[str]) -> list[str]: |
| """Sort versions oldestβnewest using HUGGINGFACE_MODELS key order.""" |
| order = {name: i for i, name in enumerate(HUGGINGFACE_MODELS)} |
| return sorted( |
| versions, |
| key=lambda v: (order.get(v, len(order)), versions.index(v)), |
| ) |
|
|
|
|
| _initial_versions = _sort_oldest_to_newest(_collection_versions()) |
| _initial_version = _initial_versions[0] if _initial_versions else None |
| _initial_ckpt_labels = _ckpt_labels(_initial_version) if _initial_version else [] |
| _mode_keys = list(MODES.keys()) |
|
|
| |
| _HF_THEME = gr.themes.Default( |
| primary_hue=gr.themes.Color( |
| c50="#FFFBEA", |
| c100="#FFF3C4", |
| c200="#FCE588", |
| c300="#FADB5F", |
| c400="#F7C948", |
| c500="#FFD21E", |
| c600="#F0B429", |
| c700="#CB6E17", |
| c800="#B44D12", |
| c900="#8D2B0B", |
| c950="#5C1A04", |
| ), |
| secondary_hue="orange", |
| neutral_hue="slate", |
| font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"], |
| font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace"], |
| ).set( |
| body_background_fill="#FFFFFF", |
| body_background_fill_dark="#0B0F19", |
| background_fill_primary="#FFFFFF", |
| background_fill_primary_dark="#0B0F19", |
| background_fill_secondary="#F5F6F8", |
| background_fill_secondary_dark="#1B1F2A", |
| block_background_fill="#FFFFFF", |
| block_background_fill_dark="#11151F", |
| block_border_color="#E5E7EB", |
| block_border_color_dark="#22273A", |
| block_label_background_fill="#FFFBEA", |
| block_label_background_fill_dark="#1B1F2A", |
| block_label_text_color="#5C1A04", |
| block_label_text_color_dark="#FFD21E", |
| block_title_text_color="#0B0F19", |
| block_title_text_color_dark="#F5F6F8", |
| button_primary_background_fill="#FFD21E", |
| button_primary_background_fill_hover="#F0B429", |
| button_primary_text_color="#0B0F19", |
| button_primary_text_color_hover="#0B0F19", |
| button_secondary_background_fill="#F5F6F8", |
| button_secondary_background_fill_dark="#1B1F2A", |
| button_secondary_text_color="#0B0F19", |
| button_secondary_text_color_dark="#F5F6F8", |
| border_color_accent="#FFD21E", |
| border_color_primary="#E5E7EB", |
| border_color_primary_dark="#22273A", |
| color_accent_soft="#FFFBEA", |
| color_accent_soft_dark="#1B1F2A", |
| ) |
|
|
| with gr.Blocks(title="CompactAI Models", theme=_HF_THEME) as demo: |
| gr.Markdown( |
| "# CompactAI β TinyMemoryLM\n" |
| "Tiny recurrent-depth language models from [CompactAI-O](https://huggingface.co/CompactAI-O)." |
| ) |
|
|
| |
| with gr.Tab("Chat"): |
| with gr.Row(): |
| with gr.Column(scale=1, min_width=240): |
| chat_version = gr.Dropdown( |
| choices=_initial_versions, |
| value=_initial_version, |
| label="Model version", |
| ) |
| chat_ckpt = gr.Dropdown( |
| choices=_initial_ckpt_labels, |
| value=_initial_ckpt_labels[0] if _initial_ckpt_labels else None, |
| label="Checkpoint", |
| ) |
| chat_mode = gr.Radio( |
| choices=_mode_keys, |
| value="chat-coherent", |
| label="Mode preset", |
| info="Ignored when 'Override preset' is checked.", |
| ) |
| c_use_custom, c_temp, c_topk, c_topp, c_minp, c_rep, c_ng, c_cap, c_lp, c_maxt, c_ctx, c_raw = _advanced_block() |
|
|
| with gr.Column(scale=3): |
| chatbot = gr.Chatbot(label="Conversation", height=500) |
| with gr.Row(): |
| msg_box = gr.Textbox(placeholder="Type a messageβ¦", show_label=False, scale=5) |
| send_btn = gr.Button("Send", variant="primary", scale=1) |
| clear_btn = gr.Button("Clear") |
|
|
| chat_version.change(_on_version_change, chat_version, chat_ckpt) |
|
|
| _chat_adv = [chat_version, chat_ckpt, chat_mode, |
| c_use_custom, c_temp, c_topk, c_topp, c_minp, |
| c_rep, c_ng, c_cap, c_lp, c_maxt, c_ctx, c_raw] |
|
|
| msg_box.submit(_chat_submit, [msg_box, chatbot], [msg_box, chatbot], queue=False).then( |
| _chat_stream, [chatbot] + _chat_adv, chatbot |
| ) |
| send_btn.click(_chat_submit, [msg_box, chatbot], [msg_box, chatbot], queue=False).then( |
| _chat_stream, [chatbot] + _chat_adv, chatbot |
| ) |
| clear_btn.click(lambda: [], None, chatbot, queue=False) |
|
|
| |
| with gr.Tab("Compare All Models"): |
| gr.Markdown("Run the same prompt on every selected model. Outputs stream live one model at a time.") |
| with gr.Row(): |
| cmp_prompt = gr.Textbox(label="Prompt", placeholder="Enter a promptβ¦", lines=4, scale=3) |
| with gr.Column(scale=1): |
| cmp_models = gr.CheckboxGroup( |
| choices=_initial_versions, value=_initial_versions, label="Models to run" |
| ) |
| cmp_mode = gr.Dropdown( |
| choices=_mode_keys, value="chat-coherent", label="Mode preset" |
| ) |
| cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw = _advanced_block() |
| cmp_run = gr.Button("βΆ Run comparison", variant="primary") |
|
|
| |
| cmp_outputs = [] |
| for row_start in range(0, len(_initial_versions), 2): |
| with gr.Row(): |
| for v in _initial_versions[row_start:row_start + 2]: |
| cmp_outputs.append(gr.Textbox(label=v, lines=10, interactive=False)) |
|
|
| cmp_run.click( |
| _compare_fn, |
| inputs=[cmp_prompt, cmp_models, cmp_mode, |
| cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, |
| cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw], |
| outputs=cmp_outputs, |
| ) |
|
|
| |
| with gr.Tab("Benchmark"): |
| gr.Markdown( |
| "Evaluate models on standard benchmarks.\n\n" |
| "- **BLiMP** β grammaticality minimal pairs (accuracy)\n" |
| "- **WikiText-2** β LM perplexity (lower = better)\n" |
| "- **ARC-Easy** β multiple-choice science QA (accuracy)" |
| ) |
| with gr.Row(): |
| bench_type = gr.Radio( |
| choices=list(BENCHMARKS.keys()), value="arc_easy", label="Benchmark" |
| ) |
| bench_models = gr.CheckboxGroup( |
| choices=_initial_versions, |
| value=[_initial_versions[0]] if _initial_versions else [], |
| label="Models", |
| ) |
| bench_samples = gr.Slider(10, 500, value=100, step=10, label="Max samples (fewer = faster)") |
| bench_run = gr.Button("Run benchmark", variant="primary") |
| with gr.Row(): |
| bench_log = gr.Textbox(label="Progress log", lines=12, interactive=False) |
| bench_plot = gr.Image(label="Results chart", type="filepath") |
|
|
| bench_run.click( |
| _benchmark_fn, |
| inputs=[bench_type, bench_models, bench_samples], |
| outputs=[bench_log, bench_plot], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(theme=_HF_THEME) |
|
|